Quantization Aware Training of Deep Learning Frameworks and the Implementation in OneFlow

OneFlow
17 min readOct 15, 2021

Written by Zhang Xiaoyu; Translated by Dong Wenwen, Wang Kaiyan

The project shared in this article is what I’m doing in OneFlow recently: port PyTorch FX to OneFlow and obtain automatic quantization aware training Dynamic computation graph (which is called nn.Module in Pytorch and OneFlow). At present, based on the nn.Module built by themselves, users can complete the complete process from nn.Module quantization aware training to using TensorRT to deploy the model after quantization aware training to the GPU with few modification of the code. The process of TensorRT inference is: Oneflow dynamic computation graph (nn.Module)->OneFlow quantization aware training model (nn.Module)->OneFlow static computation graph (nn.Graph)->ONNX-> TensorRT. Quantization aware training is based on the FX module that supports writing Pass in Eager (FX was first proposed by PyTorch, and the author ported its infrastructure to OneFlow).

0x0. Introduction

Most of you have known or used deep learning frameworks such as Pytorch, TensorFlow or OneFlow (which the author is participating in the development). You will be interested in this article if you feel that it is too perplexing when getting in touch with the quantitative schemes of the deep learning frameworks.

In this article, I will start with PyTorch’s two generations of quantitative schemes to discuss their strengths and weaknesses. Then I will talk about the quantization aware training of OneFlow that combines PyTorch’s FX module and some of my own ideas. Here are the contents in this article:

  • PyTorch’s FX module
  • Eager Pass
  • Quantization Aware Training
  • Fusion of Conv and BN
  • Conversion of nn.Module and nn.Graph in OneFlow
  • ONNX
  • TensorRT

It doesn’t matter if you are not familiar with any of the above points. In fact, you can quickly get started with the quantization aware training mentioned in this article even if you can only build the model using PyTorch. Because the codes of quantization aware training, the model transformed into ONNX and the code of using TensorRT to deploy have been open-sourced in the OneFlow’s community.

Users can build a dynamic graph model based on OneFlow (i.e. nn.Module), with the almost same API of the operator as that of PyTorch. Then call the following lines of code to make this dynamic graph model (the nn.Module) automatically insert quantization modules at the appropriate position to generate a quantization model (the nn.Module). And conduct the quantization aware training based on this quantization model.

gm: flow.fx.GraphModule = flow.fx.symbolic_trace(net)
qconfig = {
'quantization_bit': 8,
'quantization_scheme': "symmetric",
'quantization_formula': "cambricon",
'per_layer_quantization': True,
'momentum': 0.95,
}
net = quantization_aware_training(gm, flow.randn(1, 3, 32, 32), qconfig)
net = net.to(device)

After the training is completed, call the following code to complete the conversion of the trained quantization model to ONNX, and use TensorRT to perform inference on the GPU.

quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig)
quantization_resnet18 = quantization_resnet18.to("cuda")
quantization_resnet18.eval()
checkpoint = flow.load('/home/zhangxiaoyu/oneflow-cifar/checkpoint/epoch_11_val_acc_83.280000')
quantization_resnet18.load_state_dict(checkpoint)
origin_gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18)
dequantization_resnet18 = dequantization_aware_training(origin_gm, gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig)
dequantization_resnet18 = dequantization_resnet18.to("cuda")
dequantization_resnet18.eval()
class ResNet18Graph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = dequantization_resnet18
def build(self, x):
out = self.m(x)
return out
def test_resnet():
resnet_graph = ResNet18Graph()
resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda"))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(dequantization_resnet18.state_dict(), tmpdirname)
convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True)
ipt_dict, onnx_res = run_onnx("/tmp/model.onnx", get_onnx_provider("cpu"))
trt_res = run_tensorrt("/tmp/model.onnx", ipt_dict[list(ipt_dict.keys())[0]])
compare_result(onnx_res, trt_res, atol=1e-4, print_outlier=True)
test_resnet()

Users can complete the entire process of end-to-end quantization aware training to GPU deployment using just the tens of lines of code, which I think is interesting and concise! All the code of this project is open-sourced in OneFlow’s community, and the corresponding links are shown below.

0x1. Development of PyTorch’s Quantization Schemes

This chapter is mainly based on PyTorch’s official documentation https://pytorch.org/docs/1.9.0/quantization.html. PyTorch’s first generation of quantization mode is Eager Mode Quantization, and FX Graph Mode Quantization was introduced in 1.8. Eager Mode Quantization requires the users to change the model and specify the operators to be fused manually, while FX Graph Mode Quantization realizes automatic quantization: users do not need to modify the model manually and care about internal operations. This change is specifically shown in the following figure.

Comparison of PyTorch’s two generations of quantization schemes

The following is the difference between these two quantization schemes.

Eager Mode Quantization

class Net(nn.Module):def __init__(self, num_channels=1):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
self.conv2 = nn.Conv2d(40, 40, 3, 1)
self.fc = nn.Linear(5*5*40, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.reshape(-1, 5*5*40)
x = self.fc(x)
return x

In the foward of nn.Module, Pytorch is able to construct the network, call other nn.Module, call nn.functional.xxx, and even write the control logic like if. But this also brings a problem that the graph structure of this model is difficult to obtain at the Eager level. So in Eager Mode Quantization, to quantify the network, it must be modified manually:

class NetQuant(nn.Module):def __init__(self, num_channels=1):
super(NetQuant, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(40, 40, 3, 1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(5*5*40, 10)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.relu1(self.conv1(x))
x = self.pool1(x)
x = self.relu2(self.conv2(x))
x = self.pool2(x)
x = x.reshape(-1, 5*5*40)
x = self.fc(x)
x = self.dequant(x)
return x

In other words, in addition to Conv and Linear modules that contain parameters, ReLU and MaxPool2d must also be defined in __init__ to be processed correctly in Eager Mode Quantization. In addition, there are some cases that require fuse before quantization, such as Conv+ReLU. These layers also need to be specified manually. Currently, Eager Mode Quantization supports the fusion of ConV + BN, ConV + BN + ReLU, Conv + ReLU, Linear + ReLU, BN + ReLU.

model = NetQuant()model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
modules_to_fuse = [['conv1', 'relu1'], ['conv2', 'relu2']] # Specify the layer to be merged
model_fused = torch.quantization.fuse_modules(model, modules_to_fuse)
model_prepared = torch.quantization.prepare(model_fused)
post_training_quantize(model_prepared, train_loader) # Post training quantization
model_int8 = torch.quantization.convert(model_prepared)

FX Graph Mode Quantization

The introduction to the PyTorch FX module is in the next chapter.

Since PyTorch FX can automatically track the code in the forward, it records every node in the network. It is much stronger than the Eager mode in terms of fuse and dynamic insertion of quantization nodes. Therefore, for the code above, we don’t need to modify the network manually, bust only need to use FX to help us modify the network automatically. Here is an example:

from torch.quantization import get_default_qconfig, quantize_jit
from torch.quantization.quantize_fx import prepare_fx, convert_fx
model = Net()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
model_prepared = prepare_fx(model, qconfig_dict)
post_training_quantize(model_prepared, train_loader) # Post training quantization
model_int8 = convert_fx(model_prepared)

From the above explanation of the two sets of quantization approaches, the FX Graph Mode Quantization is obviously better. Because it does not require users to make additional restrictions when defining the model, instead, users can write the code as they like. OneFlow’s quantization aware training is also based on the FX module.

In the project of TensorRT https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization, we can find that Eager Mode Quantization is still used to convert to ONNX in PyTorch.

0x2. OneFlow FX (Write Pass in Eager)

What can FX use to do?

Diagram of FX

FX can transform a nn.Module to another nn.Module by implementing Transformation (also called Pass) on this architecture. For example, inset fake quantization nodes after Conv automatically to implement quantization training, then generate GraphModule (a nn.Module) to train and convert it to ONNX for deployment.

The OneFlow’s FX module is implemented in this PR: https://github.com/Oneflow-Inc/oneflow/pull/5939, which reuses the core logic and code of PyTorch's FX infrastructure. The works in this PR include:

  • Streamline the special design of PyTorch FX, such as the Trace to _C, and the interaction with Jit. Retain the four core functions of PyTorch FX, namely Symbolic Tracing, Intermediate Representation, Transformation and Python Codegen.
  • Write the code to realize the above four functions and to fully adapt to the related design of OneFlow at the same time. Use import oneflow.fx to experience these functions. We can trace the structure of nearly all Eager models built by the OneFlow API and transform them into another equivalent nn.Module. We can also customize our own Transformation Pass on the basis of this nn.Module.
  • Test of AlexNet, ResNet50, MobileNetV2 and other models.

Now, let’s see the whole idea of OneFlow FX.

Here is a simple demonstration:

import oneflow
# Simple module for demonstration
class MyModule(oneflow.nn.Module):
def __init__(self):
super().__init__()
self.param = oneflow.nn.Parameter(oneflow.rand(3, 4))
self.linear = oneflow.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from oneflow.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : oneflow.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""

The Proxy class in FX will wrap all the functions and common magic functions in call_method, call_function and math libraries in OneFlow to record all the operators in OneFlow, which is done when writing import oneflow.fx. Then when passing in a nn.Module and calling symbolic_trace to trace code, other nn.Module in __init__ will be processed first. These nn.Module and the input data should be wrapped by Proxy.

After wrapping all the operators that exist in the program by Proxy, perform forward again. In this way, the input data of this forward is no longer Tensor but Proxy (Tensor). Since the execution process of the program is similar to the process of pushing and popping of an operator, we can directly unpack the data and op recorded by the Proxy in this execution order. After unpacking, we can get the real Tensor, Parameter And operators, etc., and we use these data and operators as points and edges to construct a new Graph. So how is Graph transformed into a nn.Module? FX holds this Graph by introducing the data structure of GraphModule. In addition, GraphModule also holds code and foward, both of which are automatically generated based on Graph.GraphModule is still a nn.Module.

The automatically generated code is the code in the GraphModule, and printing the code is actually the complete execution process of the entire forward function.

In addition, FX also provides an Interpreter class to allow users to customize the execution process of nn.Module. For example, in this PR, a Pass that derives all intermediate Tensor shapes based on this class is provided. In addition, a Pass that visualizes the structure of GraphModule based on pydot is also provided, as shown in the figure below.

Model built based on Pydot visual dynamic graph mode

Here, I believe that most of you have had a general understanding of FX. One of the best features of FX is that we can modify nn.Module and then return to the changed nn.Module. Moreover, it is not difficult to find that the quantization aware training is to replace components such as Conv+BN, Conv and Linear with components inserted with fake quantization nodes. So we can accomplish this by writing a Pass based on FX.

The question at the beginning of this chapter can be answered that FX supports writing Pass in Eager.

However, FX also has shortcomings that it is unable to process control flow (but this has little effect because users generally don’t deploy the network that contains control flow).

0x3. Obtain Quantization Aware Training Pass

With OneFlow FX, we can obtain a pass for quantization aware training to automatically insert the quantization aware training component into the user-defined network to complete the quantization aware training.

Take ResNet18 as an example, which only has the Conv+BN mode, that is, any convolutional layer is followed by a BN layer. Conv will also be fused with BN by TensorRT during inference. Therefore, we must also fuse Conv and BN during training to avoid the affected accuracy of deployment. So, we need to fuse the parameters of the BN layer and the convolutional layer, and then quantify this fused parameter. The specific process is shown in the figure below.

Process of training simulation quantization

The following is the formula of fusing Conv and BN:

So:

where W and b represent the weight and bias of the convolutional layer, and x and y represent the input and output of the convolutional layer. The weight and bias after fusing the batchnorm parameters Wmerge and bmerge can be derived based can be derived based on the calculation formula of BN.

The quantization aware training component after the fusion of Conv+BN can be obtained based on the above formula. I marked the difference between the training and inference in the code.

class QConvBN(flow.nn.Module):
def __init__(
self,
conv_module,
bn_module,
quantization_bit=8,
quantization_scheme="symmetric",
quantization_formula="google",
per_layer_quantization=True,
momentum=0.95,
):
super().__init__()
self.quantization_bit = quantization_bit
self.quantization_scheme = quantization_scheme
self.quantization_formula = quantization_formula
self.per_layer_quantization = per_layer_quantization
self.conv_module = conv_module
self.bn_module = bn_module
self.moving_min_max_observer = flow.nn.MovingAverageMinMaxObserver(
training=self.training,
quantization_formula=quantization_formula,
stop_update_after_iters=1,
quantization_bit=quantization_bit,
quantization_scheme=quantization_scheme,
momentum=momentum,
)
self.min_max_observer = flow.nn.MinMaxObserver(
quantization_formula=quantization_formula,
quantization_bit=quantization_bit,
quantization_scheme=quantization_scheme,
per_layer_quantization=per_layer_quantization,
)
self.fake_quantization = flow.nn.FakeQuantization(
quantization_formula=quantization_formula,
quantization_bit=quantization_bit,
quantization_scheme=quantization_scheme,
)
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(
self.conv_module.out_channels, 1, 1, 1
)
if self.conv_module.bias is not None:
bias = (
gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
)
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, biasdef forward(self, x):
scale, zero_point = self.moving_min_max_observer(
x, flow.tensor([0], dtype=flow.int64).to(x.device.type)
)
x = self.fake_quantization(x, scale, zero_point)
if self.training:
y = flow.nn.functional.conv2d(
x,
self.conv_module.weight,
self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups,
)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
mean = y.mean(1)
var = y.var(1)
with flow.no_grad():
self.bn_module.running_mean = (
self.bn_module.momentum * self.bn_module.running_mean
+ (1 - self.bn_module.momentum) * mean
)
self.bn_module.running_var = (
self.bn_module.momentum * self.bn_module.running_var
+ (1 - self.bn_module.momentum) * var
)
else:
mean = flow.Tensor(self.bn_module.running_mean)
var = flow.Tensor(self.bn_module.running_var)
std = flow.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
weight_scale, weight_zero_point = self.min_max_observer(weight)
res = flow.nn.functional.conv2d(
x,
self.fake_quantization(weight, weight_scale, weight_zero_point),
bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups,
)
return res

After implementing this component, we can obtain a Pass for quantization aware training by replacing the Conv+BN in the user's nn.Module calculation graph with this QConvBN component. The code of the replacement is shown below:

for x in gm.graph.nodes:
if x.target in insert_place:
with gm.graph.inserting_after(x):
y = x.next
if (
isinstance(insert_op_state[x.target], flow.nn.Conv2d)
and y.target in insert_place
and isinstance(insert_op_state[y.target], flow.nn.BatchNorm2d)
):
now_target = get_current_module_space(x.target)
if now_target == "":
now_target = f"fake_conv_bn.{cnt}"
else:
now_target = (
f"{get_current_module_space(x.target)}.fake_conv_bn.{cnt}"
)
gm.add_submodule(
now_target,
QConvBN(
insert_op_state[x.target],
insert_op_state[y.target],
quantization_bit,
quantization_scheme,
quantization_formula,
per_layer_quantization,
momentum,
),
)
y.replace_all_uses_with(x)
gm.graph.erase_node(y)
gm.delete_submodule(y.target)
qconvbn = gm.graph.call_module(module_name=now_target, args=x.args,)
cnt = cnt + 1
x.replace_all_uses_with(qconvbn)
gm.graph.erase_node(x)
gm.delete_submodule(x.target)

Find the Conv+BN component in gm (GraphModule traced from ResNet18, which is still a nn.Module), delete it and replace it with the QConvBN component.

0x4. Quantization Aware Training of Cifar10 Based On ResNet18

We can conduct quantization aware training on the customized model based on the quantization Pass obtained above. Taking ResNet18 as an example, we can add the following lines of code to the original dynamic computation graph:

gm: flow.fx.GraphModule = flow.fx.symbolic_trace(net)
qconfig = {
'quantization_bit': 8,
'quantization_scheme': "symmetric",
'quantization_formula': "cambricon",
'per_layer_quantization': True,
'momentum': 0.95,
}
net = quantization_aware_training(gm, flow.randn(1, 3, 32, 32), qconfig)
net = net.to(device)

The qconfig allows users to easily configure every quantization modes supported by OneFlow.

The first net is the user-defined dynamic computation graph. The new net obtained after the Pass has been automatically inserted into the quantization aware training component. The processes of other training and testing are exactly the same as the training of FP32, which will not be introduced in detail here. I trained several quantization configurations supported by OneFlow on Cifar10 based on ResNet18, the Epochs are 200 with the same super parameters. The results are as follows:

Note:
The `momentum` parameter in the `MovingAverageMinMaxObserver` class defaults to 0.95, which will not be changed in the following experiments.
## Accuracy
| Model | quantization_bit | quantization_scheme | quantization_formula | per_layer_quantization | Acc |
| ----------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
| ResNet18 | 8 | symmetric | google | True | 95.19% |
| ResNet18 | 8 | symmetric | google | False | 95.24% |
| ResNet18 | 8 | affine | google | True | 95.32% |
| ResNet18 | 8 | affine | google | False | 95.30% |
| ResNet18 | 8 | symmetric | cambricon | True | 95.19% |

Project address: https://github.com/BBuf/oneflow-cifar.

The accuracy of ResNet18 trained on Cifar10 and based on Cifar10 is: 95.62%. Here, the quantization aware training accuracy under quantization parameters is the same as the original accuracy. The cambricon above represents the quantization scheme of Cambricon, and google represents the quantization scheme of Google.

0x5. Rewrite the Original Model Based On the Quantization Aware Training Model

We have conducted quantization aware training based on the quantization aware training model above, and now we have to consider how to deploy this quantization aware training model. Obviously, this model is not what we expected, because the model BN we used for deployment should have been merged into the convolutional layer instead of being retained. So we need to rewrite the original model based on the parameters of the quantization aware training model, and then use it to convert ONNX to TensorRT.

Similar to quantization aware training, we obtain a dequantization Pass, which is used to replace the QConvBN component with a DConv2d component. The code of the DConv2d component is shown as follows:

class DConv2d(flow.nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
quantization_bit=8,
quantization_scheme="symmetric",
quantization_formula="google",
per_layer_quantization=True,
momentum=0.95,
) -> None:
super(DConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups
)
self.moving_min_max_observer = flow.nn.MovingAverageMinMaxObserver(
training=self.training,
quantization_formula=quantization_formula,
stop_update_after_iters=1,
quantization_bit=quantization_bit,
quantization_scheme=quantization_scheme,
momentum=momentum,
)
self.min_max_observer = flow.nn.MinMaxObserver(
quantization_formula=quantization_formula,
quantization_bit=quantization_bit,
quantization_scheme=quantization_scheme,
per_layer_quantization=per_layer_quantization,
)
self.fake_quantization = flow.nn.FakeQuantization(
quantization_formula=quantization_formula,
quantization_bit=quantization_bit,
quantization_scheme=quantization_scheme,
)
self.register_buffer("new_zero", flow.Tensor(1))
self.new_zero.fill_(0)
def forward(self, x):
scale, zero_point = self.moving_min_max_observer(
x, self.new_zero.to(flow.int64).to(x.device.type)
)
x = self.fake_quantization(x, scale, zero_point)
return flow.nn.functional.conv2d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)

Then we only need to replace the Conv+BN in the original ResNet18 model with this component. Please note that the weight and bias of this component and the moving_min/max parameter of moving_min_max_observer should be assigned to the weight and bias corresponding to the QConvBN component and the moving_min/max parameter of moving_min_max_observer moving_min/max parameter of the trained quantization aware model .

The core parts of the dequantization Pass are as follows:

for x in origin_gm.graph.nodes:
if x.target in insert_place:
with origin_gm.graph.inserting_after(x):
y = x.next
if (
isinstance(insert_op_state[x.target], flow.nn.Conv2d)
and y.target in insert_place
and isinstance(insert_op_state[y.target], flow.nn.BatchNorm2d)
):
now_target = get_current_module_space(x.target)
if now_target == "":
now_target = f"fake_conv_bn.{cnt}"
else:
now_target = (
f"{get_current_module_space(x.target)}.fake_conv_bn.{cnt}"
)
dequanzation_conv = DConv2d(
quantization_op_state[now_target].conv_module.in_channels,
quantization_op_state[now_target].conv_module.out_channels,
quantization_op_state[now_target].conv_module.kernel_size,
quantization_op_state[now_target].conv_module.stride,
quantization_op_state[now_target].conv_module.padding,
quantization_op_state[now_target].conv_module.dilation,
quantization_op_state[now_target].conv_module.groups,
quantization_bit,
quantization_scheme,
quantization_formula,
per_layer_quantization,
momentum,
)
mean = flow.Tensor(quantization_op_state[now_target].bn_module.running_mean)
var = flow.Tensor(quantization_op_state[now_target].bn_module.running_var)
std = flow.sqrt(var + quantization_op_state[now_target].bn_module.eps)
if quantization_op_state[now_target].bn_module.affine:
gamma_ = quantization_op_state[now_target].bn_module.weight / std
weight = quantization_op_state[now_target].conv_module.weight * gamma_.view(
quantization_op_state[now_target].conv_module.out_channels, 1, 1, 1
)
if quantization_op_state[now_target].conv_module.bias is not None:
bias = (
gamma_ * quantization_op_state[now_target].conv_module.bias - gamma_ * mean + quantization_op_state[now_target].bn_module.bias
)
else:
bias = quantization_op_state[now_target].bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = quantization_op_state[now_target].conv_module.weight * gamma_
if quantization_op_state[now_target].conv_module.bias is not None:
bias = gamma_ * quantization_op_state[now_target].conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
dequanzation_conv.weight = flow.nn.Parameter(weight)
dequanzation_conv.bias = flow.nn.Parameter(bias)
dequanzation_conv.moving_min_max_observer.moving_max = quantization_op_state[now_target].moving_min_max_observer.moving_max
dequanzation_conv.moving_min_max_observer.moving_min = quantization_op_state[now_target].moving_min_max_observer.moving_min
origin_gm.add_submodule(
now_target,
dequanzation_conv,
)
y.replace_all_uses_with(x)
origin_gm.graph.erase_node(y)
origin_gm.delete_submodule(y.target)
qconvbn = origin_gm.graph.call_module(module_name=now_target, args=x.args,)
cnt = cnt + 1
x.replace_all_uses_with(qconvbn)
origin_gm.graph.erase_node(x)
origin_gm.delete_submodule(x.target)

The fusion of Conv and BN is conducted manually and the weight and bias after the fusion are assigned to the DConv2d component.

0x6. Convert to ONNX & TensorRT Inference

Based on the quantization aware training model and the dequantization Pass, we can obtain the nn.Module for inference. Then convert this nn.Module into ONNX and then put it in TensorRT for inference. Click https://github.com/Oneflow-Inc/oneflow_convert/blob/add_fx_train_quantization/examples/oneflow2onnx/quantization/test_resnet18.py for a demonstration. Here we will explain the core part of it.

# Load the weight of the trained quantization model
quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig)
quantization_resnet18 = quantization_resnet18.to("cuda")
quantization_resnet18.eval()
checkpoint = flow.load('/home/zhangxiaoyu/oneflow-cifar/checkpoint/epoch_11_val_acc_83.280000')
quantization_resnet18.load_state_dict(checkpoint)
# Rewrite the original model based on the quantization aware training model
origin_gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18)
dequantization_resnet18 = dequantization_aware_training(origin_gm, gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig)
dequantization_resnet18 = dequantization_resnet18.to("cuda")
dequantization_resnet18.eval()
# nn.Graph is the bridge to convert to ONNX
class ResNet18Graph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = dequantization_resnet18
def build(self, x):
out = self.m(x)
return out
# Test Function
def test_resnet():
resnet_graph = ResNet18Graph()
resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda"))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(dequantization_resnet18.state_dict(), tmpdirname)
convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True)
ipt_dict, onnx_res = run_onnx("/tmp/model.onnx", get_onnx_provider("cpu"))
trt_res = run_tensorrt("/tmp/model.onnx", ipt_dict[list(ipt_dict.keys())[0]])
compare_result(onnx_res, trt_res, atol=1e-4, print_outlier=True)
test_resnet()

First, we use the dequantization Pass to rewrite the original model and change the weight in this Pass synchronously. Then we convert the model that needs to be deployed (a nn.Module) into a static computation graph through OneFlow's nn.Graph. For the documentation of nn.Graph, see https://docs.oneflow.org/master/basics/08_nn_graph.html.

Because OneFlow’s tool for converting to ONNX is based on static computation graphs, nn.Graph is necessary. It doesn't matter if you don't want to understand the principle, the complete usage has been shown above.

The following packages need to be installed to use the OneFlow->ONNX conversion tool:

python>=3.5
onnx>=1.8.0
onnxruntime>=1.6.0
oneflow>=0.5.0

Run pip install oneflow_onnx.

Then call the convert_to_onnx_and_check API in oneflow_onnx to convert the quantization training model to ONNX. The figure below shows the ResNet18 converted into ONNX after quantization aware training.

ResNet18 quantization aware training model

We should use TesnsorRT to run this model, and also configure some environments. We need to install:

onnx>=1.8.0
onnxruntime-gpu>=1.8.0
opencv-python
pytest
nvidia-tensorrt==8.0.0.3
pycuda
flake8

After these packages are installed, TensorRT can be used for inference:

ipt_dict, onnx_res = run_onnx("/tmp/model.onnx", get_onnx_provider("cpu"))
trt_res = run_tensorrt("/tmp/model.onnx", ipt_dict[list(ipt_dict.keys())[0]])
compare_result(onnx_res, trt_res, atol=1e-4, print_outlier=True)

Other details can be seen in the code warehouse, here is the final result: under the same random input, the results of ONNX and TensorRT inference are basically the same:

-2.9825006 -2.9825
-5.438802 -5.4388037
3.5198674 3.5198674
2.409646 2.4096458
4.5826764 4.5826764
0.019911028 0.019910894
6.6347113 6.634712
-3.5996702 -3.5996711
-1.3407612 -1.340761
-3.8473191 -3.847319

Finally, the task of deploying the original dynamic graph model to the GPU for inference after quantization aware training is completed.

0x7. Conclusion

The project shared in this article is what I’m doing in OneFlow recently: port PyTorch FX to OneFlow and obtain automatic quantization aware training Dynamic computation graph (which is called nn.Module in PyTorch and OneFlow). At present, based on the nn.Module built by themselves, users can complete the complete process from nn.Module quantization aware training to using TensorRT to deploy the model after quantization aware training to the GPU with few modification of the code. The process of TensorRT inference is: Oneflow dynamic computation graph (nn.Module)->OneFlow quantization aware training model (nn.Module)->OneFlow static computation graph (nn.Graph)->ONNX-> TensorRT. Quantization aware training is based on the FX module that supports writing Pass in Eager (FX was first proposed by PyTorch, and the author ported its infrastructure to OneFlow).

0x8. Links and Documentations for Reference

I hope this article will help you in your deep learning projects😊. If you want to experience the functions of OneFlow, you can follow the method described in this article. If you have any questions or comments💡 about use, please feel free to leave a comment in the comments section below. Please do the same if you have any comments, remarks or suggestions for improvement. In future articles, we’ll introduce more functions of OneFlow, such as data type automatic promotion.

Related articles:

  1. On AI Industrialization Dilemma and the Inspiration from Database Standardization
  2. Combating Software System Complexity: Conceptual Integrity and Uniform Metaphor

Welcome to visit OneFlow on GitHub and follow us on Twitter and LinkedIn.

Also, welcome to join our Discord group to discuss and ask OneFlow related questions, and connect with OneFlow contributors and users all around the world.

--

--

OneFlow

OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient. https://github.com/Oneflow-Inc/oneflow