tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -1,164 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import Dict, List, Optional
|
|
16
|
-
|
|
17
|
-
import torch
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@torch.no_grad()
|
|
21
|
-
def smooth_weights(
|
|
22
|
-
front_module: torch.nn.Module,
|
|
23
|
-
back_modules: torch.nn.Module | List[torch.nn.Module],
|
|
24
|
-
activation_max: torch.Tensor,
|
|
25
|
-
alpha: float,
|
|
26
|
-
):
|
|
27
|
-
"""
|
|
28
|
-
Applies SmoothQuant-style smoothing to the weights and biases of two
|
|
29
|
-
connected modules using activation maximum values.
|
|
30
|
-
|
|
31
|
-
NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
|
|
32
|
-
|
|
33
|
-
Parameters
|
|
34
|
-
-----------
|
|
35
|
-
front_module
|
|
36
|
-
The front module whose weights and biases will be adjusted.
|
|
37
|
-
back_modules
|
|
38
|
-
A list of back modules whose weights and biases will be adjusted.
|
|
39
|
-
activation_max
|
|
40
|
-
A tensor of channel-wise maximum activation values for the front module.
|
|
41
|
-
alpha
|
|
42
|
-
The smoothing factor that determines the scaling for weight adjustments.
|
|
43
|
-
|
|
44
|
-
Raises
|
|
45
|
-
-------
|
|
46
|
-
AttributeError
|
|
47
|
-
If `front_module` or any module in `back_modules` does not have `weight` attributes.
|
|
48
|
-
ValueError
|
|
49
|
-
If the shape of tensors in `activation_max` does not match the number of channels
|
|
50
|
-
in `front_module`'s weight.
|
|
51
|
-
NoteImplementedError
|
|
52
|
-
If `front_module` or any module in `back_modules` is of an unsupported type.
|
|
53
|
-
"""
|
|
54
|
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|
55
|
-
|
|
56
|
-
if not isinstance(back_modules, list):
|
|
57
|
-
back_modules = [back_modules]
|
|
58
|
-
|
|
59
|
-
# Check attributes
|
|
60
|
-
if not hasattr(front_module, "weight"):
|
|
61
|
-
raise AttributeError(
|
|
62
|
-
f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
|
|
63
|
-
)
|
|
64
|
-
for back_m in back_modules:
|
|
65
|
-
if not hasattr(back_m, "weight"):
|
|
66
|
-
raise AttributeError(
|
|
67
|
-
f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
|
|
68
|
-
)
|
|
69
|
-
# Check shapes
|
|
70
|
-
if isinstance(front_module, LlamaRMSNorm):
|
|
71
|
-
front_numel = front_module.weight.numel()
|
|
72
|
-
else:
|
|
73
|
-
raise NotImplementedError(
|
|
74
|
-
f"Unsupported module type: {type(front_module).__name__}"
|
|
75
|
-
)
|
|
76
|
-
for back_m in back_modules:
|
|
77
|
-
if isinstance(back_m, torch.nn.Linear):
|
|
78
|
-
back_numel = back_m.in_features
|
|
79
|
-
else:
|
|
80
|
-
raise NotImplementedError(
|
|
81
|
-
f"Unsupported module type: {type(front_module).__name__}"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
if front_numel != back_numel or back_numel != activation_max.numel():
|
|
85
|
-
raise ValueError(
|
|
86
|
-
f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Compute scales
|
|
90
|
-
device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
|
|
91
|
-
activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
|
|
92
|
-
weight_scales = torch.cat(
|
|
93
|
-
[back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
|
|
94
|
-
dim=0,
|
|
95
|
-
)
|
|
96
|
-
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
|
97
|
-
scales = (
|
|
98
|
-
(activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
|
|
99
|
-
.clamp(min=1e-5)
|
|
100
|
-
.to(device) # type: ignore[arg-type]
|
|
101
|
-
.to(dtype) # type: ignore[arg-type]
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
# Smooth
|
|
105
|
-
front_module.weight.div_(scales)
|
|
106
|
-
if hasattr(front_module, "bias"):
|
|
107
|
-
front_module.bias.div_(scales)
|
|
108
|
-
|
|
109
|
-
for back_m in back_modules:
|
|
110
|
-
back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
@torch.no_grad()
|
|
114
|
-
def apply_smoothing(
|
|
115
|
-
model: torch.nn.Module,
|
|
116
|
-
activation_max: Dict[str, torch.Tensor],
|
|
117
|
-
alpha: float = 0.5,
|
|
118
|
-
custom_alpha_map: Optional[Dict[str, float]] = None,
|
|
119
|
-
):
|
|
120
|
-
"""
|
|
121
|
-
Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
|
|
122
|
-
|
|
123
|
-
Parameters
|
|
124
|
-
-----------
|
|
125
|
-
model
|
|
126
|
-
A torch module whose weights will be smoothed.
|
|
127
|
-
activation_max
|
|
128
|
-
The channel-wise maximum activation values for the model.
|
|
129
|
-
alpha
|
|
130
|
-
The default smoothing factor to apply across all modules.
|
|
131
|
-
custom_alpha_map
|
|
132
|
-
A dictionary mapping layer/module names to custom alpha values.
|
|
133
|
-
Layers specified in this dictionary will use the corresponding alpha
|
|
134
|
-
value instead of the default.
|
|
135
|
-
"""
|
|
136
|
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
137
|
-
|
|
138
|
-
for name, module in model.named_modules():
|
|
139
|
-
alpha_to_apply = alpha
|
|
140
|
-
if custom_alpha_map and name in custom_alpha_map:
|
|
141
|
-
alpha_to_apply = custom_alpha_map[name]
|
|
142
|
-
if alpha_to_apply > 1.0:
|
|
143
|
-
raise RuntimeError(
|
|
144
|
-
f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
|
|
145
|
-
)
|
|
146
|
-
# SmoothQuant is applied before capturing the graph. Therefore, it needs to know
|
|
147
|
-
# specific module information.
|
|
148
|
-
# TODO Suport more modules.
|
|
149
|
-
if isinstance(module, LlamaDecoderLayer):
|
|
150
|
-
attn_ln = module.input_layernorm
|
|
151
|
-
qkv = [
|
|
152
|
-
module.self_attn.q_proj,
|
|
153
|
-
module.self_attn.k_proj,
|
|
154
|
-
module.self_attn.v_proj,
|
|
155
|
-
]
|
|
156
|
-
|
|
157
|
-
qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
|
|
158
|
-
smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
|
|
159
|
-
|
|
160
|
-
ffn_ln = module.post_attention_layernorm
|
|
161
|
-
fcs = [module.mlp.gate_proj, module.mlp.up_proj]
|
|
162
|
-
fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
|
|
163
|
-
|
|
164
|
-
smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
|
|
@@ -1,109 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import Any, Callable, Dict, List
|
|
16
|
-
|
|
17
|
-
import numpy as np
|
|
18
|
-
import torch
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def compute_peir(base: torch.Tensor, target: torch.Tensor):
|
|
22
|
-
"""
|
|
23
|
-
Calculate the Peak Error to Interval Ratio (PEIR) between two tensors.
|
|
24
|
-
|
|
25
|
-
This function computes the PEIR between two tensors using the formula:
|
|
26
|
-
PEIR = max(abs(tensor1 - tensor2)) / (max(tensor1) - min(tensor2))
|
|
27
|
-
"""
|
|
28
|
-
assert base.shape == target.shape, f"shape mismatch: {base.shape} != {target.shape}"
|
|
29
|
-
base_tensor = base.numpy()
|
|
30
|
-
target_tensor = target.numpy()
|
|
31
|
-
assert (
|
|
32
|
-
base_tensor.dtype == np.float32 and target_tensor.dtype == np.float32
|
|
33
|
-
), f"dtype should be float32: base({base_tensor.dtype}), target({target_tensor.dtype})"
|
|
34
|
-
|
|
35
|
-
base_tensor = base_tensor.reshape(-1)
|
|
36
|
-
target_tensor = target_tensor.reshape(-1)
|
|
37
|
-
|
|
38
|
-
assert (
|
|
39
|
-
base_tensor.shape == target_tensor.shape
|
|
40
|
-
), f"Shape mismatch: {base_tensor.shape} != {target_tensor.shape}"
|
|
41
|
-
|
|
42
|
-
peak_error = np.max(np.absolute(target_tensor - base_tensor))
|
|
43
|
-
interval = np.max(base_tensor) - np.min(base_tensor)
|
|
44
|
-
peir = peak_error / interval # pylint: disable=invalid-name
|
|
45
|
-
|
|
46
|
-
min_value = min([base_tensor.min(), target_tensor.min()])
|
|
47
|
-
max_value = max([base_tensor.max(), target_tensor.max()])
|
|
48
|
-
|
|
49
|
-
interval = max_value - min_value
|
|
50
|
-
interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
|
|
51
|
-
|
|
52
|
-
return peir
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class MetricCalculator:
|
|
56
|
-
"""
|
|
57
|
-
Compute metrics including both built-in and custom metrics.
|
|
58
|
-
|
|
59
|
-
metrics
|
|
60
|
-
A list of metric names for comparison.
|
|
61
|
-
custom_metrics
|
|
62
|
-
A dictionary of metric names and corresponding callable functions for comparison.
|
|
63
|
-
Example: {'mse': mean_squared_error, 'cosine_similarity': cosine_similarity_fn}
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
builtin_metrics = {
|
|
67
|
-
"peir": compute_peir,
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
metrics: List[str] = list(),
|
|
73
|
-
custom_metrics: Dict[str, Callable] = dict(),
|
|
74
|
-
):
|
|
75
|
-
self.metrics: Dict[str, Callable] = dict()
|
|
76
|
-
|
|
77
|
-
for m in metrics:
|
|
78
|
-
if m in self.builtin_metrics:
|
|
79
|
-
self.metrics[m] = self.builtin_metrics[m]
|
|
80
|
-
else:
|
|
81
|
-
raise RuntimeError(f"Invalid metric: {m}")
|
|
82
|
-
|
|
83
|
-
duplicates = set(self.metrics).intersection(custom_metrics.keys())
|
|
84
|
-
if len(duplicates) != 0:
|
|
85
|
-
raise RuntimeError(f"There are duplicate metrics: {duplicates}")
|
|
86
|
-
|
|
87
|
-
self.metrics = self.metrics | custom_metrics
|
|
88
|
-
|
|
89
|
-
def compute(
|
|
90
|
-
self, output1: List[torch.Tensor], output2: List[torch.Tensor]
|
|
91
|
-
) -> Dict[str, List[Any]]:
|
|
92
|
-
"""
|
|
93
|
-
Compute both built-in metrics (if provided) and custom metrics.
|
|
94
|
-
|
|
95
|
-
Returns
|
|
96
|
-
--------
|
|
97
|
-
Dict[str, Any]
|
|
98
|
-
A dictionary with metric names and their computed values.
|
|
99
|
-
"""
|
|
100
|
-
results: Dict[str, List[Any]] = dict()
|
|
101
|
-
|
|
102
|
-
# Compute built-in metrics
|
|
103
|
-
if self.metrics is not None:
|
|
104
|
-
for m in self.metrics:
|
|
105
|
-
results[m] = list()
|
|
106
|
-
for out1, out2 in zip(output1, output2):
|
|
107
|
-
results[m].append(self.builtin_metrics[m](out1, out2))
|
|
108
|
-
|
|
109
|
-
return results
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
/tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|