Machine Learning, Programming

Tutorial on using LLVM to JIT PyTorch fx graphs to native code (x86/arm/risc-v/wasm) (Part I – Scalars)

In 2009 I started playing with LLVM for some projects (data structure jit, for genetic programming, jit for tensorflow graphs, etc), and in these projects I realized how powerful LLVM design was at the time (and still is): using an elegant IR (intermediate representation) with an user-facing API and modularized front-ends and backends with plenty of transformation and optimization passes. Nowadays, LLVM is the main engine behind many compilers and JIT compilation and where most of the modern developments in compilers is happening.

On a related note, PyTorch is doing an amazing job of exposing more of the torch tracing system and its IR and graphs, not to mention their work on recent fusers and TorchDynamo. In this context, I was doing a small test to re-implement Shine, but using ATen ops for tensors and realized that there were not many educative tutorials on how to use LLVM to JIT PyTorch graphs, so this is a quick series (if time helps there will be more following posts) on how to use LLVM (python bindings) to go from PyTorch graphs (as traced by torch.fx) to LLVM IR and native code.

Detour – PyTorch NNC (Neural Net Compiler)

PyTorch itself also has a compiler that uses LLVM to generate native code for subgraphs that the fuser identifies. This is also called NNC (Neural Net Compiler) or Tensor Expressions (TE) as well, you can read more about them here in the C++ API tutorial. One thing to note though is that default binaries you get from PyTorch do not come linked to the LLVM libraries, so you need to compile it by yourself and enable LLVM backend, otherwise it won’t use LLVM to do the JIT compilation/optimization of the subgraphs. Let’s take a look at it first before starting our tutorial

Enabling LLVM backend support in PyTorch

The first thing you need to do to enable the LLVM backend in the NNC is to build PyTorch with LLVM support enabled (and provide LLVM lib as well). In my case I used LLVM 14.0.6 from homebrew in a MacOS (w/ Apple Silicon is a little more tricky to compile PyTorch + LLVM, as you will need to manually edit some CMake files), but you can get LLVM basically anywhere. To compile PyTorch w/ LLVM enabled and disabling what we don’t need for this tutorial (and save a lot of compilation time), you can just use as follows:

USE_KINETO=0 BUILD_CAFFE2=0 USE_DISTRIBUTED=0 USE_NCCL=0 BUILD_TEST=0 USE_XNNPACK=0 USE_FBGEMM=0 USE_QNNPACK=0 USE_MKLDNN=0 USE_MIOPEN=0 USE_NNPACK=0 BUILD_CAFFE2_OPS=0 USE_TENSORPIPE=0 CC=clang CXX=clang++ USE_LLVM=[the path to your LLVM] python setup.py develop

After that, you should get True for the following call to check if LLVM is enabled:

>>> torch._C._llvm_enabled()
True

Enabling NNC fuser

Now we can try a small “fuse-opportunistic” (you can see how bad I’m at creating new technical terms) function to check the LLVM IR generated by NNC using LLVM:

import torch

# Our example function
def our_function(x):
    a = torch.mul(x, x)
    b = torch.sin(a)
    c = torch.cos(b)
    d = torch.mul(c, c)
    f = torch.mul(d, d)
    return d + f

# Setting flags to enable NNC (some are not strictly necessary)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_set_te_must_use_llvm_cpu(True)

# Script the function, this is not where the
# LLVM IR will be emitted
scripted = torch.jit.script(our_function)

# Execute the JIT compiled function
arg = torch.randn(1)
scripted(arg)

And that’s it, we have a very simple function and we execute it with a random scalar. PyTorch now will use LLVM JIT compile the function to native code. Now, if we want to see the output of the LLVM IR, we need to run:

PYTORCH_JIT_LOG_LEVEL=">>llvm_codegen" python example.py

The flag using “llvm_codegen” is to enable dumping of the LLVM IR and debugging information. However, there is a catch here: if you execute the command above, you will see that nothing will be shown. Why is that happening ? PyTorch uses profiling to specialize the graphs, so if we execute just once, it will capture profiling information (especially about types and shapes) to then in the next run generate the JIT compiled function using LLVM. In the first run, some profiling nodes will be added into the graph and then later after it specializes it, it will add some guards to make sure that we can always execute the original function as a fallback or either specialize again the function to different shapes/types. To correct the code, you just have to do:

# We get the number of profiled runs and add one to it
num_profiled = torch._C._jit_get_num_profiled_runs() + 1
for _ in range(num_profiled):
    arg = torch.randn(1)
    scripted(arg)

Now if you execute it again with the debugging flag enabled, you will see a lot of debugging. The important part for us here is the LLVM IR after optimization (the optimization is done by LLVM as well with many passes that could do DCE (Dead Code Elimination), inlining, instruction combination, among many other optimization techniques that are also done for regular compiled code. Let’s see the emitted LLVM IR:

; ModuleID = 'pytorch'
source_filename = "pytorch"
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-apple-darwin20.6.0"
; Function Attrs: mustprogress nofree nosync nounwind willreturn
define i32 @fused_mul_sin_cos_mul_mul_add(i8** nocapture readonly %0) local_unnamed_addr #0 {
wrapBB:
  %1 = bitcast i8** %0 to float**
  %2 = load float*, float** %1, align 8
  %3 = getelementptr i8*, i8** %0, i64 1
  %4 = bitcast i8** %3 to float**
  %5 = load float*, float** %4, align 8
  %.val = load float, float* %2, align 4
  call void @llvm.experimental.noalias.scope.decl(metadata !0)
  %6 = fmul float %.val, %.val
  %7 = tail call float @sinf(float %6) #3
  %8 = tail call float @cosf(float %7) #3
  %9 = fmul float %8, %8
  %10 = fmul float %9, %9
  %11 = fadd float %9, %10
  store float %11, float* %5, align 4, !alias.scope !0
  ret i32 0
}

In the IR above I omitted some LLVM metadata to make it short. As you can see, the code is doing the same operations seen in the function but replacing some ops to use IR instructions such as fmul (float multiplication), etc. Also, note the name of the generated function fused_mul_sin_cos_mul_mul_add, this is basically saying that it is fusing mul/sin/cos/mul/mul/add together, which is exactly the order of operations we have in our function.

Generated native code and static compilation with LLVM LLC

There is also a target triple, that instructs LLVM what is the system is has to emit native code for, in this case I’m on a x86 architecture. This LLVM IR was also already optimized with similar passes you see when you use clang on regular C/C++ code for example. The powerful thing about having this LLVM IR is that now it can emit assembly code for the architecture, which can be seen below:

	.section	__TEXT,__text,regular,pure_instructions
	.build_version macos, 11, 0
	.globl	_fused_mul_sin_cos_mul_mul_add
	.p2align	4, 0x90
_fused_mul_sin_cos_mul_mul_add:
	pushq	%rbx
	movq	(%rdi), %rax
	movq	8(%rdi), %rbx
	vmovss	(%rax), %xmm0
	vmulss	%xmm0, %xmm0, %xmm0
	movabsq	$_sinf, %rax
	callq	*%rax
	movabsq	$_cosf, %rax
	callq	*%rax
	vmulss	%xmm0, %xmm0, %xmm0
	vfmadd213ss	%xmm0, %xmm0, %xmm0
	vmovss	%xmm0, (%rbx)
	xorl	%eax, %eax
	popq	%rbx
	retq

This code above is the assembly code generated for the target triplet x86_64 and this is what will ultimately get executed. You can also do some interesting things with the LLVM IR, you can just get the LLVM IR, save into a file and compile it with the LLVM static compiler as an object that can be linked together with any other application:

llc -filetype=obj code.ll -o code.o

In this example above we compiled it into a code object called code.o that can be linked together with other application. if you list the symbols, you will see the function there:

~# llvm-nm code.o

                 U _cosf
0000000000000000 T _fused_mul_sin_cos_mul_mul_add
                 U _sinf

You can see that we have the fused function with a code “T” (that can be found in the text header of the object, so it is provided by the object itself) and the _cosf and _sinf that are “U” (undefined) as you need to link with another library to provide these symbols that are used by the fused function.

We will do some other funny things later with the LLVM IR, but one thing you need to realize is that once you have a LLVM IR, it opens the door for really interesting and useful things. Since we talked enough for a quick detour on PyTorch NNC compiler, let’s jump into our own implementation of a simple JIT for float scalars only, which is the main goal of this tutorial.

Building our own scalar JIT compiler with LLVM

Simplifying assumptions

The reason I decided to start with simple scalars is that it simplifies a lot the LLVM IR generation as it doesn’t have to deal with tensor shapes. We also assume that the scalars are all floats, just like when we have a closure in Genetic Programming graphs, and this also makes our code minimal and understandable as we don’t have to deal with supporting multiple types operations. A full-fledged JIT compiler is as complex as PyTorch’s NNC is, but the goal here is to show the main intuition of how we can build such compilers with LLVM.

Differently than PyTorch’s NNC that creates the LLVM IR in C++, we will create the IR in Python using the Python bindings called llvmlite (also used by Numba project).

Getting PyTorch’s graph with torch.fx

To get PyTorch graphs, we will use the new torch.fx module and its symbolic tracer that does a symbolic execution of the code and produces a graph for us. We will then traverse this graph and emit LLVM IR code using LLVM operators.

Let’s first define a not very useful function, but that will allow for many optimization techniques to improve it:

def fn(x):
    a = x + 2.0
    b = a + 2.0
    b += b
    c = b - a
    e = a * 3
    e = e / c
    d = b + c + a
    return a

As you can see, it is a very boring function but in this code we can see already the use of multiple operations such as multiplication, subtraction, addition and division. Let’s see now the graph that the symbolic tracer from torch.fx produced:

>>> gm = torch.fx.symbolic_trace(fn)
>>> type(gm)
<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>
>>> type(gm.graph)
<class 'torch.fx.graph.Graph'>
>>> gm.graph.print_tabular()

opcode         name     target                       args            kwargs
-------------  -------  ---------------------------  --------------  --------
placeholder    x        x                            ()              {}
call_function  add      <built-in function add>      (x, 2.0)        {}
call_function  add_1    <built-in function add>      (add, 2.0)      {}
call_function  add_2    <built-in function add>      (add_1, add_1)  {}
call_function  sub      <built-in function sub>      (add_2, add)    {}
call_function  mul      <built-in function mul>      (add, 3)        {}
call_function  truediv  <built-in function truediv>  (mul, sub)      {}
call_function  add_3    <built-in function add>      (add_2, sub)    {}
call_function  add_4    <built-in function add>      (add_3, add)    {}
output         output   output                       (add,)          {}

As you can see above, the symbolic tracer returns a graph module with the graph generated from the symbolic execution of the function. We are showing above as well a tabular version of the graph, where we have different types of opcodes that we will have to support, together with different function targets and constants. The first node type (opcode) we can see here is the placeholder, this opcode represents a function input, or a parameter of the function, in this case we have the “x” which is the input of the traced function we provided.

The next opcodes we can see is the call_function that summons Cthulhu every time the interpreter sees it, just kidding, it really just calls a function. At the end we have the output opcode, that is saying that the output of our function is the node called add. Note that this node is not a function, but the node that has this name (the second row in the tabular representation shown above). You can see already that there is a lot of dead code there in the function, assuming no external effects from the other ops, we only need to return the result of the node that doesn’t require all the computation below, so we will see that LLVM will optimize that for us later.

Emitting IR for the function definition and main block

We will create a class that will receive the torch.fx graph and then generate the LLVM IR. The first step is to create a function called forward() (trying to keep as close as possible of the PyTorch conventions) where it will contain all the instructions. To create a function in LLVM IR we need to define its return type and argument types as well, the complete function declaration. Here is the initial part of our LLVMCompiler class:

class LLVMCompiler:
    def __init__(self, gm):
        self.graph = gm.graph
        self.vars = {}
        self.closure_type = ir.DoubleType()
        self.llvm_module = ir.Module()
        self.ir_builder = self.build_forward()

        self.fn_map = {
            ops.add: self.ir_builder.fadd,
            ops.sub: self.ir_builder.fsub,
            ops.truediv: self.ir_builder.fdiv,
            ops.mul: self.ir_builder.fmul,
        }


    def build_forward(self):
        placeholders = [node for node in gm.graph.nodes
                        if node.op == "placeholder"]
        arg_types = [self.closure_type for _ in placeholders]
        func_forward_type = ir.FunctionType(self.closure_type,
                                            arg_types)
        func_forward = ir.Function(self.llvm_module,
                                   func_forward_type,
                                   name="forward")

        for arg, placeholder in zip(func_forward.args, placeholders):
            arg.name = placeholder.name
            self.vars[arg.name] = arg

        block = func_forward.append_basic_block(name="entry")
        return ir.IRBuilder(block)

Let’s unpack this code:

  • In the constructor we instantiate a few useful things such a dict called self.vars that will hold a map from a variable name to the variable itself (in case they are referenced later);
  • We also create a LLVM Module. This module is the top level container of all other LLVM IR instructions;
  • We created also a closure type that is a DoubleType, this is the type that will be used for all arguments, variables and return types;
  • We have also a mapping from Python operators +, -, / and * to their LLVM IR instruction counterparts for float types, we need that because we will be converting these Python operation calls to LLVM IR instruction calls;
  • Next we call the build_forward() method:
    • In this method we get all the placeholder nodes of the torch.fx graph and create a function called forward()that will have our instructions;
    • We then proceed by adding the the function arguments into the self.vars dict, so they can be referenced later (these are the inputs, the placeholders);
    • We finally create a “basic block”, which in the LLVM IR represents a single entry single exit section of code. After that we return a IRBuilder that will be used to add the function instructions, note that this IR is tied to the entry block we added int the function, so any instruction built with this IR builder, will be connected to this basic block of our function;
    • In summary, we have: LLVM Module -> Function -> Entry block

Emitting IR instructions for the torch.fx nodes

Now that we have our LLVM Module with the function declaration and an entry block, we can now traverse the torch.fx graph and emit the IR instructions for the code of the function. We need to support 2 more opcodes: call_function and  output. Let’s now show the complete code for the LLVMCompiler:

class LLVMCompiler:
    def __init__(self, gm):
        self.graph = gm.graph
        self.vars = {}
        self.closure_type = ir.DoubleType()
        self.llvm_module = ir.Module()
        self.ir_builder = self.build_forward()

        self.fn_map = {
            ops.add: self.ir_builder.fadd,
            ops.sub: self.ir_builder.fsub,
            ops.truediv: self.ir_builder.fdiv,
            ops.mul: self.ir_builder.fmul,
        }


    def build_forward(self):
        placeholders = [node for node in gm.graph.nodes
                        if node.op == "placeholder"]
        arg_types = [self.closure_type for _ in placeholders]
        func_forward_type = ir.FunctionType(self.closure_type,
                                            arg_types)
        func_forward = ir.Function(self.llvm_module,
                                   func_forward_type,
                                   name="forward")

        for arg, placeholder in zip(func_forward.args, placeholders):
            arg.name = placeholder.name
            self.vars[arg.name] = arg

        block = func_forward.append_basic_block(name="entry")
        return ir.IRBuilder(block)

    def jit(self):
        for node in self.graph.nodes:
            self.emit_node(node)

    def emit_node(self, node):
        try:
            emit_fn = getattr(self, f"emit_{node.op}")
        except:
            raise RuntimeError(f"Unknown node type: {node}, {node.op}.")

        emit_fn(node)

    def emit_placeholder(self, node):
        return

    def emit_call_function(self, node):
        args = []
        for arg in node.args:
            if isinstance(arg, Node):
                args.append(self.vars[arg.name])
            else:
                constant_arg = self.closure_type(arg)
                args.append(constant_arg)

        fn = self.fn_map[node.target]
        call_inst = fn(args[0], args[1], name=node.name)
        self.vars[node.name] = call_inst
        return

    def emit_output(self, node):
        return_arg = node.args[0].name
        return_var = self.vars[return_arg]
        self.ir_builder.ret(return_var)
        return

We added now a few new methods, let’s go through each one of them:

jit(): this method will traverse the torch.fx graph and will call the emit_node() method for each node;

emit_node(): here we will dynamically find the method named emit_[opcode] and call it with the node, it is just a simple dynamic dispatch to make it easier to separate the node processing;

emit_call_function(): this method will emit the LLVM IR for the opcode that calls a function, when you have something like 2 + 2, we are calling the python add() built-in for two constants. This method will first check if the arguments of the node are a node themselves or if it is a constant, if it is a constant it will just emit a LLVM IR constant with the closure type (DoubleType), or if the argument is a node, it will look in the self.vars dictionary to find it. After that, we use the mapping of Python built-in operators to LLVM IR instructions to get the LLVM IR instruction and then we just add a call instruction in the IR by calling it with the created arguments. The last thing we do is to add the result of this instruction itself in the self.vars as it can be referenced later by other following operators.

emit_output(): here we are dealing with the output opcode, that is telling we should return a value. What we do is to use the LLVM IR “ret” instruction to return the argument and that is it.

That’s pretty much it, with this code we can already create LLVM IR code to support the function we showed earlier. Even in the context of the simplifying assumptions that we are, it is impressive that we can build a IR code in a few lines of Python code, but that’s the power of torch.fx and LLVM.

Viewing the LLVM IR

Let’s now run our code in the function we defined and show the generated LLVM IR. For that we will use the ModuleRef from llvmlite:

# Initialization calls required by LLVM
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()

# Symbolic tracing and compilation
gm = torch.fx.symbolic_trace(fn)

# Run our compiler on the torch.fx graph
compiler = LLVMCompiler(gm)
compiler.jit()

# llvmlite has two modules: Module and ModuleRef, here we use the
# ModuleRef as it will allow us to generated native code
modref = llvm.parse_assembly(str(compiler.llvm_module))
print(modref)

The output of this will be:

; ModuleID = '<string>'
source_filename = "<string>"
target triple = "unknown-unknown-unknown"

define double @forward(double %x) {
entry:
  %add = fadd double %x, 2.000000e+00
  %add_1 = fadd double %add, 2.000000e+00
  %add_2 = fadd double %add_1, %add_1
  %sub = fsub double %add_2, %add
  %mul = fmul double %add, 3.000000e+00
  %truediv = fdiv double %mul, %sub
  %add_3 = fadd double %add_2, %sub
  %add_4 = fadd double %add_3, %add
  ret double %add
}

Which shows exactly the function that we used to emit the LLVM IR, but using LLVM IR instructions and format. Note that this the IR before optimization, we will see now the effect of optimization in this function.

Optimizing the LLVM IR

We will now use the LLVM optimization passes on top of our generated IR, that should reduce a lot our function given that we have a lot of dead code there. For that we will use the optimization level 3 (similar to the optimization passes used when you use clang -O3):

pass_manager = llvm.PassManagerBuilder() 
pass_manager.opt_level = 3 

pass_module = llvm.ModulePassManager()
pass_manager.populate(pass_module)

pass_module.run(modref)
print(modref)

This will optimize the IR and show the following result after the optimization passes:

; ModuleID = '<string>'
source_filename = "<string>"
target triple = "unknown-unknown-unknown"

; Function Attrs: norecurse nounwind readnone
define double @forward(double %x) local_unnamed_addr #0 {
entry:
  %add = fadd double %x, 2.000000e+00
  ret double %add
}

Note that it removed all the dead code using the dead code elimination pass and we have now just the sum of the argument by 2 and that is what we return from that function. This was a very simple case, but LLVM has an enormous set of optimization and analyses passes that deal with much more complex code.

Executing native code from Python interpreter

Now that we have the LLVM IR module, we can use LLVM MCJIT (Machine Code JIT) to generate native code and give us a pointer to the function that we will use Python to create bindings for and call it:

from ctypes import CFUNCTYPE, c_double

def create_execution_engine():
    target = llvm.Target.from_default_triple()
    target_machine = target.create_target_machine()
    backing_mod = llvm.parse_assembly("")
    engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
    return engine

llvm_engine = create_execution_engine()
llvm_engine.add_module(modref)
llvm_engine.finalize_object()
func_ptr = llvm_engine.get_function_address("forward")
function = CFUNCTYPE(c_double, c_double)(func_ptr)

function(2.0)

4.0

What we are doing here is creating a LLVM execution engine (EE) using the target triple that we mentioned earlier (that tells the architecture we will need native code to). After that we use this EE to generate the native code for the function in the process memory. After that, we get that native code function pointer, but since we are in realms of the Python interpreter, we need to create the equivalent of a binding for that native function before being able to call it. After that, we just call the function with the argument 2.0 and get the return 4.0.

Note that we are not interpreting the function anymore, we are now calling native code for the architecture and platform we are, so the performance boost is quite significant, especially if you need to execute that function multiple times (like a forward/backward pass of a neural network).

Linking function to C/C++ code

You can also compile the LLVM IR module into a target object and link on your C/C++ program. You can use clang to compile directly from LLVM IR to your current architecture:

clang -c forward.ll -o forward.o

Then you just need to call it from your C or C++ code:

include <stdio.h>

// Our forward() method declaration
double forward(double x);

int main(int argc, char **argv)
{
    printf("forward(2.0) = %.2f\n", forward(2.0));
    return 0;
}

Compile it and link together:

clang -c main.c -o main.o
clang forward.o main.o -o main

./main
forward(2.0) = 4.00

And that is it, you are now using the torch.fx graph in C or C++.

Compiling from torch.fx to WASM (WebAssembly) / ARM64 / RISC-V

As I promised, when you have LLVM IR, there is a whole new set of options of where you can run your code into. One interesting aspect of it is that you can go from the LLVM IR to any LLVM backend such as WebAssembly:

llc --march=wasm32 -filetype=asm forward.ll -o forward.s

That will generate the following WASM32 assembly:

    .text
    .file	"<string>"
    .section	.text.forward,"",@
    .globl	forward                         # -- Begin function forward
    .type	forward,@function
forward:                                # @forward
    .functype	forward (f64) -> (f64)
# %bb.0:                                # %entry
    local.get	0
    f64.const	0x1p1
    f64.add
                                        # fallthrough-return
    end_function
.Lfunc_end0:
    .size	forward, .Lfunc_end0-forward
                                        # -- End function

That you can now run in the browser. You can also generate object or assembly code for ARM or AVR (atmel chip used in most Arduinos). Let’s see the assembly code generated to arm64 target with:

    .text
    .file	"<string>"
    .globl	forward                         // -- Begin function forward
    .p2align	2
    .type	forward,@function
forward:                                // @forward
// %bb.0:                               // %entry
    fmov	d1, #2.00000000
    fadd	d0, d0, d1
    ret
.Lfunc_end0:
    .size	forward, .Lfunc_end0-forward
                                        // -- End function
    .section	".note.GNU-stack","",@progbits

Or even for RISC-V 64:

    .text
    .attribute	4, 16
    .attribute	5, "rv64i2p0"
    .file	"<string>"
    .globl	forward                         # -- Begin function forward
    .p2align	2
    .type	forward,@function
forward:                                # @forward
# %bb.0:                                # %entry
    addi	sp, sp, -16
    sd	ra, 8(sp)                       # 8-byte Folded Spill
    li	a1, 1
    slli	a1, a1, 62
    call	__adddf3@plt
    ld	ra, 8(sp)                       # 8-byte Folded Reload
    addi	sp, sp, 16
    ret
.Lfunc_end0:
    .size	forward, .Lfunc_end0-forward
                                        # -- End function
    .section	".note.GNU-stack","",@progbits

I hope you enjoyed this tutorial !

– Christian S. Perone

Cite this article as: Christian S. Perone, "Tutorial on using LLVM to JIT PyTorch fx graphs to native code (x86/arm/risc-v/wasm) (Part I – Scalars)," in Terra Incognita, 01/09/2022, https://blog.christianperone.com/2022/09/tutorial-on-using-llvm-to-jit-pytorch-fx-graphs-to-native-code-x86-arm-risc-v-wasm-part-i-scalars/.