tico 0.1.0.dev250722__py3-none-any.whl → 0.1.0.dev250724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/experimental/quantization/__init__.py +5 -0
  4. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +1 -6
  5. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +1 -1
  6. tico/experimental/quantization/algorithm/pt2e/utils.py +0 -1
  7. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +1 -1
  8. tico/experimental/quantization/evaluation/evaluate.py +1 -1
  9. tico/experimental/quantization/passes/fold_quant_ops.py +0 -1
  10. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +1 -1
  11. tico/experimental/quantization/passes/quantize_bias.py +0 -1
  12. tico/experimental/quantization/passes/remove_weight_dequant_op.py +1 -1
  13. tico/passes/cast_aten_where_arg_type.py +1 -1
  14. tico/passes/cast_mixed_type_args.py +2 -2
  15. tico/passes/const_prop_pass.py +1 -1
  16. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  17. tico/passes/decompose_addmm.py +0 -3
  18. tico/passes/decompose_batch_norm.py +2 -2
  19. tico/passes/decompose_fake_quantize.py +0 -3
  20. tico/passes/decompose_fake_quantize_tensor_qparams.py +0 -2
  21. tico/passes/decompose_group_norm.py +0 -3
  22. tico/passes/legalize_predefined_layout_operators.py +2 -11
  23. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  24. tico/passes/lower_to_slice.py +1 -1
  25. tico/passes/merge_consecutive_cat.py +1 -1
  26. tico/passes/remove_redundant_expand.py +0 -5
  27. tico/passes/remove_redundant_reshape.py +5 -5
  28. tico/passes/segment_index_select.py +1 -1
  29. tico/serialize/circle_graph.py +1 -1
  30. tico/serialize/circle_serializer.py +234 -141
  31. tico/serialize/operators/op_any.py +0 -3
  32. tico/serialize/operators/op_clamp.py +2 -5
  33. tico/serialize/operators/op_full_like.py +0 -2
  34. tico/serialize/operators/op_instance_norm.py +0 -6
  35. tico/serialize/operators/op_mul.py +2 -8
  36. tico/serialize/operators/op_transpose_conv.py +0 -2
  37. tico/serialize/quant_param.py +5 -5
  38. tico/utils/convert.py +1 -1
  39. tico/utils/graph.py +1 -1
  40. tico/utils/padding.py +0 -2
  41. tico/utils/serialize.py +0 -3
  42. tico/utils/utils.py +1 -2
  43. tico/utils/validate_args_kwargs.py +1 -3
  44. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/METADATA +1 -1
  45. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/RECORD +49 -49
  46. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/LICENSE +0 -0
  47. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/WHEEL +0 -0
  48. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/entry_points.txt +0 -0
  49. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -20,8 +20,16 @@ from packaging.version import Version
20
20
  from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
+ __all__ = [
24
+ "CompileConfigV1",
25
+ "get_default_config",
26
+ "convert",
27
+ "convert_from_exported_program",
28
+ "convert_from_pt2",
29
+ ]
30
+
23
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250722"
32
+ __version__ = "0.1.0.dev250724"
25
33
 
26
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/config/base.py CHANGED
@@ -31,7 +31,7 @@ class CompileConfigBase:
31
31
  config = cls()
32
32
  for key in config_dict:
33
33
  if key in config.to_dict():
34
- assert type(config.get(key)) == bool
34
+ assert isinstance(config.get(key), bool)
35
35
  config.set(key, config_dict[key])
36
36
 
37
37
  return config
@@ -1 +1,6 @@
1
1
  from tico.experimental.quantization.public_interface import convert, prepare
2
+
3
+ __all__ = [
4
+ "convert",
5
+ "prepare",
6
+ ]
@@ -21,12 +21,7 @@ if TYPE_CHECKING:
21
21
  import torch.fx
22
22
  from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
23
23
  import torch
24
- from torch.ao.quantization.observer import (
25
- MinMaxObserver,
26
- MovingAverageMinMaxObserver,
27
- MovingAveragePerChannelMinMaxObserver,
28
- PerChannelMinMaxObserver,
29
- )
24
+ from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
30
25
  from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
31
26
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
32
27
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Callable, List, Optional, TYPE_CHECKING
15
+ from typing import Callable, Optional, TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
@@ -19,7 +19,6 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import QuantizationSpec
21
21
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
22
- from torch.utils import _pytree as pytree
23
22
 
24
23
  from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
24
  QuantizationConfig,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional
15
+ from typing import Dict, List, Optional
16
16
 
17
17
  import torch
18
18
 
@@ -166,7 +166,7 @@ def evaluate(
166
166
  )
167
167
  if not isinstance(backend, BACKEND):
168
168
  raise RuntimeError(
169
- f"Invalid backend. Please use tico.quantization.evaluate.BACKEND enum class"
169
+ "Invalid backend. Please use tico.quantization.evaluate.BACKEND enum class"
170
170
  )
171
171
  # Make it a list for simpler logic.
172
172
  if input_data is not None:
@@ -16,7 +16,6 @@ from typing import TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
- import copy
20
19
 
21
20
  import torch
22
21
  from torch.export import ExportedProgram
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
19
19
  import copy
20
20
 
21
21
  from collections import defaultdict
22
- from typing import Any, Callable
22
+ from typing import Any
23
23
 
24
24
  import torch
25
25
  from torch.export import ExportedProgram
@@ -16,7 +16,6 @@ from typing import TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
- import copy
20
19
 
21
20
  import torch
22
21
  from torch.export import ExportedProgram
@@ -53,7 +53,7 @@ class ValRange:
53
53
  if isinstance(val, torch.Tensor):
54
54
  self.max = torch.max(val).item()
55
55
  self.min = torch.min(val).item()
56
- elif type(val) == list:
56
+ elif isinstance(val, list):
57
57
  self.max = max(val)
58
58
  self.min = min(val)
59
59
  else:
@@ -176,7 +176,7 @@ class CastATenWhereArgType(PassBase):
176
176
  node_dtype = extract_torch_dtype(node)
177
177
  assert (
178
178
  node_dtype == node_dtype_ori
179
- ), f"Type casting doesn't change node's dtype."
179
+ ), "Type casting doesn't change node's dtype."
180
180
 
181
181
  logger.debug(
182
182
  f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
@@ -124,7 +124,7 @@ class CastMixedTypeArgs(PassBase):
124
124
  if rhs_val.dtype == type_to_promote:
125
125
  ori_type = lhs_val.dtype
126
126
  arg_to_promote = lhs
127
- assert arg_to_promote != None
127
+ assert arg_to_promote is not None
128
128
 
129
129
  if isinstance(arg_to_promote, torch.fx.Node):
130
130
  with graph.inserting_after(arg_to_promote):
@@ -178,7 +178,7 @@ class CastMixedTypeArgs(PassBase):
178
178
  node_dtype = extract_torch_dtype(node)
179
179
  assert (
180
180
  node_dtype == node_dtype_ori
181
- ), f"Type casting doesn't change node's dtype."
181
+ ), "Type casting doesn't change node's dtype."
182
182
 
183
183
  graph.eliminate_dead_code()
184
184
  graph.lint()
@@ -301,7 +301,7 @@ class ConstPropPass(PassBase):
301
301
  graph.eliminate_dead_code()
302
302
  graph_module.recompile()
303
303
 
304
- logger.debug(f"Constant nodes are propagated")
304
+ logger.debug("Constant nodes are propagated")
305
305
  # Constant folding can be done with only one time run. Let's set `modified` to False.
306
306
  modified = False
307
307
  return PassResult(modified)
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
- from tico.serialize.circle_graph import extract_shape
22
+ from tico.serialize.circle_mapping import extract_shape
23
23
  from tico.utils import logging
24
24
  from tico.utils.errors import NotYetSupportedError
25
25
  from tico.utils.graph import create_node
@@ -20,7 +20,6 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.serialize.circle_mapping import extract_shape
23
- from tico.utils import logging
24
23
  from tico.utils.graph import add_placeholder, create_node
25
24
  from tico.utils.passes import PassBase, PassResult
26
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
@@ -59,8 +58,6 @@ class DecomposeAddmm(PassBase):
59
58
  super().__init__()
60
59
 
61
60
  def call(self, exported_program: ExportedProgram) -> PassResult:
62
- logger = logging.getLogger(__name__)
63
-
64
61
  gm = exported_program.graph_module
65
62
  graph: torch.fx.Graph = gm.graph
66
63
  modified = False
@@ -96,9 +96,9 @@ class DecomposeBatchNorm(PassBase):
96
96
  eps = args.eps
97
97
 
98
98
  if not running_mean:
99
- raise NotYetSupportedError(f"running_mean=None is not supported yet")
99
+ raise NotYetSupportedError("running_mean=None is not supported yet")
100
100
  if not running_var:
101
- raise NotYetSupportedError(f"running_var=None is not supported yet")
101
+ raise NotYetSupportedError("running_var=None is not supported yet")
102
102
 
103
103
  """
104
104
  Only support the cases generated from torch.nn.BatchNorm2d module,
@@ -19,10 +19,8 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
 
21
21
  # To import torch.ops.quantized_decomposed related operator
22
- from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
23
22
  from torch.export import ExportedProgram
24
23
 
25
- from tico.utils import logging
26
24
  from tico.utils.graph import create_node
27
25
  from tico.utils.passes import PassBase, PassResult
28
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
@@ -66,7 +64,6 @@ class DecomposeFakeQuantize(PassBase):
66
64
  super().__init__()
67
65
 
68
66
  def call(self, exported_program: ExportedProgram) -> PassResult:
69
- logger = logging.getLogger(__name__)
70
67
  modified = False
71
68
 
72
69
  gm = exported_program.graph_module
@@ -26,10 +26,8 @@ from torch._export.utils import (
26
26
  )
27
27
 
28
28
  # To import torch.ops.quantized_decomposed related operator
29
- from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
30
29
  from torch.export import ExportedProgram
31
30
 
32
- from tico.utils import logging
33
31
  from tico.utils.graph import create_node
34
32
  from tico.utils.passes import PassBase, PassResult
35
33
  from tico.utils.trace_decorators import (
@@ -22,7 +22,6 @@ import torch
22
22
  from torch.export import ExportedProgram
23
23
 
24
24
  from tico.serialize.circle_mapping import extract_shape
25
- from tico.utils import logging
26
25
  from tico.utils.graph import create_node
27
26
  from tico.utils.passes import PassBase, PassResult
28
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
@@ -126,8 +125,6 @@ class DecomposeGroupNorm(PassBase):
126
125
  )
127
126
 
128
127
  def call(self, exported_program: ExportedProgram) -> PassResult:
129
- logger = logging.getLogger(__name__)
130
-
131
128
  gm = exported_program.graph_module
132
129
  graph: torch.fx.Graph = gm.graph
133
130
  modified = False
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from torch.export import ExportedProgram
22
22
 
23
- from tico.serialize.circle_graph import extract_shape
23
+ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.errors import NotYetSupportedError
26
26
  from tico.utils.graph import create_node
@@ -206,7 +206,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
206
206
 
207
207
  args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
208
208
  input = args.input
209
- padding = args.padding
210
209
  groups = args.groups
211
210
  dilation = args.dilation
212
211
 
@@ -288,13 +287,12 @@ class LegalizePreDefinedLayoutOperators(PassBase):
288
287
  input = args.input
289
288
  weight = args.weight
290
289
  bias = args.bias
291
- eps = args.eps
292
290
 
293
291
  running_mean = args.running_mean
294
292
  running_var = args.running_var
295
293
  use_input_stats = args.use_input_stats
296
294
 
297
- if not (use_input_stats == True):
295
+ if not use_input_stats:
298
296
  raise NotYetSupportedError("Only support use_input_stats is True.")
299
297
  if not isinstance(running_mean, NoneType):
300
298
  raise NotYetSupportedError("Only support running_mean=None")
@@ -350,10 +348,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
350
348
  # max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
351
349
  args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
352
350
  input_ = args.input
353
- kernel_size = args.kernel_size
354
- stride = args.stride
355
- padding = args.padding
356
- dilation = args.dilation
357
351
  ceil_mode = args.ceil_mode
358
352
  if ceil_mode:
359
353
  raise NotYetSupportedError("Only support non-ceil model.")
@@ -402,9 +396,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
402
396
  # avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
403
397
  args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
404
398
  input_ = args.input
405
- kernel_size = args.kernel_size
406
- stride = args.stride
407
- padding = args.padding
408
399
  ceil_mode = args.ceil_mode
409
400
  if ceil_mode:
410
401
  raise NotYetSupportedError("Only support non-ceil model.")
@@ -67,7 +67,7 @@ class LowerToResizeNearestNeighbor(PassBase):
67
67
  return None
68
68
  # indices = [None, None, H index, W index]
69
69
  N, C, H, W = indices
70
- if N != None or C != None:
70
+ if N is not None or C is not None:
71
71
  return None
72
72
  if not isinstance(H, torch.fx.Node):
73
73
  return None
@@ -28,7 +28,7 @@ from torch._export.utils import (
28
28
  from torch.export import ExportedProgram
29
29
 
30
30
  from tico.passes import ops
31
- from tico.serialize.circle_graph import extract_shape
31
+ from tico.serialize.circle_mapping import extract_shape
32
32
  from tico.utils import logging
33
33
  from tico.utils.graph import create_node, is_single_value_tensor
34
34
  from tico.utils.passes import PassBase, PassResult
@@ -51,7 +51,7 @@ class MergeConsecutiveCat(PassBase):
51
51
  if not prev_cat.op == "call_function":
52
52
  continue
53
53
 
54
- if not prev_cat.target in ops.aten.cat:
54
+ if prev_cat.target not in ops.aten.cat:
55
55
  continue
56
56
 
57
57
  prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
@@ -12,11 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
15
  from torch.export import ExportedProgram
21
16
 
22
17
  from tico.passes import ops
@@ -90,7 +90,7 @@ class RemoveRedundantReshapePattern1(PassBase):
90
90
  if len(permute.users) != 1:
91
91
  continue
92
92
  permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
93
- permute_input, permute_dims = permute_args.input, permute_args.dims
93
+ permute_dims = permute_args.dims
94
94
  # (1xAxBxC) - `aten.permute` - (1xAxCxB)
95
95
  if permute_dims != [0, 1, 3, 2]:
96
96
  continue
@@ -172,7 +172,7 @@ class RemoveRedundantReshapePattern2(PassBase):
172
172
  if len(permute.users) != 1:
173
173
  continue
174
174
  permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
175
- permute_input, permute_dims = permute_args.input, permute_args.dims
175
+ permute_dims = permute_args.dims
176
176
  # (1xAxBxC) - `aten.permute` - (Bx1xAxC)
177
177
  if permute_dims != [2, 0, 1, 3]:
178
178
  continue
@@ -262,7 +262,7 @@ class RemoveRedundantReshapePattern3(PassBase):
262
262
  continue
263
263
 
264
264
  # add
265
- if not add.target in ops.aten.add:
265
+ if add.target not in ops.aten.add:
266
266
  continue
267
267
  add_args = AddTensorArgs(*add.args, **add.kwargs) # type: ignore[arg-type]
268
268
  reshape_2, reshape_3 = add_args.input, add_args.other
@@ -272,7 +272,7 @@ class RemoveRedundantReshapePattern3(PassBase):
272
272
  # reshape_2
273
273
  if not reshape_2.op == "call_function":
274
274
  continue
275
- if not reshape_2.target in ops.aten.reshape:
275
+ if reshape_2.target not in ops.aten.reshape:
276
276
  continue
277
277
  reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
278
278
  reshape_2_input = reshape_2_args.input
@@ -280,7 +280,7 @@ class RemoveRedundantReshapePattern3(PassBase):
280
280
  # reshape_3
281
281
  if not reshape_3.op == "call_function":
282
282
  continue
283
- if not reshape_3.target in ops.aten.reshape:
283
+ if reshape_3.target not in ops.aten.reshape:
284
284
  continue
285
285
  reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
286
286
  reshape_3_input = reshape_3_args.input
@@ -29,7 +29,7 @@ from torch._export.utils import (
29
29
  from torch.export import ExportedProgram
30
30
 
31
31
  from tico.passes import ops
32
- from tico.serialize.circle_graph import extract_shape
32
+ from tico.serialize.circle_mapping import extract_shape
33
33
  from tico.utils import logging
34
34
  from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
35
35
  from tico.utils.passes import PassBase, PassResult
@@ -323,7 +323,7 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
323
323
  self, node: Union[torch.fx.Node, circle.Tensor.TensorT, ConstData]
324
324
  ) -> int:
325
325
  # return -1 if node is None. This is for generating CircleOutputExclude
326
- if node == None:
326
+ if node is None:
327
327
  return -1
328
328
 
329
329
  if hasattr(node, "name") and node.name in self.name_to_tid: