Skip to content

torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_sparse.ind2ptr #400

@alexfanqi

Description

@alexfanqi

torch.compile fails for SparseTensor.matmul

torch version: 2.6.0
pytorch_sparse version: 0.6.18 py312hf276b08_7 from conda-forge

reproducer:

import torch
from torch_sparse import SparseTensor
import scipy.sparse as sp
import numpy as np


def test_compile_sparse(size = 100, density = 0.01):
    coo = sp.random_array(shape=(size, size), density=density, dtype=np.float32)
    x_np = np.random.rand(size).astype(np.float32)
    values = torch.tensor(coo.data, device="cuda")
    x = torch.tensor(x_np, device="cuda")
    @torch.compile
    def fn(values, x):
        coo_tensor = SparseTensor(
            row=torch.tensor(coo.row, device="cuda", dtype=torch.long),
            col=torch.tensor(coo.col, device="cuda", dtype=torch.long),
            value=values,
            sparse_sizes=coo.shape,
        )
        coo_tensor.matmul(x.unsqueeze(1)).sum()
    fn(values, x)

test_compile_sparse()
Full error message
/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py:2586: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return node.target(*args, **kwargs)
/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py:2586: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return node.target(*args, **kwargs)
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] Graph break from `Tensor.item()`, consider setting:
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] or:
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] to include these operations in the captured graph.
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] 
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] Graph break: from user code at:
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]   File "/tmp/test.py", line 14, in fn
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     coo_tensor = SparseTensor(
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]   File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/tensor.py", line 26, in __init__
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     self.storage = SparseStorage(
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]   File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/storage.py", line 69, in __init__
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0]     assert trust_data or int(row.max()) < M
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] 
W0519 22:51:51.846000 206303 site-packages/torch/_dynamo/variables/tensor.py:869] [0/0] 
/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py:2586: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return node.target(*args, **kwargs)
/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py:2586: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return node.target(*args, **kwargs)
/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/fx/interpreter.py:310: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return target(*args, **kwargs)
/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1241: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return func(*args, **kwargs)
Traceback (most recent call last):
  File "/tmp/test.py", line 23, in <module>
    test_compile_sparse()
  File "/tmp/test.py", line 21, in test_compile_sparse
    fn(values, x)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/test.py", line 12, in fn
    @torch.compile
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
           ^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2536, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
           ^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_sparse.ind2ptr(*(FakeTensor(..., device='cuda:0', size=(100,), dtype=torch.int64), 100), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet.
If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
   File "/tmp/test.py", line 20, in torch_dynamo_resume_in_fn_at_14
    coo_tensor.matmul(x.unsqueeze(1)).sum()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/matmul.py", line 169, in <lambda>
    SparseTensor.matmul = lambda self, other, reduce="sum": matmul(
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/matmul.py", line 160, in matmul
    return spmm(src, other, reduce)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/matmul.py", line 83, in spmm
    return spmm_sum(src, other)
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/matmul.py", line 10, in spmm_sum
    rowptr, col, value = src.csr()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/tensor.py", line 237, in csr
    return self.storage.rowptr(), self.storage.col(), self.storage.value()
  File "/home/alexfanqi/micromamba/envs/ml-py312/lib/python3.12/site-packages/torch_sparse/storage.py", line 209, in rowptr
    rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions