Tag: machine learning

Machine Learning, Programming, Python

PyTorch – Internal Architecture Tour

Update 28 Feb 2019: I added a new blog post with a slide deck containing the presentation I did for PyData Montreal.

Introduction

This post is a tour around the PyTorch codebase, it is meant to be a guide for the architectural design of PyTorch and its internals. My main goal is to provide something useful for those who are interested in understanding what happens beyond the user-facing API and show something new beyond what was already covered in other tutorials.

Note: PyTorch build system uses code generation extensively so I won’t repeat here what was already described by others. If you’re interested in understanding how this works, please read the following tutorials:

Short intro to Python extension objects in C/C++

As you probably know, you can extend Python using C and C++ and develop what is called as “extension”. All the PyTorch heavy work is implemented in C/C++ instead of pure-Python. To define a new Python object type in C/C++, you define a structure like this one example below (which is the base for the autograd Variable class):

// Python object that backs torch.autograd.Variable
struct THPVariable {
    PyObject_HEAD
    torch::autograd::Variable cdata;
    PyObject* backward_hooks;
};

As you can see, there is a macro at the beginning of the definition, called PyObject_HEAD, this macro’s goal is the standardization of Python objects and will expand to another structure that contains a pointer to a type object (which defines initialization methods, allocators, etc) and also a field with a reference counter.

There are two extra macros in the Python API called Py_INCREF() and Py_DECREF(), which are used to increment and decrement the reference counter of Python objects. Multiple entities can borrow or own a reference to other objects (the reference counter is increased), and only when this reference counter reaches zero (when all references get destroyed), Python will automatically delete the memory from that object using its garbage collector.

You can read more about Python C/++ extensions here.

Funny fact: it is very common in many applications to use small integer numbers as indexing, counters, etc. For efficiency, the official CPython interpreter caches the integers from -5 up to 256. For that reason, the statement a = 200; b = 200; a is b will be True, while the statement a = 300; b = 300; a is b will be False.

Zero-copy PyTorch Tensor to Numpy and vice-versa

PyTorch has its own Tensor representation, which decouples PyTorch internal representation from external representations. However, as it is very common, especially when data is loaded from a variety of sources, to have Numpy arrays everywhere, therefore we really need to make conversions between Numpy and PyTorch tensors. For that reason, PyTorch provides two methods called from_numpy() and numpy(), that converts a Numpy array to a PyTorch array and vice-versa, respectively. If we look the code that is being called to convert a Numpy array into a PyTorch tensor, we can get more insights on the PyTorch’s internal representation:

at::Tensor tensor_from_numpy(PyObject* obj) {
  if (!PyArray_Check(obj)) {
    throw TypeError("expected np.ndarray (got %s)", Py_TYPE(obj)->tp_name);
  }

  auto array = (PyArrayObject*)obj;
  int ndim = PyArray_NDIM(array);
  auto sizes = to_aten_shape(ndim, PyArray_DIMS(array));
  auto strides = to_aten_shape(ndim, PyArray_STRIDES(array));
  // NumPy strides use bytes. Torch strides use element counts.
  auto element_size_in_bytes = PyArray_ITEMSIZE(array);
  for (auto& stride : strides) {
    stride /= element_size_in_bytes;
  }

  // (...) - omitted for brevity

  void* data_ptr = PyArray_DATA(array);
  auto& type = CPU(dtype_to_aten(PyArray_TYPE(array)));
  Py_INCREF(obj);
  return type.tensorFromBlob(data_ptr, sizes, strides, [obj](void* data) {
    AutoGIL gil;
    Py_DECREF(obj);
  });
}

(code from tensor_numpy.cpp)

As you can see from this code, PyTorch is obtaining all information (array metadata) from Numpy representation and then creating its own. However, as you can note from the marked line 18, PyTorch is getting a pointer to the internal Numpy array raw data instead of copying it. This means that PyTorch will create a reference for this data, sharing the same memory region with the Numpy array object for the raw Tensor data.

There is also an important point here: when Numpy array object goes out of scope and get a zero reference count, it will be garbage collected and destroyed, that’s why there is an increment in the reference counting of the Numpy array object at line 20.

After this, PyTorch will create a new Tensor object from this Numpy data blob, and in the creation of this new Tensor it passes the borrowed memory data pointer, together with the memory size and strides as well as a function that will be used later by the Tensor Storage (we’ll discuss this in the next section) to release the data by decrementing the reference counting to the Numpy array object and let Python take care of this object life cycle.

The tensorFromBlob() method will create a new Tensor, but only after creating a new “Storage” for this Tensor. The storage is where the actual data pointer will be stored (and not in the Tensor structure itself). This takes us to the next section about Tensor Storages.

Tensor Storage

The actual raw data of the Tensor is not directly kept in the Tensor structure, but on another structure called Storage, which in turn is part of the Tensor structure.

As we saw in the previous code from tensor_from_numpy(), there is a call for tensorFromBlob() that will create a Tensor from the raw data blob. This last function will call another function called storageFromBlob() that will, in turn, create a storage for this data according to its type. In the case of a CPU float type, it will return a new CPUFloatStorage instance.

The CPUFloatStorage is basically a wrapper with utility functions around the actual storage structure called THFloatStorage that we show below:

typedef struct THStorage
{
    real *data;
    ptrdiff_t size;
    int refcount;
    char flag;
    THAllocator *allocator;
    void *allocatorContext;
    struct THStorage *view;
} THStorage;

(code from THStorage.h)

As you can see, the THStorage holds a pointer to the raw data, its size, flags and also an interesting field called allocator that we’ll soon discuss. It is also important to note that there is no metadata regarding on how to interpret the data inside the THStorage, this is due to the fact that the storage is “dumb” regarding of its contents and it is the Tensor responsibility to know how to “view” or interpret this data.

From this, you already probably realized that we can have multiple tensors pointing to the same storage but with different views of this data, and that’s why viewing a tensor with a different shape (but keeping the same number of elements) is so efficient. This Python code below shows that the data pointer in the storage is being shared after changing the way Tensor views its data:

>>> tensor_a = torch.ones((3, 3))
>>> tensor_b = tensor_a.view(9)
>>> tensor_a.storage().data_ptr() == tensor_b.storage().data_ptr()
True

As we can see in the example above, the data pointer on the storage of both Tensors are the same, but the Tensors represent a different interpretation of the storage data.

Now, as we saw in line 7 of the THFloatStorage structure, there is a pointer to a THAllocator structure there. And this is very important because it brings flexibility regarding the allocator that can be used to allocate the storage data. This structure is represented by the following code:

typedef struct THAllocator
{
  void* (*malloc)(void*, ptrdiff_t);
  void* (*realloc)(void*, void*, ptrdiff_t);
  void (*free)(void*, void*);
} THAllocator;

(code from THAllocator.h)

As you can see, there are three function pointer fields in this structure to define what an allocator means: a malloc, realloc and free. For CPU-allocated memory, these functions will, of course, relate to the traditional malloc/realloc/free POSIX functions, however, when we want a storage allocated on GPUs we’ll end up using the CUDA allocators such as the cudaMallocHost(), like we can see in the THCudaHostAllocator malloc function below:

static void *THCudaHostAllocator_malloc(void* ctx, ptrdiff_t size) {
  void* ptr;
  if (size < 0) THError("Invalid memory size: %ld", size);
  if (size == 0) return NULL;
  THCudaCheck(cudaMallocHost(&ptr, size));
  return ptr;
}

(code from THCAllocator.c)

You probably noticed a pattern in the repository organization, but it is important to keep in mind these conventions when navigating the repository, as summarized here (taken from the PyTorch lib readme):

  • TH = TorcH
  • THC = TorcH Cuda
  • THCS = TorcH Cuda Sparse
  • THCUNN = TorcH CUda Neural Network
  • THD = TorcH Distributed
  • THNN = TorcH Neural Network
  • THS = TorcH Sparse

This convention is also present in the function/class names and other objects, so it is important to always keep these patterns in mind. While you can find CPU allocators in the TH code, you’ll find CUDA allocators in the THC code.

Finally, we can see the composition of the main Tensor THTensor structure:

typedef struct THTensor
{
    int64_t *size;
    int64_t *stride;
    int nDimension;
    THStorage *storage;
    ptrdiff_t storageOffset;
    int refcount;
    char flag;
} THTensor;

(Code from THTensor.h)

And as you can see, the main THTensor structure holds the size/strides/dimensions/offsets/etc as well as the storage (THStorage) for the Tensor data.

We can summarize all this structure that we saw in the diagram below:

Now, once we have requirements such as multi-processing where we want to share tensor data among multiple different processes, we need a shared memory approach to solve it, otherwise, every time another process needs a tensor or even when you want to implement Hogwild training procedure where all different processes will write to the same memory region (where the parameters are), you’ll need to make copies between processes, and this is very inefficient. Therefore we’ll discuss in the next section a special kind of storage for Shared Memory.

Shared Memory

Shared memory can be implemented in many different ways depending on the platform support. PyTorch supports some of them, but for the sake of simplicity, I’ll talk here about what happens on MacOS using the CPU (instead of GPU). Since PyTorch supports multiple shared memory approaches, this part is a little tricky to grasp into since it involves more levels of indirection in the code.

PyTorch provides a wrapper around the Python multiprocessing module and can be imported from torch.multiprocessing. The changes they implemented in this wrapper around the official Python multiprocessing were done to make sure that everytime a tensor is put on a queue or shared with another process, PyTorch will make sure that only a handle for the shared memory will be shared instead of a new entire copy of the Tensor.

Now, many people aren’t aware of a Tensor method from PyTorch called share_memory_(), however, this function is what triggers an entire rebuild of the storage memory for that particular Tensor. What this method does is to create a region of shared memory that can be used among different processes. This function will, in the end, call this following function below:

static THStorage* THPStorage_(newFilenameStorage)(ptrdiff_t size)
{
  int flags = TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_EXCLUSIVE;
  std::string handle = THPStorage_(__newHandle)();
  auto ctx = libshm_context_new(NULL, handle.c_str(), flags);
  return THStorage_(newWithAllocator)(size, &THManagedSharedAllocator, (void*)ctx);
}

(Code from StorageSharing.cpp)

And as you can see, this function will create another storage using a special allocator called THManagedSharedAllocator. This function first defines some flags and then it creates a handle which is a string in the format /torch_[process id]_[random number], and after that, it will then create a new storage using the special THManagedSharedAllocator. This allocator has function pointers to an internal PyTorch library called libshm, that will implement a Unix Domain Socket communication to share the shared memory region handles. This allocator is actual an especial case and it is a kind of “smart allocator” because it contains the communication control logic as well as it uses another allocator called THRefcountedMapAllocator that will be responsible for creating the actual shared memory region and call mmap() to map this region to the process virtual address space.

Note: when a method ends with a underscore in PyTorch, such as the method called share_memory_(), it means that this method has an in-place effect, and it will change the current object instead of creating a new one with the modifications.

I’ll now show a Python example of one processing using the data from a Tensor that was allocated on another process by manually exchanging the shared memory handle:

This is executed in the process A:

>>> import torch
>>> tensor_a = torch.ones((5, 5))
>>> tensor_a

 1  1  1  1  1
 1  1  1  1  1
 1  1  1  1  1
 1  1  1  1  1
 1  1  1  1  1
[torch.FloatTensor of size 5x5]

>>> tensor_a.is_shared()
False
>>> tensor_a = tensor_a.share_memory_()
>>> tensor_a.is_shared()
True
>>> tensor_a_storage = tensor_a.storage()
>>> tensor_a_storage._share_filename_()
(b'/var/tmp/tmp.0.yowqlr', b'/torch_31258_1218748506', 25)

In this code, executed in the process A, we create a new Tensor of 5×5 filled with ones. After that we make it shared and print the tuple with the Unix Domain Socket address as well as the handle. Now we can access this memory region from another process B as shown below:

Code executed in the process B:

>>> import torch
>>> tensor_a = torch.Tensor()
>>> tuple_info = (b'/var/tmp/tmp.0.yowqlr', b'/torch_31258_1218748506', 25)
>>> storage = torch.Storage._new_shared_filename(*tuple_info)
>>> tensor_a = torch.Tensor(storage).view((5, 5))

 1  1  1  1  1
 1  1  1  1  1
 1  1  1  1  1
 1  1  1  1  1
 1  1  1  1  1
[torch.FloatTensor of size 5x5]

As you can see, using the tuple information about the Unix Domain Socket address and the handle we were able to access the Tensor storage from another process. If you change the tensor in this process B, you’ll also see that it will reflect in the process A because these Tensors are sharing the same memory region.

DLPack: a hope for the Deep Learning frameworks Babel

Now I would like to talk about something recent in the PyTorch code base, that is called DLPack. DLPack is an open standardization of an in-memory tensor structure that will allow exchange tensor data between frameworks, and what is quite interesting is that since this memory representation is standardized and very similar to the memory representation already in use by many frameworks, it will allow a zero-copy data sharing between frameworks, which is a quite amazing initiative given the variety of frameworks we have today without inter-communication among them.

This will certainly help to overcome the “island model” that we have today between tensor representations in MXNet, PyTorch, etc, and will allow developers to mix framework operations between frameworks and all the benefits that a standardization can bring to the frameworks.

The core of DLPack os a very simple structure called DLTensor, as shown below:

/*!
 * \brief Plain C Tensor object, does not manage memory.
 */
typedef struct {
  /*!
   * \brief The opaque data pointer points to the allocated data.
   *  This will be CUDA device pointer or cl_mem handle in OpenCL.
   *  This pointer is always aligns to 256 bytes as in CUDA.
   */
  void* data;
  /*! \brief The device context of the tensor */
  DLContext ctx;
  /*! \brief Number of dimensions */
  int ndim;
  /*! \brief The data type of the pointer*/
  DLDataType dtype;
  /*! \brief The shape of the tensor */
  int64_t* shape;
  /*!
   * \brief strides of the tensor,
   *  can be NULL, indicating tensor is compact.
   */
  int64_t* strides;
  /*! \brief The offset in bytes to the beginning pointer to data */
  uint64_t byte_offset;
} DLTensor;

(code from dlpack.h)

As you can see, there is a data pointer for the raw data as well as shape/stride/offset/GPU vs CPU, and other metadata information about the data that the DLTensor pointing to.

There is also a managed version of the tensor that is called DLManagedTensor, where the frameworks can provide a context and also a “deleter” function that can be called by the framework who borrowed the Tensor to inform the other framework that the resources are no longer required.

In PyTorch, if you want to convert to or from a DLTensor format, you can find both C/C++ methods for doing that or even in Python you can do that as shown below:

import torch
from torch.utils import dlpack

t = torch.ones((5, 5))
dl = dlpack.to_dlpack(t)

This Python function will call the toDLPack function from ATen, shown below:

DLManagedTensor* toDLPack(const Tensor& src) {
  ATenDLMTensor * atDLMTensor(new ATenDLMTensor);
  atDLMTensor->handle = src;
  atDLMTensor->tensor.manager_ctx = atDLMTensor;
  atDLMTensor->tensor.deleter = &deleter;
  atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
  int64_t device_id = 0;
  if (src.type().is_cuda()) {
    device_id = src.get_device();
  }
  atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
  atDLMTensor->tensor.dl_tensor.ndim = src.dim();
  atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src.type());
  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
  atDLMTensor->tensor.dl_tensor.byte_offset = 0;
  return &(atDLMTensor->tensor);
}

As you can see, it’s a pretty simple conversion, casting the metadata from the PyTorch format to the DLPack format and assigning a pointer to the internal Tensor data representation.

I really hope that more frameworks adopt this standard that will certainly give benefits to the ecosystem. It is also interesting to note that a potential integration with Apache Arrow would be amazing.

That’s it, I hope you liked this long post !

– Christian S. Perone

Cite this article as: Christian S. Perone, "PyTorch – Internal Architecture Tour," in Terra Incognita, 12/03/2018, https://blog.christianperone.com/2018/03/pytorch-internal-architecture-tour/.
Machine Learning, Math, Programming, Python

Privacy-preserving sentence semantic similarity using InferSent embeddings and secure two-party computation

Privacy-preserving Computation

Privacy-preserving computation or secure computation is a sub-field of cryptography where two (two-party, or 2PC) or multiple (multi-party, or MPC) parties can evaluate a function together without revealing information about the parties private input data to each other. The problem and the first solution to it were introduced in 1982 by an amazing breakthrough done by Andrew Yao on what later became known as the “Yao’s Millionaires’ problem“.

The Yao’s Millionaires Problem is where two millionaires, Alice and Bob, who are interested in knowing which of them is richer but without revealing to each other their actual wealth. In other words, what they want can be generalized as that: Alice and Bob want jointly compute a function securely, without knowing anything other than the result of the computation on the input data (that remains private to them).

To make the problem concrete, Alice has an amount A such as $10, and Bob has an amount B such as $ 50, and what they want to know is which one is larger, without Bob revealing the amount B to Alice or Alice revealing the amount A to Bob. It is also important to note that we also don’t want to trust on a third-party, otherwise the problem would just be a simple protocol of information exchange with the trusted party.

Formally what we want is to jointly evaluate the following function:

r = f(A, B)

Such as the private values A and B are held private to the sole owner of it and where the result r will be known to just one or both of the parties.

It seems very counterintuitive that a problem like that could ever be solved, but for the surprise of many people, it is possible to solve it on some security requirements. Thanks to the recent developments in techniques such as FHE (Fully Homomorphic Encryption), Oblivious Transfer, Garbled Circuits, problems like that started to get practical for real-life usage and they are being nowadays being used by many companies in applications such as information exchange, secure location, advertisement, satellite orbit collision avoidance, etc.

I’m not going to enter into details of these techniques, but if you’re interested in the intuition behind the OT (Oblivious Transfer), you should definitely read the amazing explanation done by Craig Gidney here. There are also, of course, many different protocols for doing 2PC or MPC, where each one of them assumes some security requirements (semi-honest, malicious, etc), I’m not going to enter into the details to keep the post focused on the goal, but you should be aware of that.

The problem: sentence similarity

What we want to achieve is to use privacy-preserving computation to calculate the similarity between sentences without disclosing the content of the sentences. Just to give a concrete example: Bob owns a company and has the description of many different projects in sentences such as: “This project is about building a deep learning sentiment analysis framework that will be used for tweets“, and Alice who owns another competitor company, has also different projects described in similar sentences. What they want to do is to jointly compute the similarity between projects in order to find if they should be doing partnership on a project or not, however, and this is the important point: Bob doesn’t want Alice to know the project descriptions and neither Alice wants Bob to be aware of their projects, they want to know the closest match between the different projects they run, but without disclosing the project ideas (project descriptions).

Sentence Similarity Comparison

Now, how can we exchange information about the Bob and Alice’s project sentences without disclosing information about the project descriptions ?

One naive way to do that would be to just compute the hashes of the sentences and then compare only the hashes to check if they match. However, this would assume that the descriptions are exactly the same, and besides that, if the entropy of the sentences is small (like small sentences), someone with reasonable computation power can try to recover the sentence.

Another approach for this problem (this is the approach that we’ll be using), is to compare the sentences in the sentence embeddings space. We just need to create sentence embeddings using a Machine Learning model (we’ll use InferSent later) and then compare the embeddings of the sentences. However, this approach also raises another concern: what if Bob or Alice trains a Seq2Seq model that would go from the embeddings of the other party back to an approximate description of the project ?

It isn’t unreasonable to think that one can recover an approximate description of the sentence given their embeddings. That’s why we’ll use the two-party secure computation for computing the embeddings similarity, in a way that Bob and Alice will compute the similarity of the embeddings without revealing their embeddings, keeping their project ideas safe.

The entire flow is described in the image below, where Bob and Alice shares the same Machine Learning model, after that they use this model to go from sentences to embeddings, followed by a secure computation of the similarity in the embedding space.

Diagram overview of the entire process.

Generating sentence embeddings with InferSent

Bi-LSTM max-pooling network. Source: Supervised Learning of Universal Sentence Representations from Natural Language Inference Data. Alexis Conneau et al.

InferSent is an NLP technique for universal sentence representation developed by Facebook that uses supervised training to produce high transferable representations.

They used a Bi-directional LSTM with attention that consistently surpassed many unsupervised training methods such as the SkipThought vectors. They also provide a Pytorch implementation that we’ll use to generate sentence embeddings.

Note: even if you don’t have GPU, you can have reasonable performance doing embeddings for a few sentences.

The first step to generate the sentence embeddings is to download and load a pre-trained InferSent model:

import numpy as np
import torch

# Trained model from: https://github.com/facebookresearch/InferSent
GLOVE_EMBS = '../dataset/GloVe/glove.840B.300d.txt'
INFERSENT_MODEL = 'infersent.allnli.pickle'

# Load trained InferSent model
model = torch.load(INFERSENT_MODEL,
                   map_location=lambda storage, loc: storage)

model.set_glove_path(GLOVE_EMBS)
model.build_vocab_k_words(K=100000)

Now we need to define a similarity measure to compare two vectors, and for that goal, I’ll the cosine similarity (I wrote a tutorial about this similarity measure here) since it’s pretty straightforward:

cos(\pmb x, \pmb y) = \frac {\pmb x \cdot \pmb y}{||\pmb x|| \cdot ||\pmb y||}

As you can see, if we have two unit vectors (vectors with norm 1), the two terms in the equation denominator will be 1 and we will be able to remove the entire denominator of the equation, leaving only:

cos(\hat{x}, \hat{y}) =\hat{x} \cdot\hat{y}

So, if we normalize our vectors to have a unit norm (that’s why the vectors are wearing hats in the equation above), we can make the computation of the cosine similarity become just a simple dot product. That will help us a lot in computing the similarity distance later when we’ll use a framework to do the secure computation of this dot product.

So, the next step is to define a function that will take some sentence text and forward it to the model to generate the embeddings and then normalize them to unit vectors:

# This function will forward the text into the model and
# get the embeddings. After that, it will normalize it
# to a unit vector.

def encode(model, text):
    embedding = model.encode([text])[0]
    embedding /= np.linalg.norm(embedding)
    return embedding

As you can see, this function is pretty simple, it feeds the text into the model, and then it will divide the embedding vector by the embedding norm.

Now, for practical reasons, I’ll be using integer computation later for computing the similarity, however, the embeddings generated by InferSent are of course real values. For that reason, you’ll see in the code below that we create another function to scale the float values and remove the radix point and converting them to integers. There is also another important issue, the framework that we’ll be using later for secure computation doesn’t allow signed integers, so we also need to clip the embeddings values between 0.0 and 1.0. This will of course cause some approximation errors, however, we can still get very good approximations after clipping and scaling with limited precision (I’m using 14 bits for scaling to avoid overflow issues later during dot product computations):

# This function will scale the embedding in order to
# remove the radix point.
   
def scale(embedding):
    SCALE = 1 << 14
    scale_embedding = np.clip(embedding, 0.0, 1.0) * SCALE
    return scale_embedding.astype(np.int32)

You can use floating-point in your secure computations and there are a lot of frameworks that support them, however, it is more tricky to do that, and for that reason, I used integer arithmetic to simplify the tutorial. The function above is just a hack to make it simple. It’s easy to see that we can recover this embedding later without too much loss of precision.

Now we just need to create some sentence samples that we’ll be using:

# The list of Alice sentences
alice_sentences = [
    'my cat loves to walk over my keyboard',
    'I like to pet my cat',
]

# The list of Bob sentences
bob_sentences = [
    'the cat is always walking over my keyboard',
]

And convert them to embeddings:

# Alice sentences
alice_sentence1 = encode(model, alice_sentences[0])
alice_sentence2 = encode(model, alice_sentences[1])

# Bob sentences
bob_sentence1 = encode(model, bob_sentences[0])

Since we have now the sentences and every sentence is also normalized, we can compute cosine similarity just by doing a dot product between the vectors:

>>> np.dot(bob_sentence1, alice_sentence1)
0.8798542

>>> np.dot(bob_sentence1, alice_sentence2)
0.62976325

As we can see, the first sentence of Bob is most similar (~0.87) with Alice first sentence than to the Alice second sentence (~0.62).

Since we have now the embeddings, we just need to convert them to scaled integers:

# Scale the Alice sentence embeddings
alice_sentence1_scaled = scale(alice_sentence1)
alice_sentence2_scaled = scale(alice_sentence2)

# Scale the Bob sentence embeddings
bob_sentence1_scaled = scale(bob_sentence1)

# This is the unit vector embedding for the sentence
>>> alice_sentence1
array([ 0.01698913, -0.0014404 ,  0.0010993 , ...,  0.00252409,
        0.00828147,  0.00466533], dtype=float32)

# This is the scaled vector as integers
>>> alice_sentence1_scaled
array([278,   0,  18, ...,  41, 135,  76], dtype=int32)

Now with these embeddings as scaled integers, we can proceed to the second part, where we’ll be doing the secure computation between two parties.

Two-party secure computation

In order to perform secure computation between the two parties (Alice and Bob), we’ll use the ABY framework. ABY implements many difference secure computation schemes and allows you to describe your computation as a circuit like pictured in the image below, where the Yao’s Millionaire’s problem is described:

Yao’s Millionaires problem. Taken from ABY documentation (https://github.com/encryptogroup/ABY).

As you can see, we have two inputs entering in one GT GATE (greater than gate) and then a output. This circuit has a bit length of 3 for each input and will compute if the Alice input is greater than (GT GATE) the Bob input. The computing parties then secret share their private data and then can use arithmetic sharing, boolean sharing, or Yao sharing to securely evaluate these gates.

ABY is really easy to use because you can just describe your inputs, shares, gates and it will do the rest for you such as creating the socket communication channel, exchanging data when needed, etc. However, the implementation is entirely written in C++ and I’m not aware of any Python bindings for it (a great contribution opportunity).

Fortunately, there is an implemented example for ABY that can do dot product calculation for us, the example is here. I won’t replicate the example here, but the only part that we have to change is to read the embedding vectors that we created before instead of generating random vectors and increasing the bit length to 32-bits.

After that, we just need to execute the application on two different machines (or by emulating locally like below):

# This will execute the server part, the -r 0 specifies the role (server)
# and the -n 4096 defines the dimension of the vector (InferSent generates
# 4096-dimensional embeddings).
~# ./innerproduct -r 0 -n 4096

# And the same on another process (or another machine, however for another
# machine execution you'll have to obviously specify the IP).
~# ./innerproduct -r 1 -n 4096

And we get the following results:

Inner Product of alice_sentence1 and bob_sentence1  = 226691917
Inner Product of alice_sentence2 and bob_sentence1  = 171746521

Even in the integer representation, you can see that the inner product of the Alice’s first sentence and the Bob sentence is higher, meaning that the similarity is also higher. But let’s now convert this value back to float:

>>> SCALE = 1 << 14

# This is the dot product we should get
>>> np.dot(alice_sentence1, bob_sentence1)
0.8798542

# This is the inner product we got on secure computation
>>> 226691917 / SCALE**2.0
0.8444931

# This is the dot product we should get
>>> np.dot(alice_sentence2, bob_sentence1)
0.6297632

# This is the inner product we got on secure computation
>>> 171746521 / SCALE**2.0
0.6398056

As you can see, we got very good approximations, even in presence of low-precision math and unsigned integer requirements. Of course that in real-life you won’t have the two values and vectors, because they’re supposed to be hidden, but the changes to accommodate that are trivial, you just need to adjust ABY code to load only the vector of the party that it is executing it and using the correct IP addresses/port of the both parties.

I hope you liked it !

– Christian S. Perone

Cite this article as: Christian S. Perone, "Privacy-preserving sentence semantic similarity using InferSent embeddings and secure two-party computation," in Terra Incognita, 22/01/2018, https://blog.christianperone.com/2018/01/privacy-preserving-infersent/.
Machine Learning

Introduction to Word Embeddings

Just publised a deck of slides for an Introduction to Word Embeddings, presented at our Machine Learning Meetup in Porto Alegre / RS / Brazil.

Here is the video (in Portugese):

Cite this article as: Christian S. Perone, "Introduction to Word Embeddings," in Terra Incognita, 08/02/2017, https://blog.christianperone.com/2017/02/introduction-to-word-embeddings/.
LLVM, Machine Learning, Programming, Python

JIT native code generation for TensorFlow computation graphs using Python and LLVM

Update: Hacker News discussion here.

The TensorFlow Computation Graph

tensorlogo

One of the most amazing components of the TensorFlow architecture is the computation graph that can be serialized using Protocol Buffers. This computation graph follows a well-defined format (click here for the proto files) and describes the computation that you specify (it can be a Deep Learning model like a CNN, a simple Logistic Regression or even any computation you want). For instance, here is an example of a very simple TensorFlow computation graph that we will use in this tutorial (using TensorFlow Python API):

import tensorflow as tf

with tf.Session() as sess:
    input_placeholder = tf.placeholder(tf.int32, 1, name="input")
    sub_op = tf.sub(input_placeholder, tf.constant(2, dtype=tf.int32))
    add_op = tf.add(sub_op, tf.constant(5, dtype=tf.int32))
    output = tf.add(add_op, tf.constant(100, dtype=tf.int32),
                    name="output")
    tf.train.write_graph(sess.graph_def, ".", "graph.pb", True)
Representation of the computation graph.
Representation of the computation graph.

As you can see, this is a very simple computation graph. First, we define the placeholder that will hold the input tensor and after that we specify the computation that should happen using this input tensor as input data. Here we can also see that we’re defining two important nodes of this graph, one is called “input” (the aforementioned placeholder) and the other is called “output“, that will hold the result of the final computation. This graph is the same as the following formula for a scalar: output = (((input - 2)-5)+100), where I intentionally added redundant operations to see LLVM constant propagation later.

In the last line of the code, we’re persisting this computation graph (including the constant values) into a serialized protobuf file. The final True parameter is to output a textual representation instead of binary, so it will produce the following human-readable output protobuf file (I omitted a part of it for brevity):

node {
  name: "input"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 2
      }
    }
  }
}

--- >(omitted for brevity) < ---

node {
  name: "output"
  op: "Add"
  input: "Add"
  input: "Const_2"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
versions {
  producer: 9
}

This is a very simple graph, and TensorFlow graphs are actually never that simple, because TensorFlow models can easily contain more than 300 nodes depending on the model you’re specifying, specially for Deep Learning models.

We’ll use the above graph to show how we can JIT native code for this simple graph using LLVM framework.

The LLVM Frontend, IR and Backend

LLVM-Logo-Derivative-1

The LLVM framework is a really nice, modular and complete ecosystem for building compilers and toolchains. A very nice description of the LLVM architecture that is important for us is shown in the picture below:

LLVM Compiler Architecture
LLVM Compiler Architecture (AOSA/LLVM, Chris Lattner)

(The picture above is just a small part of the LLVM architecture, for a comprehensive description of it, please see the nice article from the AOSA book written by Chris Lattner)

Looking in the image above, we can see that LLVM provides a lot of core functionality, in the left side you see that many languages can write code for their respective language frontends, after that it doesn’t matter in which language you wrote your code, everything is transformed into a very powerful language called LLVM IR (LLVM Intermediate Representation) which is as you can imagine, a intermediate representation of the code just before the assembly code itself. In my opinion, the IR is the key component of what makes LLVM so amazing, because it doesn’t matter in which language you wrote your code (or even if it was a JIT’ed IR), everything ends in the same representation, and then here is where the magic happens, because the IR can take advantage of the LLVM optimizations (also known as transform and analysis passes).

After this IR generation, you can feed it into any LLVM backend to generate native code for any architecture supported by LLVM (such as x86, ARM, PPC, etc) and then you can finally execute your code with the native performance and also after LLVM optimization passes.

In order to JIT code using LLVM, all you need is to build the IR programmatically, create a execution engine to convert (during execution-time) the IR into native code, get a pointer for the function you have JIT’ed and then finally execute it. I’ll use here a Python binding for LLVM called llvmlite, which is very Pythonic and easy to use.

JIT’ing TensorFlow Graph using Python and LLVM

flow

Let’s now use the LLVM and Python to JIT the TensorFlow computational graph. This is by no means a comprehensive implementation, it is very simplistic approach, a oversimplification that assumes some things: a integer closure type, just some TensorFlow operations and also a single scalar support instead of high rank tensors.

So, let’s start building our JIT code; first of all, let’s import the required packages, initialize some LLVM sub-systems and also define the LLVM respective type for the TensorFlow integer type:

from ctypes import CFUNCTYPE, c_int

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import ops

import llvmlite.ir as ll
import llvmlite.binding as llvm

llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()

TYPE_TF_LLVM = {
    types_pb2.DT_INT32: ll.IntType(32),
}

After that, let’s define a class to open the TensorFlow exported graph and also declare a method to get a node of the graph by name:

class TFGraph(object):
    def __init__(self, filename="graph.pb", binary=False):
        self.graph_def = graph_pb2.GraphDef()
        with open("graph.pb", "rb") as f:
            if binary:
                self.graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), self.graph_def)

    def get_node(self, name):
        for node in self.graph_def.node:
            if node.name == name:
                return node

And let’s start by defining our main function that will be the starting point of the code:

def run_main():
    graph = TFGraph("graph.pb", False)
    input_node = graph.get_node("input")
    output_node = graph.get_node("output")

    input_type = TYPE_TF_LLVM[input_node.attr["dtype"].type]
    output_type = TYPE_TF_LLVM[output_node.attr["T"].type]

    module = ll.Module()
    func_type = ll.FunctionType(output_type, [input_type])
    func = ll.Function(module, func_type, name='tensorflow_graph')
    func.args[0].name = 'input'

    bb_entry = func.append_basic_block('entry')
    ir_builder = ll.IRBuilder(bb_entry)

As you can see in the code above, we open the serialized protobuf graph and then get the input and output nodes of this graph. After that we also map the type of the both graph nodes (input/output) to the LLVM type (from TensorFlow integer to LLVM integer). We start then by defining a LLVM Module, which is the top level container for all IR objects. One module in LLVM can contain many different functions, here we will create just one function that will represent the graph, this function will receive as input argument the input data of the same type of the input node and then it will return a value with the same type of the output node.

After that we start by creating the entry block of the function and using this block we instantiate our IR Builder, which is a object that will provide us the building blocks for JIT’ing operations of TensorFlow graph.

Let’s now define the function that will do the real work of converting TensorFlow nodes into LLVM IR:

def build_graph(ir_builder, graph, node):
    if node.op == "Add":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.add(left_op, right_op)

    if node.op == "Sub":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.sub(left_op, right_op)

    if node.op == "Placeholder":
        function_args = ir_builder.function.args
        for arg in function_args:
            if arg.name == node.name:
                return arg
        raise RuntimeError("Input [{}] not found !".format(node.name))

    if node.op == "Const":
        llvm_const_type = TYPE_TF_LLVM[node.attr["dtype"].type]
        const_value = node.attr["value"].tensor.int_val[0]
        llvm_const_value = llvm_const_type(const_value)
        return llvm_const_value

In this function, we receive by parameters the IR Builder, the graph class that we created earlier and the output node. This function will then recursively build the LLVM IR by means of the IR Builder. Here you can see that I only implemented the Add/Sub/Placeholder and Const operations from the TensorFlow graph, just to be able to support the graph that we defined earlier.

After that, we just need to define a function that will take a LLVM Module and then create a execution engine that will execute the LLVM optimization over the LLVM IR before doing the hard-work of converting the IR into native x86 code:

def create_engine(module):
    features = llvm.get_host_cpu_features().flatten()
    llvm_module = llvm.parse_assembly(str(module))
    target = llvm.Target.from_default_triple()
    target_machine = target.create_target_machine(opt=3, features=features)
    engine = llvm.create_mcjit_compiler(llvm_module, target_machine)
    engine.finalize_object()
    print target_machine.emit_assembly(llvm_module)
    return engine

In the code above, you can see that we first get the CPU features (SSE, etc) into a list, after that we parse the LLVM IR from the module and then we create a engine using maximum optimization level (opt=3, roughly equivalent to the GCC -O3 parameter), we’re also printing the assembly code (in my case, the x86 assembly built by LLVM).

And here we just finish our run_main() function:

ret = build_graph(ir_builder, graph, output_node)
ir_builder.ret(ret)

with open("output.ir", "w") as f:
    f.write(str(module))

engine = create_engine(module)

func_ptr = engine.get_function_address("tensorflow_graph")
cfunc = CFUNCTYPE(c_int, c_int)(func_ptr)
ret = cfunc(10)

print "Execution output: {}".format(ret)

As you can see in the code above, we just call the build_graph() method and then use the IR Builder to add the “ret” LLVM IR instruction (ret = return) to return the output of the IR function we just created based on the TensorFlow graph. We’re also here writing the IR output to a external file, I’ll use this LLVM IR file later to create native assembly for other different architectures such as ARM architecture. And finally, just get the native code function address, create a Python wrapper for this function and then call it with the argument “10”, which will be input data and then output the resulting output value.

And that is it, of course that this is just a oversimplification, but now we understand the advantages of having a JIT for our TensorFlow models.

The output LLVM IR, the advantage of optimizations and multiple architectures (ARM, PPC, x86, etc)

For instance, lets create the LLVM IR (using the code I shown above) of the following TensorFlow graph:

import tensorflow as tf

with tf.Session() as sess:
    input_placeholder = tf.placeholder(tf.int32, 1, name="input")
    sub_op = tf.sub(input_placeholder, tf.constant(2, dtype=tf.int32))
    add_op = tf.add(sub_op, tf.constant(5, dtype=tf.int32))
    output = tf.add(add_op, tf.constant(100, dtype=tf.int32),
                    name="output")
    tf.train.write_graph(sess.graph_def, ".", "graph.pb", True)

The LLVM IR generated is this one below:

; ModuleID = ""
target triple = "unknown-unknown-unknown"
target datalayout = ""

define i32 @"tensorflow_graph"(i32 %"input") 
{
entry:
  %".3" = sub i32 %"input", 2
  %".4" = add i32 %".3", 5
  %".5" = add i32 %".4", 100
  ret i32 %".5"
}

As you can see, the LLVM IR looks a lot like an assembly code, but this is not the final assembly code, this is just a non-optimized IR yet. Just before generating the x86 assembly code, LLVM runs a lot of optimization passes over the LLVM IR, and it will do things such as dead code elimination, constant propagation, etc. And here is the final native x86 assembly code that LLVM generates for the above LLVM IR of the TensorFlow graph:

    .text
    .file	"<string>"
    .globl	tensorflow_graph
    .align	16, 0x90
    .type	tensorflow_graph,@function
tensorflow_graph:
    .cfi_startproc
    leal	103(%rdi), %eax
    retq
.Lfunc_end0:
    .size	tensorflow_graph, .Lfunc_end0-tensorflow_graph
    .cfi_endproc

    .section	".note.GNU-stack","",@progbits

As you can see, the optimized code removed a lot of redundant operations, and ended up just doing a add operation of 103, which is the correct simplification of the computation that we defined in the graph. For large graphs, you can see that these optimizations can be really powerful, because we are reusing the compiler optimizations that were developed for years in our Machine Learning model computation.

You can also use a LLVM tool called “llc”, that can take an LLVM IR file and the generate assembly for any other platform you want, for instance, the command-line below will generate native code for ARM architecture:

llc -O3 out.ll -march=arm -o sample.s

The output sample.s file is the one below:

    .text
    .syntax unified
    .eabi_attribute	67, "2.09"	@ Tag_conformance
    .eabi_attribute	6, 1	@ Tag_CPU_arch
    .eabi_attribute	8, 1	@ Tag_ARM_ISA_use
    .eabi_attribute	17, 1	@ Tag_ABI_PCS_GOT_use
    .eabi_attribute	20, 1	@ Tag_ABI_FP_denormal
    .eabi_attribute	21, 1	@ Tag_ABI_FP_exceptions
    .eabi_attribute	23, 3	@ Tag_ABI_FP_number_model
    .eabi_attribute	34, 1	@ Tag_CPU_unaligned_access
    .eabi_attribute	24, 1	@ Tag_ABI_align_needed
    .eabi_attribute	25, 1	@ Tag_ABI_align_preserved
    .eabi_attribute	38, 1	@ Tag_ABI_FP_16bit_format
    .eabi_attribute	14, 0	@ Tag_ABI_PCS_R9_use
    .file	"out.ll"
    .globl	tensorflow_graph
    .align	2
    .type	tensorflow_graph,%function
tensorflow_graph:                       @ @tensorflow_graph
    .fnstart
@ BB#0:                                 @ %entry
    add	r0, r0, #103
    mov	pc, lr
.Lfunc_end0:
    .size	tensorflow_graph, .Lfunc_end0-tensorflow_graph
    .fnend

    .section	".note.GNU-stack","",%progbits

As you can see above, the ARM assembly code is also just a “add” assembly instruction followed by a return instruction.

This is really nice because we can take natural advantage of the LLVM framework. For instance, today ARM just announced the ARMv8-A with Scalable Vector Extensions (SVE) that will support 2048-bit vectors, and they are already working on patches for LLVM. In future, a really nice addition to LLVM would be the development of LLVM Passes for analysis and transformation that would take into consideration the nature of Machine Learning models.

And that’s it, I hope you liked the post ! Is really awesome what you can do with a few lines of Python, LLVM and TensorFlow.

Update 22 Aug 2016: Josh Klontz just pointed his cool project called Likely on Hacker News discussion.

Update 22 Aug 2016: TensorFlow team is actually working on a JIT (I don’t know if they are using LLVM, but it seems the most reasonable way to go in my opinion). In their paper, there is also a very important statement regarding Future Work that I cite here:

“We also have a number of concrete directions to improve the performance of TensorFlow. One such direction is our initial work on a just-in-time compiler that can take a subgraph of a TensorFlow execution, perhaps with some runtime profiling information about the typical sizes and shapes of tensors, and can generate an optimized routine for this subgraph. This compiler will understand the semantics of perform a number of optimizations such as loop fusion, blocking and tiling for locality, specialization for particular shapes and sizes, etc.” – TensorFlow White Paper

Full code

from ctypes import CFUNCTYPE, c_int

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import ops

import llvmlite.ir as ll
import llvmlite.binding as llvm

llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()

TYPE_TF_LLVM = {
    types_pb2.DT_INT32: ll.IntType(32),
}


class TFGraph(object):
    def __init__(self, filename="graph.pb", binary=False):
        self.graph_def = graph_pb2.GraphDef()
        with open("graph.pb", "rb") as f:
            if binary:
                self.graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), self.graph_def)

    def get_node(self, name):
        for node in self.graph_def.node:
            if node.name == name:
                return node


def build_graph(ir_builder, graph, node):
    if node.op == "Add":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.add(left_op, right_op)

    if node.op == "Sub":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.sub(left_op, right_op)

    if node.op == "Placeholder":
        function_args = ir_builder.function.args
        for arg in function_args:
            if arg.name == node.name:
                return arg
        raise RuntimeError("Input [{}] not found !".format(node.name))

    if node.op == "Const":
        llvm_const_type = TYPE_TF_LLVM[node.attr["dtype"].type]
        const_value = node.attr["value"].tensor.int_val[0]
        llvm_const_value = llvm_const_type(const_value)
        return llvm_const_value


def create_engine(module):
    features = llvm.get_host_cpu_features().flatten()
    llvm_module = llvm.parse_assembly(str(module))
    target = llvm.Target.from_default_triple()
    target_machine = target.create_target_machine(opt=3, features=features)
    engine = llvm.create_mcjit_compiler(llvm_module, target_machine)
    engine.finalize_object()
    print target_machine.emit_assembly(llvm_module)
    return engine


def run_main():
    graph = TFGraph("graph.pb", False)
    input_node = graph.get_node("input")
    output_node = graph.get_node("output")

    input_type = TYPE_TF_LLVM[input_node.attr["dtype"].type]
    output_type = TYPE_TF_LLVM[output_node.attr["T"].type]

    module = ll.Module()
    func_type = ll.FunctionType(output_type, [input_type])
    func = ll.Function(module, func_type, name='tensorflow_graph')
    func.args[0].name = 'input'

    bb_entry = func.append_basic_block('entry')
    ir_builder = ll.IRBuilder(bb_entry)

    ret = build_graph(ir_builder, graph, output_node)
    ir_builder.ret(ret)

    with open("output.ir", "w") as f:
        f.write(str(module))

    engine = create_engine(module)

    func_ptr = engine.get_function_address("tensorflow_graph")
    cfunc = CFUNCTYPE(c_int, c_int)(func_ptr)
    ret = cfunc(10)

    print "Execution output: {}".format(ret)


if __name__ == "__main__":
    run_main()
Cite this article as: Christian S. Perone, "JIT native code generation for TensorFlow computation graphs using Python and LLVM," in Terra Incognita, 22/08/2016, https://blog.christianperone.com/2016/08/jit-native-code-generation-for-tensorflow-computation-graphs-using-python-and-llvm/.
Machine Learning

Convolutional Neural Networks – Architectural Zoo

Presentation about an “Achitectural Zoo” of different applications and architectures of CNNs. Presented at Machine Learning Meetup in Porto Alegre yesterday.

Video (there are english subtitles available):

Cite this article as: Christian S. Perone, "Convolutional Neural Networks – Architectural Zoo," in Terra Incognita, 02/06/2016, https://blog.christianperone.com/2016/06/convolutional-neural-networks-architectural-zoo/.
Machine Learning, Python

Voynich Manuscript: word vectors and t-SNE visualization of some patterns

Update 17/01: reddit discussion thread.

Update 19/01: hacker news thread.

The codex

voynich_headerThe Voynich Manuscript is a hand-written codex written in an unknown system and carbon-dated to the early 15th century (1404–1438). Although the manuscript has been studied by some famous cryptographers of the World War I and II, nobody has deciphered it yet. The manuscript is known to be written in two different languages (Language A and Language B) and it is also known to be written by a group of people. The manuscript itself is always subject of a lot of different hypothesis, including the one that I like the most which is the “culture extinction” hypothesis, supported in 2014 by Stephen Bax. This hypothesis states that the codex isn’t ciphered, it states that the codex was just written in an unknown language that disappeared due to a culture extinction. In 2014, Stephen Bax proposed a provisional, partial decoding of the manuscript, the video of his presentation is very interesting and I really recommend you to watch if you like this codex. There is also a transcription of the manuscript done thanks to the hard-work of many folks working on it since many moons ago.

Word vectors

My idea when I heard about the work of Stephen Bax was to try to capture the patterns of the text using word2vec.  Word embeddings are created by using a shallow neural network architecture. It is a unsupervised technique that uses supervided learning tasks to learn the linguistic context of the words. Here is a visualization of this architecture from the TensorFlow site:

softmax-nplm

These word vectors, after trained, carry with them a lot of semantic meaning. For instance:

word2vecqueen

We can see that those vectors can be used in vector operations to extract information about the regularities of the captured linguistic semantics. These vectors also approximates same-meaning words together, allowing similarity queries like in the example below:

>>> model.most_similar("man")
[(u'woman', 0.6056041121482849), (u'guy', 0.4935004413127899), (u'boy', 0.48933547735214233), (u'men', 0.4632953703403473), (u'person', 0.45742249488830566), (u'lady', 0.4487500488758087), (u'himself', 0.4288588762283325), (u'girl', 0.4166809320449829), (u'his', 0.3853422999382019), (u'he', 0.38293731212615967)]

>>> model.most_similar("queen")
[(u'princess', 0.519856333732605), (u'latifah', 0.47644317150115967), (u'prince', 0.45914226770401), (u'king', 0.4466976821422577), (u'elizabeth', 0.4134873151779175), (u'antoinette', 0.41033703088760376), (u'marie', 0.4061327874660492), (u'stepmother', 0.4040161967277527), (u'belle', 0.38827288150787354), (u'lovely', 0.38668593764305115)]

Word vectors can also be used (surprise) for translation, and this is the feature of the word vectors that I think that its most important when used to understand text where we know some of the words translations. I pretend to try to use the words found by Stephen Bax in the future to check if it is possible to capture some transformation that could lead to find similar structures with other languages. A nice visualization of this feature is the one below from the paper “Exploiting Similarities among Languages for Machine Translation“:

transl

This visualization was made using gradient descent to optimize a linear transformation between the source and destination language word vectors. As you can see, the structure in Spanish is really close to the structure in English.

 EVA Transcription

To train this model, I had to parse and extract the transcription from the EVA (European Voynich Alphabet) to be able to feed the Voynich sentences into the word2vec model. This EVA transcription has the following format:

<f1r.P1.1;H>       fachys.ykal.ar.ataiin.shol.shory.cth!res.y.kor.sholdy!-
<f1r.P1.1;C>       fachys.ykal.ar.ataiin.shol.shory.cthorys.y.kor.sholdy!-
<f1r.P1.1;F>       fya!ys.ykal.ar.ytaiin.shol.shory.*k*!res.y!kor.sholdy!-
<f1r.P1.1;N>       fachys.ykal.ar.ataiin.shol.shory.cth!res.y,kor.sholdy!-
<f1r.P1.1;U>       fya!ys.ykal.ar.ytaiin.shol.shory.***!r*s.y.kor.sholdo*-
#
<f1r.P1.2;H>       sory.ckhar.o!r.y.kair.chtaiin.shar.are.cthar.cthar.dan!-
<f1r.P1.2;C>       sory.ckhar.o.r.y.kain.shtaiin.shar.ar*.cthar.cthar.dan!-
<f1r.P1.2;F>       sory.ckhar.o!r!y.kair.chtaiin.shor.ar!.cthar.cthar.dana-
<f1r.P1.2;N>       sory.ckhar.o!r,y.kair.chtaiin.shar.are.cthar.cthar,dan!-
<f1r.P1.2;U>       sory.ckhar.o!r!y.kair.chtaiin.shor.ary.cthar.cthar.dan*-

The first data between “<” and “>” has information about the folio (page), line and author of the transcription. The transcription block above is the transcription for the first two lines of the first folio of the manuscript below:

Part of the "f1r"
Part of the “f1r”

As you can see, the EVA contains some code characters, like for instance “!”, “*” and they all have some meaning, like to inform that the author doing that translation is not sure about the character in that position, etc. EVA also contains transcription from different authors for the same line of the folio.

To convert this transcription to sentences I used only lines where the authors were sure about the entire line and I used the first line where the line satisfied this condition. I also did some cleaning on the transcription to remove the drawings names from the text, like: “text.text.text-{plant}text” -> “text text texttext”.

After this conversion from the EVA transcript to sentences compatible with the word2vec model, I trained the model to provide 100-dimensional word vectors for the words of the manuscript.

Vector space visualizations using t-SNE

After training word vectors, I created a visualization of the 100-dimensional vectors into a 2D embedding space using t-SNE algorithm:

tsne-vis1

As you can see there are a lot of small clusters and there visually two big clusters, probably accounting for the two different languages used in the Codex (I still need to confirm this regarding the two languages aspect). After clustering it with DBSCAN (using the original word vectors, not the t-SNE transformed vectors), we can clearly see the two major clusters:

tsne-vis-dbscan

Now comes the really interesting and useful part of the word vectors, if use a star name from the folio below (it’s pretty obvious why it is know that this is probably a star name):

>>> w2v_model.most_similar("octhey")

[('qoekaiin', 0.6402825713157654),
 ('otcheody', 0.6389687061309814),
 ('ytchos', 0.566596269607544),
 ('ocphy', 0.5415685176849365),
 ('dolchedy', 0.5343093872070312),
 ('aiicthy', 0.5323750376701355),
 ('odchecthy', 0.5235849022865295),
 ('okeeos', 0.5187858939170837),
 ('cphocthy', 0.5159749388694763),
 ('oteor', 0.5050544738769531)]

I get really interesting similar words, like for instance the ocphy and other close star names:

stars

It also returns the word “qoekaiin” from the folio 48, that precedes the same star name:

foliostars

As you can see, word vectors are really useful to find some linguistic structures, we can also create another plot, showing how close are the star names in the 2D embedding space visualization created using t-SNE:

star_clus

As you can see, we zoomed the major cluster of stars and we can see that they are really all grouped together in the vector space. These representations can be used for instance to infer plat names from the herbal section, etc.

My idea was to show how useful word vectors are to analyze unknown codex texts, I hope you liked and I hope that this could be somehow useful for other people how are also interested in this amazing manuscript.

– Christian S. Perone

Cite this article as: Christian S. Perone, "Voynich Manuscript: word vectors and t-SNE visualization of some patterns," in Terra Incognita, 16/01/2016, https://blog.christianperone.com/2016/01/voynich-manuscript-word-vectors-and-t-sne-visualization-of-some-patterns/.

References

Voynich Digitalization

Stephen Bax Site

René Zandbergen Site

Machine Learning, Python

Convolutional hypercolumns in Python

If you are following some Machine Learning news, you certainly saw the work done by Ryan Dahl on Automatic Colorization (Hacker News comments, Reddit comments). This amazing work uses pixel hypercolumn information extracted from the VGG-16 network in order to colorize images. Samim also used the network to process Black & White video frames and produced the amazing video below:

https://www.youtube.com/watch?v=_MJU8VK2PI4

Colorizing Black&White Movies with Neural Networks (video by Samim, network by Ryan)

But how does this hypercolumns works ? How to extract them to use on such variety of pixel classification problems ? The main idea of this post is to use the VGG-16 pre-trained network together with Keras and Scikit-Learn in order to extract the pixel hypercolumns and take a superficial look at the information present on it. I’m writing this because I haven’t found anything in Python to do that and this may be really useful for others working on pixel classification, segmentation, etc.

Hypercolumns

Many algorithms using features from CNNs (Convolutional Neural Networks) usually use the last FC (fully-connected) layer features in order to extract information about certain input. However, the information in the last FC layer may be too coarse spatially to allow precise localization (due to sequences of maxpooling, etc.), on the other side, the first layers may be spatially precise but will lack semantic information. To get the best of both worlds, the authors of the hypercolumn paper define the hypercolumn of a pixel as the vector of activations of all CNN units “above” that pixel.

Hypercolumn Extraction
Hypercolumn Extraction (by Hypercolumns for Object Segmentation and Fine-grained Localization)

The first step on the extraction of the hypercolumns is to feed the image into the CNN (Convolutional Neural Network) and extract the feature map activations for each location of the image. The tricky part is when the feature maps are smaller than the input image, for instance after a pooling operation, the authors of the paper then do a bilinear upsampling of the feature map in order to keep the feature maps on the same size of the input. There are also the issue with the FC (fully-connected) layers, because you can’t isolate units semantically tied only to one pixel of the image, so the FC activations are seen as 1×1 feature maps, which means that all locations shares the same information regarding the FC part of the hypercolumn. All these activations are then concatenated to create the hypercolumn. For instance, if we take the VGG-16 architecture to use only the first 2 convolutional layers after the max pooling operations, we will have a hypercolumn with the size of:

64 filters (first conv layer before pooling)

+

128 filters (second conv layer before pooling ) = 192 features

This means that each pixel of the image will have a 192-dimension hypercolumn vector. This hypercolumn is really interesting because it will contain information about the first layers (where we have a lot of spatial information but little semantic) and also information about the final layers (with little spatial information and lots of semantics). Thus this hypercolumn will certainly help in a lot of pixel classification tasks such as the one mentioned earlier of automatic colorization, because each location hypercolumn carries the information about what this pixel semantically and spatially represents. This is also very helpful on segmentation tasks (you can see more about that on the original paper introducing the hypercolumn concept).

Everything sounds cool, but how do we extract hypercolumns in practice ?

VGG-16

Before being able to extract the hypercolumns, we’ll setup the VGG-16 pre-trained network, because you know, the price of a good GPU (I can’t even imagine many of them) here in Brazil is very expensive and I don’t want to sell my kidney to buy a GPU.

VGG16 Network Architecture (by Zhicheng Yan et al.)
VGG16 Network Architecture (by Zhicheng Yan et al.)

To setup a pretrained VGG-16 network on Keras, you’ll need to download the weights file from here (vgg16_weights.h5 file with approximately 500MB) and then setup the architecture and load the downloaded weights using Keras (more information about the weights file and architecture here):

from matplotlib import pyplot as plt

import theano
import cv2
import numpy as np
import scipy as sp

from keras.models import Sequential
from keras.layers.core import Flatten, Dense, Dropout
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.convolutional import ZeroPadding2D
from keras.optimizers import SGD

from sklearn.manifold import TSNE
from sklearn import manifold
from sklearn import cluster
from sklearn.preprocessing import StandardScaler

def VGG_16(weights_path=None):
    model = Sequential()
    model.add(ZeroPadding2D((1,1),input_shape=(3,224,224)))
    model.add(Convolution2D(64, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(64, 3, 3, activation='relu'))
    model.add(MaxPooling2D((2,2), stride=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(128, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(128, 3, 3, activation='relu'))
    model.add(MaxPooling2D((2,2), stride=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, 3, 3, activation='relu'))
    model.add(MaxPooling2D((2,2), stride=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, 3, 3, activation='relu'))
    model.add(MaxPooling2D((2,2), stride=(2,2)))

    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, 3, 3, activation='relu'))
    model.add(MaxPooling2D((2,2), stride=(2,2)))

    model.add(Flatten())
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1000, activation='softmax'))

    if weights_path:
        model.load_weights(weights_path)

    return model

As you can see, this is a very simple code to declare the VGG16 architecture and load the pre-trained weights (together with Python imports for the required packages). After that we’ll compile the Keras model:

model = VGG_16('vgg16_weights.h5')
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy')

Now let’s test the network using an image:

im_original = cv2.resize(cv2.imread('madruga.jpg'), (224, 224))
im = im_original.transpose((2,0,1))
im = np.expand_dims(im, axis=0)
im_converted = cv2.cvtColor(im_original, cv2.COLOR_BGR2RGB)
plt.imshow(im_converted)

Image used

Image used

As we can see, we loaded the image, fixed the axes and then we can now feed the image into the VGG-16 to get the predictions:

out = model.predict(im)
plt.plot(out.ravel())

 

Predictions
Predictions

As you can see, these are the final activations of the softmax layer, the class with the “jersey, T-shirt, tee shirt” category.

Extracting arbitrary feature maps

Now, to extract the feature map activations, we’ll have to being able to extract feature maps from arbitrary convolutional layers of the network. We can do that by compiling a Theano function using the get_output() method of Keras, like in the example below:

get_feature = theano.function([model.layers[0].input], model.layers[3].get_output(train=False), allow_input_downcast=False)
feat = get_feature(im)
plt.imshow(feat[0][2])

Feature Map

Feature Map

In the example above, I’m compiling a Theano function to get the 3 layer (a convolutional layer) feature map and then showing only the 3rd feature map. Here we can see the intensity of the activations. If we get feature maps of the activations from the final layers, we can see that the extracted features are more abstract, like eyes, etc. Look at this example below from the 15th convolutional layer:

get_feature = theano.function([model.layers[0].input], model.layers[15].get_output(train=False), allow_input_downcast=False)
feat = get_feature(im)
plt.imshow(feat[0][13])

More semantic feature maps

More semantic feature maps.

As you can see, this second feature map is extracting more abstract features. And you can also note that the image seems to be more stretched when compared with the feature we saw earlier, that is because the the first feature maps has 224×224 size and this one has 56×56 due to the downscaling operations of the layers before the convolutional layer, and that is why we lose a lot of spatial information.

Extracting hypercolumns

Now finally let’s extract the hypercolumns of arbitrary set of layers. To do that, we will define a function to extract these hypercolumns:

def extract_hypercolumn(model, layer_indexes, instance):
    layers = [model.layers[li].get_output(train=False) for li in layer_indexes]
    get_feature = theano.function([model.layers[0].input], layers,
                                  allow_input_downcast=False)
    feature_maps = get_feature(instance)
    hypercolumns = []
    for convmap in feature_maps:
        for fmap in convmap[0]:
            upscaled = sp.misc.imresize(fmap, size=(224, 224),
                                        mode="F", interp='bilinear')
            hypercolumns.append(upscaled)

    return np.asarray(hypercolumns)

As we can see, this function will expect three parameters: the model itself, an list of layer indexes that will be used to extract the hypercolumn features and an image instance that will be used to extract the hypercolumns. Let’s now test the hypercolumn extraction for the first 2 convolutional layers:

layers_extract = [3, 8]
hc = extract_hypercolumn(model, layers_extract, im)

That’s it, we extracted the hypercolumn vectors for each pixel. The shape of this “hc” variable is: (192L, 224L, 224L), which means that we have a 192-dimensional hypercolumn for each one of the 224×224 pixel (a total of 50176 pixels with 192 hypercolumn feature each).

Let’s plot the average of the hypercolumns activations for each pixel:

ave = np.average(hc.transpose(1, 2, 0), axis=2)
plt.imshow(ave)
Hypercolumn average for layers 3 and 8.
Hypercolumn average for layers 3 and 8.

Ad you can see, those first hypercolumn activations are all looking like edge detectors, let’s see how these hypercolumns looks like for the layers 22 and 29:

layers_extract = [22, 29]
hc = extract_hypercolumn(model, layers_extract, im)
ave = np.average(hc.transpose(1, 2, 0), axis=2)
plt.imshow(ave)
Hypercolumn average for the layers 22 and 29.
Hypercolumn average for the layers 22 and 29.

As we can see now, the features are really more abstract and semantically interesting but with spatial information a little fuzzy.

Remember that you can extract the hypercolumns using all the initial layers and also the final layers, including the FC layers. Here I’m extracting them separately to show how they differ in the visualization plots.

Simple hypercolumn pixel clustering

Now, you can do a lot of things, you can use these hypercolumns to classify pixels for some task, to do automatic pixel colorization, segmentation, etc. What I’m going to do here just as an experiment, is to use the hypercolumns (from the VGG-16 layers 3, 8, 15, 22, 29) and then cluster it using KMeans with 2 clusters:

m = hc.transpose(1,2,0).reshape(50176, -1)
kmeans = cluster.KMeans(n_clusters=2, max_iter=300, n_jobs=5, precompute_distances=True)
cluster_labels = kmeans .fit_predict(m)

imcluster = np.zeros((224,224))
imcluster = imcluster.reshape((224*224,))
imcluster = cluster_labels

plt.imshow(imcluster.reshape(224, 224), cmap="hot")
KMeans clustering using hypercolumns.
KMeans clustering using hypercolumns.

Now you can imagine how useful hypercolumns can be to tasks like keypoints extraction, segmentation, etc. It’s a very elegant, simple and useful concept.

I hope you liked it !

– Christian S. Perone

Cite this article as: Christian S. Perone, "Convolutional hypercolumns in Python," in Terra Incognita, 11/01/2016, https://blog.christianperone.com/2016/01/convolutional-hypercolumns-in-python/.