@@ -86,20 +86,36 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
86
86
return zip_common (ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
87
87
}
88
88
89
- #define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR ) \
90
- namespace ops { \
91
- template <typename T> \
92
- struct NAME { \
93
- KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
94
- return ops::cast<decltype (EXPR), T> {}(EXPR); \
95
- } \
96
- }; \
97
- } \
98
- \
89
+ #define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR, EXPR_F64, EXPR_F32 ) \
90
+ namespace ops { \
91
+ template <typename T, typename = void > \
92
+ struct NAME { \
93
+ KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
94
+ return ops::cast<decltype (EXPR), T> {}(EXPR); \
95
+ } \
96
+ }; \
97
+ \
98
+ template <> \
99
+ struct NAME <double > { \
100
+ KERNEL_FLOAT_INLINE double operator ()(double left, double right) { \
101
+ return ops::cast<decltype (EXPR_F64), double > {}(EXPR_F64); \
102
+ } \
103
+ }; \
104
+ \
105
+ template <typename T> \
106
+ struct NAME <T, enable_if_t <detail::allow_float_fallback<T>::value>> { \
107
+ KERNEL_FLOAT_INLINE T operator ()(T left_, T right_) { \
108
+ float left = ops::cast<T, float > {}(left_); \
109
+ float right = ops::cast<T, float > {}(right_); \
110
+ return ops::cast<decltype (EXPR_F32), T> {}(EXPR_F32); \
111
+ } \
112
+ }; \
113
+ } \
114
+ \
99
115
KERNEL_FLOAT_DEFINE_BINARY_FUN (NAME)
100
116
101
- #define KERNEL_FLOAT_DEFINE_BINARY_OP (NAME, OP ) \
102
- KERNEL_FLOAT_DEFINE_BINARY (NAME, left OP right) \
117
+ #define KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK (NAME, OP, EXPR_F64, EXPR_F32 ) \
118
+ KERNEL_FLOAT_DEFINE_BINARY (NAME, left OP right, EXPR_F64, EXPR_F32) \
103
119
\
104
120
template <typename L, typename R, typename C = promote_t <L, R>, typename E1 , typename E2 > \
105
121
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, vector<L, E1 >, vector<R, E2 >> operator OP ( \
@@ -120,11 +136,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
120
136
return zip_common (ops::NAME<C> {}, left, right); \
121
137
}
122
138
139
+ #define KERNEL_FLOAT_DEFINE_BINARY_OP (NAME, OP ) \
140
+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK (NAME, OP, left OP right, left OP right)
141
+
123
142
KERNEL_FLOAT_DEFINE_BINARY_OP (add, +)
124
143
KERNEL_FLOAT_DEFINE_BINARY_OP (subtract, -)
125
144
KERNEL_FLOAT_DEFINE_BINARY_OP (divide, /)
126
145
KERNEL_FLOAT_DEFINE_BINARY_OP (multiply, *)
127
- KERNEL_FLOAT_DEFINE_BINARY_OP (modulo, %)
146
+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK (modulo, %, ::fmod(left, right), ::fmodf(left, right) )
128
147
129
148
KERNEL_FLOAT_DEFINE_BINARY_OP (equal_to, ==)
130
149
KERNEL_FLOAT_DEFINE_BINARY_OP(not_equal_to, !=)
@@ -133,9 +152,11 @@ KERNEL_FLOAT_DEFINE_BINARY_OP(less_equal, <=)
133
152
KERNEL_FLOAT_DEFINE_BINARY_OP(greater, >)
134
153
KERNEL_FLOAT_DEFINE_BINARY_OP(greater_equal, >=)
135
154
136
- KERNEL_FLOAT_DEFINE_BINARY_OP (bit_and, &)
137
- KERNEL_FLOAT_DEFINE_BINARY_OP (bit_or, |)
138
- KERNEL_FLOAT_DEFINE_BINARY_OP (bit_xor, ^)
155
+ // clang-format off
156
+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_and, &, bool (left) && bool(right), bool(left) && bool(right))
157
+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_or, |, bool (left) | bool(right), bool(left) | bool(right))
158
+ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool (left) ^ bool(right), bool(left) ^ bool(right))
159
+ // clang-format on
139
160
140
161
// clang-format off
141
162
template<template<typename> typename F, typename T, typename E, typename R>
@@ -247,56 +268,40 @@ KERNEL_FLOAT_DEFINE_BINARY_MATH(nextafter)
247
268
KERNEL_FLOAT_DEFINE_BINARY_MATH(pow)
248
269
KERNEL_FLOAT_DEFINE_BINARY_MATH(remainder)
249
270
250
- KERNEL_FLOAT_DEFINE_BINARY (hypot, (ops::sqrt<T>()(left * left + right * right)))
251
- KERNEL_FLOAT_DEFINE_BINARY (rhypot, (T(1 ) / ops::hypot<T>()(left, right)))
252
-
253
- namespace ops {
254
- template <>
255
- struct hypot <double > {
256
- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
257
- return ::hypot (left, right);
258
- };
259
- };
260
-
261
- template <>
262
- struct hypot <float > {
263
- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
264
- return ::hypotf (left, right);
265
- };
266
- };
271
+ KERNEL_FLOAT_DEFINE_BINARY(
272
+ hypot,
273
+ ops::sqrt<T>()(left* left + right * right),
274
+ ::hypot(left, right),
275
+ ::hypotf(left, right))
267
276
268
- // rhypot is only support on the GPU
269
277
#if KERNEL_FLOAT_IS_DEVICE
270
- template <>
271
- struct rhypot <double > {
272
- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
273
- return ::rhypot (left, right);
274
- };
275
- };
276
-
277
- template <>
278
- struct rhypot <float > {
279
- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
280
- return ::rhypotf (left, right);
281
- };
282
- };
278
+ KERNEL_FLOAT_DEFINE_BINARY (
279
+ rhypot,
280
+ (T(1 ) / ops::hypot<T>()(left, right)),
281
+ ::rhypot(left, right),
282
+ ::rhypotf(left, right))
283
+ #else
284
+ KERNEL_FLOAT_DEFINE_BINARY (
285
+ rhypot,
286
+ (T(1 ) / ops::hypot<T>()(left, right)),
287
+ (double (1 ) / ::hypot(left, right)),
288
+ (float (1 ) / ::hypotf(left, right)))
283
289
#endif
284
- }; // namespace ops
285
290
286
291
#if KERNEL_FLOAT_IS_DEVICE
287
- #define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
288
- KERNEL_FLOAT_DEFINE_BINARY (FUN_NAME, ops::OP_NAME<T> {}(left, right)) \
289
- namespace ops { \
290
- template <> \
291
- struct OP_NAME <float > { \
292
- KERNEL_FLOAT_INLINE float operator ()(float left, float right) { \
293
- return FLOAT_FUN (left, right); \
294
- } \
295
- }; \
296
- }
292
+ #define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
293
+ KERNEL_FLOAT_DEFINE_BINARY ( \
294
+ FUN_NAME, \
295
+ ops::OP_NAME<T> {}(left, right), \
296
+ ops::OP_NAME<double> {}(left, right), \
297
+ ops::OP_NAME<float > {}(left, right))
297
298
#else
298
299
#define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
299
- KERNEL_FLOAT_DEFINE_BINARY (FUN_NAME, ops::OP_NAME<T> {}(left, right))
300
+ KERNEL_FLOAT_DEFINE_BINARY ( \
301
+ FUN_NAME, \
302
+ ops::OP_NAME<T> {}(left, right), \
303
+ ops::OP_NAME<double> {}(left, right), \
304
+ ops::OP_NAME<float > {}(left, right))
300
305
#endif
301
306
302
307
KERNEL_FLOAT_DEFINE_BINARY_FAST (fast_div, divide, __fdividef)
@@ -316,48 +321,6 @@ struct multiply<bool> {
316
321
return left && right;
317
322
}
318
323
};
319
-
320
- template <>
321
- struct bit_and <float > {
322
- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
323
- return float (bool (left) && bool (right));
324
- }
325
- };
326
-
327
- template <>
328
- struct bit_or <float > {
329
- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
330
- return float (bool (left) || bool (right));
331
- }
332
- };
333
-
334
- template <>
335
- struct bit_xor <float > {
336
- KERNEL_FLOAT_INLINE float operator ()(float left, float right) {
337
- return float (bool (left) ^ bool (right));
338
- }
339
- };
340
-
341
- template <>
342
- struct bit_and <double > {
343
- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
344
- return double (bool (left) && bool (right));
345
- }
346
- };
347
-
348
- template <>
349
- struct bit_or <double > {
350
- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
351
- return double (bool (left) || bool (right));
352
- }
353
- };
354
-
355
- template <>
356
- struct bit_xor <double > {
357
- KERNEL_FLOAT_INLINE double operator ()(double left, double right) {
358
- return double (bool (left) ^ bool (right));
359
- }
360
- };
361
324
}; // namespace ops
362
325
363
326
namespace detail {
0 commit comments