tico 0.1.0.dev251106__py3-none-any.whl → 0.2.0.dev260122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tico/__init__.py +2 -2
  2. tico/_version.py +1 -0
  3. tico/passes/convert_conv3d_to_conv2d.py +435 -0
  4. tico/passes/convert_sym_size_to_circle_shape.py +99 -0
  5. tico/passes/decompose_batch_norm.py +9 -5
  6. tico/passes/lower_copy.py +95 -0
  7. tico/passes/ops.py +4 -0
  8. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +251 -0
  9. tico/quantization/algorithm/fpi_gptq/quantizer.py +180 -0
  10. tico/quantization/algorithm/gptq/gptq.py +231 -11
  11. tico/quantization/algorithm/gptq/quantizer.py +18 -6
  12. tico/quantization/config/{pt2e.py → fpi_gptq.py} +11 -4
  13. tico/quantization/config/gptq.py +27 -4
  14. tico/quantization/public_interface.py +0 -10
  15. tico/quantization/wrapq/quantizer.py +2 -0
  16. tico/quantization/wrapq/wrappers/quant_elementwise.py +51 -11
  17. tico/serialize/operators/adapters/onert/llama_attention.py +51 -0
  18. tico/serialize/operators/op_attention.py +58 -0
  19. tico/serialize/operators/op_circle_shape.py +64 -0
  20. tico/serialize/operators/op_dequantize_per_channel.py +1 -0
  21. tico/serialize/operators/op_dequantize_per_tensor.py +1 -0
  22. tico/serialize/operators/op_transpose_conv.py +66 -50
  23. tico/utils/convert.py +16 -1
  24. tico/utils/padding.py +13 -5
  25. tico/utils/record_input.py +2 -2
  26. tico/utils/register_custom_op.py +63 -0
  27. tico/utils/validate_args_kwargs.py +49 -4
  28. tico-0.2.0.dev260122.dist-info/METADATA +631 -0
  29. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/RECORD +35 -46
  30. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/WHEEL +1 -1
  31. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/entry_points.txt +0 -1
  32. tico/quantization/algorithm/pt2e/annotation/annotator.py +0 -208
  33. tico/quantization/algorithm/pt2e/annotation/config.py +0 -26
  34. tico/quantization/algorithm/pt2e/annotation/op/__init__.py +0 -21
  35. tico/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +0 -63
  36. tico/quantization/algorithm/pt2e/annotation/op/add.py +0 -55
  37. tico/quantization/algorithm/pt2e/annotation/op/conv2d.py +0 -90
  38. tico/quantization/algorithm/pt2e/annotation/op/div.py +0 -55
  39. tico/quantization/algorithm/pt2e/annotation/op/linear.py +0 -92
  40. tico/quantization/algorithm/pt2e/annotation/op/mean.py +0 -51
  41. tico/quantization/algorithm/pt2e/annotation/op/mul.py +0 -55
  42. tico/quantization/algorithm/pt2e/annotation/op/relu6.py +0 -51
  43. tico/quantization/algorithm/pt2e/annotation/op/rsqrt.py +0 -51
  44. tico/quantization/algorithm/pt2e/annotation/op/sub.py +0 -55
  45. tico/quantization/algorithm/pt2e/annotation/spec.py +0 -45
  46. tico/quantization/algorithm/pt2e/annotation/utils.py +0 -88
  47. tico/quantization/algorithm/pt2e/quantizer.py +0 -81
  48. tico/quantization/algorithm/pt2e/transformation/__init__.py +0 -1
  49. tico/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -58
  50. tico/quantization/algorithm/pt2e/utils.py +0 -135
  51. tico/serialize/operators/op_copy.py +0 -187
  52. tico-0.1.0.dev251106.dist-info/METADATA +0 -392
  53. /tico/quantization/algorithm/{pt2e → fpi_gptq}/__init__.py +0 -0
  54. /tico/{quantization/algorithm/pt2e/annotation → serialize/operators/adapters/onert}/__init__.py +0 -0
  55. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info/licenses}/LICENSE +0 -0
  56. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/top_level.txt +0 -0
@@ -31,16 +31,147 @@ torch.backends.cuda.matmul.allow_tf32 = False
31
31
  torch.backends.cudnn.allow_tf32 = False
32
32
 
33
33
 
34
+ def convtranspose2d_weights_to_conv2d_weights(layer, w) -> torch.Tensor:
35
+ if layer.groups == 1:
36
+ # the last two dimensions of w is (k_h, k_w) to get equivalent Conv2D we need to flip them to get `w_conv2D_equivalent_to_w[i, j] = w_conv[k_h - i - 1, k_w - j - 1]`
37
+ # the first two dimensions of w is (input_channels, output_channels), so we need to transpose them as Conv2D weights should be in the (output_channels, input_channels) form
38
+ # please see https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/conv.py#L1059-L1061 for additional info
39
+ w_conv_transposed = w.transpose(1, 0).flip((-2, -1))
40
+ else:
41
+ # basically it's the same as for `layer.groups == 1` but groupwise
42
+ in_channels, out_channels, kernel_h, kernel_w = layer.weight.shape
43
+ out_channels *= layer.groups
44
+ w_conv_transposed = torch.zeros(
45
+ out_channels, in_channels // layer.groups, kernel_h, kernel_w
46
+ )
47
+ for i in range(0, layer.groups):
48
+ w_conv_transposed[
49
+ i
50
+ * out_channels
51
+ // layer.groups : (i + 1)
52
+ * out_channels
53
+ // layer.groups,
54
+ :,
55
+ :,
56
+ :,
57
+ ] = (
58
+ w[
59
+ i
60
+ * in_channels
61
+ // layer.groups : (i + 1)
62
+ * in_channels
63
+ // layer.groups,
64
+ :,
65
+ :,
66
+ :,
67
+ ]
68
+ .transpose(1, 0)
69
+ .flip((-2, -1))
70
+ )
71
+
72
+ return w_conv_transposed
73
+
74
+
75
+ def conv2d_weights_to_convtranspose2d_weights(orig_layer, w) -> torch.Tensor:
76
+ # this is just an inverse of convtranspose2d_weights_to_conv2d_weights
77
+ if orig_layer.groups > 1:
78
+ in_channels, out_channels, _, _ = orig_layer.weight.shape
79
+ out_channels *= orig_layer.groups
80
+ w_conv_transposed = torch.zeros_like(orig_layer.weight)
81
+ for i in range(0, orig_layer.groups):
82
+ w_conv_transposed[
83
+ i
84
+ * in_channels
85
+ // orig_layer.groups : (i + 1)
86
+ * in_channels
87
+ // orig_layer.groups,
88
+ :,
89
+ :,
90
+ :,
91
+ ] = (
92
+ w[
93
+ i
94
+ * out_channels
95
+ // orig_layer.groups : (i + 1)
96
+ * out_channels
97
+ // orig_layer.groups,
98
+ :,
99
+ :,
100
+ :,
101
+ ]
102
+ .transpose(1, 0)
103
+ .flip((-2, -1))
104
+ )
105
+ else:
106
+ w_conv_transposed = w.transpose(1, 0).flip((-2, -1))
107
+
108
+ return w_conv_transposed
109
+
110
+
111
+ def get_matmul_input_for_convtranspose2d(layer, inp):
112
+ # Please see https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/conv.py#L996-L998 for padding
113
+ strided_pad = (
114
+ layer.dilation[0] * (layer.kernel_size[0] - 1) - layer.padding[0],
115
+ layer.dilation[1] * (layer.kernel_size[1] - 1) - layer.padding[1],
116
+ )
117
+
118
+ # interleave input with zero rows and columns according to stride
119
+ # Please see https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/conv.py#L991-L994 for more info
120
+ inp_strided = torch.zeros(
121
+ inp.shape[0],
122
+ inp.shape[1],
123
+ layer.stride[0] * (inp.shape[2] - 1) + 2 * strided_pad[0] + 1,
124
+ layer.stride[1] * (inp.shape[3] - 1) + 2 * strided_pad[1] + 1,
125
+ device=inp.device,
126
+ )
127
+
128
+ indices = torch.arange(0, inp.shape[2], device=inp.device)
129
+ # insert original input values according to stride to meet https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/conv.py#L991-L994
130
+ inp_strided[
131
+ :,
132
+ :,
133
+ layer.stride[0] * indices + strided_pad[0],
134
+ strided_pad[1] : -strided_pad[1] : layer.stride[1],
135
+ ] = inp[:, :, indices, :]
136
+ del inp
137
+ inp = (
138
+ inp_strided # so the rest is just processing for Conv2D with transposed weights
139
+ )
140
+
141
+ # TODO reduce code duplication with Conv2D
142
+ unfold = nn.Unfold(
143
+ layer.kernel_size,
144
+ dilation=layer.dilation,
145
+ padding=(
146
+ 0,
147
+ 0,
148
+ ), # equivalent Conv2D has (0, 0) padding for input_strided as input
149
+ stride=(1, 1), # equivalent Conv2D has (1, 1) stride for input_strided as input
150
+ )
151
+
152
+ if layer.groups != 1:
153
+ inp = inp.reshape(
154
+ inp.size(0) * layer.groups,
155
+ inp.size(1) // layer.groups,
156
+ inp.shape[2],
157
+ inp.shape[3],
158
+ ) # inp.shape == (batch*groups, in_channels / groups, H, W) to meet Groupwise-wise Convolution, so that each group is colvolved with its own filter
159
+
160
+ inp = unfold(inp).permute([1, 0, 2]).flatten(1)
161
+ return inp
162
+
163
+
34
164
  class GPTQ:
35
165
  def __init__(self, layer):
36
166
  self.layer = layer
37
167
  self.dev = self.layer.weight.device
38
168
  W = layer.weight.data.clone()
39
- if isinstance(self.layer, nn.Conv2d):
169
+ if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
170
+ W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
171
+ elif isinstance(self.layer, nn.ConvTranspose2d):
172
+ W = convtranspose2d_weights_to_conv2d_weights(self.layer, W)
40
173
  W = W.flatten(1)
41
174
 
42
- if isinstance(self.layer, nn.Conv1d):
43
- W = W.t()
44
175
  self.rows = W.shape[0]
45
176
  self.columns = W.shape[1]
46
177
  self.H: Optional[torch.Tensor] = torch.zeros(
@@ -53,7 +184,7 @@ class GPTQ:
53
184
  if len(inp.shape) == 2:
54
185
  inp = inp.unsqueeze(0)
55
186
  tmp = inp.shape[0]
56
- if isinstance(self.layer, nn.Linear) or isinstance(self.layer, nn.Conv1d):
187
+ if isinstance(self.layer, nn.Linear):
57
188
  if len(inp.shape) > 2:
58
189
  inp = inp.reshape((-1, inp.shape[-1]))
59
190
  inp = inp.t()
@@ -65,10 +196,59 @@ class GPTQ:
65
196
  stride=self.layer.stride,
66
197
  )
67
198
 
199
+ if self.layer.groups != 1:
200
+ # the idea behind conversion of depthwise convolution to matmul is described here
201
+ # https://discuss.pytorch.org/t/conv1d-implementation-using-torch-nn-functional-unfold/109643/2
202
+ # although depthwise convolution is equal to a set of MatMuls
203
+ # (please note `w.view(1, groups, out_channels // groups, -1)` in the reference above is not just w.flatten(1))
204
+ # we can approximate groupwise Hessians with their mean
205
+ # so that we will have just a single Hessian and the usual GPTQ applies
206
+ inp = inp.reshape(
207
+ inp.size(0) * self.layer.groups,
208
+ inp.size(1) // self.layer.groups,
209
+ inp.shape[2],
210
+ inp.shape[3],
211
+ ) # inp.shape == (batch*groups, in_channels / groups, H, W) to meet Groupwise-wise Convolution, so that each group is colvolved with its own filter
212
+
213
+ inp = unfold(
214
+ inp
215
+ ) # inp.shape == (batch*groups, k_h*k_w*in_channels / groups, flattened_patches)
216
+ inp = inp.permute(
217
+ [1, 0, 2]
218
+ ) # inp.shape == (k_h*k_w*in_channels / groups, batch * groups, flattened_patches)
219
+ inp = inp.flatten(
220
+ 1
221
+ ) # inp.shape == (k_h*k_w*in_channels / groups, batch * groups * flattened_patches)
222
+ # so inp.matmul(inp.t()).shape == (k_x*k_y*in_channels / groups, k_x*k_y*in_channels / groups) == W.flatten(1)
223
+
224
+ if isinstance(self.layer, nn.Conv1d):
225
+ # nn.Conv1d is basically the same as nn.Conv2d so we can use the same idea as for nn.Conv2d
226
+ # TODO reduce code duplication
227
+ # represent conv1d as conv2d(1, k) on reshaped_input(batch, in_channels, 1, L)
228
+ unfold = nn.Unfold(
229
+ (1, self.layer.kernel_size[0]),
230
+ dilation=(1, self.layer.dilation[0]),
231
+ padding=(0, self.layer.padding[0]),
232
+ stride=(1, self.layer.stride[0]),
233
+ )
234
+ if self.layer.groups != 1:
235
+ # please see Conv2D for additional info
236
+ inp = inp.reshape(
237
+ inp.size(0) * self.layer.groups,
238
+ inp.size(1) // self.layer.groups,
239
+ inp.shape[2],
240
+ ) # inp.shape == (batch*groups, in_channels / groups, L) to meet Groupwise-wise Convolution, so that each group is colvolved with its own filter
241
+
242
+ inp = inp.unsqueeze(
243
+ -2
244
+ ) # (batch*groups, in_channels / groups, L)->(batch*groups, in_channels / groups, 1, L), valid for Conv2D
68
245
  inp = unfold(inp)
69
246
  inp = inp.permute([1, 0, 2])
70
247
  inp = inp.flatten(1)
71
248
 
249
+ if isinstance(self.layer, nn.ConvTranspose2d):
250
+ inp = get_matmul_input_for_convtranspose2d(self.layer, inp)
251
+
72
252
  self.H *= self.nsamples / (self.nsamples + tmp)
73
253
  self.nsamples += tmp
74
254
  inp = math.sqrt(2 / self.nsamples) * inp.float()
@@ -84,10 +264,13 @@ class GPTQ:
84
264
  verbose=False,
85
265
  ):
86
266
  W = self.layer.weight.data.clone()
87
- if isinstance(self.layer, nn.Conv2d):
88
- W = W.flatten(1)
89
- if isinstance(self.layer, nn.Conv1d):
90
- W = W.t()
267
+ if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
268
+ W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
269
+ elif isinstance(self.layer, nn.ConvTranspose2d):
270
+ W = convtranspose2d_weights_to_conv2d_weights(self.layer, W)
271
+ conv2d_shape = W.shape
272
+ W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
273
+
91
274
  W = W.float()
92
275
  tick = time.time()
93
276
  if not self.quantizer.ready():
@@ -181,9 +364,46 @@ class GPTQ:
181
364
  if actorder:
182
365
  Q = Q[:, invperm]
183
366
 
184
- self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
185
- self.layer.weight.data.dtype
186
- )
367
+ if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
368
+ if groupsize == -1: # TODO support groupsize != -1
369
+ Q[:, dead] = quantize(
370
+ self.layer.weight.flatten(1)[:, dead],
371
+ self.quantizer.scale,
372
+ self.quantizer.zero,
373
+ self.quantizer.maxq,
374
+ )
375
+ elif isinstance(self.layer, nn.ConvTranspose2d):
376
+ if groupsize == -1: # TODO support groupsize != -1
377
+ Q[:, dead] = quantize(
378
+ convtranspose2d_weights_to_conv2d_weights(
379
+ self.layer, self.layer.weight.data
380
+ ).flatten(1)[:, dead],
381
+ self.quantizer.scale,
382
+ self.quantizer.zero,
383
+ self.quantizer.maxq,
384
+ )
385
+ else:
386
+ if groupsize == -1: # TODO support groupsize != -1
387
+ Q[:, dead] = quantize(
388
+ self.layer.weight[:, dead],
389
+ self.quantizer.scale,
390
+ self.quantizer.zero,
391
+ self.quantizer.maxq,
392
+ )
393
+
394
+ assert (
395
+ groupsize == -1 or torch.sum(dead) == 0
396
+ ) # TODO `dead` elements should be RTN quantized for groupwise
397
+
398
+ if isinstance(self.layer, nn.ConvTranspose2d):
399
+ Q_conv2d = Q.reshape(conv2d_shape).to(self.layer.weight.data.dtype)
400
+ self.layer.weight.data = conv2d_weights_to_convtranspose2d_weights(
401
+ self.layer, Q_conv2d
402
+ )
403
+ else:
404
+ self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
405
+ self.layer.weight.data.dtype
406
+ )
187
407
 
188
408
  def free(self):
189
409
  self.H = None
@@ -170,6 +170,7 @@ class GPTQQuantizer(BaseQuantizer):
170
170
 
171
171
  gptq_conf = self.config
172
172
  assert isinstance(gptq_conf, GPTQConfig)
173
+ gptq_conf.validate()
173
174
  # Disable use_cache during calibration
174
175
  if hasattr(model, "config") and hasattr(model.config, "use_cache"):
175
176
  orig_use_cache = model.config.use_cache
@@ -193,7 +194,15 @@ class GPTQQuantizer(BaseQuantizer):
193
194
  )
194
195
  ):
195
196
  # 1) Identify quantizable submodules within the layer
196
- full = find_layers(layer)
197
+ full = find_layers(
198
+ layer,
199
+ layers=[
200
+ torch.nn.Linear,
201
+ torch.nn.Conv2d,
202
+ torch.nn.Conv1d,
203
+ torch.nn.ConvTranspose2d,
204
+ ],
205
+ )
197
206
  sequential = [list(full.keys())]
198
207
 
199
208
  # 2) Set up GPTQ objects and gather stats
@@ -204,7 +213,10 @@ class GPTQQuantizer(BaseQuantizer):
204
213
  for name in subset:
205
214
  gptq[name] = GPTQ(subset[name])
206
215
  gptq[name].quantizer.configure(
207
- bits=8, perchannel=True, sym=False, mse=False
216
+ bits=gptq_conf.weight_bits,
217
+ perchannel=gptq_conf.perchannel,
218
+ sym=gptq_conf.symmetric,
219
+ mse=gptq_conf.mse,
208
220
  )
209
221
 
210
222
  # Hook to collect (inp, out) for GPTQ
@@ -244,10 +256,10 @@ class GPTQQuantizer(BaseQuantizer):
244
256
  if gptq_conf.verbose:
245
257
  print(f"[Layer {l_idx}] {name} -> Quantizing ...")
246
258
  gptq[name].fasterquant(
247
- percdamp=0.01,
248
- groupsize=-1,
249
- actorder=True,
250
- static_groups=False,
259
+ percdamp=gptq_conf.percdamp,
260
+ groupsize=gptq_conf.groupsize,
261
+ actorder=gptq_conf.actorder,
262
+ static_groups=gptq_conf.static_groups,
251
263
  verbose=gptq_conf.verbose,
252
264
  )
253
265
  quantizers[f"model.layers.{l_idx}.{name}"] = gptq[name].quantizer
@@ -12,14 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from tico.quantization.config.base import BaseConfig
15
+ from dataclasses import dataclass
16
16
 
17
+ from tico.quantization.config.gptq import GPTQConfig
17
18
 
18
- class PT2EConfig(BaseConfig):
19
+
20
+ @dataclass
21
+ class FPIGPTQConfig(GPTQConfig):
19
22
  """
20
- Configuration for pytorch 2.0 export quantization.
23
+ Configuration for FPIGPTQ (Fixed Point Iteration).
21
24
  """
22
25
 
26
+ def __init__(self, verbose: bool = False, show_progress: bool = True):
27
+ self.verbose = verbose
28
+ self.show_progress = show_progress
29
+
23
30
  @property
24
31
  def name(self) -> str:
25
- return "pt2e"
32
+ return "fpi_gptq"
@@ -12,18 +12,41 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from dataclasses import dataclass
16
+
15
17
  from tico.quantization.config.base import BaseConfig
16
18
 
17
19
 
20
+ @dataclass
18
21
  class GPTQConfig(BaseConfig):
19
22
  """
20
- Configuration for GPTQ.
23
+ Configuration for GPTQ weight quantization.
21
24
  """
22
25
 
23
- def __init__(self, verbose: bool = False, show_progress: bool = True):
24
- self.verbose = verbose
25
- self.show_progress = show_progress
26
+ # general
27
+ verbose: bool = False
28
+ show_progress: bool = True
29
+
30
+ # quantizer.configure params (weight quantization spec)
31
+ weight_bits: int = 8
32
+ perchannel: bool = True
33
+ symmetric: bool = False
34
+ mse: bool = False
35
+
36
+ # GPTQ.fasterquant params (algorithm hyperparams)
37
+ percdamp: float = 0.01
38
+ groupsize: int = -1
39
+ actorder: bool = True
40
+ static_groups: bool = False
26
41
 
27
42
  @property
28
43
  def name(self) -> str:
29
44
  return "gptq"
45
+
46
+ def validate(self) -> None:
47
+ if self.weight_bits <= 0:
48
+ raise ValueError(f"weight_bits must be positive. got {self.weight_bits}")
49
+ if self.groupsize != -1 and self.groupsize <= 0:
50
+ raise ValueError(f"groupsize must be -1 or positive. got {self.groupsize}")
51
+ if not (0.0 < self.percdamp <= 1.0):
52
+ raise ValueError(f"percdamp must be in (0, 1]. got {self.percdamp}")
@@ -18,7 +18,6 @@ from typing import Any, Dict, Optional
18
18
  import torch
19
19
 
20
20
  from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
- from tico.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
21
  from tico.quantization.config.base import BaseConfig
23
22
  from tico.quantization.quantizer import BaseQuantizer
24
23
  from tico.quantization.quantizer_registry import get_quantizer
@@ -55,11 +54,6 @@ def prepare(
55
54
  raise RuntimeError("prepare() already has been called.")
56
55
  quantizer = get_quantizer(quant_config)
57
56
 
58
- if isinstance(quantizer, PT2EQuantizer) and inplace:
59
- raise RuntimeError(
60
- "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
61
- )
62
-
63
57
  model = model if inplace else copy.deepcopy(model)
64
58
 
65
59
  model = quantizer.prepare(model, args, kwargs)
@@ -90,10 +84,6 @@ def convert(model, inplace: Optional[bool] = True):
90
84
  else:
91
85
  raise RuntimeError("Call prepare() function first.")
92
86
 
93
- if isinstance(quantizer, PT2EQuantizer) and inplace:
94
- raise RuntimeError(
95
- "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
96
- )
97
87
  # deepcopy prevents the quantizer from restoring the catcher used for calibration.
98
88
  # TODO Revisit `inplace` policy.
99
89
  if isinstance(quantizer, GPTQQuantizer) and not inplace:
@@ -115,6 +115,7 @@ class PTQQuantizer(BaseQuantizer):
115
115
  assert not self.strict_wrap
116
116
  wrapped = self._wrap_supported(wrapped, child_cfg)
117
117
  root[i] = wrapped # type: ignore[index]
118
+ return root
118
119
 
119
120
  if isinstance(root, nn.ModuleDict):
120
121
  for k, child in list(root.items()):
@@ -128,6 +129,7 @@ class PTQQuantizer(BaseQuantizer):
128
129
  assert not self.strict_wrap
129
130
  wrapped = self._wrap_supported(wrapped, child_cfg)
130
131
  root[k] = wrapped # type: ignore[index]
132
+ return root
131
133
 
132
134
  # Case C: Leaf node
133
135
  root_name = getattr(root, "_get_name", lambda: None)()
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Callable, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  import torch
18
18
  import torch.nn as nn
@@ -31,7 +31,7 @@ class QuantElementwise(QuantModuleBase):
31
31
  """
32
32
 
33
33
  # subclass must set this
34
- FUNC: Callable[[torch.Tensor], torch.Tensor] | None = None
34
+ FUNC: Any = None
35
35
 
36
36
  def __init_subclass__(cls, **kwargs):
37
37
  super().__init_subclass__(**kwargs)
@@ -68,7 +68,7 @@ class QuantElementwise(QuantModuleBase):
68
68
 
69
69
 
70
70
  """
71
- Why `FUNC` is a `staticmethod`
71
+ Q1) Why `FUNC` is a `staticmethod`
72
72
 
73
73
  - Prevents automatic binding: calling `self.FUNC(x)` will not inject `self`,
74
74
  so the callable keeps the expected signature `Tensor -> Tensor`
@@ -85,27 +85,67 @@ Why `FUNC` is a `staticmethod`
85
85
  than an `nn.Module` instance that would appear in the module tree.
86
86
 
87
87
  - Small perf/alloc win: no bound-method objects are created on each call.
88
+
89
+ Q2) Why we define small Python wrappers (_relu, _tanh, etc.)
90
+
91
+ - torch.relu / torch.tanh / torch.sigmoid are CPython built-ins.
92
+ Their type is `builtin_function_or_method`, not a Python `FunctionType`.
93
+ This causes `torch.export` (and FX tracing) to fail with:
94
+ "expected FunctionType, found builtin_function_or_method".
95
+
96
+ - By defining a thin Python wrapper (e.g., `def _tanh(x): return torch.tanh(x)`),
97
+ we convert it into a normal Python function object (`FunctionType`),
98
+ which satisfies export/tracing requirements.
99
+
100
+ - Functionally, this adds zero overhead and preserves semantics,
101
+ but makes the callable introspectable (has __code__, __name__, etc.)
102
+ and compatible with TorchDynamo / FX graph capture.
103
+
104
+ - It also keeps FUNC pure and stateless, ensuring the elementwise op
105
+ is represented as `call_function(_tanh)` in the traced graph
106
+ rather than a bound `call_method` or module attribute access.
88
107
  """
89
108
 
90
- # Sigmoid
109
+
110
+ def _relu(x: torch.Tensor) -> torch.Tensor:
111
+ return torch.relu(x)
112
+
113
+
114
+ def _tanh(x: torch.Tensor) -> torch.Tensor:
115
+ return torch.tanh(x)
116
+
117
+
118
+ def _sigmoid(x: torch.Tensor) -> torch.Tensor:
119
+ return torch.sigmoid(x)
120
+
121
+
122
+ def _gelu(x: torch.Tensor) -> torch.Tensor:
123
+ return torch.nn.functional.gelu(x)
124
+
125
+
91
126
  @register(nn.Sigmoid)
92
127
  class QuantSigmoid(QuantElementwise):
93
- FUNC = staticmethod(torch.sigmoid)
128
+ @staticmethod
129
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
130
+ return _sigmoid(x)
94
131
 
95
132
 
96
- # Tanh
97
133
  @register(nn.Tanh)
98
134
  class QuantTanh(QuantElementwise):
99
- FUNC = staticmethod(torch.tanh)
135
+ @staticmethod
136
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
137
+ return _tanh(x)
100
138
 
101
139
 
102
- # ReLU
103
140
  @register(nn.ReLU)
104
141
  class QuantReLU(QuantElementwise):
105
- FUNC = staticmethod(torch.relu)
142
+ @staticmethod
143
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
144
+ return _relu(x)
106
145
 
107
146
 
108
- # GELU (approximate)
109
147
  @register(nn.GELU)
110
148
  class QuantGELU(QuantElementwise):
111
- FUNC = staticmethod(torch.nn.functional.gelu)
149
+ @staticmethod
150
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
151
+ return _gelu(x)
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ import torch
18
+
19
+ from transformers.cache_utils import DynamicCache
20
+ from transformers.models.llama.modeling_llama import LlamaAttention
21
+
22
+
23
+ def llama_attention_forward_adapter(
24
+ self: LlamaAttention,
25
+ hidden_states: torch.Tensor,
26
+ position_embeddings: List[torch.Tensor],
27
+ attention_mask: torch.Tensor,
28
+ past_key_value: DynamicCache,
29
+ cache_position: torch.Tensor,
30
+ **kwargs,
31
+ ):
32
+ # past_key_value is a dict with key_cache and value_cache.
33
+ # It needs to be decomposed for tico and circle which does not know dict.
34
+ key_cache = past_key_value.key_cache # type: ignore[union-attr]
35
+ value_cache = past_key_value.value_cache # type: ignore[union-attr]
36
+ return (
37
+ torch.ops.circle_custom.attention(
38
+ hidden_states,
39
+ self.q_proj.weight,
40
+ self.k_proj.weight,
41
+ self.v_proj.weight,
42
+ self.o_proj.weight,
43
+ position_embeddings[0], # cos
44
+ position_embeddings[1], # sin
45
+ attention_mask,
46
+ key_cache[self.layer_idx],
47
+ value_cache[self.layer_idx], # Same to value_cache
48
+ cache_position,
49
+ ),
50
+ None,
51
+ )
@@ -0,0 +1,58 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import CircleAttentionArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class AttentionVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.circle_custom.attention.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
46
+ )
47
+
48
+ inputs = node.args
49
+ outputs = [node]
50
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
51
+
52
+ # Op-specific option
53
+ operator.builtinOptionsType = (
54
+ circle.BuiltinOptions.BuiltinOptions.AttentionOptions
55
+ )
56
+ operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT()
57
+
58
+ return operator