# Lecture 009

## Computational Graph Optimization

@tvm.script.ir_module
class MyModule:
@R.function
def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
with relax.dataflow():
lv0 = relax.multiply(x, y)
relax.output(gv0)
return gv0


In above image, we can see a few layers in IRModule.

• SeqExpr: a sequence of expressions, might contain multiple data flow block or control blocks. The most common situration is only containing on dataflow block.

• DataflowBlcok: computation inside dataflow block must be pure, and therefore corresponding to a DAG.

Now our goal is to fuse multiply and add together by fusing two VarBinding together.

One approach to rewrite the program would be to traverse MyModule's Abstract Syntax Tree (AST) recursively and generate a transformed AST.

visitor pattern: allows us to visit each AST node and rewrite them to transformed versions.

Here is the code we use to do fusion:

# find the pattern of add and mul to replace it with addmul
@relax.expr_functor.mutator
class EwiseFMARewriter(relax.PyExprMutator):
def visit_call_(self, call):
call = self.visit_expr_post_order(call) # parent call, remap variables
multiply_op = tvm.ir.Op.get("relax.multiply")
ewise_fma_op = tvm.ir.Op.get("relax.ewise_fma")

# since our function looks like
return call # first node must be add

# value here is non-none if the first argument of call
# is calculated from other functions (ie. it is a symbolic value, not constant)
value = self.lookup_binding(call.args) # call.args refer to first input argument of call
if not isinstance(value, relax.Call) or value.op != multiply_op:
return call # pattern matching unsuccessful

# construct new call
fma_call = relax.Call(
ewise_fma_op, [value.args, value.args, call.args], None, None
)
return fma_call # replace old call with new call

# in above example, only gv0 = relax.add(lv0, y) will trigger rewrite
# relax.multiply(x, y) will not trigger rewrite
updated_fn = EwiseFMARewriter().visit_expr(MyModule["main"])
updated_fn.show()


### Fusing Linear with ReLU

Here is a Linear with ReLU

def create_model():
bb = relax.BlockBuilder()
x = relax.Var("x", (1, 784), relax.DynTensorType(2, "float32"))
w0 = relax.const(mlp_params["w0"], "float32")
b0 = relax.const(mlp_params["b0"], "float32")
w1 = relax.const(mlp_params["w1"], "float32")
b1 = relax.const(mlp_params["b1"], "float32")

with bb.function("main", [x]):
with bb.dataflow():
# notice bb.emit(relax.op.__) is high level function for easy teaching
# in practice, we suggest using low level TensorIR generator
lv0 = bb.emit(relax.op.dense(x, w0))
lv2 = bb.emit(relax.op.relu(lv1))
lv3 = bb.emit(relax.op.dense(lv2, w1))
gv = bb.emit_output(lv4)
bb.emit_func_output(gv)

return bb.get()

MLPModel = create_model()
MLPModel.show()


The following code:

• pattern matching for dense and add

• Generate another fused sub-function dense-add that calls into the dense and add operators.

• replace dense and add with dense-add

@relax.expr_functor.mutator
def __init__(self, mod: IRModule) -> None:
super().__init__()
self.mod_ = mod
# cache pre-defined ops
self.dense_op = tvm.ir.Op.get("relax.nn.dense")
self.counter = 0

def transform(self) -> IRModule:
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue # we only interested in non-primitives (dense-add is primitive)
# avoid already fused primitive functions
if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
continue
updated_func = self.visit_expr(func) # transform(update) particular function
updated_func = relax.analysis.remove_all_unused(updated_func) # remove unused local variables after updating
self.builder_.update_func(global_var, updated_func)

return self.builder_.get() # return IRModule after transform

def visit_call_(self, call):
call = self.visit_expr_post_order(call)

def match_call(node, op): # helper function for pattern matching
if not isinstance(node, relax.Call):
return False
return node.op == op

# pattern match dense => add
return call

value = self.lookup_binding(call.args)
if value is None:
return call

if not match_call(value, self.dense_op):
return call

# extract matched values
x = value.args
w = value.args
b = call.args

# construct a new fused primitive function
param_x = relax.Var("x", x.shape_, x._checked_type_)
param_w = relax.Var("w", w.shape_, w._checked_type_)
param_b = relax.Var("b", b.shape_, b._checked_type_)

# we build our new function named fused_dense_add
bb = relax.BlockBuilder()

self.counter += 1
with bb.function(fn_name, [param_x, param_w, param_b]):
with bb.dataflow():
lv0 = bb.emit(relax.op.nn.dense(param_x, param_w))
bb.emit_func_output(gv)

# Add Primitive attribute to the fused functions
fused_fn = bb.get()[fn_name].with_attr("Primitive", 1)

# add it to current IRModule, will give us a global variable

# construct call into the fused function
return relax.Call(global_var, [x, w, b], None, None)

# mark above procedure as a one "pass"
"""The wrapper for the LowerTensorIR pass."""
def transform_module(self, mod, ctx):

MLPFused.show()


Done!

### Map to TensorIR Calls

Now we have created fusion in high level, we can map the high level primitive abstraction to: library or TensorIR to hardware.

Here is to code to transform fused Relex function to TensorIR

@relax.expr_functor.mutator
class LowerToTensorIR(relax.PyExprMutator):
def __init__(self, mod: IRModule, op_map) -> None:
super().__init__()
self.mod_ = mod
self.op_map = {
tvm.ir.Op.get(k): v for k, v in op_map.items()
}

def visit_call_(self, call):
call = self.visit_expr_post_order(call)

if call.op in self.op_map:
return self.op_map[call.op](self.builder_, call)
return call

def transform(self) -> IRModule:
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue
updated_func = self.visit_expr(func)
self.builder_.update_func(global_var, updated_func)

return self.builder_.get()

def map_dense(bb, call):
x, w = call.args
return bb.call_te(topi.nn.dense, x, w) # since we are not creating new bindings, we use call_te instead of emit_te (binding is already avaliable)

a, b = call.args

def map_relu(bb, call):
return bb.call_te(topi.nn.relu, call.args)

op_map = {
"relax.nn.dense": map_dense,
"relax.nn.relu": map_relu
}

# packaging it to a pass
@tvm.ir.transform.module_pass(opt_level=0, name="LowerToTensorIR")
class LowerToTensorIRPass:
"""The wrapper for the LowerTensorIR pass."""
def transform_module(self, mod, ctx):
return LowerToTensorIR(mod, op_map).transform()

MLPModelTIR = LowerToTensorIRPass()(MLPFused)
MLPModelTIR.show()


Note that in the above code. fused_dense_add0 and fused_dense_add1 still are high-level relax functions that calls into the corresponding TensorIR dense and add functions.

We can turn them into a single TensorIR function (a little bit more complicated, therefore not shown here), which then can be used for follow-up optimization and code generation phases.

MLPModelFinal = relax.transform.FuseTIR()(MLPModelTIR)
MLPModelFinal.show()


So in summary:

1. start from high level model
2. transform using operator fusion
3. lower to TensorIR
4. fuse lower level operator into really a single operator
5. get a runnable program

Improvements:

• our fusing rule is not smart enough: it will create 2 matrix multiplication if given 1 matrix multiplication followed by 2 additions.

• we are currently using pattern-based fusion. A more advanced option will look at properties of TensorIR (element-wise? broadcasting? reduction?) and decide whether to fuse.

Table of Content