Skip to content

Commit 7696bc6

Browse files
committed
able to wrap a bunch of subnetworks and then control whether filtering is on /off, in the case torch func breaks for some large complicated network
1 parent 0815584 commit 7696bc6

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

GAF_microbatch_pytorch/GAF.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,20 @@ def forward(
224224

225225
out = gaf_function(tree_spec, *tree_nodes)
226226
return out
227+
228+
# helper functions for disabling GAF wrappers within a network
229+
# for handy ablation, in the case subnetworks within a neural network were wrapped
230+
231+
def set_filter_gradients_(
232+
m: Module,
233+
filter_gradients: bool,
234+
filter_distance_thres = None
235+
):
236+
for module in m.modules():
237+
if not isinstance(module, GAFWrapper):
238+
continue
239+
240+
module.filter_gradients = filter_gradients
241+
242+
if exists(filter_distance_thres):
243+
module.filter_distance_thres = filter_distance_thres

GAF_microbatch_pytorch/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from GAF_microbatch_pytorch.GAF import GAFWrapper
1+
from GAF_microbatch_pytorch.GAF import (
2+
GAFWrapper,
3+
set_filter_gradients_
4+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "GAF-microbatch-pytorch"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = "Gradient Agreement Filtering"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)