File tree Expand file tree Collapse file tree 2 files changed +92
-0
lines changed Expand file tree Collapse file tree 2 files changed +92
-0
lines changed Original file line number Diff line number Diff line change
1
+ name : Tests the examples in README
2
+ on : [push, pull_request]
3
+
4
+ env :
5
+ TYPECHECK : True
6
+
7
+ jobs :
8
+ test :
9
+ runs-on : ubuntu-latest
10
+ steps :
11
+ - uses : actions/checkout@v4
12
+ - name : Install Python
13
+ uses : actions/setup-python@v5
14
+ with :
15
+ python-version : " 3.11"
16
+ - name : Install dependencies
17
+ run : |
18
+ python -m pip install uv
19
+ python -m uv pip install --upgrade pip
20
+ python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21
+ python -m uv pip install -e .[test]
22
+ - name : Test with pytest
23
+ run : |
24
+ python -m pytest tests/
Original file line number Diff line number Diff line change
1
+ import pytest
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+ from torch import nn
6
+ torch .set_default_dtype (torch .float64 )
7
+
8
+ from GAF_microbatch_pytorch import GAFWrapper , set_filter_gradients_
9
+
10
+ def test_gaf ():
11
+
12
+ net = nn .Sequential (
13
+ nn .Linear (512 , 256 ),
14
+ nn .SiLU (),
15
+ nn .Linear (256 , 128 )
16
+ )
17
+
18
+ gaf_net = GAFWrapper (
19
+ deepcopy (net ),
20
+ filter_distance_thres = 2.
21
+ )
22
+
23
+ x = torch .randn (8 , 1024 , 512 )
24
+ y = x .clone ()
25
+
26
+ x .requires_grad_ ()
27
+ y .requires_grad_ ()
28
+
29
+ out1 = net (x )
30
+ out2 = gaf_net (y )
31
+
32
+ out1 .sum ().backward ()
33
+ out2 .sum ().backward ()
34
+
35
+ grad = net [0 ].weight .grad
36
+ grad_filtered = gaf_net .net [0 ].weight .grad
37
+
38
+ assert torch .allclose (grad , grad_filtered , atol = 1e-6 )
39
+
40
+ def test_gaf ():
41
+
42
+ net = nn .Sequential (
43
+ nn .Linear (512 , 256 ),
44
+ nn .SiLU (),
45
+ nn .Linear (256 , 128 )
46
+ )
47
+
48
+ gaf_net = GAFWrapper (
49
+ deepcopy (net ),
50
+ filter_distance_thres = 0.
51
+ )
52
+
53
+ x = torch .randn (8 , 1024 , 512 )
54
+ y = x .clone ()
55
+
56
+ x .requires_grad_ ()
57
+ y .requires_grad_ ()
58
+
59
+ out1 = net (x )
60
+ out2 = gaf_net (y )
61
+
62
+ out1 .sum ().backward ()
63
+ out2 .sum ().backward ()
64
+
65
+ grad = net [0 ].weight .grad
66
+ grad_filtered = gaf_net .net [0 ].weight .grad
67
+
68
+ assert not torch .allclose (grad , grad_filtered , atol = 1e-6 )
You can’t perform that action at this time.
0 commit comments