tico 0.1.0.dev250812__py3-none-any.whl → 0.1.0.dev250814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250812"
32
+ __version__ = "0.1.0.dev250814"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,13 @@
1
+ from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
2
+ from tico.experimental.quantization.ptq.observers.base import ObserverBase
3
+ from tico.experimental.quantization.ptq.observers.ema import EMAObserver
4
+ from tico.experimental.quantization.ptq.observers.identity import IdentityObserver
5
+ from tico.experimental.quantization.ptq.observers.minmax import MinMaxObserver
6
+
7
+ __all__ = [
8
+ "AffineObserverBase",
9
+ "ObserverBase",
10
+ "EMAObserver",
11
+ "IdentityObserver",
12
+ "MinMaxObserver",
13
+ ]
@@ -0,0 +1,128 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from tico.experimental.quantization.ptq.dtypes import DType, UINT8
21
+ from tico.experimental.quantization.ptq.observers.base import ObserverBase
22
+ from tico.experimental.quantization.ptq.qscheme import QScheme
23
+
24
+
25
+ class AffineObserverBase(ObserverBase):
26
+ """Base for affine observers (min/max → scale/zp)."""
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ name: str,
32
+ dtype: DType = UINT8,
33
+ qscheme: QScheme = QScheme.PER_TENSOR_ASYMM,
34
+ channel_axis: Optional[int] = None,
35
+ ):
36
+ super().__init__(
37
+ name=name, dtype=dtype, qscheme=qscheme, channel_axis=channel_axis
38
+ )
39
+
40
+ def reset(self) -> None:
41
+ """
42
+ Reset running min/max and drop cached qparams.
43
+ """
44
+ self.min_val: torch.Tensor = torch.tensor(math.inf)
45
+ self.max_val: torch.Tensor = torch.tensor(-math.inf)
46
+ if hasattr(self, "_cached_scale"):
47
+ del self._cached_scale
48
+ if hasattr(self, "_cached_zp"):
49
+ del self._cached_zp
50
+
51
+ def load_qparams(self, scale: torch.Tensor, zp: torch.Tensor, *, lock: bool = True):
52
+ """
53
+ Inject externally computed qparams and optionally lock the observer.
54
+
55
+ When locked, subsequent `collect()` calls are ignored.
56
+ """
57
+ self._cached_scale = scale.detach()
58
+ self._cached_zp = zp.to(torch.int)
59
+ if lock:
60
+ self.enabled = False
61
+
62
+ @property
63
+ def has_qparams(self) -> bool:
64
+ return hasattr(self, "_cached_scale")
65
+
66
+ def compute_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ qmin, qmax = self.dtype.qmin, self.dtype.qmax
68
+ rng = self.max_val - self.min_val
69
+ eps = 1e-12
70
+
71
+ if self.qscheme.is_symmetric():
72
+ max_abs = torch.maximum(self.max_val.abs(), self.min_val.abs())
73
+ scale = torch.clamp(max_abs, min=eps) / qmax
74
+ zp = torch.zeros_like(scale, dtype=torch.int)
75
+ self._cached_scale, self._cached_zp = scale, zp
76
+ return scale, zp
77
+
78
+ if self.channel_axis is None:
79
+ if torch.all(rng.abs() < 1e-8):
80
+ C = self.min_val
81
+ if torch.allclose(C, torch.zeros_like(C)):
82
+ scale = torch.ones_like(C)
83
+ zp = torch.zeros_like(C, dtype=torch.int)
84
+ elif (C > 0).all():
85
+ scale = torch.clamp(C, min=eps)
86
+ zp = torch.zeros_like(C, dtype=torch.int)
87
+ else:
88
+ scale = torch.clamp(C.abs(), min=eps)
89
+ zp = torch.full_like(C, qmax, dtype=torch.int)
90
+ else:
91
+ scale = torch.clamp(rng, min=eps) / (qmax - qmin)
92
+ zp = (
93
+ torch.round(qmin - self.min_val / scale)
94
+ .clamp(qmin, qmax)
95
+ .to(torch.int)
96
+ )
97
+ else:
98
+ scale = torch.clamp(rng, min=eps) / (qmax - qmin)
99
+ zp = (
100
+ torch.round(qmin - self.min_val / scale).clamp(qmin, qmax).to(torch.int)
101
+ )
102
+
103
+ self._cached_scale, self._cached_zp = scale, zp
104
+ return scale, zp
105
+
106
+ def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.has_qparams:
108
+ raise RuntimeError(
109
+ "Call compute_qparams()/freeze_qparams() or load_qparams() first."
110
+ )
111
+ scale, zp = self._cached_scale, self._cached_zp
112
+ if self.channel_axis is None:
113
+ return torch.fake_quantize_per_tensor_affine(
114
+ x,
115
+ scale=scale,
116
+ zero_point=zp,
117
+ quant_min=self.dtype.qmin,
118
+ quant_max=self.dtype.qmax,
119
+ )
120
+ else:
121
+ return torch.fake_quantize_per_channel_affine(
122
+ x,
123
+ scale=scale,
124
+ zero_point=zp,
125
+ axis=self.channel_axis,
126
+ quant_min=self.dtype.qmin,
127
+ quant_max=self.dtype.qmax,
128
+ )
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from tico.experimental.quantization.ptq.dtypes import DType, UINT8
21
+ from tico.experimental.quantization.ptq.qscheme import QScheme
22
+
23
+
24
+ class ObserverBase(ABC):
25
+ """
26
+ Minimal abstract base for all observers/quantizers.
27
+
28
+ Subclasses must implement:
29
+ - reset()
30
+ - collect(x)
31
+ - fake_quant(x)
32
+ - compute_qparams(): optional in practice for some observers (e.g., MX),
33
+ but still part of the interface; those can return None.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ *,
39
+ name: str,
40
+ dtype: DType = UINT8,
41
+ qscheme: QScheme = QScheme.PER_TENSOR_ASYMM,
42
+ channel_axis: Optional[int] = None, # None → per-tensor
43
+ ):
44
+ self.name = name
45
+ self.dtype = dtype
46
+ self.qscheme = qscheme
47
+ self.channel_axis = channel_axis if qscheme.is_per_channel() else None
48
+ self.enabled = True
49
+ self.reset()
50
+
51
+ @abstractmethod
52
+ def reset(self) -> None:
53
+ """Clear any running statistics or cached params."""
54
+ raise NotImplementedError
55
+
56
+ def collect(self, x: torch.Tensor) -> None:
57
+ """
58
+ Update running statistics with a new batch of data.
59
+
60
+ This base implementation guards on `enabled` and then calls `_update_stats(x)`.
61
+ Subclasses should implement `_update_stats(x)` instead of overriding `collect`.
62
+ """
63
+ if not self.enabled:
64
+ return
65
+ self._update_stats(x)
66
+
67
+ @abstractmethod
68
+ def _update_stats(self, x: torch.Tensor) -> None:
69
+ """
70
+ Update running statistics (min/max, hist, mse buffers, ...).
71
+
72
+ Must be implemented by subclasses (e.g., MinMax, EMA, Histogram, MSE).
73
+ """
74
+ raise NotImplementedError
75
+
76
+ @abstractmethod
77
+ def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Apply the observer's quantization.
80
+ Implementations may or may not rely on qparams.
81
+ """
82
+ raise NotImplementedError
83
+
84
+ @abstractmethod
85
+ def compute_qparams(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
86
+ """
87
+ Compute and (if applicable) cache quantization params.
88
+ Affine observers typically return (scale, zero_point).
89
+ Observers that do not use qparams (e.g., MX) may return None.
90
+ """
91
+ raise NotImplementedError
92
+
93
+ # String repr helps debugging
94
+ def __repr__(self) -> str:
95
+ return (
96
+ f"{self.__class__.__name__}(name={self.name}, dtype={str(self.dtype)}, "
97
+ f"qscheme={str(self.qscheme)}, channel_axis={self.channel_axis}, enabled={self.enabled})"
98
+ )
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
18
+ from tico.experimental.quantization.ptq.utils import channelwise_minmax
19
+
20
+
21
+ class EMAObserver(AffineObserverBase):
22
+ """
23
+ Exponential-Moving-Average min/max tracker.
24
+
25
+ Why?
26
+ -----
27
+ • Smoother than raw MinMax (reduces outlier shock).
28
+ • Much cheaper than histogram/MSE observers.
29
+
30
+ The update rule follows the common "momentum" form:
31
+
32
+ ema = momentum * ema + (1 - momentum) * new_value
33
+
34
+ With momentum → 0: *fast* adaptation, momentum → 1: *slow* adaptation.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ momentum: float = 0.9,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ assert 0.0 < momentum < 1.0, "momentum must be in (0, 1)"
45
+ self.momentum = momentum
46
+
47
+ @torch.no_grad()
48
+ def _update_stats(self, x: torch.Tensor):
49
+ if self.channel_axis is None:
50
+ curr_min, curr_max = x.min(), x.max()
51
+ else:
52
+ curr_min, curr_max = channelwise_minmax(x, self.channel_axis)
53
+
54
+ if (
55
+ torch.isinf(self.min_val).any() and torch.isinf(self.max_val).any()
56
+ ): # first batch → hard init
57
+ self.min_val, self.max_val = curr_min, curr_max
58
+ return
59
+
60
+ m = self.momentum
61
+ self.min_val = m * self.min_val + (1 - m) * curr_min
62
+ self.max_val = m * self.max_val + (1 - m) * curr_max
@@ -0,0 +1,74 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ IdentityObserver: a *no-op* observer for FP-only modules.
17
+
18
+ Motivation
19
+ ----------
20
+ Some layers should stay in full precision even when the rest of the model
21
+ is quantized. Attaching an `IdentityObserver` satisfies the wrapper API
22
+ (`_update_stats()`, `compute_qparams()`, `fake_quant()`) without actually
23
+ performing any statistics gathering or fake-quantization.
24
+ """
25
+ import torch
26
+
27
+ from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
28
+
29
+
30
+ class IdentityObserver(AffineObserverBase):
31
+ """
32
+ Passthrough observer that **never** alters the tensor.
33
+
34
+ • `_update_stats()` → does nothing
35
+ • `compute_qparams()` → returns (1.0, 0) *dummy* q-params
36
+ • `fake_quant()` → returns *x* unchanged
37
+ """
38
+
39
+ def __init__(self, **kwargs):
40
+ # Call parent so the usual fields (`dtype`, `qscheme`, …) exist,
41
+ # but immediately disable any stateful behaviour.
42
+ super().__init__(**kwargs)
43
+
44
+ # Deactivate statistics collection permanently.
45
+ self.enabled = False
46
+
47
+ # Pre-cache sentinel q-params so wrapper code that blindly
48
+ # accesses them won't crash.
49
+ self._cached_scale = torch.tensor(1.0)
50
+ self._cached_zp = torch.tensor(0, dtype=torch.int)
51
+
52
+ def reset(self) -> None: # (simple override – nothing to do)
53
+ """No internal state to reset."""
54
+ pass
55
+
56
+ def _update_stats(self, x: torch.Tensor) -> None:
57
+ """Skip statistic collection entirely."""
58
+ return
59
+
60
+ def compute_qparams(self):
61
+ """
62
+ Return the pre-cached (scale, zero_point) tuple.
63
+
64
+ Keeping the signature identical to other observers allows uniform
65
+ lifecycle management in wrapper code.
66
+ """
67
+ return self._cached_scale, self._cached_zp
68
+
69
+ def fake_quant(self, x: torch.Tensor):
70
+ """Identity mapping — leaves *x* in FP."""
71
+ return x
72
+
73
+ def __repr__(self) -> str:
74
+ return f"{self.__class__.__name__}()"
@@ -0,0 +1,39 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
18
+ from tico.experimental.quantization.ptq.utils import channelwise_minmax
19
+
20
+
21
+ class MinMaxObserver(AffineObserverBase):
22
+ """Plain min/max range tracker."""
23
+
24
+ @torch.no_grad()
25
+ def _update_stats(self, x: torch.Tensor) -> None:
26
+ """
27
+ Update running min/max with the incoming batch.
28
+
29
+ Per-tensor: use global min/max.
30
+ Per-channel: reduce all axes except the channel axis.
31
+ """
32
+ if self.channel_axis is None:
33
+ curr_min, curr_max = x.min(), x.max()
34
+ else:
35
+ curr_min, curr_max = channelwise_minmax(x, self.channel_axis)
36
+
37
+ # Broadcasting handles scalar-vs-vector cases
38
+ self.min_val = torch.minimum(self.min_val, curr_min)
39
+ self.max_val = torch.maximum(self.max_val, curr_max)
@@ -0,0 +1,5 @@
1
+ from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
2
+
3
+ __all__ = [
4
+ "channelwise_minmax",
5
+ ]
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ def channelwise_minmax(x: torch.Tensor, channel_axis: int):
19
+ """
20
+ Compute per-channel (min, max) by reducing all axes except `channel_axis`.
21
+ """
22
+ channel_axis = channel_axis % x.ndim # handle negative indices safely
23
+ dims = tuple(d for d in range(x.ndim) if d != channel_axis)
24
+
25
+ return x.amin(dim=dims), x.amax(dim=dims)
@@ -244,7 +244,7 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
244
244
  # So, let's remove `mask` from the output.args first.
245
245
  # mask_user(output).args == (dequantize_per_tensor.tensor, mask)
246
246
  if mask:
247
- len(mask) == 1
247
+ assert len(mask) == 1
248
248
  mask_user = list(mask[0].users.keys())[0]
249
249
  assert len(mask_user.args) == 1
250
250
  mask_user.args = ((mask_user.args[0][0],),)
@@ -707,7 +707,7 @@ def CircleRMSNorm():
707
707
  @custom_op("circle_custom::rms_norm", mutates_args=())
708
708
  def rms_norm(
709
709
  hidden_states: torch.Tensor,
710
- weight: Optional[torch.Tensor] = None,
710
+ weight: torch.Tensor,
711
711
  eps: float = 1e-05,
712
712
  ) -> torch.Tensor:
713
713
  input_dtype = hidden_states.dtype
@@ -719,7 +719,7 @@ def CircleRMSNorm():
719
719
  @register_fake("circle_custom::rms_norm")
720
720
  def _(
721
721
  hidden_states: torch.Tensor,
722
- weight: Optional[torch.Tensor] = None,
722
+ weight: torch.Tensor,
723
723
  eps: float = 1e-05,
724
724
  ) -> torch.Tensor:
725
725
  return hidden_states.new_empty(hidden_states.size())
tico/utils/utils.py CHANGED
@@ -79,73 +79,70 @@ def enforce_type(callable):
79
79
  def check_types(*args, **kwargs):
80
80
  parameters = dict(zip(spec.args, args))
81
81
  parameters.update(kwargs)
82
- for name, value in parameters.items():
83
- if name == "self":
84
- # skip 'self' in spec.args
85
- continue
86
-
87
- assert (
88
- name in spec.annotations
89
- ), f"All parameter require type hints. {name} needs a type hint"
90
-
91
- type_hint = spec.annotations[name]
92
82
 
93
- # Return tuple of flattened types.
94
- # Q) What is flatten?
95
- # A) Optional/Union is not included. Below are included.
96
- # collections: List, Set, ...
97
- # primitive types: int, str, ...
98
- def _flatten_type(type_hint) -> tuple:
99
- # `get_origin` maps Union[...] and Optional[...] varieties to Union
100
- if typing.get_origin(type_hint) == typing.Union:
101
- # ex. typing.Union[list, int] -> (list, int)
102
- # ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
103
- actual_type = tuple(
104
- [_flatten_type(t) for t in typing.get_args(type_hint)]
105
- )
106
- else:
107
- actual_type = (type_hint,)
108
- return actual_type
83
+ # Return tuple of flattened types.
84
+ # Q) What is flatten?
85
+ # A) Optional/Union is not included. Below are included.
86
+ # collections: List, Set, ...
87
+ # primitive types: int, str, ...
88
+ def _flatten_type(type_hint) -> tuple:
89
+ # `get_origin` maps Union[...] and Optional[...] varieties to Union
90
+ if typing.get_origin(type_hint) == typing.Union:
91
+ # ex. typing.Union[list, int] -> (list, int)
92
+ # ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
93
+ actual_type = tuple(
94
+ _flatten_type(t) for t in typing.get_args(type_hint)
95
+ )
96
+ else:
97
+ actual_type = (type_hint,)
98
+ return actual_type
109
99
 
110
- type_hint = _flatten_type(type_hint)
100
+ # Return true if value matches with type_hint
101
+ # Return false otherwise
102
+ def _check_type(value, type_hint):
103
+ if type_hint == typing.Any:
104
+ return True
111
105
 
112
- # Return true if value matches with type_hint
113
- # Return false otherwise
114
- def _check_type(value, type_hint):
115
- if type_hint == typing.Any:
116
- return True
106
+ if isinstance(type_hint, tuple):
107
+ return any(_check_type(value, t) for t in type_hint)
117
108
 
118
- if isinstance(type_hint, tuple):
119
- return any([_check_type(value, t) for t in type_hint])
109
+ if typing.get_origin(type_hint) in (list, set):
110
+ if not isinstance(value, typing.get_origin(type_hint)):
111
+ return False
120
112
 
121
- if typing.get_origin(type_hint) in (list, set):
122
- if not isinstance(value, typing.get_origin(type_hint)):
113
+ for v in value:
114
+ if not any(_check_type(v, t) for t in typing.get_args(type_hint)):
123
115
  return False
124
116
 
125
- for v in value:
126
- if not any(
127
- [_check_type(v, t) for t in typing.get_args(type_hint)]
128
- ):
129
- return False
117
+ return True
130
118
 
131
- return True
119
+ if typing.get_origin(type_hint) is dict:
120
+ if not isinstance(value, typing.get_origin(type_hint)):
121
+ return False
132
122
 
133
- if typing.get_origin(type_hint) is dict:
134
- if not isinstance(value, typing.get_origin(type_hint)):
123
+ for k, v in value.items():
124
+ k_type, v_type = typing.get_args(type_hint)
125
+ if not _check_type(k, k_type):
126
+ return False
127
+ if not _check_type(v, v_type):
135
128
  return False
136
129
 
137
- for k, v in value.items():
138
- k_type, v_type = typing.get_args(type_hint)
139
- if not _check_type(k, k_type):
140
- return False
141
- if not _check_type(v, v_type):
142
- return False
130
+ return True
131
+
132
+ # TODO: Support more type hints
133
+ return isinstance(value, type_hint)
143
134
 
144
- return True
135
+ for name, value in parameters.items():
136
+ if name == "self":
137
+ # skip 'self' in spec.args
138
+ continue
145
139
 
146
- # TODO: Support more type hints
147
- return isinstance(value, type_hint)
140
+ assert (
141
+ name in spec.annotations
142
+ ), f"All parameter require type hints. {name} needs a type hint"
148
143
 
144
+ type_hint = spec.annotations[name]
145
+ type_hint = _flatten_type(type_hint)
149
146
  type_check_result = _check_type(value, type_hint)
150
147
  if not type_check_result:
151
148
  raise ArgTypeError(
@@ -175,13 +175,12 @@ class CatArgs:
175
175
  @dataclass
176
176
  class CircleRMSNormArgs:
177
177
  """
178
- This is not aten ops but custom op for RMSNorm.
179
- circle_custom.rms_norm(Tensor input, Tensor? weight=None, float? eps=None) -> Tensor
178
+ For circle.BuiltinOperator.BuiltinOperator.RMS_NORM
180
179
  """
181
180
 
182
181
  input: torch.fx.Node
183
- weight: Optional[torch.fx.Node]
184
- eps: Optional[float]
182
+ weight: torch.fx.Node
183
+ eps: float
185
184
 
186
185
 
187
186
  @enforce_type
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250812
3
+ Version: 0.1.0.dev250814
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=QzZEOB1tLI8yPD6dzOftGhdVFbFkAintnJP1vv89AJQ,1883
1
+ tico/__init__.py,sha256=Y5AdGv7QfIPC6o_P2M2S16JtXG5LLmtIfUU7gq6scYQ,1883
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
@@ -60,6 +60,14 @@ tico/experimental/quantization/ptq/__init__.py,sha256=ZoPdEwZ1i1n5pBFChx8GuUrkfR
60
60
  tico/experimental/quantization/ptq/dtypes.py,sha256=xfCBtq6mQmUYRwsoFgII6gvRl1raQi0Inj9pznDuKwQ,2236
61
61
  tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAdh_3PWB-ptPKaI,1052
62
62
  tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
63
+ tico/experimental/quantization/ptq/observers/__init__.py,sha256=wyrO0KTZve78aFWTwvsOE82Vu2kbCxJv8aqjiO1QL2s,524
64
+ tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
65
+ tico/experimental/quantization/ptq/observers/base.py,sha256=Wons1MzpqK1mfcy-ppl-B2Dum0edXg2dWW2Lw3V18tw,3280
66
+ tico/experimental/quantization/ptq/observers/ema.py,sha256=WZiYWEHrkgizAwnRCtfOm9JPHfZrjZTxMr6X9Wuovmo,2061
67
+ tico/experimental/quantization/ptq/observers/identity.py,sha256=jdlNH52z8ANOZbs_0KFZ4iEstVfNC1OUzQsm1a9FFpM,2595
68
+ tico/experimental/quantization/ptq/observers/minmax.py,sha256=mLHkwIzWFzQXev7EU7w1333KckwRjukc3_cUPJOnUfs,1486
69
+ tico/experimental/quantization/ptq/utils/__init__.py,sha256=PL9IZgiWoMtsXVljeOy7KymmLVP238SXEFRLXYK72WQ,126
70
+ tico/experimental/quantization/ptq/utils/reduce_utils.py,sha256=3kWawLB91EcvvHlCrNqqfZF7tpgr22htBSA049mKw_4,973
63
71
  tico/interpreter/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
64
72
  tico/interpreter/infer.py,sha256=1ZFe3DVMR2mlwBosoedqoL0-CGN_01CKLgMgxuw62KA,4861
65
73
  tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
@@ -75,7 +83,7 @@ tico/passes/convert_to_relu6.py,sha256=1BJpUwUb6Zli_1y3eyJQo7dg9B1xvZ7sYjMbvEQsF
75
83
  tico/passes/decompose_addmm.py,sha256=KjnpZjSuA0uvNmKaTN_EMwobcOi3CAB81buORzTDxro,3979
76
84
  tico/passes/decompose_batch_norm.py,sha256=06LAxhSmpTxFZJmUelwB3I_GipNWrLoM7PfM6ZkxOZY,6512
77
85
  tico/passes/decompose_fake_quantize.py,sha256=736srs8SM8K_mLR0WG10LVMMLRkYkBM9OF0k1GCkAW0,5218
78
- tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=k9MJhMVABFNF6lXgEum1fJyGpdQwVRKxWOYhkMR2M7c,13915
86
+ tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=p-sz_cgir4jMWp43CR75fj0TbLkNvSl888fvkRqFRtE,13922
79
87
  tico/passes/decompose_group_norm.py,sha256=6BqvYtMTPzeIgp8cPA8OFMwEBvb7odcg04IUgwtp7NQ,10120
80
88
  tico/passes/decompose_grouped_conv2d.py,sha256=n2qv320akL1ju33ucZ6lU1cKEAaj0NI8YZ5CrUnkRLM,8512
81
89
  tico/passes/decompose_slice_scatter.py,sha256=xqMHKhW2595YoAeubKZ4jRhYW4TQ09EXPgLNgODqXG8,5653
@@ -203,20 +211,20 @@ tico/utils/padding.py,sha256=qKke-dJeeLHiRaePjDS66txrGyiYuipLVQeqLYad8uk,3349
203
211
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
204
212
  tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
205
213
  tico/utils/record_input.py,sha256=QN-8D71G_WAX3QQQ5CIwbEfFJZTQ3CvL4wCMiVddua4,3894
206
- tico/utils/register_custom_op.py,sha256=n91UtmPedoqhkR8fBNRbk9Msq79pn9DHNHlt99l2s_w,28142
214
+ tico/utils/register_custom_op.py,sha256=dPemLyjrf4xMYCUlhhZeUhixL8Eat3Ywlv6K5kTqG8Y,28108
207
215
  tico/utils/serialize.py,sha256=mEuusEzi82WFsz3AkowgWwxSLeo50JDxyOj6yYDQhEI,1914
208
216
  tico/utils/signature.py,sha256=R2GV0alRpXEbZISqPKyxCUWbgDcsrQ2ovbVG3737IzA,9595
209
217
  tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
210
218
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
211
- tico/utils/utils.py,sha256=A5p3iAAxRGDsZJh4ybp-Qo3MX3vk5RrmSY-R3rXqVeI,12976
212
- tico/utils/validate_args_kwargs.py,sha256=yikeUbYfSg2378wagEMXDlJeSRv8HKI2oxpjWarolec,27268
219
+ tico/utils/utils.py,sha256=aySftYnNTsqVAMcGs_3uX3-hz577a2cj4p1aVV-1XeQ,12747
220
+ tico/utils/validate_args_kwargs.py,sha256=aY7hyDaZKrZn0ev0liUFZdHS8__0Vpp5QyqDEobZ_zM,27163
213
221
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
214
222
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
215
223
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
216
224
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
217
- tico-0.1.0.dev250812.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
218
- tico-0.1.0.dev250812.dist-info/METADATA,sha256=VYlCGN-A_iI52ctvoTcdzg3UrtQ2zhZRCxyunaCdiqU,8450
219
- tico-0.1.0.dev250812.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
220
- tico-0.1.0.dev250812.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
221
- tico-0.1.0.dev250812.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
222
- tico-0.1.0.dev250812.dist-info/RECORD,,
225
+ tico-0.1.0.dev250814.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
226
+ tico-0.1.0.dev250814.dist-info/METADATA,sha256=CWbuRvOwPOiqf8FwnG5H8tij1fQqH3k_JpO8zhtXQfg,8450
227
+ tico-0.1.0.dev250814.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
228
+ tico-0.1.0.dev250814.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
229
+ tico-0.1.0.dev250814.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
230
+ tico-0.1.0.dev250814.dist-info/RECORD,,