Memory-mapped CPU tensor between Torch, Numpy, Jax and TensorFlow
This is just a fun experiment to answer the question: how can I share a memory-mapped tensor from PyTorch to Numpy, Jax and TensorFlow in CPU without copy and making sure changes done in memory by torch are reflected on all these shared tensors ?
One approach is shown below:
import torch import tensorflow as tf import numpy as np import jax.numpy as jnp import jax.dlpack # Create the tensor and persist t = torch.randn(10, dtype=torch.float32) t.numpy().tofile("tensor.pt") # Memory-map the file with PyTorch t_mapped = torch.from_file("tensor.pt", shared=True, size=10, dtype=torch.float32) # Memory-map it with numpy, the same tensor n_mapped = np.memmap("tensor.pt", dtype='float32', mode='r+', shape=(10)) # Convert it to Jax, will reuse the same buffer j_mapped = jnp.asarray(n_mapped) # Convert it to dlpack capsule and load in TensorFlow dlcapsule = jax.dlpack.to_dlpack(j_mapped) tf_mapped = tf.experimental.dlpack.from_dlpack(dlcapsule)
Now the fun part begins, I will change the tensor in PyTorch and we will check what happens in the Numpy, Jax and TensorFlow tensors:
>>> t_mapped.fill_(42.0) # Changing only PyTorch tensorA tensor([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.]) >>> n_mapped # Numpy Array memmap([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.], dtype=float32) >>> j_mapped # Jax Array Array([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.], dtype=float32) >>> tf_mapped # TensorFlow Tensor <tf.Tensor: shape=(10,), dtype=float32, numpy=array([42., 42., 42., 42., 42., 42., 42., 42., 42., 42.], dtype=float32)>
As you can see from above, changes in the torch tensor reflected back into Numpy, Jax and TensorFlow, that’s the magic of memmap().