Skip to content

Commit 7dbc034

Browse files
committed
Move non-rewrite test
1 parent d5a8c0a commit 7dbc034

File tree

2 files changed

+50
-52
lines changed

2 files changed

+50
-52
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,57 +39,7 @@
3939
solve,
4040
solve_triangular,
4141
)
42-
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
43-
from tests import unittest_tools as utt
44-
from tests.test_rop import break_op
45-
46-
47-
def test_matrix_inverse_rop_lop():
48-
rtol = 1e-7 if config.floatX == "float64" else 1e-5
49-
mx = matrix("mx")
50-
mv = matrix("mv")
51-
v = vector("v")
52-
y = MatrixInverse()(mx).sum(axis=0)
53-
54-
yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True)
55-
rop_f = function([mx, mv], yv)
56-
57-
yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False)
58-
rop_via_lop_f = function([mx, mv], yv_via_lop)
59-
60-
sy, _ = pytensor.scan(
61-
lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(),
62-
sequences=pt.arange(y.shape[0]),
63-
non_sequences=[y, mx, mv],
64-
)
65-
scan_f = function([mx, mv], sy)
66-
67-
rng = np.random.default_rng(utt.fetch_seed())
68-
vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
69-
vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
70-
71-
v_ref = scan_f(vx, vv)
72-
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
73-
np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol)
74-
75-
with pytest.raises(ValueError):
76-
pytensor.gradient.Rop(
77-
pytensor.clone_replace(y, replace={mx: break_op(mx)}),
78-
mx,
79-
mv,
80-
use_op_rop_implementation=True,
81-
)
82-
83-
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
84-
yv = pytensor.gradient.Lop(y, mx, v)
85-
lop_f = function([mx, v], yv)
86-
87-
sy = pytensor.gradient.grad((v * y).sum(), mx)
88-
scan_f = function([mx, v], sy)
89-
90-
v_ref = scan_f(vx, vv)
91-
v = lop_f(vx, vv)
92-
np.testing.assert_allclose(v, v_ref, rtol=rtol)
42+
from pytensor.tensor.type import dmatrix, matrix, tensor
9343

9444

9545
def test_transinv_to_invtrans():

tests/tensor/test_nlinalg.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytensor
88
from pytensor import function
99
from pytensor.configdefaults import config
10-
from pytensor.tensor.basic import as_tensor_variable
10+
from pytensor.tensor.basic import arange, as_tensor_variable
1111
from pytensor.tensor.math import _allclose
1212
from pytensor.tensor.nlinalg import (
1313
SVD,
@@ -41,6 +41,7 @@
4141
vector,
4242
)
4343
from tests import unittest_tools as utt
44+
from tests.test_rop import break_op
4445

4546

4647
def test_pseudoinverse_correctness():
@@ -101,6 +102,53 @@ def test_infer_shape(self):
101102

102103
self._compile_and_check([x], [xi], [r], self.op_class, warn=False)
103104

105+
def test_rop_lop(self):
106+
rtol = 1e-7 if config.floatX == "float64" else 1e-5
107+
mx = matrix("mx")
108+
mv = matrix("mv")
109+
v = vector("v")
110+
y = MatrixInverse()(mx).sum(axis=0)
111+
112+
yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True)
113+
rop_f = function([mx, mv], yv)
114+
115+
yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False)
116+
rop_via_lop_f = function([mx, mv], yv_via_lop)
117+
118+
sy, _ = pytensor.scan(
119+
lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(),
120+
sequences=arange(y.shape[0]),
121+
non_sequences=[y, mx, mv],
122+
)
123+
scan_f = function([mx, mv], sy)
124+
125+
rng = np.random.default_rng(utt.fetch_seed())
126+
vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
127+
vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
128+
129+
v_ref = scan_f(vx, vv)
130+
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
131+
np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol)
132+
133+
with pytest.raises(ValueError):
134+
pytensor.gradient.Rop(
135+
pytensor.clone_replace(y, replace={mx: break_op(mx)}),
136+
mx,
137+
mv,
138+
use_op_rop_implementation=True,
139+
)
140+
141+
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
142+
yv = pytensor.gradient.Lop(y, mx, v)
143+
lop_f = function([mx, v], yv)
144+
145+
sy = pytensor.gradient.grad((v * y).sum(), mx)
146+
scan_f = function([mx, v], sy)
147+
148+
v_ref = scan_f(vx, vv)
149+
v = lop_f(vx, vv)
150+
np.testing.assert_allclose(v, v_ref, rtol=rtol)
151+
104152

105153
def test_matrix_dot():
106154
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)