File tree Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Original file line number Diff line number Diff line change @@ -224,3 +224,20 @@ def forward(
224
224
225
225
out = gaf_function (tree_spec , * tree_nodes )
226
226
return out
227
+
228
+ # helper functions for disabling GAF wrappers within a network
229
+ # for handy ablation, in the case subnetworks within a neural network were wrapped
230
+
231
+ def set_filter_gradients_ (
232
+ m : Module ,
233
+ filter_gradients : bool ,
234
+ filter_distance_thres = None
235
+ ):
236
+ for module in m .modules ():
237
+ if not isinstance (module , GAFWrapper ):
238
+ continue
239
+
240
+ module .filter_gradients = filter_gradients
241
+
242
+ if exists (filter_distance_thres ):
243
+ module .filter_distance_thres = filter_distance_thres
Original file line number Diff line number Diff line change 1
- from GAF_microbatch_pytorch .GAF import GAFWrapper
1
+ from GAF_microbatch_pytorch .GAF import (
2
+ GAFWrapper ,
3
+ set_filter_gradients_
4
+ )
Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " GAF-microbatch-pytorch"
3
- version = " 0.0.2 "
3
+ version = " 0.0.3 "
4
4
description = " Gradient Agreement Filtering"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
You can’t perform that action at this time.
0 commit comments