Commit 35cc605b authored by Jerry Zhang's avatar Jerry Zhang Committed by Facebook GitHub Bot
Browse files

torch.quantization --> torch.ao.quantization in deeplearning/projects/fairseq-py

Summary:
codemod -m -d $dir --extensions py \
            'torch.quantization' \
            'torch.ao.quantization'

Reviewed By: z-a-f

Differential Revision: D31294192

fbshipit-source-id: fcad50d07a8397fc2ab8fd7188ab338f51f3ba10
parent 72bb4447
......@@ -27,7 +27,7 @@ from fairseq.models.speech_to_text.utils import (
layer_norm_backward_hook,
)
from torch import Tensor, device as Device
from torch.quantization.qconfig import (
from torch.ao.quantization.qconfig import (
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
)
......@@ -140,7 +140,7 @@ class PositionwiseFF(nn.Module):
qconfig = per_channel_dynamic_qconfig
else:
qconfig = default_dynamic_qconfig
torch.quantization.quantize_dynamic(
torch.ao.quantization.quantize_dynamic(
self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
)
return self
......@@ -728,7 +728,7 @@ class NoSegAugmentedMemoryMultiheadAttentionBmm(nn.Module):
qconfig = per_channel_dynamic_qconfig
else:
qconfig = default_dynamic_qconfig
torch.quantization.quantize_dynamic(
torch.ao.quantization.quantize_dynamic(
self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
)
return self
......@@ -1771,7 +1771,7 @@ class NoSegAugmentedMemoryTransformerEncoderLayer(FairseqEncoder):
qconfig = per_channel_dynamic_qconfig
else:
qconfig = default_dynamic_qconfig
torch.quantization.quantize_dynamic(
torch.ao.quantization.quantize_dynamic(
self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
)
return self
......
......@@ -21,7 +21,7 @@ def quantize(w, scale, zero_point, bits=8):
def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8):
if scale is None:
obs = torch.quantization.observer.HistogramObserver()
obs = torch.ao.quantization.observer.HistogramObserver()
obs.to(device=w.device)
_ = obs(w.float())
scale, zero_point = obs.calculate_qparams()
......@@ -32,7 +32,7 @@ def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8):
def emulate_int8_channel(w, scale=None, zero_point=None, bits=8):
if scale is None:
obs = torch.quantization.observer.PerChannelMinMaxObserver(
obs = torch.ao.quantization.observer.PerChannelMinMaxObserver(
ch_axis=-1, qscheme=torch.per_channel_symmetric
)
obs.to(device=w.device)
......@@ -45,7 +45,7 @@ def emulate_int8_channel(w, scale=None, zero_point=None, bits=8):
def emulate_int8_tensor(w, scale=None, zero_point=None, bits=8):
if scale is None:
obs = torch.quantization.observer.MinMaxObserver()
obs = torch.ao.quantization.observer.MinMaxObserver()
obs.to(device=w.device)
_ = obs(w)
scale, zero_point = obs.calculate_qparams()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment