Lecture 006

Tensor Expression - Creating Primitive Tensor Function from High Level

Tensor Expression (te): domain-specific language to build TensorIR functions.

We can make dummy nodes in a computational graph called A and B.

A = te.placeholder((128, 128), name="A", dtype="float32")
B = te.placeholder((128, 128), name="B", dtype="float32")
# tvm.te.tensor.Tensor
# [128, 128]

We can build a high-level tensor expression:

def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor:
    assert A.shape[1] == B.shape[0]
    n = A.shape[0]
    m = B.shape[1]
    k = te.reduce_axis((0, A.shape[1]), name="k")
    return te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")

We can compile Tensor Expression to TensorIR functions:

C = te_matmul(A, B)
D = te_relu(C)
te.create_prim_func([A, B, D]).show()
# you can do te.create_prim_func([A, B, C, D]).show(), but the generated function will expect you to pass buffer of temporary variable C (in general, we do not want to pass in intermediate value for optimization purpose)

It will generate the following

# from tvm.script import tir as T
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # body
    # with T.block("root")
    matmul = T.alloc_buffer([128, 128], dtype="float32")
    for i0, i1, i2 in T.grid(128, 128, 128):
        with T.block("matmul"):
            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
            T.reads(A[i, k], B[k, j])
            T.writes(matmul[i, j])
            with T.init():
                matmul[i, j] = T.float32(0)
            matmul[i, j] = matmul[i, j] + A[i, k] * B[k, j]
    for i0, i1 in T.grid(128, 128):
        with T.block("relu"):
            i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
            T.reads(matmul[i0_1, i1_1])
            T.writes(relu[i0_1, i1_1])
            relu[i0_1, i1_1] = T.max(matmul[i0_1, i1_1], T.float32(0))

Relax Block Builder - Creating Computational Graph from High Level

We initialize our input variable

A = relax.Var("A", (128, 128), relax.DynTensorType(2, "float32"))
B = relax.Var("B", (128, 128), relax.DynTensorType(2, "float32"))

We write high level description of algroithm

bb = relax.BlockBuilder()

with bb.function("main"):
    with bb.dataflow():
        C = bb.emit_te(te_matmul, A, B)
        D = bb.emit_te(te_relu, C)
        R = bb.emit_output(D)
    bb.emit_func_output(R, params=[A, B]) # parameter here are input params, this function acts like a function specification (function header that tell you input and output)

MyModule = bb.get()

It generates TensorIR:

class Module:
    def te_matmul(rxplaceholder: T.Buffer[(128, 128), "float32"], rxplaceholder_1: T.Buffer[(128, 128), "float32"], matmul: T.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "te_matmul", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2 in T.grid(128, 128, 128):
            with T.block("matmul"):
                i, j, k = T.axis.remap("SSR", [i0, i1, i2])
                T.reads(rxplaceholder[i, k], rxplaceholder_1[k, j])
                T.writes(matmul[i, j])
                with T.init():
                    matmul[i, j] = T.float32(0)
                matmul[i, j] = matmul[i, j] + rxplaceholder[i, k] * rxplaceholder_1[k, j]

    def te_relu(rxplaceholder: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "te_relu", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1 in T.grid(128, 128):
            with T.block("relu"):
                i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
                T.reads(rxplaceholder[i0_1, i1_1])
                T.writes(relu[i0_1, i1_1])
                relu[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))

    def main(A: Tensor((128, 128), "float32"), B: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2):
        # block 0
        with R.dataflow():
            lv = R.call_tir(te_matmul, (A, B), (128, 128), dtype="float32")
            lv1 = R.call_tir(te_relu, (lv,), (128, 128), dtype="float32")
            gv: Tensor((128, 128), "float32") = lv1
        return gv

Compare Relax Blockbuilder to TensorIR

When you write C = bb.emit_te(te_matmul, A, B), under the hood, the converter

  1. Create an input te.placeholder for A and B
  2. Run them through te_matmul function.
  3. Call into te.create_prim_func to create a TensorIR function.
  4. Generate a call into the function via call_tir.

PyTorch Interface

TorchFX: a graph generation tool for PyTorch. Graph cannot contain control-flow graph, only data-flow. But you can trace a part of the model that is a graph and stich control flow together.

Simple PyTroch Interface

model = MyModel()
fx_module = fx.symbolic_trace(model)
# torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

will generate the following graph

opcode         name    target                                                     args         kwargs
-------------  ------  ---------------------------------------------------------  -----------  --------
placeholder    x       x                                                          ()           {}
get_attr       weight  weight                                                     ()           {}
call_function  matmul  <built-in method matmul of type object at 0x7efeb81d5980>  (x, weight)  {}
call_function  relu    <built-in method relu of type object at 0x7efeb81d5980>    (matmul,)    {}
output         output  output                                                     (relu,)      {}

We now want to map each node in TorchFX graph to a corresponding node in Relax. The following piece of code directly translate TorchFX into an TensorIR module:

def map_param(param: nn.Parameter):
    ndim = len(param.data.shape)
    return relax.const(
        param.data.cpu().numpy(), relax.DynTensorType(ndim, "float32")

def fetch_attr(fx_mod, target: str):
    """Helper function to fetch an attr"""
    target_atoms = target.split('.')
    attr_itr = fx_mod
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
            raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr

def from_fx(fx_mod, input_shapes, call_function_map, call_module_map):
    input_index = 0
    node_map = {}
    named_modules = dict(fx_mod.named_modules())

    bb = relax.BlockBuilder()

    fn_inputs = []
    fn_output = None
    with bb.function("main"):
        with bb.dataflow():
            for node in fx_mod.graph.nodes:
                if node.op == "placeholder":
                    # create input placeholder
                    shape = input_shapes[input_index]
                    input_index += 1
                    input_var = relax.Var(
                        node.target, shape, relax.DynTensorType(len(shape), "float32")
                    node_map[node] = input_var
                elif node.op == "get_attr":
                    node_map[node] = map_param(fetch_attr(fx_mod, node.target))
                elif node.op == "call_function":
                    node_map[node] = call_function_map[node.target](bb, node_map, node)
                elif node.op == "call_module":
                    named_module = named_modules[node.target]
                    node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)
                elif node.op == "output":
                    output = node_map[node.args[0]]
                    assert fn_output is None
                    fn_output = bb.emit_output(output)
        # output and finalize the function
        bb.emit_func_output(output, fn_inputs)
    return bb.get()

And we run the translation, after supplying the translation for each primitive tensor function:

# specify translation for primitive matmul function
def map_matmul(bb, node_map, node: fx.Node):
    A = node_map[node.args[0]]
    B = node_map[node.args[1]]
    return bb.emit_te(te_matmul, A, B)

# specify translation for primitive relu function
def map_relu(bb, node_map, node: fx.Node):
    A = node_map[node.args[0]]
    return bb.emit_te(te_relu, A)

# do translation
MyModule = from_fx(
    input_shapes = [(1, 128)],
    call_function_map = {
      torch.matmul: map_matmul,
      torch.relu: map_relu,

# will print out TensorIR

Real Pytorch Interface

To translate realworld nn.Module, we do the following:

from tvm import topi

def map_nn_linear(bb, node_map, node, nn_mod): # nn_mod is nn.Module
    x = node_map[node.args[0]] # get node in nn.Module
    w = map_param(nn_mod.weight) # get weight in nn.Module, convert to constant
    if nn_mod.bias is not None: # get bias in nn.Module, convert to constant
        b = map_param(nn_mod.bias)
    y = bb.emit_te(topi.nn.dense, x, w) # topi is predefined Tensor Expression API
    return bb.emit_te(topi.add, y, b) # translating to two functions "dense" and "add"

def map_nn_relu(bb, node_map, node, nn_mod):
    return map_relu(bb, node_map, node)

MLPModule = from_fx(
    input_shapes = [(1, 784)],
        torch.nn.Linear: map_nn_linear,
        torch.nn.ReLU: map_nn_relu,


In summary, we translate PyTorch high-level to Tensor Expression high-level (with Relax). Then we translate Tenor Expression to TensorIR and do optimization there.

