tico 0.1.0.dev250818__py3-none-any.whl → 0.1.0.dev250820__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/ptq/observers/__init__.py +2 -0
- tico/experimental/quantization/ptq/observers/ema.py +1 -1
- tico/experimental/quantization/ptq/observers/identity.py +5 -5
- tico/experimental/quantization/ptq/observers/mx.py +60 -0
- tico/experimental/quantization/ptq/quant_config.py +111 -0
- tico/experimental/quantization/ptq/wrappers/__init__.py +0 -0
- tico/experimental/quantization/ptq/wrappers/nn/__init__.py +5 -0
- tico/experimental/quantization/ptq/wrappers/nn/quant_linear.py +66 -0
- tico/experimental/quantization/ptq/wrappers/quant_module_base.py +153 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +112 -0
- {tico-0.1.0.dev250818.dist-info → tico-0.1.0.dev250820.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250818.dist-info → tico-0.1.0.dev250820.dist-info}/RECORD +17 -10
- {tico-0.1.0.dev250818.dist-info → tico-0.1.0.dev250820.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250818.dist-info → tico-0.1.0.dev250820.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250818.dist-info → tico-0.1.0.dev250820.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250818.dist-info → tico-0.1.0.dev250820.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -3,6 +3,7 @@ from tico.experimental.quantization.ptq.observers.base import ObserverBase
|
|
3
3
|
from tico.experimental.quantization.ptq.observers.ema import EMAObserver
|
4
4
|
from tico.experimental.quantization.ptq.observers.identity import IdentityObserver
|
5
5
|
from tico.experimental.quantization.ptq.observers.minmax import MinMaxObserver
|
6
|
+
from tico.experimental.quantization.ptq.observers.mx import MXObserver
|
6
7
|
|
7
8
|
__all__ = [
|
8
9
|
"AffineObserverBase",
|
@@ -10,4 +11,5 @@ __all__ = [
|
|
10
11
|
"EMAObserver",
|
11
12
|
"IdentityObserver",
|
12
13
|
"MinMaxObserver",
|
14
|
+
"MXObserver",
|
13
15
|
]
|
@@ -31,7 +31,7 @@ class EMAObserver(AffineObserverBase):
|
|
31
31
|
|
32
32
|
ema = momentum * ema + (1 - momentum) * new_value
|
33
33
|
|
34
|
-
With momentum → 0:
|
34
|
+
With momentum → 0: FAST adaptation, momentum → 1: SLOW adaptation.
|
35
35
|
"""
|
36
36
|
|
37
37
|
def __init__(
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
"""
|
16
|
-
IdentityObserver: a
|
16
|
+
IdentityObserver: a "no-op" observer for FP-only modules.
|
17
17
|
|
18
18
|
Motivation
|
19
19
|
----------
|
@@ -29,11 +29,11 @@ from tico.experimental.quantization.ptq.observers.affine_base import AffineObser
|
|
29
29
|
|
30
30
|
class IdentityObserver(AffineObserverBase):
|
31
31
|
"""
|
32
|
-
Passthrough observer that
|
32
|
+
Passthrough observer that NEVER alters the tensor.
|
33
33
|
|
34
34
|
• `_update_stats()` → does nothing
|
35
|
-
• `compute_qparams()` → returns (1.0, 0)
|
36
|
-
• `fake_quant()` → returns
|
35
|
+
• `compute_qparams()` → returns (1.0, 0) "dummy" q-params
|
36
|
+
• `fake_quant()` → returns `x` unchanged
|
37
37
|
"""
|
38
38
|
|
39
39
|
def __init__(self, **kwargs):
|
@@ -67,7 +67,7 @@ class IdentityObserver(AffineObserverBase):
|
|
67
67
|
return self._cached_scale, self._cached_zp
|
68
68
|
|
69
69
|
def fake_quant(self, x: torch.Tensor):
|
70
|
-
"""Identity mapping — leaves
|
70
|
+
"""Identity mapping — leaves `x` in FP."""
|
71
71
|
return x
|
72
72
|
|
73
73
|
def __repr__(self) -> str:
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from tico.experimental.quantization.ptq.observers.base import ObserverBase
|
18
|
+
from tico.utils.mx.mx_ops import quantize_mx
|
19
|
+
|
20
|
+
|
21
|
+
class MXObserver(ObserverBase):
|
22
|
+
"""MX (micro-scaling) observer: no min/max, no affine qparams."""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
name: str,
|
28
|
+
elem_format: str = "int8",
|
29
|
+
axis: int = 0,
|
30
|
+
shared_exp_method: str = "max",
|
31
|
+
round: str = "nearest",
|
32
|
+
**base_kwargs,
|
33
|
+
):
|
34
|
+
super().__init__(name=name, **base_kwargs)
|
35
|
+
self.elem_format = elem_format
|
36
|
+
self.axis = axis
|
37
|
+
self.shared_exp_method = shared_exp_method
|
38
|
+
self.round = round
|
39
|
+
|
40
|
+
def reset(self) -> None:
|
41
|
+
# No state to reset
|
42
|
+
return
|
43
|
+
|
44
|
+
@torch.no_grad()
|
45
|
+
def _update_stats(self, x: torch.Tensor) -> None:
|
46
|
+
# No stats required
|
47
|
+
return None
|
48
|
+
|
49
|
+
def compute_qparams(self):
|
50
|
+
# MX path does not produce affine qparams; keep interface contract.
|
51
|
+
return None
|
52
|
+
|
53
|
+
def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
|
54
|
+
return quantize_mx(
|
55
|
+
x,
|
56
|
+
elem_format=self.elem_format,
|
57
|
+
axis=self.axis,
|
58
|
+
shared_exp_method=self.shared_exp_method,
|
59
|
+
round=self.round,
|
60
|
+
)
|
@@ -0,0 +1,111 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass, field
|
16
|
+
from typing import Any, Dict, Mapping, Type
|
17
|
+
|
18
|
+
from tico.experimental.quantization.ptq.dtypes import DType
|
19
|
+
from tico.experimental.quantization.ptq.observers.base import ObserverBase
|
20
|
+
from tico.experimental.quantization.ptq.observers.minmax import MinMaxObserver
|
21
|
+
from tico.experimental.quantization.ptq.qscheme import QScheme
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class QuantConfig:
|
26
|
+
"""
|
27
|
+
One object describes the quantization preferences for a single wrapper
|
28
|
+
and its descendants.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
default_dtype : DType
|
33
|
+
Fallback dtype for every observer that DOES NOT receive an explicit
|
34
|
+
override.
|
35
|
+
default_observer : Type[ObserverBase], optional
|
36
|
+
Observer class to instantiate when the caller (or an override) does
|
37
|
+
not provide a `observer` key.
|
38
|
+
default_qscheme : QScheme
|
39
|
+
Fallback quantization scheme (per-tensor / per-channel,
|
40
|
+
asymmetric / symmetric) for observers that DO NOT receive an explicit
|
41
|
+
override.
|
42
|
+
overrides : Mapping[str, Mapping[str, Any]]
|
43
|
+
Two-level mapping of scopes → observer-kwargs.
|
44
|
+
|
45
|
+
• SCOPE can be either
|
46
|
+
- the attribute name of a child wrapper
|
47
|
+
(e.g. "gate_proj" or "up_proj"), or
|
48
|
+
- an observer logical name inside this wrapper
|
49
|
+
(e.g. "mul", "act_in").
|
50
|
+
|
51
|
+
• "Observer-kwargs" is forwarded verbatim to the observer constructor
|
52
|
+
(`dtype`, `qscheme`, `channel_axis`, `observer`, …).
|
53
|
+
|
54
|
+
Example
|
55
|
+
-------
|
56
|
+
```python
|
57
|
+
from ptq.observers import PercentileObserver
|
58
|
+
|
59
|
+
cfg = QuantConfig(
|
60
|
+
default_dtype = DType.uint(8),
|
61
|
+
default_qscheme = QScheme.PER_TENSOR_SYMM, # <- global scheme
|
62
|
+
default_observer = PercentileObserver, # <- global algorithm
|
63
|
+
overrides={
|
64
|
+
# local override: input observer now MinMax & 4-bit, per-channel asymmetric
|
65
|
+
"act_in": {"observer": MinMaxObserver,
|
66
|
+
"dtype": DType.uint(4),
|
67
|
+
"qscheme": QScheme.PER_CHANNEL_ASYMM},
|
68
|
+
},
|
69
|
+
)
|
70
|
+
```
|
71
|
+
"""
|
72
|
+
|
73
|
+
default_dtype: DType = DType.uint(8)
|
74
|
+
default_observer: Type[ObserverBase] = MinMaxObserver
|
75
|
+
default_qscheme: QScheme = QScheme.PER_TENSOR_ASYMM
|
76
|
+
overrides: Mapping[str, Mapping[str, Any]] = field(default_factory=dict)
|
77
|
+
|
78
|
+
def get_kwargs(self, obs_name: str) -> Dict[str, Any]:
|
79
|
+
"""
|
80
|
+
Return user-specified kwargs for *obs_name* inside **this** wrapper.
|
81
|
+
|
82
|
+
NOTE:
|
83
|
+
Do NOT inject a dtype/qscheme here. `_make_obs()` resolves precedence:
|
84
|
+
1) user override (kw_cfg["dtype" | "qscheme"])
|
85
|
+
2) wrapper's default passed to `_make_obs(..., dtype=..., qscheme=...)`
|
86
|
+
3) self.default_dtype / `self.default_qscheme`
|
87
|
+
"""
|
88
|
+
return dict(self.overrides.get(obs_name, {}))
|
89
|
+
|
90
|
+
def child(self, scope: str) -> "QuantConfig":
|
91
|
+
"""
|
92
|
+
Produce a *view* for a child wrapper.
|
93
|
+
|
94
|
+
The child inherits:
|
95
|
+
• same `default_dtype`
|
96
|
+
• same `default_observer`
|
97
|
+
• same `default_qscheme`
|
98
|
+
• overrides under `self.overrides.get(scope, {})`
|
99
|
+
|
100
|
+
Other scopes remain invisible to the child.
|
101
|
+
"""
|
102
|
+
sub_overrides = self.overrides.get(scope, {})
|
103
|
+
return QuantConfig(
|
104
|
+
self.default_dtype,
|
105
|
+
self.default_observer,
|
106
|
+
default_qscheme=self.default_qscheme,
|
107
|
+
overrides=sub_overrides,
|
108
|
+
)
|
109
|
+
|
110
|
+
def __repr__(self):
|
111
|
+
return f"QuantConfig(default_dtype={self.default_dtype}, default_observer={self.default_observer}, default_qscheme={self.default_qscheme}, overrides={dict(self.overrides)})"
|
File without changes
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.mode import Mode
|
21
|
+
from tico.experimental.quantization.ptq.qscheme import QScheme
|
22
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
23
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
24
|
+
QuantModuleBase,
|
25
|
+
)
|
26
|
+
from tico.experimental.quantization.ptq.wrappers.registry import register
|
27
|
+
|
28
|
+
|
29
|
+
@register(nn.Linear)
|
30
|
+
class QuantLinear(QuantModuleBase):
|
31
|
+
"""Per-channel weight fake-quant, eager-output activation fake-quant."""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
fp: nn.Linear,
|
36
|
+
*,
|
37
|
+
qcfg: Optional[QuantConfig] = None,
|
38
|
+
fp_name: Optional[str] = None
|
39
|
+
):
|
40
|
+
super().__init__(qcfg, fp_name=fp_name)
|
41
|
+
self.weight_obs = self._make_obs(
|
42
|
+
"weight", qscheme=QScheme.PER_CHANNEL_ASYMM, channel_axis=0
|
43
|
+
)
|
44
|
+
self.act_in_obs = self._make_obs("act_in")
|
45
|
+
self.act_out_obs = self._make_obs("act_out")
|
46
|
+
self.module = fp
|
47
|
+
|
48
|
+
def enable_calibration(self) -> None:
|
49
|
+
super().enable_calibration()
|
50
|
+
# immediately capture the fixed weight range
|
51
|
+
self.weight_obs.collect(self.module.weight)
|
52
|
+
|
53
|
+
def forward(self, x):
|
54
|
+
x_q = self._fq(x, self.act_in_obs)
|
55
|
+
|
56
|
+
w = self.module.weight
|
57
|
+
if self._mode is Mode.QUANT:
|
58
|
+
w = self.weight_obs.fake_quant(w)
|
59
|
+
b = self.module.bias
|
60
|
+
|
61
|
+
out = F.linear(x_q, w, b)
|
62
|
+
|
63
|
+
return self._fq(out, self.act_out_obs)
|
64
|
+
|
65
|
+
def _all_observers(self):
|
66
|
+
return (self.weight_obs, self.act_in_obs, self.act_out_obs)
|
@@ -0,0 +1,153 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Iterable, Optional, Tuple
|
17
|
+
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.mode import Mode
|
21
|
+
from tico.experimental.quantization.ptq.observers.base import ObserverBase
|
22
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
23
|
+
|
24
|
+
|
25
|
+
class QuantModuleBase(nn.Module, ABC):
|
26
|
+
"""
|
27
|
+
Abstract parent for EVERY wrapper.
|
28
|
+
|
29
|
+
Responsibilities
|
30
|
+
----------------
|
31
|
+
• Own *one* Mode enum (`NO_QUANT / CALIB / QUANT`)
|
32
|
+
• Own a QuantConfig describing default / per-observer dtypes
|
33
|
+
• Expose a canonical lifecycle:
|
34
|
+
enable_calibration()
|
35
|
+
freeze_qparams()
|
36
|
+
• Provide helper `_fq(x, observer)` (“fake-quant or collect”) so
|
37
|
+
subclasses write arithmetic code without boilerplate.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self, qcfg: Optional[QuantConfig] = None, *, fp_name: Optional[str] = None
|
42
|
+
) -> None:
|
43
|
+
super().__init__()
|
44
|
+
self.qcfg = qcfg or QuantConfig()
|
45
|
+
self._mode: Mode = Mode.NO_QUANT # default state
|
46
|
+
self.fp_name = fp_name
|
47
|
+
|
48
|
+
def _child_quant_modules(self):
|
49
|
+
"""Yield direct children that are QuantModuleBase."""
|
50
|
+
for m in self.children():
|
51
|
+
if isinstance(m, QuantModuleBase):
|
52
|
+
yield m
|
53
|
+
|
54
|
+
def enable_calibration(self) -> None:
|
55
|
+
self._mode = Mode.CALIB
|
56
|
+
for obs in self._all_observers():
|
57
|
+
obs.enabled = True
|
58
|
+
obs.reset()
|
59
|
+
|
60
|
+
# propagate to children
|
61
|
+
for child in self._child_quant_modules():
|
62
|
+
child.enable_calibration()
|
63
|
+
|
64
|
+
def freeze_qparams(self) -> None:
|
65
|
+
self._mode = Mode.QUANT
|
66
|
+
for obs in self._all_observers():
|
67
|
+
obs.enabled = False
|
68
|
+
obs.compute_qparams()
|
69
|
+
|
70
|
+
# propagate to children
|
71
|
+
for child in self._child_quant_modules():
|
72
|
+
child.freeze_qparams()
|
73
|
+
|
74
|
+
def _fq(self, x, obs: ObserverBase):
|
75
|
+
"""Fake-quant or collect."""
|
76
|
+
if self._mode is Mode.CALIB:
|
77
|
+
obs.collect(x.detach())
|
78
|
+
return x
|
79
|
+
if self._mode is Mode.QUANT:
|
80
|
+
return obs.fake_quant(x)
|
81
|
+
return x # NO_QUANT
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def _all_observers(self) -> Iterable[ObserverBase]:
|
85
|
+
"""Return every observer owned by this module."""
|
86
|
+
...
|
87
|
+
|
88
|
+
def named_observers(self) -> Iterable[Tuple[str, ObserverBase]]:
|
89
|
+
for obs in self._all_observers():
|
90
|
+
yield obs.name, obs
|
91
|
+
|
92
|
+
def get_observer(self, name: str) -> Optional[ObserverBase]:
|
93
|
+
for obs in self._all_observers():
|
94
|
+
if obs.name == name:
|
95
|
+
return obs
|
96
|
+
return None
|
97
|
+
|
98
|
+
def _make_obs(
|
99
|
+
self,
|
100
|
+
name: str,
|
101
|
+
**default_kwargs,
|
102
|
+
) -> ObserverBase:
|
103
|
+
"""
|
104
|
+
Instantiate an observer named *name*.
|
105
|
+
|
106
|
+
Precedence (3-tier) for keys:
|
107
|
+
• observer: user > wrapper-default > QuantConfig.default_observer
|
108
|
+
• dtype: user > wrapper-default > QuantConfig.default_dtype
|
109
|
+
• qscheme: user > wrapper-default > QuantConfig.default_qscheme
|
110
|
+
|
111
|
+
Other kwargs (e.g., qscheme, channel_axis, etc.) remain:
|
112
|
+
user override > wrapper-default
|
113
|
+
"""
|
114
|
+
_UNSPEC = object()
|
115
|
+
|
116
|
+
wrapper_defaults = default_kwargs.copy()
|
117
|
+
user_cfg = self.qcfg.get_kwargs(name).copy()
|
118
|
+
|
119
|
+
def pick3(user_val, wrap_val, global_val):
|
120
|
+
return (
|
121
|
+
user_val
|
122
|
+
if user_val is not _UNSPEC
|
123
|
+
else wrap_val
|
124
|
+
if wrap_val is not _UNSPEC
|
125
|
+
else global_val
|
126
|
+
)
|
127
|
+
|
128
|
+
# 1) resolve observer class
|
129
|
+
user_observer = user_cfg.pop("observer", _UNSPEC)
|
130
|
+
wrapper_observer = wrapper_defaults.pop("observer", _UNSPEC)
|
131
|
+
obs_cls = pick3(user_observer, wrapper_observer, self.qcfg.default_observer)
|
132
|
+
|
133
|
+
# 2) resolve dtype
|
134
|
+
user_dtype = user_cfg.pop("dtype", _UNSPEC)
|
135
|
+
wrapper_dtype = wrapper_defaults.pop("dtype", _UNSPEC)
|
136
|
+
final_dtype = pick3(user_dtype, wrapper_dtype, self.qcfg.default_dtype)
|
137
|
+
|
138
|
+
# 3) resolve qscheme
|
139
|
+
user_qscheme = user_cfg.pop("qscheme", _UNSPEC)
|
140
|
+
wrapper_qscheme = wrapper_defaults.pop("qscheme", _UNSPEC)
|
141
|
+
final_qscheme = pick3(user_qscheme, wrapper_qscheme, self.qcfg.default_qscheme)
|
142
|
+
|
143
|
+
# 4) merge remaining kwargs: user_cfg wins
|
144
|
+
final_kw = wrapper_defaults
|
145
|
+
final_kw.update(user_cfg)
|
146
|
+
final_kw["dtype"] = final_dtype
|
147
|
+
final_kw["qscheme"] = final_qscheme
|
148
|
+
|
149
|
+
return obs_cls(**final_kw, name=name)
|
150
|
+
|
151
|
+
# nice repr
|
152
|
+
def extra_repr(self) -> str:
|
153
|
+
return f"mode={self._mode.name.lower()}"
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import importlib
|
16
|
+
from typing import Callable, Dict, Type
|
17
|
+
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
21
|
+
QuantModuleBase,
|
22
|
+
)
|
23
|
+
|
24
|
+
_WRAPPERS: Dict[Type[nn.Module], Type[QuantModuleBase]] = {}
|
25
|
+
_IMPORT_ONCE = False
|
26
|
+
_CORE_MODULES = (
|
27
|
+
"tico.experimental.quantization.ptq.wrappers.nn.quant_linear",
|
28
|
+
# add future core wrappers here
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def _lazy_init():
|
33
|
+
"""
|
34
|
+
Deferred one-shot import of "core wrapper modules".
|
35
|
+
|
36
|
+
Why not import everything when the program first starts?
|
37
|
+
--------------------------------------------------
|
38
|
+
* **Avoid circular-import hell**
|
39
|
+
Core wrappers often import `PTQWrapper`, which in turn calls
|
40
|
+
`registry.lookup()`. Importing those files eagerly here would create a
|
41
|
+
cycle (`registry → wrapper → registry`). Delaying the import until the
|
42
|
+
*first* `lookup()` call lets Python finish constructing the registry
|
43
|
+
module before any wrapper files are touched.
|
44
|
+
|
45
|
+
* **Cold-start speed**
|
46
|
+
Most user code never wraps layers explicitly; they only hit
|
47
|
+
`PTQWrapper` if they are doing quantization. Deferring half-a-dozen
|
48
|
+
heavyweight `import torch …` files until they are really needed
|
49
|
+
reduces library start-up latency in the common path.
|
50
|
+
|
51
|
+
* **Optional dependencies**
|
52
|
+
Core wrappers listed in `_CORE_MODULES` are chosen to be dependency-free
|
53
|
+
(pure PyTorch). Anything that needs `transformers`, `torchvision`,
|
54
|
+
etc. uses the `@try_register()` decorator inside its own module. Those
|
55
|
+
optional modules are *not* imported here, so users without the extra
|
56
|
+
packages still get a clean import.
|
57
|
+
|
58
|
+
Implementation notes
|
59
|
+
--------------------
|
60
|
+
* `_IMPORT_ONCE` guard ensures we execute the import loop only once,
|
61
|
+
even if `lookup()` is called from multiple threads.
|
62
|
+
* Each path in `_CORE_MODULES` is a "fully-qualified module string"
|
63
|
+
(e.g. "ptq.wrappers.linear_quant"). Importing the module runs all
|
64
|
+
its `@register(nn.Layer)` decorators, populating `_WRAPPERS`.
|
65
|
+
* After the first call the function becomes a cheap constant-time no-op.
|
66
|
+
"""
|
67
|
+
global _IMPORT_ONCE
|
68
|
+
if _IMPORT_ONCE:
|
69
|
+
return
|
70
|
+
for mod in _CORE_MODULES:
|
71
|
+
__import__(mod) # triggers decorators
|
72
|
+
_IMPORT_ONCE = True
|
73
|
+
|
74
|
+
|
75
|
+
# ───────────────────────────── decorator for always-present classes
|
76
|
+
def register(
|
77
|
+
fp_cls: Type[nn.Module],
|
78
|
+
) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
|
79
|
+
def _decorator(quant_cls: Type[QuantModuleBase]):
|
80
|
+
_WRAPPERS[fp_cls] = quant_cls
|
81
|
+
return quant_cls
|
82
|
+
|
83
|
+
return _decorator
|
84
|
+
|
85
|
+
|
86
|
+
# ───────────────────────────── conditional decorator
|
87
|
+
def try_register(path: str) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
|
88
|
+
"""
|
89
|
+
@try_register("transformers.models.llama.modeling_llama.LlamaMLP")
|
90
|
+
|
91
|
+
• If import succeeds → behave like `@register`
|
92
|
+
• If module/class not found → become a NO-OP
|
93
|
+
"""
|
94
|
+
|
95
|
+
def _decorator(quant_cls: Type[QuantModuleBase]):
|
96
|
+
module_name, _, cls_name = path.rpartition(".")
|
97
|
+
try:
|
98
|
+
mod = importlib.import_module(module_name)
|
99
|
+
fp_cls = getattr(mod, cls_name)
|
100
|
+
_WRAPPERS[fp_cls] = quant_cls
|
101
|
+
except (ModuleNotFoundError, AttributeError):
|
102
|
+
# transformers not installed or class renamed – silently skip
|
103
|
+
pass
|
104
|
+
return quant_cls
|
105
|
+
|
106
|
+
return _decorator
|
107
|
+
|
108
|
+
|
109
|
+
# ───────────────────────────── lookup
|
110
|
+
def lookup(fp_cls: Type[nn.Module]) -> Type[QuantModuleBase] | None:
|
111
|
+
_lazy_init()
|
112
|
+
return _WRAPPERS.get(fp_cls)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=I2l2YkiyVNNCBv8PmWOlk2OxIDiMcpubxCENlSiOh8E,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -60,14 +60,21 @@ tico/experimental/quantization/ptq/__init__.py,sha256=ZoPdEwZ1i1n5pBFChx8GuUrkfR
|
|
60
60
|
tico/experimental/quantization/ptq/dtypes.py,sha256=xfCBtq6mQmUYRwsoFgII6gvRl1raQi0Inj9pznDuKwQ,2236
|
61
61
|
tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAdh_3PWB-ptPKaI,1052
|
62
62
|
tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
|
63
|
-
tico/experimental/quantization/ptq/
|
63
|
+
tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
|
64
|
+
tico/experimental/quantization/ptq/observers/__init__.py,sha256=WF2MvL9M_jl-B1FqcY9zic34NOCRp17HkRYv-TMxMr4,613
|
64
65
|
tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
|
65
66
|
tico/experimental/quantization/ptq/observers/base.py,sha256=Wons1MzpqK1mfcy-ppl-B2Dum0edXg2dWW2Lw3V18tw,3280
|
66
|
-
tico/experimental/quantization/ptq/observers/ema.py,sha256=
|
67
|
-
tico/experimental/quantization/ptq/observers/identity.py,sha256=
|
67
|
+
tico/experimental/quantization/ptq/observers/ema.py,sha256=MAMdBmjVNMg_vsqXrcBzbw_1nFJ-j4Gz651k3-VlaMQ,2057
|
68
|
+
tico/experimental/quantization/ptq/observers/identity.py,sha256=vkec8Or-7VwM4zkFEvEKROQJk8XEHMVX8mBNDnxSyS8,2591
|
68
69
|
tico/experimental/quantization/ptq/observers/minmax.py,sha256=mLHkwIzWFzQXev7EU7w1333KckwRjukc3_cUPJOnUfs,1486
|
70
|
+
tico/experimental/quantization/ptq/observers/mx.py,sha256=aP4qmBgeiRIYZJksShN5gs6UyYOFi2-Sbk5k5xvPQ4w,1863
|
69
71
|
tico/experimental/quantization/ptq/utils/__init__.py,sha256=PL9IZgiWoMtsXVljeOy7KymmLVP238SXEFRLXYK72WQ,126
|
70
72
|
tico/experimental/quantization/ptq/utils/reduce_utils.py,sha256=3kWawLB91EcvvHlCrNqqfZF7tpgr22htBSA049mKw_4,973
|
73
|
+
tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
74
|
+
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=6RK4bn9G1pzFmkIdBdFf7liBOpb-b7rpthgD83AgkbQ,5256
|
75
|
+
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=exXl2wNNzVgC2P9gMjpF_-PqIBgYERGruzh0u1Pril0,4367
|
76
|
+
tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=q4A9BiGlsa8ZdGV3y0SDiSkzkdVugsK2iz2daiJqBCY,118
|
77
|
+
tico/experimental/quantization/ptq/wrappers/nn/quant_linear.py,sha256=xW-VEPB7RJoslS3xLVCdhIuMjppknvpkZleRGK4JFVQ,2240
|
71
78
|
tico/interpreter/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
72
79
|
tico/interpreter/infer.py,sha256=1ZFe3DVMR2mlwBosoedqoL0-CGN_01CKLgMgxuw62KA,4861
|
73
80
|
tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
|
@@ -222,9 +229,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
222
229
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
223
230
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
224
231
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
225
|
-
tico-0.1.0.
|
226
|
-
tico-0.1.0.
|
227
|
-
tico-0.1.0.
|
228
|
-
tico-0.1.0.
|
229
|
-
tico-0.1.0.
|
230
|
-
tico-0.1.0.
|
232
|
+
tico-0.1.0.dev250820.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
233
|
+
tico-0.1.0.dev250820.dist-info/METADATA,sha256=SF_FvVX_JdbjpMAN2-VJGo7zsYU2_nEPXCHap296J3Y,8450
|
234
|
+
tico-0.1.0.dev250820.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
235
|
+
tico-0.1.0.dev250820.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
236
|
+
tico-0.1.0.dev250820.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
237
|
+
tico-0.1.0.dev250820.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|