Skip to content

Commit e6762a9

Browse files
committed
support silu vectorization
1 parent 5a19a6c commit e6762a9

File tree

2 files changed

+70
-16
lines changed

2 files changed

+70
-16
lines changed

csrc/activation_kernels.cu

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,50 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
1717
}
1818
// Activation and gating kernel template.
1919

20+
template <typename T, uint32_t N>
21+
struct __align__(16) vec_t {
22+
T values[N];
23+
};
24+
25+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
26+
bool act_first>
27+
__global__ void act_and_mul_kernel_vec(
28+
scalar_t* __restrict__ out, // [..., d]
29+
const scalar_t* __restrict__ input, // [..., 2, d]
30+
const int d) {
31+
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
32+
(__CUDA_ARCH__ >= 900))
33+
asm volatile("griddepcontrol.wait;");
34+
#endif
35+
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
36+
const int64_t token_idx = blockIdx.x;
37+
scalar_t* __restrict__ out_ptr = out + token_idx * d;
38+
vec_t<scalar_t, vec_size>* __restrict__ out_vec_ptr =
39+
reinterpret_cast<vec_t<scalar_t, vec_size>*>(out_ptr);
40+
vec_t<scalar_t, vec_size> out_vec;
41+
const int64_t stride = blockDim.x;
42+
const int64_t offset = token_idx * 2 * d;
43+
#pragma unroll 1
44+
for (int64_t idx = threadIdx.x; idx < d / vec_size; idx += stride) {
45+
const vec_t<scalar_t, vec_size> x_vec =
46+
reinterpret_cast<const vec_t<scalar_t, vec_size>*>(
47+
input)[offset / vec_size + idx];
48+
const vec_t<scalar_t, vec_size> y_vec =
49+
reinterpret_cast<const vec_t<scalar_t, vec_size>*>(
50+
input)[(offset + d) / vec_size + idx];
51+
#pragma unroll
52+
for (uint32_t i = 0; i < vec_size; ++i) {
53+
out_vec.values[i] = compute<scalar_t, ACT_FN, act_first>(x_vec.values[i],
54+
y_vec.values[i]);
55+
}
56+
out_vec_ptr[idx] = out_vec;
57+
}
58+
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
59+
(__CUDA_ARCH__ >= 900))
60+
asm volatile("griddepcontrol.launch_dependents;");
61+
#endif
62+
}
63+
2064
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
2165
bool act_first>
2266
__global__ void act_and_mul_kernel(
@@ -65,21 +109,31 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
65109
// Launch activation and gating kernel.
66110
// Use ACT_FIRST (bool) indicating whether to apply the activation function
67111
// first.
68-
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
69-
int d = input.size(-1) / 2; \
70-
int64_t num_tokens = input.numel() / input.size(-1); \
71-
dim3 grid(num_tokens); \
72-
dim3 block(std::min(d, 1024)); \
73-
if (num_tokens == 0) { \
74-
return; \
75-
} \
76-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
77-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
78-
VLLM_DISPATCH_FLOATING_TYPES( \
79-
input.scalar_type(), "act_and_mul_kernel", [&] { \
80-
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
81-
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
82-
input.data_ptr<scalar_t>(), d); \
112+
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
113+
int d = input.size(-1) / 2; \
114+
int64_t num_tokens = input.numel() / input.size(-1); \
115+
if (num_tokens == 0) { \
116+
return; \
117+
} \
118+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
119+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
120+
VLLM_DISPATCH_FLOATING_TYPES( \
121+
input.scalar_type(), "act_and_mul_kernel", [&] { \
122+
uint32_t vec_size = 16 / sizeof(scalar_t); \
123+
dim3 grid(num_tokens); \
124+
dim3 block_vec(std::min(d / vec_size, 1024U)); \
125+
dim3 block(std::min(d, 1024)); \
126+
if (d % vec_size == 0 && \
127+
(reinterpret_cast<uintptr_t>(input.data_ptr()) % 16 == 0) && \
128+
(reinterpret_cast<uintptr_t>(out.data_ptr()) % 16 == 0)) { \
129+
vllm::act_and_mul_kernel_vec<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
130+
<<<grid, block_vec, 0, stream>>>(out.data_ptr<scalar_t>(), \
131+
input.data_ptr<scalar_t>(), d); \
132+
} else { \
133+
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
134+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
135+
input.data_ptr<scalar_t>(), d); \
136+
} \
83137
});
84138

85139
void silu_and_mul(torch::Tensor& out, // [..., d]

tests/kernels/core/test_activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
DTYPES = [torch.half, torch.bfloat16, torch.float]
1818
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
19-
D = [512, 13824] # Arbitrary values for testing
19+
D = [512, 13824, 16385] # Arbitrary values for testing
2020
SEEDS = [0]
2121
CUDA_DEVICES = [
2222
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)

0 commit comments

Comments
 (0)