-
Notifications
You must be signed in to change notification settings - Fork 22
Torch integration working with CUDA, Vulkan and D3D12 #362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…hader-slang/slangpy into dev/ccummings/torchintegration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some of the comments are on the nitpicky side, so feel free to ignore. I didn't super deep into the marshalling and pytorch autograd implementation but reading through it looked good to me.
@@ -91,8 +91,16 @@ jobs: | |||
run: | | |||
sudo apt update && sudo apt install -y libxinerama-dev libxcursor-dev xorg-dev libglu1-mesa-dev pkg-config | |||
|
|||
# Setup Python. | |||
- name: Setup Python ${{ matrix.python }} | |||
# Setup Python (no pip cache on unit test windows runners - massive slow down). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we use the cache?
if (command_encoder) { | ||
SGL_CHECK( | ||
!cuda_stream.is_valid(), | ||
"Can not specify cuda stream if appending to a command encoder." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: in general I think we should use CUDA
all-caps in text/strings/output
@@ -40,6 +40,8 @@ SGL_PY_EXPORT(device_kernel) | |||
uint3 thread_count, | |||
nb::dict vars, | |||
CommandEncoder* command_encoder, | |||
CommandQueueType queue, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support multiple queues, so is this needed?
{ | ||
uint8_t buffer[8]; | ||
for (int i = 0; i < 8; ++i) { | ||
buffer[7 - i] = HEX_CHARS[(value >> (i * 4)) & 0xF]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: there is sgl::string::hexlify
but it returns a std::string
so this has no allocation overhead. but we could also introduce a hexlify overload that takes an output buffer and use that.
nb::arg("args"), | ||
nb::arg("kwargs"), | ||
D_NA(NativeCallData, _py_torch_call) | ||
) | ||
.def_prop_rw( | ||
"call_group_shape", | ||
&NativeCallData::get_call_group_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated but this should be just call_group_shape
std::optional<nb::ndarray<nb::pytorch, nb::device::cuda>> tensor() const { return m_tensor; } | ||
|
||
void set_tensor(const std::optional<nb::ndarray<nb::pytorch, nb::device::cuda>> tensor) { m_tensor = tensor; } | ||
|
||
ref<Buffer> interop_buffer() const { return m_interop_buffer; } | ||
|
||
void set_interop_buffer(const ref<Buffer>& interop_buffer) { m_interop_buffer = interop_buffer; } | ||
|
||
int32_t id() const { return m_id; } | ||
|
||
void set_id(int32_t id) { m_id = id; } | ||
|
||
ref<TensorRef> grad_in() const { return m_grad_in; } | ||
void set_grad_in(const ref<TensorRef>& grad_in) { m_grad_in = grad_in; } | ||
|
||
ref<TensorRef> grad_out() const { return m_grad_out; } | ||
void set_grad_out(const ref<TensorRef>& grad_out) { m_grad_out = grad_out; } | ||
|
||
std::pair<AccessType, AccessType> last_access() const { return m_last_access; } | ||
void set_last_access(const std::pair<AccessType, AccessType>& last_access) { m_last_access = last_access; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks a lot like these could all just be public fields instead of getter/setters
Specify a CUDA stream to use for the function. This is useful for synchronizing with other | ||
CUDA operations or ensuring that the function runs on a specific stream. | ||
""" | ||
if stream.type != NativeHandleType.CUstream: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you also do that check in FunctionNodeCUDAStream
constructor
@@ -82,6 +80,17 @@ def get_device( | |||
"Please set use_cache=False if you want to use existing_device_handles." | |||
) | |||
|
|||
selected_adaptor_luid = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slightly different then in sglhelpers.py
, maybe worth consolidating. i.e. we could maybe select the adapter in conftest.py?
torch.cuda.current_device() | ||
torch.cuda.current_stream() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these calls necessary?
from typing import Any, Optional, cast | ||
from numpy import ScalarType | ||
from slangpy import DataType, Device, BufferUsage, TypeReflection, DeviceType | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we guaranteed that this import is not run at import slangpy
time?
This is the first version of rewritten torch integration, which includes fixed interop with vulkan/d3d, and direct device sharing with CUDA. Key additions:
What's missing and will be in next PR:
I'm still not entirely happy with the hoops we have to jump through for the auto-grad hook, but PyTorch has some pretty rock solid rules about what you can/can't do involving the storing/clearing/copying of tensors, and I can't find a simpler way to work around them.