Tutorial on using LLVM to JIT PyTorch fx graphs to native code (x86/arm/risc-v/wasm) (Part I – Scalars)
In 2009 I started playing with LLVM for some projects (data structure jit, for genetic programming, jit for tensorflow graphs, etc), and in these projects I realized how powerful LLVM design was at the time (and still is): using an elegant IR (intermediate representation) with an user-facing API and modularized front-ends and backends with plenty of transformation and optimization passes. Nowadays, LLVM is the main engine behind many compilers and JIT compilation and where most of the modern developments in compilers is happening.
On a related note, PyTorch is doing an amazing job of exposing more of the torch tracing system and its IR and graphs, not to mention their work on recent fusers and TorchDynamo. In this context, I was doing a small test to re-implement Shine, but using ATen ops for tensors and realized that there were not many educative tutorials on how to use LLVM to JIT PyTorch graphs, so this is a quick series (if time helps there will be more following posts) on how to use LLVM (python bindings) to go from PyTorch graphs (as traced by torch.fx) to LLVM IR and native code.
Detour – PyTorch NNC (Neural Net Compiler)
PyTorch itself also has a compiler that uses LLVM to generate native code for subgraphs that the fuser identifies. This is also called NNC (Neural Net Compiler) or Tensor Expressions (TE) as well, you can read more about them here in the C++ API tutorial. One thing to note though is that default binaries you get from PyTorch do not come linked to the LLVM libraries, so you need to compile it by yourself and enable LLVM backend, otherwise it won’t use LLVM to do the JIT compilation/optimization of the subgraphs. Let’s take a look at it first before starting our tutorial
Enabling LLVM backend support in PyTorch
The first thing you need to do to enable the LLVM backend in the NNC is to build PyTorch with LLVM support enabled (and provide LLVM lib as well). In my case I used LLVM 14.0.6 from homebrew in a MacOS (w/ Apple Silicon is a little more tricky to compile PyTorch + LLVM, as you will need to manually edit some CMake files), but you can get LLVM basically anywhere. To compile PyTorch w/ LLVM enabled and disabling what we don’t need for this tutorial (and save a lot of compilation time), you can just use as follows:
USE_KINETO=0 BUILD_CAFFE2=0 USE_DISTRIBUTED=0 USE_NCCL=0 BUILD_TEST=0 USE_XNNPACK=0 USE_FBGEMM=0 USE_QNNPACK=0 USE_MKLDNN=0 USE_MIOPEN=0 USE_NNPACK=0 BUILD_CAFFE2_OPS=0 USE_TENSORPIPE=0 CC=clang CXX=clang++ USE_LLVM=[the path to your LLVM] python setup.py develop
After that, you should get True
for the following call to check if LLVM is enabled:
>>> torch._C._llvm_enabled() True
Enabling NNC fuser
Now we can try a small “fuse-opportunistic” (you can see how bad I’m at creating new technical terms) function to check the LLVM IR generated by NNC using LLVM:
import torch # Our example function def our_function(x): a = torch.mul(x, x) b = torch.sin(a) c = torch.cos(b) d = torch.mul(c, c) f = torch.mul(d, d) return d + f # Setting flags to enable NNC (some are not strictly necessary) torch._C._jit_set_texpr_fuser_enabled(True) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_set_te_must_use_llvm_cpu(True) # Script the function, this is not where the # LLVM IR will be emitted scripted = torch.jit.script(our_function) # Execute the JIT compiled function arg = torch.randn(1) scripted(arg)
And that’s it, we have a very simple function and we execute it with a random scalar. PyTorch now will use LLVM JIT compile the function to native code. Now, if we want to see the output of the LLVM IR, we need to run:
PYTORCH_JIT_LOG_LEVEL=">>llvm_codegen" python example.py
The flag using “llvm_codegen” is to enable dumping of the LLVM IR and debugging information. However, there is a catch here: if you execute the command above, you will see that nothing will be shown. Why is that happening ? PyTorch uses profiling to specialize the graphs, so if we execute just once, it will capture profiling information (especially about types and shapes) to then in the next run generate the JIT compiled function using LLVM. In the first run, some profiling nodes will be added into the graph and then later after it specializes it, it will add some guards to make sure that we can always execute the original function as a fallback or either specialize again the function to different shapes/types. To correct the code, you just have to do:
# We get the number of profiled runs and add one to it num_profiled = torch._C._jit_get_num_profiled_runs() + 1 for _ in range(num_profiled): arg = torch.randn(1) scripted(arg)
Now if you execute it again with the debugging flag enabled, you will see a lot of debugging. The important part for us here is the LLVM IR after optimization (the optimization is done by LLVM as well with many passes that could do DCE (Dead Code Elimination), inlining, instruction combination, among many other optimization techniques that are also done for regular compiled code. Let’s see the emitted LLVM IR:
; ModuleID = 'pytorch' source_filename = "pytorch" target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-apple-darwin20.6.0" ; Function Attrs: mustprogress nofree nosync nounwind willreturn define i32 @fused_mul_sin_cos_mul_mul_add(i8** nocapture readonly %0) local_unnamed_addr #0 { wrapBB: %1 = bitcast i8** %0 to float** %2 = load float*, float** %1, align 8 %3 = getelementptr i8*, i8** %0, i64 1 %4 = bitcast i8** %3 to float** %5 = load float*, float** %4, align 8 %.val = load float, float* %2, align 4 call void @llvm.experimental.noalias.scope.decl(metadata !0) %6 = fmul float %.val, %.val %7 = tail call float @sinf(float %6) #3 %8 = tail call float @cosf(float %7) #3 %9 = fmul float %8, %8 %10 = fmul float %9, %9 %11 = fadd float %9, %10 store float %11, float* %5, align 4, !alias.scope !0 ret i32 0 }
In the IR above I omitted some LLVM metadata to make it short. As you can see, the code is doing the same operations seen in the function but replacing some ops to use IR instructions such as fmul
(float multiplication), etc. Also, note the name of the generated function fused_mul_sin_cos_mul_mul_add
, this is basically saying that it is fusing mul/sin/cos/mul/mul/add together, which is exactly the order of operations we have in our function.
Generated native code and static compilation with LLVM LLC
There is also a target triple, that instructs LLVM what is the system is has to emit native code for, in this case I’m on a x86 architecture. This LLVM IR was also already optimized with similar passes you see when you use clang on regular C/C++ code for example. The powerful thing about having this LLVM IR is that now it can emit assembly code for the architecture, which can be seen below:
.section __TEXT,__text,regular,pure_instructions .build_version macos, 11, 0 .globl _fused_mul_sin_cos_mul_mul_add .p2align 4, 0x90 _fused_mul_sin_cos_mul_mul_add: pushq %rbx movq (%rdi), %rax movq 8(%rdi), %rbx vmovss (%rax), %xmm0 vmulss %xmm0, %xmm0, %xmm0 movabsq $_sinf, %rax callq *%rax movabsq $_cosf, %rax callq *%rax vmulss %xmm0, %xmm0, %xmm0 vfmadd213ss %xmm0, %xmm0, %xmm0 vmovss %xmm0, (%rbx) xorl %eax, %eax popq %rbx retq
This code above is the assembly code generated for the target triplet x86_64 and this is what will ultimately get executed. You can also do some interesting things with the LLVM IR, you can just get the LLVM IR, save into a file and compile it with the LLVM static compiler as an object that can be linked together with any other application:
llc -filetype=obj code.ll -o code.o
In this example above we compiled it into a code object called code.o
that can be linked together with other application. if you list the symbols, you will see the function there:
~# llvm-nm code.o U _cosf 0000000000000000 T _fused_mul_sin_cos_mul_mul_add U _sinf
You can see that we have the fused function with a code “T” (that can be found in the text header of the object, so it is provided by the object itself) and the _cosf
and _sinf
that are “U” (undefined) as you need to link with another library to provide these symbols that are used by the fused function.
We will do some other funny things later with the LLVM IR, but one thing you need to realize is that once you have a LLVM IR, it opens the door for really interesting and useful things. Since we talked enough for a quick detour on PyTorch NNC compiler, let’s jump into our own implementation of a simple JIT for float scalars only, which is the main goal of this tutorial.
Building our own scalar JIT compiler with LLVM
Simplifying assumptions
The reason I decided to start with simple scalars is that it simplifies a lot the LLVM IR generation as it doesn’t have to deal with tensor shapes. We also assume that the scalars are all floats, just like when we have a closure in Genetic Programming graphs, and this also makes our code minimal and understandable as we don’t have to deal with supporting multiple types operations. A full-fledged JIT compiler is as complex as PyTorch’s NNC is, but the goal here is to show the main intuition of how we can build such compilers with LLVM.
Differently than PyTorch’s NNC that creates the LLVM IR in C++, we will create the IR in Python using the Python bindings called llvmlite
(also used by Numba project).
Getting PyTorch’s graph with torch.fx
To get PyTorch graphs, we will use the new torch.fx module and its symbolic tracer that does a symbolic execution of the code and produces a graph for us. We will then traverse this graph and emit LLVM IR code using LLVM operators.
Let’s first define a not very useful function, but that will allow for many optimization techniques to improve it:
def fn(x): a = x + 2.0 b = a + 2.0 b += b c = b - a e = a * 3 e = e / c d = b + c + a return a
As you can see, it is a very boring function but in this code we can see already the use of multiple operations such as multiplication, subtraction, addition and division. Let’s see now the graph that the symbolic tracer from torch.fx produced:
>>> gm = torch.fx.symbolic_trace(fn) >>> type(gm) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'> >>> type(gm.graph) <class 'torch.fx.graph.Graph'> >>> gm.graph.print_tabular() opcode name target args kwargs ------------- ------- --------------------------- -------------- -------- placeholder x x () {} call_function add <built-in function add> (x, 2.0) {} call_function add_1 <built-in function add> (add, 2.0) {} call_function add_2 <built-in function add> (add_1, add_1) {} call_function sub <built-in function sub> (add_2, add) {} call_function mul <built-in function mul> (add, 3) {} call_function truediv <built-in function truediv> (mul, sub) {} call_function add_3 <built-in function add> (add_2, sub) {} call_function add_4 <built-in function add> (add_3, add) {} output output output (add,) {}
As you can see above, the symbolic tracer returns a graph module with the graph generated from the symbolic execution of the function. We are showing above as well a tabular version of the graph, where we have different types of opcodes that we will have to support, together with different function targets and constants. The first node type (opcode) we can see here is the placeholder
, this opcode represents a function input, or a parameter of the function, in this case we have the “x” which is the input of the traced function we provided.
The next opcodes we can see is the call_function
that summons Cthulhu every time the interpreter sees it, just kidding, it really just calls a function. At the end we have the output
opcode, that is saying that the output of our function is the node called add
. Note that this node is not a function, but the node that has this name (the second row in the tabular representation shown above). You can see already that there is a lot of dead code there in the function, assuming no external effects from the other ops, we only need to return the result of the node that doesn’t require all the computation below, so we will see that LLVM will optimize that for us later.
Emitting IR for the function definition and main block
We will create a class that will receive the torch.fx graph and then generate the LLVM IR. The first step is to create a function called forward()
(trying to keep as close as possible of the PyTorch conventions) where it will contain all the instructions. To create a function in LLVM IR we need to define its return type and argument types as well, the complete function declaration. Here is the initial part of our LLVMCompiler
class:
class LLVMCompiler: def __init__(self, gm): self.graph = gm.graph self.vars = {} self.closure_type = ir.DoubleType() self.llvm_module = ir.Module() self.ir_builder = self.build_forward() self.fn_map = { ops.add: self.ir_builder.fadd, ops.sub: self.ir_builder.fsub, ops.truediv: self.ir_builder.fdiv, ops.mul: self.ir_builder.fmul, } def build_forward(self): placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] arg_types = [self.closure_type for _ in placeholders] func_forward_type = ir.FunctionType(self.closure_type, arg_types) func_forward = ir.Function(self.llvm_module, func_forward_type, name="forward") for arg, placeholder in zip(func_forward.args, placeholders): arg.name = placeholder.name self.vars[arg.name] = arg block = func_forward.append_basic_block(name="entry") return ir.IRBuilder(block)
Let’s unpack this code:
- In the constructor we instantiate a few useful things such a dict called
self.vars
that will hold a map from a variable name to the variable itself (in case they are referenced later); - We also create a LLVM Module. This module is the top level container of all other LLVM IR instructions;
- We created also a closure type that is a
DoubleType
, this is the type that will be used for all arguments, variables and return types; - We have also a mapping from Python operators +, -, / and * to their LLVM IR instruction counterparts for float types, we need that because we will be converting these Python operation calls to LLVM IR instruction calls;
- Next we call the
build_forward()
method:- In this method we get all the placeholder nodes of the torch.fx graph and create a function called
forward()
that will have our instructions; - We then proceed by adding the the function arguments into the
self.vars
dict, so they can be referenced later (these are the inputs, the placeholders); - We finally create a “basic block”, which in the LLVM IR represents a single entry single exit section of code. After that we return a
IRBuilder
that will be used to add the function instructions, note that this IR is tied to the entry block we added int the function, so any instruction built with this IR builder, will be connected to this basic block of our function; - In summary, we have: LLVM Module -> Function -> Entry block
- In this method we get all the placeholder nodes of the torch.fx graph and create a function called
Emitting IR instructions for the torch.fx nodes
Now that we have our LLVM Module with the function declaration and an entry block, we can now traverse the torch.fx graph and emit the IR instructions for the code of the function. We need to support 2 more opcodes: call_function
and output
. Let’s now show the complete code for the LLVMCompiler
:
class LLVMCompiler: def __init__(self, gm): self.graph = gm.graph self.vars = {} self.closure_type = ir.DoubleType() self.llvm_module = ir.Module() self.ir_builder = self.build_forward() self.fn_map = { ops.add: self.ir_builder.fadd, ops.sub: self.ir_builder.fsub, ops.truediv: self.ir_builder.fdiv, ops.mul: self.ir_builder.fmul, } def build_forward(self): placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] arg_types = [self.closure_type for _ in placeholders] func_forward_type = ir.FunctionType(self.closure_type, arg_types) func_forward = ir.Function(self.llvm_module, func_forward_type, name="forward") for arg, placeholder in zip(func_forward.args, placeholders): arg.name = placeholder.name self.vars[arg.name] = arg block = func_forward.append_basic_block(name="entry") return ir.IRBuilder(block) def jit(self): for node in self.graph.nodes: self.emit_node(node) def emit_node(self, node): try: emit_fn = getattr(self, f"emit_{node.op}") except: raise RuntimeError(f"Unknown node type: {node}, {node.op}.") emit_fn(node) def emit_placeholder(self, node): return def emit_call_function(self, node): args = [] for arg in node.args: if isinstance(arg, Node): args.append(self.vars[arg.name]) else: constant_arg = self.closure_type(arg) args.append(constant_arg) fn = self.fn_map[node.target] call_inst = fn(args[0], args[1], name=node.name) self.vars[node.name] = call_inst return def emit_output(self, node): return_arg = node.args[0].name return_var = self.vars[return_arg] self.ir_builder.ret(return_var) return
We added now a few new methods, let’s go through each one of them:
jit()
: this method will traverse the torch.fx graph and will call the emit_node()
method for each node;
emit_node()
: here we will dynamically find the method named emit_[opcode]
and call it with the node, it is just a simple dynamic dispatch to make it easier to separate the node processing;
emit_call_function()
: this method will emit the LLVM IR for the opcode that calls a function, when you have something like 2 + 2, we are calling the python add() built-in for two constants. This method will first check if the arguments of the node are a node themselves or if it is a constant, if it is a constant it will just emit a LLVM IR constant with the closure type (DoubleType
), or if the argument is a node, it will look in the self.vars
dictionary to find it. After that, we use the mapping of Python built-in operators to LLVM IR instructions to get the LLVM IR instruction and then we just add a call instruction in the IR by calling it with the created arguments. The last thing we do is to add the result of this instruction itself in the self.vars
as it can be referenced later by other following operators.
emit_output()
: here we are dealing with the output opcode, that is telling we should return a value. What we do is to use the LLVM IR “ret” instruction to return the argument and that is it.
That’s pretty much it, with this code we can already create LLVM IR code to support the function we showed earlier. Even in the context of the simplifying assumptions that we are, it is impressive that we can build a IR code in a few lines of Python code, but that’s the power of torch.fx and LLVM.
Viewing the LLVM IR
Let’s now run our code in the function we defined and show the generated LLVM IR. For that we will use the ModuleRef
from llvmlite
:
# Initialization calls required by LLVM llvm.initialize() llvm.initialize_native_target() llvm.initialize_native_asmprinter() # Symbolic tracing and compilation gm = torch.fx.symbolic_trace(fn) # Run our compiler on the torch.fx graph compiler = LLVMCompiler(gm) compiler.jit() # llvmlite has two modules: Module and ModuleRef, here we use the # ModuleRef as it will allow us to generated native code modref = llvm.parse_assembly(str(compiler.llvm_module)) print(modref)
The output of this will be:
; ModuleID = '<string>' source_filename = "<string>" target triple = "unknown-unknown-unknown" define double @forward(double %x) { entry: %add = fadd double %x, 2.000000e+00 %add_1 = fadd double %add, 2.000000e+00 %add_2 = fadd double %add_1, %add_1 %sub = fsub double %add_2, %add %mul = fmul double %add, 3.000000e+00 %truediv = fdiv double %mul, %sub %add_3 = fadd double %add_2, %sub %add_4 = fadd double %add_3, %add ret double %add }
Which shows exactly the function that we used to emit the LLVM IR, but using LLVM IR instructions and format. Note that this the IR before optimization, we will see now the effect of optimization in this function.
Optimizing the LLVM IR
We will now use the LLVM optimization passes on top of our generated IR, that should reduce a lot our function given that we have a lot of dead code there. For that we will use the optimization level 3 (similar to the optimization passes used when you use clang -O3):
pass_manager = llvm.PassManagerBuilder() pass_manager.opt_level = 3 pass_module = llvm.ModulePassManager() pass_manager.populate(pass_module) pass_module.run(modref) print(modref)
This will optimize the IR and show the following result after the optimization passes:
; ModuleID = '<string>' source_filename = "<string>" target triple = "unknown-unknown-unknown" ; Function Attrs: norecurse nounwind readnone define double @forward(double %x) local_unnamed_addr #0 { entry: %add = fadd double %x, 2.000000e+00 ret double %add }
Note that it removed all the dead code using the dead code elimination pass and we have now just the sum of the argument by 2 and that is what we return from that function. This was a very simple case, but LLVM has an enormous set of optimization and analyses passes that deal with much more complex code.
Executing native code from Python interpreter
Now that we have the LLVM IR module, we can use LLVM MCJIT (Machine Code JIT) to generate native code and give us a pointer to the function that we will use Python to create bindings for and call it:
from ctypes import CFUNCTYPE, c_double def create_execution_engine(): target = llvm.Target.from_default_triple() target_machine = target.create_target_machine() backing_mod = llvm.parse_assembly("") engine = llvm.create_mcjit_compiler(backing_mod, target_machine) return engine llvm_engine = create_execution_engine() llvm_engine.add_module(modref) llvm_engine.finalize_object() func_ptr = llvm_engine.get_function_address("forward") function = CFUNCTYPE(c_double, c_double)(func_ptr) function(2.0) 4.0
What we are doing here is creating a LLVM execution engine (EE) using the target triple that we mentioned earlier (that tells the architecture we will need native code to). After that we use this EE to generate the native code for the function in the process memory. After that, we get that native code function pointer, but since we are in realms of the Python interpreter, we need to create the equivalent of a binding for that native function before being able to call it. After that, we just call the function with the argument 2.0 and get the return 4.0.
Note that we are not interpreting the function anymore, we are now calling native code for the architecture and platform we are, so the performance boost is quite significant, especially if you need to execute that function multiple times (like a forward/backward pass of a neural network).
Linking function to C/C++ code
You can also compile the LLVM IR module into a target object and link on your C/C++ program. You can use clang to compile directly from LLVM IR to your current architecture:
clang -c forward.ll -o forward.o
Then you just need to call it from your C or C++ code:
include <stdio.h> // Our forward() method declaration double forward(double x); int main(int argc, char **argv) { printf("forward(2.0) = %.2f\n", forward(2.0)); return 0; }
Compile it and link together:
clang -c main.c -o main.o clang forward.o main.o -o main ./main forward(2.0) = 4.00
And that is it, you are now using the torch.fx graph in C or C++.
Compiling from torch.fx to WASM (WebAssembly) / ARM64 / RISC-V
As I promised, when you have LLVM IR, there is a whole new set of options of where you can run your code into. One interesting aspect of it is that you can go from the LLVM IR to any LLVM backend such as WebAssembly:
llc --march=wasm32 -filetype=asm forward.ll -o forward.s
That will generate the following WASM32 assembly:
.text .file "<string>" .section .text.forward,"",@ .globl forward # -- Begin function forward .type forward,@function forward: # @forward .functype forward (f64) -> (f64) # %bb.0: # %entry local.get 0 f64.const 0x1p1 f64.add # fallthrough-return end_function .Lfunc_end0: .size forward, .Lfunc_end0-forward # -- End function
That you can now run in the browser. You can also generate object or assembly code for ARM or AVR (atmel chip used in most Arduinos). Let’s see the assembly code generated to arm64 target with:
.text .file "<string>" .globl forward // -- Begin function forward .p2align 2 .type forward,@function forward: // @forward // %bb.0: // %entry fmov d1, #2.00000000 fadd d0, d0, d1 ret .Lfunc_end0: .size forward, .Lfunc_end0-forward // -- End function .section ".note.GNU-stack","",@progbits
Or even for RISC-V 64:
.text .attribute 4, 16 .attribute 5, "rv64i2p0" .file "<string>" .globl forward # -- Begin function forward .p2align 2 .type forward,@function forward: # @forward # %bb.0: # %entry addi sp, sp, -16 sd ra, 8(sp) # 8-byte Folded Spill li a1, 1 slli a1, a1, 62 call __adddf3@plt ld ra, 8(sp) # 8-byte Folded Reload addi sp, sp, 16 ret .Lfunc_end0: .size forward, .Lfunc_end0-forward # -- End function .section ".note.GNU-stack","",@progbits
I hope you enjoyed this tutorial !
– Christian S. Perone
Could you share runnable version of the code used in this blog? Thanks.
Great article. When will you be publishing part II (for Tensors)?
Adding the following made the code above working:
import torch
import llvmlite.binding as llvm
from llvmlite import ir
from torch.fx.node import Node
import operator as ops