Tensor Expression (te): domain-specific language to build TensorIR functions.
We can make dummy nodes in a computational graph called A
and B
.
A = te.placeholder((128, 128), name="A", dtype="float32")
B = te.placeholder((128, 128), name="B", dtype="float32")
type(A)
# tvm.te.tensor.Tensor
A.shape
# [128, 128]
We can build a high-level tensor expression:
def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor:
assert A.shape[1] == B.shape[0]
n = A.shape[0]
m = B.shape[1]
k = te.reduce_axis((0, A.shape[1]), name="k")
return te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")
We can compile Tensor Expression to TensorIR functions:
C = te_matmul(A, B)
D = te_relu(C)
te.create_prim_func([A, B, D]).show()
# you can do te.create_prim_func([A, B, C, D]).show(), but the generated function will expect you to pass buffer of temporary variable C (in general, we do not want to pass in intermediate value for optimization purpose)
It will generate the following
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
matmul = T.alloc_buffer([128, 128], dtype="float32")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("matmul"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(A[i, k], B[k, j])
T.writes(matmul[i, j])
with T.init():
matmul[i, j] = T.float32(0)
matmul[i, j] = matmul[i, j] + A[i, k] * B[k, j]
for i0, i1 in T.grid(128, 128):
with T.block("relu"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(matmul[i0_1, i1_1])
T.writes(relu[i0_1, i1_1])
relu[i0_1, i1_1] = T.max(matmul[i0_1, i1_1], T.float32(0))
We initialize our input variable
A = relax.Var("A", (128, 128), relax.DynTensorType(2, "float32"))
B = relax.Var("B", (128, 128), relax.DynTensorType(2, "float32"))
We write high level description of algroithm
bb = relax.BlockBuilder()
with bb.function("main"):
with bb.dataflow():
C = bb.emit_te(te_matmul, A, B)
D = bb.emit_te(te_relu, C)
R = bb.emit_output(D)
bb.emit_func_output(R, params=[A, B]) # parameter here are input params, this function acts like a function specification (function header that tell you input and output)
MyModule = bb.get()
MyModule.show()
It generates TensorIR:
@tvm.script.ir_module
class Module:
@T.prim_func
def te_matmul(rxplaceholder: T.Buffer[(128, 128), "float32"], rxplaceholder_1: T.Buffer[(128, 128), "float32"], matmul: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "te_matmul", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("matmul"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(rxplaceholder[i, k], rxplaceholder_1[k, j])
T.writes(matmul[i, j])
with T.init():
matmul[i, j] = T.float32(0)
matmul[i, j] = matmul[i, j] + rxplaceholder[i, k] * rxplaceholder_1[k, j]
@T.prim_func
def te_relu(rxplaceholder: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "te_relu", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(128, 128):
with T.block("relu"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[i0_1, i1_1])
T.writes(relu[i0_1, i1_1])
relu[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0))
@R.function
def main(A: Tensor((128, 128), "float32"), B: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv = R.call_tir(te_matmul, (A, B), (128, 128), dtype="float32")
lv1 = R.call_tir(te_relu, (lv,), (128, 128), dtype="float32")
gv: Tensor((128, 128), "float32") = lv1
R.output(gv)
return gv
When you write C = bb.emit_te(te_matmul, A, B)
, under the hood, the converter
te.placeholder
for A and Bte_matmul
function.te.create_prim_func
to create a TensorIR function.call_tir
.TorchFX: a graph generation tool for PyTorch. Graph cannot contain control-flow graph, only data-flow. But you can trace a part of the model that is a graph and stich control flow together.
model = MyModel()
fx_module = fx.symbolic_trace(model)
type(fx_module)
# torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl
fx_module.graph.print_tabular()
will generate the following graph
opcode name target args kwargs
------------- ------ --------------------------------------------------------- ----------- --------
placeholder x x () {}
get_attr weight weight () {}
call_function matmul <built-in method matmul of type object at 0x7efeb81d5980> (x, weight) {}
call_function relu <built-in method relu of type object at 0x7efeb81d5980> (matmul,) {}
output output output (relu,) {}
We now want to map each node in TorchFX graph to a corresponding node in Relax. The following piece of code directly translate TorchFX into an TensorIR module:
def map_param(param: nn.Parameter):
ndim = len(param.data.shape)
return relax.const(
param.data.cpu().numpy(), relax.DynTensorType(ndim, "float32")
)
def fetch_attr(fx_mod, target: str):
"""Helper function to fetch an attr"""
target_atoms = target.split('.')
attr_itr = fx_mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
def from_fx(fx_mod, input_shapes, call_function_map, call_module_map):
input_index = 0
node_map = {}
named_modules = dict(fx_mod.named_modules())
bb = relax.BlockBuilder()
fn_inputs = []
fn_output = None
with bb.function("main"):
with bb.dataflow():
for node in fx_mod.graph.nodes:
if node.op == "placeholder":
# create input placeholder
shape = input_shapes[input_index]
input_index += 1
input_var = relax.Var(
node.target, shape, relax.DynTensorType(len(shape), "float32")
)
fn_inputs.append(input_var)
node_map[node] = input_var
elif node.op == "get_attr":
node_map[node] = map_param(fetch_attr(fx_mod, node.target))
elif node.op == "call_function":
node_map[node] = call_function_map[node.target](bb, node_map, node)
elif node.op == "call_module":
named_module = named_modules[node.target]
node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)
elif node.op == "output":
output = node_map[node.args[0]]
assert fn_output is None
fn_output = bb.emit_output(output)
# output and finalize the function
bb.emit_func_output(output, fn_inputs)
return bb.get()
And we run the translation, after supplying the translation for each primitive tensor function:
# specify translation for primitive matmul function
def map_matmul(bb, node_map, node: fx.Node):
A = node_map[node.args[0]]
B = node_map[node.args[1]]
return bb.emit_te(te_matmul, A, B)
# specify translation for primitive relu function
def map_relu(bb, node_map, node: fx.Node):
A = node_map[node.args[0]]
return bb.emit_te(te_relu, A)
# do translation
MyModule = from_fx(
fx_module,
input_shapes = [(1, 128)],
call_function_map = {
torch.matmul: map_matmul,
torch.relu: map_relu,
},
call_module_map={},
)
MyModule.show()
# will print out TensorIR
To translate realworld nn.Module
, we do the following:
from tvm import topi
def map_nn_linear(bb, node_map, node, nn_mod): # nn_mod is nn.Module
x = node_map[node.args[0]] # get node in nn.Module
w = map_param(nn_mod.weight) # get weight in nn.Module, convert to constant
if nn_mod.bias is not None: # get bias in nn.Module, convert to constant
b = map_param(nn_mod.bias)
y = bb.emit_te(topi.nn.dense, x, w) # topi is predefined Tensor Expression API
return bb.emit_te(topi.add, y, b) # translating to two functions "dense" and "add"
def map_nn_relu(bb, node_map, node, nn_mod):
return map_relu(bb, node_map, node)
MLPModule = from_fx(
fx.symbolic_trace(mlp_model),
input_shapes = [(1, 784)],
call_function_map={
},
call_module_map={
torch.nn.Linear: map_nn_linear,
torch.nn.ReLU: map_nn_relu,
},
)
MLPModule.show()
In summary, we translate PyTorch high-level to Tensor Expression high-level (with Relax). Then we translate Tenor Expression to TensorIR and do optimization there.
Table of Content