Lecture 002

Primitive Tensor Function

Primitive Tensor Function: tensor function correspond to a single unit of computation operation.

ML Compilation: transform between primitive tensor function. The loop gets split into units of length 4 where f32x4 add corresponds to a special vector add function that carries out the computation.

ML Compilation: transform between primitive tensor function. The loop gets split into units of length 4 where f32x4 add corresponds to a special vector add function that carries out the computation.

Applying transfromation to primitive function is the easiest transformation we can do. Most ML framework maps add to cudaAdd just by calling CUDA library. This is rather bad and too high level and therefore require manual mapping from every primitive function (like add) to different library calls on different platform cudaAdd, or amdAdd...

There are multiple ways to implement a primitive tensor function. We can only do some transformation if the order of the loop doesn't matter (most don't matter).

Tensor Program Abstraction

The typical elements in a primitive tensor function

The typical elements in a primitive tensor function

We can specify T.axis.spacial() to let TVM know that we have spacial independence that can be parallelized.

In a typical primitive tensor function, we have 3 things: buffer, loop, and computation.

We want to automatically do transformation instead of applying hand-crafted optimization rules.

In TVM, anything denoted as @tvm.script.ir_module is a tvm.ir.module.IRModule (now type(MyModule) -> tvm.ir.module.IRModule). In the following example, we have function main be in IRModule.

To install: !python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels

import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"],
             B: T.Buffer[128, "float32"],
             C: T.Buffer[128, "float32"]):
        # extra annotations for the function
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in range(128):
            with T.block("C"):
                # declare a data parallel iterator on spatial domain
                vi = T.axis.spatial(128, i) # this line says that i can be paralleled, and give a parallelized version a new name as vi
                C[vi] = A[vi] + B[vi]

Using MyModule.show() can convert to TVM script format: it has extra information good for compilers

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i in T.serial(128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                T.reads(A[vi], B[vi]) # you can see read A, B
                T.writes(C[vi]) # you can see write on C
                C[vi] = A[vi] + B[vi]

Transforming Module in TVMScript

To transform the module, do sch = tvm.tir.Schedule(MyModule), and once you done with transform, do sch.mod.script() to show transformed version.

Splitting the Loop

To split the a loop into multiple pieces, you can do

# Get block by its name
block_c = sch.get_block("C")
# Get loops surronding the block
(i,) = sch.get_loops(block_c)
# Tile the loop nesting.
i_0, i_1, i_2 = sch.split(i, factors=[None, 4, 4]) # None is automatic deduced
sch.mod.show()

and then you will get the following IR:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i_0, i_1, i_2 in T.grid(8, 4, 4): # note that 128 = 8*4*4
            with T.block("C"):
                vi = T.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                T.reads(A[vi], B[vi])
                T.writes(C[vi])
                C[vi] = A[vi] + B[vi]

Reordering the Loop

We can change the order of loop using sch.reorder(i_0, i_2, i_1):

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i_0, i_2, i_1 in T.grid(8, 4, 4): # note that i_1 and i_2 changed
            with T.block("C"):
                vi = T.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                T.reads(A[vi], B[vi])
                T.writes(C[vi])
                C[vi] = A[vi] + B[vi]

Parallelize the Loop

If we want to parallelize one dimension, we can do sch.parallel(i_0), and it will give:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i_0 in T.parallel(8): # notice 8 dimension paralleled
            for i_2, i_1 in T.grid(4, 4):
                with T.block("C"):
                    vi = T.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                    T.reads(A[vi], B[vi])
                    T.writes(C[vi])
                    C[vi] = A[vi] + B[vi]

Compile TVMScript to Function

Once you are satisfied, you can compile. Compilation takes in a Module that is already transformed and spit out an object that is of a function type.

rt_mod = tvm.build(MyModule, target="llvm")  # The module for CPU backends.
print(type(rt_mod)) # <class 'tvm.driver.build_module.OperatorModule'>
func = rt_mod["main"]
func # <tvm.runtime.packed_func.PackedFunc at 0x7fdf00501e40>

a = tvm.nd.array(np.arange(128, dtype="float32"))
b = tvm.nd.array(np.ones(128, dtype="float32"))
c = tvm.nd.empty((128,), dtype="float32")

func(a, b, c)
print(a)
print(b)
print(c) # c will be modified

Programatically Construct TensorIR

And also, we have transformed directly using IR. But I didn't tell you how do you construct IR. To do so, you can do:

# namespace for tensor expression utility
from tvm import te

# declare the computation using the expression API
A = te.placeholder((128, ), name="A")
B = te.placeholder((128, ), name="B")
C = te.compute((128,), lambda i: A[i] + B[i], name="C")

# create a function with the specified list of arguments.
func = te.create_prim_func([A, B, C])
# mark that the function name is main
func = func.with_attr("global_symbol", "main")
ir_mod_from_te = IRModule({"main": func})

ir_mod_from_te.show()

The above ir_mod_from_te.show() will print out:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0 in T.serial(128):
            with T.block("C"):
                i = T.axis.spatial(128, i0)
                T.reads(A[i], B[i])
                T.writes(C[i])
                C[i] = A[i] + B[i]

Table of Content