Skip to content

Commit d2f20d7

Browse files
authored
popxl.ops.repeat - remove 'see also' links (#4)
Removing "See also" links in `popxl.ops.repeat`.
1 parent 1aa2053 commit d2f20d7

File tree

1 file changed

+36
-40
lines changed

1 file changed

+36
-40
lines changed

python/popxl/python_files/ops/repeat.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def build(self, x):
5454
# repeat 8 times
5555
y0, w0 = ops.repeat(add_weight_graph0, 8, x0, inputs_dict={add_weight0.w: w0})
5656
57-
See also `PyTorch Tensor.repeat <https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html>`__, `NumPy repeat <https://numpy.org/doc/stable/reference/generated/numpy.repeat.html>`__.
58-
5957
Args:
6058
graph (Graph): User defined graph to repeat `repeat_count` times.
6159
repeat_count (int): Number of times to repeat calling the graph.
@@ -74,7 +72,10 @@ def build(self, x):
7472
Tuple[Tensor, ...]:
7573
Tuple of the output tensors of the call in the parent graph.
7674
"""
77-
loop_info = repeat_with_info(graph, repeat_count, *inputs, inputs_dict=inputs_dict)
75+
loop_info = repeat_with_info(graph,
76+
repeat_count,
77+
*inputs,
78+
inputs_dict=inputs_dict)
7879

7980
out_tensors = loop_info.outputs
8081
return out_tensors
@@ -193,8 +194,7 @@ def build(self, x):
193194
if total_inputs < total_outputs:
194195
raise ValueError(
195196
f"To repeat the subgraph ({graph.id}) the number of inputs must be greater than or equal to the number of outputs."
196-
f" {total_inputs} < {total_outputs}"
197-
)
197+
f" {total_inputs} < {total_outputs}")
198198

199199
# For clarity, we rename our graphs:
200200
# - Bottom: The user provided bottom level graph. We call this with a call op. This has gone
@@ -215,16 +215,14 @@ def build(self, x):
215215

216216
# Create the middle graph, call and loop ops
217217
pb_middle_graph, pb_callop, pb_loop_op = _setup_call_and_repeat(
218-
pb_ir, pb_top_graph, pb_bottom_graph
219-
)
218+
pb_ir, pb_top_graph, pb_bottom_graph)
220219

221220
# set the number of times to loop
222221
pb_loop_op.setTripCountValue(repeat_count)
223222

224223
# Prep and validate inputs
225-
inputs_all = _prep_and_validate_inputs(
226-
check_inputs, top_graph, graph, "called", inputs, inputs_dict
227-
)
224+
inputs_all = _prep_and_validate_inputs(check_inputs, top_graph, graph,
225+
"called", inputs, inputs_dict)
228226

229227
# 1, 2. Connect inputs.
230228
_setup_inputs(
@@ -236,9 +234,8 @@ def build(self, x):
236234
)
237235

238236
# 3. Connect outputs.
239-
_ = _setup_outputs(
240-
pb_top_graph, pb_bottom_graph, pb_middle_graph, pb_callop, pb_loop_op
241-
)
237+
_ = _setup_outputs(pb_top_graph, pb_bottom_graph, pb_middle_graph,
238+
pb_callop, pb_loop_op)
242239

243240
pb_callop.setup()
244241
pb_loop_op.setup()
@@ -250,13 +247,14 @@ def build(self, x):
250247
loop_carried_inputs = pb_loop_op.getNumExplicitInputs()
251248
for bottom_t in bottom_graph._by_ref_inputs:
252249
middle_t = c_info.graph_to_parent(bottom_t)
253-
loop_carried = pb_middle_graph.getInputIndex(middle_t.id) < loop_carried_inputs
250+
loop_carried = pb_middle_graph.getInputIndex(
251+
middle_t.id) < loop_carried_inputs
254252
# If a tensor was set as a by_ref_input, we should also do the same for the looped subgraph.
255253
c_info.set_parent_input_modified(
256-
middle_t, infer_modified_regions=not loop_carried
257-
)
254+
middle_t, infer_modified_regions=not loop_carried)
258255
top_t = r_info.graph_to_parent(middle_t)
259-
r_info.set_parent_input_modified(top_t, infer_modified_regions=not loop_carried)
256+
r_info.set_parent_input_modified(
257+
top_t, infer_modified_regions=not loop_carried)
260258
r_info.called_graph._by_ref_inputs.add(middle_t)
261259

262260
return r_info
@@ -280,34 +278,35 @@ def _setup_call_and_repeat(
280278
# This is the graph we will repeat.
281279
pb_middle_graph = pb_ir.createGraph(
282280
_ir.GraphId(
283-
pb_ir.createUniqueSubgraphId(f"{pb_bottom_graph.id.str()}__loop_wrapper")
284-
)
285-
)
281+
pb_ir.createUniqueSubgraphId(
282+
f"{pb_bottom_graph.id.str()}__loop_wrapper")))
286283

287-
opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(), 0)
284+
opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(),
285+
0)
288286
op_name = pb_middle_graph.id.str() + "__call__" + pb_bottom_graph.id.str()
289287

290288
ctx = get_current_context()
291289
# Call the bottom_graph
292-
pb_callop = pb_middle_graph.createOp_CallOp(
293-
opid, pb_bottom_graph, ctx._get_op_settings(op_name)
294-
)
290+
pb_callop = pb_middle_graph.createOp_CallOp(opid, pb_bottom_graph,
291+
ctx._get_op_settings(op_name))
295292

296293
opid = _ir.OperatorIdentifier("ai.onnx", "Loop", 11, _ir.NumInputs(), 0)
297294
op_name = pb_top_graph.id.str() + "__loop__" + pb_middle_graph.id.str()
298295

299296
# Loop the middle_graph
300-
pb_loop_op = pb_top_graph.createOp_LoopOp(
301-
opid, ctx._get_op_settings(op_name), pb_middle_graph
302-
)
297+
pb_loop_op = pb_top_graph.createOp_LoopOp(opid,
298+
ctx._get_op_settings(op_name),
299+
pb_middle_graph)
303300

304301
# Add mandatory loop iterator tensor to graph (is not an output)
305302
repeatIterId = _ir.addScope(pb_middle_graph, "Iterator___")
306-
pb_middle_graph.addInput(repeatIterId, _ir.TensorInfo(_ir.DataType.INT32, ()))
303+
pb_middle_graph.addInput(repeatIterId,
304+
_ir.TensorInfo(_ir.DataType.INT32, ()))
307305

308306
# Add mandatory loop condition tensor to graph (is also an output)
309307
repeatCondId = _ir.addScope(pb_middle_graph, "LoopCond___")
310-
pb_middle_graph.addInput(repeatCondId, _ir.TensorInfo(_ir.DataType.BOOL, ()))
308+
pb_middle_graph.addInput(repeatCondId,
309+
_ir.TensorInfo(_ir.DataType.BOOL, ()))
311310
pb_middle_graph.markAsOutput(repeatCondId)
312311

313312
return pb_middle_graph, pb_callop, pb_loop_op
@@ -354,8 +353,7 @@ def _setup_inputs(
354353
False,
355354
)
356355
pb_callop.connectInTensor(
357-
call_in_idx, _ir.addScope(pb_middle_graph, parent_tensor.name)
358-
)
356+
call_in_idx, _ir.addScope(pb_middle_graph, parent_tensor.name))
359357

360358

361359
def _setup_outputs(
@@ -385,33 +383,31 @@ def _setup_outputs(
385383

386384
for pb_subgraph_out_id in pb_bottom_graph.getOutputIds():
387385
top_tensor_id = _ir.addScope(
388-
pb_top_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)
389-
)
386+
pb_top_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id))
390387
# Already has scope added
391388
middle_tensor_id = _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)
392389
bottom_tensor_id = _ir.addScope(
393-
pb_bottom_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)
394-
)
390+
pb_bottom_graph,
391+
_ir.removeScope(pb_bottom_graph, pb_subgraph_out_id))
395392

396393
sgOutIdx = pb_bottom_graph.getOutputIndex(bottom_tensor_id)
397394
callOutIdx = pb_callop.subgraphOutToOpOutIndex(sgOutIdx)
398395

399396
# Avoid tensor name collisions
400397
middle_tensor_id = pb_middle_graph.getIr().createIntermediateTensorId(
401-
middle_tensor_id
402-
)
398+
middle_tensor_id)
403399
pb_callop.createAndConnectOutTensor(callOutIdx, middle_tensor_id)
404400

405401
pb_middle_graph.markAsOutput(middle_tensor_id)
406402
sgOutIdx = pb_middle_graph.getOutputIndex(middle_tensor_id)
407403
repeatOutIdx = pb_loop_op.subgraphOutToOpOutIndex(sgOutIdx)
408404
# Avoid tensor name collisions
409405
top_tensor_id = pb_middle_graph.getIr().createIntermediateTensorId(
410-
top_tensor_id
411-
)
406+
top_tensor_id)
412407
# We overwrite here as we added the middle_tensor_id as an output above, but we want to make
413408
# sure the loop op is setup correctly.
414-
pb_loop_op.addLoopOutput(repeatOutIdx, top_tensor_id, middle_tensor_id, True)
409+
pb_loop_op.addLoopOutput(repeatOutIdx, top_tensor_id, middle_tensor_id,
410+
True)
415411

416412
outnames.append(top_tensor_id)
417413
return outnames

0 commit comments

Comments
 (0)