@@ -61,7 +61,6 @@ class FusedMHABlockAttention : public HpuFusedOperator {
61
61
std::vector<int64_t > src_dims = std::vector<int64_t >(ins[src_index].dims );
62
62
63
63
int64_t batch_size = src_dims[0 ];
64
- int64_t seq_length = src_dims[1 ];
65
64
int64_t hidden_size = ins[linear_weights_index].dims [0 ];
66
65
int64_t block_size = ins[key_cache_index].dims [1 ];
67
66
int64_t num_of_block = ins[block_list_index].dims [0 ];
@@ -93,15 +92,15 @@ class FusedMHABlockAttention : public HpuFusedOperator {
93
92
}
94
93
auto tmp_dims = src_dims;
95
94
auto wt_dims = ins[qkv_weights_index].dims ;
96
- tmp_dims[2 ] = wt_dims[0 ];
95
+ tmp_dims[1 ] = wt_dims[0 ];
97
96
98
97
auto qkv_out = createTensorNoPresist (" qkv_out" , dtype_, tmp_dims);
99
98
std::vector<synTensor> linear_outputs;
100
99
linear_outputs.push_back (qkv_out);
101
100
AddNodeLinear<T>(linear_inputs, linear_outputs, guid_ + " linear" );
102
101
103
102
auto reshape_dims = src_dims;
104
- reshape_dims[2 ] = num_head + 2 * num_kv_head;
103
+ reshape_dims[1 ] = num_head + 2 * num_kv_head;
105
104
reshape_dims.push_back (head_dim);
106
105
107
106
std::vector<synTensor> reshape_outputs;
@@ -112,12 +111,10 @@ class FusedMHABlockAttention : public HpuFusedOperator {
112
111
113
112
std::vector<int64_t > q_dims;
114
113
q_dims.push_back (batch_size);
115
- q_dims.push_back (seq_length);
116
114
q_dims.push_back (num_head);
117
115
q_dims.push_back (head_dim);
118
116
std::vector<int64_t > kv_dims;
119
117
kv_dims.push_back (batch_size);
120
- kv_dims.push_back (seq_length);
121
118
kv_dims.push_back (num_kv_head);
122
119
kv_dims.push_back (head_dim);
123
120
@@ -169,7 +166,7 @@ class FusedMHABlockAttention : public HpuFusedOperator {
169
166
sin_squeezed.push_back (sin_sq);
170
167
171
168
synSqueezeParams squeezeParams;
172
- squeezeParams.axis = 4 ;
169
+ squeezeParams.axis = 3 ;
173
170
AddNodeSqueeze (
174
171
sin_inputs, sin_squeezed, squeezeParams, guid_ + " squeeze_sin" );
175
172
@@ -205,20 +202,6 @@ class FusedMHABlockAttention : public HpuFusedOperator {
205
202
AddNodeRope<T>(inputs_k, outputs_k, ropeParams, guid_ + " rope_k" );
206
203
207
204
// ////////////////////////////////////////////////////////////////
208
- kv_dims.erase (kv_dims.begin () + 1 );
209
-
210
- std::vector<synTensor> outputs_k_squeeze;
211
- auto k_squeeze = createTensorNoPresist (" k_squeeze" , dtype_, kv_dims);
212
- outputs_k_squeeze.push_back (k_squeeze);
213
- AddNodeReshape (outputs_k, outputs_k_squeeze, guid_ + " squeeze_k" );
214
-
215
- std::vector<synTensor> inputs_v_squeeze;
216
- inputs_v_squeeze.push_back (v_split);
217
- std::vector<synTensor> outputs_v_squeeze;
218
- auto v_squeeze = createTensorNoPresist (" v_squeeze" , dtype_, kv_dims);
219
- outputs_v_squeeze.push_back (v_squeeze);
220
- AddNodeReshape (inputs_v_squeeze, outputs_v_squeeze, guid_ + " squeeze_v" );
221
-
222
205
std::vector<int64_t > indices_concat_dims =
223
206
std::vector<int64_t >(ins[block_indices_index].dims );
224
207
indices_concat_dims.emplace_back (1 );
@@ -256,7 +239,7 @@ class FusedMHABlockAttention : public HpuFusedOperator {
256
239
std::vector<synTensor> inputs_scatter_k;
257
240
inputs_scatter_k.push_back (key_cache);
258
241
inputs_scatter_k.push_back (indices_concat);
259
- inputs_scatter_k.push_back (k_squeeze );
242
+ inputs_scatter_k.push_back (k_rope );
260
243
std::vector<synTensor> outputs_scatter_k;
261
244
outputs_scatter_k.push_back (kCache_out );
262
245
AddNodeScatter<T>(
@@ -269,7 +252,7 @@ class FusedMHABlockAttention : public HpuFusedOperator {
269
252
std::vector<synTensor> inputs_scatter_v;
270
253
inputs_scatter_v.push_back (value_cache);
271
254
inputs_scatter_v.push_back (indices_concat);
272
- inputs_scatter_v.push_back (v_squeeze );
255
+ inputs_scatter_v.push_back (v_split );
273
256
std::vector<synTensor> outputs_scatter_v;
274
257
outputs_scatter_v.push_back (vCache_out);
275
258
AddNodeScatter<T>(
@@ -702,19 +685,9 @@ class FusedMHABlockAttention : public HpuFusedOperator {
702
685
AddNodeGemm (
703
686
map_attn_in, map_attn_out, gemm_params_t_f, guid_ + " gemm_map_attn" );
704
687
705
- std::vector<int64_t > reshape_attn_dims;
706
- reshape_attn_dims.push_back (batch_size);
707
- reshape_attn_dims.push_back (1 );
708
- reshape_attn_dims.push_back (hidden_size);
709
- auto attn = createTensorNoPresist (" attn" , dtype_, reshape_attn_dims);
710
- std::vector<synTensor> attn_out;
711
- attn_out.push_back (attn);
712
-
713
- AddNodeReshape (map_attn_out, attn_out, guid_ + " attn" );
714
-
715
688
std::vector<synTensor> proj_in;
716
689
auto linear_weights = createTensorFromCT (&ct, linear_weights_index);
717
- proj_in.push_back (attn );
690
+ proj_in.push_back (mapped_attn );
718
691
proj_in.push_back (linear_weights);
719
692
720
693
auto linear_out = createTensorFromCT (&ct, 0 , false );
@@ -756,7 +729,6 @@ class FusedGQABlockAttention : public HpuFusedOperator {
756
729
std::vector<int64_t > src_dims = std::vector<int64_t >(ins[src_index].dims );
757
730
758
731
int64_t batch_size = src_dims[0 ];
759
- int64_t seq_length = src_dims[1 ];
760
732
int64_t hidden_size = ins[linear_weights_index].dims [0 ];
761
733
int64_t block_size = ins[key_cache_index].dims [1 ];
762
734
int64_t num_of_block = ins[block_list_index].dims [0 ];
@@ -789,15 +761,15 @@ class FusedGQABlockAttention : public HpuFusedOperator {
789
761
}
790
762
auto tmp_dims = src_dims;
791
763
auto wt_dims = ins[qkv_weights_index].dims ;
792
- tmp_dims[2 ] = wt_dims[0 ];
764
+ tmp_dims[1 ] = wt_dims[0 ];
793
765
794
766
auto qkv_out = createTensorNoPresist (" qkv_out" , dtype_, tmp_dims);
795
767
std::vector<synTensor> linear_outputs;
796
768
linear_outputs.push_back (qkv_out);
797
769
AddNodeLinear<T>(linear_inputs, linear_outputs, guid_ + " linear" );
798
770
799
771
auto reshape_dims = src_dims;
800
- reshape_dims[2 ] = num_head + 2 * num_kv_head;
772
+ reshape_dims[1 ] = num_head + 2 * num_kv_head;
801
773
reshape_dims.push_back (head_dim);
802
774
803
775
std::vector<synTensor> reshape_outputs;
@@ -808,12 +780,10 @@ class FusedGQABlockAttention : public HpuFusedOperator {
808
780
809
781
std::vector<int64_t > q_dims;
810
782
q_dims.push_back (batch_size);
811
- q_dims.push_back (seq_length);
812
783
q_dims.push_back (num_head);
813
784
q_dims.push_back (head_dim);
814
785
std::vector<int64_t > kv_dims;
815
786
kv_dims.push_back (batch_size);
816
- kv_dims.push_back (seq_length);
817
787
kv_dims.push_back (num_kv_head);
818
788
kv_dims.push_back (head_dim);
819
789
@@ -865,7 +835,7 @@ class FusedGQABlockAttention : public HpuFusedOperator {
865
835
sin_squeezed.push_back (sin_sq);
866
836
867
837
synSqueezeParams squeezeParams;
868
- squeezeParams.axis = 4 ;
838
+ squeezeParams.axis = 3 ;
869
839
AddNodeSqueeze (
870
840
sin_inputs, sin_squeezed, squeezeParams, guid_ + " squeeze_sin" );
871
841
@@ -901,20 +871,6 @@ class FusedGQABlockAttention : public HpuFusedOperator {
901
871
AddNodeRope<T>(inputs_k, outputs_k, ropeParams, guid_ + " rope_k" );
902
872
903
873
// ////////////////////////////////////////////////////////////////
904
- kv_dims.erase (kv_dims.begin () + 1 );
905
-
906
- std::vector<synTensor> outputs_k_squeeze;
907
- auto k_squeeze = createTensorNoPresist (" k_squeeze" , dtype_, kv_dims);
908
- outputs_k_squeeze.push_back (k_squeeze);
909
- AddNodeReshape (outputs_k, outputs_k_squeeze, guid_ + " squeeze_k" );
910
-
911
- std::vector<synTensor> inputs_v_squeeze;
912
- inputs_v_squeeze.push_back (v_split);
913
- std::vector<synTensor> outputs_v_squeeze;
914
- auto v_squeeze = createTensorNoPresist (" v_squeeze" , dtype_, kv_dims);
915
- outputs_v_squeeze.push_back (v_squeeze);
916
- AddNodeReshape (inputs_v_squeeze, outputs_v_squeeze, guid_ + " squeeze_v" );
917
-
918
874
std::vector<int64_t > indices_concat_dims =
919
875
std::vector<int64_t >(ins[block_indices_index].dims );
920
876
indices_concat_dims.emplace_back (1 );
@@ -952,7 +908,7 @@ class FusedGQABlockAttention : public HpuFusedOperator {
952
908
std::vector<synTensor> inputs_scatter_k;
953
909
inputs_scatter_k.push_back (key_cache);
954
910
inputs_scatter_k.push_back (indices_concat);
955
- inputs_scatter_k.push_back (k_squeeze );
911
+ inputs_scatter_k.push_back (k_rope );
956
912
std::vector<synTensor> outputs_scatter_k;
957
913
outputs_scatter_k.push_back (kCache_out );
958
914
AddNodeScatter<T>(
@@ -965,7 +921,7 @@ class FusedGQABlockAttention : public HpuFusedOperator {
965
921
std::vector<synTensor> inputs_scatter_v;
966
922
inputs_scatter_v.push_back (value_cache);
967
923
inputs_scatter_v.push_back (indices_concat);
968
- inputs_scatter_v.push_back (v_squeeze );
924
+ inputs_scatter_v.push_back (v_split );
969
925
std::vector<synTensor> outputs_scatter_v;
970
926
outputs_scatter_v.push_back (vCache_out);
971
927
AddNodeScatter<T>(
@@ -1443,19 +1399,9 @@ class FusedGQABlockAttention : public HpuFusedOperator {
1443
1399
AddNodeGemm (
1444
1400
map_attn_in, map_attn_out, gemm_params_t_f, guid_ + " gemm_map_attn" );
1445
1401
1446
- std::vector<int64_t > reshape_attn_dims;
1447
- reshape_attn_dims.push_back (batch_size);
1448
- reshape_attn_dims.push_back (1 );
1449
- reshape_attn_dims.push_back (hidden_size);
1450
- auto attn = createTensorNoPresist (" attn" , dtype_, reshape_attn_dims);
1451
- std::vector<synTensor> attn_out;
1452
- attn_out.push_back (attn);
1453
-
1454
- AddNodeReshape (map_attn_out, attn_out, guid_ + " attn" );
1455
-
1456
1402
std::vector<synTensor> proj_in;
1457
1403
auto linear_weights = createTensorFromCT (&ct, linear_weights_index);
1458
- proj_in.push_back (attn );
1404
+ proj_in.push_back (mapped_attn );
1459
1405
proj_in.push_back (linear_weights);
1460
1406
1461
1407
auto linear_out = createTensorFromCT (&ct, 0 , false );
@@ -1696,7 +1642,7 @@ std::vector<paddle::Tensor> FusedBlockAttentionForward(
1696
1642
1697
1643
std::shared_ptr<phi::DenseTensor> out_linear =
1698
1644
std::make_shared<phi::DenseTensor>();
1699
- out_linear->Resize (phi::make_ddim ({batch_size, 1 , out_features}));
1645
+ out_linear->Resize (phi::make_ddim ({batch_size, out_features}));
1700
1646
dev_ctx->Alloc (out_linear.get (), src_tensor->dtype ());
1701
1647
1702
1648
CallFusedBlockAttentionKernel (*dev_ctx,
0 commit comments