Skip to content

Add Metal DLPack zero-copy sharing#3531

Open
XXXXRT666 wants to merge 18 commits into
ml-explore:mainfrom
XXXXRT666:metal-dlpack-zero-copy-draft
Open

Add Metal DLPack zero-copy sharing#3531
XXXXRT666 wants to merge 18 commits into
ml-explore:mainfrom
XXXXRT666:metal-dlpack-zero-copy-draft

Conversation

@XXXXRT666
Copy link
Copy Markdown
Contributor

@XXXXRT666 XXXXRT666 commented May 11, 2026

Proposed changes

This draft adds zero-copy Metal DLPack sharing for MLX arrays and PyTorch MPS tensors.

This PR builds on the merged DLPack import PR #3495 and requires nanobind support.

The main changes are:

  • Import Metal DLPack arrays by wrapping the underlying Metal buffer instead of copying through CPU.
  • Export MLX arrays to Metal DLPack using the MLX Metal buffer and DLPack byte_offset.
  • Add mx.from_dlpack(..., copy=...) controls for Metal DLPack inputs.
  • Keep mx.array(...) zero-copy for Metal DLPack inputs unless an explicit different dtype is requested.
  • Document the explicit synchronization requirements between PyTorch MPS and MLX.

The shared lifetime is tied to the exported or imported buffer. Synchronization remains explicit: PyTorch writes require torch.mps.synchronize() before MLX reads, and MLX writes require mx.eval(...) before PyTorch reads.

For MLX arrays exported to PyTorch, later MLX updates may rebind the MLX array to a new buffer while the PyTorch tensor continues to reference the exported buffer.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@megacpp
Copy link
Copy Markdown

megacpp commented May 13, 2026

Hi @XXXXRT666 — read through this PR after @awni redirected us here from #3548. The nb::ndarray<nb::ro, nb::c_contig> approach over the in-flight nanobind PR (#1338) is materially cleaner than the manual capsule parsing we had in our downstream PoC, and lifting is_host_accessible() into mlx/allocator.h is the right level of abstraction. Closed the RFC; happy with this being the path forward for #2848.

Wanted to offer some testing help that complements the PyTorch MPS bring-up you have:

We maintain a downstream TileLang fork (https://github.com/DatasunriseOU/tilelang) whose TVM-FFI bridge exports kDLMetal DLPack capsules for tensors backed by id<MTLBuffer>. That gives a non-PyTorch Metal DLPack producer that exercises the same import path you're adding here. Specifically it covers:

  • kDLMetal producers that do not require the PyTorch-MPS workaround for __dlpack__ — exercises the import path directly.
  • Round-trip mx.array → DLPack → TVM-FFI Metal kernel → DLPack → mx.array zero-copy.
  • Custom Metal kernels (via mlx.fast.metal_kernel) consuming an imported mx.array whose underlying MTLBuffer was allocated outside MLX.
  • storageMode matrix (we hit Shared and Managed; Private is the obvious edge case the spec needs to nail down — your is_host_accessible() decision likely answers this implicitly but worth a sanity check from the producer side).
  • byte_offset != 0 cases that we'd previously rejected outright in our PoC — your PR seems to handle these via byte_offset-aware import; happy to write a TileLang-side test for it.

If useful, once the PR converges I can:

  1. Pull this branch into our TileLang test matrix and report back on any rough edges (CI on macOS Metal hardware).
  2. Send a minimal standalone repro (no TileLang dependency) for any of the above scenarios if you'd like them as additions to python/tests/test_array.py.
  3. Beta-test the mx.from_dlpack(..., copy=...) semantics against the dtype-mismatch-shares case (002360faa) once the API stabilizes.

Tag me here when you'd like input — no rush, just don't want this to slip past once it's review-ready.

(For the orthogonal mx.empty() piece that was also in our PoC, opened it as a separate issue per @awni's guidance.)

megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 13, 2026
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing).
SHA: 33f52e635db5e6229060481d16a167230a1a474b
PR:   wjakob/nanobind#1338
Branch: metal-dlpack-cast
@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 002360f to 4e16f1d Compare May 14, 2026 04:39
@McPatate
Copy link
Copy Markdown

This would be super cool if it landed for end to end "0-copy" support in safetensors! I'm working (safetensors/safetensors#767) on adding reading bytes from disk in raw MTLBuffers, which can then be handed to the framework via dlpack with 0-copy. Works well with torch, would be happy to see that land in mlx!

Also, support for byte_offset !=0 would be nice (already in the PR but commenting to notify it's useful) since we can go one step further: currently the mps path is pread -> MTLBuffer, but that goes through kernel pages before hitting userspace buffer. Having byte_offset non zero support would enable mmap-ing the file and creating MTLBuffers that reference specific slices of the mmap, which would demand-fault pages from disk into the page cache on first access and give userspace access directly, leaving only the disk -> kernel-side copy.

Quick question on the dl_tensor.data convention, torch's mps treats it as id<MTLBuffer>, passing the contents segfaults. Curious to know which direction MLX will be taking, as it impacts us downstream!

megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 14, 2026
megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 14, 2026
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing).
SHA: 33f52e635db5e6229060481d16a167230a1a474b
PR:   wjakob/nanobind#1338
Branch: metal-dlpack-cast
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

Quick question on the dl_tensor.data convention, torch's mps treats it as id<MTLBuffer>, passing the contents segfaults. Curious to know which direction MLX will be taking, as it impacts us downstream!

https://dmlc.github.io/dlpack/latest/c_api.html#c.DLTensor.data

The data pointer points to the allocated data. This will be CUDA device pointer, cl_mem handle in OpenCL, or id<MTLBuffer> for Metal.

@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 4e16f1d to a17cd99 Compare May 19, 2026 07:44
@XXXXRT666 XXXXRT666 marked this pull request as ready for review May 19, 2026 08:40
Comment thread docs/src/usage/numpy.rst Outdated
Comment thread mlx/backend/cuda/allocator.cpp Outdated
Comment thread mlx/backend/metal/allocator.cpp Outdated
Comment thread CMakeLists.txt
Comment thread python/src/convert.cpp
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

Metal DLPack benchmark

Mean over 50 measured iterations after 5 warmups on M4. Each timed iteration synchronizes the producer before timing and synchronizes/evaluates the result before stopping the timer.

  • Branch: mlx 0.32.0.dev20260521+04665e3bf, torch 2.12.0
  • Baseline: PyPI mlx 0.31.2, torch 2.12.0
  • Shapes: 1024x1024, 2048x2048, 4096x4096; dtypes: float32, float16.
  • Bandwidth is effective bandwidth computed as tensor bytes divided by mean time. For zero-copy paths it measures conversion overhead, not physical memory-copy bandwidth.

PyTorch MPS -> MLX Metal

Branch uses mx.array(torch_mps_tensor). Baseline uses the legacy path mx.array(torch_mps_tensor.cpu()).

dtype shape branch mean baseline mean comparison
float32 1024x1024 0.0017 ms 1.0455 ms, 3.7 GiB/s 618x lower latency, 618x bandwidth
float32 2048x2048 0.0015 ms 1.8094 ms, 8.6 GiB/s 1167x lower latency, 1167x bandwidth
float32 4096x4096 0.0015 ms 7.7090 ms, 8.1 GiB/s 5209x lower latency, 5209x bandwidth
float16 1024x1024 0.0016 ms 0.6499 ms, 3.0 GiB/s 399x lower latency, 399x bandwidth
float16 2048x2048 0.0015 ms 1.9847 ms, 3.9 GiB/s 1305x lower latency, 1305x bandwidth
float16 4096x4096 0.0018 ms 4.1107 ms, 7.6 GiB/s 2306x lower latency, 2306x bandwidth

MLX Metal -> PyTorch MPS

Both variants call torch.utils.dlpack.from_dlpack(mx_array) and then ensure the result is on MPS with to("mps") if needed.

dtype shape branch mean baseline mean comparison
float32 1024x1024 0.0040 ms 0.5166 ms, 7.6 GiB/s 128x lower latency, 128x bandwidth
float32 2048x2048 0.0017 ms 1.5256 ms, 10.2 GiB/s 923x lower latency, 923x bandwidth
float32 4096x4096 0.0018 ms 2.3682 ms, 26.4 GiB/s 1353x lower latency, 1353x bandwidth
float16 1024x1024 0.0014 ms 0.4257 ms, 4.6 GiB/s 295x lower latency, 295x bandwidth
float16 2048x2048 0.0018 ms 0.7522 ms, 10.4 GiB/s 428x lower latency, 428x bandwidth
float16 4096x4096 0.0014 ms 1.4215 ms, 22.0 GiB/s 1036x lower latency, 1036x bandwidth

Comment thread docs/src/usage/numpy.rst Outdated
torch.mps.synchronize()
print(a.tolist()) # [10.0, 11.0, 12.0]

a.copy(mx.array([4, 5, 6], dtype=mx.float32))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about adding a array.copy op, I think the mental model would be simple if users just assume MLX array being immutable, is there a reason for this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed array.copy and the CopyInto primitive. The docs now keep this PR focused on zero-copy DLPack import/export and avoid introducing a new mutating array API.

Comment thread python/src/convert.cpp Outdated
}

auto out = mx::array(
mx::allocator::Buffer(owner->data_handle()),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check whether the alien buffer uses private storage, and if so create a new buffer and do the copy here, this way we can simply ensure that all Metal buffers in MLX are using unified memory.

It is likely going to require a lot more work, but we can then avoid introducing the whole is_host_accessible concept in Metal backend and isolate the dirty work in one function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Private Metal DLPack buffers are now handled in the import path

Comment thread python/src/convert.cpp Outdated

template <typename T>
mx::array metal_dlpack_to_mlx_contiguous(
std::shared_ptr<nb::ndarray<nb::ro, nb::c_contig>> owner,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need to store it in a shared_ptr, copying by value should be able to reserve the reference.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nanobind ndarray owner is now kept by value

Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp Outdated
a.eval();
a.wait();
if (dl_device_type == nb::device::cpu::value) {
data = a.data<T>();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine to use same code for cpu/gpu devices? i.e. just data = a.buffer().ptr().

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CPU DLPack export, a.buffer().ptr() is the allocator-owned buffer handle, not the host data pointer. The CPU allocator stores its size header before the actual data, while a.data<T>() goes through raw_ptr() and returns the correct host data address

For Metal, a.buffer().ptr() is the expected id<MTLBuffer> handle. So the CPU and Metal branches intentionally use different pointers here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks for clarification. In this case I think we can use data<void>() to avoid the type dispatching.

Comment thread python/src/convert.cpp Outdated
shape.data(),
/* owner= */ owner,
a.strides().data(),
nb::dtype<T>(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be simpler writing a mlx_dtype_to_dl_dtype utility and then pass a dlpack::dtype, to avoid the whole type dispatch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DLPack export now maps MLX dtype with mlx_dtype_to_dl_dtype and passes the resulting dlpack::dtype directly.

Comment thread python/src/convert.cpp Outdated
Comment thread python/src/convert.cpp
}
}

mx::array from_dlpack(nb::object v, std::optional<bool> copy) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nd_array and dlpack basically mean the same thing and having both nd_array_to_mlx and from_dlpack is confusing. Maybe moving the copy parameter to nd_array_to_mlx and have the python from_dlpack op just call it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_dlpack now just casts once to the nanobind ndarray and calls nd_array_to_mlx(copy=...); the device and copy handling lives in one path.

Comment thread python/src/convert.cpp Outdated
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

One API question: should mx.array(...) always copy DLPack inputs, and should zero-copy / copy control live in mx.asarray(..., copy=...) instead?

That would match the mental model used by NumPy/PyTorch more closely: array creates a new array, while asarray may avoid a copy depending on copy. In that design, mx.from_dlpack(..., copy=...) could remain the explicit DLPack entry point, while mx.array(torch_mps_tensor) would not unexpectedly share the underlying Metal buffer by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants