tico 0.1.0.dev250428__py3-none-any.whl → 0.1.0.dev250429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -22,7 +22,7 @@ from tico.config import CompileConfigV1, get_default_config
22
22
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
23
23
 
24
24
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
25
- __version__ = '0.1.0.dev250428'
25
+ __version__ = '0.1.0.dev250429'
26
26
 
27
27
 
28
28
  if Version(torch.__version__) < Version("2.5.0"):
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,267 @@
1
+ """
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+
5
+ Name: elemwise_ops.py
6
+
7
+ Pytorch functions for elementwise (i.e. bfloat) quantization.
8
+
9
+ Usage Notes:
10
+ - Use the "Exposed Methods" below to implement autograd functions
11
+ - Use autograd functions to then implement torch.nn.Module(s)
12
+ - Do *not* use methods in this file in Modules, they have no defined
13
+ backwards pass and will block gradient computation.
14
+ - Avoid importing internal function if at all possible.
15
+
16
+ Exposed Methods:
17
+ quantize_elemwise_op - quantizes a tensor to bfloat or other
18
+ custom float format
19
+ """
20
+ import torch
21
+
22
+ from .formats import RoundingMode, _get_format_params
23
+ from .formats import _get_min_norm, _get_max_norm
24
+
25
+
26
+ # -------------------------------------------------------------------------
27
+ # Helper funcs
28
+ # -------------------------------------------------------------------------
29
+ # Never explicitly compute 2**(-exp) since subnorm numbers have
30
+ # exponents smaller than -126
31
+ def _safe_lshift(x, bits, exp):
32
+ if exp is None:
33
+ return x * (2**bits)
34
+ else:
35
+ return x / (2 ** exp) * (2**bits)
36
+
37
+
38
+ def _safe_rshift(x, bits, exp):
39
+ if exp is None:
40
+ return x / (2**bits)
41
+ else:
42
+ return x / (2**bits) * (2 ** exp)
43
+
44
+
45
+ def _round_mantissa(A, bits, round, clamp=False):
46
+ """
47
+ Rounds mantissa to nearest bits depending on the rounding method 'round'
48
+ Args:
49
+ A {PyTorch tensor} -- Input tensor
50
+ round {str} -- Rounding method
51
+ "floor" rounds to the floor
52
+ "nearest" rounds to ceil or floor, whichever is nearest
53
+ Returns:
54
+ A {PyTorch tensor} -- Tensor with mantissas rounded
55
+ """
56
+
57
+ if round == "dither":
58
+ rand_A = torch.rand_like(A, requires_grad=False)
59
+ A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
60
+ elif round == "floor":
61
+ A = torch.sign(A) * torch.floor(torch.abs(A))
62
+ elif round == "nearest":
63
+ A = torch.sign(A) * torch.floor(torch.abs(A) + 0.5)
64
+ elif round == "even":
65
+ absA = torch.abs(A)
66
+ # find 0.5, 2.5, 4.5 ...
67
+ maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype)
68
+ A = torch.sign(A) * (torch.floor(absA + 0.5) - maskA)
69
+ else:
70
+ raise Exception("Unrecognized round method %s" % (round))
71
+
72
+ # Clip values that cannot be expressed by the specified number of bits
73
+ if clamp:
74
+ max_mantissa = 2 ** (bits - 1) - 1
75
+ A = torch.clamp(A, -max_mantissa, max_mantissa)
76
+ return A
77
+
78
+
79
+ # -------------------------------------------------------------------------
80
+ # Main funcs
81
+ # -------------------------------------------------------------------------
82
+ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest',
83
+ saturate_normals=False, allow_denorm=True,
84
+ custom_cuda=False):
85
+ """ Core function used for element-wise quantization
86
+ Arguments:
87
+ A {PyTorch tensor} -- A tensor to be quantized
88
+ bits {int} -- Number of mantissa bits. Includes
89
+ sign bit and implicit one for floats
90
+ exp_bits {int} -- Number of exponent bits, 0 for ints
91
+ max_norm {float} -- Largest representable normal number
92
+ round {str} -- Rounding mode: (floor, nearest, even)
93
+ saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
94
+ that exceed max norm are clamped.
95
+ Must be True for correct MX conversion.
96
+ allow_denorm {bool} -- If False, flush denorm numbers in the
97
+ elem_format to zero.
98
+ custom_cuda {str} -- If True, use custom CUDA kernels
99
+ Returns:
100
+ quantized tensor {PyTorch tensor} -- A tensor that has been quantized
101
+ """
102
+ A_is_sparse = A.is_sparse
103
+ if A_is_sparse:
104
+ if A.layout != torch.sparse_coo:
105
+ raise NotImplementedError("Only COO layout sparse tensors are currently supported.")
106
+
107
+ sparse_A = A.coalesce()
108
+ A = sparse_A.values().clone()
109
+
110
+ # custom cuda only support floor and nearest rounding modes
111
+ custom_cuda = custom_cuda and round in RoundingMode.string_enums()
112
+
113
+ if custom_cuda:
114
+ A = A.contiguous()
115
+
116
+ from . import custom_extensions
117
+ if A.device.type == "cuda":
118
+ A = custom_extensions.funcs.quantize_elemwise_func_cuda(
119
+ A, bits, exp_bits, max_norm, RoundingMode[round],
120
+ saturate_normals, allow_denorm)
121
+ elif A.device.type == "cpu":
122
+ A = custom_extensions.funcs.quantize_elemwise_func_cpp(
123
+ A, bits, exp_bits, max_norm, RoundingMode[round],
124
+ saturate_normals, allow_denorm)
125
+ return A
126
+
127
+ # Flush values < min_norm to zero if denorms are not allowed
128
+ if not allow_denorm and exp_bits > 0:
129
+ min_norm = _get_min_norm(exp_bits)
130
+ out = (torch.abs(A) >= min_norm).type(A.dtype) * A
131
+ else:
132
+ out = A
133
+
134
+ if exp_bits != 0:
135
+ private_exp = torch.floor(torch.log2(
136
+ torch.abs(A) + (A == 0).type(A.dtype)))
137
+
138
+ # The minimum representable exponent for 8 exp bits is -126
139
+ min_exp = -(2**(exp_bits-1)) + 2
140
+ private_exp = private_exp.clip(min=min_exp)
141
+ else:
142
+ private_exp = None
143
+
144
+ # Scale up so appropriate number of bits are in the integer portion of the number
145
+ out = _safe_lshift(out, bits - 2, private_exp)
146
+
147
+ out = _round_mantissa(out, bits, round, clamp=False)
148
+
149
+ # Undo scaling
150
+ out = _safe_rshift(out, bits - 2, private_exp)
151
+
152
+ # Set values > max_norm to Inf if desired, else clamp them
153
+ if saturate_normals or exp_bits == 0:
154
+ out = torch.clamp(out, min=-max_norm, max=max_norm)
155
+ else:
156
+ out = torch.where((torch.abs(out) > max_norm),
157
+ torch.sign(out) * float("Inf"), out)
158
+
159
+ # handle Inf/NaN
160
+ if not custom_cuda:
161
+ out[A == float("Inf")] = float("Inf")
162
+ out[A == -float("Inf")] = -float("Inf")
163
+ out[A == float("NaN")] = float("NaN")
164
+
165
+ if A_is_sparse:
166
+ output = torch.sparse_coo_tensor(sparse_A.indices(), output,
167
+ sparse_A.size(), dtype=sparse_A.dtype, device=sparse_A.device,
168
+ requires_grad=sparse_A.requires_grad)
169
+
170
+ return out
171
+
172
+
173
+ def _quantize_elemwise(A, elem_format, round='nearest', custom_cuda=False,
174
+ saturate_normals=False, allow_denorm=True):
175
+ """ Quantize values to a defined format. See _quantize_elemwise_core()
176
+ """
177
+ if elem_format == None:
178
+ return A
179
+
180
+ ebits, mbits, _, max_norm, _ = _get_format_params(elem_format)
181
+
182
+ output = _quantize_elemwise_core(
183
+ A, mbits, ebits, max_norm,
184
+ round=round, allow_denorm=allow_denorm,
185
+ saturate_normals=saturate_normals,
186
+ custom_cuda=custom_cuda)
187
+
188
+ return output
189
+
190
+
191
+ def _quantize_bfloat(A, bfloat, round='nearest', custom_cuda=False, allow_denorm=True):
192
+ """ Quantize values to bfloatX format
193
+ Arguments:
194
+ bfloat {int} -- Total number of bits for bfloatX format,
195
+ Includes 1 sign, 8 exp bits, and variable
196
+ mantissa bits. Must be >= 9.
197
+ """
198
+ # Shortcut for no quantization
199
+ if bfloat == 0 or bfloat == 32:
200
+ return A
201
+
202
+ max_norm = _get_max_norm(8, bfloat-7)
203
+
204
+ return _quantize_elemwise_core(
205
+ A, bits=bfloat-7, exp_bits=8, max_norm=max_norm, round=round,
206
+ allow_denorm=allow_denorm, custom_cuda=custom_cuda)
207
+
208
+
209
+ def _quantize_fp(A, exp_bits=None, mantissa_bits=None,
210
+ round='nearest', custom_cuda=False, allow_denorm=True):
211
+ """ Quantize values to IEEE fpX format. The format defines NaN/Inf
212
+ and subnorm numbers in the same way as FP32 and FP16.
213
+ Arguments:
214
+ exp_bits {int} -- number of bits used to store exponent
215
+ mantissa_bits {int} -- number of bits used to store mantissa, not
216
+ including sign or implicit 1
217
+ round {str} -- Rounding mode, (floor, nearest, even)
218
+ """
219
+ # Shortcut for no quantization
220
+ if exp_bits is None or mantissa_bits is None:
221
+ return A
222
+
223
+ max_norm = _get_max_norm(exp_bits, mantissa_bits+2)
224
+
225
+ output = _quantize_elemwise_core(
226
+ A, bits=mantissa_bits + 2, exp_bits=exp_bits,
227
+ max_norm=max_norm, round=round, allow_denorm=allow_denorm,
228
+ custom_cuda=custom_cuda)
229
+
230
+ return output
231
+
232
+
233
+ def quantize_elemwise_op(A, mx_specs, round=None):
234
+ """A function used for element-wise quantization with mx_specs
235
+ Arguments:
236
+ A {PyTorch tensor} -- a tensor that needs to be quantized
237
+ mx_specs {dictionary} -- dictionary to specify mx_specs
238
+ round {str} -- Rounding mode, choose from (floor, nearest, even)
239
+ (default: "nearest")
240
+ Returns:
241
+ quantized value {PyTorch tensor} -- a tensor that has been quantized
242
+ """
243
+ if mx_specs is None:
244
+ return A
245
+ elif round is None:
246
+ round = mx_specs['round']
247
+
248
+ if mx_specs['bfloat'] == 16 and round == 'even'\
249
+ and torch.cuda.is_bf16_supported() \
250
+ and mx_specs['bfloat_subnorms'] == True:
251
+ return A.to(torch.bfloat16)
252
+
253
+ if mx_specs['bfloat'] > 0 and mx_specs['fp'] > 0:
254
+ raise ValueError("Cannot set both [bfloat] and [fp] in mx_specs.")
255
+ elif mx_specs['bfloat'] > 9:
256
+ A = _quantize_bfloat(A, bfloat=mx_specs['bfloat'], round=round,
257
+ custom_cuda=mx_specs['custom_cuda'],
258
+ allow_denorm=mx_specs['bfloat_subnorms'])
259
+ elif mx_specs['bfloat'] > 0 and mx_specs['bfloat'] <= 9:
260
+ raise ValueError("Cannot set [bfloat] <= 9 in mx_specs.")
261
+ elif mx_specs['fp'] > 6:
262
+ A = _quantize_fp(A, exp_bits=5, mantissa_bits=mx_specs['fp'] - 6,
263
+ round=round, custom_cuda=mx_specs['custom_cuda'],
264
+ allow_denorm=mx_specs['bfloat_subnorms'])
265
+ elif mx_specs['fp'] > 0 and mx_specs['fp'] <= 6:
266
+ raise ValueError("Cannot set [fp] <= 6 in mx_specs.")
267
+ return A
@@ -0,0 +1,125 @@
1
+ """
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+ """
5
+
6
+ from enum import Enum, IntEnum
7
+
8
+ FP32_EXPONENT_BIAS = 127
9
+ FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1)
10
+
11
+ # Enum for rounding modes
12
+ class RoundingMode(IntEnum):
13
+ nearest = 0
14
+ floor = 1
15
+ even = 2
16
+
17
+ @staticmethod
18
+ def string_enums():
19
+ return [s.name for s in list(RoundingMode)]
20
+
21
+ # Enum for scalar data formats
22
+ class ElemFormat(Enum):
23
+ int8 = 1
24
+ int4 = 2
25
+ int2 = 3
26
+ fp8_e5m2 = 4
27
+ fp8_e4m3 = 5
28
+ fp6_e3m2 = 6
29
+ fp6_e2m3 = 7
30
+ fp4 = 8
31
+ fp4_e2m1 = 8
32
+ float16 = 9
33
+ fp16 = 9
34
+ bfloat16 = 10
35
+ bf16 = 10
36
+
37
+ @staticmethod
38
+ def from_str(s):
39
+ assert(s != None), "String elem_format == None"
40
+ s = s.lower()
41
+ if hasattr(ElemFormat, s):
42
+ return getattr(ElemFormat, s)
43
+ else:
44
+ raise Exception("Undefined elem format", s)
45
+
46
+
47
+ def _get_min_norm(ebits):
48
+ """ Valid for all float formats """
49
+ emin = 2 - (2 ** (ebits - 1))
50
+ return 0 if ebits == 0 else 2 ** emin
51
+
52
+
53
+ def _get_max_norm(ebits, mbits):
54
+ """ Valid only for floats that define NaN """
55
+ assert(ebits >= 5), "invalid for floats that don't define NaN"
56
+ emax = 0 if ebits==0 else 2**(ebits - 1) - 1
57
+ return 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
58
+
59
+
60
+ _FORMAT_CACHE = {}
61
+ def _get_format_params(fmt):
62
+ """ Allowed formats:
63
+ - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
64
+ - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
65
+ - bfloatX/bfX: 9 <= X <= 32
66
+ - fp4, no NaN/Inf
67
+ - fp6_e3m2/e2m3, no NaN/Inf
68
+ - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
69
+
70
+ Returns:
71
+ ebits: exponent bits
72
+ mbits: mantissa bits: includes sign and implicit bits
73
+ emax: max normal exponent
74
+ max_norm: max normal number
75
+ min_norm: min normal number
76
+ """
77
+ if type(fmt) is str:
78
+ fmt = ElemFormat.from_str(fmt)
79
+
80
+ if fmt in _FORMAT_CACHE:
81
+ return _FORMAT_CACHE[fmt]
82
+
83
+ if fmt == ElemFormat.int8:
84
+ ebits, mbits = 0, 8
85
+ emax = 0
86
+ elif fmt == ElemFormat.int4:
87
+ ebits, mbits = 0, 4
88
+ emax = 0
89
+ elif fmt == ElemFormat.int2:
90
+ ebits, mbits = 0, 2
91
+ emax = 0
92
+ elif fmt == ElemFormat.fp8_e5m2:
93
+ ebits, mbits = 5, 4
94
+ emax = 2**(ebits - 1) - 1
95
+ elif fmt == ElemFormat.fp8_e4m3:
96
+ ebits, mbits = 4, 5
97
+ emax = 2**(ebits - 1)
98
+ elif fmt == ElemFormat.fp6_e3m2:
99
+ ebits, mbits = 3, 4
100
+ emax = 2**(ebits - 1)
101
+ elif fmt == ElemFormat.fp6_e2m3:
102
+ ebits, mbits = 2, 5
103
+ emax = 2**(ebits - 1)
104
+ elif fmt == ElemFormat.fp4:
105
+ ebits, mbits = 2, 3
106
+ emax = 2**(ebits - 1)
107
+ elif fmt == ElemFormat.float16:
108
+ ebits, mbits = 5, 12
109
+ emax = 2**(ebits - 1) - 1
110
+ elif fmt == ElemFormat.bfloat16:
111
+ ebits, mbits = 8, 9
112
+ emax = 2**(ebits - 1) - 1
113
+ else:
114
+ raise Exception("Unknown element format %s" % fmt)
115
+
116
+ if fmt != ElemFormat.fp8_e4m3:
117
+ max_norm = 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
118
+ else:
119
+ max_norm = 2**emax * 1.75 # FP8 has custom max_norm
120
+
121
+ min_norm = _get_min_norm(ebits)
122
+
123
+ _FORMAT_CACHE[fmt] = (ebits, mbits, emax, max_norm, min_norm)
124
+
125
+ return ebits, mbits, emax, max_norm, min_norm
@@ -0,0 +1,270 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file was copied from https://github.com/microsoft/microxcaling/tree/v1.1.0
16
+ # and modified for our purpose.
17
+ """
18
+ Copyright (c) Microsoft Corporation.
19
+ Licensed under the MIT License.
20
+
21
+ Name: mx_ops.py
22
+
23
+ Pytorch methods for MX quantization.
24
+
25
+ Usage Notes:
26
+ - Use the "Exposed Methods" below to implement autograd functions
27
+ - Use autograd functions to then implement torch.nn.Module(s)
28
+ - Do *not* use methods in this file in Modules, they have no defined
29
+ backwards pass and will block gradient computation.
30
+ - Avoid importing internal function if at all possible.
31
+
32
+ Exposed Methods:
33
+ quantize_mx_op - quantizes a tensor to MX format.
34
+
35
+ Internal Methods:
36
+ _safe_lshift, _safe_rshift - fp16 compatible shifts
37
+ _shared_exponents - Returns MX shared exponent for the passed tensor
38
+ _reshape_to_blocks - tiles a tensor by splitting one dim into two
39
+ _undo_reshape_to_blocks - undos the above reshaping
40
+ _quantize_mx - quantizes a tensor to MX format
41
+ """
42
+
43
+ import torch
44
+
45
+ from .elemwise_ops import _quantize_elemwise_core
46
+
47
+ from .formats import (
48
+ _get_format_params,
49
+ FP32_EXPONENT_BIAS,
50
+ FP32_MIN_NORMAL,
51
+ RoundingMode,
52
+ )
53
+
54
+
55
+ # -------------------------------------------------------------------------
56
+ # Helper funcs
57
+ # -------------------------------------------------------------------------
58
+ def _shared_exponents(A, method="max", axes=None, ebits=0):
59
+ """
60
+ Get shared exponents for the passed matrix A.
61
+ Args:
62
+ A {PyTorch tensor} -- Input tensor
63
+ method {str} -- Exponent selection method.
64
+ "max" uses the max absolute value
65
+ "none" uses an exponent for each value (i.e., no sharing)
66
+ axes {list(int)} -- List of integers which specifies the axes across which
67
+ shared exponents are calculated.
68
+ Returns:
69
+ shared_exp {PyTorch tensor} -- Tensor of shared exponents
70
+ """
71
+
72
+ if method == "max":
73
+ if axes is None:
74
+ shared_exp = torch.max(torch.abs(A))
75
+ else:
76
+ shared_exp = A
77
+ for axis in axes:
78
+ shared_exp, _ = torch.max(torch.abs(shared_exp), dim=axis, keepdim=True)
79
+ elif method == "none":
80
+ shared_exp = torch.abs(A)
81
+ else:
82
+ raise Exception("Unrecognized shared exponent selection method %s" % (method))
83
+
84
+ # log2(shared_exp) and truncate to integer
85
+ shared_exp = torch.floor(
86
+ torch.log2(
87
+ shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)
88
+ )
89
+ )
90
+
91
+ # Restrict to [-emax, emax] range
92
+ if ebits > 0:
93
+ emax = 2 ** (ebits - 1) - 1
94
+ # shared_exp = torch.clamp(shared_exp, -emax, emax)
95
+ # Overflow to Inf
96
+ shared_exp[shared_exp > emax] = float("NaN")
97
+ # Underflows are set to -127 which causes them to be
98
+ # flushed to 0 later
99
+ shared_exp[shared_exp < -emax] = -emax
100
+
101
+ return shared_exp
102
+
103
+
104
+ def _reshape_to_blocks(A, axes, block_size):
105
+ if axes is None:
106
+ raise Exception(
107
+ "axes required in order to determine which "
108
+ "dimension toapply block size to"
109
+ )
110
+ if block_size == 0:
111
+ raise Exception("block_size == 0 in _reshape_to_blocks")
112
+
113
+ # Fix axes to be positive and sort them
114
+ axes = [(x + len(A.shape) if x < 0 else x) for x in axes]
115
+ assert all(x >= 0 for x in axes)
116
+ axes = sorted(axes)
117
+
118
+ # Add extra dimension for tiles
119
+ for i in range(len(axes)):
120
+ axes[i] += i # Shift axes due to added dimensions
121
+ A = torch.unsqueeze(A, dim=axes[i] + 1)
122
+
123
+ # Pad to block_size
124
+ orig_shape = A.size()
125
+ pad = []
126
+ for i in range(len(orig_shape)):
127
+ pad += [0, 0]
128
+
129
+ do_padding = False
130
+ for axis in axes:
131
+ pre_pad_size = orig_shape[axis]
132
+ if isinstance(pre_pad_size, torch.Tensor):
133
+ pre_pad_size = int(pre_pad_size.value)
134
+ # Don't pad if the axis is short enough to fit inside one tile
135
+ if pre_pad_size % block_size == 0:
136
+ pad[2 * axis] = 0
137
+ else:
138
+ pad[2 * axis] = block_size - pre_pad_size % block_size
139
+ do_padding = True
140
+
141
+ if do_padding:
142
+ pad = list(reversed(pad))
143
+ A = torch.nn.functional.pad(A, pad, mode="constant")
144
+
145
+ def _reshape(shape, reshape_block_size):
146
+ for axis in axes:
147
+ # Reshape to tiles if axis length > reshape_block_size
148
+ if shape[axis] >= reshape_block_size:
149
+ assert shape[axis] % reshape_block_size == 0
150
+ shape[axis + 1] = reshape_block_size
151
+ shape[axis] = shape[axis] // reshape_block_size
152
+ # Otherwise preserve length and insert a 1 into the shape
153
+ else:
154
+ shape[axis + 1] = shape[axis]
155
+ shape[axis] = 1
156
+ return shape
157
+
158
+ # Reshape to tiles
159
+ padded_shape = A.size()
160
+ reshape = _reshape(list(padded_shape), block_size)
161
+
162
+ A = A.view(reshape)
163
+ return A, axes, orig_shape, padded_shape
164
+
165
+
166
+ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
167
+ # Undo tile reshaping
168
+ A = A.view(padded_shape)
169
+ # Undo padding
170
+ if not list(padded_shape) == list(orig_shape):
171
+ slices = [slice(0, x) for x in orig_shape]
172
+ A = A[slices]
173
+ for axis in reversed(axes):
174
+ # Remove extra dimension
175
+ A = torch.squeeze(A, dim=axis + 1)
176
+ return A
177
+
178
+
179
+ # -------------------------------------------------------------------------
180
+ # Main funcs
181
+ # -------------------------------------------------------------------------
182
+ def _quantize_mx(
183
+ A,
184
+ scale_bits,
185
+ elem_format, # can be None for no quantization
186
+ shared_exp_method="max",
187
+ axes=None,
188
+ block_size=0,
189
+ round="nearest",
190
+ flush_fp32_subnorms=False,
191
+ custom_cuda=False,
192
+ ):
193
+ """Function used for MX* quantization"""
194
+ # Shortcut for no quantization
195
+ if elem_format == None:
196
+ return A
197
+
198
+ assert scale_bits > 0
199
+
200
+ # Make sure axes is a list of non-negative numbers
201
+ axes = [axes] if type(axes) == int else axes
202
+ axes = [x + A.ndim if x < 0 else x for x in axes]
203
+
204
+ # Custom CUDA only supports limited rounding modes
205
+ custom_cuda = custom_cuda and round in RoundingMode.string_enums()
206
+
207
+ ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format)
208
+
209
+ # Perform tiling to the hardware vector size
210
+ if block_size > 0:
211
+ A, axes, orig_shape, padded_shape = _reshape_to_blocks(A, axes, block_size)
212
+
213
+ ####################
214
+ # Quantize
215
+ ####################
216
+ shared_exp_axes = [x + 1 for x in axes] if block_size > 0 else axes
217
+
218
+ # Get shared exponents
219
+ shared_exp = _shared_exponents(
220
+ A,
221
+ method=shared_exp_method,
222
+ axes=shared_exp_axes,
223
+ ebits=0,
224
+ )
225
+
226
+ # Flush subnormal FP32 inputs to zero
227
+ if flush_fp32_subnorms:
228
+ A = A * (shared_exp > -FP32_EXPONENT_BIAS).type(A.dtype)
229
+
230
+ # Offset the max exponent by the largest representable exponent
231
+ # in the element data format
232
+ shared_exp = shared_exp - emax
233
+
234
+ scale_emax = 2 ** (scale_bits - 1) - 1
235
+ shared_exp[shared_exp > scale_emax] = float("NaN")
236
+ shared_exp[shared_exp < -scale_emax] = -scale_emax
237
+
238
+ A = A / (2**shared_exp)
239
+
240
+ A = _quantize_elemwise_core(
241
+ A,
242
+ mbits,
243
+ ebits,
244
+ max_norm,
245
+ round=round,
246
+ allow_denorm=True,
247
+ saturate_normals=True,
248
+ custom_cuda=custom_cuda,
249
+ )
250
+
251
+ A = A * (2**shared_exp)
252
+
253
+ # Undo tile reshaping
254
+ if block_size:
255
+ A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes)
256
+
257
+ return A
258
+
259
+
260
+ # Wrapper function of circle_custom::quantize_mx
261
+ def quantize_mx(
262
+ input_: torch.Tensor,
263
+ elem_format: str,
264
+ axis: int,
265
+ shared_exp_method: str = "max",
266
+ round: str = "nearest",
267
+ ) -> torch.Tensor:
268
+ return torch.ops.circle_custom.quantize_mx(
269
+ input_, elem_format, axis, shared_exp_method=shared_exp_method, round=round
270
+ )
@@ -18,6 +18,7 @@ import torch
18
18
  from torch._subclasses.fake_tensor import FakeTensor
19
19
  from torch.library import custom_op, register_fake
20
20
 
21
+ from tico.utils.mx.mx_ops import _quantize_mx
21
22
 
22
23
  # Note that an operator assumes input tensor has NHWC format.
23
24
  def CircleResizeNearestNeighbor():
@@ -550,6 +551,47 @@ def CircleInstanceNorm():
550
551
  return input.new_empty(input.size())
551
552
 
552
553
 
554
+ def CircleQuantizeMX():
555
+ # This operator conducts fake-quantization of microscaling
556
+ # NOTE Why using "quantize"_mx not "fake_quantize"_mx?
557
+ # To align with function name of microxcaling repo.
558
+ # https://github.com/microsoft/microxcaling/blob/v1.1.0/mx/mx_ops.py#L173
559
+ @custom_op("circle_custom::quantize_mx", mutates_args=())
560
+ def quantize_mx(
561
+ input_: torch.Tensor,
562
+ elem_format: str,
563
+ axis: int,
564
+ shared_exp_method: str = "max",
565
+ round: str = "nearest",
566
+ ) -> torch.Tensor:
567
+ if elem_format == "int8":
568
+ scale_bits = 8
569
+ block_size = 32
570
+ else:
571
+ raise RuntimeError(f"Unsupported elem_format in quantize_mx: {elem_format}")
572
+
573
+ result = _quantize_mx(
574
+ input_,
575
+ scale_bits=scale_bits,
576
+ elem_format=elem_format,
577
+ axes=[axis],
578
+ block_size=block_size,
579
+ shared_exp_method=shared_exp_method,
580
+ round=round,
581
+ )
582
+ return result
583
+
584
+ @register_fake("circle_custom::quantize_mx")
585
+ def _(
586
+ input_: torch.Tensor,
587
+ elem_format: str,
588
+ axis: int,
589
+ shared_exp_method: str = "max", # Fixed
590
+ round: str = "nearest", # Fixed
591
+ ) -> torch.Tensor:
592
+ return input_
593
+
594
+
553
595
  # Add custom ops to the torch namespace
554
596
  def RegisterOps():
555
597
  CircleResizeNearestNeighbor()
@@ -560,3 +602,4 @@ def RegisterOps():
560
602
  CircleMaxPool2D()
561
603
  CircleAvgPool2D()
562
604
  CircleInstanceNorm()
605
+ CircleQuantizeMX()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250428
3
+ Version: 0.1.0.dev250429
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=kSSRG796hu-9-49yOofe9x7zKByX-BsUETRlHWZbDHc,1181
1
+ tico/__init__.py,sha256=VqDTIcftfqxHmepF7PI4HaijESf5C9i2qtOFcdJ9DU8,1181
2
2
  tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -184,13 +184,17 @@ tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
184
184
  tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
185
185
  tico/utils/padding.py,sha256=GGO27VbaOvtaMYLDrSaKv7uxjeet566aMJD0PyYeMvQ,1484
186
186
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
187
- tico/utils/register_custom_op.py,sha256=FbMcrg8o5vKWC_aoVxL2GrIcR14KFi1yKG0mFGqXkPY,21595
187
+ tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
188
188
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
189
189
  tico/utils/utils.py,sha256=pybDU1LoNhjEplANig11lboX9yzYRkvFCSmyYth_2Do,10359
190
190
  tico/utils/validate_args_kwargs.py,sha256=krT68b5CfBI9rxBIOsgYSy0LfEJqLfKfRikkp8ep9oQ,24726
191
- tico-0.1.0.dev250428.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
192
- tico-0.1.0.dev250428.dist-info/METADATA,sha256=U5ulm5GVuL1RPg8sA0RVHl-TRPCWRsXM2uZDFPcvnA8,7353
193
- tico-0.1.0.dev250428.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
194
- tico-0.1.0.dev250428.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
195
- tico-0.1.0.dev250428.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
196
- tico-0.1.0.dev250428.dist-info/RECORD,,
191
+ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
192
+ tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
193
+ tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
194
+ tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
195
+ tico-0.1.0.dev250429.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
196
+ tico-0.1.0.dev250429.dist-info/METADATA,sha256=9CsjedscOI7c1_UCpwOmht40v6n6h-qTsGowGQ5mawo,7353
197
+ tico-0.1.0.dev250429.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
198
+ tico-0.1.0.dev250429.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
199
+ tico-0.1.0.dev250429.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
200
+ tico-0.1.0.dev250429.dist-info/RECORD,,