Add Metal DLPack zero-copy sharing#3531
Conversation
|
Hi @XXXXRT666 — read through this PR after @awni redirected us here from #3548. The Wanted to offer some testing help that complements the PyTorch MPS bring-up you have: We maintain a downstream TileLang fork (https://github.com/DatasunriseOU/tilelang) whose TVM-FFI bridge exports
If useful, once the PR converges I can:
Tag me here when you'd like input — no rush, just don't want this to slip past once it's review-ready. (For the orthogonal |
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing). SHA: 33f52e635db5e6229060481d16a167230a1a474b PR: wjakob/nanobind#1338 Branch: metal-dlpack-cast
002360f to
4e16f1d
Compare
|
This would be super cool if it landed for end to end "0-copy" support in safetensors! I'm working (safetensors/safetensors#767) on adding reading bytes from disk in raw Also, support for Quick question on the |
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing). SHA: 33f52e635db5e6229060481d16a167230a1a474b PR: wjakob/nanobind#1338 Branch: metal-dlpack-cast
https://dmlc.github.io/dlpack/latest/c_api.html#c.DLTensor.data The data pointer points to the allocated data. This will be CUDA device pointer, |
4e16f1d to
a17cd99
Compare
Metal DLPack benchmarkMean over 50 measured iterations after 5 warmups on M4. Each timed iteration synchronizes the producer before timing and synchronizes/evaluates the result before stopping the timer.
PyTorch MPS -> MLX MetalBranch uses
MLX Metal -> PyTorch MPSBoth variants call
|
| torch.mps.synchronize() | ||
| print(a.tolist()) # [10.0, 11.0, 12.0] | ||
|
|
||
| a.copy(mx.array([4, 5, 6], dtype=mx.float32)) |
There was a problem hiding this comment.
I'm not sure about adding a array.copy op, I think the mental model would be simple if users just assume MLX array being immutable, is there a reason for this?
There was a problem hiding this comment.
I removed array.copy and the CopyInto primitive. The docs now keep this PR focused on zero-copy DLPack import/export and avoid introducing a new mutating array API.
| } | ||
|
|
||
| auto out = mx::array( | ||
| mx::allocator::Buffer(owner->data_handle()), |
There was a problem hiding this comment.
We should check whether the alien buffer uses private storage, and if so create a new buffer and do the copy here, this way we can simply ensure that all Metal buffers in MLX are using unified memory.
It is likely going to require a lot more work, but we can then avoid introducing the whole is_host_accessible concept in Metal backend and isolate the dirty work in one function.
There was a problem hiding this comment.
Private Metal DLPack buffers are now handled in the import path
|
|
||
| template <typename T> | ||
| mx::array metal_dlpack_to_mlx_contiguous( | ||
| std::shared_ptr<nb::ndarray<nb::ro, nb::c_contig>> owner, |
There was a problem hiding this comment.
I'm not sure if we need to store it in a shared_ptr, copying by value should be able to reserve the reference.
There was a problem hiding this comment.
The nanobind ndarray owner is now kept by value
| a.eval(); | ||
| a.wait(); | ||
| if (dl_device_type == nb::device::cpu::value) { | ||
| data = a.data<T>(); |
There was a problem hiding this comment.
It should be fine to use same code for cpu/gpu devices? i.e. just data = a.buffer().ptr().
There was a problem hiding this comment.
For CPU DLPack export, a.buffer().ptr() is the allocator-owned buffer handle, not the host data pointer. The CPU allocator stores its size header before the actual data, while a.data<T>() goes through raw_ptr() and returns the correct host data address
For Metal, a.buffer().ptr() is the expected id<MTLBuffer> handle. So the CPU and Metal branches intentionally use different pointers here
There was a problem hiding this comment.
That makes sense, thanks for clarification. In this case I think we can use data<void>() to avoid the type dispatching.
| shape.data(), | ||
| /* owner= */ owner, | ||
| a.strides().data(), | ||
| nb::dtype<T>(), |
There was a problem hiding this comment.
It would be simpler writing a mlx_dtype_to_dl_dtype utility and then pass a dlpack::dtype, to avoid the whole type dispatch.
There was a problem hiding this comment.
DLPack export now maps MLX dtype with mlx_dtype_to_dl_dtype and passes the resulting dlpack::dtype directly.
| } | ||
| } | ||
|
|
||
| mx::array from_dlpack(nb::object v, std::optional<bool> copy) { |
There was a problem hiding this comment.
nd_array and dlpack basically mean the same thing and having both nd_array_to_mlx and from_dlpack is confusing. Maybe moving the copy parameter to nd_array_to_mlx and have the python from_dlpack op just call it?
There was a problem hiding this comment.
from_dlpack now just casts once to the nanobind ndarray and calls nd_array_to_mlx(copy=...); the device and copy handling lives in one path.
|
One API question: should That would match the mental model used by NumPy/PyTorch more closely: |
Proposed changes
This draft adds zero-copy Metal DLPack sharing for MLX arrays and PyTorch MPS tensors.
This PR builds on the merged DLPack import PR #3495 and requires nanobind support.
The main changes are:
byte_offset.mx.from_dlpack(..., copy=...)controls for Metal DLPack inputs.mx.array(...)zero-copy for Metal DLPack inputs unless an explicit different dtype is requested.The shared lifetime is tied to the exported or imported buffer. Synchronization remains explicit: PyTorch writes require
torch.mps.synchronize()before MLX reads, and MLX writes requiremx.eval(...)before PyTorch reads.For MLX arrays exported to PyTorch, later MLX updates may rebind the MLX array to a new buffer while the PyTorch tensor continues to reference the exported buffer.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes