Skip to content

Torch integration working with CUDA, Vulkan and D3D12 #362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 51 commits into
base: main
Choose a base branch
from

Conversation

ccummingsNV
Copy link
Contributor

@ccummingsNV ccummingsNV commented Jul 21, 2025

This is the first version of rewritten torch integration, which includes fixed interop with vulkan/d3d, and direct device sharing with CUDA. Key additions:

  • Lots of fixes for CUDA interop and CUDA backend
  • Special 'torch module/function' removed - use of torch is now auto detected from parameters passed to a function
  • Reduced complexity of autograd integrationn + moved some stuff native

What's missing and will be in next PR:

  • Shifting more stuff to native land, such as the marshalling of TensorRef and aspects of the torch call
  • A heap for interop buffers, to avoid constant re-allocation of buffers for passing in/out of torch

I'm still not entirely happy with the hoops we have to jump through for the auto-grad hook, but PyTorch has some pretty rock solid rules about what you can/can't do involving the storing/clearing/copying of tensors, and I can't find a simpler way to work around them.

@ccummingsNV ccummingsNV requested a review from a team as a code owner July 25, 2025 14:11
@ccummingsNV ccummingsNV requested a review from Copilot July 29, 2025 13:28
Copilot

This comment was marked as outdated.

@ccummingsNV ccummingsNV requested a review from Copilot July 30, 2025 13:19
@ccummingsNV ccummingsNV changed the title Draft: Working torch integration Torch integration working with CUDA, Vulkan and D3D12 Jul 30, 2025
Copilot

This comment was marked as outdated.

@ccummingsNV ccummingsNV requested a review from Copilot July 30, 2025 13:48
Copilot

This comment was marked as resolved.

@ccummingsNV ccummingsNV requested a review from Copilot July 30, 2025 14:12
Copilot

This comment was marked as resolved.

@ccummingsNV ccummingsNV requested a review from tunabrain July 30, 2025 21:26
Copy link
Contributor

@skallweitNV skallweitNV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some of the comments are on the nitpicky side, so feel free to ignore. I didn't super deep into the marshalling and pytorch autograd implementation but reading through it looked good to me.

@@ -91,8 +91,16 @@ jobs:
run: |
sudo apt update && sudo apt install -y libxinerama-dev libxcursor-dev xorg-dev libglu1-mesa-dev pkg-config

# Setup Python.
- name: Setup Python ${{ matrix.python }}
# Setup Python (no pip cache on unit test windows runners - massive slow down).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use the cache?

if (command_encoder) {
SGL_CHECK(
!cuda_stream.is_valid(),
"Can not specify cuda stream if appending to a command encoder."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: in general I think we should use CUDA all-caps in text/strings/output

@@ -40,6 +40,8 @@ SGL_PY_EXPORT(device_kernel)
uint3 thread_count,
nb::dict vars,
CommandEncoder* command_encoder,
CommandQueueType queue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't support multiple queues, so is this needed?

{
uint8_t buffer[8];
for (int i = 0; i < 8; ++i) {
buffer[7 - i] = HEX_CHARS[(value >> (i * 4)) & 0xF];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: there is sgl::string::hexlify but it returns a std::string so this has no allocation overhead. but we could also introduce a hexlify overload that takes an output buffer and use that.

nb::arg("args"),
nb::arg("kwargs"),
D_NA(NativeCallData, _py_torch_call)
)
.def_prop_rw(
"call_group_shape",
&NativeCallData::get_call_group_shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but this should be just call_group_shape

Comment on lines +821 to +840
std::optional<nb::ndarray<nb::pytorch, nb::device::cuda>> tensor() const { return m_tensor; }

void set_tensor(const std::optional<nb::ndarray<nb::pytorch, nb::device::cuda>> tensor) { m_tensor = tensor; }

ref<Buffer> interop_buffer() const { return m_interop_buffer; }

void set_interop_buffer(const ref<Buffer>& interop_buffer) { m_interop_buffer = interop_buffer; }

int32_t id() const { return m_id; }

void set_id(int32_t id) { m_id = id; }

ref<TensorRef> grad_in() const { return m_grad_in; }
void set_grad_in(const ref<TensorRef>& grad_in) { m_grad_in = grad_in; }

ref<TensorRef> grad_out() const { return m_grad_out; }
void set_grad_out(const ref<TensorRef>& grad_out) { m_grad_out = grad_out; }

std::pair<AccessType, AccessType> last_access() const { return m_last_access; }
void set_last_access(const std::pair<AccessType, AccessType>& last_access) { m_last_access = last_access; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks a lot like these could all just be public fields instead of getter/setters

Specify a CUDA stream to use for the function. This is useful for synchronizing with other
CUDA operations or ensuring that the function runs on a specific stream.
"""
if stream.type != NativeHandleType.CUstream:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also do that check in FunctionNodeCUDAStream constructor

@@ -82,6 +80,17 @@ def get_device(
"Please set use_cache=False if you want to use existing_device_handles."
)

selected_adaptor_luid = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly different then in sglhelpers.py, maybe worth consolidating. i.e. we could maybe select the adapter in conftest.py?

Comment on lines +146 to +147
torch.cuda.current_device()
torch.cuda.current_stream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these calls necessary?

from typing import Any, Optional, cast
from numpy import ScalarType
from slangpy import DataType, Device, BufferUsage, TypeReflection, DeviceType
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we guaranteed that this import is not run at import slangpy time?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants