tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
tico/utils/mx/formats.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) Microsoft Corporation.
|
3
|
+
Licensed under the MIT License.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from enum import Enum, IntEnum
|
7
|
+
|
8
|
+
FP32_EXPONENT_BIAS = 127
|
9
|
+
FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1)
|
10
|
+
|
11
|
+
# Enum for rounding modes
|
12
|
+
class RoundingMode(IntEnum):
|
13
|
+
nearest = 0
|
14
|
+
floor = 1
|
15
|
+
even = 2
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def string_enums():
|
19
|
+
return [s.name for s in list(RoundingMode)]
|
20
|
+
|
21
|
+
# Enum for scalar data formats
|
22
|
+
class ElemFormat(Enum):
|
23
|
+
int8 = 1
|
24
|
+
int4 = 2
|
25
|
+
int2 = 3
|
26
|
+
fp8_e5m2 = 4
|
27
|
+
fp8_e4m3 = 5
|
28
|
+
fp6_e3m2 = 6
|
29
|
+
fp6_e2m3 = 7
|
30
|
+
fp4 = 8
|
31
|
+
fp4_e2m1 = 8
|
32
|
+
float16 = 9
|
33
|
+
fp16 = 9
|
34
|
+
bfloat16 = 10
|
35
|
+
bf16 = 10
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def from_str(s):
|
39
|
+
assert(s != None), "String elem_format == None"
|
40
|
+
s = s.lower()
|
41
|
+
if hasattr(ElemFormat, s):
|
42
|
+
return getattr(ElemFormat, s)
|
43
|
+
else:
|
44
|
+
raise Exception("Undefined elem format", s)
|
45
|
+
|
46
|
+
|
47
|
+
def _get_min_norm(ebits):
|
48
|
+
""" Valid for all float formats """
|
49
|
+
emin = 2 - (2 ** (ebits - 1))
|
50
|
+
return 0 if ebits == 0 else 2 ** emin
|
51
|
+
|
52
|
+
|
53
|
+
def _get_max_norm(ebits, mbits):
|
54
|
+
""" Valid only for floats that define NaN """
|
55
|
+
assert(ebits >= 5), "invalid for floats that don't define NaN"
|
56
|
+
emax = 0 if ebits==0 else 2**(ebits - 1) - 1
|
57
|
+
return 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
|
58
|
+
|
59
|
+
|
60
|
+
_FORMAT_CACHE = {}
|
61
|
+
def _get_format_params(fmt):
|
62
|
+
""" Allowed formats:
|
63
|
+
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
|
64
|
+
- floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
|
65
|
+
- bfloatX/bfX: 9 <= X <= 32
|
66
|
+
- fp4, no NaN/Inf
|
67
|
+
- fp6_e3m2/e2m3, no NaN/Inf
|
68
|
+
- fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
ebits: exponent bits
|
72
|
+
mbits: mantissa bits: includes sign and implicit bits
|
73
|
+
emax: max normal exponent
|
74
|
+
max_norm: max normal number
|
75
|
+
min_norm: min normal number
|
76
|
+
"""
|
77
|
+
if type(fmt) is str:
|
78
|
+
fmt = ElemFormat.from_str(fmt)
|
79
|
+
|
80
|
+
if fmt in _FORMAT_CACHE:
|
81
|
+
return _FORMAT_CACHE[fmt]
|
82
|
+
|
83
|
+
if fmt == ElemFormat.int8:
|
84
|
+
ebits, mbits = 0, 8
|
85
|
+
emax = 0
|
86
|
+
elif fmt == ElemFormat.int4:
|
87
|
+
ebits, mbits = 0, 4
|
88
|
+
emax = 0
|
89
|
+
elif fmt == ElemFormat.int2:
|
90
|
+
ebits, mbits = 0, 2
|
91
|
+
emax = 0
|
92
|
+
elif fmt == ElemFormat.fp8_e5m2:
|
93
|
+
ebits, mbits = 5, 4
|
94
|
+
emax = 2**(ebits - 1) - 1
|
95
|
+
elif fmt == ElemFormat.fp8_e4m3:
|
96
|
+
ebits, mbits = 4, 5
|
97
|
+
emax = 2**(ebits - 1)
|
98
|
+
elif fmt == ElemFormat.fp6_e3m2:
|
99
|
+
ebits, mbits = 3, 4
|
100
|
+
emax = 2**(ebits - 1)
|
101
|
+
elif fmt == ElemFormat.fp6_e2m3:
|
102
|
+
ebits, mbits = 2, 5
|
103
|
+
emax = 2**(ebits - 1)
|
104
|
+
elif fmt == ElemFormat.fp4:
|
105
|
+
ebits, mbits = 2, 3
|
106
|
+
emax = 2**(ebits - 1)
|
107
|
+
elif fmt == ElemFormat.float16:
|
108
|
+
ebits, mbits = 5, 12
|
109
|
+
emax = 2**(ebits - 1) - 1
|
110
|
+
elif fmt == ElemFormat.bfloat16:
|
111
|
+
ebits, mbits = 8, 9
|
112
|
+
emax = 2**(ebits - 1) - 1
|
113
|
+
else:
|
114
|
+
raise Exception("Unknown element format %s" % fmt)
|
115
|
+
|
116
|
+
if fmt != ElemFormat.fp8_e4m3:
|
117
|
+
max_norm = 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
|
118
|
+
else:
|
119
|
+
max_norm = 2**emax * 1.75 # FP8 has custom max_norm
|
120
|
+
|
121
|
+
min_norm = _get_min_norm(ebits)
|
122
|
+
|
123
|
+
_FORMAT_CACHE[fmt] = (ebits, mbits, emax, max_norm, min_norm)
|
124
|
+
|
125
|
+
return ebits, mbits, emax, max_norm, min_norm
|
tico/utils/mx/mx_ops.py
ADDED
@@ -0,0 +1,270 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# This file was copied from https://github.com/microsoft/microxcaling/tree/v1.1.0
|
16
|
+
# and modified for our purpose.
|
17
|
+
"""
|
18
|
+
Copyright (c) Microsoft Corporation.
|
19
|
+
Licensed under the MIT License.
|
20
|
+
|
21
|
+
Name: mx_ops.py
|
22
|
+
|
23
|
+
Pytorch methods for MX quantization.
|
24
|
+
|
25
|
+
Usage Notes:
|
26
|
+
- Use the "Exposed Methods" below to implement autograd functions
|
27
|
+
- Use autograd functions to then implement torch.nn.Module(s)
|
28
|
+
- Do *not* use methods in this file in Modules, they have no defined
|
29
|
+
backwards pass and will block gradient computation.
|
30
|
+
- Avoid importing internal function if at all possible.
|
31
|
+
|
32
|
+
Exposed Methods:
|
33
|
+
quantize_mx_op - quantizes a tensor to MX format.
|
34
|
+
|
35
|
+
Internal Methods:
|
36
|
+
_safe_lshift, _safe_rshift - fp16 compatible shifts
|
37
|
+
_shared_exponents - Returns MX shared exponent for the passed tensor
|
38
|
+
_reshape_to_blocks - tiles a tensor by splitting one dim into two
|
39
|
+
_undo_reshape_to_blocks - undos the above reshaping
|
40
|
+
_quantize_mx - quantizes a tensor to MX format
|
41
|
+
"""
|
42
|
+
|
43
|
+
import torch
|
44
|
+
|
45
|
+
from .elemwise_ops import _quantize_elemwise_core
|
46
|
+
|
47
|
+
from .formats import (
|
48
|
+
_get_format_params,
|
49
|
+
FP32_EXPONENT_BIAS,
|
50
|
+
FP32_MIN_NORMAL,
|
51
|
+
RoundingMode,
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
# -------------------------------------------------------------------------
|
56
|
+
# Helper funcs
|
57
|
+
# -------------------------------------------------------------------------
|
58
|
+
def _shared_exponents(A, method="max", axes=None, ebits=0):
|
59
|
+
"""
|
60
|
+
Get shared exponents for the passed matrix A.
|
61
|
+
Args:
|
62
|
+
A {PyTorch tensor} -- Input tensor
|
63
|
+
method {str} -- Exponent selection method.
|
64
|
+
"max" uses the max absolute value
|
65
|
+
"none" uses an exponent for each value (i.e., no sharing)
|
66
|
+
axes {list(int)} -- List of integers which specifies the axes across which
|
67
|
+
shared exponents are calculated.
|
68
|
+
Returns:
|
69
|
+
shared_exp {PyTorch tensor} -- Tensor of shared exponents
|
70
|
+
"""
|
71
|
+
|
72
|
+
if method == "max":
|
73
|
+
if axes is None:
|
74
|
+
shared_exp = torch.max(torch.abs(A))
|
75
|
+
else:
|
76
|
+
shared_exp = A
|
77
|
+
for axis in axes:
|
78
|
+
shared_exp, _ = torch.max(torch.abs(shared_exp), dim=axis, keepdim=True)
|
79
|
+
elif method == "none":
|
80
|
+
shared_exp = torch.abs(A)
|
81
|
+
else:
|
82
|
+
raise Exception("Unrecognized shared exponent selection method %s" % (method))
|
83
|
+
|
84
|
+
# log2(shared_exp) and truncate to integer
|
85
|
+
shared_exp = torch.floor(
|
86
|
+
torch.log2(
|
87
|
+
shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)
|
88
|
+
)
|
89
|
+
)
|
90
|
+
|
91
|
+
# Restrict to [-emax, emax] range
|
92
|
+
if ebits > 0:
|
93
|
+
emax = 2 ** (ebits - 1) - 1
|
94
|
+
# shared_exp = torch.clamp(shared_exp, -emax, emax)
|
95
|
+
# Overflow to Inf
|
96
|
+
shared_exp[shared_exp > emax] = float("NaN")
|
97
|
+
# Underflows are set to -127 which causes them to be
|
98
|
+
# flushed to 0 later
|
99
|
+
shared_exp[shared_exp < -emax] = -emax
|
100
|
+
|
101
|
+
return shared_exp
|
102
|
+
|
103
|
+
|
104
|
+
def _reshape_to_blocks(A, axes, block_size):
|
105
|
+
if axes is None:
|
106
|
+
raise Exception(
|
107
|
+
"axes required in order to determine which "
|
108
|
+
"dimension toapply block size to"
|
109
|
+
)
|
110
|
+
if block_size == 0:
|
111
|
+
raise Exception("block_size == 0 in _reshape_to_blocks")
|
112
|
+
|
113
|
+
# Fix axes to be positive and sort them
|
114
|
+
axes = [(x + len(A.shape) if x < 0 else x) for x in axes]
|
115
|
+
assert all(x >= 0 for x in axes)
|
116
|
+
axes = sorted(axes)
|
117
|
+
|
118
|
+
# Add extra dimension for tiles
|
119
|
+
for i in range(len(axes)):
|
120
|
+
axes[i] += i # Shift axes due to added dimensions
|
121
|
+
A = torch.unsqueeze(A, dim=axes[i] + 1)
|
122
|
+
|
123
|
+
# Pad to block_size
|
124
|
+
orig_shape = A.size()
|
125
|
+
pad = []
|
126
|
+
for i in range(len(orig_shape)):
|
127
|
+
pad += [0, 0]
|
128
|
+
|
129
|
+
do_padding = False
|
130
|
+
for axis in axes:
|
131
|
+
pre_pad_size = orig_shape[axis]
|
132
|
+
if isinstance(pre_pad_size, torch.Tensor):
|
133
|
+
pre_pad_size = int(pre_pad_size.value)
|
134
|
+
# Don't pad if the axis is short enough to fit inside one tile
|
135
|
+
if pre_pad_size % block_size == 0:
|
136
|
+
pad[2 * axis] = 0
|
137
|
+
else:
|
138
|
+
pad[2 * axis] = block_size - pre_pad_size % block_size
|
139
|
+
do_padding = True
|
140
|
+
|
141
|
+
if do_padding:
|
142
|
+
pad = list(reversed(pad))
|
143
|
+
A = torch.nn.functional.pad(A, pad, mode="constant")
|
144
|
+
|
145
|
+
def _reshape(shape, reshape_block_size):
|
146
|
+
for axis in axes:
|
147
|
+
# Reshape to tiles if axis length > reshape_block_size
|
148
|
+
if shape[axis] >= reshape_block_size:
|
149
|
+
assert shape[axis] % reshape_block_size == 0
|
150
|
+
shape[axis + 1] = reshape_block_size
|
151
|
+
shape[axis] = shape[axis] // reshape_block_size
|
152
|
+
# Otherwise preserve length and insert a 1 into the shape
|
153
|
+
else:
|
154
|
+
shape[axis + 1] = shape[axis]
|
155
|
+
shape[axis] = 1
|
156
|
+
return shape
|
157
|
+
|
158
|
+
# Reshape to tiles
|
159
|
+
padded_shape = A.size()
|
160
|
+
reshape = _reshape(list(padded_shape), block_size)
|
161
|
+
|
162
|
+
A = A.view(reshape)
|
163
|
+
return A, axes, orig_shape, padded_shape
|
164
|
+
|
165
|
+
|
166
|
+
def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
|
167
|
+
# Undo tile reshaping
|
168
|
+
A = A.view(padded_shape)
|
169
|
+
# Undo padding
|
170
|
+
if not list(padded_shape) == list(orig_shape):
|
171
|
+
slices = [slice(0, x) for x in orig_shape]
|
172
|
+
A = A[slices]
|
173
|
+
for axis in reversed(axes):
|
174
|
+
# Remove extra dimension
|
175
|
+
A = torch.squeeze(A, dim=axis + 1)
|
176
|
+
return A
|
177
|
+
|
178
|
+
|
179
|
+
# -------------------------------------------------------------------------
|
180
|
+
# Main funcs
|
181
|
+
# -------------------------------------------------------------------------
|
182
|
+
def _quantize_mx(
|
183
|
+
A,
|
184
|
+
scale_bits,
|
185
|
+
elem_format, # can be None for no quantization
|
186
|
+
shared_exp_method="max",
|
187
|
+
axes=None,
|
188
|
+
block_size=0,
|
189
|
+
round="nearest",
|
190
|
+
flush_fp32_subnorms=False,
|
191
|
+
custom_cuda=False,
|
192
|
+
):
|
193
|
+
"""Function used for MX* quantization"""
|
194
|
+
# Shortcut for no quantization
|
195
|
+
if elem_format == None:
|
196
|
+
return A
|
197
|
+
|
198
|
+
assert scale_bits > 0
|
199
|
+
|
200
|
+
# Make sure axes is a list of non-negative numbers
|
201
|
+
axes = [axes] if type(axes) == int else axes
|
202
|
+
axes = [x + A.ndim if x < 0 else x for x in axes]
|
203
|
+
|
204
|
+
# Custom CUDA only supports limited rounding modes
|
205
|
+
custom_cuda = custom_cuda and round in RoundingMode.string_enums()
|
206
|
+
|
207
|
+
ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format)
|
208
|
+
|
209
|
+
# Perform tiling to the hardware vector size
|
210
|
+
if block_size > 0:
|
211
|
+
A, axes, orig_shape, padded_shape = _reshape_to_blocks(A, axes, block_size)
|
212
|
+
|
213
|
+
####################
|
214
|
+
# Quantize
|
215
|
+
####################
|
216
|
+
shared_exp_axes = [x + 1 for x in axes] if block_size > 0 else axes
|
217
|
+
|
218
|
+
# Get shared exponents
|
219
|
+
shared_exp = _shared_exponents(
|
220
|
+
A,
|
221
|
+
method=shared_exp_method,
|
222
|
+
axes=shared_exp_axes,
|
223
|
+
ebits=0,
|
224
|
+
)
|
225
|
+
|
226
|
+
# Flush subnormal FP32 inputs to zero
|
227
|
+
if flush_fp32_subnorms:
|
228
|
+
A = A * (shared_exp > -FP32_EXPONENT_BIAS).type(A.dtype)
|
229
|
+
|
230
|
+
# Offset the max exponent by the largest representable exponent
|
231
|
+
# in the element data format
|
232
|
+
shared_exp = shared_exp - emax
|
233
|
+
|
234
|
+
scale_emax = 2 ** (scale_bits - 1) - 1
|
235
|
+
shared_exp[shared_exp > scale_emax] = float("NaN")
|
236
|
+
shared_exp[shared_exp < -scale_emax] = -scale_emax
|
237
|
+
|
238
|
+
A = A / (2**shared_exp)
|
239
|
+
|
240
|
+
A = _quantize_elemwise_core(
|
241
|
+
A,
|
242
|
+
mbits,
|
243
|
+
ebits,
|
244
|
+
max_norm,
|
245
|
+
round=round,
|
246
|
+
allow_denorm=True,
|
247
|
+
saturate_normals=True,
|
248
|
+
custom_cuda=custom_cuda,
|
249
|
+
)
|
250
|
+
|
251
|
+
A = A * (2**shared_exp)
|
252
|
+
|
253
|
+
# Undo tile reshaping
|
254
|
+
if block_size:
|
255
|
+
A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes)
|
256
|
+
|
257
|
+
return A
|
258
|
+
|
259
|
+
|
260
|
+
# Wrapper function of circle_custom::quantize_mx
|
261
|
+
def quantize_mx(
|
262
|
+
input_: torch.Tensor,
|
263
|
+
elem_format: str,
|
264
|
+
axis: int,
|
265
|
+
shared_exp_method: str = "max",
|
266
|
+
round: str = "nearest",
|
267
|
+
) -> torch.Tensor:
|
268
|
+
return torch.ops.circle_custom.quantize_mx(
|
269
|
+
input_, elem_format, axis, shared_exp_method=shared_exp_method, round=round
|
270
|
+
)
|
tico/utils/padding.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from tico.utils.errors import InvalidArgumentError
|
18
|
+
|
19
|
+
SAME = 0
|
20
|
+
VALID = 1
|
21
|
+
|
22
|
+
|
23
|
+
def is_valid_padding(padding: str | list):
|
24
|
+
if isinstance(padding, str):
|
25
|
+
return padding == "valid"
|
26
|
+
|
27
|
+
if isinstance(padding, list):
|
28
|
+
assert len(padding) == 2, "Padding should be a list of length 2."
|
29
|
+
return padding == [0, 0]
|
30
|
+
|
31
|
+
raise InvalidArgumentError("Invalid padding.")
|
32
|
+
|
33
|
+
|
34
|
+
def is_same_padding(
|
35
|
+
padding: str | list, input_shape: list | torch.Size, output_shape: list | torch.Size
|
36
|
+
):
|
37
|
+
if isinstance(padding, str):
|
38
|
+
return padding == "same"
|
39
|
+
|
40
|
+
if isinstance(padding, list):
|
41
|
+
assert len(padding) == 2, "Padding should be a list of length 2."
|
42
|
+
|
43
|
+
input_HW = input_shape[1:2] # N H W C
|
44
|
+
output_HW = output_shape[1:2] # N H W C
|
45
|
+
return input_HW == output_HW
|
46
|
+
|
47
|
+
raise InvalidArgumentError("Invalid padding.")
|
tico/utils/passes.py
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from enum import Enum
|
18
|
+
from typing import List
|
19
|
+
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class PassResult:
|
25
|
+
modified: bool
|
26
|
+
|
27
|
+
|
28
|
+
class PassBase(ABC):
|
29
|
+
"""
|
30
|
+
Base interface for passes.
|
31
|
+
"""
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class PassStrategy(Enum):
|
39
|
+
# Run passes until there are no changes.
|
40
|
+
UNTIL_NO_CHANGE = (1,)
|
41
|
+
# Same as `UNTIL_NO_CHANGE` but it starts agian from the beginning.
|
42
|
+
RESTART = (2,)
|
43
|
+
|
44
|
+
|
45
|
+
class PassManager:
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
passes: List[PassBase],
|
49
|
+
strategy: PassStrategy = PassStrategy.RESTART,
|
50
|
+
):
|
51
|
+
self.passes: List[PassBase] = passes
|
52
|
+
self.strategy: PassStrategy = strategy
|
53
|
+
|
54
|
+
def run(self, exported_program: ExportedProgram):
|
55
|
+
MAXIMUM_STEP_COUNT = 1000
|
56
|
+
step = 0
|
57
|
+
while True:
|
58
|
+
modified = False
|
59
|
+
for _pass in self.passes:
|
60
|
+
# Automatically update the signatures of the input and output.
|
61
|
+
# https://github.com/pytorch/executorch/issues/4013#issuecomment-2187161844
|
62
|
+
with exported_program.graph_module._set_replace_hook(
|
63
|
+
exported_program.graph_signature.get_replace_hook()
|
64
|
+
):
|
65
|
+
result = _pass.call(exported_program)
|
66
|
+
modified = modified or result.modified
|
67
|
+
if modified and self.strategy == PassStrategy.RESTART:
|
68
|
+
break
|
69
|
+
|
70
|
+
if not modified:
|
71
|
+
break
|
72
|
+
step += 1
|
73
|
+
|
74
|
+
assert (
|
75
|
+
step < MAXIMUM_STEP_COUNT
|
76
|
+
), f"Loop iterated for {MAXIMUM_STEP_COUNT} times. Circular loop is suspected in {self.passes}"
|