tico 0.1.0.dev251106__py3-none-any.whl → 0.2.0.dev260122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tico/__init__.py +2 -2
  2. tico/_version.py +1 -0
  3. tico/passes/convert_conv3d_to_conv2d.py +435 -0
  4. tico/passes/convert_sym_size_to_circle_shape.py +99 -0
  5. tico/passes/decompose_batch_norm.py +9 -5
  6. tico/passes/lower_copy.py +95 -0
  7. tico/passes/ops.py +4 -0
  8. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +251 -0
  9. tico/quantization/algorithm/fpi_gptq/quantizer.py +180 -0
  10. tico/quantization/algorithm/gptq/gptq.py +231 -11
  11. tico/quantization/algorithm/gptq/quantizer.py +18 -6
  12. tico/quantization/config/{pt2e.py → fpi_gptq.py} +11 -4
  13. tico/quantization/config/gptq.py +27 -4
  14. tico/quantization/public_interface.py +0 -10
  15. tico/quantization/wrapq/quantizer.py +2 -0
  16. tico/quantization/wrapq/wrappers/quant_elementwise.py +51 -11
  17. tico/serialize/operators/adapters/onert/llama_attention.py +51 -0
  18. tico/serialize/operators/op_attention.py +58 -0
  19. tico/serialize/operators/op_circle_shape.py +64 -0
  20. tico/serialize/operators/op_dequantize_per_channel.py +1 -0
  21. tico/serialize/operators/op_dequantize_per_tensor.py +1 -0
  22. tico/serialize/operators/op_transpose_conv.py +66 -50
  23. tico/utils/convert.py +16 -1
  24. tico/utils/padding.py +13 -5
  25. tico/utils/record_input.py +2 -2
  26. tico/utils/register_custom_op.py +63 -0
  27. tico/utils/validate_args_kwargs.py +49 -4
  28. tico-0.2.0.dev260122.dist-info/METADATA +631 -0
  29. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/RECORD +35 -46
  30. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/WHEEL +1 -1
  31. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/entry_points.txt +0 -1
  32. tico/quantization/algorithm/pt2e/annotation/annotator.py +0 -208
  33. tico/quantization/algorithm/pt2e/annotation/config.py +0 -26
  34. tico/quantization/algorithm/pt2e/annotation/op/__init__.py +0 -21
  35. tico/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +0 -63
  36. tico/quantization/algorithm/pt2e/annotation/op/add.py +0 -55
  37. tico/quantization/algorithm/pt2e/annotation/op/conv2d.py +0 -90
  38. tico/quantization/algorithm/pt2e/annotation/op/div.py +0 -55
  39. tico/quantization/algorithm/pt2e/annotation/op/linear.py +0 -92
  40. tico/quantization/algorithm/pt2e/annotation/op/mean.py +0 -51
  41. tico/quantization/algorithm/pt2e/annotation/op/mul.py +0 -55
  42. tico/quantization/algorithm/pt2e/annotation/op/relu6.py +0 -51
  43. tico/quantization/algorithm/pt2e/annotation/op/rsqrt.py +0 -51
  44. tico/quantization/algorithm/pt2e/annotation/op/sub.py +0 -55
  45. tico/quantization/algorithm/pt2e/annotation/spec.py +0 -45
  46. tico/quantization/algorithm/pt2e/annotation/utils.py +0 -88
  47. tico/quantization/algorithm/pt2e/quantizer.py +0 -81
  48. tico/quantization/algorithm/pt2e/transformation/__init__.py +0 -1
  49. tico/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -58
  50. tico/quantization/algorithm/pt2e/utils.py +0 -135
  51. tico/serialize/operators/op_copy.py +0 -187
  52. tico-0.1.0.dev251106.dist-info/METADATA +0 -392
  53. /tico/quantization/algorithm/{pt2e → fpi_gptq}/__init__.py +0 -0
  54. /tico/{quantization/algorithm/pt2e/annotation → serialize/operators/adapters/onert}/__init__.py +0 -0
  55. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info/licenses}/LICENSE +0 -0
  56. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/top_level.txt +0 -0
@@ -1,90 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
- from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
-
22
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
26
- from tico.utils.validate_args_kwargs import Conv2DArgs
27
-
28
-
29
- @annot_spec.register_annotator(
30
- [torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.padding]
31
- )
32
- def _annotate_conv2d(
33
- gm: torch.fx.GraphModule,
34
- node: torch.fx.Node,
35
- quantization_config: Optional[QuantizationConfig],
36
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
37
- ):
38
- for node in gm.graph.nodes:
39
- if node.op != "call_function" or node.target not in [
40
- torch.ops.aten.conv2d.default,
41
- torch.ops.aten.conv2d.padding,
42
- ]:
43
- continue
44
- if filter_fn and not filter_fn(node):
45
- continue
46
- if quant_utils.is_annotated(node):
47
- continue
48
-
49
- args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50
- input_ = args.input
51
- weight = args.weight
52
- bias = args.bias
53
-
54
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
55
- weight_qspec = quant_utils.get_weight_qspec(quantization_config)
56
- annot_utils.annotate_input_qspec_map(node, input_, input_act_qspec)
57
- annot_utils.annotate_input_qspec_map(node, weight, weight_qspec)
58
- nodes_to_mark_annotated = [input_, weight, node]
59
- if bias:
60
-
61
- def _derive_bias_qparams_from_act_and_weight_qparams(obs_or_fqs):
62
- act_scale, _ = obs_or_fqs[0].calculate_qparams()
63
- weight_scale, _ = obs_or_fqs[1].calculate_qparams()
64
- bias_scale = act_scale * weight_scale
65
- bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
66
- return bias_scale, bias_zero_point
67
-
68
- bias_qspec = DerivedQuantizationSpec(
69
- derived_from=[
70
- (input_, node),
71
- (weight, node),
72
- ],
73
- derive_qparams_fn=_derive_bias_qparams_from_act_and_weight_qparams,
74
- dtype=torch.int32,
75
- quant_min=-(2**31),
76
- quant_max=2**31 - 1,
77
- qscheme=weight_qspec.qscheme,
78
- ch_axis=0 if weight_qspec.qscheme == torch.per_channel_affine else None,
79
- )
80
- annot_utils.annotate_input_qspec_map(
81
- node,
82
- bias,
83
- bias_qspec,
84
- )
85
- nodes_to_mark_annotated.append(bias)
86
-
87
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
88
- annot_utils.annotate_output_qspec(node, output_act_qspec)
89
-
90
- annot_utils.mark_nodes_as_annotated(nodes_to_mark_annotated)
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
25
- from tico.utils.validate_args_kwargs import DivTensorArgs
26
-
27
-
28
- @annot_spec.register_annotator([torch.ops.aten.div.Tensor])
29
- def _annotate_div(
30
- gm: torch.fx.GraphModule,
31
- node: torch.fx.Node,
32
- quantization_config: Optional[QuantizationConfig],
33
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
34
- ):
35
- if node.op != "call_function" or node.target != torch.ops.aten.div.Tensor:
36
- return
37
- if filter_fn and not filter_fn(node):
38
- return
39
- if quant_utils.is_annotated(node):
40
- return
41
-
42
- args = DivTensorArgs(*node.args) # type: ignore[arg-type]
43
- input = args.input
44
- other = args.other
45
-
46
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
47
- if isinstance(input, torch.fx.Node):
48
- annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
49
- if isinstance(other, torch.fx.Node):
50
- annot_utils.annotate_input_qspec_map(node, other, input_act_qspec)
51
-
52
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
53
- annot_utils.annotate_output_qspec(node, output_act_qspec)
54
-
55
- annot_utils.mark_nodes_as_annotated(node)
@@ -1,92 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
- from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
-
22
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
26
- from tico.utils.validate_args_kwargs import LinearArgs
27
-
28
-
29
- @annot_spec.register_annotator([torch.ops.aten.linear.default])
30
- def _annotate_linear(
31
- gm: torch.fx.GraphModule,
32
- node: torch.fx.Node,
33
- quantization_config: Optional[QuantizationConfig],
34
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
35
- ):
36
- if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
37
- return
38
- if filter_fn and not filter_fn(node):
39
- return
40
- if quant_utils.is_annotated(node):
41
- return
42
-
43
- args = LinearArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
- input_ = args.input
45
- weight = args.weight
46
- bias = args.bias
47
-
48
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
49
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
50
- weight_qspec = quant_utils.get_weight_qspec(quantization_config)
51
- bias_qspec = quant_utils.get_bias_qspec(quantization_config)
52
-
53
- annot_utils.annotate_input_qspec_map(
54
- node,
55
- input_,
56
- input_act_qspec,
57
- )
58
- annot_utils.annotate_input_qspec_map(
59
- node,
60
- weight,
61
- weight_qspec,
62
- )
63
- nodes_to_mark_annotated = [node, weight]
64
- if bias:
65
-
66
- def _derive_bias_qparams_from_act_and_weight_qparams(obs_or_fqs):
67
- act_scale, _ = obs_or_fqs[0].calculate_qparams()
68
- weight_scale, _ = obs_or_fqs[1].calculate_qparams()
69
- bias_scale = act_scale * weight_scale
70
- bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
71
- return bias_scale, bias_zero_point
72
-
73
- bias_qspec = DerivedQuantizationSpec(
74
- derived_from=[
75
- (input_, node),
76
- (weight, node),
77
- ],
78
- derive_qparams_fn=_derive_bias_qparams_from_act_and_weight_qparams,
79
- dtype=torch.int32,
80
- quant_min=-(2**31),
81
- quant_max=2**31 - 1,
82
- qscheme=weight_qspec.qscheme,
83
- ch_axis=0 if weight_qspec.qscheme == torch.per_channel_affine else None,
84
- )
85
- annot_utils.annotate_input_qspec_map(
86
- node,
87
- bias,
88
- bias_qspec,
89
- )
90
- nodes_to_mark_annotated.append(bias)
91
- annot_utils.annotate_output_qspec(node, output_act_qspec)
92
- annot_utils.mark_nodes_as_annotated(nodes_to_mark_annotated)
@@ -1,51 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
25
- from tico.utils.validate_args_kwargs import MeanDimArgs
26
-
27
-
28
- @annot_spec.register_annotator([torch.ops.aten.mean.dim])
29
- def _annotate_mean(
30
- gm: torch.fx.GraphModule,
31
- node: torch.fx.Node,
32
- quantization_config: Optional[QuantizationConfig],
33
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
34
- ):
35
- if node.op != "call_function" or node.target != torch.ops.aten.mean.dim:
36
- return
37
- if filter_fn and not filter_fn(node):
38
- return
39
- if quant_utils.is_annotated(node):
40
- return
41
-
42
- args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
- input = args.input
44
-
45
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
46
- annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
47
-
48
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
49
- annot_utils.annotate_output_qspec(node, output_act_qspec)
50
-
51
- annot_utils.mark_nodes_as_annotated(node)
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
25
- from tico.utils.validate_args_kwargs import MulTensorArgs
26
-
27
-
28
- @annot_spec.register_annotator([torch.ops.aten.mul.Tensor])
29
- def _annotate_mul(
30
- gm: torch.fx.GraphModule,
31
- node: torch.fx.Node,
32
- quantization_config: Optional[QuantizationConfig],
33
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
34
- ):
35
- if node.op != "call_function" or node.target != torch.ops.aten.mul.Tensor:
36
- return
37
- if filter_fn and not filter_fn(node):
38
- return
39
- if quant_utils.is_annotated(node):
40
- return
41
-
42
- args = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
- input = args.input
44
- other = args.other
45
-
46
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
47
- if isinstance(input, torch.fx.Node):
48
- annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
49
- if isinstance(other, torch.fx.Node):
50
- annot_utils.annotate_input_qspec_map(node, other, input_act_qspec)
51
-
52
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
53
- annot_utils.annotate_output_qspec(node, output_act_qspec)
54
-
55
- annot_utils.mark_nodes_as_annotated(node)
@@ -1,51 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
25
- from tico.utils.validate_args_kwargs import Relu6Args
26
-
27
-
28
- @annot_spec.register_annotator([torch.ops.aten.relu6.default])
29
- def _annotate_relu6(
30
- gm: torch.fx.GraphModule,
31
- node: torch.fx.Node,
32
- quantization_config: Optional[QuantizationConfig],
33
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
34
- ):
35
- if node.op != "call_function" or node.target != torch.ops.aten.relu6.default:
36
- return
37
- if filter_fn and not filter_fn(node):
38
- return
39
- if quant_utils.is_annotated(node):
40
- return
41
-
42
- args = Relu6Args(*node.args, **node.kwargs) # type: ignore
43
- input = args.input
44
-
45
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
46
- annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
47
-
48
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
49
- annot_utils.annotate_output_qspec(node, output_act_qspec)
50
-
51
- annot_utils.mark_nodes_as_annotated(node)
@@ -1,51 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
25
- from tico.utils.validate_args_kwargs import RsqrtArgs
26
-
27
-
28
- @annot_spec.register_annotator([torch.ops.aten.rsqrt.default])
29
- def _annotate_rsqrt(
30
- gm: torch.fx.GraphModule,
31
- node: torch.fx.Node,
32
- quantization_config: Optional[QuantizationConfig],
33
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
34
- ):
35
- if node.op != "call_function" or node.target != torch.ops.aten.rsqrt.default:
36
- return
37
- if filter_fn and not filter_fn(node):
38
- return
39
- if quant_utils.is_annotated(node):
40
- return
41
-
42
- args = RsqrtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
- input = args.input
44
-
45
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
46
- annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
47
-
48
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
49
- annot_utils.annotate_output_qspec(node, output_act_qspec)
50
-
51
- annot_utils.mark_nodes_as_annotated(node)
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
25
- from tico.utils.validate_args_kwargs import SubTensorArgs
26
-
27
-
28
- @annot_spec.register_annotator([torch.ops.aten.sub.Tensor])
29
- def _annotate_sub(
30
- gm: torch.fx.GraphModule,
31
- node: torch.fx.Node,
32
- quantization_config: Optional[QuantizationConfig],
33
- filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
34
- ):
35
- if node.op != "call_function" or node.target != torch.ops.aten.sub.Tensor:
36
- return
37
- if filter_fn and not filter_fn(node):
38
- return
39
- if quant_utils.is_annotated(node):
40
- return
41
-
42
- args = SubTensorArgs(*node.args) # type: ignore[arg-type]
43
- input = args.input
44
- other = args.other
45
-
46
- input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
47
- if isinstance(input, torch.fx.Node):
48
- annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
49
- if isinstance(other, torch.fx.Node):
50
- annot_utils.annotate_input_qspec_map(node, other, input_act_qspec)
51
-
52
- output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
53
- annot_utils.annotate_output_qspec(node, output_act_qspec)
54
-
55
- annot_utils.mark_nodes_as_annotated(node)
@@ -1,45 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Callable, Dict, List, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
-
21
- from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
22
-
23
- AnnotatorType = Callable[
24
- [
25
- torch.fx.GraphModule,
26
- torch.fx.Node,
27
- Optional[QuantizationConfig],
28
- Optional[Callable[[torch.fx.Node], bool]],
29
- ],
30
- None,
31
- ]
32
- OP_TO_ANNOTATOR: Dict[torch._ops.OpOverload, AnnotatorType] = {}
33
- OP_TO_SHARE_QUANT_SPEC: List[Callable] = [
34
- torch.ops.aten.view_copy.default,
35
- torch.ops.aten.view.default,
36
- ]
37
-
38
-
39
- def register_annotator(target: List[torch._ops.OpOverload]):
40
- def decorator(annotator: AnnotatorType):
41
- for t in target:
42
- OP_TO_ANNOTATOR[t] = annotator
43
- return annotator
44
-
45
- return decorator
@@ -1,88 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import List, Optional, TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
- from torch.ao.quantization.quantizer import (
21
- QuantizationAnnotation,
22
- SharedQuantizationSpec,
23
- )
24
-
25
- import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
26
-
27
-
28
- def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):
29
- quantization_annotation: QuantizationAnnotation = node.meta.get(
30
- "quantization_annotation", QuantizationAnnotation()
31
- )
32
- quantization_annotation.input_qspec_map[input_node] = qspec
33
- node.meta["quantization_annotation"] = quantization_annotation
34
-
35
-
36
- def annotate_output_qspec(node: torch.fx.Node, qspec):
37
- quantization_annotation: QuantizationAnnotation = node.meta.get(
38
- "quantization_annotation", QuantizationAnnotation()
39
- )
40
- quantization_annotation.output_qspec = qspec
41
- node.meta["quantization_annotation"] = quantization_annotation
42
-
43
-
44
- def mark_nodes_as_annotated(nodes: List[torch.fx.Node] | torch.fx.Node):
45
- if isinstance(nodes, torch.fx.Node):
46
- nodes = [nodes]
47
- for node in nodes:
48
- if node is not None:
49
- if "quantization_annotation" not in node.meta:
50
- node.meta["quantization_annotation"] = QuantizationAnnotation()
51
- node.meta["quantization_annotation"]._annotated = True
52
-
53
-
54
- def propagate_annotation_forward(model: torch.fx.GraphModule) -> None:
55
- for n in model.graph.nodes:
56
- if n.op != "call_function" or n.target not in annot_spec.OP_TO_SHARE_QUANT_SPEC:
57
- continue
58
-
59
- prev_node = n.args[0]
60
- if not isinstance(prev_node, torch.fx.Node):
61
- continue
62
-
63
- quantization_annotation: Optional[QuantizationAnnotation] = prev_node.meta.get(
64
- "quantization_annotation", None
65
- )
66
- if not quantization_annotation:
67
- continue
68
-
69
- output_qspec = quantization_annotation.output_qspec
70
- if not output_qspec:
71
- continue
72
-
73
- # Make sure current node is not annotated
74
- if (
75
- "quantization_annotation" in n.meta
76
- and n.meta["quantization_annotation"]._annotated
77
- ):
78
- continue
79
-
80
- shared_qspec = SharedQuantizationSpec(prev_node)
81
- # Propagate the previous output_qspec to the current node
82
- n.meta["quantization_annotation"] = QuantizationAnnotation(
83
- input_qspec_map={
84
- prev_node: shared_qspec,
85
- },
86
- output_qspec=shared_qspec,
87
- _annotated=True,
88
- )