File tree Expand file tree Collapse file tree 3 files changed +28
-5
lines changed Expand file tree Collapse file tree 3 files changed +28
-5
lines changed Original file line number Diff line number Diff line change @@ -73,7 +73,7 @@ def filter_gradients_by_agreement(
73
73
else :
74
74
raise ValueError (f'unknown strategy { strategy } ' )
75
75
76
- if not accept_mask .any () :
76
+ if accept_mask .sum (). item () <= 1 :
77
77
return torch .zeros_like (grads )
78
78
79
79
if accept_mask .all ():
Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " GAF-microbatch-pytorch"
3
- version = " 0.0.4 "
3
+ version = " 0.0.5 "
4
4
description = " Gradient Agreement Filtering"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change 7
7
8
8
from GAF_microbatch_pytorch import GAFWrapper , set_filter_gradients_
9
9
10
- def test_gaf ():
10
+ def test_unfiltered_gaf ():
11
11
12
12
net = nn .Sequential (
13
13
nn .Linear (512 , 256 ),
@@ -47,7 +47,7 @@ def test_gaf():
47
47
48
48
gaf_net = GAFWrapper (
49
49
deepcopy (net ),
50
- filter_distance_thres = 0.
50
+ filter_distance_thres = 0.7
51
51
)
52
52
53
53
x = torch .randn (8 , 1024 , 512 )
@@ -65,4 +65,27 @@ def test_gaf():
65
65
grad = net [0 ].weight .grad
66
66
grad_filtered = gaf_net .net [0 ].weight .grad
67
67
68
- assert not torch .allclose (grad , grad_filtered , atol = 1e-6 )
68
+ assert not (grad_filtered == 0. ).all () and not torch .allclose (grad , grad_filtered , atol = 1e-6 )
69
+
70
+ def test_all_filtered_gaf ():
71
+
72
+ net = nn .Sequential (
73
+ nn .Linear (512 , 256 ),
74
+ nn .SiLU (),
75
+ nn .Linear (256 , 128 )
76
+ )
77
+
78
+ gaf_net = GAFWrapper (
79
+ deepcopy (net ),
80
+ filter_distance_thres = 0.
81
+ )
82
+
83
+ x = torch .randn (8 , 1024 , 512 )
84
+ x .requires_grad_ ()
85
+
86
+ out = gaf_net (x )
87
+ out .sum ().backward ()
88
+
89
+ grad_filtered = gaf_net .net [0 ].weight .grad
90
+
91
+ assert (grad_filtered == 0. ).all ()
You can’t perform that action at this time.
0 commit comments