Lecture 009

Computational Graph Optimization

Fusing Multiply with Add

class MyModule:
    def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
        with relax.dataflow():
            lv0 = relax.multiply(x, y)
            gv0 = relax.add(lv0, y)
        return gv0

Data and Blocks

Data and Blocks

In above image, we can see a few layers in IRModule.

Now our goal is to fuse multiply and add together by fusing two VarBinding together.

One approach to rewrite the program would be to traverse MyModule's Abstract Syntax Tree (AST) recursively and generate a transformed AST.

visitor pattern: allows us to visit each AST node and rewrite them to transformed versions.

Here is the code we use to do fusion:

# find the pattern of add and mul to replace it with addmul
class EwiseFMARewriter(relax.PyExprMutator):
    def visit_call_(self, call):
        call = self.visit_expr_post_order(call) # parent call, remap variables
        add_op = tvm.ir.Op.get("relax.add") # primitive function structure for pattern match
        multiply_op = tvm.ir.Op.get("relax.multiply")
        ewise_fma_op = tvm.ir.Op.get("relax.ewise_fma")

        # since our function looks like
        # add(mul(a, b)), add correspond to first (outer layer) function
        if call.op != add_op:
            return call # first node must be add

        # value here is non-none if the first argument of call
        # is calculated from other functions (ie. it is a symbolic value, not constant)
        value = self.lookup_binding(call.args[0]) # call.args[0] refer to first input argument of call
        if not isinstance(value, relax.Call) or value.op != multiply_op:
            return call # pattern matching unsuccessful

        # construct new call
        fma_call = relax.Call(
            ewise_fma_op, [value.args[0], value.args[1], call.args[1]], None, None
        return fma_call # replace old call with new call

# in above example, only gv0 = relax.add(lv0, y) will trigger rewrite
# relax.multiply(x, y) will not trigger rewrite
updated_fn = EwiseFMARewriter().visit_expr(MyModule["main"])

Fusing Linear with ReLU

Here is a Linear with ReLU

def create_model():
    bb = relax.BlockBuilder()
    x = relax.Var("x", (1, 784), relax.DynTensorType(2, "float32"))
    w0 = relax.const(mlp_params["w0"], "float32")
    b0 = relax.const(mlp_params["b0"], "float32")
    w1 = relax.const(mlp_params["w1"], "float32")
    b1 = relax.const(mlp_params["b1"], "float32")

    with bb.function("main", [x]):
        with bb.dataflow():
            # notice bb.emit(relax.op.__) is high level function for easy teaching
            # in practice, we suggest using low level TensorIR generator
            lv0 = bb.emit(relax.op.dense(x, w0))
            lv1 = bb.emit(relax.op.add(lv0, b0))
            lv2 = bb.emit(relax.op.relu(lv1))
            lv3 = bb.emit(relax.op.dense(lv2, w1))
            lv4 = bb.emit(relax.op.add(lv3, b1))
            gv = bb.emit_output(lv4)

    return bb.get()

MLPModel = create_model()

The following code:

class DenseAddFusor(relax.PyExprMutator):
    def __init__(self, mod: IRModule) -> None:
        self.mod_ = mod
        # cache pre-defined ops
        self.add_op = tvm.ir.Op.get("relax.add")
        self.dense_op = tvm.ir.Op.get("relax.nn.dense")
        self.counter = 0

    def transform(self) -> IRModule:
        for global_var, func in self.mod_.functions.items():
            if not isinstance(func, relax.Function):
                continue # we only interested in non-primitives (dense-add is primitive)
            # avoid already fused primitive functions
            if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
            updated_func = self.visit_expr(func) # transform(update) particular function
            updated_func = relax.analysis.remove_all_unused(updated_func) # remove unused local variables after updating
            self.builder_.update_func(global_var, updated_func)

        return self.builder_.get() # return IRModule after transform

    def visit_call_(self, call):
        call = self.visit_expr_post_order(call)

        def match_call(node, op): # helper function for pattern matching
            if not isinstance(node, relax.Call):
                return False
            return node.op == op

        # pattern match dense => add
        if not match_call(call, self.add_op):
            return call

        value = self.lookup_binding(call.args[0])
        if value is None:
            return call

        if not match_call(value, self.dense_op):
            return call

        # extract matched values
        x = value.args[0]
        w = value.args[1]
        b = call.args[1]

        # construct a new fused primitive function
        param_x = relax.Var("x", x.shape_, x._checked_type_)
        param_w = relax.Var("w", w.shape_, w._checked_type_)
        param_b = relax.Var("b", b.shape_, b._checked_type_)

        # we build our new function named fused_dense_add
        bb = relax.BlockBuilder()

        fn_name = "fused_dense_add%d" % (self.counter)
        self.counter += 1
        with bb.function(fn_name, [param_x, param_w, param_b]):
            with bb.dataflow():
                lv0 = bb.emit(relax.op.nn.dense(param_x, param_w))
                gv = bb.emit_output(relax.op.add(lv0, param_b))

        # Add Primitive attribute to the fused functions
        fused_fn = bb.get()[fn_name].with_attr("Primitive", 1)

        # add it to current IRModule, will give us a global variable
        global_var = self.builder_.add_func(fused_fn, fn_name)

        # construct call into the fused function
        return relax.Call(global_var, [x, w, b], None, None)

# mark above procedure as a one "pass"
@tvm.ir.transform.module_pass(opt_level=2, name="DeseAddFuse")
class FuseDenseAddPass:
    """The wrapper for the LowerTensorIR pass."""
    def transform_module(self, mod, ctx):
        return DenseAddFusor(mod).transform()

MLPFused = FuseDenseAddPass()(MLPModel)


Map to TensorIR Calls

Now we have created fusion in high level, we can map the high level primitive abstraction to: library or TensorIR to hardware.

Here is to code to transform fused Relex function to TensorIR

class LowerToTensorIR(relax.PyExprMutator):
    def __init__(self, mod: IRModule, op_map) -> None:
        self.mod_ = mod
        self.op_map = {
            tvm.ir.Op.get(k): v for k, v in op_map.items()

    def visit_call_(self, call):
        call = self.visit_expr_post_order(call)

        if call.op in self.op_map:
            return self.op_map[call.op](self.builder_, call)
        return call

    def transform(self) -> IRModule:
        for global_var, func in self.mod_.functions.items():
            if not isinstance(func, relax.Function):
            updated_func = self.visit_expr(func)
            self.builder_.update_func(global_var, updated_func)

        return self.builder_.get()

def map_dense(bb, call):
    x, w = call.args
    return bb.call_te(topi.nn.dense, x, w) # since we are not creating new bindings, we use call_te instead of emit_te (binding is already avaliable)

def map_add(bb, call):
    a, b = call.args
    return bb.call_te(topi.add, a, b)

def map_relu(bb, call):
    return bb.call_te(topi.nn.relu, call.args[0])

op_map = {
  "relax.nn.dense": map_dense,
  "relax.add": map_add,
  "relax.nn.relu": map_relu

# packaging it to a pass
@tvm.ir.transform.module_pass(opt_level=0, name="LowerToTensorIR")
class LowerToTensorIRPass:
    """The wrapper for the LowerTensorIR pass."""
    def transform_module(self, mod, ctx):
        return LowerToTensorIR(mod, op_map).transform()

MLPModelTIR = LowerToTensorIRPass()(MLPFused)

Note that in the above code. fused_dense_add0 and fused_dense_add1 still are high-level relax functions that calls into the corresponding TensorIR dense and add functions.

We can turn them into a single TensorIR function (a little bit more complicated, therefore not shown here), which then can be used for follow-up optimization and code generation phases.

MLPModelFinal = relax.transform.FuseTIR()(MLPModelTIR)

So in summary:

  1. start from high level model
  2. transform using operator fusion
  3. lower to TensorIR
  4. fuse lower level operator into really a single operator
  5. get a runnable program


Table of Content