Lecture 005

Stochastic Schedule Transformation

Remember our module

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(
        A: T.Buffer[(128, 128), "float32"],
        B: T.Buffer[(128, 128), "float32"],
        C: T.Buffer[(128, 128), "float32"],
    ):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(128, 128, 128):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Randomized Re-order

Stochastic schedule transformation acts as an oracle to tell you how to split your computations:

def stochastic_schedule_mm(sch: tvm.tir.Schedule):
    block_C = sch.get_block("C", "main") # we only select main to optimize
    i, j, k = sch.get_loops(block=block_C)
    j_factors = sch.sample_perfect_tile(loop=j, n=2) # j_factor here isn't a concrete integer but a random variable
    j_0, j_1 = sch.split(loop=j, factors=j_factors)
    sch.reorder(i, j_0, k, j_1)
    sch.decompose_reduction(block_C, k)
    return sch

Stochastic v.s. Classical

Stochastic v.s. Classical

The sch.sample_perfect_tile function will sample j_factors to be one of [8, 16], [32, 4], [2, 64]. (Notice 8 \times 16 = 4 \times 32 = 2 \times 64 = 128, which is why it is called "perfect")

So each time we run the program, the execution behavior will be different when you run:

sch = tvm.tir.Schedule(MyModule)
sch = stochastic_schedule_mm(sch)
print(sch.trace)
# b0 = sch.get_block(name="C", func_name="main")
# l1, l2, l3 = sch.get_loops(block=b0)
# v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
# l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
# sch.reorder(l1, l6, l3, l7)
# b8 = sch.decompose_reduction(block=b0, loop=l3)

We generate a space of all possible programs

We generate a space of all possible programs

So now we can code our random search algorithm:

def random_search(mod: tvm.IRModule, num_trials=5):
    best_result = None
    best_sch = None

    for i in range(num_trials):
        sch = stochastic_schedule_mm(tvm.tir.Schedule(mod))
        lib = tvm.build(sch.mod, target="llvm")
        f_timer_after = lib.time_evaluator("main", tvm.cpu())
        result = f_timer_after(a_nd, b_nd, c_nd).mean

        print("=====Attempt %d, time-cost: %.3f ms====" % (i, result * 1000))
        print(sch.trace)

        # book keep the best result so far
        if best_result is None or result < best_result:
            best_result = result
            best_sch = sch

    return best_sch

sch = random_search(MyModule)

Randomized Tuning

TVM instead implements evolution search:

from tvm import meta_schedule as ms

sch_tuned = ms.tune_tir(
    mod=MyModule,
    target="llvm --num-cores=1",
    config=ms.TuneConfig(
      max_trials_global=64,
      num_trials_per_iter=64,
    ),
    space=ms.space_generator.ScheduleFn(stochastic_schedule_mm), # you don't have to specify a space here. TVM can search for all possible transformations.
    work_dir="./tune_tmp",
    task_name="main"
)

After transformation, we can see our code runs 10x faster

2022-08-22 20:32:55.548 INFO Scheduler picks Task #0: "main"
2022-08-22 20:33:00.443 INFO Sending 0 sample(s) to builder
2022-08-22 20:33:00.452 INFO Sending 0 sample(s) to runner
2022-08-22 20:33:00.454 INFO [Updated] Task #0: "main"
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated
------------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |         2.8342 |    1479.8785 |             1479.8785 |      5 |
------------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 1479.88

2022-08-22 20:33:00.455 INFO Scheduler picks Task #0: "main"
2022-08-22 20:33:05.147 INFO Task #0 has finished. Remaining task(s): 0
2022-08-22 20:33:05.176 INFO Saved XGBModel to ./tune_tmp/cost_model.xgb
print(sch_tuned.trace)
# b0 = sch.get_block(name="C", func_name="main")
# l1, l2, l3 = sch.get_loops(block=b0)
# v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
# l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
# sch.reorder(l1, l6, l3, l7)
# b8 = sch.decompose_reduction(block=b0, loop=l3)
# sch.enter_postproc()
IPython.display.HTML(code2html(sch_tuned.mod.script()))

Here is generated code

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i, j_0 in T.grid(128, 8):
            for j_1_init in T.serial(16):
                with T.block("C_init"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
                    T.reads()
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.float32(0)
            for k, j_1 in T.grid(128, 16):
                with T.block("C_update"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Note that only functions implemented using TVM are optimizable. Imported library functions are not.

Replacing Primitive Functions in Neural Network

Say we have this neural network:

@tvm.script.ir_module
class MyModuleMixture:
    @T.prim_func
    def linear0(X: T.Buffer[(1, 784), "float32"],
                W: T.Buffer[(128, 784), "float32"],
                B: T.Buffer[(128,), "float32"],
                Z: T.Buffer[(1, 128), "float32"]):
        T.func_attr({"global_symbol": "linear0", "tir.noalias": True})
        Y = T.alloc_buffer((1, 128), "float32")
        for i, j, k in T.grid(1, 128, 784):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]

        for i, j in T.grid(1, 128):
            with T.block("Z"):
                vi, vj = T.axis.remap("SS", [i, j])
                Z[vi, vj] =  Y[vi, vj] + B[vj]

    @R.function
    def main(x: Tensor((1, 784), "float32"),
             w0: Tensor((128, 784), "float32"),
             b0: Tensor((128,), "float32"),
             w1: Tensor((10, 128), "float32"),
             b1: Tensor((10,), "float32")):
        with R.dataflow():
            lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")
            lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32")
            out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32")
            R.output(out)
        return out

We want to optimize and replace linear0 layer.

We take out the linear layer as a single module

mod_linear = tvm.IRModule.from_expr(MyModuleMixture["linear0"].with_attr("global_symbol", "main"))

and tune it

sch_tuned_linear = ms.tune_tir(
    mod=mod_linear,
    target="llvm --num-cores=1",
    config=ms.TuneConfig(
      max_trials_global=64,
      num_trials_per_iter=64,
    ),
    work_dir="./tune_tmp",
    task_name="main",
)

and replace the old one with the tuned function

MyModuleWithParams2 = relax.transform.BindParams("main", nd_params)(MyModuleMixture)
new_func = sch_tuned_linear.mod["main"].with_attr("global_symbol", "linear0")
gv = MyModuleWithParams2.get_global_var("linear0")
MyModuleWithParams2.update_func(gv, new_func)

and build and run it

ex = relax.vm.build(MyModuleWithParams2, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())

nd_res = vm["main"](data_nd)

pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams2 Prediction:", class_names[pred_kind[0]])

Table of Content