We can write the following program:
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b is equivalent to np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)
The particular code mm_relu
is implemented in a language called TVMScript, which is a domain-specific dialect embedded in python AST.
We annotate TVMScript using python equivalence below:
@tvm.script.ir_module
class MyModule:
@T.prim_func # to achieve type(MyModule['mm_relu']) => tvm.tir.function.PrimFunc
def mm_relu(A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"]): # def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) # mm_relu is function handel name, "tir.noalias": True means our array does not have overlap memory
Y = T.alloc_buffer((128, 128), dtype="float32") # Y = np.empty((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128): # for i in range(128): for j in range(128): for k in range(128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k) # vi, vj, vk = i, j, k
with T.init(): # if vk == 0:
Y[vi, vj] = T.float32(0) # Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
The following block (in the format [block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])
) specifies more information than python code:
They define where should vi, vj, vk be bound to (in this case i, j k).
They declare the original range that the vi, vj, vk are supposed to be (the 128 in T.axis.spatial(128, i))
They declare the properties of the iterators (spatial, reduce)
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
# above code is equivalent to
vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # SSR = "spatial", "spatial", "reduce"
In order to get a specific scalar for Y, say Y[0, 1]
, we need to set vi=0
, vj=1
, and the left-over vk
has to loop for all values. This is why vi, vj
are called spacial axis and vk
are called reduce axis.
Observe that we can calculate
Y[0, 1]
in parallel withY[1, 2]
sincevi, vj
are spacial dimension.
We want to change the dimension of the loop, to generate code that looks like the following:
# The order of iterations changes slightly
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j0 in range(32): # notice here
for k in range(128):
for j1 in range(4): # and here
j = j0 * 4 + j1 # i is split from 128 = 32*4
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)
To do so, we make a scheduler with our MyModule
as input: sch = tvm.tir.Schedule(MyModule)
Then we perform the following operations to obtain a reference to block Y and corresponding loops:
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])
If you now do IPython.display.Code(sch.mod.script(), language="python")
, you will see the generated code:
@tvm.script.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
Y = T.alloc_buffer([128, 128], dtype="float32")
for i, j_0, j_1, k in T.grid(128, 32, 4, 128): # notice j_0, j_1 appear instead of j
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1) # calculation changes, but vj is the same. This does not destroy info of our original version
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
We can also merge two blocks (Y
and C
) together using:
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
IPython.display.Code(sch.mod.script(), language="python")
and we get:
@tvm.script.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
Y = T.alloc_buffer([128, 128], dtype="float32")
for i, j_0 in T.grid(128, 32):
for k, j_1 in T.grid(128, 4):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in T.serial(4):
with T.block("C"): # block C is not inside loop for i and j_0
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
We can compiler and run the untransformed version:
rt_lib = tvm.build(MyModule, target="llvm") # compile the untransformed version to target
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd) # tvm.runtime.ndarray.NDArray
func_mm_relu = rt_lib["mm_relu"] # get function by name
func_mm_relu(a_nd, b_nd, c_nd) # run it
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5) # test corectness
As well as transformed version:
rt_lib_after = tvm.build(sch.mod, target="llvm") # compile the transformed version to target due to cache miss
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)
We can run benchmark test: transformed version perform significantly better
f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)
# Time cost of MyModule 0.0061771 sec
# Time cost of transformed sch.mod 0.00197015 sec
Two ways to obtain TensorIR:
you can write TensorIR via TVMScript directly
automatically generate TensorIR using high level Tensor Expression (te)
obtain through transformation from some other TensorIR
Here is an example of automatically generate TensorIR:
from tvm import te
A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")
# two inputs A, B. one output C.
te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})
Table of Content