Skip to content

Commit 66e5591

Browse files
authored
[Group Partitioner] Optimize Speed (#12844)
We do some optimizations on the group partitioner to improve its speed. Theoretically, since we've pre-grouped some partitions already, we should be faster than the capability based partitioner. For example, if we a dynamically quantized group, this could contain ``` [ choose_q_params, quantize, dequantize, op, dequantize_weight, weight, bias, get_item_scale, get_item_zp] ``` 9 nodes. In capability-based partitioner they will have to run DFS on all 9 nodes in order to group these together. Based on the hints and purpose of the group based partitioenr, we don't perform these checks and instead group all these 9 nodes, saving time by avoiding these checks. Some stats when partitioning the mobile bert model: ``` elpased time old partitioner: 65.3421 old_partitioner num partitions: 170 elpased time new partitioner: 5.1964 new_partitioner num partitions: 170 ``` we see a 13x improvement in partitioning when using the group based partitioner, while still partitioning around the same number of nodes.
1 parent 2120135 commit 66e5591

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
)
8787
self.node_to_group = collections.defaultdict(int)
8888
self.all_nodes_in_groups = set()
89-
if node_groups:
89+
if self.node_groups:
9090
for i, group in enumerate(self.node_groups):
9191
for node in group:
9292
# Node is in multiple groups - not allowed
@@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
101101
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102102
combined_nodes = p1_nodes.union(p2_nodes)
103103

104-
for node in combined_nodes:
105-
# Get all downstream nodes that are not in the combined partition
106-
external_downstreams = {
107-
n
108-
for n in self.dependency_viewer.downstreams_of(node)
109-
if n not in combined_nodes
110-
}
104+
user_nodes = []
105+
# topologically, p2_nodes comes before p1_nodes, so we only
106+
# need to check the downstream nodes of p2.
107+
# Additionally, we don't need to check all the downstream nodes
108+
# of p2, we only need to check the nodes directly outside of p2.
109+
# example:
110+
# partition[a --> b --> c] --> d --> e --> f
111+
# we don't need to check [d, e, f] we only need to check [d] because
112+
# the downstream users of [d] will include [e, f]
113+
for node in p2_nodes:
114+
for user in node.users:
115+
if user not in combined_nodes:
116+
user_nodes.append(user)
111117

118+
for external_node in user_nodes:
112119
# Check if any external downstream nodes have downstream nodes in the combined partition
113-
for external_node in external_downstreams:
114-
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
115-
if any(n in combined_nodes for n in downstream_nodes):
116-
return False
120+
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
121+
if any(n in combined_nodes for n in downstream_nodes):
122+
return False
117123

118124
return True
119125

@@ -133,13 +139,30 @@ def _process_node_groups(
133139
if not self.node_groups:
134140
return group_to_partition_id
135141

136-
for i, group in enumerate(self.node_groups):
137-
# Create a partition for each group
142+
processed_nodes = set()
143+
144+
# We have to create the partitions in reverse topological order
145+
# so we find the groups as we traverse backwards in the graph
146+
# this likely needs to be combined with the process_remaining_nodes
147+
# TODO: this currently doesn't work with _process_remaining_nodes so
148+
# if a user provides grouped nodes with operatorsupport, then this will
149+
# faile
150+
for node in reversed(self.graph_module.graph.nodes):
151+
if node not in self.node_to_group:
152+
continue
153+
154+
if node in processed_nodes:
155+
continue
156+
157+
group_idx = self.node_to_group[node]
158+
group = self.node_groups[group_idx]
159+
160+
# Create a partition for group
138161
partition_id = next(new_partition_id)
139162
partition = Partition(id=partition_id, nodes=set())
140163
partitions_by_id[partition_id] = partition
141164
partitions_order[partition_id] = partition_id
142-
group_to_partition_id[i] = partition_id
165+
group_to_partition_id[group_idx] = partition_id
143166

144167
# Add all supported nodes from the group to the partition
145168
for node in group:
@@ -164,6 +187,12 @@ def _process_node_groups(
164187
partition_map[partition_id].add(target_id)
165188
partition_map[partition_id].update(partition_map[target_id])
166189

190+
# all the nodes in the group have now been processed
191+
# so skip if we encoutner them again in our rev topo
192+
# iteration
193+
for node in group:
194+
processed_nodes.add(node)
195+
167196
return group_to_partition_id
168197

169198
def _process_remaining_nodes(
@@ -209,7 +238,6 @@ def _merge_partitions(
209238

210239
# Set to track removed partitions from initial static list so we can skip them
211240
already_merged = set()
212-
213241
# Try to merge each pair of partitions
214242
for i, p1 in enumerate(partition_ids):
215243
# Skip if this partition has been already merged

0 commit comments

Comments
 (0)