Understanding Triton kernels and TTIR
Contents:
# Triton kernels
I had been trying to see how triton’s GPU kernels compile to LLVM IR. Seemed especially interesting given the fact that triton is python written for GPUs! So, let’s start with a very very simple example, which I am picking from this repo.
import time
import torch
import triton
import triton.language as tl
@triton.jit
def constant_add_kernel(
x_ptr, # Pointer to the input vector x
constant, # The constant value to add
y_ptr, # Pointer to the output vector y
N0: tl.constexpr, # Total number of elements in vector x (and y)
BLOCK_SIZE: tl.constexpr # Block size, set equal to N0
):
# Each kernel instance processes a block of elements.
# With BLOCK_SIZE equal to N0, only one instance is launched.
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N0 # Ensure we don't access out-of-bound indices
# Load x values, add the constant, and store the result in y
x = tl.load(x_ptr + offsets, mask=mask)
y = x + constant
tl.store(y_ptr + offsets, y, mask=mask)
def constant_add_triton(x: torch.Tensor, constant: float) -> torch.Tensor:
"""
Adds a constant to each element of the input vector x using a Triton kernel.
The block size is set equal to the vector length (N0), meaning that only one
kernel instance is launched.
Args:
x (torch.Tensor): Input vector on CUDA.
constant (float): The constant to add to each element.
Returns:
torch.Tensor: Output vector with the constant added.
"""
N0 = x.numel()
BLOCK_SIZE = N0 # Block size equals the vector length
y = torch.empty_like(x)
# With BLOCK_SIZE = N0, our grid consists of a single block.
grid = lambda meta: (1,)
# Launch the Triton kernel
constant_add_kernel[grid](x, constant, y, N0, BLOCK_SIZE=BLOCK_SIZE)
return y
You probably can’t have a simpler kernel (that’s also non-trivial) than this. In the final step of the wrapper function, before calling the kernel, let’s print all compilation IRs:
def constant_add_triton(x: torch.Tensor, constant: float) -> torch.Tensor:
...
# Launch the Triton kernel
triton_kernel = constant_add_kernel[grid](x, constant, y, N0, BLOCK_SIZE=BLOCK_SIZE)
with open('triton_IR_S.txt', 'w') as f:
print(triton_kernel.asm['ttir'], file=f)
with open('triton_TTGIR_S.txt', 'w') as f:
print(triton_kernel.asm['ttgir'], file=f)
with open('triton_LLVMIR_S.txt', 'w') as f:
print(triton_kernel.asm['llir'], file=f)
with open('triton_PTX_S.ptx', 'w') as f:
print(triton_kernel.asm['ptx'], file=f)
with open('triton_cubin_S.txt', 'w') as f:
print(triton_kernel.asm['cubin'], file=f)
return y
So, the normal compilation process goes like:
- The python code is read by the interpreter and the JIT compiler
"compiles"the kernel
You see that decorator @triton.jit? That’s what’s telling triton that this is the kernel implementation. This is a JIT (Just-In-Time) compiler because… there really is no other way. You need to run the python code to start the interpreter which means you NEED a JIT and you can’t have an AOT (Ahead-of-Time) compiler. This is kind of good too because this means less headaches especially since you specify what to compile.
- The compilation first creates a
MLIR.MLIRis kind of an umbrella term for:
- Triton-IR (ttir)
- Triton-GPUIR (ttgir)
- LLVM-IR (LLIR)
- Next the LLVM-IR is used to compile the code into device assembly (PTX) -> This enabled all of LLVM’s optimisations
- And finally this becomes a binary (
cubinor CUda+BIN)
# Understanding TTIR
Let’s first look at TTIR and see what we’ve got. This is the right documentation if you wish to go through it yourself.
#loc =
#loc13 =
#loc14 =
#loc15 =
module
#loc1 =
#loc2 =
#loc3 =
#loc4 =
#loc5 =
#loc6 =
#loc7 =
#loc8 =
#loc9 =
#loc10 =
#loc11 =
#loc12 =
#loc16 =
#loc17 =
#loc18 =
#loc19 =
#loc20 =
#loc21 =
#loc22 =
#loc23 =
From a blog post on Trition compilation cycles, I could understand that this is based on the open-source LLVM compiler project. Which I am assuming to mean that the syntax and semantics are “similar”.
Let’s ignore the fluff for now (we’ll come back to all the loc lines later) and look at the center of attraction:
module
In LLVM IR, all the % prefixed strings are identifiers, which I am assuming is the case here. We’re defining a function (tt.func) that is public (which I am assuming means that it is accessbile outside of this module definition as well, LLVM IR doesn’t have any public linkage though, its called external or something). The name is a global definition (@constant_add_kernel) and we can see that we’re passing in the same params we had defined in our function:
One interesting thing to note is that the IR doesn’t include the number of elements in the vectors and the block size. If you look at these lines carefully:
%c1024_i32 = arith.constant 1024
%pid = tt.get_program_id x : i32
%offsets = arith.muli %pid, %c1024_i32 : i32
It just assumed the block size to be 1024 (N0), and proceed from. So, it never takes that as a function parameter, cause its smart. (this is the same pid * BLOCK_SIZE multiplication we used in the kernel).
The
arithdialect is intended to hold basic integer and floating point mathematical operations. You can see here that it has a lot of functionality and its quite similar to the basic assembly instructions likeadd,imuletc.Note that most of the operations are ALSO dependent on the strings following it. Look at this example from the documentation:
// Scalar addition.
%a = arith.addf %b, %c : f64
// SIMD vector addition, e.g. for Intel SSE.
%f = arith.addf %g, %h : vector<4xf32>
// Tensor addition.
%x = arith.addf %y, %z : tensor<4x?xbf16>
// Scalar addition with rounding mode.
%a = arith.addf %b, %c to_nearest_even : f64
In our case for example we have this addf:
%y_6 = arith.addf %x_5, %y : tensor<1024xf32>
This means a tensor addition! Nice.
Let’s look at the analogs for these parts of the code now:
=
= +
which we’ll find here:
%offsets_0 = tt.make_range : tensor<1024xi32>
%offsets_1 = tt.splat %offsets : i32 -> tensor<1024xi32>
%offsets_2 = arith.addi %offsets_1, %offsets_0 : tensor<1024xi32>
%mask_3 = arith.cmpi slt, %offsets_2, %mask : tensor<1024xi32>
%x = tt.splat %x_ptr : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%x_4 = tt.addptr %x, %offsets_2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%x_5 = tt.load %x_4, %mask_3 : tensor<1024x!tt.ptr<f32>>
%y = tt.splat %constant : f32 -> tensor<1024xf32>
%y_6 = arith.addf %x_5, %y : tensor<1024xf32>
%0 = tt.splat %y_ptr : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%1 = tt.addptr %0, %offsets_2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
tt.store %1, %y_6, %mask_3 : tensor<1024x!tt.ptr<f32>>
tt.return
We had seen what %offsets is storing (the arith.muli instruction), now we can trace this out.
The
tt.make_rangeis defined here. It creates a 1D int32 tensor whose values span from$startto$end(exclusive), with step = 1.So our
offsets_0is a 1D tensor, a vector (which is what is denoted by the: tensor<1024xi32>in the end) that spans from 0 to 1024 (theN0value).
The reason the code is doing that is because we had defined something like this:
= * +
So, it already calculated the pid*BLOCK_SIZE and now it needs the vector to add it to. But it also needs to ensure alignment:
%offsets_1 = tt.splat %offsets : i32 -> tensor<1024xi32>
%offsets_2 = arith.addi %offsets_1, %offsets_0 : tensor<1024xi32>
which is exactly what the %offsets_1 identifier is doing. The %offsets_2 identifier is the analog of the offsets variable in python which is the final result of the addition. The output is a tensor. The splat function is used to convert between datatypes.
The syntax for the tt.splat operation is as follows:
operation ::= `tt.splat` $src attr-dict `:` type($src) `->` type($result)
which means in our case, it converts the single i32 value we got in %offsets calculation. Recall that
%offsets = arith.muli %pid, %c1024_i32 : i32
%offsets was therefore a 32 bit integer before, which we converted now to a tensor for addition.
- Then it performs the addition which also results in a
tensor<1024xi32>(basically a 1D tensor or a vector of 32 bit integers).
Then we see it create the mask here:
%mask = arith.constant dense<1024> : tensor<1024xi32>
...
%mask_3 = arith.cmpi slt, %offsets_2, %mask : tensor<1024xi32>
I am not sure why its mask_3 and not mask_1? or even mask_0? I am assuming this is also SSA like LLVMIR. The numbering might be a little off.
- The
arith.cmpiinstruction is documented here, it follows this syntax:
operation ::= `arith.cmpi` $predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)
in our case, the $predicate is slt which stands for signed less than, the $lhs is %offsets_2 (which basically the “offsets” variable in the python code) and the $rhs is the older %mask which is simply a constant vector held in memory.
Recall that when we were making the mask, we said
mask = offsets < N0, that’s the comparision we’re going to do here.The result is
1if the comparison is true and0otherwise. Note that the syntax doesn’t provide us the output type of the result, rather the output type of$lhsso lets not get confused as to why its a tensor.
So our mask_3 is either a 0 or 1, let’s move on:
%x = tt.splat %x_ptr : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%x_4 = tt.addptr %x, %offsets_2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%x_5 = tt.load %x_4, %mask_3 : tensor<1024x!tt.ptr<f32>>
%y = tt.splat %constant : f32 -> tensor<1024xf32>
%y_6 = arith.addf %x_5, %y : tensor<1024xf32>
%0 = tt.splat %y_ptr : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%1 = tt.addptr %0, %offsets_2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
tt.store %1, %y_6, %mask_3 : tensor<1024x!tt.ptr<f32>>
tt.return
Recall how we wanted to add
x_ptr + offsetsbefore passing it totl.load(which istt.loadhere). To do that, we need to ensure that both are of the same types.The
tt.splatin this case also converts the 32 bit floating pointer to a tensor of 32 bit floating pointers. This is creating theconstantindentifier we’re using in the python code fory.Then we call the
arith.addffunction, and from previous discussions we know that this is tensor addition between the output oftl.load(ortt.load) add that to the constant to get theyvalue from the python code.Then we load in the
y_ptrvia%0and convert it from a floating 32 bit to a tensor of 32 bit floats.Add the
y_ptrto theoffsets(or%offsets2in this case) and then we store it along with themask, just like we do in python.
That brings us to an end of this introduction. You can see that its NOT very different from the python code itself except for a few more details on how things are actually working. But its not “there” yet, as in not sufficiently low level.
So next we’ll be looking at TTGIR which is the next stage in the compilation process and then understand how the IR maps to the GPU architecture.
Thanks for reading!