tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,125 @@
1
+ """
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+ """
5
+
6
+ from enum import Enum, IntEnum
7
+
8
+ FP32_EXPONENT_BIAS = 127
9
+ FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1)
10
+
11
+ # Enum for rounding modes
12
+ class RoundingMode(IntEnum):
13
+ nearest = 0
14
+ floor = 1
15
+ even = 2
16
+
17
+ @staticmethod
18
+ def string_enums():
19
+ return [s.name for s in list(RoundingMode)]
20
+
21
+ # Enum for scalar data formats
22
+ class ElemFormat(Enum):
23
+ int8 = 1
24
+ int4 = 2
25
+ int2 = 3
26
+ fp8_e5m2 = 4
27
+ fp8_e4m3 = 5
28
+ fp6_e3m2 = 6
29
+ fp6_e2m3 = 7
30
+ fp4 = 8
31
+ fp4_e2m1 = 8
32
+ float16 = 9
33
+ fp16 = 9
34
+ bfloat16 = 10
35
+ bf16 = 10
36
+
37
+ @staticmethod
38
+ def from_str(s):
39
+ assert(s != None), "String elem_format == None"
40
+ s = s.lower()
41
+ if hasattr(ElemFormat, s):
42
+ return getattr(ElemFormat, s)
43
+ else:
44
+ raise Exception("Undefined elem format", s)
45
+
46
+
47
+ def _get_min_norm(ebits):
48
+ """ Valid for all float formats """
49
+ emin = 2 - (2 ** (ebits - 1))
50
+ return 0 if ebits == 0 else 2 ** emin
51
+
52
+
53
+ def _get_max_norm(ebits, mbits):
54
+ """ Valid only for floats that define NaN """
55
+ assert(ebits >= 5), "invalid for floats that don't define NaN"
56
+ emax = 0 if ebits==0 else 2**(ebits - 1) - 1
57
+ return 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
58
+
59
+
60
+ _FORMAT_CACHE = {}
61
+ def _get_format_params(fmt):
62
+ """ Allowed formats:
63
+ - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
64
+ - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
65
+ - bfloatX/bfX: 9 <= X <= 32
66
+ - fp4, no NaN/Inf
67
+ - fp6_e3m2/e2m3, no NaN/Inf
68
+ - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
69
+
70
+ Returns:
71
+ ebits: exponent bits
72
+ mbits: mantissa bits: includes sign and implicit bits
73
+ emax: max normal exponent
74
+ max_norm: max normal number
75
+ min_norm: min normal number
76
+ """
77
+ if type(fmt) is str:
78
+ fmt = ElemFormat.from_str(fmt)
79
+
80
+ if fmt in _FORMAT_CACHE:
81
+ return _FORMAT_CACHE[fmt]
82
+
83
+ if fmt == ElemFormat.int8:
84
+ ebits, mbits = 0, 8
85
+ emax = 0
86
+ elif fmt == ElemFormat.int4:
87
+ ebits, mbits = 0, 4
88
+ emax = 0
89
+ elif fmt == ElemFormat.int2:
90
+ ebits, mbits = 0, 2
91
+ emax = 0
92
+ elif fmt == ElemFormat.fp8_e5m2:
93
+ ebits, mbits = 5, 4
94
+ emax = 2**(ebits - 1) - 1
95
+ elif fmt == ElemFormat.fp8_e4m3:
96
+ ebits, mbits = 4, 5
97
+ emax = 2**(ebits - 1)
98
+ elif fmt == ElemFormat.fp6_e3m2:
99
+ ebits, mbits = 3, 4
100
+ emax = 2**(ebits - 1)
101
+ elif fmt == ElemFormat.fp6_e2m3:
102
+ ebits, mbits = 2, 5
103
+ emax = 2**(ebits - 1)
104
+ elif fmt == ElemFormat.fp4:
105
+ ebits, mbits = 2, 3
106
+ emax = 2**(ebits - 1)
107
+ elif fmt == ElemFormat.float16:
108
+ ebits, mbits = 5, 12
109
+ emax = 2**(ebits - 1) - 1
110
+ elif fmt == ElemFormat.bfloat16:
111
+ ebits, mbits = 8, 9
112
+ emax = 2**(ebits - 1) - 1
113
+ else:
114
+ raise Exception("Unknown element format %s" % fmt)
115
+
116
+ if fmt != ElemFormat.fp8_e4m3:
117
+ max_norm = 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
118
+ else:
119
+ max_norm = 2**emax * 1.75 # FP8 has custom max_norm
120
+
121
+ min_norm = _get_min_norm(ebits)
122
+
123
+ _FORMAT_CACHE[fmt] = (ebits, mbits, emax, max_norm, min_norm)
124
+
125
+ return ebits, mbits, emax, max_norm, min_norm
@@ -0,0 +1,270 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file was copied from https://github.com/microsoft/microxcaling/tree/v1.1.0
16
+ # and modified for our purpose.
17
+ """
18
+ Copyright (c) Microsoft Corporation.
19
+ Licensed under the MIT License.
20
+
21
+ Name: mx_ops.py
22
+
23
+ Pytorch methods for MX quantization.
24
+
25
+ Usage Notes:
26
+ - Use the "Exposed Methods" below to implement autograd functions
27
+ - Use autograd functions to then implement torch.nn.Module(s)
28
+ - Do *not* use methods in this file in Modules, they have no defined
29
+ backwards pass and will block gradient computation.
30
+ - Avoid importing internal function if at all possible.
31
+
32
+ Exposed Methods:
33
+ quantize_mx_op - quantizes a tensor to MX format.
34
+
35
+ Internal Methods:
36
+ _safe_lshift, _safe_rshift - fp16 compatible shifts
37
+ _shared_exponents - Returns MX shared exponent for the passed tensor
38
+ _reshape_to_blocks - tiles a tensor by splitting one dim into two
39
+ _undo_reshape_to_blocks - undos the above reshaping
40
+ _quantize_mx - quantizes a tensor to MX format
41
+ """
42
+
43
+ import torch
44
+
45
+ from .elemwise_ops import _quantize_elemwise_core
46
+
47
+ from .formats import (
48
+ _get_format_params,
49
+ FP32_EXPONENT_BIAS,
50
+ FP32_MIN_NORMAL,
51
+ RoundingMode,
52
+ )
53
+
54
+
55
+ # -------------------------------------------------------------------------
56
+ # Helper funcs
57
+ # -------------------------------------------------------------------------
58
+ def _shared_exponents(A, method="max", axes=None, ebits=0):
59
+ """
60
+ Get shared exponents for the passed matrix A.
61
+ Args:
62
+ A {PyTorch tensor} -- Input tensor
63
+ method {str} -- Exponent selection method.
64
+ "max" uses the max absolute value
65
+ "none" uses an exponent for each value (i.e., no sharing)
66
+ axes {list(int)} -- List of integers which specifies the axes across which
67
+ shared exponents are calculated.
68
+ Returns:
69
+ shared_exp {PyTorch tensor} -- Tensor of shared exponents
70
+ """
71
+
72
+ if method == "max":
73
+ if axes is None:
74
+ shared_exp = torch.max(torch.abs(A))
75
+ else:
76
+ shared_exp = A
77
+ for axis in axes:
78
+ shared_exp, _ = torch.max(torch.abs(shared_exp), dim=axis, keepdim=True)
79
+ elif method == "none":
80
+ shared_exp = torch.abs(A)
81
+ else:
82
+ raise Exception("Unrecognized shared exponent selection method %s" % (method))
83
+
84
+ # log2(shared_exp) and truncate to integer
85
+ shared_exp = torch.floor(
86
+ torch.log2(
87
+ shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)
88
+ )
89
+ )
90
+
91
+ # Restrict to [-emax, emax] range
92
+ if ebits > 0:
93
+ emax = 2 ** (ebits - 1) - 1
94
+ # shared_exp = torch.clamp(shared_exp, -emax, emax)
95
+ # Overflow to Inf
96
+ shared_exp[shared_exp > emax] = float("NaN")
97
+ # Underflows are set to -127 which causes them to be
98
+ # flushed to 0 later
99
+ shared_exp[shared_exp < -emax] = -emax
100
+
101
+ return shared_exp
102
+
103
+
104
+ def _reshape_to_blocks(A, axes, block_size):
105
+ if axes is None:
106
+ raise Exception(
107
+ "axes required in order to determine which "
108
+ "dimension toapply block size to"
109
+ )
110
+ if block_size == 0:
111
+ raise Exception("block_size == 0 in _reshape_to_blocks")
112
+
113
+ # Fix axes to be positive and sort them
114
+ axes = [(x + len(A.shape) if x < 0 else x) for x in axes]
115
+ assert all(x >= 0 for x in axes)
116
+ axes = sorted(axes)
117
+
118
+ # Add extra dimension for tiles
119
+ for i in range(len(axes)):
120
+ axes[i] += i # Shift axes due to added dimensions
121
+ A = torch.unsqueeze(A, dim=axes[i] + 1)
122
+
123
+ # Pad to block_size
124
+ orig_shape = A.size()
125
+ pad = []
126
+ for i in range(len(orig_shape)):
127
+ pad += [0, 0]
128
+
129
+ do_padding = False
130
+ for axis in axes:
131
+ pre_pad_size = orig_shape[axis]
132
+ if isinstance(pre_pad_size, torch.Tensor):
133
+ pre_pad_size = int(pre_pad_size.value)
134
+ # Don't pad if the axis is short enough to fit inside one tile
135
+ if pre_pad_size % block_size == 0:
136
+ pad[2 * axis] = 0
137
+ else:
138
+ pad[2 * axis] = block_size - pre_pad_size % block_size
139
+ do_padding = True
140
+
141
+ if do_padding:
142
+ pad = list(reversed(pad))
143
+ A = torch.nn.functional.pad(A, pad, mode="constant")
144
+
145
+ def _reshape(shape, reshape_block_size):
146
+ for axis in axes:
147
+ # Reshape to tiles if axis length > reshape_block_size
148
+ if shape[axis] >= reshape_block_size:
149
+ assert shape[axis] % reshape_block_size == 0
150
+ shape[axis + 1] = reshape_block_size
151
+ shape[axis] = shape[axis] // reshape_block_size
152
+ # Otherwise preserve length and insert a 1 into the shape
153
+ else:
154
+ shape[axis + 1] = shape[axis]
155
+ shape[axis] = 1
156
+ return shape
157
+
158
+ # Reshape to tiles
159
+ padded_shape = A.size()
160
+ reshape = _reshape(list(padded_shape), block_size)
161
+
162
+ A = A.view(reshape)
163
+ return A, axes, orig_shape, padded_shape
164
+
165
+
166
+ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
167
+ # Undo tile reshaping
168
+ A = A.view(padded_shape)
169
+ # Undo padding
170
+ if not list(padded_shape) == list(orig_shape):
171
+ slices = [slice(0, x) for x in orig_shape]
172
+ A = A[slices]
173
+ for axis in reversed(axes):
174
+ # Remove extra dimension
175
+ A = torch.squeeze(A, dim=axis + 1)
176
+ return A
177
+
178
+
179
+ # -------------------------------------------------------------------------
180
+ # Main funcs
181
+ # -------------------------------------------------------------------------
182
+ def _quantize_mx(
183
+ A,
184
+ scale_bits,
185
+ elem_format, # can be None for no quantization
186
+ shared_exp_method="max",
187
+ axes=None,
188
+ block_size=0,
189
+ round="nearest",
190
+ flush_fp32_subnorms=False,
191
+ custom_cuda=False,
192
+ ):
193
+ """Function used for MX* quantization"""
194
+ # Shortcut for no quantization
195
+ if elem_format == None:
196
+ return A
197
+
198
+ assert scale_bits > 0
199
+
200
+ # Make sure axes is a list of non-negative numbers
201
+ axes = [axes] if type(axes) == int else axes
202
+ axes = [x + A.ndim if x < 0 else x for x in axes]
203
+
204
+ # Custom CUDA only supports limited rounding modes
205
+ custom_cuda = custom_cuda and round in RoundingMode.string_enums()
206
+
207
+ ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format)
208
+
209
+ # Perform tiling to the hardware vector size
210
+ if block_size > 0:
211
+ A, axes, orig_shape, padded_shape = _reshape_to_blocks(A, axes, block_size)
212
+
213
+ ####################
214
+ # Quantize
215
+ ####################
216
+ shared_exp_axes = [x + 1 for x in axes] if block_size > 0 else axes
217
+
218
+ # Get shared exponents
219
+ shared_exp = _shared_exponents(
220
+ A,
221
+ method=shared_exp_method,
222
+ axes=shared_exp_axes,
223
+ ebits=0,
224
+ )
225
+
226
+ # Flush subnormal FP32 inputs to zero
227
+ if flush_fp32_subnorms:
228
+ A = A * (shared_exp > -FP32_EXPONENT_BIAS).type(A.dtype)
229
+
230
+ # Offset the max exponent by the largest representable exponent
231
+ # in the element data format
232
+ shared_exp = shared_exp - emax
233
+
234
+ scale_emax = 2 ** (scale_bits - 1) - 1
235
+ shared_exp[shared_exp > scale_emax] = float("NaN")
236
+ shared_exp[shared_exp < -scale_emax] = -scale_emax
237
+
238
+ A = A / (2**shared_exp)
239
+
240
+ A = _quantize_elemwise_core(
241
+ A,
242
+ mbits,
243
+ ebits,
244
+ max_norm,
245
+ round=round,
246
+ allow_denorm=True,
247
+ saturate_normals=True,
248
+ custom_cuda=custom_cuda,
249
+ )
250
+
251
+ A = A * (2**shared_exp)
252
+
253
+ # Undo tile reshaping
254
+ if block_size:
255
+ A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes)
256
+
257
+ return A
258
+
259
+
260
+ # Wrapper function of circle_custom::quantize_mx
261
+ def quantize_mx(
262
+ input_: torch.Tensor,
263
+ elem_format: str,
264
+ axis: int,
265
+ shared_exp_method: str = "max",
266
+ round: str = "nearest",
267
+ ) -> torch.Tensor:
268
+ return torch.ops.circle_custom.quantize_mx(
269
+ input_, elem_format, axis, shared_exp_method=shared_exp_method, round=round
270
+ )
tico/utils/padding.py ADDED
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.utils.errors import InvalidArgumentError
18
+
19
+ SAME = 0
20
+ VALID = 1
21
+
22
+
23
+ def is_valid_padding(padding: str | list):
24
+ if isinstance(padding, str):
25
+ return padding == "valid"
26
+
27
+ if isinstance(padding, list):
28
+ assert len(padding) == 2, "Padding should be a list of length 2."
29
+ return padding == [0, 0]
30
+
31
+ raise InvalidArgumentError("Invalid padding.")
32
+
33
+
34
+ def is_same_padding(
35
+ padding: str | list, input_shape: list | torch.Size, output_shape: list | torch.Size
36
+ ):
37
+ if isinstance(padding, str):
38
+ return padding == "same"
39
+
40
+ if isinstance(padding, list):
41
+ assert len(padding) == 2, "Padding should be a list of length 2."
42
+
43
+ input_HW = input_shape[1:2] # N H W C
44
+ output_HW = output_shape[1:2] # N H W C
45
+ return input_HW == output_HW
46
+
47
+ raise InvalidArgumentError("Invalid padding.")
tico/utils/passes.py ADDED
@@ -0,0 +1,76 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import List
19
+
20
+ from torch.export import ExportedProgram
21
+
22
+
23
+ @dataclass
24
+ class PassResult:
25
+ modified: bool
26
+
27
+
28
+ class PassBase(ABC):
29
+ """
30
+ Base interface for passes.
31
+ """
32
+
33
+ @abstractmethod
34
+ def call(self, exported_program: ExportedProgram) -> PassResult:
35
+ pass
36
+
37
+
38
+ class PassStrategy(Enum):
39
+ # Run passes until there are no changes.
40
+ UNTIL_NO_CHANGE = (1,)
41
+ # Same as `UNTIL_NO_CHANGE` but it starts agian from the beginning.
42
+ RESTART = (2,)
43
+
44
+
45
+ class PassManager:
46
+ def __init__(
47
+ self,
48
+ passes: List[PassBase],
49
+ strategy: PassStrategy = PassStrategy.RESTART,
50
+ ):
51
+ self.passes: List[PassBase] = passes
52
+ self.strategy: PassStrategy = strategy
53
+
54
+ def run(self, exported_program: ExportedProgram):
55
+ MAXIMUM_STEP_COUNT = 1000
56
+ step = 0
57
+ while True:
58
+ modified = False
59
+ for _pass in self.passes:
60
+ # Automatically update the signatures of the input and output.
61
+ # https://github.com/pytorch/executorch/issues/4013#issuecomment-2187161844
62
+ with exported_program.graph_module._set_replace_hook(
63
+ exported_program.graph_signature.get_replace_hook()
64
+ ):
65
+ result = _pass.call(exported_program)
66
+ modified = modified or result.modified
67
+ if modified and self.strategy == PassStrategy.RESTART:
68
+ break
69
+
70
+ if not modified:
71
+ break
72
+ step += 1
73
+
74
+ assert (
75
+ step < MAXIMUM_STEP_COUNT
76
+ ), f"Loop iterated for {MAXIMUM_STEP_COUNT} times. Circular loop is suspected in {self.passes}"