Relex: relay-next. A high level computational graph abstraction
Say we want to implement a really simple machine learning model that looks like the following:
The corresponding numpy code looks like this:
def numpy_mlp(data, w0, b0, w1, b1):
lv0 = data @ w0.T + b0
lv1 = np.maximum(lv0, 0)
lv2 = lv1 @ w1.T + b1
return lv2
To run the model, we can do this ```python import pickle as pkl
mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb")) res = numpy_mlp(img.reshape(1, 784), mlp_params["w0"], mlp_params["b0"], mlp_params["w1"], mlp_params["b1"]) print(res) pred_kind = res.argmax(axis=1) print(pred_kind) print("NumPy-MLP Prediction:", class_names[pred_kind[0]]) ```
The corresponding numpy low-level implementation may look like the following:
def lnumpy_linear0(X: np.ndarray, W: np.ndarray, B: np.ndarray, Z: np.ndarray):
Y = np.empty((1, 128), dtype="float32")
for i in range(1):
for j in range(128):
for k in range(784):
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + X[i, k] * W[j, k]
for i in range(1):
for j in range(128):
Z[i, j] = Y[i, j] + B[j]
def lnumpy_relu0(X: np.ndarray, Y: np.ndarray):
for i in range(1):
for j in range(128):
Y[i, j] = np.maximum(X[i, j], 0)
def lnumpy_linear1(X: np.ndarray, W: np.ndarray, B: np.ndarray, Z: np.ndarray):
Y = np.empty((1, 10), dtype="float32")
for i in range(1):
for j in range(10):
for k in range(128):
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + X[i, k] * W[j, k]
for i in range(1):
for j in range(10):
Z[i, j] = Y[i, j] + B[j]
def lnumpy_mlp(data, w0, b0, w1, b1):
lv0 = np.empty((1, 128), dtype="float32")
lnumpy_linear0(data, w0, b0, lv0)
lv1 = np.empty((1, 128), dtype="float32")
lnumpy_relu0(lv0, lv1)
out = np.empty((1, 10), dtype="float32")
lnumpy_linear1(lv1, w1, b1, out)
return out
result =lnumpy_mlp(
img.reshape(1, 784),
mlp_params["w0"],
mlp_params["b0"],
mlp_params["w1"],
mlp_params["b1"])
pred_kind = result.argmax(axis=1)
print("Low-level Numpy MLP Prediction:", class_names[pred_kind[0]])
Now we show you how you may implement this in TVMScript:
@tvm.script.ir_module
class MyModule:
@T.prim_func
def relu0(X: T.Buffer[(1, 128), "float32"],
Y: T.Buffer[(1, 128), "float32"]):
# function attr dict
T.func_attr({"global_symbol": "relu0", "tir.noalias": True})
for i, j in T.grid(1, 128):
with T.block("Y"):
vi, vj = T.axis.remap("SS", [i, j])
Y[vi, vj] = T.max(X[vi, vj], T.float32(0))
@T.prim_func
def linear0(X: T.Buffer[(1, 784), "float32"],
W: T.Buffer[(128, 784), "float32"],
B: T.Buffer[(128,), "float32"],
Z: T.Buffer[(1, 128), "float32"]):
T.func_attr({"global_symbol": "linear0", "tir.noalias": True})
Y = T.alloc_buffer((1, 128), "float32")
for i, j, k in T.grid(1, 128, 784):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(1, 128):
with T.block("Z"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Y[vi, vj] + B[vj]
@T.prim_func
def linear1(X: T.Buffer[(1, 128), "float32"],
W: T.Buffer[(10, 128), "float32"],
B: T.Buffer[(10,), "float32"],
Z: T.Buffer[(1, 10), "float32"]):
T.func_attr({"global_symbol": "linear1", "tir.noalias": True})
Y = T.alloc_buffer((1, 10), "float32")
for i, j, k in T.grid(1, 10, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
for i, j in T.grid(1, 10):
with T.block("Z"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Y[vi, vj] + B[vj]
@R.function
def main(x: Tensor((1, 784), "float32"),
w0: Tensor((128, 784), "float32"),
b0: Tensor((128,), "float32"),
w1: Tensor((10, 128), "float32"),
b1: Tensor((10,), "float32")):
with R.dataflow():
# R.call_tir(function_name, inputs_whether_differentiable_or_not, shape_of_output, datatype)
lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")
lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32")
out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype="float32")
R.output(out)
return out
where with R.dataflow()
is to mark the portion of code that is pure and has no side effect. We can make them in computational-graph form. (don't treat all computation as graph!)
functions that are not inside a dataflow block must be executed in sequential order.
within each block, R.output(out)
specifies the single output can be seen by later functions. Other variables like lv0
is not avaliable outside the block.
where the R.call_tir
function is similar to below numpy implementation:
def lnumpy_call_tir(prim_func, inputs, shape, dtype):
res = np.empty(shape, dtype=dtype)
prim_func(*inputs, res)
return res
Notice the @R.function
above that correspond to a computational graph, with input output specified for each node in graph.
You may notice that the output allocation happens inside primitive function, this is by design. Otherwise, if you allocate at @R.function
level, you will end up with a computational graph that looks like the following, which is not desirable. And plus the function looks non-pure and relies on side effect.
To run the prediction, you can do
ex = relax.vm.build(MyModuleWithExternCall, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
nd_res = vm["main"](data_nd,
nd_params["w0"],
nd_params["b0"],
nd_params["w1"],
nd_params["b1"])
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithExternCall Prediction:", class_names[pred_kind[0]])
We want TVM to support execution of external functions. For example, we want to execute env.linear
(a different linear layer implementation provided by env
package) by calling out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32")
We can do so by registering a new function:
@tvm.register_func("env.linear", override=True) # this register the function. Not necessarily from Python, can be from other languages
def torch_linear(x: tvm.nd.NDArray,
w: tvm.nd.NDArray,
b: tvm.nd.NDArray,
out: tvm.nd.NDArray):
x_torch = torch.from_dlpack(x) # casting TVM tensor to Torch tensor (without copy), so they refer to same piece of memory
w_torch = torch.from_dlpack(w)
b_torch = torch.from_dlpack(b)
out_torch = torch.from_dlpack(out)
torch.mm(x_torch, w_torch.T, out=out_torch)
torch.add(out_torch, b_torch, out=out_torch)
When we have a huge model, we don't want to pass model weights to each function. So we want to bind model weights to functions.
We use MyModuleWithParams = relax.transform.BindParams("main", nd_params)(MyModuleMixture)
to bind parameters. This is possible since parameters are stored in dictionary and each parameter is associated with a key of dictionary.
After binding, you will discover our main function doesn't need input parameter anymore:
ex = relax.vm.build(MyModuleWithParams, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
nd_res = vm["main"](data_nd)
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams Prediction:", class_names[pred_kind[0]])
Table of Content