Lecture 003

Python Low-level vs TVMScript

The task of MLC is to transform from the left to the right

The task of MLC is to transform from the left to the right

Say we want to compute this.

Say we want to compute this.

We can write the following program:

dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b is equivalent to np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)

The particular code mm_relu is implemented in a language called TVMScript, which is a domain-specific dialect embedded in python AST.

TensorIR vs Typical Implementation

TensorIR vs Typical Implementation

We annotate TVMScript using python equivalence below:

@tvm.script.ir_module
class MyModule:
    @T.prim_func # to achieve type(MyModule['mm_relu']) => tvm.tir.function.PrimFunc
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]): # def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) # mm_relu is function handel name, "tir.noalias": True means our array does not have overlap memory
        Y = T.alloc_buffer((128, 128), dtype="float32") # Y = np.empty((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128): # for i in range(128): for j in range(128): for k in range(128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k) # vi, vj, vk = i, j, k
                with T.init(): # if vk == 0:
                    Y[vi, vj] = T.float32(0) # Y[vi, vj] = 0
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

The following block (in the format [block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])) specifies more information than python code:

vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)

# above code is equivalent to
vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # SSR = "spatial", "spatial", "reduce"

Spacial and Reduce Axis

Spacial and Reduce Axis

In order to get a specific scalar for Y, say Y[0, 1], we need to set vi=0, vj=1, and the left-over vk has to loop for all values. This is why vi, vj are called spacial axis and vk are called reduce axis.

Observe that we can calculate Y[0, 1] in parallel with Y[1, 2] since vi, vj are spacial dimension.

More Transforms

Split Loops Along Dimension

We want to change the dimension of the loop, to generate code that looks like the following:

# The order of iterations changes slightly
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j0 in range(32): # notice here
            for k in range(128):
                for j1 in range(4): # and here
                    j = j0 * 4 + j1 # i is split from 128 = 32*4
                    if k == 0:
                        Y[i, j] = 0
                    Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

To do so, we make a scheduler with our MyModule as input: sch = tvm.tir.Schedule(MyModule)

Then we perform the following operations to obtain a reference to block Y and corresponding loops:

block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])

If you now do IPython.display.Code(sch.mod.script(), language="python"), you will see the generated code:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([128, 128], dtype="float32")
        for i, j_0, j_1, k in T.grid(128, 32, 4, 128): # notice j_0, j_1 appear instead of j
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j_0 * 4 + j_1) # calculation changes, but vj is the same. This does not destroy info of our original version
                vk = T.axis.reduce(128, k)
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(Y[vi, vj])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

Merge Loops

We can also merge two blocks (Y and C) together using:

block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
IPython.display.Code(sch.mod.script(), language="python")

and we get:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([128, 128], dtype="float32")
        for i, j_0 in T.grid(128, 32):
            for k, j_1 in T.grid(128, 4):
                with T.block("Y"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in T.serial(4):
                with T.block("C"): # block C is not inside loop for i and j_0
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

Build and Run

We can compiler and run the untransformed version:

rt_lib = tvm.build(MyModule, target="llvm") # compile the untransformed version to target
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd) # tvm.runtime.ndarray.NDArray

func_mm_relu = rt_lib["mm_relu"] # get function by name
func_mm_relu(a_nd, b_nd, c_nd) # run it

np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5) # test corectness

As well as transformed version:

rt_lib_after = tvm.build(sch.mod, target="llvm") # compile the transformed version to target due to cache miss
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

We can run benchmark test: transformed version perform significantly better

f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)

# Time cost of MyModule             0.0061771 sec
# Time cost of transformed sch.mod  0.00197015 sec

Obtain TensorIR

Two ways to obtain TensorIR:

Here is an example of automatically generate TensorIR:

from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")

# two inputs A, B. one output C.
te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})

Development Cycle

Traditional Development Cycle: engineers create different version of the same function.

Traditional Development Cycle: engineers create different version of the same function.

TensorIR Development Cycle: engineers create TensorIR, automatically transform TensorIR, compile from TensorIR to language.

TensorIR Development Cycle: engineers create TensorIR, automatically transform TensorIR, compile from TensorIR to language.

Table of Content