@@ -54,8 +54,6 @@ def build(self, x):
54
54
# repeat 8 times
55
55
y0, w0 = ops.repeat(add_weight_graph0, 8, x0, inputs_dict={add_weight0.w: w0})
56
56
57
- See also `PyTorch Tensor.repeat <https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html>`__, `NumPy repeat <https://numpy.org/doc/stable/reference/generated/numpy.repeat.html>`__.
58
-
59
57
Args:
60
58
graph (Graph): User defined graph to repeat `repeat_count` times.
61
59
repeat_count (int): Number of times to repeat calling the graph.
@@ -74,7 +72,10 @@ def build(self, x):
74
72
Tuple[Tensor, ...]:
75
73
Tuple of the output tensors of the call in the parent graph.
76
74
"""
77
- loop_info = repeat_with_info (graph , repeat_count , * inputs , inputs_dict = inputs_dict )
75
+ loop_info = repeat_with_info (graph ,
76
+ repeat_count ,
77
+ * inputs ,
78
+ inputs_dict = inputs_dict )
78
79
79
80
out_tensors = loop_info .outputs
80
81
return out_tensors
@@ -193,8 +194,7 @@ def build(self, x):
193
194
if total_inputs < total_outputs :
194
195
raise ValueError (
195
196
f"To repeat the subgraph ({ graph .id } ) the number of inputs must be greater than or equal to the number of outputs."
196
- f" { total_inputs } < { total_outputs } "
197
- )
197
+ f" { total_inputs } < { total_outputs } " )
198
198
199
199
# For clarity, we rename our graphs:
200
200
# - Bottom: The user provided bottom level graph. We call this with a call op. This has gone
@@ -215,16 +215,14 @@ def build(self, x):
215
215
216
216
# Create the middle graph, call and loop ops
217
217
pb_middle_graph , pb_callop , pb_loop_op = _setup_call_and_repeat (
218
- pb_ir , pb_top_graph , pb_bottom_graph
219
- )
218
+ pb_ir , pb_top_graph , pb_bottom_graph )
220
219
221
220
# set the number of times to loop
222
221
pb_loop_op .setTripCountValue (repeat_count )
223
222
224
223
# Prep and validate inputs
225
- inputs_all = _prep_and_validate_inputs (
226
- check_inputs , top_graph , graph , "called" , inputs , inputs_dict
227
- )
224
+ inputs_all = _prep_and_validate_inputs (check_inputs , top_graph , graph ,
225
+ "called" , inputs , inputs_dict )
228
226
229
227
# 1, 2. Connect inputs.
230
228
_setup_inputs (
@@ -236,9 +234,8 @@ def build(self, x):
236
234
)
237
235
238
236
# 3. Connect outputs.
239
- _ = _setup_outputs (
240
- pb_top_graph , pb_bottom_graph , pb_middle_graph , pb_callop , pb_loop_op
241
- )
237
+ _ = _setup_outputs (pb_top_graph , pb_bottom_graph , pb_middle_graph ,
238
+ pb_callop , pb_loop_op )
242
239
243
240
pb_callop .setup ()
244
241
pb_loop_op .setup ()
@@ -250,13 +247,14 @@ def build(self, x):
250
247
loop_carried_inputs = pb_loop_op .getNumExplicitInputs ()
251
248
for bottom_t in bottom_graph ._by_ref_inputs :
252
249
middle_t = c_info .graph_to_parent (bottom_t )
253
- loop_carried = pb_middle_graph .getInputIndex (middle_t .id ) < loop_carried_inputs
250
+ loop_carried = pb_middle_graph .getInputIndex (
251
+ middle_t .id ) < loop_carried_inputs
254
252
# If a tensor was set as a by_ref_input, we should also do the same for the looped subgraph.
255
253
c_info .set_parent_input_modified (
256
- middle_t , infer_modified_regions = not loop_carried
257
- )
254
+ middle_t , infer_modified_regions = not loop_carried )
258
255
top_t = r_info .graph_to_parent (middle_t )
259
- r_info .set_parent_input_modified (top_t , infer_modified_regions = not loop_carried )
256
+ r_info .set_parent_input_modified (
257
+ top_t , infer_modified_regions = not loop_carried )
260
258
r_info .called_graph ._by_ref_inputs .add (middle_t )
261
259
262
260
return r_info
@@ -280,34 +278,35 @@ def _setup_call_and_repeat(
280
278
# This is the graph we will repeat.
281
279
pb_middle_graph = pb_ir .createGraph (
282
280
_ir .GraphId (
283
- pb_ir .createUniqueSubgraphId (f"{ pb_bottom_graph .id .str ()} __loop_wrapper" )
284
- )
285
- )
281
+ pb_ir .createUniqueSubgraphId (
282
+ f"{ pb_bottom_graph .id .str ()} __loop_wrapper" )))
286
283
287
- opid = _ir .OperatorIdentifier ("ai.graphcore" , "Call" , 1 , _ir .NumInputs (), 0 )
284
+ opid = _ir .OperatorIdentifier ("ai.graphcore" , "Call" , 1 , _ir .NumInputs (),
285
+ 0 )
288
286
op_name = pb_middle_graph .id .str () + "__call__" + pb_bottom_graph .id .str ()
289
287
290
288
ctx = get_current_context ()
291
289
# Call the bottom_graph
292
- pb_callop = pb_middle_graph .createOp_CallOp (
293
- opid , pb_bottom_graph , ctx ._get_op_settings (op_name )
294
- )
290
+ pb_callop = pb_middle_graph .createOp_CallOp (opid , pb_bottom_graph ,
291
+ ctx ._get_op_settings (op_name ))
295
292
296
293
opid = _ir .OperatorIdentifier ("ai.onnx" , "Loop" , 11 , _ir .NumInputs (), 0 )
297
294
op_name = pb_top_graph .id .str () + "__loop__" + pb_middle_graph .id .str ()
298
295
299
296
# Loop the middle_graph
300
- pb_loop_op = pb_top_graph .createOp_LoopOp (
301
- opid , ctx ._get_op_settings (op_name ), pb_middle_graph
302
- )
297
+ pb_loop_op = pb_top_graph .createOp_LoopOp (opid ,
298
+ ctx ._get_op_settings (op_name ),
299
+ pb_middle_graph )
303
300
304
301
# Add mandatory loop iterator tensor to graph (is not an output)
305
302
repeatIterId = _ir .addScope (pb_middle_graph , "Iterator___" )
306
- pb_middle_graph .addInput (repeatIterId , _ir .TensorInfo (_ir .DataType .INT32 , ()))
303
+ pb_middle_graph .addInput (repeatIterId ,
304
+ _ir .TensorInfo (_ir .DataType .INT32 , ()))
307
305
308
306
# Add mandatory loop condition tensor to graph (is also an output)
309
307
repeatCondId = _ir .addScope (pb_middle_graph , "LoopCond___" )
310
- pb_middle_graph .addInput (repeatCondId , _ir .TensorInfo (_ir .DataType .BOOL , ()))
308
+ pb_middle_graph .addInput (repeatCondId ,
309
+ _ir .TensorInfo (_ir .DataType .BOOL , ()))
311
310
pb_middle_graph .markAsOutput (repeatCondId )
312
311
313
312
return pb_middle_graph , pb_callop , pb_loop_op
@@ -354,8 +353,7 @@ def _setup_inputs(
354
353
False ,
355
354
)
356
355
pb_callop .connectInTensor (
357
- call_in_idx , _ir .addScope (pb_middle_graph , parent_tensor .name )
358
- )
356
+ call_in_idx , _ir .addScope (pb_middle_graph , parent_tensor .name ))
359
357
360
358
361
359
def _setup_outputs (
@@ -385,33 +383,31 @@ def _setup_outputs(
385
383
386
384
for pb_subgraph_out_id in pb_bottom_graph .getOutputIds ():
387
385
top_tensor_id = _ir .addScope (
388
- pb_top_graph , _ir .removeScope (pb_bottom_graph , pb_subgraph_out_id )
389
- )
386
+ pb_top_graph , _ir .removeScope (pb_bottom_graph , pb_subgraph_out_id ))
390
387
# Already has scope added
391
388
middle_tensor_id = _ir .removeScope (pb_bottom_graph , pb_subgraph_out_id )
392
389
bottom_tensor_id = _ir .addScope (
393
- pb_bottom_graph , _ir . removeScope ( pb_bottom_graph , pb_subgraph_out_id )
394
- )
390
+ pb_bottom_graph ,
391
+ _ir . removeScope ( pb_bottom_graph , pb_subgraph_out_id ) )
395
392
396
393
sgOutIdx = pb_bottom_graph .getOutputIndex (bottom_tensor_id )
397
394
callOutIdx = pb_callop .subgraphOutToOpOutIndex (sgOutIdx )
398
395
399
396
# Avoid tensor name collisions
400
397
middle_tensor_id = pb_middle_graph .getIr ().createIntermediateTensorId (
401
- middle_tensor_id
402
- )
398
+ middle_tensor_id )
403
399
pb_callop .createAndConnectOutTensor (callOutIdx , middle_tensor_id )
404
400
405
401
pb_middle_graph .markAsOutput (middle_tensor_id )
406
402
sgOutIdx = pb_middle_graph .getOutputIndex (middle_tensor_id )
407
403
repeatOutIdx = pb_loop_op .subgraphOutToOpOutIndex (sgOutIdx )
408
404
# Avoid tensor name collisions
409
405
top_tensor_id = pb_middle_graph .getIr ().createIntermediateTensorId (
410
- top_tensor_id
411
- )
406
+ top_tensor_id )
412
407
# We overwrite here as we added the middle_tensor_id as an output above, but we want to make
413
408
# sure the loop op is setup correctly.
414
- pb_loop_op .addLoopOutput (repeatOutIdx , top_tensor_id , middle_tensor_id , True )
409
+ pb_loop_op .addLoopOutput (repeatOutIdx , top_tensor_id , middle_tensor_id ,
410
+ True )
415
411
416
412
outnames .append (top_tensor_id )
417
413
return outnames
0 commit comments