Machine Learning, Programming

Memory-mapped CPU tensor between Torch, Numpy, Jax and TensorFlow

This is just a fun experiment to answer the question: how can I share a memory-mapped tensor from PyTorch to Numpy, Jax and TensorFlow in CPU without copy and making sure changes done in memory by torch are reflected on all these shared tensors ?

One approach is shown below:

import torch
import tensorflow as tf
import numpy as np
import jax.numpy as jnp
import jax.dlpack

# Create the tensor and persist
t = torch.randn(10, dtype=torch.float32)
t.numpy().tofile("tensor.pt")

# Memory-map the file with PyTorch
t_mapped = torch.from_file("tensor.pt", shared=True, size=10, dtype=torch.float32)

# Memory-map it with numpy, the same tensor
n_mapped = np.memmap("tensor.pt", dtype='float32', mode='r+', shape=(10))

# Convert it to Jax, will reuse the same buffer
j_mapped = jnp.asarray(n_mapped)

# Convert it to dlpack capsule and load in TensorFlow
dlcapsule = jax.dlpack.to_dlpack(j_mapped)
tf_mapped = tf.experimental.dlpack.from_dlpack(dlcapsule)

Now the fun part begins, I will change the tensor in PyTorch and we will check what happens in the Numpy, Jax and TensorFlow tensors:

>>> t_mapped.fill_(42.0) # Changing only PyTorch tensorA
tensor([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.])

>>> n_mapped # Numpy Array
memmap([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.], dtype=float32)

>>> j_mapped # Jax Array
Array([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.], dtype=float32)

>>> tf_mapped # TensorFlow Tensor
<tf.Tensor: shape=(10,), dtype=float32, numpy=array([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.], dtype=float32)>

As you can see from above, changes in the torch tensor reflected back into Numpy, Jax and TensorFlow, that’s the magic of memmap().