Tag Archives: jit

JIT native code generation for TensorFlow computation graphs using Python and LLVM

Update: Hacker News discussion here.

The TensorFlow Computation Graph

tensorlogo

One of the most amazing components of the TensorFlow architecture is the computation graph that can be serialized using Protocol Buffers. This computation graph follows a well-defined format (click here for the proto files) and describes the computation that you specify (it can be a Deep Learning model like a CNN, a simple Logistic Regression or even any computation you want). For instance, here is an example of a very simple TensorFlow computation graph that we will use in this tutorial (using TensorFlow Python API):

import tensorflow as tf

with tf.Session() as sess:
    input_placeholder = tf.placeholder(tf.int32, 1, name="input")
    sub_op = tf.sub(input_placeholder, tf.constant(2, dtype=tf.int32))
    add_op = tf.add(sub_op, tf.constant(5, dtype=tf.int32))
    output = tf.add(add_op, tf.constant(100, dtype=tf.int32),
                    name="output")
    tf.train.write_graph(sess.graph_def, ".", "graph.pb", True)
Representation of the computation graph.
Representation of the computation graph.

As you can see, this is a very simple computation graph. First, we define the placeholder that will hold the input tensor and after that we specify the computation that should happen using this input tensor as input data. Here we can also see that we’re defining two important nodes of this graph, one is called “input” (the aforementioned placeholder) and the other is called “output“, that will hold the result of the final computation. This graph is the same as the following formula for a scalar: output = (((input - 2)-5)+100), where I intentionally added redundant operations to see LLVM constant propagation later.

In the last line of the code, we’re persisting this computation graph (including the constant values) into a serialized protobuf file. The final True parameter is to output a textual representation instead of binary, so it will produce the following human-readable output protobuf file (I omitted a part of it for brevity):

node {
  name: "input"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 2
      }
    }
  }
}

--- >(omitted for brevity) < ---

node {
  name: "output"
  op: "Add"
  input: "Add"
  input: "Const_2"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
versions {
  producer: 9
}

This is a very simple graph, and TensorFlow graphs are actually never that simple, because TensorFlow models can easily contain more than 300 nodes depending on the model you’re specifying, specially for Deep Learning models.

We’ll use the above graph to show how we can JIT native code for this simple graph using LLVM framework.

The LLVM Frontend, IR and Backend

LLVM-Logo-Derivative-1

The LLVM framework is a really nice, modular and complete ecosystem for building compilers and toolchains. A very nice description of the LLVM architecture that is important for us is shown in the picture below:

LLVM Compiler Architecture
LLVM Compiler Architecture (AOSA/LLVM, Chris Lattner)

(The picture above is just a small part of the LLVM architecture, for a comprehensive description of it, please see the nice article from the AOSA book written by Chris Lattner)

Looking in the image above, we can see that LLVM provides a lot of core functionality, in the left side you see that many languages can write code for their respective language frontends, after that it doesn’t matter in which language you wrote your code, everything is transformed into a very powerful language called LLVM IR (LLVM Intermediate Representation) which is as you can imagine, a intermediate representation of the code just before the assembly code itself. In my opinion, the IR is the key component of what makes LLVM so amazing, because it doesn’t matter in which language you wrote your code (or even if it was a JIT’ed IR), everything ends in the same representation, and then here is where the magic happens, because the IR can take advantage of the LLVM optimizations (also known as transform and analysis passes).

After this IR generation, you can feed it into any LLVM backend to generate native code for any architecture supported by LLVM (such as x86, ARM, PPC, etc) and then you can finally execute your code with the native performance and also after LLVM optimization passes.

In order to JIT code using LLVM, all you need is to build the IR programmatically, create a execution engine to convert (during execution-time) the IR into native code, get a pointer for the function you have JIT’ed and then finally execute it. I’ll use here a Python binding for LLVM called llvmlite, which is very Pythonic and easy to use.

JIT’ing TensorFlow Graph using Python and LLVM

flow

Let’s now use the LLVM and Python to JIT the TensorFlow computational graph. This is by no means a comprehensive implementation, it is very simplistic approach, a oversimplification that assumes some things: a integer closure type, just some TensorFlow operations and also a single scalar support instead of high rank tensors.

So, let’s start building our JIT code; first of all, let’s import the required packages, initialize some LLVM sub-systems and also define the LLVM respective type for the TensorFlow integer type:

from ctypes import CFUNCTYPE, c_int

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import ops

import llvmlite.ir as ll
import llvmlite.binding as llvm

llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()

TYPE_TF_LLVM = {
    types_pb2.DT_INT32: ll.IntType(32),
}

After that, let’s define a class to open the TensorFlow exported graph and also declare a method to get a node of the graph by name:

class TFGraph(object):
    def __init__(self, filename="graph.pb", binary=False):
        self.graph_def = graph_pb2.GraphDef()
        with open("graph.pb", "rb") as f:
            if binary:
                self.graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), self.graph_def)

    def get_node(self, name):
        for node in self.graph_def.node:
            if node.name == name:
                return node

And let’s start by defining our main function that will be the starting point of the code:

def run_main():
    graph = TFGraph("graph.pb", False)
    input_node = graph.get_node("input")
    output_node = graph.get_node("output")

    input_type = TYPE_TF_LLVM[input_node.attr["dtype"].type]
    output_type = TYPE_TF_LLVM[output_node.attr["T"].type]

    module = ll.Module()
    func_type = ll.FunctionType(output_type, [input_type])
    func = ll.Function(module, func_type, name='tensorflow_graph')
    func.args[0].name = 'input'

    bb_entry = func.append_basic_block('entry')
    ir_builder = ll.IRBuilder(bb_entry)

As you can see in the code above, we open the serialized protobuf graph and then get the input and output nodes of this graph. After that we also map the type of the both graph nodes (input/output) to the LLVM type (from TensorFlow integer to LLVM integer). We start then by defining a LLVM Module, which is the top level container for all IR objects. One module in LLVM can contain many different functions, here we will create just one function that will represent the graph, this function will receive as input argument the input data of the same type of the input node and then it will return a value with the same type of the output node.

After that we start by creating the entry block of the function and using this block we instantiate our IR Builder, which is a object that will provide us the building blocks for JIT’ing operations of TensorFlow graph.

Let’s now define the function that will do the real work of converting TensorFlow nodes into LLVM IR:

def build_graph(ir_builder, graph, node):
    if node.op == "Add":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.add(left_op, right_op)

    if node.op == "Sub":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.sub(left_op, right_op)

    if node.op == "Placeholder":
        function_args = ir_builder.function.args
        for arg in function_args:
            if arg.name == node.name:
                return arg
        raise RuntimeError("Input [{}] not found !".format(node.name))

    if node.op == "Const":
        llvm_const_type = TYPE_TF_LLVM[node.attr["dtype"].type]
        const_value = node.attr["value"].tensor.int_val[0]
        llvm_const_value = llvm_const_type(const_value)
        return llvm_const_value

In this function, we receive by parameters the IR Builder, the graph class that we created earlier and the output node. This function will then recursively build the LLVM IR by means of the IR Builder. Here you can see that I only implemented the Add/Sub/Placeholder and Const operations from the TensorFlow graph, just to be able to support the graph that we defined earlier.

After that, we just need to define a function that will take a LLVM Module and then create a execution engine that will execute the LLVM optimization over the LLVM IR before doing the hard-work of converting the IR into native x86 code:

def create_engine(module):
    features = llvm.get_host_cpu_features().flatten()
    llvm_module = llvm.parse_assembly(str(module))
    target = llvm.Target.from_default_triple()
    target_machine = target.create_target_machine(opt=3, features=features)
    engine = llvm.create_mcjit_compiler(llvm_module, target_machine)
    engine.finalize_object()
    print target_machine.emit_assembly(llvm_module)
    return engine

In the code above, you can see that we first get the CPU features (SSE, etc) into a list, after that we parse the LLVM IR from the module and then we create a engine using maximum optimization level (opt=3, roughly equivalent to the GCC -O3 parameter), we’re also printing the assembly code (in my case, the x86 assembly built by LLVM).

And here we just finish our run_main() function:

ret = build_graph(ir_builder, graph, output_node)
ir_builder.ret(ret)

with open("output.ir", "w") as f:
    f.write(str(module))

engine = create_engine(module)

func_ptr = engine.get_function_address("tensorflow_graph")
cfunc = CFUNCTYPE(c_int, c_int)(func_ptr)
ret = cfunc(10)

print "Execution output: {}".format(ret)

As you can see in the code above, we just call the build_graph() method and then use the IR Builder to add the “ret” LLVM IR instruction (ret = return) to return the output of the IR function we just created based on the TensorFlow graph. We’re also here writing the IR output to a external file, I’ll use this LLVM IR file later to create native assembly for other different architectures such as ARM architecture. And finally, just get the native code function address, create a Python wrapper for this function and then call it with the argument “10”, which will be input data and then output the resulting output value.

And that is it, of course that this is just a oversimplification, but now we understand the advantages of having a JIT for our TensorFlow models.

The output LLVM IR, the advantage of optimizations and multiple architectures (ARM, PPC, x86, etc)

For instance, lets create the LLVM IR (using the code I shown above) of the following TensorFlow graph:

import tensorflow as tf

with tf.Session() as sess:
    input_placeholder = tf.placeholder(tf.int32, 1, name="input")
    sub_op = tf.sub(input_placeholder, tf.constant(2, dtype=tf.int32))
    add_op = tf.add(sub_op, tf.constant(5, dtype=tf.int32))
    output = tf.add(add_op, tf.constant(100, dtype=tf.int32),
                    name="output")
    tf.train.write_graph(sess.graph_def, ".", "graph.pb", True)

The LLVM IR generated is this one below:

; ModuleID = ""
target triple = "unknown-unknown-unknown"
target datalayout = ""

define i32 @"tensorflow_graph"(i32 %"input") 
{
entry:
  %".3" = sub i32 %"input", 2
  %".4" = add i32 %".3", 5
  %".5" = add i32 %".4", 100
  ret i32 %".5"
}

As you can see, the LLVM IR looks a lot like an assembly code, but this is not the final assembly code, this is just a non-optimized IR yet. Just before generating the x86 assembly code, LLVM runs a lot of optimization passes over the LLVM IR, and it will do things such as dead code elimination, constant propagation, etc. And here is the final native x86 assembly code that LLVM generates for the above LLVM IR of the TensorFlow graph:

    .text
    .file	"<string>"
    .globl	tensorflow_graph
    .align	16, 0x90
    .type	tensorflow_graph,@function
tensorflow_graph:
    .cfi_startproc
    leal	103(%rdi), %eax
    retq
.Lfunc_end0:
    .size	tensorflow_graph, .Lfunc_end0-tensorflow_graph
    .cfi_endproc

    .section	".note.GNU-stack","",@progbits

As you can see, the optimized code removed a lot of redundant operations, and ended up just doing a add operation of 103, which is the correct simplification of the computation that we defined in the graph. For large graphs, you can see that these optimizations can be really powerful, because we are reusing the compiler optimizations that were developed for years in our Machine Learning model computation.

You can also use a LLVM tool called “llc”, that can take an LLVM IR file and the generate assembly for any other platform you want, for instance, the command-line below will generate native code for ARM architecture:

llc -O3 out.ll -march=arm -o sample.s

The output sample.s file is the one below:

    .text
    .syntax unified
    .eabi_attribute	67, "2.09"	@ Tag_conformance
    .eabi_attribute	6, 1	@ Tag_CPU_arch
    .eabi_attribute	8, 1	@ Tag_ARM_ISA_use
    .eabi_attribute	17, 1	@ Tag_ABI_PCS_GOT_use
    .eabi_attribute	20, 1	@ Tag_ABI_FP_denormal
    .eabi_attribute	21, 1	@ Tag_ABI_FP_exceptions
    .eabi_attribute	23, 3	@ Tag_ABI_FP_number_model
    .eabi_attribute	34, 1	@ Tag_CPU_unaligned_access
    .eabi_attribute	24, 1	@ Tag_ABI_align_needed
    .eabi_attribute	25, 1	@ Tag_ABI_align_preserved
    .eabi_attribute	38, 1	@ Tag_ABI_FP_16bit_format
    .eabi_attribute	14, 0	@ Tag_ABI_PCS_R9_use
    .file	"out.ll"
    .globl	tensorflow_graph
    .align	2
    .type	tensorflow_graph,%function
tensorflow_graph:                       @ @tensorflow_graph
    .fnstart
@ BB#0:                                 @ %entry
    add	r0, r0, #103
    mov	pc, lr
.Lfunc_end0:
    .size	tensorflow_graph, .Lfunc_end0-tensorflow_graph
    .fnend

    .section	".note.GNU-stack","",%progbits

As you can see above, the ARM assembly code is also just a “add” assembly instruction followed by a return instruction.

This is really nice because we can take natural advantage of the LLVM framework. For instance, today ARM just announced the ARMv8-A with Scalable Vector Extensions (SVE) that will support 2048-bit vectors, and they are already working on patches for LLVM. In future, a really nice addition to LLVM would be the development of LLVM Passes for analysis and transformation that would take into consideration the nature of Machine Learning models.

And that’s it, I hope you liked the post ! Is really awesome what you can do with a few lines of Python, LLVM and TensorFlow.

Update 22 Aug 2016: Josh Klontz just pointed his cool project called Likely on Hacker News discussion.

Update 22 Aug 2016: TensorFlow team is actually working on a JIT (I don’t know if they are using LLVM, but it seems the most reasonable way to go in my opinion). In their paper, there is also a very important statement regarding Future Work that I cite here:

“We also have a number of concrete directions to improve the performance of TensorFlow. One such direction is our initial work on a just-in-time compiler that can take a subgraph of a TensorFlow execution, perhaps with some runtime profiling information about the typical sizes and shapes of tensors, and can generate an optimized routine for this subgraph. This compiler will understand the semantics of perform a number of optimizations such as loop fusion, blocking and tiling for locality, specialization for particular shapes and sizes, etc.” – TensorFlow White Paper

Full code

from ctypes import CFUNCTYPE, c_int

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import ops

import llvmlite.ir as ll
import llvmlite.binding as llvm

llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()

TYPE_TF_LLVM = {
    types_pb2.DT_INT32: ll.IntType(32),
}


class TFGraph(object):
    def __init__(self, filename="graph.pb", binary=False):
        self.graph_def = graph_pb2.GraphDef()
        with open("graph.pb", "rb") as f:
            if binary:
                self.graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), self.graph_def)

    def get_node(self, name):
        for node in self.graph_def.node:
            if node.name == name:
                return node


def build_graph(ir_builder, graph, node):
    if node.op == "Add":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.add(left_op, right_op)

    if node.op == "Sub":
        left_op_node = graph.get_node(node.input[0])
        right_op_node = graph.get_node(node.input[1])
        left_op = build_graph(ir_builder, graph, left_op_node)
        right_op = build_graph(ir_builder, graph, right_op_node)
        return ir_builder.sub(left_op, right_op)

    if node.op == "Placeholder":
        function_args = ir_builder.function.args
        for arg in function_args:
            if arg.name == node.name:
                return arg
        raise RuntimeError("Input [{}] not found !".format(node.name))

    if node.op == "Const":
        llvm_const_type = TYPE_TF_LLVM[node.attr["dtype"].type]
        const_value = node.attr["value"].tensor.int_val[0]
        llvm_const_value = llvm_const_type(const_value)
        return llvm_const_value


def create_engine(module):
    features = llvm.get_host_cpu_features().flatten()
    llvm_module = llvm.parse_assembly(str(module))
    target = llvm.Target.from_default_triple()
    target_machine = target.create_target_machine(opt=3, features=features)
    engine = llvm.create_mcjit_compiler(llvm_module, target_machine)
    engine.finalize_object()
    print target_machine.emit_assembly(llvm_module)
    return engine


def run_main():
    graph = TFGraph("graph.pb", False)
    input_node = graph.get_node("input")
    output_node = graph.get_node("output")

    input_type = TYPE_TF_LLVM[input_node.attr["dtype"].type]
    output_type = TYPE_TF_LLVM[output_node.attr["T"].type]

    module = ll.Module()
    func_type = ll.FunctionType(output_type, [input_type])
    func = ll.Function(module, func_type, name='tensorflow_graph')
    func.args[0].name = 'input'

    bb_entry = func.append_basic_block('entry')
    ir_builder = ll.IRBuilder(bb_entry)

    ret = build_graph(ir_builder, graph, output_node)
    ir_builder.ret(ret)

    with open("output.ir", "w") as f:
        f.write(str(module))

    engine = create_engine(module)

    func_ptr = engine.get_function_address("tensorflow_graph")
    cfunc = CFUNCTYPE(c_int, c_int)(func_ptr)
    ret = cfunc(10)

    print "Execution output: {}".format(ret)


if __name__ == "__main__":
    run_main()

 

Genetic Programming and a LLVM JIT for restricted Python AST expressions

A small intro on the rationale

So I’m working on a Symbolic Regression Machine written in C/C++ called Shine, which is intended to be a JIT for Genetic Programming libraries (like Pyevolve for instance). The main rationale behind Shine is that we have today a lot of research on speeding Genetic Programming using GPUs (the GPU fever !) or any other special hardware, etc, however we don’t have many papers talking about optimizing GP using the state of art compilers optimizations like we have on clang, gcc, etc.

The “hot spot” or the component that consumes a lot of CPU resources today on Genetic Programming is the evaluation of each individual in order to calculate the fitness of the program tree. This evaluation is often executed on each set of parameters of the “training” set. Suppose you want to make a symbolic regression of a single expression like the Pythagoras Theorem and you have a linear space of parameters from 1.0 to 1000.0 with a step of 0.1 you have 10.000 evaluations for each individual (program tree) of your population !

What Shine does is described on the image below:

It takes the individual of the Genetic Programming engine and then converts it to LLVM Intermediate Representation (LLVM assembly language), after that it runs the transformation passes of the LLVM (here is where the true power of modern compilers enter on the GP context) and then the LLVM JIT converts the optimized LLVM IR into native code for the specified target (X86, PowerPC, etc).

You can see below the Shine architecture:

This architecture brings a lot of flexibility for Genetic Programming, you can for instance write functions that could be used later on your individuals on any language supported by the LLVM, what matters to Shine is the LLVM IR, you can use any language that LLVM supports and then use the IR generated by LLVM, you can mix code from C, C++, Ada, Fortran, D, etc and use your functions as non-terminal nodes of your Genetic Programming trees.

Shine is still on its earlier development, it looks a simple idea but I still have a lot of problems to solve, things like how to JIT the evaluation process itself instead of doing calls from Python using ctypes bindings of the JITed trees.

Doing Genetic Programming on the Python AST itself

During the development of Shine, an idea happened to me, that I could use a restricted Python Abstract Syntax Tree (AST) as the representation of individuals on a Genetic Programming engine, the main advantage of this is the flexibility and the possibility to reuse a lot of things. Of course that a shared library written in C/C++ would be useful for a lot of Genetic Programming engines that doesn’t uses Python, but since my spare time to work on this is becoming more and more rare I started to rethink the approach and use Python and the LLVM bindings for LLVM (LLVMPY) and I just discovered that is pretty easy to JIT a restricted set of the Python AST to native code using LLVM, and this is what this post is going to show.

JIT’ing a restricted Python AST

The most amazing part of LLVM is obviously the amount of transformation passes, the JIT and of course the ability to use the entire framework through a simple API (ok, not so simple sometimes). To simplify this example, I’m going to use an arbitrary restricted AST set of the Python AST that supports only subtraction (-), addition (+), multiplication (*) and division (/).

To understand the Python AST, you can use the Python parser that converts source into AST:

1
2
3
4
5
6
7
8
9
10
>>> import ast
>>> astp = ast.parse("2*7")
>>> ast.dump(astp)
'Module(
    body=[Expr(
        value=BinOp(
            left=Num(n=2), op=Mult(), right=Num(n=7)
        )
    )]
)'

What the parse created was an Abstract Syntax Tree containing the BinOp (Binary Operation) with the left operator as the number 2, the right operator as the number 7 and the operation itself as Multiplication(Mult), very easy to understand. What we are going to do to create the LLVM IR is to create a visitor that is going to visit each node of the tree. To do that, we can subclass the Python NodeVisitor class from the ast module. What the NodeVisitor does is to visit each node of the tree and then call the method ‘visit_OPERATOR’ if it exists, when the NodeVisitor is going to visit the node for the BinOp for example, it will call the method ‘visit_BinOp’ passing as parameter the BinOp node itself.

The structure of the class for for the JIT visitor will look like the code below:

1
2
3
4
5
6
7
8
9
10
# Import the ast and the llvm Python bindings
import ast
from llvm import *
from llvm.core import *
from llvm.ee import *
import llvm.passes as lp

class AstJit(ast.NodeVisitor):
    def __init__(self):
        pass

What we need to do now is to create an initialization method to keep the last state of the JIT visitor, this is needed because we are going to JIT the content of the Python AST into a function and the last instruction of the function needs to return what was the result of the last instruction visited by the JIT. We also need to receive a LLVM Module object in which our function will be created as well the closure type, for the sake of simplicity I’m not type any object, I’m just assuming that all numbers from the expression are integers, so the closure type will be the LLVM integer type.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def __init__(self, module, parameters):
        self.last_state = None
        self.module = module
        # Parameters that will be created on the IR function
        self.parameters = parameters
        self.closure_type = Type.int()
        # An attribute to hold a link to the created function
        # so we can use it to JIT later
        self.func_obj = None
        self._create_builder()

    def _create_builder(self):
        # How many parameters of integer type
        params = [self.closure_type] * len(self.parameters)

        # The prototype of the function, returning a integer
        # and receiving the integer parameters
        ty_func = Type.function(self.closure_type, params)

        # Add the function to the module with the name 'func_ast_jit'
        self.func_obj = self.module.add_function(ty_func, 'func_ast_jit')

        # Create an argument in the function for each parameter specified
        for index, pname in enumerate(self.parameters):
            self.func_obj.args[index].name = pname

        # Create a basic block and the builder
        bb = self.func_obj.append_basic_block("entry")
        self.builder = Builder.new(bb)

Now what we need to implement on our visitor is the ‘visit_OPERATOR’ methods for the BinOp and for the Numand Name operators. We will also implement the method to create the return instruction that will return the last state.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    # A 'Name' is a node produced in the AST when you
    # access a variable, like '2+x+y', 'x' and 'y' are
    # the two names created on the AST for the expression.
    def visit_Name(self, node):
        # This variable is what function argument ?
        index = self.parameters.index(node.id)
        self.last_state = self.func_obj.args[index]
        return self.last_state

    # Here we create a LLVM IR integer constant using the
    # Num node, on the expression '2+3' you'll have two
    # Num nodes, the Num(n=2) and the Num(n=3).
    def visit_Num(self, node):
        self.last_state = Constant.int(self.closure_type, node.n)
        return self.last_state
 
    # The visitor for the binary operation
    def visit_BinOp(self, node):
        # Get the operation, left and right arguments
        lhs = self.visit(node.left)
        rhs = self.visit(node.right)
        op = node.op

        # Convert each operation (Sub, Add, Mult, Div) to their
        # LLVM IR integer instruction equivalent
        if isinstance(op, ast.Sub):
            op = self.builder.sub(lhs, rhs, 'sub_t')
        elif isinstance(op, ast.Add):
            op = self.builder.add(lhs, rhs, 'add_t')
        elif isinstance(op, ast.Mult):
            op = self.builder.mul(lhs, rhs, 'mul_t')
        elif isinstance(op, ast.Div):
            op = self.builder.sdiv(lhs, rhs, 'sdiv_t')
       
        self.last_state = op
        return self.last_state

    # Build the return (ret) statement with the last state
    def build_return(self):
        self.builder.ret(self.last_state)

And that is it, our visitor is ready to convert a Python AST to a LLVM IR assembly language, to run it we’ll first create a LLVM module and an expression:

1
2
3
4
5
module = Module.new('ast_jit_module')
# Note that I'm using two variables 'a' and 'b'
expr = "(2+3*b+33*(10/2)+1+3/3+a)/2"
node = ast.parse(expr)
print ast.dump(node)

Will output:

Module(body=[Expr(value=BinOp(left=BinOp(left=BinOp(left=BinOp(
left=BinOp(left=BinOp(left=Num(n=2), op=Add(), right=BinOp(
left=Num(n=3), op=Mult(), right=Name(id='b', ctx=Load()))), op=Add(),
right=BinOp(left=Num(n=33), op=Mult(), right=Num(n=2))), op=Add(),
right=Num(n=1)), op=Add(), right=Num(n=3)), op=Add(),
right=Name(id='a', ctx=Load())), op=Div(), right=Num(n=2)))])

Now we can finally run our visitor on that generated AST the check the LLVM IR output:

1
2
3
4
visitor = AstJit(module, ['a', 'b'])
visitor.visit(node)
visitor.build_return()
print module

Will output the LLVM IR:

; ModuleID = 'ast_jit_module'

define i32 @func_ast_jit(i32 %a, i32 %b) {
entry:
  %mul_t = mul i32 3, %b
  %add_t = add i32 2, %mul_t
  %add_t1 = add i32 %add_t, 165
  %add_t2 = add i32 %add_t1, 1
  %add_t3 = add i32 %add_t2, 1
  %add_t4 = add i32 %add_t3, %a
  %sdiv_t = sdiv i32 %add_t4, 2
  ret i32 %sdiv_t
}

Now is when the real fun begins, we want to run LLVM optimization passes to optimize our code with an equivalent GCC -O2 optimization level, to do that we create a PassManagerBuilder and a PassManager, the PassManagerBuilder is the component that adds the passes to the PassManager, you can also manually add arbitrary transformations like dead code elimination, function inlining, etc:

1
2
3
4
5
6
7
8
9
10
pmb = lp.PassManagerBuilder.new()
# Optimization level
pmb.opt_level = 2

pm = lp.PassManager.new()
pmb.populate(pm)

# Run the passes into the module
pm.run(module)
print module

Will output:

; ModuleID = 'ast_jit_module'

define i32 @func_ast_jit(i32 %a, i32 %b) nounwind readnone {
entry:
  %mul_t = mul i32 %b, 3
  %add_t3 = add i32 %a, 169
  %add_t4 = add i32 %add_t3, %mul_t
  %sdiv_t = sdiv i32 %add_t4, 2
  ret i32 %sdiv_t
}

And here we have the optimized LLVM IR of the Python AST expression. The next step is to JIT that IR into native code and then execute it with some parameters:

1
2
3
4
5
6
    ee = ExecutionEngine.new(module)
    arg_a = GenericValue.int(Type.int(), 100)
    arg_b = GenericValue.int(Type.int(), 42)
   
    retval = ee.run_function(visitor.func_obj, [arg_a, arg_b])
    print "Return: %d" % retval.as_int()

Will output:

Return: 197

And that’s it, you have created a AST->LLVM IR converter, optimized the LLVM IR with the transformation passes and then converted it to native code using the LLVM execution engine. I hope you liked =)

The future can be written in RPython now

Following the recent article arguing why PyPy is the future of Python, I must say, PyPy is not the future of Python, is the present. When I have tested it last time (PyPy-c 1.1.0) with Pyevolve into the optimization of a simple Sphere function, it was at least 2x slower than Unladen Swallow Q2, but in that time, PyPy was not able to JIT. Now, with this new release of PyPy and the JIT’ing support, the scenario has changed.

PyPy has evolved a lot (actually, you can see this evolution here), a nice work was done on the GC system, saving (when compared to CPython) 8 bytes per object allocated, which is very interesting for applications that makes heavy use of object allocation (GP system are a strong example of this, since when they are implemented on object oriented languages, each syntax tree node is an object). Efforts are also being done to improve support for CPython extensions (written in C/C++), one of them is a little tricky: the use of RPyC, to proxy through TCP the remote calls to CPython; but the other seems by far more effective, which is the creation of the CPyExt subsystem. By using CPyExt, all you need is to have your CPython API functions implemented in CPyExt, a lot of people is working on this right now and you can do it too, it’s a long road to have a good API coverage, but when you think about advantages, this road becomes small.

In order to benchmark CPython, Jython, CPython+Psyco, Unladen Swallow and PyPy, I’ve used the Rastrigin function optimization (an example of that implementation is here in the Example 7 of Pyevolve 0.6rc1):

f(x) = 10n + \sum_{i=1}^{n}{x_{i}^{2}} -10\cos(2\pi x_{i})

Due to its large search space and number of local minima, Rastrigin function is often used to measure the performance of Genetic Algorithms. Rastrigin function has a global minimum at x=0 where the f(x) = 0; in order to increase the search space and required resources, I’ve used 40 variables (n=40)  and 10k generations.

Here are the information about versions used in this benchmark:

No warmup was performed in JVM or in PyPy. PyPy translator was executed using the “-Ojit” option in order to get the JIT version of the Python interpreter. The JVM was executed using the server mode, I’ve tested the client and server mode for Sun JVM and IcedTea6, the best results were observed from the server mode using Sun JVM, however when I’ve compared the client mode of IcedTea6 with the client mode of Sun JVM, the best results observed were from IcedTea6 (the same as using server mode in IcedTea6). Unladen Swallow was compiled using the project wiki instructions for building an optimized binary.

The machine used was an Intel(R) Core(TM) 2 Duo E4500 (2×2.20Ghz) with 2GB of RAM.

The result of the benchmark (measured using wall time) in seconds for each setup (these results are the best of 3 sequential runs):

As you can see, PyPy with JIT got a speedup of 2.57x when compared to CPython 2.6.5 and 2.0x  faster than Unladen Swallow current trunk.

PyPy is not only the future of Python, but is becoming the present right now. PyPy will not bring us only an implementation of Python in Python (which in itself is the valuable result of great efforts), but also will bring the performance back (which many doubted at the beginning, wondering how could it be possible for an implementation of Python in Python be faster than an implementation in C ? And here is where the translation and JIT magic enters). When the time comes that Python interpreter can be entire written in a high level language (actually almost the same language, which is really weird), Python community can put their focus on improving the language itself instead of spending time solving the complexity of the lower level languages, is this not the great point of those efforts ?

By the way, just to note, PyPy isn’t only a translator for the Python interpreter written in RPython, it’s a translator of RPython, what means that PyPy isn’t only the future of Python, but probably, the future of many interpreters.

A method for JIT’ing algorithms and data structures with LLVM

llvm_dragon

Hello folks, I always post about Python and EvoComp (Pyevolve), but this time it’s about C, LLVM, search algorithms and data structures. This post describes the efforts to implement an idea: to JIT (verb) algorithms and the data structures used by them, together.

AVL Tree Intro

Here is a short intro to AVL Trees from Wikipedia:

In computer science, an AVL tree is a self-balancing binary search tree, and it is the first such data structure to be invented. In an AVL tree, the heights of the two child subtrees of any node differ by at most one; therefore, it is also said to be height-balanced. Lookup, insertion, and deletion all take O(log n) time in both the average and worst cases, where n is the number of nodes in the tree prior to the operation. Insertions and deletions may require the tree to be rebalanced by one or more tree rotations.

The problem and the idea

When we have a data structure and algorithms to handle (insert, remove and lookup) that structure, the native code of our algorithm is usually full of overhead; for example, in an AVL Tree (Balanced Binary Tree), the overhead appear in: checking if we really have a left or right node while traversing the nodes for lookups, accessing nodes inside nodes, etc. This overhead creates unnecessary assembly operations which in turn, creates native code overhead, even when the compiler optimize it. This overhead directly impacts on the performance of our algorithm (this traditional approach, of course, give us a very flexible structure and the complexity (not Big-O) is easy to handle, but we pay for it: performance loss).

Continue reading A method for JIT’ing algorithms and data structures with LLVM