@tvm.script.ir_module
class MyModule:
@R.function
def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
with relax.dataflow():
lv0 = relax.multiply(x, y)
gv0 = relax.add(lv0, y)
relax.output(gv0)
return gv0
In above image, we can see a few layers in IRModule.
SeqExpr
: a sequence of expressions, might contain multiple data flow block or control blocks. The most common situration is only containing on dataflow block.
DataflowBlcok
: computation inside dataflow block must be pure, and therefore corresponding to a DAG.
Now our goal is to fuse multiply
and add
together by fusing two VarBinding
together.
One approach to rewrite the program would be to traverse MyModule's Abstract Syntax Tree (AST) recursively and generate a transformed AST.
visitor pattern: allows us to visit each AST node and rewrite them to transformed versions.
Here is the code we use to do fusion:
# find the pattern of add and mul to replace it with addmul
@relax.expr_functor.mutator
class EwiseFMARewriter(relax.PyExprMutator):
def visit_call_(self, call):
call = self.visit_expr_post_order(call) # parent call, remap variables
add_op = tvm.ir.Op.get("relax.add") # primitive function structure for pattern match
multiply_op = tvm.ir.Op.get("relax.multiply")
ewise_fma_op = tvm.ir.Op.get("relax.ewise_fma")
# since our function looks like
# add(mul(a, b)), add correspond to first (outer layer) function
if call.op != add_op:
return call # first node must be add
# value here is non-none if the first argument of call
# is calculated from other functions (ie. it is a symbolic value, not constant)
value = self.lookup_binding(call.args[0]) # call.args[0] refer to first input argument of call
if not isinstance(value, relax.Call) or value.op != multiply_op:
return call # pattern matching unsuccessful
# construct new call
fma_call = relax.Call(
ewise_fma_op, [value.args[0], value.args[1], call.args[1]], None, None
)
return fma_call # replace old call with new call
# in above example, only gv0 = relax.add(lv0, y) will trigger rewrite
# relax.multiply(x, y) will not trigger rewrite
updated_fn = EwiseFMARewriter().visit_expr(MyModule["main"])
updated_fn.show()
Here is a Linear with ReLU
def create_model():
bb = relax.BlockBuilder()
x = relax.Var("x", (1, 784), relax.DynTensorType(2, "float32"))
w0 = relax.const(mlp_params["w0"], "float32")
b0 = relax.const(mlp_params["b0"], "float32")
w1 = relax.const(mlp_params["w1"], "float32")
b1 = relax.const(mlp_params["b1"], "float32")
with bb.function("main", [x]):
with bb.dataflow():
# notice bb.emit(relax.op.__) is high level function for easy teaching
# in practice, we suggest using low level TensorIR generator
lv0 = bb.emit(relax.op.dense(x, w0))
lv1 = bb.emit(relax.op.add(lv0, b0))
lv2 = bb.emit(relax.op.relu(lv1))
lv3 = bb.emit(relax.op.dense(lv2, w1))
lv4 = bb.emit(relax.op.add(lv3, b1))
gv = bb.emit_output(lv4)
bb.emit_func_output(gv)
return bb.get()
MLPModel = create_model()
MLPModel.show()
The following code:
pattern matching for dense
and add
Generate another fused sub-function dense-add
that calls into the dense and add operators.
replace dense
and add
with dense-add
@relax.expr_functor.mutator
class DenseAddFusor(relax.PyExprMutator):
def __init__(self, mod: IRModule) -> None:
super().__init__()
self.mod_ = mod
# cache pre-defined ops
self.add_op = tvm.ir.Op.get("relax.add")
self.dense_op = tvm.ir.Op.get("relax.nn.dense")
self.counter = 0
def transform(self) -> IRModule:
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue # we only interested in non-primitives (dense-add is primitive)
# avoid already fused primitive functions
if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
continue
updated_func = self.visit_expr(func) # transform(update) particular function
updated_func = relax.analysis.remove_all_unused(updated_func) # remove unused local variables after updating
self.builder_.update_func(global_var, updated_func)
return self.builder_.get() # return IRModule after transform
def visit_call_(self, call):
call = self.visit_expr_post_order(call)
def match_call(node, op): # helper function for pattern matching
if not isinstance(node, relax.Call):
return False
return node.op == op
# pattern match dense => add
if not match_call(call, self.add_op):
return call
value = self.lookup_binding(call.args[0])
if value is None:
return call
if not match_call(value, self.dense_op):
return call
# extract matched values
x = value.args[0]
w = value.args[1]
b = call.args[1]
# construct a new fused primitive function
param_x = relax.Var("x", x.shape_, x._checked_type_)
param_w = relax.Var("w", w.shape_, w._checked_type_)
param_b = relax.Var("b", b.shape_, b._checked_type_)
# we build our new function named fused_dense_add
bb = relax.BlockBuilder()
fn_name = "fused_dense_add%d" % (self.counter)
self.counter += 1
with bb.function(fn_name, [param_x, param_w, param_b]):
with bb.dataflow():
lv0 = bb.emit(relax.op.nn.dense(param_x, param_w))
gv = bb.emit_output(relax.op.add(lv0, param_b))
bb.emit_func_output(gv)
# Add Primitive attribute to the fused functions
fused_fn = bb.get()[fn_name].with_attr("Primitive", 1)
# add it to current IRModule, will give us a global variable
global_var = self.builder_.add_func(fused_fn, fn_name)
# construct call into the fused function
return relax.Call(global_var, [x, w, b], None, None)
# mark above procedure as a one "pass"
@tvm.ir.transform.module_pass(opt_level=2, name="DeseAddFuse")
class FuseDenseAddPass:
"""The wrapper for the LowerTensorIR pass."""
def transform_module(self, mod, ctx):
return DenseAddFusor(mod).transform()
MLPFused = FuseDenseAddPass()(MLPModel)
MLPFused.show()
Done!
Now we have created fusion in high level, we can map the high level primitive abstraction to: library or TensorIR to hardware.
Here is to code to transform fused Relex function to TensorIR
@relax.expr_functor.mutator
class LowerToTensorIR(relax.PyExprMutator):
def __init__(self, mod: IRModule, op_map) -> None:
super().__init__()
self.mod_ = mod
self.op_map = {
tvm.ir.Op.get(k): v for k, v in op_map.items()
}
def visit_call_(self, call):
call = self.visit_expr_post_order(call)
if call.op in self.op_map:
return self.op_map[call.op](self.builder_, call)
return call
def transform(self) -> IRModule:
for global_var, func in self.mod_.functions.items():
if not isinstance(func, relax.Function):
continue
updated_func = self.visit_expr(func)
self.builder_.update_func(global_var, updated_func)
return self.builder_.get()
def map_dense(bb, call):
x, w = call.args
return bb.call_te(topi.nn.dense, x, w) # since we are not creating new bindings, we use call_te instead of emit_te (binding is already avaliable)
def map_add(bb, call):
a, b = call.args
return bb.call_te(topi.add, a, b)
def map_relu(bb, call):
return bb.call_te(topi.nn.relu, call.args[0])
op_map = {
"relax.nn.dense": map_dense,
"relax.add": map_add,
"relax.nn.relu": map_relu
}
# packaging it to a pass
@tvm.ir.transform.module_pass(opt_level=0, name="LowerToTensorIR")
class LowerToTensorIRPass:
"""The wrapper for the LowerTensorIR pass."""
def transform_module(self, mod, ctx):
return LowerToTensorIR(mod, op_map).transform()
MLPModelTIR = LowerToTensorIRPass()(MLPFused)
MLPModelTIR.show()
Note that in the above code. fused_dense_add0
and fused_dense_add1
still are high-level relax functions that calls into the corresponding TensorIR dense and add functions.
We can turn them into a single TensorIR function (a little bit more complicated, therefore not shown here), which then can be used for follow-up optimization and code generation phases.
MLPModelFinal = relax.transform.FuseTIR()(MLPModelTIR)
MLPModelFinal.show()
So in summary:
Improvements:
our fusing rule is not smart enough: it will create 2 matrix multiplication if given 1 matrix multiplication followed by 2 additions.
we are currently using pattern-based fusion. A more advanced option will look at properties of TensorIR (element-wise? broadcasting? reduction?) and decide whether to fuse.
Table of Content