Remember our module
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i, j, k in T.grid(128, 128, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Stochastic schedule transformation acts as an oracle to tell you how to split your computations:
def stochastic_schedule_mm(sch: tvm.tir.Schedule):
block_C = sch.get_block("C", "main") # we only select main to optimize
i, j, k = sch.get_loops(block=block_C)
j_factors = sch.sample_perfect_tile(loop=j, n=2) # j_factor here isn't a concrete integer but a random variable
j_0, j_1 = sch.split(loop=j, factors=j_factors)
sch.reorder(i, j_0, k, j_1)
sch.decompose_reduction(block_C, k)
return sch
The sch.sample_perfect_tile
function will sample j_factors
to be one of [8, 16], [32, 4], [2, 64]
. (Notice 8 \times 16 = 4 \times 32 = 2 \times 64 = 128, which is why it is called "perfect")
So each time we run the program, the execution behavior will be different when you run:
sch = tvm.tir.Schedule(MyModule)
sch = stochastic_schedule_mm(sch)
print(sch.trace)
# b0 = sch.get_block(name="C", func_name="main")
# l1, l2, l3 = sch.get_loops(block=b0)
# v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
# l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
# sch.reorder(l1, l6, l3, l7)
# b8 = sch.decompose_reduction(block=b0, loop=l3)
So now we can code our random search algorithm:
def random_search(mod: tvm.IRModule, num_trials=5):
best_result = None
best_sch = None
for i in range(num_trials):
sch = stochastic_schedule_mm(tvm.tir.Schedule(mod))
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
result = f_timer_after(a_nd, b_nd, c_nd).mean
print("=====Attempt %d, time-cost: %.3f ms====" % (i, result * 1000))
print(sch.trace)
# book keep the best result so far
if best_result is None or result < best_result:
best_result = result
best_sch = sch
return best_sch
sch = random_search(MyModule)
TVM instead implements evolution search:
from tvm import meta_schedule as ms
sch_tuned = ms.tune_tir(
mod=MyModule,
target="llvm --num-cores=1",
config=ms.TuneConfig(
max_trials_global=64,
num_trials_per_iter=64,
),
space=ms.space_generator.ScheduleFn(stochastic_schedule_mm), # you don't have to specify a space here. TVM can search for all possible transformations.
work_dir="./tune_tmp",
task_name="main"
)
After transformation, we can see our code runs 10x faster
2022-08-22 20:32:55.548 INFO Scheduler picks Task #0: "main"
2022-08-22 20:33:00.443 INFO Sending 0 sample(s) to builder
2022-08-22 20:33:00.452 INFO Sending 0 sample(s) to runner
2022-08-22 20:33:00.454 INFO [Updated] Task #0: "main"
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated
------------------------------------------------------------------------------------------------------------
0 | main | 4194304 | 1 | 2.8342 | 1479.8785 | 1479.8785 | 5 |
------------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 1479.88
2022-08-22 20:33:00.455 INFO Scheduler picks Task #0: "main"
2022-08-22 20:33:05.147 INFO Task #0 has finished. Remaining task(s): 0
2022-08-22 20:33:05.176 INFO Saved XGBModel to ./tune_tmp/cost_model.xgb
print(sch_tuned.trace)
# b0 = sch.get_block(name="C", func_name="main")
# l1, l2, l3 = sch.get_loops(block=b0)
# v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
# l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
# sch.reorder(l1, l6, l3, l7)
# b8 = sch.decompose_reduction(block=b0, loop=l3)
# sch.enter_postproc()
IPython.display.HTML(code2html(sch_tuned.mod.script()))
Here is generated code
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i, j_0 in T.grid(128, 8):
for j_1_init in T.serial(16):
with T.block("C_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
T.reads()
T.writes(C[vi, vj])
C[vi, vj] = T.float32(0)
for k, j_1 in T.grid(128, 16):
with T.block("C_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
vk = T.axis.reduce(128, k)
T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Note that only functions implemented using TVM are optimizable. Imported library functions are not.
Say we have this neural network:
@tvm.script.ir_module
class MyModuleMixture:
@T.prim_func
def linear0(X: T.Buffer[(1, 784), "float32"],
W: T.Buffer[(128, 784), "float32"],
B: T.Buffer[(128,), "float32"],
Z: T.Buffer[(1, 128), "float32"]):
T.func_attr({"global_symbol": "linear0", "tir.noalias": True})
Y = T.alloc_buffer((1, 128), "float32")
for i, j, k in T.grid(1, 128, 784):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(1, 128):
with T.block("Z"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Y[vi, vj] + B[vj]
@R.function
def main(x: Tensor((1, 784), "float32"),
w0: Tensor((128, 784), "float32"),
b0: Tensor((128,), "float32"),
w1: Tensor((10, 128), "float32"),
b1: Tensor((10,), "float32")):
with R.dataflow():
lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")
lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32")
out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32")
R.output(out)
return out
We want to optimize and replace linear0
layer.
We take out the linear layer as a single module
mod_linear = tvm.IRModule.from_expr(MyModuleMixture["linear0"].with_attr("global_symbol", "main"))
and tune it
sch_tuned_linear = ms.tune_tir(
mod=mod_linear,
target="llvm --num-cores=1",
config=ms.TuneConfig(
max_trials_global=64,
num_trials_per_iter=64,
),
work_dir="./tune_tmp",
task_name="main",
)
and replace the old one with the tuned function
MyModuleWithParams2 = relax.transform.BindParams("main", nd_params)(MyModuleMixture)
new_func = sch_tuned_linear.mod["main"].with_attr("global_symbol", "linear0")
gv = MyModuleWithParams2.get_global_var("linear0")
MyModuleWithParams2.update_func(gv, new_func)
and build and run it
ex = relax.vm.build(MyModuleWithParams2, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
nd_res = vm["main"](data_nd)
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams2 Prediction:", class_names[pred_kind[0]])
Table of Content