Skip to content

Commit de0554d

Browse files
Arm backend: Add bitwise scalar ops (#12857)
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 45846c8 commit de0554d

File tree

5 files changed

+267
-0
lines changed

5 files changed

+267
-0
lines changed

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(self, exported_program):
5454
exir_ops.edge.aten.le.Tensor,
5555
exir_ops.edge.aten.pow.Tensor_Tensor,
5656
exir_ops.edge.aten.where.self,
57+
exir_ops.edge.aten.bitwise_and.Tensor,
58+
exir_ops.edge.aten.bitwise_xor.Tensor,
59+
exir_ops.edge.aten.bitwise_or.Tensor,
5760
]
5861

5962
def _match_op_rank(self, graph_module, node, arg, max_rank):

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
3535
exir_ops.edge.aten.le.Scalar: exir_ops.edge.aten.le.Tensor,
3636
exir_ops.edge.aten.ne.Scalar: exir_ops.edge.aten.ne.Tensor,
37+
exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor,
38+
exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor,
39+
exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor,
3740
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3841
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3942
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
@@ -46,6 +49,9 @@
4649
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
4750
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
4851
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
52+
torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor,
53+
torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor,
54+
torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor,
4955
}
5056

5157

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class EthosU55NotSupported(OperatorSupportBase):
125125
exir_ops.edge.aten.bitwise_and.Tensor,
126126
exir_ops.edge.aten.bitwise_or.Tensor,
127127
exir_ops.edge.aten.bitwise_xor.Tensor,
128+
exir_ops.edge.aten.bitwise_and.Scalar,
129+
exir_ops.edge.aten.bitwise_or.Scalar,
130+
exir_ops.edge.aten.bitwise_xor.Scalar,
128131
exir_ops.edge.aten.bitwise_not,
129132
exir_ops.edge.aten.logical_and.default,
130133
exir_ops.edge.aten.logical_or.default,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def is_node_supported(
164164
exir_ops.edge.aten.bitwise_and.Tensor,
165165
exir_ops.edge.aten.bitwise_or.Tensor,
166166
exir_ops.edge.aten.bitwise_xor.Tensor,
167+
exir_ops.edge.aten.bitwise_and.Scalar,
168+
exir_ops.edge.aten.bitwise_or.Scalar,
169+
exir_ops.edge.aten.bitwise_xor.Scalar,
167170
exir_ops.edge.aten.expand_copy.default,
168171
exir_ops.edge.aten.cat.default,
169172
exir_ops.edge.aten.ceil.default,

backends/arm/test/ops/test_bitwise.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ class BitwiseBinary(torch.nn.Module):
5656
}
5757

5858

59+
class BitwiseBinaryScalar(torch.nn.Module):
60+
test_data = {
61+
"zeros": lambda: (torch.zeros(1, 10, 10, 10, dtype=torch.int32), 0),
62+
"ones_int8": lambda: (torch.ones(10, 10, 10, dtype=torch.int8), 1),
63+
"pattern_int8": lambda: (0xAA * torch.ones(1, 2, 2, 2, dtype=torch.int8), 0x77),
64+
"pattern_int16": lambda: (
65+
0xAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int16),
66+
0x7777,
67+
),
68+
"pattern_int32": lambda: (
69+
0xAAAAAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int32),
70+
0x77777777,
71+
),
72+
"rand_rank2": lambda: (torch.randint(-128, 127, (10, 10), dtype=torch.int8), 5),
73+
"rand_rank4": lambda: (
74+
torch.randint(-128, 127, (1, 10, 10, 10), dtype=torch.int8),
75+
-7,
76+
),
77+
}
78+
79+
5980
class And(BitwiseBinary):
6081
aten_op = "torch.ops.aten.bitwise_and.Tensor"
6182
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor"
@@ -80,6 +101,36 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
80101
return tensor1.bitwise_or(tensor2)
81102

82103

104+
class AndScalar(BitwiseBinaryScalar):
105+
aten_op = "torch.ops.aten.bitwise_and.Scalar"
106+
# Tensor because it gets converted from Scalar -> Tensor in lowering
107+
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor"
108+
109+
def forward(self, tensor: torch.Tensor, scalar: int):
110+
return tensor.bitwise_and(scalar)
111+
112+
113+
class XorScalar(BitwiseBinaryScalar):
114+
aten_op = "torch.ops.aten.bitwise_xor.Scalar"
115+
# Tensor because it gets converted from Scalar -> Tensor in lowering
116+
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Tensor"
117+
118+
def forward(self, tensor: torch.Tensor, scalar: int):
119+
return tensor.bitwise_xor(scalar)
120+
121+
122+
class OrScalar(BitwiseBinaryScalar):
123+
aten_op = "torch.ops.aten.bitwise_or.Scalar"
124+
# Tensor because it gets converted from Scalar -> Tensor in lowering
125+
exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Tensor"
126+
127+
def forward(self, tensor: torch.Tensor, scalar: int):
128+
return tensor.bitwise_or(scalar)
129+
130+
131+
# Bitwise AND
132+
133+
83134
@common.parametrize("test_data", And().test_data)
84135
def test_bitwise_and_tensor_tosa_MI(test_data: input_t2):
85136
pipeline = TosaPipelineMI[input_t2](
@@ -94,6 +145,20 @@ def test_bitwise_and_tensor_tosa_MI(test_data: input_t2):
94145
pipeline.run()
95146

96147

148+
@common.parametrize("test_data", AndScalar.test_data)
149+
def test_bitwise_and_scalar_tosa_MI(test_data: input_t2):
150+
pipeline = TosaPipelineMI[input_t2](
151+
AndScalar(),
152+
test_data(),
153+
AndScalar.aten_op,
154+
AndScalar.exir_op,
155+
atol=0,
156+
rtol=0,
157+
qtol=0,
158+
)
159+
pipeline.run()
160+
161+
97162
@common.parametrize("test_data", And().test_data)
98163
def test_bitwise_and_tensor_tosa_BI(test_data: input_t2):
99164
pipeline = TosaPipelineBI[input_t2](
@@ -110,6 +175,22 @@ def test_bitwise_and_tensor_tosa_BI(test_data: input_t2):
110175
pipeline.run()
111176

112177

178+
@common.parametrize("test_data", AndScalar.test_data)
179+
def test_bitwise_and_scalar_tosa_BI(test_data: input_t2):
180+
pipeline = TosaPipelineBI[input_t2](
181+
AndScalar(),
182+
test_data(),
183+
AndScalar.aten_op,
184+
AndScalar.exir_op,
185+
atol=0,
186+
rtol=0,
187+
qtol=0,
188+
)
189+
pipeline.pop_stage("quantize")
190+
pipeline.pop_stage("check.quant_nodes")
191+
pipeline.run()
192+
193+
113194
@common.parametrize("test_data", And().test_data)
114195
def test_bitwise_and_tensor_u55_BI(test_data: input_t2):
115196
# Tests that we don't delegate these ops since they are not supported on U55.
@@ -123,6 +204,43 @@ def test_bitwise_and_tensor_u55_BI(test_data: input_t2):
123204
pipeline.run()
124205

125206

207+
@common.parametrize("test_data", AndScalar.test_data)
208+
def test_bitwise_and_scalar_u55_BI(test_data: input_t2):
209+
# There will be one full op which will be delegated.
210+
num_delegates = 1
211+
num_exir = 0
212+
pipeline = OpNotSupportedPipeline[input_t2](
213+
AndScalar(),
214+
test_data(),
215+
{
216+
AndScalar.exir_op: 1,
217+
"executorch_exir_dialects_edge__ops_aten_full_default": num_exir,
218+
},
219+
num_delegates,
220+
quantize=True,
221+
u55_subset=True,
222+
)
223+
pipeline.run()
224+
225+
226+
@common.parametrize("test_data", AndScalar.test_data)
227+
@common.XfailIfNoCorstone320
228+
def test_bitwise_and_scalar_u85_BI(test_data: input_t2):
229+
pipeline = EthosU85PipelineBI[input_t2](
230+
AndScalar(),
231+
test_data(),
232+
AndScalar.aten_op,
233+
AndScalar.exir_op,
234+
run_on_fvp=True,
235+
atol=0,
236+
rtol=0,
237+
qtol=0,
238+
)
239+
pipeline.pop_stage("quantize")
240+
pipeline.pop_stage("check.quant_nodes")
241+
pipeline.run()
242+
243+
126244
@common.parametrize("test_data", And().test_data)
127245
@common.XfailIfNoCorstone320
128246
def test_bitwise_and_tensor_u85_BI(test_data: input_t2):
@@ -155,6 +273,20 @@ def test_bitwise_xor_tensor_tosa_MI(test_data: input_t2):
155273
pipeline.run()
156274

157275

276+
@common.parametrize("test_data", XorScalar.test_data)
277+
def test_bitwise_xor_scalar_tosa_MI(test_data: input_t2):
278+
pipeline = TosaPipelineMI[input_t2](
279+
XorScalar(),
280+
test_data(),
281+
XorScalar.aten_op,
282+
XorScalar.exir_op,
283+
atol=0,
284+
rtol=0,
285+
qtol=0,
286+
)
287+
pipeline.run()
288+
289+
158290
@common.parametrize("test_data", Xor().test_data)
159291
def test_bitwise_xor_tensor_tosa_BI(test_data: input_t2):
160292
pipeline = TosaPipelineBI[input_t2](
@@ -171,6 +303,22 @@ def test_bitwise_xor_tensor_tosa_BI(test_data: input_t2):
171303
pipeline.run()
172304

173305

306+
@common.parametrize("test_data", XorScalar.test_data)
307+
def test_bitwise_xor_scalar_tosa_BI(test_data: input_t2):
308+
pipeline = TosaPipelineBI[input_t2](
309+
XorScalar(),
310+
test_data(),
311+
XorScalar.aten_op,
312+
XorScalar.exir_op,
313+
atol=0,
314+
rtol=0,
315+
qtol=0,
316+
)
317+
pipeline.pop_stage("quantize")
318+
pipeline.pop_stage("check.quant_nodes")
319+
pipeline.run()
320+
321+
174322
@common.parametrize("test_data", Xor().test_data)
175323
def test_bitwise_xor_tensor_u55_BI(test_data: input_t2):
176324
# Tests that we don't delegate these ops since they are not supported on U55.
@@ -184,6 +332,25 @@ def test_bitwise_xor_tensor_u55_BI(test_data: input_t2):
184332
pipeline.run()
185333

186334

335+
@common.parametrize("test_data", XorScalar.test_data)
336+
def test_bitwise_xor_scalar_u55_BI(test_data: input_t2):
337+
# There will be one full op which will be delegated.
338+
num_delegates = 1
339+
num_exir = 0
340+
pipeline = OpNotSupportedPipeline[input_t2](
341+
XorScalar(),
342+
test_data(),
343+
{
344+
XorScalar.exir_op: 1,
345+
"executorch_exir_dialects_edge__ops_aten_full_default": num_exir,
346+
},
347+
num_delegates,
348+
quantize=True,
349+
u55_subset=True,
350+
)
351+
pipeline.run()
352+
353+
187354
@common.parametrize("test_data", Xor().test_data)
188355
@common.XfailIfNoCorstone320
189356
def test_bitwise_xor_tensor_u85_BI(test_data: input_t2):
@@ -202,6 +369,24 @@ def test_bitwise_xor_tensor_u85_BI(test_data: input_t2):
202369
pipeline.run()
203370

204371

372+
@common.parametrize("test_data", XorScalar.test_data)
373+
@common.XfailIfNoCorstone320
374+
def test_bitwise_xor_scalar_u85_BI(test_data: input_t2):
375+
pipeline = EthosU85PipelineBI[input_t2](
376+
XorScalar(),
377+
test_data(),
378+
XorScalar.aten_op,
379+
XorScalar.exir_op,
380+
run_on_fvp=True,
381+
atol=0,
382+
rtol=0,
383+
qtol=0,
384+
)
385+
pipeline.pop_stage("quantize")
386+
pipeline.pop_stage("check.quant_nodes")
387+
pipeline.run()
388+
389+
205390
@common.parametrize("test_data", Or().test_data)
206391
def test_bitwise_or_tensor_tosa_MI(test_data: input_t2):
207392
pipeline = TosaPipelineMI[input_t2](
@@ -216,6 +401,20 @@ def test_bitwise_or_tensor_tosa_MI(test_data: input_t2):
216401
pipeline.run()
217402

218403

404+
@common.parametrize("test_data", OrScalar.test_data)
405+
def test_bitwise_or_scalar_tosa_MI(test_data: input_t2):
406+
pipeline = TosaPipelineMI[input_t2](
407+
OrScalar(),
408+
test_data(),
409+
OrScalar.aten_op,
410+
OrScalar.exir_op,
411+
atol=0,
412+
rtol=0,
413+
qtol=0,
414+
)
415+
pipeline.run()
416+
417+
219418
@common.parametrize("test_data", Or().test_data)
220419
def test_bitwise_or_tensor_tosa_BI(test_data: input_t2):
221420
pipeline = TosaPipelineBI[input_t2](
@@ -232,6 +431,22 @@ def test_bitwise_or_tensor_tosa_BI(test_data: input_t2):
232431
pipeline.run()
233432

234433

434+
@common.parametrize("test_data", OrScalar.test_data)
435+
def test_bitwise_or_scalar_tosa_BI(test_data: input_t2):
436+
pipeline = TosaPipelineBI[input_t2](
437+
OrScalar(),
438+
test_data(),
439+
OrScalar.aten_op,
440+
OrScalar.exir_op,
441+
atol=0,
442+
rtol=0,
443+
qtol=0,
444+
)
445+
pipeline.pop_stage("quantize")
446+
pipeline.pop_stage("check.quant_nodes")
447+
pipeline.run()
448+
449+
235450
@common.parametrize("test_data", Or().test_data)
236451
def test_bitwise_or_tensor_u55_BI(test_data: input_t2):
237452
# Tests that we don't delegate these ops since they are not supported on U55.
@@ -245,6 +460,25 @@ def test_bitwise_or_tensor_u55_BI(test_data: input_t2):
245460
pipeline.run()
246461

247462

463+
@common.parametrize("test_data", OrScalar.test_data)
464+
def test_bitwise_or_scalar_u55_BI(test_data: input_t2):
465+
# There will be one full op which will be delegated.
466+
num_delegates = 1
467+
num_exir = 0
468+
pipeline = OpNotSupportedPipeline[input_t2](
469+
OrScalar(),
470+
test_data(),
471+
{
472+
OrScalar.exir_op: 1,
473+
"executorch_exir_dialects_edge__ops_aten_full_default": num_exir,
474+
},
475+
num_delegates,
476+
quantize=True,
477+
u55_subset=True,
478+
)
479+
pipeline.run()
480+
481+
248482
@common.parametrize("test_data", Or().test_data)
249483
@common.XfailIfNoCorstone320
250484
def test_bitwise_or_tensor_u85_BI(test_data: input_t2):
@@ -261,3 +495,21 @@ def test_bitwise_or_tensor_u85_BI(test_data: input_t2):
261495
pipeline.pop_stage("quantize")
262496
pipeline.pop_stage("check.quant_nodes")
263497
pipeline.run()
498+
499+
500+
@common.parametrize("test_data", OrScalar.test_data)
501+
@common.XfailIfNoCorstone320
502+
def test_bitwise_or_scalar_u85_BI(test_data: input_t2):
503+
pipeline = EthosU85PipelineBI[input_t2](
504+
OrScalar(),
505+
test_data(),
506+
OrScalar.aten_op,
507+
OrScalar.exir_op,
508+
run_on_fvp=True,
509+
atol=0,
510+
rtol=0,
511+
qtol=0,
512+
)
513+
pipeline.pop_stage("quantize")
514+
pipeline.pop_stage("check.quant_nodes")
515+
pipeline.run()

0 commit comments

Comments
 (0)