tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -20,18 +20,12 @@ import torch
|
|
|
20
20
|
from circle_schema import circle
|
|
21
21
|
from torch.utils import _pytree as pytree
|
|
22
22
|
|
|
23
|
-
from tico.
|
|
24
|
-
from tico.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
from tico.
|
|
28
|
-
|
|
29
|
-
)
|
|
30
|
-
from tico.experimental.quantization.evaluation.executor.triv24_executor import (
|
|
31
|
-
Triv24Executor,
|
|
32
|
-
)
|
|
33
|
-
from tico.experimental.quantization.evaluation.metric import MetricCalculator
|
|
34
|
-
from tico.experimental.quantization.evaluation.utils import (
|
|
23
|
+
from tico.quantization.evaluation.backend import BACKEND
|
|
24
|
+
from tico.quantization.evaluation.executor.backend_executor import BackendExecutor
|
|
25
|
+
from tico.quantization.evaluation.executor.circle_executor import CircleExecutor
|
|
26
|
+
from tico.quantization.evaluation.executor.triv24_executor import Triv24Executor
|
|
27
|
+
from tico.quantization.evaluation.metric import MetricCalculator
|
|
28
|
+
from tico.quantization.evaluation.utils import (
|
|
35
29
|
ensure_list,
|
|
36
30
|
find_invalid_types,
|
|
37
31
|
get_graph_input_output,
|
|
@@ -114,7 +108,6 @@ def evaluate(
|
|
|
114
108
|
input_data: InputDataType = None,
|
|
115
109
|
*,
|
|
116
110
|
mode="plot",
|
|
117
|
-
metrics: List[str] = ["peir"],
|
|
118
111
|
custom_metrics: Dict[str, Callable] = dict(),
|
|
119
112
|
) -> Optional[Dict[str, Any]]:
|
|
120
113
|
"""
|
|
@@ -140,8 +133,6 @@ def evaluate(
|
|
|
140
133
|
The mode of operation. Options are:
|
|
141
134
|
- "plot": Plot the results (default)
|
|
142
135
|
- "return": Return the results.
|
|
143
|
-
metrics
|
|
144
|
-
A list of metric names for comparison.
|
|
145
136
|
custom_metrics
|
|
146
137
|
A dictionary of metric names and corresponding callable functions for comparison.
|
|
147
138
|
Example: {'mse': mean_squared_error, 'cosine_similarity': cosine_similarity_fn}
|
|
@@ -166,7 +157,7 @@ def evaluate(
|
|
|
166
157
|
)
|
|
167
158
|
if not isinstance(backend, BACKEND):
|
|
168
159
|
raise RuntimeError(
|
|
169
|
-
|
|
160
|
+
"Invalid backend. Please use tico.quantization.evaluate.BACKEND enum class"
|
|
170
161
|
)
|
|
171
162
|
# Make it a list for simpler logic.
|
|
172
163
|
if input_data is not None:
|
|
@@ -205,7 +196,7 @@ def evaluate(
|
|
|
205
196
|
)
|
|
206
197
|
|
|
207
198
|
# Computes the comparison score based on the provided metrics.
|
|
208
|
-
metric_calculator = MetricCalculator(
|
|
199
|
+
metric_calculator = MetricCalculator(custom_metrics)
|
|
209
200
|
results: Dict[str, Any] = metric_calculator.compute(torch_output, circle_output)
|
|
210
201
|
|
|
211
202
|
if mode == "return":
|
|
@@ -19,9 +19,7 @@ from typing import List
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
import torch
|
|
21
21
|
|
|
22
|
-
from tico.
|
|
23
|
-
BackendExecutor,
|
|
24
|
-
)
|
|
22
|
+
from tico.quantization.evaluation.executor.backend_executor import BackendExecutor
|
|
25
23
|
from tico.utils.model import CircleModel
|
|
26
24
|
from tico.utils.utils import run_bash_cmd
|
|
27
25
|
|
|
@@ -72,4 +70,5 @@ class CircleExecutor(BackendExecutor):
|
|
|
72
70
|
return out
|
|
73
71
|
|
|
74
72
|
def __del__(self):
|
|
75
|
-
self.temp_dir
|
|
73
|
+
if hasattr(self, "temp_dir") and self.temp_dir:
|
|
74
|
+
self.temp_dir.cleanup()
|
|
@@ -20,10 +20,8 @@ import numpy as np
|
|
|
20
20
|
import torch
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
|
-
from tico.
|
|
24
|
-
|
|
25
|
-
)
|
|
26
|
-
from tico.experimental.quantization.evaluation.utils import (
|
|
23
|
+
from tico.quantization.evaluation.executor.backend_executor import BackendExecutor
|
|
24
|
+
from tico.quantization.evaluation.utils import (
|
|
27
25
|
dequantize,
|
|
28
26
|
get_graph_input_output,
|
|
29
27
|
quantize,
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def compute_max_abs_diff(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
22
|
+
"""
|
|
23
|
+
Return the *maximum* absolute element-wise difference between two tensors.
|
|
24
|
+
"""
|
|
25
|
+
assert base.shape == target.shape, "shape mismatch"
|
|
26
|
+
return (base.detach() - target.detach()).abs().max().item()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_peir(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
30
|
+
"""
|
|
31
|
+
Peak-Error-to-Interval Ratio (PEIR).
|
|
32
|
+
|
|
33
|
+
PEIR = max(|base - target|) / (max(base) - min(base))
|
|
34
|
+
|
|
35
|
+
The interval denominator uses the reference (*base*) tensor only — this
|
|
36
|
+
makes PEIR independent of quantisation error in `target`.
|
|
37
|
+
"""
|
|
38
|
+
assert base.shape == target.shape, "shape mismatch"
|
|
39
|
+
peak_error = (base.detach() - target.detach()).abs().max().item()
|
|
40
|
+
interval = (base.detach().max() - base.detach().min()).item()
|
|
41
|
+
interval = 1.0 if interval == 0.0 else interval # avoid divide-by-zero
|
|
42
|
+
return peak_error / interval
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def mse(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
46
|
+
"""
|
|
47
|
+
Mean Squared Error (MSE).
|
|
48
|
+
Penalizes **larger** deviations more heavily than MAE by squaring each
|
|
49
|
+
difference — helpful to expose occasional large spikes.
|
|
50
|
+
Formula
|
|
51
|
+
-------
|
|
52
|
+
MSE = mean((base - target)²)
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
float
|
|
56
|
+
Mean squared error. *Lower is better*.
|
|
57
|
+
"""
|
|
58
|
+
return torch.mean((base.detach() - target.detach()) ** 2).item()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MetricCalculator:
|
|
62
|
+
"""
|
|
63
|
+
Lightweight registry-and-dispatcher for **pair-wise tensor comparison metrics**.
|
|
64
|
+
|
|
65
|
+
Purpose
|
|
66
|
+
-------
|
|
67
|
+
Consolidate all metrics used to assess the discrepancy between a reference
|
|
68
|
+
(usually FP32) tensor and its quantized counterpart, while letting the caller
|
|
69
|
+
choose *at runtime* which subset to evaluate.
|
|
70
|
+
|
|
71
|
+
Built-in metrics
|
|
72
|
+
----------------
|
|
73
|
+
Key Description
|
|
74
|
+
-------------------- -------------------------------------------------
|
|
75
|
+
"diff" / "max_abs_diff" Maximum absolute element-wise difference
|
|
76
|
+
"peir" Peak-Error-to-Interval Ratio
|
|
77
|
+
|
|
78
|
+
Usage pattern
|
|
79
|
+
-------------
|
|
80
|
+
>>> calc = MetricCalculator(custom_metrics={'mse': mse_fn})
|
|
81
|
+
>>> stats = calc.compute(fp_outs, q_outs, metrics=['diff', 'mse'])
|
|
82
|
+
|
|
83
|
+
• **Instantiation** registers any extra user metrics
|
|
84
|
+
(signature: ``fn(base: Tensor, target: Tensor) -> float``).
|
|
85
|
+
• **compute(...)** takes two *equal-length* lists of tensors and an optional
|
|
86
|
+
list of metric names.
|
|
87
|
+
— If *metrics* is *None*, every registered metric is evaluated.
|
|
88
|
+
— Returns a dict: ``{metric_name -> [value for each tensor pair]}``.
|
|
89
|
+
|
|
90
|
+
Implementation notes
|
|
91
|
+
--------------------
|
|
92
|
+
* All tensors are detached before calculation to avoid autograd overhead.
|
|
93
|
+
* Registrations are stored in `self.registry` (str → callable).
|
|
94
|
+
* Duplicate metric names between built-ins and custom metrics raise an error
|
|
95
|
+
at construction time to prevent silent shadowing.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
builtin_metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = {
|
|
99
|
+
"diff": compute_max_abs_diff,
|
|
100
|
+
"max_abs_diff": compute_max_abs_diff,
|
|
101
|
+
"peir": compute_peir,
|
|
102
|
+
"mse": mse,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
custom_metrics: Optional[
|
|
108
|
+
Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]
|
|
109
|
+
] = None,
|
|
110
|
+
):
|
|
111
|
+
self.registry: Dict[str, Callable] = self.builtin_metrics.copy()
|
|
112
|
+
if custom_metrics:
|
|
113
|
+
dup = self.registry.keys() & custom_metrics.keys()
|
|
114
|
+
if dup:
|
|
115
|
+
raise RuntimeError(f"Duplicate metric names: {dup}")
|
|
116
|
+
assert custom_metrics is not None
|
|
117
|
+
self.registry.update(custom_metrics) # type: ignore[arg-type]
|
|
118
|
+
|
|
119
|
+
# ----------------------------------------------------------------- #
|
|
120
|
+
# Public API #
|
|
121
|
+
# ----------------------------------------------------------------- #
|
|
122
|
+
def compute(
|
|
123
|
+
self,
|
|
124
|
+
base_outputs: List[torch.Tensor],
|
|
125
|
+
target_outputs: List[torch.Tensor],
|
|
126
|
+
metrics: Optional[List[str]] = None,
|
|
127
|
+
) -> Dict[str, List[Any]]:
|
|
128
|
+
"""
|
|
129
|
+
Compute selected metrics for every (base, target) pair.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
metrics
|
|
134
|
+
List of metric names to evaluate **this call**.
|
|
135
|
+
• None → evaluate *all* registered metrics.
|
|
136
|
+
"""
|
|
137
|
+
sel = metrics or list(self.registry)
|
|
138
|
+
unknown = set(sel) - self.registry.keys()
|
|
139
|
+
if unknown:
|
|
140
|
+
raise RuntimeError(f"Unknown metric(s): {unknown}")
|
|
141
|
+
|
|
142
|
+
results: Dict[str, List[Any]] = {m: [] for m in sel}
|
|
143
|
+
for base, tgt in zip(base_outputs, target_outputs):
|
|
144
|
+
for m in sel:
|
|
145
|
+
results[m].append(self.registry[m](base, tgt))
|
|
146
|
+
return results
|
|
@@ -44,7 +44,7 @@ def quantize(
|
|
|
44
44
|
data = np.array(data)
|
|
45
45
|
# Perfrom quantization
|
|
46
46
|
if not scale:
|
|
47
|
-
logger.
|
|
47
|
+
logger.warning("WARNING: scale value is 0. 1e-7 will be used instead.")
|
|
48
48
|
scale = 1e-7
|
|
49
49
|
rescaled = np.round(data / scale) + zero_point
|
|
50
50
|
# Clamp the values
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|