Skip to content

Commit f40ff58

Browse files
committed
Simplify how binary operators are defined
1 parent 3da5ba0 commit f40ff58

File tree

2 files changed

+130
-204
lines changed

2 files changed

+130
-204
lines changed

include/kernel_float/binops.h

Lines changed: 64 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,36 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
8686
return zip_common(ops::NAME<C> {}, std::forward<L>(left), std::forward<R>(right)); \
8787
}
8888

89-
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \
90-
namespace ops { \
91-
template<typename T> \
92-
struct NAME { \
93-
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
94-
return ops::cast<decltype(EXPR), T> {}(EXPR); \
95-
} \
96-
}; \
97-
} \
98-
\
89+
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \
90+
namespace ops { \
91+
template<typename T, typename = void> \
92+
struct NAME { \
93+
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
94+
return ops::cast<decltype(EXPR), T> {}(EXPR); \
95+
} \
96+
}; \
97+
\
98+
template<> \
99+
struct NAME<double> { \
100+
KERNEL_FLOAT_INLINE double operator()(double left, double right) { \
101+
return ops::cast<decltype(EXPR_F64), double> {}(EXPR_F64); \
102+
} \
103+
}; \
104+
\
105+
template<typename T> \
106+
struct NAME<T, enable_if_t<detail::allow_float_fallback<T>::value>> { \
107+
KERNEL_FLOAT_INLINE T operator()(T left_, T right_) { \
108+
float left = ops::cast<T, float> {}(left_); \
109+
float right = ops::cast<T, float> {}(right_); \
110+
return ops::cast<decltype(EXPR_F32), T> {}(EXPR_F32); \
111+
} \
112+
}; \
113+
} \
114+
\
99115
KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME)
100116

101-
#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \
102-
KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \
117+
#define KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(NAME, OP, EXPR_F64, EXPR_F32) \
118+
KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right, EXPR_F64, EXPR_F32) \
103119
\
104120
template<typename L, typename R, typename C = promote_t<L, R>, typename E1, typename E2> \
105121
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, vector<L, E1>, vector<R, E2>> operator OP( \
@@ -120,11 +136,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
120136
return zip_common(ops::NAME<C> {}, left, right); \
121137
}
122138

139+
#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \
140+
KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(NAME, OP, left OP right, left OP right)
141+
123142
KERNEL_FLOAT_DEFINE_BINARY_OP(add, +)
124143
KERNEL_FLOAT_DEFINE_BINARY_OP(subtract, -)
125144
KERNEL_FLOAT_DEFINE_BINARY_OP(divide, /)
126145
KERNEL_FLOAT_DEFINE_BINARY_OP(multiply, *)
127-
KERNEL_FLOAT_DEFINE_BINARY_OP(modulo, %)
146+
KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(modulo, %, ::fmod(left, right), ::fmodf(left, right))
128147

129148
KERNEL_FLOAT_DEFINE_BINARY_OP(equal_to, ==)
130149
KERNEL_FLOAT_DEFINE_BINARY_OP(not_equal_to, !=)
@@ -133,9 +152,11 @@ KERNEL_FLOAT_DEFINE_BINARY_OP(less_equal, <=)
133152
KERNEL_FLOAT_DEFINE_BINARY_OP(greater, >)
134153
KERNEL_FLOAT_DEFINE_BINARY_OP(greater_equal, >=)
135154

136-
KERNEL_FLOAT_DEFINE_BINARY_OP(bit_and, &)
137-
KERNEL_FLOAT_DEFINE_BINARY_OP(bit_or, |)
138-
KERNEL_FLOAT_DEFINE_BINARY_OP(bit_xor, ^)
155+
// clang-format off
156+
KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_and, &, bool(left) && bool(right), bool(left) && bool(right))
157+
KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_or, |, bool(left) | bool(right), bool(left) | bool(right))
158+
KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool(left) ^ bool(right), bool(left) ^ bool(right))
159+
// clang-format on
139160

140161
// clang-format off
141162
template<template<typename> typename F, typename T, typename E, typename R>
@@ -247,56 +268,40 @@ KERNEL_FLOAT_DEFINE_BINARY_MATH(nextafter)
247268
KERNEL_FLOAT_DEFINE_BINARY_MATH(pow)
248269
KERNEL_FLOAT_DEFINE_BINARY_MATH(remainder)
249270

250-
KERNEL_FLOAT_DEFINE_BINARY(hypot, (ops::sqrt<T>()(left * left + right * right)))
251-
KERNEL_FLOAT_DEFINE_BINARY(rhypot, (T(1) / ops::hypot<T>()(left, right)))
252-
253-
namespace ops {
254-
template<>
255-
struct hypot<double> {
256-
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
257-
return ::hypot(left, right);
258-
};
259-
};
260-
261-
template<>
262-
struct hypot<float> {
263-
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
264-
return ::hypotf(left, right);
265-
};
266-
};
271+
KERNEL_FLOAT_DEFINE_BINARY(
272+
hypot,
273+
ops::sqrt<T>()(left* left + right * right),
274+
::hypot(left, right),
275+
::hypotf(left, right))
267276

268-
// rhypot is only support on the GPU
269277
#if KERNEL_FLOAT_IS_DEVICE
270-
template<>
271-
struct rhypot<double> {
272-
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
273-
return ::rhypot(left, right);
274-
};
275-
};
276-
277-
template<>
278-
struct rhypot<float> {
279-
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
280-
return ::rhypotf(left, right);
281-
};
282-
};
278+
KERNEL_FLOAT_DEFINE_BINARY(
279+
rhypot,
280+
(T(1) / ops::hypot<T>()(left, right)),
281+
::rhypot(left, right),
282+
::rhypotf(left, right))
283+
#else
284+
KERNEL_FLOAT_DEFINE_BINARY(
285+
rhypot,
286+
(T(1) / ops::hypot<T>()(left, right)),
287+
(double(1) / ::hypot(left, right)),
288+
(float(1) / ::hypotf(left, right)))
283289
#endif
284-
}; // namespace ops
285290

286291
#if KERNEL_FLOAT_IS_DEVICE
287-
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
288-
KERNEL_FLOAT_DEFINE_BINARY(FUN_NAME, ops::OP_NAME<T> {}(left, right)) \
289-
namespace ops { \
290-
template<> \
291-
struct OP_NAME<float> { \
292-
KERNEL_FLOAT_INLINE float operator()(float left, float right) { \
293-
return FLOAT_FUN(left, right); \
294-
} \
295-
}; \
296-
}
292+
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
293+
KERNEL_FLOAT_DEFINE_BINARY( \
294+
FUN_NAME, \
295+
ops::OP_NAME<T> {}(left, right), \
296+
ops::OP_NAME<double> {}(left, right), \
297+
ops::OP_NAME<float> {}(left, right))
297298
#else
298299
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
299-
KERNEL_FLOAT_DEFINE_BINARY(FUN_NAME, ops::OP_NAME<T> {}(left, right))
300+
KERNEL_FLOAT_DEFINE_BINARY( \
301+
FUN_NAME, \
302+
ops::OP_NAME<T> {}(left, right), \
303+
ops::OP_NAME<double> {}(left, right), \
304+
ops::OP_NAME<float> {}(left, right))
300305
#endif
301306

302307
KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_div, divide, __fdividef)
@@ -316,48 +321,6 @@ struct multiply<bool> {
316321
return left && right;
317322
}
318323
};
319-
320-
template<>
321-
struct bit_and<float> {
322-
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
323-
return float(bool(left) && bool(right));
324-
}
325-
};
326-
327-
template<>
328-
struct bit_or<float> {
329-
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
330-
return float(bool(left) || bool(right));
331-
}
332-
};
333-
334-
template<>
335-
struct bit_xor<float> {
336-
KERNEL_FLOAT_INLINE float operator()(float left, float right) {
337-
return float(bool(left) ^ bool(right));
338-
}
339-
};
340-
341-
template<>
342-
struct bit_and<double> {
343-
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
344-
return double(bool(left) && bool(right));
345-
}
346-
};
347-
348-
template<>
349-
struct bit_or<double> {
350-
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
351-
return double(bool(left) || bool(right));
352-
}
353-
};
354-
355-
template<>
356-
struct bit_xor<double> {
357-
KERNEL_FLOAT_INLINE double operator()(double left, double right) {
358-
return double(bool(left) ^ bool(right));
359-
}
360-
};
361324
}; // namespace ops
362325

363326
namespace detail {

0 commit comments

Comments
 (0)