ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1041 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from dataclasses import dataclass
17
+ import itertools
18
+ import operator
19
+ from typing import Callable, Dict, List, NamedTuple, Optional
20
+
21
+ import torch
22
+ from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
23
+ from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
24
+ from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
25
+ from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
26
+ from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
27
+ from torch.ao.quantization.quantizer import QuantizationAnnotation
28
+ from torch.ao.quantization.quantizer import QuantizationSpec
29
+ from torch.ao.quantization.quantizer import QuantizationSpecBase
30
+ from torch.ao.quantization.quantizer import SharedQuantizationSpec
31
+ from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map
32
+ from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
33
+ from torch.fx import Node
34
+ from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap # NOQA
35
+ from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
36
+ import torch.nn.functional as F
37
+
38
+ __all__ = [
39
+ "OperatorConfig",
40
+ "OperatorPatternType",
41
+ "QuantizationConfig",
42
+ "get_input_act_qspec",
43
+ "get_output_act_qspec",
44
+ "get_weight_qspec",
45
+ "get_bias_qspec",
46
+ "OP_TO_ANNOTATOR",
47
+ "propagate_annotation",
48
+ ]
49
+
50
+
51
+ @dataclass(eq=True, frozen=True)
52
+ class QuantizationConfig:
53
+ input_activation: Optional[QuantizationSpec]
54
+ output_activation: Optional[QuantizationSpec]
55
+ weight: Optional[QuantizationSpec]
56
+ bias: Optional[QuantizationSpec]
57
+ fixed_qparams: Optional[QuantizationSpec]
58
+ # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
59
+ is_qat: bool = False
60
+ is_dynamic: bool = False
61
+
62
+
63
+ OperatorPatternType = List[Callable]
64
+ OperatorPatternType.__module__ = "ai_edge_torch.quantize.pt2e_quantizer_utils"
65
+
66
+ AnnotatorType = Callable[
67
+ [
68
+ torch.fx.GraphModule,
69
+ Optional[QuantizationConfig],
70
+ Optional[Callable[[Node], bool]],
71
+ ],
72
+ Optional[List[List[Node]]],
73
+ ]
74
+ OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {}
75
+
76
+
77
+ def register_annotator(op: str):
78
+ def decorator(annotator: AnnotatorType):
79
+ OP_TO_ANNOTATOR[op] = annotator
80
+
81
+ return decorator
82
+
83
+
84
+ class OperatorConfig(NamedTuple):
85
+ # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
86
+ # Basically we are mapping a quantization config to some list of patterns.
87
+ # a pattern is defined as a list of nn module, function or builtin function names
88
+ # e.g. [nn.Conv2d, torch.relu, torch.add]
89
+ # We have not resolved whether fusion can be considered internal details of the
90
+ # quantizer hence it does not need communication to user.
91
+ # Note this pattern is not really informative since it does not really
92
+ # tell us the graph structure resulting from the list of ops.
93
+ config: QuantizationConfig
94
+ operators: List[OperatorPatternType]
95
+
96
+
97
+ def _is_annotated(nodes: List[Node]):
98
+ """
99
+ Given a list of nodes (that represents an operator pattern),
100
+ check if any of the node is annotated, return True if any of the node
101
+ is annotated, otherwise return False
102
+ """
103
+ annotated = False
104
+ for node in nodes:
105
+ annotated = annotated or (
106
+ "quantization_annotation" in node.meta
107
+ and node.meta["quantization_annotation"]._annotated
108
+ )
109
+ return annotated
110
+
111
+
112
+ def _mark_nodes_as_annotated(nodes: List[Node]):
113
+ for node in nodes:
114
+ if node is not None:
115
+ if "quantization_annotation" not in node.meta:
116
+ node.meta["quantization_annotation"] = QuantizationAnnotation()
117
+ node.meta["quantization_annotation"]._annotated = True
118
+
119
+
120
+ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
121
+ if quantization_config is None:
122
+ return None
123
+ if quantization_config.input_activation is None:
124
+ return None
125
+ quantization_spec: QuantizationSpec = quantization_config.input_activation
126
+ assert quantization_spec.qscheme in [
127
+ torch.per_tensor_affine,
128
+ torch.per_tensor_symmetric,
129
+ ]
130
+ return quantization_spec
131
+
132
+
133
+ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
134
+ if quantization_config is None:
135
+ return None
136
+ if quantization_config.output_activation is None:
137
+ return None
138
+ quantization_spec: QuantizationSpec = quantization_config.output_activation
139
+ assert quantization_spec.qscheme in [
140
+ torch.per_tensor_affine,
141
+ torch.per_tensor_symmetric,
142
+ ]
143
+ return quantization_spec
144
+
145
+
146
+ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
147
+ if quantization_config is None:
148
+ return None
149
+ assert quantization_config is not None
150
+ if quantization_config.weight is None:
151
+ return None
152
+ quantization_spec: QuantizationSpec = quantization_config.weight
153
+ if quantization_spec.qscheme not in [
154
+ torch.per_tensor_symmetric,
155
+ torch.per_channel_symmetric,
156
+ ]:
157
+ raise ValueError(f"Unsupported quantization_spec {quantization_spec} for weight")
158
+ return quantization_spec
159
+
160
+
161
+ def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
162
+ if quantization_config is None:
163
+ return None
164
+ assert quantization_config is not None
165
+ if quantization_config.bias is None:
166
+ return None
167
+ quantization_spec: QuantizationSpec = quantization_config.bias
168
+ assert (
169
+ quantization_spec.dtype == torch.float
170
+ ), "Only float dtype for bias is supported for bias right now"
171
+ return quantization_spec
172
+
173
+
174
+ def get_fixed_qparams_qspec(quantization_config: Optional[QuantizationConfig]):
175
+ if quantization_config is None:
176
+ return None
177
+ assert quantization_config is not None
178
+ if quantization_config.fixed_qparams is None:
179
+ return None
180
+ quantization_spec: QuantizationSpec = quantization_config.fixed_qparams
181
+ return quantization_spec
182
+
183
+
184
+ @register_annotator("linear")
185
+ def _annotate_linear(
186
+ gm: torch.fx.GraphModule,
187
+ quantization_config: Optional[QuantizationConfig],
188
+ filter_fn: Optional[Callable[[Node], bool]] = None,
189
+ ) -> Optional[List[List[Node]]]:
190
+ annotated_partitions = []
191
+ input_act_qspec = get_input_act_qspec(quantization_config)
192
+ output_act_qspec = get_output_act_qspec(quantization_config)
193
+ weight_qspec = get_weight_qspec(quantization_config)
194
+ bias_qspec = get_bias_qspec(quantization_config)
195
+ for node in gm.graph.nodes:
196
+ if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
197
+ continue
198
+ if filter_fn and not filter_fn(node):
199
+ continue
200
+ act_node = node.args[0]
201
+ weight_node = node.args[1]
202
+ bias_node = None
203
+ if len(node.args) > 2:
204
+ bias_node = node.args[2]
205
+
206
+ if _is_annotated([node]) is False: # type: ignore[list-item]
207
+ _annotate_input_qspec_map(
208
+ node,
209
+ act_node,
210
+ input_act_qspec,
211
+ )
212
+ _annotate_input_qspec_map(
213
+ node,
214
+ weight_node,
215
+ weight_qspec,
216
+ )
217
+ nodes_to_mark_annotated = [node, weight_node]
218
+ if bias_node:
219
+ _annotate_input_qspec_map(
220
+ node,
221
+ bias_node,
222
+ bias_qspec,
223
+ )
224
+ nodes_to_mark_annotated.append(bias_node)
225
+ _annotate_output_qspec(node, output_act_qspec)
226
+ _mark_nodes_as_annotated(nodes_to_mark_annotated)
227
+ annotated_partitions.append(nodes_to_mark_annotated)
228
+
229
+ return annotated_partitions
230
+
231
+
232
+ @register_annotator("addmm")
233
+ def _annotate_addmm(
234
+ gm: torch.fx.GraphModule,
235
+ quantization_config: Optional[QuantizationConfig],
236
+ filter_fn: Optional[Callable[[Node], bool]] = None,
237
+ ) -> Optional[List[List[Node]]]:
238
+ annotated_partitions = []
239
+ for n in gm.graph.nodes:
240
+ if n.op != "call_function" or n.target not in [
241
+ torch.ops.aten.addmm.default,
242
+ ]:
243
+ continue
244
+ addm_node = n
245
+
246
+ input_qspec_map = {}
247
+ input_act = addm_node.args[0]
248
+ assert isinstance(input_act, Node)
249
+ is_bias = (
250
+ len(list(input_act.meta["val"].size())) < 2
251
+ and input_act.op == "get_attr"
252
+ and "_param_constant" in input_act.target
253
+ )
254
+ input_qspec_map[input_act] = (
255
+ get_bias_qspec(quantization_config)
256
+ if is_bias
257
+ else get_input_act_qspec(quantization_config)
258
+ )
259
+
260
+ mat1_act = addm_node.args[1]
261
+ assert isinstance(mat1_act, Node)
262
+ input_qspec_map[mat1_act] = get_input_act_qspec(quantization_config)
263
+
264
+ mat2_act = addm_node.args[2]
265
+ assert isinstance(mat2_act, Node)
266
+ is_weight = False
267
+ if mat2_act.op == "get_attr" and "_param_constant" in mat2_act.target:
268
+ is_weight = True
269
+ elif mat2_act.target == torch.ops.aten.t.default:
270
+ t_in = mat2_act.args[0]
271
+ if t_in.op == "get_attr" and "_param_constant" in t_in.target:
272
+ is_weight = True
273
+ input_qspec_map[mat2_act] = (
274
+ get_weight_qspec(quantization_config)
275
+ if is_weight
276
+ else get_input_act_qspec(quantization_config)
277
+ )
278
+
279
+ partition = [addm_node, addm_node.args[1], addm_node.args[2]]
280
+
281
+ if _is_annotated(partition):
282
+ continue
283
+
284
+ if filter_fn and any(not filter_fn(n) for n in partition):
285
+ continue
286
+
287
+ addm_node.meta["quantization_annotation"] = QuantizationAnnotation(
288
+ input_qspec_map=input_qspec_map,
289
+ output_qspec=get_output_act_qspec(quantization_config),
290
+ _annotated=True,
291
+ )
292
+ _mark_nodes_as_annotated(partition)
293
+ annotated_partitions.append(partition)
294
+ return annotated_partitions
295
+
296
+
297
+ @register_annotator("conv")
298
+ def _annotate_conv(
299
+ gm: torch.fx.GraphModule,
300
+ quantization_config: Optional[QuantizationConfig],
301
+ filter_fn: Optional[Callable[[Node], bool]] = None,
302
+ ) -> Optional[List[List[Node]]]:
303
+ annotated_partitions = []
304
+ for n in gm.graph.nodes:
305
+ if n.op != "call_function" or n.target not in [
306
+ torch.ops.aten.conv1d.default,
307
+ torch.ops.aten.conv2d.default,
308
+ torch.ops.aten.convolution.default,
309
+ ]:
310
+ continue
311
+ conv_node = n
312
+
313
+ input_qspec_map = {}
314
+ input_act = conv_node.args[0]
315
+ assert isinstance(input_act, Node)
316
+ input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
317
+
318
+ weight = conv_node.args[1]
319
+ assert isinstance(weight, Node)
320
+ input_qspec_map[weight] = get_weight_qspec(quantization_config)
321
+
322
+ # adding weight node to the partition as well
323
+ partition = [conv_node, conv_node.args[1]]
324
+
325
+ bias = conv_node.args[2] if len(conv_node.args) > 2 else None
326
+ if isinstance(bias, Node):
327
+ input_qspec_map[bias] = get_bias_qspec(quantization_config)
328
+ partition.append(bias)
329
+
330
+ if _is_annotated(partition):
331
+ continue
332
+
333
+ if filter_fn and any(not filter_fn(n) for n in partition):
334
+ continue
335
+
336
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
337
+ input_qspec_map=input_qspec_map,
338
+ output_qspec=get_output_act_qspec(quantization_config),
339
+ _annotated=True,
340
+ )
341
+ _mark_nodes_as_annotated(partition)
342
+ annotated_partitions.append(partition)
343
+ return annotated_partitions
344
+
345
+
346
+ @register_annotator("conv_relu")
347
+ def _annotate_conv_relu(
348
+ gm: torch.fx.GraphModule,
349
+ quantization_config: Optional[QuantizationConfig],
350
+ filter_fn: Optional[Callable[[Node], bool]] = None,
351
+ ) -> Optional[List[List[Node]]]:
352
+ annotated_partitions = []
353
+ for n in gm.graph.nodes:
354
+ if n.op != "call_function" or n.target not in [
355
+ torch.ops.aten.relu.default,
356
+ torch.ops.aten.relu_.default,
357
+ ]:
358
+ continue
359
+ relu_node = n
360
+ maybe_conv_node = n.args[0]
361
+ if (
362
+ not isinstance(maybe_conv_node, Node)
363
+ or maybe_conv_node.op != "call_function"
364
+ or maybe_conv_node.target
365
+ not in [
366
+ torch.ops.aten.conv1d.default,
367
+ torch.ops.aten.conv2d.default,
368
+ torch.ops.aten.convolution.default,
369
+ ]
370
+ ):
371
+ continue
372
+ conv_node = maybe_conv_node
373
+
374
+ input_qspec_map = {}
375
+ input_act = conv_node.args[0]
376
+ assert isinstance(input_act, Node)
377
+ input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
378
+
379
+ weight = conv_node.args[1]
380
+ assert isinstance(weight, Node)
381
+ input_qspec_map[weight] = get_weight_qspec(quantization_config)
382
+
383
+ # adding weight node to the partition as well
384
+ partition = [relu_node, conv_node, conv_node.args[1]]
385
+ bias = conv_node.args[2] if len(conv_node.args) > 2 else None
386
+ if isinstance(bias, Node):
387
+ input_qspec_map[bias] = get_bias_qspec(quantization_config)
388
+ partition.append(bias)
389
+
390
+ if _is_annotated(partition):
391
+ continue
392
+
393
+ if filter_fn and any(not filter_fn(n) for n in partition):
394
+ continue
395
+
396
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
397
+ input_qspec_map=input_qspec_map, _annotated=True
398
+ )
399
+ relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
400
+ output_qspec=get_output_act_qspec(
401
+ quantization_config
402
+ ), # type: ignore[arg-type]
403
+ _annotated=True,
404
+ )
405
+ _mark_nodes_as_annotated(partition)
406
+ annotated_partitions.append(partition)
407
+ return annotated_partitions
408
+
409
+
410
+ @register_annotator("conv_bn")
411
+ def _annotate_conv_bn(
412
+ gm: torch.fx.GraphModule,
413
+ quantization_config: Optional[QuantizationConfig],
414
+ filter_fn: Optional[Callable[[Node], bool]] = None,
415
+ ) -> Optional[List[List[Node]]]:
416
+ """
417
+ Find conv + batchnorm parititions
418
+ Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
419
+ """
420
+ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
421
+
422
+
423
+ @register_annotator("conv_bn_relu")
424
+ def _annotate_conv_bn_relu(
425
+ gm: torch.fx.GraphModule,
426
+ quantization_config: Optional[QuantizationConfig],
427
+ filter_fn: Optional[Callable[[Node], bool]] = None,
428
+ ) -> Optional[List[List[Node]]]:
429
+ """
430
+ Find conv + batchnorm + relu parititions
431
+ Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
432
+ """
433
+ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
434
+
435
+
436
+ def _do_annotate_conv_bn(
437
+ gm: torch.fx.GraphModule,
438
+ quantization_config: Optional[QuantizationConfig],
439
+ filter_fn: Optional[Callable[[Node], bool]],
440
+ has_relu: bool,
441
+ ) -> List[List[Node]]:
442
+ """
443
+ Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
444
+ return a list of annotated partitions.
445
+
446
+ The output of the pattern must include a dictionary from string name to node
447
+ for the following names: "input", "conv", "weight", "bias", and "output".
448
+ """
449
+
450
+ def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
451
+ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
452
+ conv = conv_fn(x, conv_weight, conv_bias)
453
+ bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
454
+ if has_relu:
455
+ output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
456
+ else:
457
+ output = bn
458
+ return output, {
459
+ "input": x,
460
+ "conv": conv,
461
+ "weight": conv_weight,
462
+ "bias": conv_bias,
463
+ "output": output,
464
+ }
465
+
466
+ return _conv_bn
467
+
468
+ # Needed for matching, otherwise the matches gets filtered out due to unused
469
+ # nodes returned by batch norm
470
+ gm.graph.eliminate_dead_code()
471
+ gm.recompile()
472
+
473
+ matches = []
474
+ combinations = [
475
+ (F.conv1d, _conv1d_bn_example_inputs),
476
+ (F.conv2d, _conv2d_bn_example_inputs),
477
+ ]
478
+
479
+ # Add `is_cuda` and `relu_is_inplace` dimensions
480
+ combinations = itertools.product(
481
+ combinations,
482
+ [True, False] if torch.cuda.is_available() else [False], # is_cuda
483
+ [True, False] if has_relu else [False], # relu_is_inplace
484
+ )
485
+
486
+ # Match against all conv dimensions and cuda variants
487
+ for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
488
+ pattern = get_pattern(conv_fn, relu_is_inplace)
489
+ pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
490
+ pattern.graph.eliminate_dead_code()
491
+ pattern.recompile()
492
+ matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
493
+ matches.extend(matcher.match(gm.graph))
494
+
495
+ # Annotate nodes returned in the matches
496
+ annotated_partitions = []
497
+ for match in matches:
498
+ name_node_map = match.name_node_map
499
+ input_node = name_node_map["input"]
500
+ conv_node = name_node_map["conv"]
501
+ weight_node = name_node_map["weight"]
502
+ bias_node = name_node_map["bias"]
503
+ output_node = name_node_map["output"]
504
+
505
+ # TODO: annotate the uses of input, weight, and bias separately instead
506
+ # of assuming they come from a single conv node. This is not possible today
507
+ # because input may have multiple users, and we can't rely on the conv node
508
+ # always being the first user. This was the case in models with skip
509
+ # connections like resnet18
510
+
511
+ # Validate conv args
512
+ if conv_node.args[0] is not input_node:
513
+ raise ValueError("Conv arg did not contain input node ", input_node)
514
+ if conv_node.args[1] is not weight_node:
515
+ raise ValueError("Conv arg did not contain weight node ", weight_node)
516
+ if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
517
+ raise ValueError("Conv arg did not contain bias node ", bias_node)
518
+
519
+ # Skip if the partition is already annotated or is filtered out by the user
520
+ partition = [conv_node, weight_node]
521
+ if bias_node is not None:
522
+ partition.append(bias_node)
523
+ if _is_annotated(partition):
524
+ continue
525
+ if filter_fn and any(not filter_fn(n) for n in partition):
526
+ continue
527
+
528
+ # Annotate conv inputs and pattern output
529
+ input_qspec_map = {}
530
+ input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
531
+ input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
532
+ if bias_node is not None:
533
+ input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
534
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
535
+ input_qspec_map=input_qspec_map,
536
+ _annotated=True,
537
+ )
538
+ output_node.meta["quantization_annotation"] = QuantizationAnnotation(
539
+ output_qspec=get_output_act_qspec(
540
+ quantization_config
541
+ ), # type: ignore[arg-type]
542
+ _annotated=True,
543
+ )
544
+ _mark_nodes_as_annotated(partition)
545
+ annotated_partitions.append(partition)
546
+ return annotated_partitions
547
+
548
+
549
+ @register_annotator("gru_io_only")
550
+ def _annotate_gru_io_only(
551
+ gm: torch.fx.GraphModule,
552
+ quantization_config: Optional[QuantizationConfig],
553
+ filter_fn: Optional[Callable[[Node], bool]] = None,
554
+ ) -> Optional[List[List[Node]]]:
555
+ gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
556
+ gru_partitions = list(itertools.chain(*gru_partitions.values()))
557
+ annotated_partitions = []
558
+ for gru_partition in gru_partitions:
559
+ annotated_partitions.append(gru_partition.nodes)
560
+ output_nodes = gru_partition.output_nodes
561
+ input_nodes = gru_partition.input_nodes
562
+ # skip annotation if it is already annotated
563
+ if _is_annotated(input_nodes + output_nodes):
564
+ continue
565
+ # inside each GRU partition, we should be able to annotate each linear
566
+ # subgraph
567
+ input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
568
+ input_act = input_nodes[0]
569
+ input_act_user = next(iter(input_act.users.keys()))
570
+ assert isinstance(input_act, Node)
571
+ assert isinstance(input_act_user, Node)
572
+ input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
573
+ input_qspec_map={
574
+ input_act: get_input_act_qspec(quantization_config),
575
+ },
576
+ _annotated=True,
577
+ )
578
+
579
+ hidden_state = input_nodes[1]
580
+ hidden_state_user = next(iter(hidden_state.users.keys()))
581
+ assert isinstance(hidden_state, Node)
582
+ assert isinstance(hidden_state_user, Node)
583
+ hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
584
+ input_qspec_map={
585
+ hidden_state: get_input_act_qspec(quantization_config),
586
+ },
587
+ _annotated=True,
588
+ )
589
+
590
+ assert len(output_nodes) == 2, "expecting GRU to have two outputs"
591
+ for output in output_nodes:
592
+ output.meta["quantization_annotation"] = QuantizationAnnotation(
593
+ output_qspec=get_output_act_qspec(quantization_config),
594
+ _annotated=True,
595
+ )
596
+ nodes_to_mark_annotated = list(gru_partition.nodes)
597
+ _mark_nodes_as_annotated(nodes_to_mark_annotated)
598
+ return annotated_partitions
599
+
600
+
601
+ @register_annotator("max_pool2d")
602
+ def _annotate_max_pool2d(
603
+ gm: torch.fx.GraphModule,
604
+ quantization_config: Optional[QuantizationConfig],
605
+ filter_fn: Optional[Callable[[Node], bool]] = None,
606
+ ) -> Optional[List[List[Node]]]:
607
+ module_partitions = get_source_partitions(
608
+ gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
609
+ )
610
+ maxpool_partitions = list(itertools.chain(*module_partitions.values()))
611
+ annotated_partitions = []
612
+ for maxpool_partition in maxpool_partitions:
613
+ annotated_partitions.append(maxpool_partition.nodes)
614
+ output_node = maxpool_partition.output_nodes[0]
615
+ maxpool_node = None
616
+ for n in maxpool_partition.nodes:
617
+ if (
618
+ n.target == torch.ops.aten.max_pool2d.default
619
+ or n.target == torch.ops.aten.max_pool2d_with_indices.default
620
+ ):
621
+ maxpool_node = n
622
+
623
+ assert (
624
+ maxpool_node is not None
625
+ ), "PT2EQuantizer only works with torch.ops.aten.max_pool2d.default, "
626
+ "please make sure you are exporting the model correctly"
627
+ if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
628
+ continue
629
+
630
+ input_act = maxpool_node.args[0] # type: ignore[union-attr]
631
+ assert isinstance(input_act, Node)
632
+
633
+ # only annotate maxpool when the output of the input node is annotated
634
+ if (
635
+ "quantization_annotation" not in input_act.meta
636
+ or not input_act.meta["quantization_annotation"]._annotated
637
+ or input_act.meta["quantization_annotation"].output_qspec is None
638
+ ):
639
+ continue
640
+ # input and output of maxpool will share quantization parameter with input of maxpool
641
+ act_qspec = SharedQuantizationSpec(input_act)
642
+ # act_qspec = get_act_qspec(quantization_config)
643
+ maxpool_node.meta[
644
+ # type: ignore[union-attr]
645
+ "quantization_annotation"
646
+ ] = QuantizationAnnotation(
647
+ input_qspec_map={
648
+ input_act: act_qspec,
649
+ },
650
+ _annotated=True,
651
+ )
652
+ output_node.meta["quantization_annotation"] = QuantizationAnnotation(
653
+ output_qspec=act_qspec,
654
+ _annotated=True,
655
+ )
656
+ return annotated_partitions
657
+
658
+
659
+ @register_annotator("adaptive_avg_pool2d")
660
+ def _annotate_adaptive_avg_pool2d(
661
+ gm: torch.fx.GraphModule,
662
+ quantization_config: Optional[QuantizationConfig],
663
+ filter_fn: Optional[Callable[[Node], bool]] = None,
664
+ ) -> Optional[List[List[Node]]]:
665
+ """Always annotate adaptive_avg_pool2d op"""
666
+ module_partitions = get_source_partitions(
667
+ gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
668
+ )
669
+ partitions = list(itertools.chain(*module_partitions.values()))
670
+ annotated_partitions = []
671
+ for partition in partitions:
672
+ pool_node = partition.output_nodes[0]
673
+ if pool_node.op != "call_function" or (
674
+ pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
675
+ and pool_node.target != torch.ops.aten._adaptive_avg_pool2d.default
676
+ and pool_node.target != torch.ops.aten.mean.dim
677
+ and pool_node.target != torch.ops.aten.as_strided_.default
678
+ ):
679
+ raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
680
+
681
+ if _is_annotated([pool_node]):
682
+ continue
683
+
684
+ annotated_partitions.append(partition.nodes)
685
+ input_act = pool_node.args[0]
686
+ assert isinstance(input_act, Node)
687
+
688
+ # only annotate input output sharing operator
689
+ # when the output of the input node is annotated
690
+ if (
691
+ "quantization_annotation" not in input_act.meta
692
+ or not input_act.meta["quantization_annotation"]._annotated
693
+ or input_act.meta["quantization_annotation"].output_qspec is None
694
+ ):
695
+ input_act_qspec = get_input_act_qspec(quantization_config)
696
+ else:
697
+ input_act_qspec = SharedQuantizationSpec(input_act)
698
+
699
+ # output sharing with input
700
+ output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
701
+ pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
702
+ input_qspec_map={
703
+ input_act: input_act_qspec,
704
+ },
705
+ output_qspec=output_act_qspec,
706
+ _annotated=True,
707
+ )
708
+ return annotated_partitions
709
+
710
+
711
+ @register_annotator("fixed_qparams")
712
+ def _annotate_fixed_qparams(
713
+ gm: torch.fx.GraphModule,
714
+ quantization_config: Optional[QuantizationConfig],
715
+ filter_fn: Optional[Callable[[Node], bool]] = None,
716
+ ) -> Optional[List[List[Node]]]:
717
+ annotated_partitions = []
718
+ for node in gm.graph.nodes:
719
+ if node.op != "call_function" or (
720
+ node.target != torch.ops.aten.sigmoid.default
721
+ and node.target != torch.ops.aten._softmax.default
722
+ ):
723
+ continue
724
+
725
+ input_act = node.args[0] # type: ignore[union-attr]
726
+ assert isinstance(input_act, Node)
727
+
728
+ # only annotate when the output of the input node is annotated
729
+ if (
730
+ "quantization_annotation" not in input_act.meta
731
+ or not input_act.meta["quantization_annotation"]._annotated
732
+ or input_act.meta["quantization_annotation"].output_qspec is None
733
+ ):
734
+ continue
735
+ partition = [node]
736
+
737
+ if _is_annotated(partition):
738
+ continue
739
+
740
+ if filter_fn and any(not filter_fn(n) for n in partition):
741
+ continue
742
+
743
+ node.meta["quantization_annotation"] = QuantizationAnnotation(
744
+ output_qspec=get_fixed_qparams_qspec(quantization_config), _annotated=True
745
+ )
746
+ _mark_nodes_as_annotated(partition)
747
+ annotated_partitions.append(partition)
748
+
749
+ return annotated_partitions
750
+
751
+
752
+ @register_annotator("add_relu")
753
+ def _annotate_add_relu(
754
+ gm: torch.fx.GraphModule,
755
+ quantization_config: Optional[QuantizationConfig],
756
+ filter_fn: Optional[Callable[[Node], bool]] = None,
757
+ ) -> Optional[List[List[Node]]]:
758
+ fused_partitions = find_sequential_partitions(
759
+ gm, [torch.add, torch.nn.ReLU], filter_fn
760
+ )
761
+ annotated_partitions = []
762
+ for fused_partition in fused_partitions:
763
+ add_partition, relu_partition = fused_partition
764
+ annotated_partitions.append(add_partition.nodes + relu_partition.nodes)
765
+ if len(relu_partition.output_nodes) > 1:
766
+ raise ValueError("Relu partition has more than one output node")
767
+ relu_node = relu_partition.output_nodes[0]
768
+ if len(add_partition.output_nodes) > 1:
769
+ raise ValueError("add partition has more than one output node")
770
+ add_node = add_partition.output_nodes[0]
771
+
772
+ if _is_annotated([relu_node, add_node]):
773
+ continue
774
+
775
+ input_act_qspec = get_input_act_qspec(quantization_config)
776
+ output_act_qspec = get_output_act_qspec(quantization_config)
777
+
778
+ input_qspec_map = {}
779
+ input_act0 = add_node.args[0]
780
+ if isinstance(input_act0, Node):
781
+ input_qspec_map[input_act0] = input_act_qspec
782
+
783
+ input_act1 = add_node.args[1]
784
+ if isinstance(input_act1, Node):
785
+ input_qspec_map[input_act1] = input_act_qspec
786
+
787
+ add_node.meta["quantization_annotation"] = QuantizationAnnotation(
788
+ input_qspec_map=input_qspec_map,
789
+ _annotated=True,
790
+ )
791
+ relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
792
+ output_qspec=output_act_qspec,
793
+ _annotated=True,
794
+ )
795
+ return annotated_partitions
796
+
797
+
798
+ @register_annotator("add")
799
+ def _annotate_add(
800
+ gm: torch.fx.GraphModule,
801
+ quantization_config: Optional[QuantizationConfig],
802
+ filter_fn: Optional[Callable[[Node], bool]] = None,
803
+ ) -> Optional[List[List[Node]]]:
804
+ add_partitions = get_source_partitions(
805
+ gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
806
+ )
807
+ add_partitions = list(itertools.chain(*add_partitions.values()))
808
+ annotated_partitions = []
809
+ for add_partition in add_partitions:
810
+ annotated_partitions.append(add_partition.nodes)
811
+ add_node = add_partition.output_nodes[0]
812
+ if _is_annotated([add_node]):
813
+ continue
814
+
815
+ input_act_qspec = get_input_act_qspec(quantization_config)
816
+ output_act_qspec = get_output_act_qspec(quantization_config)
817
+
818
+ input_qspec_map = {}
819
+ input_act0 = add_node.args[0]
820
+ if isinstance(input_act0, Node):
821
+ input_qspec_map[input_act0] = input_act_qspec
822
+
823
+ input_act1 = add_node.args[1]
824
+ if isinstance(input_act1, Node):
825
+ input_qspec_map[input_act1] = input_act_qspec
826
+
827
+ add_node.meta["quantization_annotation"] = QuantizationAnnotation(
828
+ input_qspec_map=input_qspec_map,
829
+ output_qspec=output_act_qspec,
830
+ _annotated=True,
831
+ )
832
+ return annotated_partitions
833
+
834
+
835
+ @register_annotator("mul_relu")
836
+ def _annotate_mul_relu(
837
+ gm: torch.fx.GraphModule,
838
+ quantization_config: Optional[QuantizationConfig],
839
+ filter_fn: Optional[Callable[[Node], bool]] = None,
840
+ ) -> Optional[List[List[Node]]]:
841
+ fused_partitions = find_sequential_partitions(
842
+ gm, [torch.mul, torch.nn.ReLU], filter_fn
843
+ )
844
+ annotated_partitions = []
845
+ for fused_partition in fused_partitions:
846
+ mul_partition, relu_partition = fused_partition
847
+ annotated_partitions.append(mul_partition.nodes + relu_partition.nodes)
848
+ if len(relu_partition.output_nodes) > 1:
849
+ raise ValueError("Relu partition has more than one output node")
850
+ relu_node = relu_partition.output_nodes[0]
851
+ if len(mul_partition.output_nodes) > 1:
852
+ raise ValueError("mul partition has more than one output node")
853
+ mul_node = mul_partition.output_nodes[0]
854
+
855
+ if _is_annotated([relu_node, mul_node]):
856
+ continue
857
+
858
+ input_act_qspec = get_input_act_qspec(quantization_config)
859
+ output_act_qspec = get_output_act_qspec(quantization_config)
860
+
861
+ input_qspec_map = {}
862
+ input_act0 = mul_node.args[0]
863
+ if isinstance(input_act0, Node):
864
+ input_qspec_map[input_act0] = input_act_qspec
865
+
866
+ input_act1 = mul_node.args[1]
867
+ if isinstance(input_act1, Node):
868
+ input_qspec_map[input_act1] = input_act_qspec
869
+
870
+ mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
871
+ input_qspec_map=input_qspec_map,
872
+ _annotated=True,
873
+ )
874
+ relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
875
+ output_qspec=output_act_qspec,
876
+ _annotated=True,
877
+ )
878
+ return annotated_partitions
879
+
880
+
881
+ @register_annotator("mul")
882
+ def _annotate_mul(
883
+ gm: torch.fx.GraphModule,
884
+ quantization_config: Optional[QuantizationConfig],
885
+ filter_fn: Optional[Callable[[Node], bool]] = None,
886
+ ) -> Optional[List[List[Node]]]:
887
+ mul_partitions = get_source_partitions(
888
+ gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
889
+ )
890
+ mul_partitions = list(itertools.chain(*mul_partitions.values()))
891
+ annotated_partitions = []
892
+ for mul_partition in mul_partitions:
893
+ annotated_partitions.append(mul_partition.nodes)
894
+ mul_node = mul_partition.output_nodes[0]
895
+ if _is_annotated([mul_node]):
896
+ continue
897
+
898
+ input_act_qspec = get_input_act_qspec(quantization_config)
899
+ output_act_qspec = get_output_act_qspec(quantization_config)
900
+
901
+ input_qspec_map = {}
902
+ input_act0 = mul_node.args[0]
903
+ if isinstance(input_act0, Node):
904
+ input_qspec_map[input_act0] = input_act_qspec
905
+
906
+ input_act1 = mul_node.args[1]
907
+ if isinstance(input_act1, Node):
908
+ input_qspec_map[input_act1] = input_act_qspec
909
+
910
+ mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
911
+ input_qspec_map=input_qspec_map,
912
+ output_qspec=output_act_qspec,
913
+ _annotated=True,
914
+ )
915
+ return annotated_partitions
916
+
917
+
918
+ # TODO: remove Optional in return type, fix annotated_partitions logic
919
+ @register_annotator("cat")
920
+ def _annotate_cat(
921
+ gm: torch.fx.GraphModule,
922
+ quantization_config: Optional[QuantizationConfig],
923
+ filter_fn: Optional[Callable[[Node], bool]] = None,
924
+ ) -> Optional[List[List[Node]]]:
925
+ cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
926
+ cat_partitions = list(itertools.chain(*cat_partitions.values()))
927
+ annotated_partitions = []
928
+ for cat_partition in cat_partitions:
929
+ cat_node = cat_partition.output_nodes[0]
930
+ if _is_annotated([cat_node]):
931
+ continue
932
+
933
+ if cat_node.target != torch.ops.aten.cat.default:
934
+ raise Exception(
935
+ f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
936
+ " please check if you are calling the correct capture API"
937
+ )
938
+
939
+ annotated_partitions.append(cat_partition.nodes)
940
+
941
+ input_act_qspec = get_input_act_qspec(quantization_config)
942
+ inputs = cat_node.args[0]
943
+
944
+ input_qspec_map = {}
945
+ input_act0 = inputs[0]
946
+ if isinstance(input_act0, Node):
947
+ input_qspec_map[input_act0] = input_act_qspec
948
+
949
+ shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
950
+ for input_act in inputs[1:]:
951
+ input_qspec_map[input_act] = shared_with_input0_qspec
952
+
953
+ output_act_qspec = shared_with_input0_qspec
954
+
955
+ cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
956
+ input_qspec_map=input_qspec_map,
957
+ output_qspec=output_act_qspec,
958
+ _annotated=True,
959
+ )
960
+ return annotated_partitions
961
+
962
+
963
+ def _is_share_obs_or_fq_op(op: Callable) -> bool:
964
+ return op in [
965
+ torch.ops.aten.hardtanh.default,
966
+ torch.ops.aten.hardtanh_.default,
967
+ torch.ops.aten.mean.default,
968
+ torch.ops.aten.mean.dim,
969
+ torch.ops.aten.permute.default,
970
+ torch.ops.aten.permute_copy.default,
971
+ torch.ops.aten.squeeze.dim,
972
+ torch.ops.aten.squeeze_copy.dim,
973
+ torch.ops.aten.adaptive_avg_pool2d.default,
974
+ torch.ops.aten.view_copy.default,
975
+ torch.ops.aten.view.default,
976
+ torch.ops.aten.slice_copy.Tensor,
977
+ torch.ops.aten.flatten.using_ints,
978
+ ]
979
+
980
+
981
+ def propagate_annotation(model: torch.fx.GraphModule) -> None:
982
+ for n in model.graph.nodes:
983
+ if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
984
+ continue
985
+
986
+ prev_node = n.args[0]
987
+ if not isinstance(prev_node, Node):
988
+ continue
989
+
990
+ quantization_annotation = prev_node.meta.get("quantization_annotation", None)
991
+ if not quantization_annotation:
992
+ continue
993
+
994
+ output_qspec = quantization_annotation.output_qspec
995
+ if not output_qspec:
996
+ continue
997
+
998
+ # make sure current node is not annotated
999
+ if (
1000
+ "quantization_annotation" in n.meta
1001
+ and n.meta["quantization_annotation"]._annotated
1002
+ ):
1003
+ continue
1004
+
1005
+ shared_qspec = SharedQuantizationSpec(prev_node)
1006
+ # propagate the previous output_qspec to the current node
1007
+ n.meta["quantization_annotation"] = QuantizationAnnotation(
1008
+ input_qspec_map={
1009
+ prev_node: shared_qspec,
1010
+ },
1011
+ output_qspec=shared_qspec,
1012
+ _annotated=True,
1013
+ )
1014
+
1015
+
1016
+ # TODO: make the list of ops customizable
1017
+ def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1018
+ for n in model.graph.nodes:
1019
+ if n.op != "call_function" or n.target not in [
1020
+ torch.ops.aten.add.Tensor,
1021
+ torch.ops.aten.mul.Tensor,
1022
+ ]:
1023
+ continue
1024
+ args = list(n.args)
1025
+ new_args = []
1026
+ for i in range(len(args)):
1027
+ if isinstance(args[i], torch.fx.Node):
1028
+ new_args.append(args[i])
1029
+ continue
1030
+ prefix = "_tensor_constant_"
1031
+ get_new_attr_name = get_new_attr_name_with_prefix(prefix)
1032
+ tensor_constant_name = get_new_attr_name(model)
1033
+ model.register_buffer(tensor_constant_name, torch.tensor(args[i]))
1034
+ with model.graph.inserting_before(n):
1035
+ get_attr_node = model.graph.create_node(
1036
+ "get_attr", tensor_constant_name, (), {}
1037
+ )
1038
+ new_args.append(get_attr_node)
1039
+ n.args = tuple(new_args)
1040
+ model.recompile()
1041
+ return model