|
7 | 7 | import pytensor |
8 | 8 | from pytensor import function |
9 | 9 | from pytensor.configdefaults import config |
10 | | -from pytensor.tensor.basic import as_tensor_variable |
| 10 | +from pytensor.tensor.basic import arange, as_tensor_variable |
11 | 11 | from pytensor.tensor.math import _allclose |
12 | 12 | from pytensor.tensor.nlinalg import ( |
13 | 13 | SVD, |
|
41 | 41 | vector, |
42 | 42 | ) |
43 | 43 | from tests import unittest_tools as utt |
| 44 | +from tests.test_rop import break_op |
44 | 45 |
|
45 | 46 |
|
46 | 47 | def test_pseudoinverse_correctness(): |
@@ -101,6 +102,53 @@ def test_infer_shape(self): |
101 | 102 |
|
102 | 103 | self._compile_and_check([x], [xi], [r], self.op_class, warn=False) |
103 | 104 |
|
| 105 | + def test_rop_lop(self): |
| 106 | + rtol = 1e-7 if config.floatX == "float64" else 1e-5 |
| 107 | + mx = matrix("mx") |
| 108 | + mv = matrix("mv") |
| 109 | + v = vector("v") |
| 110 | + y = MatrixInverse()(mx).sum(axis=0) |
| 111 | + |
| 112 | + yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True) |
| 113 | + rop_f = function([mx, mv], yv) |
| 114 | + |
| 115 | + yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False) |
| 116 | + rop_via_lop_f = function([mx, mv], yv_via_lop) |
| 117 | + |
| 118 | + sy, _ = pytensor.scan( |
| 119 | + lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), |
| 120 | + sequences=arange(y.shape[0]), |
| 121 | + non_sequences=[y, mx, mv], |
| 122 | + ) |
| 123 | + scan_f = function([mx, mv], sy) |
| 124 | + |
| 125 | + rng = np.random.default_rng(utt.fetch_seed()) |
| 126 | + vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) |
| 127 | + vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) |
| 128 | + |
| 129 | + v_ref = scan_f(vx, vv) |
| 130 | + np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) |
| 131 | + np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol) |
| 132 | + |
| 133 | + with pytest.raises(ValueError): |
| 134 | + pytensor.gradient.Rop( |
| 135 | + pytensor.clone_replace(y, replace={mx: break_op(mx)}), |
| 136 | + mx, |
| 137 | + mv, |
| 138 | + use_op_rop_implementation=True, |
| 139 | + ) |
| 140 | + |
| 141 | + vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) |
| 142 | + yv = pytensor.gradient.Lop(y, mx, v) |
| 143 | + lop_f = function([mx, v], yv) |
| 144 | + |
| 145 | + sy = pytensor.gradient.grad((v * y).sum(), mx) |
| 146 | + scan_f = function([mx, v], sy) |
| 147 | + |
| 148 | + v_ref = scan_f(vx, vv) |
| 149 | + v = lop_f(vx, vv) |
| 150 | + np.testing.assert_allclose(v, v_ref, rtol=rtol) |
| 151 | + |
104 | 152 |
|
105 | 153 | def test_matrix_dot(): |
106 | 154 | rng = np.random.default_rng(utt.fetch_seed()) |
|
0 commit comments