# Lecture 006

## Tensor Expression - Creating Primitive Tensor Function from High Level

Tensor Expression (te): domain-specific language to build TensorIR functions.

We can make dummy nodes in a computational graph called A and B.

A = te.placeholder((128, 128), name="A", dtype="float32")
B = te.placeholder((128, 128), name="B", dtype="float32")
type(A)
# tvm.te.tensor.Tensor
A.shape
# [128, 128]


We can build a high-level tensor expression:

def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor:
assert A.shape[1] == B.shape[0]
n = A.shape[0]
m = B.shape[1]
k = te.reduce_axis((0, A.shape[1]), name="k")
return te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")


We can compile Tensor Expression to TensorIR functions:

C = te_matmul(A, B)
D = te_relu(C)
te.create_prim_func([A, B, D]).show()
# you can do te.create_prim_func([A, B, C, D]).show(), but the generated function will expect you to pass buffer of temporary variable C (in general, we do not want to pass in intermediate value for optimization purpose)


It will generate the following

# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
matmul = T.alloc_buffer([128, 128], dtype="float32")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("matmul"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.writes(matmul[i, j])
with T.init():
matmul[i, j] = T.float32(0)
matmul[i, j] = matmul[i, j] + A[i, k] * B[k, j]
for i0, i1 in T.grid(128, 128):
with T.block("relu"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.writes(relu[i0_1, i1_1])
relu[i0_1, i1_1] = T.max(matmul[i0_1, i1_1], T.float32(0))


## Relax Block Builder - Creating Computational Graph from High Level

We initialize our input variable

A = relax.Var("A", (128, 128), relax.DynTensorType(2, "float32"))
B = relax.Var("B", (128, 128), relax.DynTensorType(2, "float32"))


We write high level description of algroithm

bb = relax.BlockBuilder()

with bb.function("main"):
with bb.dataflow():
C = bb.emit_te(te_matmul, A, B)
D = bb.emit_te(te_relu, C)
R = bb.emit_output(D)
bb.emit_func_output(R, params=[A, B]) # parameter here are input params, this function acts like a function specification (function header that tell you input and output)

MyModule = bb.get()
MyModule.show()


It generates TensorIR:

@tvm.script.ir_module
class Module:
@T.prim_func
def te_matmul(rxplaceholder: T.Buffer[(128, 128), "float32"], rxplaceholder_1: T.Buffer[(128, 128), "float32"], matmul: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "te_matmul", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("matmul"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.writes(matmul[i, j])
with T.init():
matmul[i, j] = T.float32(0)
matmul[i, j] = matmul[i, j] + rxplaceholder[i, k] * rxplaceholder_1[k, j]

@T.prim_func
def te_relu(rxplaceholder: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "te_relu", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(128, 128):
with T.block("relu"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.writes(relu[i0_1, i1_1])
relu[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))

@R.function
def main(A: Tensor((128, 128), "float32"), B: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv = R.call_tir(te_matmul, (A, B), (128, 128), dtype="float32")
lv1 = R.call_tir(te_relu, (lv,), (128, 128), dtype="float32")
gv: Tensor((128, 128), "float32") = lv1
R.output(gv)
return gv


When you write C = bb.emit_te(te_matmul, A, B), under the hood, the converter

1. Create an input te.placeholder for A and B
2. Run them through te_matmul function.
3. Call into te.create_prim_func to create a TensorIR function.
4. Generate a call into the function via call_tir.

## PyTorch Interface

TorchFX: a graph generation tool for PyTorch. Graph cannot contain control-flow graph, only data-flow. But you can trace a part of the model that is a graph and stich control flow together.

### Simple PyTroch Interface

model = MyModel()
fx_module = fx.symbolic_trace(model)
type(fx_module)
# torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl
fx_module.graph.print_tabular()


will generate the following graph

opcode         name    target                                                     args         kwargs
-------------  ------  ---------------------------------------------------------  -----------  --------
placeholder    x       x                                                          ()           {}
get_attr       weight  weight                                                     ()           {}
call_function  matmul  <built-in method matmul of type object at 0x7efeb81d5980>  (x, weight)  {}
call_function  relu    <built-in method relu of type object at 0x7efeb81d5980>    (matmul,)    {}
output         output  output                                                     (relu,)      {}


We now want to map each node in TorchFX graph to a corresponding node in Relax. The following piece of code directly translate TorchFX into an TensorIR module:

def map_param(param: nn.Parameter):
ndim = len(param.data.shape)
return relax.const(
param.data.cpu().numpy(), relax.DynTensorType(ndim, "float32")
)

def fetch_attr(fx_mod, target: str):
"""Helper function to fetch an attr"""
target_atoms = target.split('.')
attr_itr = fx_mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr

def from_fx(fx_mod, input_shapes, call_function_map, call_module_map):
input_index = 0
node_map = {}
named_modules = dict(fx_mod.named_modules())

bb = relax.BlockBuilder()

fn_inputs = []
fn_output = None
with bb.function("main"):
with bb.dataflow():
for node in fx_mod.graph.nodes:
if node.op == "placeholder":
# create input placeholder
shape = input_shapes[input_index]
input_index += 1
input_var = relax.Var(
node.target, shape, relax.DynTensorType(len(shape), "float32")
)
fn_inputs.append(input_var)
node_map[node] = input_var
elif node.op == "get_attr":
node_map[node] = map_param(fetch_attr(fx_mod, node.target))
elif node.op == "call_function":
node_map[node] = call_function_map[node.target](bb, node_map, node)
elif node.op == "call_module":
named_module = named_modules[node.target]
node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)
elif node.op == "output":
output = node_map[node.args[0]]
assert fn_output is None
fn_output = bb.emit_output(output)
# output and finalize the function
bb.emit_func_output(output, fn_inputs)
return bb.get()


And we run the translation, after supplying the translation for each primitive tensor function:

# specify translation for primitive matmul function
def map_matmul(bb, node_map, node: fx.Node):
A = node_map[node.args[0]]
B = node_map[node.args[1]]
return bb.emit_te(te_matmul, A, B)

# specify translation for primitive relu function
def map_relu(bb, node_map, node: fx.Node):
A = node_map[node.args[0]]
return bb.emit_te(te_relu, A)

# do translation
MyModule = from_fx(
fx_module,
input_shapes = [(1, 128)],
call_function_map = {
torch.matmul: map_matmul,
torch.relu: map_relu,
},
call_module_map={},
)

MyModule.show()
# will print out TensorIR


### Real Pytorch Interface

To translate realworld nn.Module, we do the following:

from tvm import topi

def map_nn_linear(bb, node_map, node, nn_mod): # nn_mod is nn.Module
x = node_map[node.args[0]] # get node in nn.Module
w = map_param(nn_mod.weight) # get weight in nn.Module, convert to constant
if nn_mod.bias is not None: # get bias in nn.Module, convert to constant
b = map_param(nn_mod.bias)
y = bb.emit_te(topi.nn.dense, x, w) # topi is predefined Tensor Expression API
return bb.emit_te(topi.add, y, b) # translating to two functions "dense" and "add"

def map_nn_relu(bb, node_map, node, nn_mod):
return map_relu(bb, node_map, node)

MLPModule = from_fx(
fx.symbolic_trace(mlp_model),
input_shapes = [(1, 784)],
call_function_map={
},
call_module_map={
torch.nn.Linear: map_nn_linear,
torch.nn.ReLU: map_nn_relu,
},
)

MLPModule.show()


In summary, we translate PyTorch high-level to Tensor Expression high-level (with Relax). Then we translate Tenor Expression to TensorIR and do optimization there.

Table of Content