Skip to content

Commit 352f6d9

Browse files
authored
[INTEL_HPU] change hidden states to 2D (#1824)
1 parent 9a1b9f2 commit 352f6d9

File tree

7 files changed

+85
-108
lines changed

7 files changed

+85
-108
lines changed

backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class FusedMHABlockAttention : public HpuFusedOperator {
6161
std::vector<int64_t> src_dims = std::vector<int64_t>(ins[src_index].dims);
6262

6363
int64_t batch_size = src_dims[0];
64-
int64_t seq_length = src_dims[1];
6564
int64_t hidden_size = ins[linear_weights_index].dims[0];
6665
int64_t block_size = ins[key_cache_index].dims[1];
6766
int64_t num_of_block = ins[block_list_index].dims[0];
@@ -93,15 +92,15 @@ class FusedMHABlockAttention : public HpuFusedOperator {
9392
}
9493
auto tmp_dims = src_dims;
9594
auto wt_dims = ins[qkv_weights_index].dims;
96-
tmp_dims[2] = wt_dims[0];
95+
tmp_dims[1] = wt_dims[0];
9796

9897
auto qkv_out = createTensorNoPresist("qkv_out", dtype_, tmp_dims);
9998
std::vector<synTensor> linear_outputs;
10099
linear_outputs.push_back(qkv_out);
101100
AddNodeLinear<T>(linear_inputs, linear_outputs, guid_ + "linear");
102101

103102
auto reshape_dims = src_dims;
104-
reshape_dims[2] = num_head + 2 * num_kv_head;
103+
reshape_dims[1] = num_head + 2 * num_kv_head;
105104
reshape_dims.push_back(head_dim);
106105

107106
std::vector<synTensor> reshape_outputs;
@@ -112,12 +111,10 @@ class FusedMHABlockAttention : public HpuFusedOperator {
112111

113112
std::vector<int64_t> q_dims;
114113
q_dims.push_back(batch_size);
115-
q_dims.push_back(seq_length);
116114
q_dims.push_back(num_head);
117115
q_dims.push_back(head_dim);
118116
std::vector<int64_t> kv_dims;
119117
kv_dims.push_back(batch_size);
120-
kv_dims.push_back(seq_length);
121118
kv_dims.push_back(num_kv_head);
122119
kv_dims.push_back(head_dim);
123120

@@ -169,7 +166,7 @@ class FusedMHABlockAttention : public HpuFusedOperator {
169166
sin_squeezed.push_back(sin_sq);
170167

171168
synSqueezeParams squeezeParams;
172-
squeezeParams.axis = 4;
169+
squeezeParams.axis = 3;
173170
AddNodeSqueeze(
174171
sin_inputs, sin_squeezed, squeezeParams, guid_ + "squeeze_sin");
175172

@@ -205,20 +202,6 @@ class FusedMHABlockAttention : public HpuFusedOperator {
205202
AddNodeRope<T>(inputs_k, outputs_k, ropeParams, guid_ + "rope_k");
206203

207204
//////////////////////////////////////////////////////////////////
208-
kv_dims.erase(kv_dims.begin() + 1);
209-
210-
std::vector<synTensor> outputs_k_squeeze;
211-
auto k_squeeze = createTensorNoPresist("k_squeeze", dtype_, kv_dims);
212-
outputs_k_squeeze.push_back(k_squeeze);
213-
AddNodeReshape(outputs_k, outputs_k_squeeze, guid_ + "squeeze_k");
214-
215-
std::vector<synTensor> inputs_v_squeeze;
216-
inputs_v_squeeze.push_back(v_split);
217-
std::vector<synTensor> outputs_v_squeeze;
218-
auto v_squeeze = createTensorNoPresist("v_squeeze", dtype_, kv_dims);
219-
outputs_v_squeeze.push_back(v_squeeze);
220-
AddNodeReshape(inputs_v_squeeze, outputs_v_squeeze, guid_ + "squeeze_v");
221-
222205
std::vector<int64_t> indices_concat_dims =
223206
std::vector<int64_t>(ins[block_indices_index].dims);
224207
indices_concat_dims.emplace_back(1);
@@ -256,7 +239,7 @@ class FusedMHABlockAttention : public HpuFusedOperator {
256239
std::vector<synTensor> inputs_scatter_k;
257240
inputs_scatter_k.push_back(key_cache);
258241
inputs_scatter_k.push_back(indices_concat);
259-
inputs_scatter_k.push_back(k_squeeze);
242+
inputs_scatter_k.push_back(k_rope);
260243
std::vector<synTensor> outputs_scatter_k;
261244
outputs_scatter_k.push_back(kCache_out);
262245
AddNodeScatter<T>(
@@ -269,7 +252,7 @@ class FusedMHABlockAttention : public HpuFusedOperator {
269252
std::vector<synTensor> inputs_scatter_v;
270253
inputs_scatter_v.push_back(value_cache);
271254
inputs_scatter_v.push_back(indices_concat);
272-
inputs_scatter_v.push_back(v_squeeze);
255+
inputs_scatter_v.push_back(v_split);
273256
std::vector<synTensor> outputs_scatter_v;
274257
outputs_scatter_v.push_back(vCache_out);
275258
AddNodeScatter<T>(
@@ -702,19 +685,9 @@ class FusedMHABlockAttention : public HpuFusedOperator {
702685
AddNodeGemm(
703686
map_attn_in, map_attn_out, gemm_params_t_f, guid_ + "gemm_map_attn");
704687

705-
std::vector<int64_t> reshape_attn_dims;
706-
reshape_attn_dims.push_back(batch_size);
707-
reshape_attn_dims.push_back(1);
708-
reshape_attn_dims.push_back(hidden_size);
709-
auto attn = createTensorNoPresist("attn", dtype_, reshape_attn_dims);
710-
std::vector<synTensor> attn_out;
711-
attn_out.push_back(attn);
712-
713-
AddNodeReshape(map_attn_out, attn_out, guid_ + "attn");
714-
715688
std::vector<synTensor> proj_in;
716689
auto linear_weights = createTensorFromCT(&ct, linear_weights_index);
717-
proj_in.push_back(attn);
690+
proj_in.push_back(mapped_attn);
718691
proj_in.push_back(linear_weights);
719692

720693
auto linear_out = createTensorFromCT(&ct, 0, false);
@@ -756,7 +729,6 @@ class FusedGQABlockAttention : public HpuFusedOperator {
756729
std::vector<int64_t> src_dims = std::vector<int64_t>(ins[src_index].dims);
757730

758731
int64_t batch_size = src_dims[0];
759-
int64_t seq_length = src_dims[1];
760732
int64_t hidden_size = ins[linear_weights_index].dims[0];
761733
int64_t block_size = ins[key_cache_index].dims[1];
762734
int64_t num_of_block = ins[block_list_index].dims[0];
@@ -789,15 +761,15 @@ class FusedGQABlockAttention : public HpuFusedOperator {
789761
}
790762
auto tmp_dims = src_dims;
791763
auto wt_dims = ins[qkv_weights_index].dims;
792-
tmp_dims[2] = wt_dims[0];
764+
tmp_dims[1] = wt_dims[0];
793765

794766
auto qkv_out = createTensorNoPresist("qkv_out", dtype_, tmp_dims);
795767
std::vector<synTensor> linear_outputs;
796768
linear_outputs.push_back(qkv_out);
797769
AddNodeLinear<T>(linear_inputs, linear_outputs, guid_ + "linear");
798770

799771
auto reshape_dims = src_dims;
800-
reshape_dims[2] = num_head + 2 * num_kv_head;
772+
reshape_dims[1] = num_head + 2 * num_kv_head;
801773
reshape_dims.push_back(head_dim);
802774

803775
std::vector<synTensor> reshape_outputs;
@@ -808,12 +780,10 @@ class FusedGQABlockAttention : public HpuFusedOperator {
808780

809781
std::vector<int64_t> q_dims;
810782
q_dims.push_back(batch_size);
811-
q_dims.push_back(seq_length);
812783
q_dims.push_back(num_head);
813784
q_dims.push_back(head_dim);
814785
std::vector<int64_t> kv_dims;
815786
kv_dims.push_back(batch_size);
816-
kv_dims.push_back(seq_length);
817787
kv_dims.push_back(num_kv_head);
818788
kv_dims.push_back(head_dim);
819789

@@ -865,7 +835,7 @@ class FusedGQABlockAttention : public HpuFusedOperator {
865835
sin_squeezed.push_back(sin_sq);
866836

867837
synSqueezeParams squeezeParams;
868-
squeezeParams.axis = 4;
838+
squeezeParams.axis = 3;
869839
AddNodeSqueeze(
870840
sin_inputs, sin_squeezed, squeezeParams, guid_ + "squeeze_sin");
871841

@@ -901,20 +871,6 @@ class FusedGQABlockAttention : public HpuFusedOperator {
901871
AddNodeRope<T>(inputs_k, outputs_k, ropeParams, guid_ + "rope_k");
902872

903873
//////////////////////////////////////////////////////////////////
904-
kv_dims.erase(kv_dims.begin() + 1);
905-
906-
std::vector<synTensor> outputs_k_squeeze;
907-
auto k_squeeze = createTensorNoPresist("k_squeeze", dtype_, kv_dims);
908-
outputs_k_squeeze.push_back(k_squeeze);
909-
AddNodeReshape(outputs_k, outputs_k_squeeze, guid_ + "squeeze_k");
910-
911-
std::vector<synTensor> inputs_v_squeeze;
912-
inputs_v_squeeze.push_back(v_split);
913-
std::vector<synTensor> outputs_v_squeeze;
914-
auto v_squeeze = createTensorNoPresist("v_squeeze", dtype_, kv_dims);
915-
outputs_v_squeeze.push_back(v_squeeze);
916-
AddNodeReshape(inputs_v_squeeze, outputs_v_squeeze, guid_ + "squeeze_v");
917-
918874
std::vector<int64_t> indices_concat_dims =
919875
std::vector<int64_t>(ins[block_indices_index].dims);
920876
indices_concat_dims.emplace_back(1);
@@ -952,7 +908,7 @@ class FusedGQABlockAttention : public HpuFusedOperator {
952908
std::vector<synTensor> inputs_scatter_k;
953909
inputs_scatter_k.push_back(key_cache);
954910
inputs_scatter_k.push_back(indices_concat);
955-
inputs_scatter_k.push_back(k_squeeze);
911+
inputs_scatter_k.push_back(k_rope);
956912
std::vector<synTensor> outputs_scatter_k;
957913
outputs_scatter_k.push_back(kCache_out);
958914
AddNodeScatter<T>(
@@ -965,7 +921,7 @@ class FusedGQABlockAttention : public HpuFusedOperator {
965921
std::vector<synTensor> inputs_scatter_v;
966922
inputs_scatter_v.push_back(value_cache);
967923
inputs_scatter_v.push_back(indices_concat);
968-
inputs_scatter_v.push_back(v_squeeze);
924+
inputs_scatter_v.push_back(v_split);
969925
std::vector<synTensor> outputs_scatter_v;
970926
outputs_scatter_v.push_back(vCache_out);
971927
AddNodeScatter<T>(
@@ -1443,19 +1399,9 @@ class FusedGQABlockAttention : public HpuFusedOperator {
14431399
AddNodeGemm(
14441400
map_attn_in, map_attn_out, gemm_params_t_f, guid_ + "gemm_map_attn");
14451401

1446-
std::vector<int64_t> reshape_attn_dims;
1447-
reshape_attn_dims.push_back(batch_size);
1448-
reshape_attn_dims.push_back(1);
1449-
reshape_attn_dims.push_back(hidden_size);
1450-
auto attn = createTensorNoPresist("attn", dtype_, reshape_attn_dims);
1451-
std::vector<synTensor> attn_out;
1452-
attn_out.push_back(attn);
1453-
1454-
AddNodeReshape(map_attn_out, attn_out, guid_ + "attn");
1455-
14561402
std::vector<synTensor> proj_in;
14571403
auto linear_weights = createTensorFromCT(&ct, linear_weights_index);
1458-
proj_in.push_back(attn);
1404+
proj_in.push_back(mapped_attn);
14591405
proj_in.push_back(linear_weights);
14601406

14611407
auto linear_out = createTensorFromCT(&ct, 0, false);
@@ -1696,7 +1642,7 @@ std::vector<paddle::Tensor> FusedBlockAttentionForward(
16961642

16971643
std::shared_ptr<phi::DenseTensor> out_linear =
16981644
std::make_shared<phi::DenseTensor>();
1699-
out_linear->Resize(phi::make_ddim({batch_size, 1, out_features}));
1645+
out_linear->Resize(phi::make_ddim({batch_size, out_features}));
17001646
dev_ctx->Alloc(out_linear.get(), src_tensor->dtype());
17011647

17021648
CallFusedBlockAttentionKernel(*dev_ctx,

backends/intel_hpu/custom_ops/llama_infer/fused_mlp.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,14 @@ class FusedGateUpMlp : public HpuOperator {
191191
ins[0].dims.size(), dtype_, ins[0].dims, true, ins[0].name);
192192
synTensor proj_weight = createTensor(
193193
ins[1].dims.size(), dtype_, ins[1].dims, true, ins[1].name);
194-
std::vector<int64_t> proj_dims = {
195-
ins[0].dims[0], ins[0].dims[1], ins[1].dims[1]};
194+
std::vector<int64_t> proj_dims = ins[0].dims;
195+
proj_dims[ins[0].dims.size() - 1] = ins[1].dims[1];
196196
synTensor proj_out =
197197
createTensor(proj_dims.size(), dtype_, proj_dims, false, "proj_out");
198198

199-
std::vector<int64_t> split_out_dims = {
200-
proj_dims[0], proj_dims[1], proj_dims[2] / 2};
199+
std::vector<int64_t> split_out_dims = proj_dims;
200+
split_out_dims[proj_dims.size() - 1] = proj_dims[proj_dims.size() - 1] / 2;
201+
201202
synTensor gate_out = createTensor(
202203
split_out_dims.size(), dtype_, split_out_dims, false, "gate_out");
203204
synTensor up_out = createTensor(

0 commit comments

Comments
 (0)