TVM

An Overview of TVM and Model Optimization

TVM Overview

TVM Overview

Steps:

  1. import model from different framework (Tensorflow, PyTroch, ONNX) to TVM. ONNX is most general.
  2. Translate to Relay (A High-Level Compiler for Deep Learning, developed by Catalyst Group, a functional language and IR for NN) for graph-level optimization passes
    1. data flow-style representations
    2. Functional-style, fully featured differentiable language
    3. allow the user to mix the two programming styles
  3. Lowering to Tensor Expression (TE) representation.
    1. Relay runs FuseOps to partition model into many subgraphs
    2. Automatic lower to Tensor Expression (TE) representation: a language describe tensor computation
    3. Manual lower: TVM includes a Tensor Operator Inventory (TOPI) that has pre-defined templates of common tensor operators (e.g., conv2d, transpose).
    4. TE also provides several schedule primitives to specify low-level loop optimizations (tiling, vectorization, parallelization, unrolling, and fusion)
  4. Search for the best schedule with AutoTVM or AutoScheduler. Operator or subgraph level auto-tuning module (measure cost model and on-device statistics).
    1. AutoTVM: A template-based auto-tuning module by tuning parameters in template. For common operators, their templates are already provided in TOPI.
    2. AutoScheduler (Ansor): A template-free auto-tuning module. Does not require pre-defined schedule templates, generates the search space by analyzing computation definition.
  5. Auto-tuning generates tuning records in JSON format. This step picks the best schedule for each subgraph.
  6. TE lower to Tensor Intermediate Representation (TIR, TVM's low-level IR). Perform low-level optimization pass. Then code generation to target compiler using different backend
    1. LLVM: for x86, ARM, AMDGPU, NVPTX
    2. Specialized Compiler: NVCC
    3. Embedded and specialized targets, which are implemented through TVM’s Bring Your Own Codegen (BYOC) framework.
  7. Lower to machine code.

Installation

I choose to install from source.

git clone --recursive https://github.com/apache/tvm tvm

Minimum:

For Ubuntu:

sudo apt-get update
sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

Then

mkdir build
cp cmake/config.cmake build

Then do some modification to config.make. Instruction can be found at official documentation

cd build
cmake ..
make -j4

If there is error about gtest, then disable the functionality at CMakeList.txt by commenting out every appearance of if(GTEST_FOUND). (For now.)

Now we setup python environment. In .zshrc, include:

export TVM_HOME=/path/to/tvm # the one you cloned
export PYTHONPATH=$TVM_HOME/python:${PYTHONPATH}

And some python dependency

conda create -n "tvm" python=3.8
pip3 install numpy decorator attrs
pip3 install tornado psutil 'xgboost<1.6.0' cloudpickle

That should be it.

Now you should be able to do the following:

Python 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:35)
[GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tvm
>>> dir(tvm)

You don't need to install NNPACK Contrib. NNPACK is mainly for reference and comparison purpose.

But in addition to the official documentation, you might need pip install typing-extensions

Then you should be able to run python -m tvm.driver.tvmc --version.

Then you can add to your .zshrc:

alias tvmc="python -m tvm.driver.tvmc "

Using TVMC

First download the model:

conda activate tvm
pip install onnx onnxoptimizer

wget https://github.com/onnx/models/raw/b9a54e89508f101a1611cd64f4ef56b9cb62c7cf/vision/classification/resnet/model/resnet50-v2-7.onnx
# you will get resnet50-v2-7.onnx

The onnxoptimizer dependency is optional, and is only used for onnx>=1.9. (we are using 1.13)

# This may take several minutes depending on your machine
tvmc compile \
--target "llvm" \
--input-shapes "data:[1,3,224,224]" \
--output resnet50-v2-7-tvm.tar \
resnet50-v2-7.onnx

It should produce resnet50-v2-7-tvm.tar. You will find 3 files:

target option specifies hardware target.

You prepare you input to imagenet_cat.npz format with np.savez("imagenet_cat", data=img_data). Then run the following to get prediction

tvmc run \
--inputs imagenet_cat.npz \
--output predictions.npz \
resnet50-v2-7-tvm.tar

Then we need to load .npz file and pass a softmax and map prediction to human-readable label.

To tune the model, for example, on an Intel i7 processor you could use --target llvm -mcpu=skylake. For this tuning example, we are tuning locally on the CPU using LLVM as the compiler for the specified architecture. It could take hours to search.

tvmc tune \
--target "llvm" \
--output resnet50-v2-7-autotuner_records.json \
resnet50-v2-7.onnx

The generated resnet50-v2-7-autotuner_records.json could be used to do further tuning or optimize compilation (via --tuning-records).

By default this search is guided using an XGBoost Grid algorithm. Depending on your model complexity and amount of time available, you might want to choose a different algorithm.

tvmc compile \
--target "llvm" \
--tuning-records resnet50-v2-7-autotuner_records.json  \
--output resnet50-v2-7-tvm_autotuned.tar \
resnet50-v2-7.onnx

For example, on a test Intel i7 system, we see that the tuned model runs 47% faster than the untuned model.

TVMC supports many more features including cross-compilation, remote execution and profiling/benchmarking

TVM Python Interface

In its python interface, you can use enable_autoscheduler = True for non-hand-crafted search space. And you can also set more parameters trials=10000, timeout=10 in tuning.

import onnx
from tvm.contrib.download import download_testdata
from PIL import Image
import numpy as np
import tvm.relay as relay
import tvm
from tvm.contrib import graph_executor
# Seed numpy's RNG to get consistent results
np.random.seed(0)

# download models
model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx")
onnx_model = onnx.load(model_path)

# settings
target = "llvm"
input_name = "data" # The input name may vary across model types. You can use a tool like Netron to check input names
shape_dict = {input_name: img_data.shape}

# load model
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

# compile to relay
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))

# run the model
dtype = "float32"
module.set_input(input_name, img_data)
module.run()
output_shape = (1, 1000)
tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()

# tunning the model for platform specific optimization
import tvm.auto_scheduler as auto_scheduler
from tvm.autotvm.tuner import XGBTuner
from tvm import autotvm

# create a TVM runner
runner = autotvm.LocalRunner(
    number=10, # number of different configurations that we will test # QUESTION: i assume it is hand-crafted search
    repeat=1, # how many measurements we will take of each configuration.
    timeout=10, # in seconds, for each tested configuration
    min_repeat_ms=0, # how long need to run configuration test, Set 0 for CPU.
    enable_cpu_cache_flush=True,
)

tuning_option = {
    "tuner": "xgb", # XGBoost Grid algorithm
    "trials": 20, # (in production, for CPU set to 1500, GPU set to 3000~4000) # QUESTION: i assume it is not hand-crafted search
    "early_stopping": 100, # QUESTION: what is this
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(build_func="default"), runner=runner
    ),
    "tuning_records": "resnet-50-v2-autotuning.json", # where to save
}

# extracting the tasks from the onnx model
tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params)
# QUESTION: what are tasks

# Tune the extracted tasks sequentially.
for i, task in enumerate(tasks):
    prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
    tuner_obj = XGBTuner(task, loss_type="rank")
    tuner_obj.tune(
        n_trial=min(tuning_option["trials"], len(task.config_space)),
        early_stopping=tuning_option["early_stopping"],
        measure_option=tuning_option["measure_option"],
        callbacks=[
            autotvm.callback.progress_bar(tuning_option["trials"], prefix=prefix),
            autotvm.callback.log_to_file(tuning_option["tuning_records"]),
        ],
    )

# recompile optimized version
with autotvm.apply_history_best(tuning_option["tuning_records"]):
    with tvm.transform.PassContext(opt_level=3, config={}):
        lib = relay.build(mod, target=target, params=params)

dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))

Working with Operators Using Tensor Expression

We will use Tensor Expression (TE) to define tensor computations and apply loop optimizations.

Example 1: Writing and Scheduling Vector Addition in TE for CPU¶

import tvm
import tvm.testing
from tvm import te
import numpy as np

# llc --version to get the CPU type
# check /proc/cpuinfo for additional extensions
# e.g. llvm -mcpu=skylake-avx512 for CPUs with AVX-512 instructions.
# QUESTION: how do I exactly write target field after I get CPU type
tgt = tvm.target.Target(target="llvm", host="llvm")

# define computation layout (no actual computation)
n = te.var("n") # initialize n as a dimension variable
A = te.placeholder((n,), name="A") # A has n dimension
B = te.placeholder((n,), name="B") # B has n dimension
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") # C should have the dimension similar to A's dimension, and computation is given by lambda function

Sequential Scheduler

# create a naive scheduler, similar to below code
# for (int i = 0; i < n; ++i) {
#   C[i] = A[i] + B[i];
# }
s = te.create_schedule(C.op)

# pack everything into a function
fadd = tvm.build(s, # scheduler
                [A, B, C], # signature of the function (input, output variables)
                tgt, # target platform language
                name="myadd"
)

# register device
dev = tvm.device(tgt.kind.name, 0)

# initialize actual data
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), dev)

# do computation
fadd(a, b, c)

# check for correct answer
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())

# you might need profiling: https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html

Parallel Scheduler

# s: scheduler
# C: the output `compute` template variable
s[C].parallel(C.op.axis[0])

# generate IR of TE (simple_mode=True for readability)
print(tvm.lower(s, [A, B, C], simple_mode=True))
# @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
#   attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
#   buffers = {A: Buffer(A_2: Pointer(float32), float32, [n: int32], [stride: int32], type="auto"),
#              B: Buffer(B_2: Pointer(float32), float32, [n], [stride_1: int32], type="auto"),
#              C: Buffer(C_2: Pointer(float32), float32, [n], [stride_2: int32], type="auto")}
#   buffer_map = {A_1: A, B_1: B, C_1: C} {
#   for (i: int32, 0, n) "parallel" {
#     C_3: Buffer(C_2, float32, [(stride_2*n)], [], type="auto")[(i*stride_2)] = (A_3: Buffer(A_2, float32, [(stride*n)], [], type="auto")[(i*stride)] + B_3: Buffer(B_2, float32, [(stride_1*n)], [], type="auto")[(i*stride_1)])
#   }
# }

# pack everything into a function
fadd_parallel = tvm.build(s, [A, B, C], tgt, name="myadd_parallel")

# run the function and check correctness
fadd_parallel(a, b, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())

Vectorization Scheduler

n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")

s = te.create_schedule(C.op)

# This factor should be chosen to match the number of threads appropriate for
# your CPU. This will vary depending on architecture, but a good rule is
# setting this factor to equal the number of available CPU cores.
factor = 4

# vectorize stuff
outer, inner = s[C].split(C.op.axis[0], factor=factor)
s[C].parallel(outer)
s[C].vectorize(inner)

fadd_vector = tvm.build(s, [A, B, C], tgt, name="myadd_parallel")
evaluate_addition(fadd_vector, tgt, "vector", log=log)

Table of Content