Skip to content

Commit a2c5a64

Browse files
committed
if all neighbors outside of filter dist threshold, should return zero gradients
1 parent 324b82e commit a2c5a64

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

GAF_microbatch_pytorch/GAF.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def filter_gradients_by_agreement(
7373
else:
7474
raise ValueError(f'unknown strategy {strategy}')
7575

76-
if not accept_mask.any():
76+
if accept_mask.sum().item() <= 1:
7777
return torch.zeros_like(grads)
7878

7979
if accept_mask.all():

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "GAF-microbatch-pytorch"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "Gradient Agreement Filtering"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_gaf.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from GAF_microbatch_pytorch import GAFWrapper, set_filter_gradients_
99

10-
def test_gaf():
10+
def test_unfiltered_gaf():
1111

1212
net = nn.Sequential(
1313
nn.Linear(512, 256),
@@ -47,7 +47,7 @@ def test_gaf():
4747

4848
gaf_net = GAFWrapper(
4949
deepcopy(net),
50-
filter_distance_thres = 0.
50+
filter_distance_thres = 0.7
5151
)
5252

5353
x = torch.randn(8, 1024, 512)
@@ -65,4 +65,27 @@ def test_gaf():
6565
grad = net[0].weight.grad
6666
grad_filtered = gaf_net.net[0].weight.grad
6767

68-
assert not torch.allclose(grad, grad_filtered, atol = 1e-6)
68+
assert not (grad_filtered == 0.).all() and not torch.allclose(grad, grad_filtered, atol = 1e-6)
69+
70+
def test_all_filtered_gaf():
71+
72+
net = nn.Sequential(
73+
nn.Linear(512, 256),
74+
nn.SiLU(),
75+
nn.Linear(256, 128)
76+
)
77+
78+
gaf_net = GAFWrapper(
79+
deepcopy(net),
80+
filter_distance_thres = 0.
81+
)
82+
83+
x = torch.randn(8, 1024, 512)
84+
x.requires_grad_()
85+
86+
out = gaf_net(x)
87+
out.sum().backward()
88+
89+
grad_filtered = gaf_net.net[0].weight.grad
90+
91+
assert (grad_filtered == 0.).all()

0 commit comments

Comments
 (0)