tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250904"
32
+ __version__ = "0.1.0.dev251109"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/config/v1.py CHANGED
@@ -20,6 +20,11 @@ from tico.config.base import CompileConfigBase
20
20
  @dataclass
21
21
  class CompileConfigV1(CompileConfigBase):
22
22
  legalize_causal_mask_value: bool = False
23
+ remove_constant_input: bool = False
24
+ convert_lhs_const_mm_to_fc: bool = False
25
+ convert_rhs_const_mm_to_fc: bool = True
26
+ convert_single_batch_lhs_const_bmm_to_fc: bool = False
27
+ convert_expand_to_slice_cat: bool = False
23
28
 
24
29
  def get(self, name: str):
25
30
  return super().get(name)
@@ -41,6 +41,8 @@ ops_to_promote = {
41
41
  torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
42
42
  torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
43
43
  torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
44
+ torch.ops.aten.le.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
45
+ torch.ops.aten.le.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
44
46
  torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
45
47
  torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
46
48
  torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
@@ -0,0 +1,153 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class ConvertExpandToSliceCat(PassBase):
34
+ """
35
+ This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
36
+ a series of `aten.slice` and `aten.cat` operations.
37
+
38
+ This pass is specialized for expand of KVCache.
39
+ - Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
40
+ """
41
+
42
+ def __init__(self, enabled: bool = False):
43
+ super().__init__()
44
+ self.enabled = enabled
45
+
46
+ def call(self, exported_program: ExportedProgram) -> PassResult:
47
+ if not self.enabled:
48
+ return PassResult(False)
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ graph_module = exported_program.graph_module
53
+ graph = graph_module.graph
54
+ modified = False
55
+
56
+ # This pass handles expand on EXPAND_DIM only
57
+ CAT_DIM = 1
58
+ EXPAND_DIM = 2
59
+
60
+ for node in graph.nodes:
61
+ if not isinstance(node, torch.fx.Node) or not is_target_node(
62
+ node, ops.aten.reshape
63
+ ):
64
+ continue
65
+
66
+ post_reshape = node
67
+ post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
68
+ post_reshape_input = post_reshape_args.input
69
+
70
+ if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
71
+ post_reshape_input, ops.aten.expand
72
+ ):
73
+ continue
74
+
75
+ expand = post_reshape_input
76
+ expand_args = ExpandArgs(*expand.args, **expand.kwargs)
77
+ expand_input = expand_args.input
78
+ expand_shape = extract_shape(expand)
79
+
80
+ if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
81
+ expand_input, ops.aten.reshape
82
+ ):
83
+ continue
84
+
85
+ pre_reshape = expand_input
86
+ pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
87
+ pre_reshape_input = pre_reshape_args.input
88
+ pre_reshape_shape = extract_shape(pre_reshape)
89
+
90
+ if pre_reshape_shape[EXPAND_DIM] != 1:
91
+ continue
92
+
93
+ reshape_input_shape = extract_shape(pre_reshape_input)
94
+
95
+ if len(expand_shape) != len(pre_reshape_shape):
96
+ continue
97
+
98
+ # Ensure all dimensions *except* at EXPAND_DIM are identical.
99
+ if not (
100
+ expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
101
+ and expand_shape[EXPAND_DIM + 1 :]
102
+ == pre_reshape_shape[EXPAND_DIM + 1 :]
103
+ ):
104
+ continue
105
+
106
+ # Ensure the expansion dimension is a clean multiple.
107
+ if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
108
+ continue
109
+
110
+ expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
111
+
112
+ if expand_ratio <= 1:
113
+ continue
114
+
115
+ cat_nodes = []
116
+
117
+ for i in range(reshape_input_shape[CAT_DIM]):
118
+ with graph.inserting_before(expand):
119
+ slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
120
+ slice_node = create_node(
121
+ graph,
122
+ torch.ops.aten.slice.Tensor,
123
+ args=slice_copy_args,
124
+ origin=expand,
125
+ )
126
+ with graph.inserting_after(slice_node):
127
+ cat_args = ([slice_node] * expand_ratio, CAT_DIM)
128
+ cat_node = create_node(
129
+ graph,
130
+ torch.ops.aten.cat.default,
131
+ args=cat_args,
132
+ origin=expand,
133
+ )
134
+ cat_nodes.append(cat_node)
135
+
136
+ with graph.inserting_after(expand):
137
+ cat_args = (cat_nodes, CAT_DIM)
138
+ cat_node = create_node(
139
+ graph,
140
+ torch.ops.aten.cat.default,
141
+ args=cat_args,
142
+ origin=expand,
143
+ )
144
+ expand.replace_all_uses_with(cat_node)
145
+
146
+ modified = True
147
+ logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
148
+
149
+ graph.eliminate_dead_code()
150
+ graph.lint()
151
+ graph_module.recompile()
152
+
153
+ return PassResult(modified)
@@ -0,0 +1,312 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.circle_mapping import extract_shape
24
+
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.validate_args_kwargs import BmmArgs, MatmulArgs
30
+
31
+
32
+ class Converter: # type: ignore[empty-body]
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
37
+ return False
38
+
39
+ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
40
+ pass
41
+
42
+
43
+ class MatmulToLinearConverter(Converter):
44
+ def __init__(self):
45
+ super().__init__()
46
+
47
+ def convert(self, exported_program, node) -> torch.fx.Node:
48
+ graph_module = exported_program.graph_module
49
+ graph = graph_module.graph
50
+
51
+ mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
52
+
53
+ lhs = mm_args.input
54
+ rhs = mm_args.other
55
+
56
+ with graph.inserting_before(node):
57
+ transpose_node = create_node(
58
+ graph,
59
+ torch.ops.aten.permute.default,
60
+ args=(rhs, [1, 0]),
61
+ )
62
+ linear_node = create_node(
63
+ graph,
64
+ torch.ops.aten.linear.default,
65
+ args=(lhs, transpose_node),
66
+ )
67
+ node.replace_all_uses_with(linear_node, propagate_meta=True)
68
+
69
+ return linear_node
70
+
71
+
72
+ class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def match(self, exported_program, node) -> bool:
77
+ if not node.target == torch.ops.aten.mm.default:
78
+ return False
79
+
80
+ mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
81
+
82
+ rhs = mm_args.other
83
+ if isinstance(rhs, torch.fx.Node):
84
+ if is_lifted_tensor_constant(exported_program, rhs):
85
+ return True
86
+ elif is_param(exported_program, rhs):
87
+ return True
88
+ elif is_buffer(exported_program, rhs):
89
+ return True
90
+ else:
91
+ return False
92
+ return False
93
+
94
+ def convert(self, exported_program, node) -> torch.fx.Node:
95
+ return super().convert(exported_program, node)
96
+
97
+
98
+ class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
99
+ def __init__(self):
100
+ super().__init__()
101
+
102
+ def match(self, exported_program, node) -> bool:
103
+ if not node.target == torch.ops.aten.mm.default:
104
+ return False
105
+
106
+ mm_args = MatmulArgs(*node.args, **node.kwargs)
107
+ lhs = mm_args.input
108
+ if isinstance(lhs, torch.fx.Node):
109
+ if is_lifted_tensor_constant(exported_program, lhs):
110
+ return True
111
+ elif is_param(exported_program, lhs):
112
+ return True
113
+ elif is_buffer(exported_program, lhs):
114
+ return True
115
+ return False
116
+
117
+ def convert(self, exported_program, node) -> torch.fx.Node:
118
+ return super().convert(exported_program, node)
119
+
120
+
121
+ class SingleBatchLhsConstBmmToLinearConverter(Converter):
122
+ """
123
+ Convert `single-batched & lhs-const BatchMatMul` to `linear` operation.
124
+
125
+ [1] exchange lhs and rhs
126
+ [2] transpose rhs
127
+ [3] transpose output
128
+
129
+ **Before**
130
+
131
+ lhs[1,a,b](const) rhs[1,b,c]
132
+ | |
133
+ | |
134
+ ---------bmm---------
135
+ |
136
+ output[1,a,c]
137
+
138
+
139
+ **After**
140
+
141
+ rhs[1,b,c]
142
+ |
143
+ tr lhs'[a,b](const-folded)
144
+ |[1,c,b] |
145
+ | |
146
+ ---------fc--------
147
+ |[1,c,a]
148
+ tr
149
+ |
150
+ output[1,a,c]
151
+
152
+ """
153
+
154
+ def __init__(self):
155
+ super().__init__()
156
+
157
+ def match(self, exported_program, node) -> bool:
158
+ if not node.target == torch.ops.aten.bmm.default:
159
+ return False
160
+
161
+ bmm_args = BmmArgs(*node.args, **node.kwargs)
162
+ lhs = bmm_args.input
163
+ rhs = bmm_args.mat2
164
+
165
+ # [1] Single-batch
166
+ lhs_shape = extract_shape(lhs)
167
+ rhs_shape = extract_shape(rhs)
168
+
169
+ assert len(lhs_shape) == len(
170
+ rhs_shape
171
+ ), f"Bmm input's ranks must be the same but got {lhs_shape} and {rhs_shape}"
172
+
173
+ if not (lhs_shape[0] == rhs_shape[0] == 1):
174
+ return False
175
+
176
+ # [2] Lhs is constant
177
+ if not isinstance(lhs, torch.fx.Node):
178
+ return False
179
+ if not (
180
+ is_lifted_tensor_constant(exported_program, lhs)
181
+ or is_param(exported_program, lhs)
182
+ or is_buffer(exported_program, lhs)
183
+ ):
184
+ return False
185
+
186
+ return True
187
+
188
+ def convert(self, exported_program, node) -> torch.fx.Node:
189
+ graph_module = exported_program.graph_module
190
+ graph = graph_module.graph
191
+
192
+ bmm_args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
193
+
194
+ lhs = bmm_args.input # const
195
+ rhs = bmm_args.mat2 # non-const
196
+ lhs_shape = extract_shape(lhs)
197
+ rhs_shape = extract_shape(rhs)
198
+ assert rhs_shape[0] == 1
199
+ assert lhs_shape[0] == 1
200
+
201
+ with graph.inserting_before(node):
202
+ rhs_tr = create_node(
203
+ graph,
204
+ torch.ops.aten.permute.default,
205
+ args=(rhs, [0, 2, 1]),
206
+ )
207
+ lhs_reshape = create_node(
208
+ graph,
209
+ torch.ops.aten.view.default,
210
+ args=(lhs, list(lhs_shape[1:])),
211
+ )
212
+
213
+ linear_node = create_node(
214
+ graph,
215
+ torch.ops.aten.linear.default,
216
+ args=(rhs_tr, lhs_reshape),
217
+ )
218
+
219
+ tr_linear_node = create_node(
220
+ graph,
221
+ torch.ops.aten.permute.default,
222
+ args=(linear_node, [0, 2, 1]),
223
+ )
224
+
225
+ node.replace_all_uses_with(tr_linear_node, propagate_meta=False)
226
+
227
+ return tr_linear_node
228
+
229
+
230
+ @trace_graph_diff_on_pass
231
+ class ConvertMatmulToLinear(PassBase):
232
+ """
233
+ This pass converts matmul(partially includes single-batch bmm) to linear selectively
234
+
235
+ How to select between `matmul` and `linear`?
236
+
237
+ * Linear has better quantization accuracy (NPU backend)
238
+ Due to ONE compiler's quantization policy;
239
+ FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
240
+ BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
241
+
242
+ * Matmul to Linear requires Transpose, which may harm latency
243
+ When RHS is constant, addtional transpose can be folded.
244
+
245
+ [RHS non-const case]
246
+ Constant folding cannot be performed.
247
+
248
+ lhs rhs (non-const)
249
+ | |
250
+ | transpose
251
+ | |
252
+ -- linear --
253
+ |
254
+ out
255
+
256
+ [RHS const case]
257
+ Constant folding can be performed to
258
+
259
+ lhs rhs (const) lh rhs (folded const)
260
+ | | | |
261
+ | transpose | |
262
+ | | | |
263
+ -- linear -- --> -- linear --
264
+ | |
265
+ out out
266
+
267
+
268
+ enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
269
+ enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ enable_lhs_const: Optional[bool] = False,
275
+ enable_rhs_const: Optional[bool] = True,
276
+ enable_single_batch_lhs_const_bmm: Optional[bool] = False,
277
+ ):
278
+ super().__init__()
279
+ self.converters: List[Converter] = []
280
+ if enable_lhs_const:
281
+ self.converters.append(LhsConstMatmulToLinearConverter())
282
+ if enable_rhs_const:
283
+ self.converters.append(RhsConstMatmulToLinearConverter())
284
+ if enable_single_batch_lhs_const_bmm:
285
+ self.converters.append(SingleBatchLhsConstBmmToLinearConverter())
286
+
287
+ def call(self, exported_program: ExportedProgram) -> PassResult:
288
+ logger = logging.getLogger(__name__)
289
+
290
+ graph_module = exported_program.graph_module
291
+ graph = graph_module.graph
292
+ modified = False
293
+ for node in graph.nodes:
294
+ if not node.op == "call_function":
295
+ continue
296
+
297
+ for converter in self.converters:
298
+ if not converter.match(exported_program, node):
299
+ continue
300
+
301
+ new_node = converter.convert(exported_program, node)
302
+ modified = True
303
+ logger.debug(
304
+ f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
305
+ )
306
+ continue
307
+
308
+ graph.eliminate_dead_code()
309
+ graph.lint()
310
+ graph_module.recompile()
311
+
312
+ return PassResult(modified)
@@ -172,7 +172,7 @@ class ConvertToReLU6(PassBase):
172
172
  converter.convert(exported_program, node)
173
173
  modified = True
174
174
  logger.debug(f"{node.name} is replaced with ReLU6 operator")
175
- break
175
+ continue
176
176
 
177
177
  graph.eliminate_dead_code()
178
178
  graph.lint()
@@ -245,9 +245,10 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
245
245
  # mask_user(output).args == (dequantize_per_tensor.tensor, mask)
246
246
  if mask:
247
247
  assert len(mask) == 1
248
- mask_user = list(mask[0].users.keys())[0]
249
- assert len(mask_user.args) == 1
250
- mask_user.args = ((mask_user.args[0][0],),)
248
+ if len(mask[0].users) > 0:
249
+ mask_user = list(mask[0].users.keys())[0]
250
+ assert len(mask_user.args) == 1
251
+ mask_user.args = ((mask_user.args[0][0],),)
251
252
  modified = True
252
253
  if (
253
254
  node.target
tico/passes/ops.py CHANGED
@@ -69,7 +69,6 @@ class AtenOps:
69
69
  torch.ops.aten.unsqueeze_copy.default,
70
70
  ]
71
71
  self.view = [
72
- torch.ops.aten.view,
73
72
  torch.ops.aten.view.default,
74
73
  torch.ops.aten.view_copy.default,
75
74
  ]
@@ -46,7 +46,9 @@ class RemoveRedundantExpand(PassBase):
46
46
  input, size = args.input, args.size
47
47
 
48
48
  input_shape = extract_shape(input)
49
- if list(input_shape) != size:
49
+ output_shape = extract_shape(node)
50
+
51
+ if input_shape != output_shape:
50
52
  continue
51
53
 
52
54
  node.replace_all_uses_with(input, propagate_meta=False)
@@ -0,0 +1,6 @@
1
+ from tico.quantization.public_interface import convert, prepare
2
+
3
+ __all__ = [
4
+ "convert",
5
+ "prepare",
6
+ ]