@@ -17,6 +17,50 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
17
17
}
18
18
// Activation and gating kernel template.
19
19
20
+ template <typename T, uint32_t N>
21
+ struct __align__ (16 ) vec_t {
22
+ T values[N];
23
+ };
24
+
25
+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
26
+ bool act_first>
27
+ __global__ void act_and_mul_kernel_vec (
28
+ scalar_t * __restrict__ out, // [..., d]
29
+ const scalar_t * __restrict__ input, // [..., 2, d]
30
+ const int d) {
31
+ #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
32
+ (__CUDA_ARCH__ >= 900 ))
33
+ asm volatile (" griddepcontrol.wait;" );
34
+ #endif
35
+ constexpr uint32_t vec_size = 16 / sizeof (scalar_t );
36
+ const int64_t token_idx = blockIdx .x ;
37
+ scalar_t * __restrict__ out_ptr = out + token_idx * d;
38
+ vec_t <scalar_t , vec_size>* __restrict__ out_vec_ptr =
39
+ reinterpret_cast <vec_t <scalar_t , vec_size>*>(out_ptr);
40
+ vec_t <scalar_t , vec_size> out_vec;
41
+ const int64_t stride = blockDim .x ;
42
+ const int64_t offset = token_idx * 2 * d;
43
+ #pragma unroll 1
44
+ for (int64_t idx = threadIdx .x ; idx < d / vec_size; idx += stride) {
45
+ const vec_t <scalar_t , vec_size> x_vec =
46
+ reinterpret_cast <const vec_t <scalar_t , vec_size>*>(
47
+ input)[offset / vec_size + idx];
48
+ const vec_t <scalar_t , vec_size> y_vec =
49
+ reinterpret_cast <const vec_t <scalar_t , vec_size>*>(
50
+ input)[(offset + d) / vec_size + idx];
51
+ #pragma unroll
52
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
53
+ out_vec.values [i] = compute<scalar_t , ACT_FN, act_first>(x_vec.values [i],
54
+ y_vec.values [i]);
55
+ }
56
+ out_vec_ptr[idx] = out_vec;
57
+ }
58
+ #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
59
+ (__CUDA_ARCH__ >= 900 ))
60
+ asm volatile (" griddepcontrol.launch_dependents;" );
61
+ #endif
62
+ }
63
+
20
64
template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
21
65
bool act_first>
22
66
__global__ void act_and_mul_kernel (
@@ -65,21 +109,31 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
65
109
// Launch activation and gating kernel.
66
110
// Use ACT_FIRST (bool) indicating whether to apply the activation function
67
111
// first.
68
- #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL, ACT_FIRST ) \
69
- int d = input.size(-1 ) / 2 ; \
70
- int64_t num_tokens = input.numel() / input.size(-1 ); \
71
- dim3 grid (num_tokens); \
72
- dim3 block (std::min(d, 1024 )); \
73
- if (num_tokens == 0 ) { \
74
- return ; \
75
- } \
76
- const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
77
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
78
- VLLM_DISPATCH_FLOATING_TYPES ( \
79
- input.scalar_type(), "act_and_mul_kernel", [&] { \
80
- vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >, ACT_FIRST> \
81
- <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
82
- input.data_ptr <scalar_t >(), d); \
112
+ #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL, ACT_FIRST ) \
113
+ int d = input.size(-1 ) / 2 ; \
114
+ int64_t num_tokens = input.numel() / input.size(-1 ); \
115
+ if (num_tokens == 0 ) { \
116
+ return ; \
117
+ } \
118
+ const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
119
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
120
+ VLLM_DISPATCH_FLOATING_TYPES ( \
121
+ input.scalar_type(), "act_and_mul_kernel", [&] { \
122
+ uint32_t vec_size = 16 / sizeof (scalar_t ); \
123
+ dim3 grid (num_tokens); \
124
+ dim3 block_vec (std::min (d / vec_size, 1024U )); \
125
+ dim3 block (std::min (d, 1024 )); \
126
+ if (d % vec_size == 0 && \
127
+ (reinterpret_cast <uintptr_t >(input.data_ptr ()) % 16 == 0 ) && \
128
+ (reinterpret_cast <uintptr_t >(out.data_ptr ()) % 16 == 0 )) { \
129
+ vllm::act_and_mul_kernel_vec<scalar_t , KERNEL<scalar_t >, ACT_FIRST> \
130
+ <<<grid, block_vec, 0 , stream>>> (out.data_ptr <scalar_t >(), \
131
+ input.data_ptr <scalar_t >(), d); \
132
+ } else { \
133
+ vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >, ACT_FIRST> \
134
+ <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
135
+ input.data_ptr <scalar_t >(), d); \
136
+ } \
83
137
});
84
138
85
139
void silu_and_mul (torch::Tensor& out, // [..., d]
0 commit comments