Skip to content

Commit 324b82e

Browse files
committed
tests
1 parent fd6e001 commit 324b82e

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

.github/workflows/test.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: Tests the examples in README
2+
on: [push, pull_request]
3+
4+
env:
5+
TYPECHECK: True
6+
7+
jobs:
8+
test:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@v4
12+
- name: Install Python
13+
uses: actions/setup-python@v5
14+
with:
15+
python-version: "3.11"
16+
- name: Install dependencies
17+
run: |
18+
python -m pip install uv
19+
python -m uv pip install --upgrade pip
20+
python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21+
python -m uv pip install -e .[test]
22+
- name: Test with pytest
23+
run: |
24+
python -m pytest tests/

tests/test_gaf.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
from copy import deepcopy
3+
4+
import torch
5+
from torch import nn
6+
torch.set_default_dtype(torch.float64)
7+
8+
from GAF_microbatch_pytorch import GAFWrapper, set_filter_gradients_
9+
10+
def test_gaf():
11+
12+
net = nn.Sequential(
13+
nn.Linear(512, 256),
14+
nn.SiLU(),
15+
nn.Linear(256, 128)
16+
)
17+
18+
gaf_net = GAFWrapper(
19+
deepcopy(net),
20+
filter_distance_thres = 2.
21+
)
22+
23+
x = torch.randn(8, 1024, 512)
24+
y = x.clone()
25+
26+
x.requires_grad_()
27+
y.requires_grad_()
28+
29+
out1 = net(x)
30+
out2 = gaf_net(y)
31+
32+
out1.sum().backward()
33+
out2.sum().backward()
34+
35+
grad = net[0].weight.grad
36+
grad_filtered = gaf_net.net[0].weight.grad
37+
38+
assert torch.allclose(grad, grad_filtered, atol = 1e-6)
39+
40+
def test_gaf():
41+
42+
net = nn.Sequential(
43+
nn.Linear(512, 256),
44+
nn.SiLU(),
45+
nn.Linear(256, 128)
46+
)
47+
48+
gaf_net = GAFWrapper(
49+
deepcopy(net),
50+
filter_distance_thres = 0.
51+
)
52+
53+
x = torch.randn(8, 1024, 512)
54+
y = x.clone()
55+
56+
x.requires_grad_()
57+
y.requires_grad_()
58+
59+
out1 = net(x)
60+
out2 = gaf_net(y)
61+
62+
out1.sum().backward()
63+
out2.sum().backward()
64+
65+
grad = net[0].weight.grad
66+
grad_filtered = gaf_net.net[0].weight.grad
67+
68+
assert not torch.allclose(grad, grad_filtered, atol = 1e-6)

0 commit comments

Comments
 (0)