Skip to content

Commit 28f811a

Browse files
committed
Refactor how apply_impl is performed to simplify fp16/bf16
1 parent a35b9f6 commit 28f811a

File tree

9 files changed

+749
-769
lines changed

9 files changed

+749
-769
lines changed

include/kernel_float/apply.h

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#ifndef KERNEL_FLOAT_APPLY_H
2+
#define KERNEL_FLOAT_APPLY_H
3+
4+
#include "base.h"
5+
6+
namespace kernel_float {
7+
namespace detail {
8+
9+
template<typename... Es>
10+
struct broadcast_extent_helper;
11+
12+
template<typename E>
13+
struct broadcast_extent_helper<E> {
14+
using type = E;
15+
};
16+
17+
template<size_t N>
18+
struct broadcast_extent_helper<extent<N>, extent<N>> {
19+
using type = extent<N>;
20+
};
21+
22+
template<size_t N>
23+
struct broadcast_extent_helper<extent<1>, extent<N>> {
24+
using type = extent<N>;
25+
};
26+
27+
template<size_t N>
28+
struct broadcast_extent_helper<extent<N>, extent<1>> {
29+
using type = extent<N>;
30+
};
31+
32+
template<>
33+
struct broadcast_extent_helper<extent<1>, extent<1>> {
34+
using type = extent<1>;
35+
};
36+
37+
template<typename A, typename B, typename C, typename... Rest>
38+
struct broadcast_extent_helper<A, B, C, Rest...>:
39+
broadcast_extent_helper<typename broadcast_extent_helper<A, B>::type, C, Rest...> {};
40+
41+
} // namespace detail
42+
43+
template<typename... Es>
44+
using broadcast_extent = typename detail::broadcast_extent_helper<Es...>::type;
45+
46+
template<typename... Vs>
47+
using broadcast_vector_extent_type = broadcast_extent<vector_extent_type<Vs>...>;
48+
49+
template<typename From, typename To>
50+
static constexpr bool is_broadcastable = is_same_type<broadcast_extent<From, To>, To>;
51+
52+
template<typename V, typename To>
53+
static constexpr bool is_vector_broadcastable = is_broadcastable<vector_extent_type<V>, To>;
54+
55+
namespace detail {
56+
57+
template<typename T, typename From, typename To>
58+
struct broadcast_impl;
59+
60+
template<typename T, size_t N>
61+
struct broadcast_impl<T, extent<1>, extent<N>> {
62+
KERNEL_FLOAT_INLINE static vector_storage<T, N> call(const vector_storage<T, 1>& input) {
63+
vector_storage<T, N> output;
64+
for (size_t i = 0; i < N; i++) {
65+
output.data()[i] = input.data()[0];
66+
}
67+
return output;
68+
}
69+
};
70+
71+
template<typename T, size_t N>
72+
struct broadcast_impl<T, extent<N>, extent<N>> {
73+
KERNEL_FLOAT_INLINE static vector_storage<T, N> call(vector_storage<T, N> input) {
74+
return input;
75+
}
76+
};
77+
78+
template<typename T>
79+
struct broadcast_impl<T, extent<1>, extent<1>> {
80+
KERNEL_FLOAT_INLINE static vector_storage<T, 1> call(vector_storage<T, 1> input) {
81+
return input;
82+
}
83+
};
84+
85+
} // namespace detail
86+
87+
/**
88+
* Takes the given vector `input` and extends its size to a length of `N`. This is only valid if the size of `input`
89+
* is 1 or `N`.
90+
*
91+
* Example
92+
* =======
93+
* ```
94+
* vec<float, 1> a = {1.0f};
95+
* vec<float, 5> x = broadcast<5>(a); // Returns [1.0f, 1.0f, 1.0f, 1.0f, 1.0f]
96+
*
97+
* vec<float, 5> b = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
98+
* vec<float, 5> y = broadcast<5>(b); // Returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f]
99+
* ```
100+
*/
101+
template<size_t N, typename V>
102+
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, extent<N>>
103+
broadcast(const V& input, extent<N> new_size = {}) {
104+
using T = vector_value_type<V>;
105+
return detail::broadcast_impl<T, vector_extent_type<V>, extent<N>>::call(
106+
into_vector_storage(input));
107+
}
108+
109+
/**
110+
* Takes the given vector `input` and extends its size to the same length as vector `other`. This is only valid if the
111+
* size of `input` is 1 or the same as `other`.
112+
*/
113+
template<typename V, typename R>
114+
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<R>>
115+
broadcast_like(const V& input, const R& other) {
116+
return broadcast(input, vector_extent_type<R> {});
117+
}
118+
119+
namespace detail {
120+
121+
template<size_t N>
122+
struct apply_recur_impl;
123+
124+
template<typename F, size_t N, typename Output, typename... Args>
125+
struct apply_impl {
126+
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {
127+
apply_recur_impl<N>::call(fun, result, inputs...);
128+
}
129+
};
130+
131+
template<size_t N>
132+
struct apply_recur_impl {
133+
static constexpr size_t K = round_up_to_power_of_two(N) / 2;
134+
135+
template<typename F, typename Output, typename... Args>
136+
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {
137+
apply_impl<F, K, Output, Args...>::call(fun, result, inputs...);
138+
apply_impl<F, N - K, Output, Args...>::call(fun, result + K, (inputs + K)...);
139+
}
140+
};
141+
142+
template<>
143+
struct apply_recur_impl<0> {
144+
template<typename F, typename Output, typename... Args>
145+
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {}
146+
};
147+
148+
template<>
149+
struct apply_recur_impl<1> {
150+
template<typename F, typename Output, typename... Args>
151+
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {
152+
result[0] = fun(inputs[0]...);
153+
}
154+
};
155+
} // namespace detail
156+
157+
template<typename F, typename... Args>
158+
using map_type =
159+
vector<result_t<F, vector_value_type<Args>...>, broadcast_vector_extent_type<Args...>>;
160+
161+
/**
162+
* Apply the function `F` to each element from the vector `input` and return the results as a new vector.
163+
*
164+
* Examples
165+
* ========
166+
* ```
167+
* vec<float, 4> input = {1.0f, 2.0f, 3.0f, 4.0f};
168+
* vec<float, 4> squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f]
169+
* ```
170+
*/
171+
template<typename F, typename... Args>
172+
KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
173+
using Output = result_t<F, vector_value_type<Args>...>;
174+
using E = broadcast_vector_extent_type<Args...>;
175+
vector_storage<Output, E::value> result;
176+
177+
detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>::call(
178+
fun,
179+
result.data(),
180+
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
181+
into_vector_storage(args))
182+
.data())...);
183+
184+
return result;
185+
}
186+
187+
} // namespace kernel_float
188+
189+
#endif // KERNEL_FLOAT_APPLY_H

0 commit comments

Comments
 (0)