import torchvision.models as models
= nn.Sequential(
resnet *list(models.quantization.resnet18(pretrained=True).children())[:-3],
nn.Flatten(), 512,2)
nn.Linear( ).cuda()
The what
So you have trained a neural network and want to deploy it. Performance — speed and computational complexity, not just accuracy — matters a lot when in production. If your model can achieve low enough latencies on a cpu instance, you will have a massively lower deployment cost over using a gpu instance. Lower costs equals higher profits.
Model quantization is (usually) the easiest way to massively speed up your model. If you want to learn more about the theory behind quantization and how it works check out this blogpost. Feeling too lazy to read through all that? Here’s a quick summary. Quantization provides us a way to compress the weights of our model. Weights are usually represented with 32-bit floats. But we “quantize” the weights and reduce this to 8-bits instead. You can go even further and use as less as 1-bit for every parameter, creating binary neural networks, but that is beyond the scope of this post. While quantization directly reduces model size by 4x, that is not the most important part. Using reduced precisions significantly reduces the time taken for matrix multiplication and addition. And I am not talking about measly 10-20% gains either. You can expect a 3-5x speed up when quantizing a model from FP32 to INT8! These gains are serious enough that they offset the performance gap between a CPU and GPU, making real time inference possible on CPU.
So… what’s the catch you ask? The catch is that using lower precision arithmetic means there is an increased chance of arithmetic overflow — because we are greatly limiting the range in which values can lie. There are ways to reduce the probability of overflow (more on this later, in calibrate) but the chances still remain.
The how
Quantizing common pytorch models are pretty simple thanks to Pytorch’s quantization API. You need to perform the following steps to get a basic quantized model
Step 0: Create a model
Let’s create a basic resnet18 model with a binary classification head. Note that we need to use the ‘quantization’ version of resnet18, instead of standard torchvision version. The latter will give an error. I will explain the reason for this later.
Step 1: Fuse layers
In this step we will ‘combine’ the layers of our model. This step is actually not related to quantization, but it does give extra speedups.
'0', '1', '2']], inplace=True)
torch.quantization.fuse_modules(resnet, [[# In our sequential model, there are 4x2 resblocks in positions [4,5,6,7]
for i in range(4,8):
# fuse modules in the first resblock
0],
torch.quantization.fuse_modules(resnet[i]['conv1', 'bn1', 'relu'], ['conv2', 'bn2']],
[[=True)
inplace# if this resblock does downsampling, also fuse the skip-connection
if resnet[i][0].downsample is not None:
0].downsample,
torch.quantization.fuse_modules(resnet[i]['0', '1']],
[[=True)
inplace# fuse modules in the second resblock
1],
torch.quantization.fuse_modules(resnet[i]['conv1', 'bn1', 'relu'], ['conv2', 'bn2']],
[[=True) inplace
Step 2: Prepare for qat
Prepare the model for quantization aware training
= torch.quantization.get_default_qat_qconfig('fbgemm')
resnet.qconfig = torch.quantization.prepare_qat(resnet).cuda() resnet
Step 3: Train normally
for _ in range(epochs):
for x, y in train_loader:
= loss_fn(resnet(x.cuda()), y.cuda())
loss
opt.zero_grad()
loss.backward() opt.step()
Step 4: Post training steps
This is where the magic happens
class Qresnet(nn.Module):
def __init__(self, m):
super().__init__()
self.q = torch.quantization.QuantStub()
self.m = m
self.dq = torch.quantization.DeQuantStub()
def forward(self, x):
return self.dq(self.m(self.q(x)))
# load the best model from training phase
'best_model.pth'))
resnet.load_state_dict(torch.load(
# 4.1: wrap qat resnet with quant dequant stubs
= Qresnet(resnet)
qmodel
# add quantization recipe to model again
= torch.quantization.get_default_qconfig('fbgemm')
qmodel.qconfig
# 4.2: prepare the modules in the model to be quantized
= torch.quantization.prepare(qmodel)
qmodel
# 4.3: calibrate weights
for x,y in train_loader:
qmodel(x.cuda())
# 4.4: actually quantize the trained model.
= torch.quantization.convert(qmodel.cpu())
qmodel
# put to eval mode
= qmodel.eval()
qmodel
# 4.5: script the model using TorchScript for easy delpoyment (optional)
torch.jit.script(qmodel)
Going deeper
The quint8 datatype
When pytorch quantizes a 32-bit float tensor it is represented as an 8-bit unsigned int. This is called quint8. You can quantize any fp32 tensor to quint8 using the following command
= torch.randn(10, 10)
a = torch.quantize_per_tensor(a, scale=0.1, zero_point=10, dtype=torch.quint8) aq
The concepts of scale and zero point are explained in the above blogpost. As a short summary, these values determine the range of values in which the quantized tensor can lie. In quantization we achieve performance boosts by significantly limiting the range in which our weights and activations can lie. Unlike an fp32 tensor, the ‘value’ of a quint8 tensor has no meaning. The value along with the scale_factor determines an offset from the fixed zero_point.
Step 0
In step:0 we created a ‘quantization’ version of resnet18. This version is exactly same as the original except for the forward function in the Residual blocks. We cannot use the normal version because in resnets there is a skip connection. The final output of resblock is created by adding the inputs with outputs of the conv layers in the block. Standard addition or multiplication is not allowed for tensors of type quint8. This is because we cannot simply add their values without looking at the scale and zero_point.
The resulting tensor after any mathematical operation between two quint8 tensors will have a different scale and zero_point. This is taken care of by the nn.quantized.FloatFunctional module. To add two quint8 tensors do the following.
= nn.quantized.FloatFunctional()
ff = torch.tensor(3.0)
a = torch.tensor(4.0)
b # Equivalent to ``torch.add(a, b)`` ff.add(a, b)
The FloatFunctional module also allows multiplication and relu
Step 1
This step doesn’t have anything to do with quantization per se. It just reduces the number of operations in your model by combining convolution, batch norm and relu layers. For now only [conv, batch norm] or [conv, batch norm, relu] can be fused. Note that after this step the state_dict of your model will become slightly different, so if you want to load pretrained weights, you should do it before this step.
# before fusing
0]
resnet[# out: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# after fusing
'0', '1', '2']], inplace=False)[0]
torch.quantization.fuse_modules(resnet, [[# out: ConvBnReLU2d(
# (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# )
Step 2
This step adds a quantization recipe to our model. We are using the ‘fbgemm’ recipe, which seems to be the Pytorch recommended default. The recipe is essentially a strategy for quantizing the model. Different recipes may or may not be able to quantize different operations. As an end user this is all we only need to worry about.
The next line actually prepares the model for quantization aware training. The model won’t actually be trained at quint8. The weights and gradients will still be at fp32 but their will be some fake quantization layers visible in the state dict. These are resposnible simulating the effect of precision loss during inference due to quantization. The model will ideally learn to work around the precision loss due to qauntization, and there should be no noticeable drop in performance on deploying this model.
Pytorch calls this fake quantization. Let’s see what happens to the model after this step.
# first conv block after prepare_qat
0]
resnet[# out: ConvBnReLU2d(
# 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
# (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
# fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'),
# scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32),
# dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
# (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([], device='cuda:0'),
# max_val=tensor([], device='cuda:0'))
# )
# (activation_post_process): FusedMovingAvgObsFakeQuantize(
# fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'),
# scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32),
# dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
# (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
# )
# )
As can be seen, some fake quantization ‘stubs’ have been added. These deal with the simulation of precision loss, explained above. Now let’s see what happened to the state_dict.
# state dict of the first conv block after prepare_qat
0].state_dict().keys()
resnet[# out: odict_keys(['weight', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var',
# 'bn.num_batches_tracked', 'weight_fake_quant.fake_quant_enabled', 'weight_fake_quant.observer_enabled',
# 'weight_fake_quant.scale', 'weight_fake_quant.zero_point', 'weight_fake_quant.activation_post_process.eps',
# 'weight_fake_quant.activation_post_process.min_val', 'weight_fake_quant.activation_post_process.max_val',
# 'activation_post_process.fake_quant_enabled', 'activation_post_process.observer_enabled',
# 'activation_post_process.scale', 'activation_post_process.zero_point',
# 'activation_post_process.activation_post_process.eps', 'activation_post_process.activation_post_process.min_val',
# 'activation_post_process.activation_post_process.max_val'])
Quite a few new keys have been added for the fake quantization stubs. There is a quantization stub to quantize the weights called weight_fake_quant and another to quantize the activations called activation_post_process. The zero points and scales of these stubs have been added to the state_dict.
# dtype of weight
0].weight.dtype
resnet[# out: torch.float32
Step 4
4.1: Wrap
After training run we need to actually convert the model to fp32. The quantized model will take quint8 inputs, so we need to wrap the model with Quantization and DeQuantization stubs. This achived with the new Qresnet class. We need this because we cannot manually quantize our input tensors without knowing the scale and zero point. Using a Quant stub will allow torch to calculate these values using the train set. Note that the DeQuant stub has no parameters.
4.2: Prepare
This is another prepare step. The point of this step is to add necesary components to the model to allow calculation of zero point and scale for quantization. In our case it does not change the state dict since we have trained the model with qat added fake quantization stubs. These stubs have calculated the necessary scale and zero points.
But, if we were training without qat, this step would add a HistogramObserver module
# qmodel.m is a resnet on which prepare_qat has NOT been called
0]
qmodel.m[# out: ConvBnReLU2d(
# (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# (activation_post_process): HistogramObserver()
# )
4.3 : Calibrate
In this step we forward pass all elements of the training set through the model to calibrate the weights. Remember how we mentioned earlier that there is a way to reduce the probability of overflow due to reduced precision in computation? That is what is happening here. We assume that the train set is representative of the actual data distribution. Under this assumption we can calculate the range in which inputs and outputs of each layer lies, and set the scale and zero points accordingly. As long as we do not encounter a strongly out-of-distribution data point at test time, there is a low risk of overflow.
Of course since we have used qat, these values have already been computed, except for the input quant stub. If you have a very large train set, this step can take quite some time. In this case you might be able to get away with using a subset of your train set.
4.4 : Quantize
After this step, the weights have finally turned to quint8 from fp32
# first conv block after quantization
0]
qmodel.m[# out: QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.03747640550136566,
# zero_point=0, padding=(3, 3))
# weight of first conv block
# note that weight is now a function instead of parameter
0].weight().dtype
qmodel.m[# out: torch.qint8
4.5: Script
Finally you can use torchscript to compile your model. This is also an optional step but recommended since if you just save the weights of your quantized model, you will have to do all the steps for creating the quantized model, at inference time, before you can load the weights.