1
+ #ifndef KERNEL_FLOAT_APPLY_H
2
+ #define KERNEL_FLOAT_APPLY_H
3
+
4
+ #include " base.h"
5
+
6
+ namespace kernel_float {
7
+ namespace detail {
8
+
9
+ template <typename ... Es>
10
+ struct broadcast_extent_helper ;
11
+
12
+ template <typename E>
13
+ struct broadcast_extent_helper <E> {
14
+ using type = E;
15
+ };
16
+
17
+ template <size_t N>
18
+ struct broadcast_extent_helper <extent<N>, extent<N>> {
19
+ using type = extent<N>;
20
+ };
21
+
22
+ template <size_t N>
23
+ struct broadcast_extent_helper <extent<1 >, extent<N>> {
24
+ using type = extent<N>;
25
+ };
26
+
27
+ template <size_t N>
28
+ struct broadcast_extent_helper <extent<N>, extent<1 >> {
29
+ using type = extent<N>;
30
+ };
31
+
32
+ template <>
33
+ struct broadcast_extent_helper <extent<1 >, extent<1 >> {
34
+ using type = extent<1 >;
35
+ };
36
+
37
+ template <typename A, typename B, typename C, typename ... Rest>
38
+ struct broadcast_extent_helper <A, B, C, Rest...>:
39
+ broadcast_extent_helper<typename broadcast_extent_helper<A, B>::type, C, Rest...> {};
40
+
41
+ } // namespace detail
42
+
43
+ template <typename ... Es>
44
+ using broadcast_extent = typename detail::broadcast_extent_helper<Es...>::type;
45
+
46
+ template <typename ... Vs>
47
+ using broadcast_vector_extent_type = broadcast_extent<vector_extent_type<Vs>...>;
48
+
49
+ template <typename From, typename To>
50
+ static constexpr bool is_broadcastable = is_same_type<broadcast_extent<From, To>, To>;
51
+
52
+ template <typename V, typename To>
53
+ static constexpr bool is_vector_broadcastable = is_broadcastable<vector_extent_type<V>, To>;
54
+
55
+ namespace detail {
56
+
57
+ template <typename T, typename From, typename To>
58
+ struct broadcast_impl ;
59
+
60
+ template <typename T, size_t N>
61
+ struct broadcast_impl <T, extent<1 >, extent<N>> {
62
+ KERNEL_FLOAT_INLINE static vector_storage<T, N> call (const vector_storage<T, 1 >& input) {
63
+ vector_storage<T, N> output;
64
+ for (size_t i = 0 ; i < N; i++) {
65
+ output.data ()[i] = input.data ()[0 ];
66
+ }
67
+ return output;
68
+ }
69
+ };
70
+
71
+ template <typename T, size_t N>
72
+ struct broadcast_impl <T, extent<N>, extent<N>> {
73
+ KERNEL_FLOAT_INLINE static vector_storage<T, N> call (vector_storage<T, N> input) {
74
+ return input;
75
+ }
76
+ };
77
+
78
+ template <typename T>
79
+ struct broadcast_impl <T, extent<1 >, extent<1 >> {
80
+ KERNEL_FLOAT_INLINE static vector_storage<T, 1 > call (vector_storage<T, 1 > input) {
81
+ return input;
82
+ }
83
+ };
84
+
85
+ } // namespace detail
86
+
87
+ /* *
88
+ * Takes the given vector `input` and extends its size to a length of `N`. This is only valid if the size of `input`
89
+ * is 1 or `N`.
90
+ *
91
+ * Example
92
+ * =======
93
+ * ```
94
+ * vec<float, 1> a = {1.0f};
95
+ * vec<float, 5> x = broadcast<5>(a); // Returns [1.0f, 1.0f, 1.0f, 1.0f, 1.0f]
96
+ *
97
+ * vec<float, 5> b = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
98
+ * vec<float, 5> y = broadcast<5>(b); // Returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f]
99
+ * ```
100
+ */
101
+ template <size_t N, typename V>
102
+ KERNEL_FLOAT_INLINE vector<vector_value_type<V>, extent<N>>
103
+ broadcast (const V& input, extent<N> new_size = {}) {
104
+ using T = vector_value_type<V>;
105
+ return detail::broadcast_impl<T, vector_extent_type<V>, extent<N>>::call (
106
+ into_vector_storage (input));
107
+ }
108
+
109
+ /* *
110
+ * Takes the given vector `input` and extends its size to the same length as vector `other`. This is only valid if the
111
+ * size of `input` is 1 or the same as `other`.
112
+ */
113
+ template <typename V, typename R>
114
+ KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<R>>
115
+ broadcast_like (const V& input, const R& other) {
116
+ return broadcast (input, vector_extent_type<R> {});
117
+ }
118
+
119
+ namespace detail {
120
+
121
+ template <size_t N>
122
+ struct apply_recur_impl ;
123
+
124
+ template <typename F, size_t N, typename Output, typename ... Args>
125
+ struct apply_impl {
126
+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
127
+ apply_recur_impl<N>::call (fun, result, inputs...);
128
+ }
129
+ };
130
+
131
+ template <size_t N>
132
+ struct apply_recur_impl {
133
+ static constexpr size_t K = round_up_to_power_of_two(N) / 2 ;
134
+
135
+ template <typename F, typename Output, typename ... Args>
136
+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
137
+ apply_impl<F, K, Output, Args...>::call (fun, result, inputs...);
138
+ apply_impl<F, N - K, Output, Args...>::call (fun, result + K, (inputs + K)...);
139
+ }
140
+ };
141
+
142
+ template <>
143
+ struct apply_recur_impl <0 > {
144
+ template <typename F, typename Output, typename ... Args>
145
+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {}
146
+ };
147
+
148
+ template <>
149
+ struct apply_recur_impl <1 > {
150
+ template <typename F, typename Output, typename ... Args>
151
+ KERNEL_FLOAT_INLINE static void call (F fun, Output* result, const Args*... inputs) {
152
+ result[0 ] = fun (inputs[0 ]...);
153
+ }
154
+ };
155
+ } // namespace detail
156
+
157
+ template <typename F, typename ... Args>
158
+ using map_type =
159
+ vector<result_t <F, vector_value_type<Args>...>, broadcast_vector_extent_type<Args...>>;
160
+
161
+ /* *
162
+ * Apply the function `F` to each element from the vector `input` and return the results as a new vector.
163
+ *
164
+ * Examples
165
+ * ========
166
+ * ```
167
+ * vec<float, 4> input = {1.0f, 2.0f, 3.0f, 4.0f};
168
+ * vec<float, 4> squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f]
169
+ * ```
170
+ */
171
+ template <typename F, typename ... Args>
172
+ KERNEL_FLOAT_INLINE map_type<F, Args...> map (F fun, const Args&... args) {
173
+ using Output = result_t <F, vector_value_type<Args>...>;
174
+ using E = broadcast_vector_extent_type<Args...>;
175
+ vector_storage<Output, E::value> result;
176
+
177
+ detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>::call (
178
+ fun,
179
+ result.data (),
180
+ (detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
181
+ into_vector_storage (args))
182
+ .data ())...);
183
+
184
+ return result;
185
+ }
186
+
187
+ } // namespace kernel_float
188
+
189
+ #endif // KERNEL_FLOAT_APPLY_H
0 commit comments