We leverage GPU architecture.
We write our module
@tvm.script.ir_module
class MyModuleWindowSum:
@T.prim_func
def main(A: T.Buffer[(1027,), "float32"],
B: T.Buffer[(1024,), "float32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in T.grid(1024):
with T.block("C"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + A[vi + 1] + A[vi + 2]
we can split and bind the loop to GPU threads
# split
sch = tvm.tir.Schedule(MyModuleWindowSum)
nthread = 128
block_C = sch.get_block("C")
i, = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, nthread])
# bind
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
sch.mod.show()
here is resulting code
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[1027, "float32"], B: T.Buffer[1024, "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A[vi : vi + 3])
T.writes(B[vi])
B[vi] = A[vi] + A[vi + 1] + A[vi + 2]
The above example allow us to split a loop to blockIdx
and threadIdx
.
In this example, we will leverage shared memory by using sch.cache_read
A_shared = sch.cache_read(block_C, read_buffer_index=0, storage_scope="shared")
sch.compute_at(A_shared, i1)
sch.mod.show()
This is the output
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[1027, "float32"], B: T.Buffer[1024, "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
A_shared = T.alloc_buffer([1027], dtype="float32", scope="shared")
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0 in T.serial(130):
with T.block("A_shared"):
v0 = T.axis.spatial(1027, i_0 * 128 + ax0)
T.reads(A[v0])
T.writes(A_shared[v0])
A_shared[v0] = A[v0]
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A_shared[vi : vi + 3])
T.writes(B[vi])
B[vi] = A_shared[vi] + A_shared[vi + 1] + A_shared[vi + 2]
The above code construct shared memory A_shared
from input A
sequentially, which is slow.
We also want to split loop and map again to implement cooperative fetching: multiple threads work together to bring the data onto the shared memory.
ax = sch.get_loops(A_shared)[-1]
ax0, ax1 = sch.split(ax, [None, nthread])
sch.bind(ax1, "threadIdx.x")
sch.mod.show()
will output
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[1027, "float32"], B: T.Buffer[1024, "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
A_shared = T.alloc_buffer([1027], dtype="float32", scope="shared")
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_0 in T.serial(2):
for ax0_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("A_shared"):
v0 = T.axis.spatial(1027, i_0 * 128 + (ax0_0 * 128 + ax0_1))
T.where(ax0_0 * 128 + ax0_1 < 130)
T.reads(A[v0])
T.writes(A_shared[v0])
A_shared[v0] = A[v0]
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A_shared[vi : vi + 3])
T.writes(B[vi])
B[vi] = A_shared[vi] + A_shared[vi + 1] + A_shared[vi + 2]
To build and print, we can use
rt_mod = tvm.build(sch.mod, target="cuda")
print(rt_mod.imported_modules[0].get_source())
we can generate Metal code(which is another kind of GPU programming model) by changing the target parameter.
python rt_mod = tvm.build(sch.mod, target="metal")
Here is a simple code
@tvm.script.ir_module
class MyModuleMatmul:
@T.prim_func
def main(A: T.Buffer[(1024, 1024), "float32"],
B: T.Buffer[(1024, 1024), "float32"],
C: T.Buffer[(1024, 1024), "float32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
You do crazy optimization by apply blocking
def blocking(sch,
tile_local_y,
tile_local_x,
tile_block_y,
tile_block_x,
tile_k):
block_C = sch.get_block("C")
C_local = sch.cache_write(block_C, 0, "local")
i, j, k = sch.get_loops(block=block_C)
i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y])
j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x])
k0, k1 = sch.split(loop=k, factors=[None, tile_k])
sch.unroll(k1)
sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
sch.reverse_compute_at(C_local, j1)
sch.bind(i0, "blockIdx.y")
sch.bind(j0, "blockIdx.x")
sch.bind(i1, "threadIdx.y")
sch.bind(j1, "threadIdx.x")
sch.decompose_reduction(block_C, k0)
return sch
sch = tvm.tir.Schedule(MyModuleMatmul)
sch = blocking(sch, 8, 8, 8, 8, 4)
sch.mod.show()
and here is the generated code
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
C_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local")
for i_0 in T.thread_binding(16, thread="blockIdx.y"):
for j_0 in T.thread_binding(16, thread="blockIdx.x"):
for i_1 in T.thread_binding(8, thread="threadIdx.y"):
for j_1 in T.thread_binding(8, thread="threadIdx.x"):
for i_2_init, j_2_init in T.grid(8, 8):
with T.block("C_init"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + i_2_init)
vj = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + j_2_init)
T.reads()
T.writes(C_local[vi, vj])
C_local[vi, vj] = T.float32(0)
for k_0 in T.serial(256):
for k_1 in T.unroll(4):
for i_2, j_2 in T.grid(8, 8):
with T.block("C_update"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + i_2)
vj = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + j_2)
vk = T.axis.reduce(1024, k_0 * 4 + k_1)
T.reads(C_local[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C_local[vi, vj])
C_local[vi, vj] = C_local[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0, ax1 in T.grid(8, 8):
with T.block("C_local"):
v0 = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + ax0)
v1 = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + ax1)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]
and you can run it like this
rt_mod = tvm.build(sch.mod, target="cuda")
dev = tvm.cuda(0)
A_np = np.random.uniform(size=(1024, 1024)).astype("float32")
B_np = np.random.uniform(size=(1024, 1024)).astype("float32")
A_nd = tvm.nd.array(A_np, dev)
B_nd = tvm.nd.array(B_np, dev)
C_nd = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev)
num_flop = 2 * 1024 * 1024 * 1024
evaluator = rt_mod.time_evaluator("main", dev, number=10)
print("GEMM-Blocking: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))
Our first attempt did not consider the neighboring threads which sit in the same GPU thread block, and we can load the data they commonly need into a piece of shared memory.
So you do crazy optimization again
def cache_read_and_coop_fetch(sch, block, nthread, read_idx, read_loc):
read_cache = sch.cache_read(block=block, read_buffer_index=read_idx, storage_scope="shared")
sch.compute_at(block=read_cache, loop=read_loc)
# vectorized cooperative fetch
inner0, inner1 = sch.get_loops(block=read_cache)[-2:]
inner = sch.fuse(inner0, inner1)
_, tx, vec = sch.split(loop=inner, factors=[None, nthread, 4])
sch.vectorize(vec)
sch.bind(tx, "threadIdx.x")
def blocking_with_shared(
sch,
tile_local_y,
tile_local_x,
tile_block_y,
tile_block_x,
tile_k):
block_C = sch.get_block("C")
C_local = sch.cache_write(block_C, 0, "local")
i, j, k = sch.get_loops(block=block_C)
i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y])
j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x])
k0, k1 = sch.split(loop=k, factors=[None, tile_k])
sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2)
sch.reverse_compute_at(C_local, j1)
sch.bind(i0, "blockIdx.y")
sch.bind(j0, "blockIdx.x")
tx = sch.fuse(i1, j1) # we fuse it
sch.bind(tx, "threadIdx.x")
nthread = tile_block_y * tile_block_x
cache_read_and_coop_fetch(sch, block_C, nthread, 0, k0) # we do shared memory
cache_read_and_coop_fetch(sch, block_C, nthread, 1, k0)
sch.decompose_reduction(block_C, k0)
return sch
sch = tvm.tir.Schedule(MyModuleMatmul)
sch = blocking_with_shared(sch, 8, 8, 8, 8, 8)
sch.mod.show()
and get the following code
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
C_local = T.alloc_buffer([1024, 1024], dtype="float32", scope="local")
A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
for i_0 in T.thread_binding(16, thread="blockIdx.y"):
for j_0 in T.thread_binding(16, thread="blockIdx.x"):
for i_1_j_1_fused in T.thread_binding(64, thread="threadIdx.x"):
for i_2_init, j_2_init in T.grid(8, 8):
with T.block("C_init"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + i_2_init)
vj = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + j_2_init)
T.reads()
T.writes(C_local[vi, vj])
C_local[vi, vj] = T.float32(0)
for k_0 in T.serial(128):
for ax0_ax1_fused_0 in T.serial(2):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(4):
with T.block("A_shared"):
v0 = T.axis.spatial(1024, i_0 * 64 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 8)
v1 = T.axis.spatial(1024, k_0 * 8 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 8)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in T.serial(2):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(4):
with T.block("B_shared"):
v0 = T.axis.spatial(1024, k_0 * 8 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 64)
v1 = T.axis.spatial(1024, j_0 * 64 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 64)
T.reads(B[v0, v1])
T.writes(B_shared[v0, v1])
B_shared[v0, v1] = B[v0, v1]
for k_1, i_2, j_2 in T.grid(8, 8, 8):
with T.block("C_update"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + i_2)
vj = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + j_2)
vk = T.axis.reduce(1024, k_0 * 8 + k_1)
T.reads(C_local[vi, vj], A_shared[vi, vk], B_shared[vk, vj])
T.writes(C_local[vi, vj])
C_local[vi, vj] = C_local[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj]
for ax0, ax1 in T.grid(8, 8):
with T.block("C_local"):
v0 = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + ax0)
v1 = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + ax1)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]
In above examples, we manually writing transformations to optimize the TensorIR program on GPU. But we can automatically optimize our code.
from tvm import meta_schedule as ms
sch_tuned = ms.tune_tir(
mod=MyModuleMatmul,
target="nvidia/tesla-p100",
config=ms.TuneConfig(
max_trials_global=64,
num_trials_per_iter=64,
),
work_dir="./tune_tmp",
task_name="main"
)
sch_tuned.mod.show()
and we can run it
rt_mod = tvm.build(sch_tuned.mod, target="nvidia/tesla-p100")
dev = tvm.cuda(0)
evaluator = rt_mod.time_evaluator("main", dev, number=10)
print("MetaSchedule: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))
Table of Content