I really like to peek into different ML codebases for distributed training and this is a very short post on some things I found interesting in Torch Titan:
Disable and control of Python’s garbage collector (GC): titan codebase disables the Python GC and then manually forces a collection in the beginning of every training step during the training loop. This makes sense, but I’m not sure what are the gains of doing it, I think doing every step can be too much and I’m not sure if taking control of GC would be worth for the gains you get by manually controlling it, especially depending on complexity of other dependencies you use, as this could cause unintended behavior that would be difficult to associate with the GC collection;
Custom GPU memory monitoring: titan has a custom class to monitor GPU memory that is quite nice, it resets peak stats and empty the CUDA caching allocator upon initialization. At every step then they collect the peak stats for both small and large pools by capturing the stats for active, reserved and also failed retries and number of OOMs. It is very common for people to just monitor max GPU usage externally from NVML, however, this ignores the fact that PyTorch uses a caching allocator and that you need to look at the internal memory management mechanism inside PyTorch. If you don’t do that, you will certainly be mislead by what you are getting from NVML;
Custom profiling context manager: they wrote a context manager for profiling, where they measure time it takes to dump the profiling data per rank. Interesting here that there is a barrier at the end, which makes sense, but this is often the pain point of distributed training with PyTorch + NCCL;
Measuring data loading: this is of minor interest, but I liked the idea of not iterating on data loader in the loop statement itself but manually calling next() to get the batches, that makes it easier to measure data loading, that they average at the end for each epoch;
Logging MFU (model FLOPS utilization): they also compute and log MFU, which is quite helpful;
Delete predictions before backward: titan also deletes the model predictions before the backward() call to avoid memory peaks. This can be quite effective since you really don’t need this tensor anymore and you can delete it immediately before the backward pass;
Reduction of NCCL timeout: after the first training step, they reduce the NCCL timeout from the default 10 min to 100 sec. This is nice if you have well behaved replicas code and don’t need to do anything more complex, but 100 sec is a very short timeout that I would be careful using, it might be a good fit for your load but if your replicas drift a bit more, then you will need to keep adding barriers to avoid timeouts that can be incredibly difficult to debug and cause a lot of headaches;
Distributed checkpointing with mid-epoch checkpoint support: this is a very cool implementation, it uses distributed checkpointing from PyTorch. They create some wrappers (e.g. for optimizer) where they implement the Stateful protocol to support checkpointing. They also use the StatefulDataLoader from torchdata to do checkpointing of mid-epoch data loader state;
Misc: there are of course other interesting things, but it is cool to mention that they also implemented a no frills LLaMA model without relying on thousands of different libs (it seems it became fashionable nowadays to keep adding dependencies), so kudos for that to keep it simple.
This is just a fun experiment to answer the question: how can I share a memory-mapped tensor from PyTorch to Numpy, Jax and TensorFlow in CPU without copy and making sure changes done in memory by torch are reflected on all these shared tensors ?
One approach is shown below:
import torch
import tensorflow as tf
import numpy as np
import jax.numpy as jnp
import jax.dlpack
# Create the tensor and persist
t = torch.randn(10, dtype=torch.float32)
t.numpy().tofile("tensor.pt")
# Memory-map the file with PyTorch
t_mapped = torch.from_file("tensor.pt", shared=True, size=10, dtype=torch.float32)
# Memory-map it with numpy, the same tensor
n_mapped = np.memmap("tensor.pt", dtype='float32', mode='r+', shape=(10))
# Convert it to Jax, will reuse the same buffer
j_mapped = jnp.asarray(n_mapped)
# Convert it to dlpack capsule and load in TensorFlow
dlcapsule = jax.dlpack.to_dlpack(j_mapped)
tf_mapped = tf.experimental.dlpack.from_dlpack(dlcapsule)
Now the fun part begins, I will change the tensor in PyTorch and we will check what happens in the Numpy, Jax and TensorFlow tensors:
In 2009 I started playing with LLVM for some projects (data structure jit, for genetic programming, jit for tensorflow graphs, etc), and in these projects I realized how powerful LLVM design was at the time (and still is): using an elegant IR (intermediate representation) with an user-facing API and modularized front-ends and backends with plenty of transformation and optimization passes. Nowadays, LLVM is the main engine behind many compilers and JIT compilation and where most of the modern developments in compilers is happening.
On a related note, PyTorch is doing an amazing job of exposing more of the torch tracing system and its IR and graphs, not to mention their work on recent fusers and TorchDynamo. In this context, I was doing a small test to re-implement Shine, but using ATen ops for tensors and realized that there were not many educative tutorials on how to use LLVM to JIT PyTorch graphs, so this is a quick series (if time helps there will be more following posts) on how to use LLVM (python bindings) to go from PyTorch graphs (as traced by torch.fx) to LLVM IR and native code.
Detour – PyTorch NNC (Neural Net Compiler)
PyTorch itself also has a compiler that uses LLVM to generate native code for subgraphs that the fuser identifies. This is also called NNC (Neural Net Compiler) or Tensor Expressions (TE) as well, you can read more about them here in the C++ API tutorial. One thing to note though is that default binaries you get from PyTorch do not come linked to the LLVM libraries, so you need to compile it by yourself and enable LLVM backend, otherwise it won’t use LLVM to do the JIT compilation/optimization of the subgraphs. Let’s take a look at it first before starting our tutorial
This is the post of 2020, so happy new year to you all !
I’m a huge fan of LLVM since 11 years ago when I started playing with it to JIT data structures such as AVLs, then later to JIT restricted AST trees and to JIT native code from TensorFlow graphs. Since then, LLVM evolved into one of the most important compiler framework ecosystem and is used nowadays by a lot of important open-source projects.
One cool project that I recently became aware of is Gandiva. Gandiva was developed by Dremio and then later donated to Apache Arrow (kudos to Dremio team for that). The main idea of Gandiva is that it provides a compiler to generate LLVM IR that can operate on batches of Apache Arrow. Gandiva was written in C++ and comes with a lot of different functions implemented to build an expression tree that can be JIT’ed using LLVM. One nice feature of this design is that it can use LLVM to automatically optimize complex expressions, add native target platform vectorization such as AVX while operating on Arrow batches and execute native code to evaluate the expressions.
The image below gives an overview of Gandiva:
In this post I’ll build a very simple expression parser supporting a limited set of operations that I will use to filter a Pandas DataFrame.
Building simple expression with Gandiva
In this section I’ll show how to create a simple expression manually using tree builder from Gandiva.
Using Gandiva Python bindings to JIT and expression
Before building our parser and expression builder for expressions, let’s manually build a simple expression with Gandiva. First, we will create a simple Pandas DataFrame with numbers from 0.0 to 9.0:
import pandas as pd
import pyarrow as pa
import pyarrow.gandiva as gandiva
# Create a simple Pandas DataFrame
df = pd.DataFrame({"x": [1.0 * i for i in range(10)]})
table = pa.Table.from_pandas(df)
schema = pa.Schema.from_pandas(df)
We converted the DataFrame to an Arrow Table, it is important to note that in this case it was a zero-copy operation, Arrow isn’t copying data from Pandas and duplicating the DataFrame. Later we get the schemafrom the table, that contains column types and other metadata.
After that, we want to use Gandiva to build the following expression to filter the data:
(x > 2.0) and (x < 6.0)
This expression will be built using nodes from Gandiva:
builder = gandiva.TreeExprBuilder()
# Reference the column "x"
node_x = builder.make_field(table.schema.field("x"))
# Make two literals: 2.0 and 6.0
two = builder.make_literal(2.0, pa.float64())
six = builder.make_literal(6.0, pa.float64())
# Create a function for "x > 2.0"
gt_five_node = builder.make_function("greater_than",
[node_x, two],
pa.bool_())
# Create a function for "x < 6.0"
lt_ten_node = builder.make_function("less_than",
[node_x, six],
pa.bool_())
# Create an "and" node, for "(x > 2.0) and (x < 6.0)"
and_node = builder.make_and([gt_five_node, lt_ten_node])
# Make the expression a condition and create a filter
condition = builder.make_condition(and_node)
filter_ = gandiva.make_filter(table.schema, condition)
This code now looks a little more complex but it is easy to understand. We are basically creating the nodes of a tree that will represent the expression we showed earlier. Here is a graphical representation of what it looks like:
Inspecting the generated LLVM IR
Unfortunately, haven’t found a way to dump the LLVM IR that was generated using the Arrow’s Python bindings, however, we can just use the C++ API to build the same tree and then look at the generated LLVM IR:
auto field_x = field("x", float32());
auto schema = arrow::schema({field_x});
auto node_x = TreeExprBuilder::MakeField(field_x);
auto two = TreeExprBuilder::MakeLiteral((float_t)2.0);
auto six = TreeExprBuilder::MakeLiteral((float_t)6.0);
auto gt_five_node = TreeExprBuilder::MakeFunction("greater_than",
{node_x, two}, arrow::boolean());
auto lt_ten_node = TreeExprBuilder::MakeFunction("less_than",
{node_x, six}, arrow::boolean());
auto and_node = TreeExprBuilder::MakeAnd({gt_five_node, lt_ten_node});
auto condition = TreeExprBuilder::MakeCondition(and_node);
std::shared_ptr<Filter> filter;
auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
The code above is the same as the Python code, but using the C++ Gandiva API. Now that we built the tree in C++, we can get the LLVM Module and dump the IR code for it. The generated IR is full of boilerplate code and the JIT’ed functions from the Gandiva registry, however the important parts are show below:
As you can see, on the IR we can see the call to the functions less_than_float32_float_32 and greater_than_float32_float32that are the (in this case very simple) Gandiva functions to do float comparisons. Note the specialization of the function by looking at the function name prefix.
What is quite interesting is that LLVM will apply all optimizations in this code and it will generate efficient native code for the target platform while Godiva and LLVM will take care of making sure that memory alignment will be correct for extensions such as AVX to be used for vectorization.
This IR code I showed isn’t actually the one that is executed, but the optimized one. And in the optimized one we can see that LLVM inlined the functions, as shown in a part of the optimized code below:
You can see that the expression is now much simpler after optimization as LLVM applied its powerful optimizations and inlined a lot of Gandiva funcions.
Building a Pandas filter expression JIT with Gandiva
Now we want to be able to implement something similar as the Pandas’ DataFrame.query()function using Gandiva. The first problem we will face is that we need to parse a string such as (x > 2.0) and (x < 6.0), later we will have to build the Gandiva expression tree using the tree builder from Gandiva and then evaluate that expression on arrow data.
Now, instead of implementing a full parsing of the expression string, I’ll use the Python AST module to parse valid Python code and build an Abstract Syntax Tree (AST) of that expression, that I’ll be later using to emit the Gandiva/LLVM nodes.
The heavy work of parsing the string will be delegated to Python AST module and our work will be mostly walking on this tree and emitting the Gandiva nodes based on that syntax tree. The code for visiting the nodes of this Python AST tree and emitting Gandiva nodes is shown below:
As you can see, the code is pretty straightforward as I’m not supporting every possible Python expressions but a minor subset of it. What we do in this class is basically a conversion of the Python AST nodes such as Comparators and BinOps (binary operations) to the Gandiva nodes. I’m also changing the semantics of the & and the | operators to represent AND and OR respectively, such as in Pandas query()function.
Register as a Pandas extension
The next step is to create a simple Pandas extension using the gandiva_query() method that we created:
And that is it, now we can use this extension to do things such as:
df = pd.DataFrame({"a": [1.0 * i for i in range(nsize)]})
results = df.gandiva.query("a > 10.0")
As we have registered a Pandas extension called gandiva that is now a first-class citizen of the Pandas DataFrames.
Let’s create now a 5 million floats DataFrame and use the new query() method to filter it:
df = pd.DataFrame({"a": [1.0 * i for i in range(50000000)]})
df.gandiva.query("a < 4.0")
# This will output:
# array([0, 1, 2, 3], dtype=uint32)
Note that the returned values are the indexes satisfying the condition we implemented, so it is different than the Pandas query()that returns the data already filtered.
I did some benchmarks and found that Gandiva is usually always faster than Pandas, however I’ll leave proper benchmarks for a next post on Gandiva as this post was to show how you can use it to JIT expressions.
That’s it ! I hope you liked the post as I enjoyed exploring Gandiva. It seems that we will probably have more and more tools coming up with Gandiva acceleration, specially for SQL parsing/projection/JITing. Gandiva is much more than what I just showed, but you can get started now to understand more of its architecture and how to build the expression trees.
This post is a tour around the PyTorch codebase, it is meant to be a guide for the architectural design of PyTorch and its internals. My main goal is to provide something useful for those who are interested in understanding what happens beyond the user-facing API and show something new beyond what was already covered in other tutorials.
Note: PyTorch build system uses code generation extensively so I won’t repeat here what was already described by others. If you’re interested in understanding how this works, please read the following tutorials:
As you probably know, you can extend Python using C and C++ and develop what is called as “extension”. All the PyTorch heavy work is implemented in C/C++ instead of pure-Python. To define a new Python object type in C/C++, you define a structure like this one example below (which is the base for the autograd Variable class):
As you can see, there is a macro at the beginning of the definition, called PyObject_HEAD, this macro’s goal is the standardization of Python objects and will expand to another structure that contains a pointer to a type object (which defines initialization methods, allocators, etc) and also a field with a reference counter.
There are two extra macros in the Python API called Py_INCREF() and Py_DECREF(), which are used to increment and decrement the reference counter of Python objects. Multiple entities can borrow or own a reference to other objects (the reference counter is increased), and only when this reference counter reaches zero (when all references get destroyed), Python will automatically delete the memory from that object using its garbage collector.
You can read more about Python C/++ extensions here.
Funny fact: it is very common in many applications to use small integer numbers as indexing, counters, etc. For efficiency, the official CPython interpreter caches the integers from -5 up to 256. For that reason, the statement a = 200; b = 200; a is b will be True, while the statement a = 300; b = 300; a is b will be False.
Zero-copy PyTorch Tensor to Numpy and vice-versa
PyTorch has its own Tensor representation, which decouples PyTorch internal representation from external representations. However, as it is very common, especially when data is loaded from a variety of sources, to have Numpy arrays everywhere, therefore we really need to make conversions between Numpy and PyTorch tensors. For that reason, PyTorch provides two methods called from_numpy() and numpy(), that converts a Numpy array to a PyTorch array and vice-versa, respectively. If we look the code that is being called to convert a Numpy array into a PyTorch tensor, we can get more insights on the PyTorch’s internal representation:
at::Tensor tensor_from_numpy(PyObject* obj) {
if (!PyArray_Check(obj)) {
throw TypeError("expected np.ndarray (got %s)", Py_TYPE(obj)->tp_name);
}
auto array = (PyArrayObject*)obj;
int ndim = PyArray_NDIM(array);
auto sizes = to_aten_shape(ndim, PyArray_DIMS(array));
auto strides = to_aten_shape(ndim, PyArray_STRIDES(array));
// NumPy strides use bytes. Torch strides use element counts.
auto element_size_in_bytes = PyArray_ITEMSIZE(array);
for (auto& stride : strides) {
stride /= element_size_in_bytes;
}
// (...) - omitted for brevity
void* data_ptr = PyArray_DATA(array);
auto& type = CPU(dtype_to_aten(PyArray_TYPE(array)));
Py_INCREF(obj);
return type.tensorFromBlob(data_ptr, sizes, strides, [obj](void* data) {
AutoGIL gil;
Py_DECREF(obj);
});
}
As you can see from this code, PyTorch is obtaining all information (array metadata) from Numpy representation and then creating its own. However, as you can note from the marked line 18, PyTorch is getting a pointer to the internal Numpy array raw data instead of copying it. This means that PyTorch will create a reference for this data, sharing the same memory region with the Numpy array object for the raw Tensor data.
There is also an important point here: when Numpy array object goes out of scope and get a zero reference count, it will be garbage collected and destroyed, that’s why there is an increment in the reference counting of the Numpy array object at line 20.
After this, PyTorch will create a new Tensor object from this Numpy data blob, and in the creation of this new Tensor it passes the borrowed memory data pointer, together with the memory size and strides as well as a function that will be used later by the Tensor Storage (we’ll discuss this in the next section) to release the data by decrementing the reference counting to the Numpy array object and let Python take care of this object life cycle.
The tensorFromBlob() method will create a new Tensor, but only after creating a new “Storage” for this Tensor. The storage is where the actual data pointer will be stored (and not in the Tensor structure itself). This takes us to the next section about Tensor Storages.
Tensor Storage
The actual raw data of the Tensor is not directly kept in the Tensor structure, but on another structure called Storage, which in turn is part of the Tensor structure.
As we saw in the previous code from tensor_from_numpy(), there is a call for tensorFromBlob() that will create a Tensor from the raw data blob. This last function will call another function called storageFromBlob() that will, in turn, create a storage for this data according to its type. In the case of a CPU float type, it will return a new CPUFloatStorage instance.
The CPUFloatStorage is basically a wrapper with utility functions around the actual storage structure called THFloatStorage that we show below:
As you can see, the THStorage holds a pointer to the raw data, its size, flags and also an interesting field called allocator that we’ll soon discuss. It is also important to note that there is no metadata regarding on how to interpret the data inside the THStorage, this is due to the fact that the storage is “dumb” regarding of its contents and it is the Tensor responsibility to know how to “view” or interpret this data.
From this, you already probably realized that we can have multiple tensors pointing to the same storage but with different views of this data, and that’s why viewing a tensor with a different shape (but keeping the same number of elements) is so efficient. This Python code below shows that the data pointer in the storage is being shared after changing the way Tensor views its data:
As we can see in the example above, the data pointer on the storage of both Tensors are the same, but the Tensors represent a different interpretation of the storage data.
Now, as we saw in line 7 of the THFloatStorage structure, there is a pointer to a THAllocator structure there. And this is very important because it brings flexibility regarding the allocator that can be used to allocate the storage data. This structure is represented by the following code:
As you can see, there are three function pointer fields in this structure to define what an allocator means: a malloc, realloc and free. For CPU-allocated memory, these functions will, of course, relate to the traditional malloc/realloc/free POSIX functions, however, when we want a storage allocated on GPUs we’ll end up using the CUDA allocators such as the cudaMallocHost(), like we can see in the THCudaHostAllocator malloc function below:
You probably noticed a pattern in the repository organization, but it is important to keep in mind these conventions when navigating the repository, as summarized here (taken from the PyTorch lib readme):
TH = TorcH
THC = TorcHCuda
THCS = TorcHCuda Sparse
THCUNN = TorcHCUda Neural Network
THD = TorcHDistributed
THNN = TorcHNeural Network
THS = TorcH Sparse
This convention is also present in the function/class names and other objects, so it is important to always keep these patterns in mind. While you can find CPU allocators in the TH code, you’ll find CUDA allocators in the THC code.
Finally, we can see the composition of the main Tensor THTensor structure:
typedef struct THTensor
{
int64_t *size;
int64_t *stride;
int nDimension;
THStorage *storage;
ptrdiff_t storageOffset;
int refcount;
char flag;
} THTensor;
And as you can see, the main THTensor structure holds the size/strides/dimensions/offsets/etc as well as the storage (THStorage) for the Tensor data.
We can summarize all this structure that we saw in the diagram below:
Now, once we have requirements such as multi-processing where we want to share tensor data among multiple different processes, we need a shared memory approach to solve it, otherwise, every time another process needs a tensor or even when you want to implement Hogwild training procedure where all different processes will write to the same memory region (where the parameters are), you’ll need to make copies between processes, and this is very inefficient. Therefore we’ll discuss in the next section a special kind of storage for Shared Memory.
Shared Memory
Shared memory can be implemented in many different ways depending on the platform support. PyTorch supports some of them, but for the sake of simplicity, I’ll talk here about what happens on MacOS using the CPU (instead of GPU). Since PyTorch supports multiple shared memory approaches, this part is a little tricky to grasp into since it involves more levels of indirection in the code.
PyTorch provides a wrapper around the Python multiprocessing module and can be imported from torch.multiprocessing. The changes they implemented in this wrapper around the official Python multiprocessing were done to make sure that everytime a tensor is put on a queue or shared with another process, PyTorch will make sure that only a handle for the shared memory will be shared instead of a new entire copy of the Tensor.
Now, many people aren’t aware of a Tensor method from PyTorch called share_memory_(), however, this function is what triggers an entire rebuild of the storage memory for that particular Tensor. What this method does is to create a region of shared memory that can be used among different processes. This function will, in the end, call this following function below:
And as you can see, this function will create another storage using a special allocator called THManagedSharedAllocator. This function first defines some flags and then it creates a handle which is a string in the format /torch_[process id]_[random number], and after that, it will then create a new storage using the special THManagedSharedAllocator. This allocator has function pointers to an internal PyTorch library called libshm, that will implement a Unix Domain Socket communication to share the shared memory region handles. This allocator is actual an especial case and it is a kind of “smart allocator” because it contains the communication control logic as well as it uses another allocator called THRefcountedMapAllocator that will be responsible for creating the actual shared memory region and call mmap() to map this region to the process virtual address space.
Note: when a method ends with a underscore in PyTorch, such as the method called share_memory_(), it means that this method has an in-place effect, and it will change the current object instead of creating a new one with the modifications.
I’ll now show a Python example of one processing using the data from a Tensor that was allocated on another process by manually exchanging the shared memory handle:
In this code, executed in the process A, we create a new Tensor of 5×5 filled with ones. After that we make it shared and print the tuple with the Unix Domain Socket address as well as the handle. Now we can access this memory region from another process B as shown below:
As you can see, using the tuple information about the Unix Domain Socket address and the handle we were able to access the Tensor storage from another process. If you change the tensor in this process B, you’ll also see that it will reflect in the process A because these Tensors are sharing the same memory region.
DLPack: a hope for the Deep Learning frameworks Babel
Now I would like to talk about something recent in the PyTorch code base, that is called DLPack. DLPack is an open standardization of an in-memory tensor structure that will allow exchange tensor data between frameworks, and what is quite interesting is that since this memory representation is standardized and very similar to the memory representation already in use by many frameworks, it will allow a zero-copy data sharing between frameworks, which is a quite amazing initiative given the variety of frameworks we have today without inter-communication among them.
This will certainly help to overcome the “island model” that we have today between tensor representations in MXNet, PyTorch, etc, and will allow developers to mix framework operations between frameworks and all the benefits that a standardization can bring to the frameworks.
The core of DLPack os a very simple structure called DLTensor, as shown below:
/*!
* \brief Plain C Tensor object, does not manage memory.
*/
typedef struct {
/*!
* \brief The opaque data pointer points to the allocated data.
* This will be CUDA device pointer or cl_mem handle in OpenCL.
* This pointer is always aligns to 256 bytes as in CUDA.
*/
void* data;
/*! \brief The device context of the tensor */
DLContext ctx;
/*! \brief Number of dimensions */
int ndim;
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor,
* can be NULL, indicating tensor is compact.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;
As you can see, there is a data pointer for the raw data as well as shape/stride/offset/GPU vs CPU, and other metadata information about the data that the DLTensor pointing to.
There is also a managed version of the tensor that is called DLManagedTensor, where the frameworks can provide a context and also a “deleter” function that can be called by the framework who borrowed the Tensor to inform the other framework that the resources are no longer required.
In PyTorch, if you want to convert to or from a DLTensor format, you can find both C/C++ methods for doing that or even in Python you can do that as shown below:
import torch
from torch.utils import dlpack
t = torch.ones((5, 5))
dl = dlpack.to_dlpack(t)
This Python function will call the toDLPack function from ATen, shown below:
As you can see, it’s a pretty simple conversion, casting the metadata from the PyTorch format to the DLPack format and assigning a pointer to the internal Tensor data representation.
I really hope that more frameworks adopt this standard that will certainly give benefits to the ecosystem. It is also interesting to note that a potential integration with Apache Arrow would be amazing.
Privacy-preserving computation or secure computation is a sub-field of cryptography where two (two-party, or 2PC) or multiple (multi-party, or MPC) parties can evaluate a function together without revealing information about the parties private input data to each other. The problem and the first solution to it were introduced in 1982 by an amazing breakthrough done by Andrew Yao on what later became known as the “Yao’s Millionaires’ problem“.
The Yao’s Millionaires Problem is where two millionaires, Alice and Bob, who are interested in knowing which of them is richer but without revealing to each other their actual wealth. In other words, what they want can be generalized as that: Alice and Bob want jointly compute a function securely, without knowing anything other than the result of the computation on the input data (that remains private to them).
To make the problem concrete, Alice has an amount A such as $10, and Bob has an amount B such as $ 50, and what they want to know is which one is larger, without Bob revealing the amount B to Alice or Alice revealing the amount A to Bob. It is also important to note that we also don’t want to trust on a third-party, otherwise the problem would just be a simple protocol of information exchange with the trusted party.
Formally what we want is to jointly evaluate the following function:
Such as the private values A and B are held private to the sole owner of it and where the result r will be known to just one or both of the parties.
It seems very counterintuitive that a problem like that could ever be solved, but for the surprise of many people, it is possible to solve it on some security requirements. Thanks to the recent developments in techniques such as FHE (Fully Homomorphic Encryption), Oblivious Transfer, Garbled Circuits, problems like that started to get practical for real-life usage and they are being nowadays being used by many companies in applications such as information exchange, secure location, advertisement, satellite orbit collision avoidance, etc.
I’m not going to enter into details of these techniques, but if you’re interested in the intuition behind the OT (Oblivious Transfer), you should definitely read the amazing explanation done by Craig Gidney here. There are also, of course, many different protocols for doing 2PC or MPC, where each one of them assumes some security requirements (semi-honest, malicious, etc), I’m not going to enter into the details to keep the post focused on the goal, but you should be aware of that.
The problem: sentence similarity
What we want to achieve is to use privacy-preserving computation to calculate the similarity between sentences without disclosing the content of the sentences. Just to give a concrete example: Bob owns a company and has the description of many different projects in sentences such as: “This project is about building a deep learning sentiment analysis framework that will be used for tweets“, and Alice who owns another competitor company, has also different projects described in similar sentences. What they want to do is to jointly compute the similarity between projects in order to find if they should be doing partnership on a project or not, however, and this is the important point: Bob doesn’t want Alice to know the project descriptions and neither Alice wants Bob to be aware of their projects, they want to know the closest match between the different projects they run, but without disclosing the project ideas (project descriptions).
Sentence Similarity Comparison
Now, how can we exchange information about the Bob and Alice’s project sentences without disclosing information about the project descriptions ?
One naive way to do that would be to just compute the hashes of the sentences and then compare only the hashes to check if they match. However, this would assume that the descriptions are exactly the same, and besides that, if the entropy of the sentences is small (like small sentences), someone with reasonable computation power can try to recover the sentence.
Another approach for this problem (this is the approach that we’ll be using), is to compare the sentences in the sentence embeddings space. We just need to create sentence embeddings using a Machine Learning model (we’ll use InferSent later) and then compare the embeddings of the sentences. However, this approach also raises another concern: what if Bob or Alice trains a Seq2Seq model that would go from the embeddings of the other party back to an approximate description of the project ?
It isn’t unreasonable to think that one can recover an approximate description of the sentence given their embeddings. That’s why we’ll use the two-party secure computation for computing the embeddings similarity, in a way that Bob and Alice will compute the similarity of the embeddings without revealing their embeddings, keeping their project ideas safe.
The entire flow is described in the image below, where Bob and Alice shares the same Machine Learning model, after that they use this model to go from sentences to embeddings, followed by a secure computation of the similarity in the embedding space.
Generating sentence embeddings with InferSent
InferSent is an NLP technique for universal sentence representation developed by Facebook that uses supervised training to produce high transferable representations.
They used a Bi-directional LSTM with attention that consistently surpassed many unsupervised training methods such as the SkipThought vectors. They also provide a Pytorch implementation that we’ll use to generate sentence embeddings.
Note: even if you don’t have GPU, you can have reasonable performance doing embeddings for a few sentences.
The first step to generate the sentence embeddings is to download and load a pre-trained InferSent model:
import numpy as np
import torch
# Trained model from: https://github.com/facebookresearch/InferSent
GLOVE_EMBS = '../dataset/GloVe/glove.840B.300d.txt'
INFERSENT_MODEL = 'infersent.allnli.pickle'
# Load trained InferSent model
model = torch.load(INFERSENT_MODEL,
map_location=lambda storage, loc: storage)
model.set_glove_path(GLOVE_EMBS)
model.build_vocab_k_words(K=100000)
As you can see, if we have two unit vectors (vectors with norm 1), the two terms in the equation denominator will be 1 and we will be able to remove the entire denominator of the equation, leaving only:
So, if we normalize our vectors to have a unit norm (that’s why the vectors are wearing hats in the equation above), we can make the computation of the cosine similarity become just a simple dot product. That will help us a lot in computing the similarity distance later when we’ll use a framework to do the secure computation of this dot product.
So, the next step is to define a function that will take some sentence text and forward it to the model to generate the embeddings and then normalize them to unit vectors:
# This function will forward the text into the model and
# get the embeddings. After that, it will normalize it
# to a unit vector.
def encode(model, text):
embedding = model.encode([text])[0]
embedding /= np.linalg.norm(embedding)
return embedding
As you can see, this function is pretty simple, it feeds the text into the model, and then it will divide the embedding vector by the embedding norm.
Now, for practical reasons, I’ll be using integer computation later for computing the similarity, however, the embeddings generated by InferSent are of course real values. For that reason, you’ll see in the code below that we create another function to scale the float values and remove the radix point andconverting them to integers. There is also another important issue, the framework that we’ll be using later for secure computation doesn’t allow signed integers, so we also need to clip the embeddings values between 0.0 and 1.0. This will of course cause some approximation errors, however, we can still get very good approximations after clipping and scaling with limited precision (I’m using 14 bits for scaling to avoid overflow issues later during dot product computations):
# This function will scale the embedding in order to
# remove the radix point.
def scale(embedding):
SCALE = 1 << 14
scale_embedding = np.clip(embedding, 0.0, 1.0) * SCALE
return scale_embedding.astype(np.int32)
You can use floating-point in your secure computations and there are a lot of frameworks that support them, however, it is more tricky to do that, and for that reason, I used integer arithmetic to simplify the tutorial. The function above is just a hack to make it simple. It’s easy to see that we can recover this embedding later without too much loss of precision.
Now we just need to create some sentence samples that we’ll be using:
# The list of Alice sentences
alice_sentences = [
'my cat loves to walk over my keyboard',
'I like to pet my cat',
]
# The list of Bob sentences
bob_sentences = [
'the cat is always walking over my keyboard',
]
And convert them to embeddings:
# Alice sentences
alice_sentence1 = encode(model, alice_sentences[0])
alice_sentence2 = encode(model, alice_sentences[1])
# Bob sentences
bob_sentence1 = encode(model, bob_sentences[0])
Since we have now the sentences and every sentence is also normalized, we can compute cosine similarity just by doing a dot product between the vectors:
As we can see, the first sentence of Bob is most similar (~0.87) with Alice first sentence than to the Alice second sentence (~0.62).
Since we have now the embeddings, we just need to convert them to scaled integers:
# Scale the Alice sentence embeddings
alice_sentence1_scaled = scale(alice_sentence1)
alice_sentence2_scaled = scale(alice_sentence2)
# Scale the Bob sentence embeddings
bob_sentence1_scaled = scale(bob_sentence1)
# This is the unit vector embedding for the sentence
>>> alice_sentence1
array([ 0.01698913, -0.0014404 , 0.0010993 , ..., 0.00252409,
0.00828147, 0.00466533], dtype=float32)
# This is the scaled vector as integers
>>> alice_sentence1_scaled
array([278, 0, 18, ..., 41, 135, 76], dtype=int32)
Now with these embeddings as scaled integers, we can proceed to the second part, where we’ll be doing the secure computation between two parties.
Two-party secure computation
In order to perform secure computation between the two parties (Alice and Bob), we’ll use the ABY framework. ABY implements many difference secure computation schemes and allows you to describe your computation as a circuit like pictured in the image below, where the Yao’s Millionaire’s problem is described:
As you can see, we have two inputs entering in one GT GATE (greater than gate) and then a output. This circuit has a bit length of 3 for each input and will compute if the Alice input is greater than (GT GATE) the Bob input. The computing parties then secret share their private data and then can use arithmetic sharing, boolean sharing, or Yao sharing to securely evaluate these gates.
ABY is really easy to use because you can just describe your inputs, shares, gates and it will do the rest for you such as creating the socket communication channel, exchanging data when needed, etc. However, the implementation is entirely written in C++ and I’m not aware of any Python bindings for it (a great contribution opportunity).
Fortunately, there is an implemented example for ABY that can do dot product calculation for us, the example is here. I won’t replicate the example here, but the only part that we have to change is to read the embedding vectors that we created before instead of generating random vectors and increasing the bit length to 32-bits.
After that, we just need to execute the application on two different machines (or by emulating locally like below):
# This will execute the server part, the -r 0 specifies the role (server)
# and the -n 4096 defines the dimension of the vector (InferSent generates
# 4096-dimensional embeddings).
~# ./innerproduct -r 0 -n 4096
# And the same on another process (or another machine, however for another
# machine execution you'll have to obviously specify the IP).
~# ./innerproduct -r 1 -n 4096
And we get the following results:
Inner Product of alice_sentence1 and bob_sentence1 = 226691917
Inner Product of alice_sentence2 and bob_sentence1 = 171746521
Even in the integer representation, you can see that the inner product of the Alice’s first sentence and the Bob sentence is higher, meaning that the similarity is also higher. But let’s now convert this value back to float:
>>> SCALE = 1 << 14
# This is the dot product we should get
>>> np.dot(alice_sentence1, bob_sentence1)
0.8798542
# This is the inner product we got on secure computation
>>> 226691917 / SCALE**2.0
0.8444931
# This is the dot product we should get
>>> np.dot(alice_sentence2, bob_sentence1)
0.6297632
# This is the inner product we got on secure computation
>>> 171746521 / SCALE**2.0
0.6398056
As you can see, we got very good approximations, even in presence of low-precision math and unsigned integer requirements. Of course that in real-life you won’t have the two values and vectors, because they’re supposed to be hidden, but the changes to accommodate that are trivial, you just need to adjust ABY code to load only the vector of the party that it is executing it and using the correct IP addresses/port of the both parties.
Hello everyone, I just released the Nanopipe project. Nanopipe is a library that allows you to connect different message queue systems (but not limited to) together. Nanopipe was built to avoid the glue code between different types of communication protocols/channels that is very common nowadays. An example of this is: you have an application that is listening for messages on an AMQP broker (ie. RabbitMQ) but you also have a Redis pub/sub source of messages and also a MQTT source from a weird IoT device you may have. Using Nanopipe, you can connect both MQTT and Redis to RabbitMQ without doing any glue code for that. You can also build any kind of complex connection scheme using Nanopipe.
One of the most amazing components of the TensorFlow architecture is the computation graph that can be serialized using Protocol Buffers. This computation graph follows a well-defined format (click here for the proto files) and describes the computation that you specify (it can be a Deep Learning model like a CNN, a simple Logistic Regression or even any computation you want). For instance, here is an example of a very simple TensorFlow computation graph that we will use in this tutorial (using TensorFlow Python API):
As you can see, this is a very simple computation graph. First, we define the placeholder that will hold the input tensor and after that we specify the computation that should happen using this input tensor as input data. Here we can also see that we’re defining two important nodes of this graph, one is called “input” (the aforementioned placeholder) and the other is called “output“, that will hold the result of the final computation. This graph is the same as the following formula for a scalar: , where I intentionally added redundant operations to see LLVM constant propagation later.
In the last line of the code, we’re persisting this computation graph (including the constant values) into a serialized protobuf file. The final True parameter is to output a textual representation instead of binary, so it will produce the following human-readable output protobuf file (I omitted a part of it for brevity):
This is a very simple graph, and TensorFlow graphs are actually never that simple, because TensorFlow models can easily contain more than 300 nodes depending on the model you’re specifying, specially for Deep Learning models.
We’ll use the above graph to show how we can JIT native code for this simple graph using LLVM framework.
The LLVM Frontend, IR and Backend
The LLVM framework is a really nice, modular and complete ecosystem for building compilers and toolchains. A very nice description of the LLVM architecture that is important for us is shown in the picture below:
(The picture above is just a small part of the LLVM architecture, for a comprehensive description of it, please see the nice article from the AOSA book written by Chris Lattner)
Looking in the image above, we can see that LLVM provides a lot of core functionality, in the left side you see that many languages can write code for their respective language frontends, after that it doesn’t matter in which language you wrote your code, everything is transformed into a very powerful language called LLVM IR (LLVM Intermediate Representation) which is as you can imagine, a intermediate representation of the code just before the assembly code itself. In my opinion, the IR is the key component of what makes LLVM so amazing, because it doesn’t matter in which language you wrote your code (or even if it was a JIT’ed IR), everything ends in the same representation, and then here is where the magic happens, because the IR can take advantage of the LLVM optimizations (also known as transform and analysis passes).
After this IR generation, you can feed it into any LLVM backend to generate native code for any architecture supported by LLVM (such as x86, ARM, PPC, etc) and then you can finally execute your code with the native performance and also after LLVM optimization passes.
In order to JIT code using LLVM, all you need is to build the IR programmatically, create a execution engine to convert (during execution-time) the IR into native code, get a pointer for the function you have JIT’ed and then finally execute it. I’ll use here a Python binding for LLVM called llvmlite, which is very Pythonic and easy to use.
JIT’ing TensorFlow Graph using Python and LLVM
Let’s now use the LLVM and Python to JIT the TensorFlow computational graph. This is by no means a comprehensive implementation, it is very simplistic approach, a oversimplification that assumes some things: a integer closure type, just some TensorFlow operations and also a single scalar support instead of high rank tensors.
So, let’s start building our JIT code; first of all, let’s import the required packages, initialize some LLVM sub-systems and also define the LLVM respective type for the TensorFlow integer type:
from ctypes import CFUNCTYPE, c_int
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import ops
import llvmlite.ir as ll
import llvmlite.binding as llvm
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
TYPE_TF_LLVM = {
types_pb2.DT_INT32: ll.IntType(32),
}
After that, let’s define a class to open the TensorFlow exported graph and also declare a method to get a node of the graph by name:
class TFGraph(object):
def __init__(self, filename="graph.pb", binary=False):
self.graph_def = graph_pb2.GraphDef()
with open("graph.pb", "rb") as f:
if binary:
self.graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), self.graph_def)
def get_node(self, name):
for node in self.graph_def.node:
if node.name == name:
return node
And let’s start by defining our main function that will be the starting point of the code:
As you can see in the code above, we open the serialized protobuf graph and then get the input and output nodes of this graph. After that we also map the type of the both graph nodes (input/output) to the LLVM type (from TensorFlow integer to LLVM integer). We start then by defining a LLVM Module, which is the top level container for all IR objects. One module in LLVM can contain many different functions, here we will create just one function that will represent the graph, this function will receive as input argument the input data of the same type of the input node and then it will return a value with the same type of the output node.
After that we start by creating the entry block of the function and using this block we instantiate our IR Builder, which is a object that will provide us the building blocks for JIT’ing operations of TensorFlow graph.
Let’s now define the function that will do the real work of converting TensorFlow nodes into LLVM IR:
def build_graph(ir_builder, graph, node):
if node.op == "Add":
left_op_node = graph.get_node(node.input[0])
right_op_node = graph.get_node(node.input[1])
left_op = build_graph(ir_builder, graph, left_op_node)
right_op = build_graph(ir_builder, graph, right_op_node)
return ir_builder.add(left_op, right_op)
if node.op == "Sub":
left_op_node = graph.get_node(node.input[0])
right_op_node = graph.get_node(node.input[1])
left_op = build_graph(ir_builder, graph, left_op_node)
right_op = build_graph(ir_builder, graph, right_op_node)
return ir_builder.sub(left_op, right_op)
if node.op == "Placeholder":
function_args = ir_builder.function.args
for arg in function_args:
if arg.name == node.name:
return arg
raise RuntimeError("Input [{}] not found !".format(node.name))
if node.op == "Const":
llvm_const_type = TYPE_TF_LLVM[node.attr["dtype"].type]
const_value = node.attr["value"].tensor.int_val[0]
llvm_const_value = llvm_const_type(const_value)
return llvm_const_value
In this function, we receive by parameters the IR Builder, the graph class that we created earlier and the output node. This function will then recursively build the LLVM IR by means of the IR Builder. Here you can see that I only implemented the Add/Sub/Placeholder and Const operations from the TensorFlow graph, just to be able to support the graph that we defined earlier.
After that, we just need to define a function that will take a LLVM Module and then create a execution engine that will execute the LLVM optimization over the LLVM IR before doing the hard-work of converting the IR into native x86 code:
In the code above, you can see that we first get the CPU features (SSE, etc) into a list, after that we parse the LLVM IR from the module and then we create a engine using maximum optimization level (opt=3, roughly equivalent to the GCC -O3 parameter), we’re also printing the assembly code (in my case, the x86 assembly built by LLVM).
And here we just finish our run_main() function:
ret = build_graph(ir_builder, graph, output_node)
ir_builder.ret(ret)
with open("output.ir", "w") as f:
f.write(str(module))
engine = create_engine(module)
func_ptr = engine.get_function_address("tensorflow_graph")
cfunc = CFUNCTYPE(c_int, c_int)(func_ptr)
ret = cfunc(10)
print "Execution output: {}".format(ret)
As you can see in the code above, we just call the build_graph() method and then use the IR Builder to add the “ret” LLVM IR instruction (ret = return) to return the output of the IR function we just created based on the TensorFlow graph. We’re also here writing the IR output to a external file, I’ll use this LLVM IR file later to create native assembly for other different architectures such as ARM architecture. And finally, just get the native code function address, create a Python wrapper for this function and then call it with the argument “10”, which will be input data and then output the resulting output value.
And that is it, of course that this is just a oversimplification, but now we understand the advantages of having a JIT for our TensorFlow models.
The output LLVM IR, the advantage of optimizations and multiple architectures (ARM, PPC, x86, etc)
For instance, lets create the LLVM IR (using the code I shown above) of the following TensorFlow graph:
As you can see, the LLVM IR looks a lot like an assembly code, but this is not the final assembly code, this is just a non-optimized IR yet. Just before generating the x86 assembly code, LLVM runs a lot of optimization passes over the LLVM IR, and it will do things such as dead code elimination, constant propagation, etc. And here is the final native x86 assembly code that LLVM generates for the above LLVM IR of the TensorFlow graph:
As you can see, the optimized code removed a lot of redundant operations, and ended up just doing a add operation of 103, which is the correct simplification of the computation that we defined in the graph. For large graphs, you can see that these optimizations can be really powerful, because we are reusing the compiler optimizations that were developed for years in our Machine Learning model computation.
You can also use a LLVM tool called “llc”, that can take an LLVM IR file and the generate assembly for any other platform you want, for instance, the command-line below will generate native code for ARM architecture:
As you can see above, the ARM assembly code is also just a “add” assembly instruction followed by a return instruction.
This is really nice because we can take natural advantage of the LLVM framework. For instance, today ARM just announced the ARMv8-A with Scalable Vector Extensions (SVE) that will support 2048-bit vectors, and they are already working on patches for LLVM. In future, a really nice addition to LLVM would be the development of LLVM Passes for analysis and transformation that would take into consideration the nature of Machine Learning models.
And that’s it, I hope you liked the post ! Is really awesome what you can do with a few lines of Python, LLVM and TensorFlow.
Update 22 Aug 2016: TensorFlow team is actually working on a JIT (I don’t know if they are using LLVM, but it seems the most reasonable way to go in my opinion). In their paper, there is also a very important statement regarding Future Work that I cite here:
“We also have a number of concrete directions to improve the performance of TensorFlow. One such direction is our initial work on a just-in-time compiler that can take a subgraph of a TensorFlow execution, perhaps with some runtime profiling information about the typical sizes and shapes of tensors, and can generate an optimized routine for this subgraph. This compiler will understand the semantics of perform a number of optimizations such as loop fusion, blocking and tiling for locality, specialization for particular shapes and sizes, etc.” – TensorFlow White Paper
Full code
from ctypes import CFUNCTYPE, c_int
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import ops
import llvmlite.ir as ll
import llvmlite.binding as llvm
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
TYPE_TF_LLVM = {
types_pb2.DT_INT32: ll.IntType(32),
}
class TFGraph(object):
def __init__(self, filename="graph.pb", binary=False):
self.graph_def = graph_pb2.GraphDef()
with open("graph.pb", "rb") as f:
if binary:
self.graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), self.graph_def)
def get_node(self, name):
for node in self.graph_def.node:
if node.name == name:
return node
def build_graph(ir_builder, graph, node):
if node.op == "Add":
left_op_node = graph.get_node(node.input[0])
right_op_node = graph.get_node(node.input[1])
left_op = build_graph(ir_builder, graph, left_op_node)
right_op = build_graph(ir_builder, graph, right_op_node)
return ir_builder.add(left_op, right_op)
if node.op == "Sub":
left_op_node = graph.get_node(node.input[0])
right_op_node = graph.get_node(node.input[1])
left_op = build_graph(ir_builder, graph, left_op_node)
right_op = build_graph(ir_builder, graph, right_op_node)
return ir_builder.sub(left_op, right_op)
if node.op == "Placeholder":
function_args = ir_builder.function.args
for arg in function_args:
if arg.name == node.name:
return arg
raise RuntimeError("Input [{}] not found !".format(node.name))
if node.op == "Const":
llvm_const_type = TYPE_TF_LLVM[node.attr["dtype"].type]
const_value = node.attr["value"].tensor.int_val[0]
llvm_const_value = llvm_const_type(const_value)
return llvm_const_value
def create_engine(module):
features = llvm.get_host_cpu_features().flatten()
llvm_module = llvm.parse_assembly(str(module))
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine(opt=3, features=features)
engine = llvm.create_mcjit_compiler(llvm_module, target_machine)
engine.finalize_object()
print target_machine.emit_assembly(llvm_module)
return engine
def run_main():
graph = TFGraph("graph.pb", False)
input_node = graph.get_node("input")
output_node = graph.get_node("output")
input_type = TYPE_TF_LLVM[input_node.attr["dtype"].type]
output_type = TYPE_TF_LLVM[output_node.attr["T"].type]
module = ll.Module()
func_type = ll.FunctionType(output_type, [input_type])
func = ll.Function(module, func_type, name='tensorflow_graph')
func.args[0].name = 'input'
bb_entry = func.append_basic_block('entry')
ir_builder = ll.IRBuilder(bb_entry)
ret = build_graph(ir_builder, graph, output_node)
ir_builder.ret(ret)
with open("output.ir", "w") as f:
f.write(str(module))
engine = create_engine(module)
func_ptr = engine.get_function_address("tensorflow_graph")
cfunc = CFUNCTYPE(c_int, c_int)(func_ptr)
ret = cfunc(10)
print "Execution output: {}".format(ret)
if __name__ == "__main__":
run_main()
This website uses cookies to improve your experience. We'll assume you're ok with this, but you can opt-out if you wish. Cookie settingsACCEPT
Privacy & Cookies Policy
Privacy Overview
This website uses cookies to improve your experience while you navigate through the website. Out of these cookies, the cookies that are categorized as necessary are stored on your browser as they are essential for the working of basic functionalities of the website. We also use third-party cookies that help us analyze and understand how you use this website. These cookies will be stored in your browser only with your consent. You also have the option to opt-out of these cookies. But opting out of some of these cookies may have an effect on your browsing experience.
Necessary cookies are absolutely essential for the website to function properly. This category only includes cookies that ensures basic functionalities and security features of the website. These cookies do not store any personal information.
Any cookies that may not be particularly necessary for the website to function and is used specifically to collect user personal data via analytics, ads, other embedded contents are termed as non-necessary cookies. It is mandatory to procure user consent prior to running these cookies on your website.