|
115 | 115 | simplify_mul,
|
116 | 116 | )
|
117 | 117 | from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
|
| 118 | +from pytensor.tensor.slinalg import BlockDiagonal |
118 | 119 | from pytensor.tensor.type import (
|
119 | 120 | TensorType,
|
120 | 121 | cmatrix,
|
@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
|
4745 | 4746 | out.eval({a: a_test, b: b_test}, mode=test_mode),
|
4746 | 4747 | rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
|
4747 | 4748 | )
|
| 4749 | + |
| 4750 | + |
| 4751 | +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) |
| 4752 | +@pytest.mark.parametrize( |
| 4753 | + "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] |
| 4754 | +) |
| 4755 | +@pytest.mark.parametrize( |
| 4756 | + "batch_other", [True, False], ids=["batched_other", "unbatched_other"] |
| 4757 | +) |
| 4758 | +def test_local_block_diag_dot_to_dot_block_diag( |
| 4759 | + left_multiply, batch_blockdiag, batch_other |
| 4760 | +): |
| 4761 | + """ |
| 4762 | + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) |
| 4763 | + """ |
| 4764 | + |
| 4765 | + def has_blockdiag(graph): |
| 4766 | + return any( |
| 4767 | + ( |
| 4768 | + var.owner |
| 4769 | + and ( |
| 4770 | + isinstance(var.owner.op, BlockDiagonal) |
| 4771 | + or ( |
| 4772 | + isinstance(var.owner.op, Blockwise) |
| 4773 | + and isinstance(var.owner.op.core_op, BlockDiagonal) |
| 4774 | + ) |
| 4775 | + ) |
| 4776 | + ) |
| 4777 | + for var in ancestors([graph]) |
| 4778 | + ) |
| 4779 | + |
| 4780 | + a = tensor("a", shape=(4, 2)) |
| 4781 | + b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4)) |
| 4782 | + c = tensor("c", shape=(4, 4)) |
| 4783 | + x = pt.linalg.block_diag(a, b, c) |
| 4784 | + |
| 4785 | + d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) |
| 4786 | + |
| 4787 | + # Test multiple clients are all rewritten |
| 4788 | + if left_multiply: |
| 4789 | + out = x @ d |
| 4790 | + else: |
| 4791 | + out = d @ x |
| 4792 | + |
| 4793 | + assert has_blockdiag(out) |
| 4794 | + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) |
| 4795 | + assert not has_blockdiag(fn.maker.fgraph.outputs[0]) |
| 4796 | + |
| 4797 | + n_dots_rewrite = sum( |
| 4798 | + isinstance(node.op, Dot | Dot22) |
| 4799 | + or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) |
| 4800 | + for node in fn.maker.fgraph.apply_nodes |
| 4801 | + ) |
| 4802 | + assert n_dots_rewrite == 3 |
| 4803 | + |
| 4804 | + fn_expected = pytensor.function( |
| 4805 | + [a, b, c, d], |
| 4806 | + out, |
| 4807 | + mode=Mode(linker="py", optimizer=None), |
| 4808 | + ) |
| 4809 | + assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) |
| 4810 | + |
| 4811 | + n_dots_no_rewrite = sum( |
| 4812 | + isinstance(node.op, Dot | Dot22) |
| 4813 | + or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) |
| 4814 | + for node in fn_expected.maker.fgraph.apply_nodes |
| 4815 | + ) |
| 4816 | + assert n_dots_no_rewrite == 1 |
| 4817 | + |
| 4818 | + rng = np.random.default_rng() |
| 4819 | + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) |
| 4820 | + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) |
| 4821 | + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) |
| 4822 | + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) |
| 4823 | + |
| 4824 | + rewrite_out = fn(a_val, b_val, c_val, d_val) |
| 4825 | + expected_out = fn_expected(a_val, b_val, c_val, d_val) |
| 4826 | + np.testing.assert_allclose( |
| 4827 | + rewrite_out, |
| 4828 | + expected_out, |
| 4829 | + atol=1e-6 if config.floatX == "float32" else 1e-12, |
| 4830 | + rtol=1e-6 if config.floatX == "float32" else 1e-12, |
| 4831 | + ) |
| 4832 | + |
| 4833 | + |
| 4834 | +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) |
| 4835 | +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) |
| 4836 | +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): |
| 4837 | + rng = np.random.default_rng() |
| 4838 | + a_size = int(rng.uniform(0, size)) |
| 4839 | + b_size = int(rng.uniform(0, size - a_size)) |
| 4840 | + c_size = size - a_size - b_size |
| 4841 | + |
| 4842 | + a = tensor("a", shape=(a_size, a_size)) |
| 4843 | + b = tensor("b", shape=(b_size, b_size)) |
| 4844 | + c = tensor("c", shape=(c_size, c_size)) |
| 4845 | + d = tensor("d", shape=(size,)) |
| 4846 | + |
| 4847 | + x = pt.linalg.block_diag(a, b, c) |
| 4848 | + out = x @ d |
| 4849 | + |
| 4850 | + mode = get_default_mode() |
| 4851 | + if not rewrite: |
| 4852 | + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") |
| 4853 | + fn = pytensor.function([a, b, c, d], out, mode=mode) |
| 4854 | + |
| 4855 | + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) |
| 4856 | + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) |
| 4857 | + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) |
| 4858 | + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) |
| 4859 | + |
| 4860 | + benchmark( |
| 4861 | + fn, |
| 4862 | + a_val, |
| 4863 | + b_val, |
| 4864 | + c_val, |
| 4865 | + d_val, |
| 4866 | + ) |
0 commit comments