tico 0.1.0.dev250427__py3-none-any.whl → 0.1.0.dev250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/register_custom_op.py +43 -0
- {tico-0.1.0.dev250427.dist-info → tico-0.1.0.dev250429.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250427.dist-info → tico-0.1.0.dev250429.dist-info}/RECORD +12 -8
- {tico-0.1.0.dev250427.dist-info → tico-0.1.0.dev250429.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250427.dist-info → tico-0.1.0.dev250429.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250427.dist-info → tico-0.1.0.dev250429.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250427.dist-info → tico-0.1.0.dev250429.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -22,7 +22,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
22
22
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
23
23
|
|
24
24
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
25
|
-
__version__ = '0.1.0.
|
25
|
+
__version__ = '0.1.0.dev250429'
|
26
26
|
|
27
27
|
|
28
28
|
if Version(torch.__version__) < Version("2.5.0"):
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,267 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) Microsoft Corporation.
|
3
|
+
Licensed under the MIT License.
|
4
|
+
|
5
|
+
Name: elemwise_ops.py
|
6
|
+
|
7
|
+
Pytorch functions for elementwise (i.e. bfloat) quantization.
|
8
|
+
|
9
|
+
Usage Notes:
|
10
|
+
- Use the "Exposed Methods" below to implement autograd functions
|
11
|
+
- Use autograd functions to then implement torch.nn.Module(s)
|
12
|
+
- Do *not* use methods in this file in Modules, they have no defined
|
13
|
+
backwards pass and will block gradient computation.
|
14
|
+
- Avoid importing internal function if at all possible.
|
15
|
+
|
16
|
+
Exposed Methods:
|
17
|
+
quantize_elemwise_op - quantizes a tensor to bfloat or other
|
18
|
+
custom float format
|
19
|
+
"""
|
20
|
+
import torch
|
21
|
+
|
22
|
+
from .formats import RoundingMode, _get_format_params
|
23
|
+
from .formats import _get_min_norm, _get_max_norm
|
24
|
+
|
25
|
+
|
26
|
+
# -------------------------------------------------------------------------
|
27
|
+
# Helper funcs
|
28
|
+
# -------------------------------------------------------------------------
|
29
|
+
# Never explicitly compute 2**(-exp) since subnorm numbers have
|
30
|
+
# exponents smaller than -126
|
31
|
+
def _safe_lshift(x, bits, exp):
|
32
|
+
if exp is None:
|
33
|
+
return x * (2**bits)
|
34
|
+
else:
|
35
|
+
return x / (2 ** exp) * (2**bits)
|
36
|
+
|
37
|
+
|
38
|
+
def _safe_rshift(x, bits, exp):
|
39
|
+
if exp is None:
|
40
|
+
return x / (2**bits)
|
41
|
+
else:
|
42
|
+
return x / (2**bits) * (2 ** exp)
|
43
|
+
|
44
|
+
|
45
|
+
def _round_mantissa(A, bits, round, clamp=False):
|
46
|
+
"""
|
47
|
+
Rounds mantissa to nearest bits depending on the rounding method 'round'
|
48
|
+
Args:
|
49
|
+
A {PyTorch tensor} -- Input tensor
|
50
|
+
round {str} -- Rounding method
|
51
|
+
"floor" rounds to the floor
|
52
|
+
"nearest" rounds to ceil or floor, whichever is nearest
|
53
|
+
Returns:
|
54
|
+
A {PyTorch tensor} -- Tensor with mantissas rounded
|
55
|
+
"""
|
56
|
+
|
57
|
+
if round == "dither":
|
58
|
+
rand_A = torch.rand_like(A, requires_grad=False)
|
59
|
+
A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
|
60
|
+
elif round == "floor":
|
61
|
+
A = torch.sign(A) * torch.floor(torch.abs(A))
|
62
|
+
elif round == "nearest":
|
63
|
+
A = torch.sign(A) * torch.floor(torch.abs(A) + 0.5)
|
64
|
+
elif round == "even":
|
65
|
+
absA = torch.abs(A)
|
66
|
+
# find 0.5, 2.5, 4.5 ...
|
67
|
+
maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype)
|
68
|
+
A = torch.sign(A) * (torch.floor(absA + 0.5) - maskA)
|
69
|
+
else:
|
70
|
+
raise Exception("Unrecognized round method %s" % (round))
|
71
|
+
|
72
|
+
# Clip values that cannot be expressed by the specified number of bits
|
73
|
+
if clamp:
|
74
|
+
max_mantissa = 2 ** (bits - 1) - 1
|
75
|
+
A = torch.clamp(A, -max_mantissa, max_mantissa)
|
76
|
+
return A
|
77
|
+
|
78
|
+
|
79
|
+
# -------------------------------------------------------------------------
|
80
|
+
# Main funcs
|
81
|
+
# -------------------------------------------------------------------------
|
82
|
+
def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest',
|
83
|
+
saturate_normals=False, allow_denorm=True,
|
84
|
+
custom_cuda=False):
|
85
|
+
""" Core function used for element-wise quantization
|
86
|
+
Arguments:
|
87
|
+
A {PyTorch tensor} -- A tensor to be quantized
|
88
|
+
bits {int} -- Number of mantissa bits. Includes
|
89
|
+
sign bit and implicit one for floats
|
90
|
+
exp_bits {int} -- Number of exponent bits, 0 for ints
|
91
|
+
max_norm {float} -- Largest representable normal number
|
92
|
+
round {str} -- Rounding mode: (floor, nearest, even)
|
93
|
+
saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
|
94
|
+
that exceed max norm are clamped.
|
95
|
+
Must be True for correct MX conversion.
|
96
|
+
allow_denorm {bool} -- If False, flush denorm numbers in the
|
97
|
+
elem_format to zero.
|
98
|
+
custom_cuda {str} -- If True, use custom CUDA kernels
|
99
|
+
Returns:
|
100
|
+
quantized tensor {PyTorch tensor} -- A tensor that has been quantized
|
101
|
+
"""
|
102
|
+
A_is_sparse = A.is_sparse
|
103
|
+
if A_is_sparse:
|
104
|
+
if A.layout != torch.sparse_coo:
|
105
|
+
raise NotImplementedError("Only COO layout sparse tensors are currently supported.")
|
106
|
+
|
107
|
+
sparse_A = A.coalesce()
|
108
|
+
A = sparse_A.values().clone()
|
109
|
+
|
110
|
+
# custom cuda only support floor and nearest rounding modes
|
111
|
+
custom_cuda = custom_cuda and round in RoundingMode.string_enums()
|
112
|
+
|
113
|
+
if custom_cuda:
|
114
|
+
A = A.contiguous()
|
115
|
+
|
116
|
+
from . import custom_extensions
|
117
|
+
if A.device.type == "cuda":
|
118
|
+
A = custom_extensions.funcs.quantize_elemwise_func_cuda(
|
119
|
+
A, bits, exp_bits, max_norm, RoundingMode[round],
|
120
|
+
saturate_normals, allow_denorm)
|
121
|
+
elif A.device.type == "cpu":
|
122
|
+
A = custom_extensions.funcs.quantize_elemwise_func_cpp(
|
123
|
+
A, bits, exp_bits, max_norm, RoundingMode[round],
|
124
|
+
saturate_normals, allow_denorm)
|
125
|
+
return A
|
126
|
+
|
127
|
+
# Flush values < min_norm to zero if denorms are not allowed
|
128
|
+
if not allow_denorm and exp_bits > 0:
|
129
|
+
min_norm = _get_min_norm(exp_bits)
|
130
|
+
out = (torch.abs(A) >= min_norm).type(A.dtype) * A
|
131
|
+
else:
|
132
|
+
out = A
|
133
|
+
|
134
|
+
if exp_bits != 0:
|
135
|
+
private_exp = torch.floor(torch.log2(
|
136
|
+
torch.abs(A) + (A == 0).type(A.dtype)))
|
137
|
+
|
138
|
+
# The minimum representable exponent for 8 exp bits is -126
|
139
|
+
min_exp = -(2**(exp_bits-1)) + 2
|
140
|
+
private_exp = private_exp.clip(min=min_exp)
|
141
|
+
else:
|
142
|
+
private_exp = None
|
143
|
+
|
144
|
+
# Scale up so appropriate number of bits are in the integer portion of the number
|
145
|
+
out = _safe_lshift(out, bits - 2, private_exp)
|
146
|
+
|
147
|
+
out = _round_mantissa(out, bits, round, clamp=False)
|
148
|
+
|
149
|
+
# Undo scaling
|
150
|
+
out = _safe_rshift(out, bits - 2, private_exp)
|
151
|
+
|
152
|
+
# Set values > max_norm to Inf if desired, else clamp them
|
153
|
+
if saturate_normals or exp_bits == 0:
|
154
|
+
out = torch.clamp(out, min=-max_norm, max=max_norm)
|
155
|
+
else:
|
156
|
+
out = torch.where((torch.abs(out) > max_norm),
|
157
|
+
torch.sign(out) * float("Inf"), out)
|
158
|
+
|
159
|
+
# handle Inf/NaN
|
160
|
+
if not custom_cuda:
|
161
|
+
out[A == float("Inf")] = float("Inf")
|
162
|
+
out[A == -float("Inf")] = -float("Inf")
|
163
|
+
out[A == float("NaN")] = float("NaN")
|
164
|
+
|
165
|
+
if A_is_sparse:
|
166
|
+
output = torch.sparse_coo_tensor(sparse_A.indices(), output,
|
167
|
+
sparse_A.size(), dtype=sparse_A.dtype, device=sparse_A.device,
|
168
|
+
requires_grad=sparse_A.requires_grad)
|
169
|
+
|
170
|
+
return out
|
171
|
+
|
172
|
+
|
173
|
+
def _quantize_elemwise(A, elem_format, round='nearest', custom_cuda=False,
|
174
|
+
saturate_normals=False, allow_denorm=True):
|
175
|
+
""" Quantize values to a defined format. See _quantize_elemwise_core()
|
176
|
+
"""
|
177
|
+
if elem_format == None:
|
178
|
+
return A
|
179
|
+
|
180
|
+
ebits, mbits, _, max_norm, _ = _get_format_params(elem_format)
|
181
|
+
|
182
|
+
output = _quantize_elemwise_core(
|
183
|
+
A, mbits, ebits, max_norm,
|
184
|
+
round=round, allow_denorm=allow_denorm,
|
185
|
+
saturate_normals=saturate_normals,
|
186
|
+
custom_cuda=custom_cuda)
|
187
|
+
|
188
|
+
return output
|
189
|
+
|
190
|
+
|
191
|
+
def _quantize_bfloat(A, bfloat, round='nearest', custom_cuda=False, allow_denorm=True):
|
192
|
+
""" Quantize values to bfloatX format
|
193
|
+
Arguments:
|
194
|
+
bfloat {int} -- Total number of bits for bfloatX format,
|
195
|
+
Includes 1 sign, 8 exp bits, and variable
|
196
|
+
mantissa bits. Must be >= 9.
|
197
|
+
"""
|
198
|
+
# Shortcut for no quantization
|
199
|
+
if bfloat == 0 or bfloat == 32:
|
200
|
+
return A
|
201
|
+
|
202
|
+
max_norm = _get_max_norm(8, bfloat-7)
|
203
|
+
|
204
|
+
return _quantize_elemwise_core(
|
205
|
+
A, bits=bfloat-7, exp_bits=8, max_norm=max_norm, round=round,
|
206
|
+
allow_denorm=allow_denorm, custom_cuda=custom_cuda)
|
207
|
+
|
208
|
+
|
209
|
+
def _quantize_fp(A, exp_bits=None, mantissa_bits=None,
|
210
|
+
round='nearest', custom_cuda=False, allow_denorm=True):
|
211
|
+
""" Quantize values to IEEE fpX format. The format defines NaN/Inf
|
212
|
+
and subnorm numbers in the same way as FP32 and FP16.
|
213
|
+
Arguments:
|
214
|
+
exp_bits {int} -- number of bits used to store exponent
|
215
|
+
mantissa_bits {int} -- number of bits used to store mantissa, not
|
216
|
+
including sign or implicit 1
|
217
|
+
round {str} -- Rounding mode, (floor, nearest, even)
|
218
|
+
"""
|
219
|
+
# Shortcut for no quantization
|
220
|
+
if exp_bits is None or mantissa_bits is None:
|
221
|
+
return A
|
222
|
+
|
223
|
+
max_norm = _get_max_norm(exp_bits, mantissa_bits+2)
|
224
|
+
|
225
|
+
output = _quantize_elemwise_core(
|
226
|
+
A, bits=mantissa_bits + 2, exp_bits=exp_bits,
|
227
|
+
max_norm=max_norm, round=round, allow_denorm=allow_denorm,
|
228
|
+
custom_cuda=custom_cuda)
|
229
|
+
|
230
|
+
return output
|
231
|
+
|
232
|
+
|
233
|
+
def quantize_elemwise_op(A, mx_specs, round=None):
|
234
|
+
"""A function used for element-wise quantization with mx_specs
|
235
|
+
Arguments:
|
236
|
+
A {PyTorch tensor} -- a tensor that needs to be quantized
|
237
|
+
mx_specs {dictionary} -- dictionary to specify mx_specs
|
238
|
+
round {str} -- Rounding mode, choose from (floor, nearest, even)
|
239
|
+
(default: "nearest")
|
240
|
+
Returns:
|
241
|
+
quantized value {PyTorch tensor} -- a tensor that has been quantized
|
242
|
+
"""
|
243
|
+
if mx_specs is None:
|
244
|
+
return A
|
245
|
+
elif round is None:
|
246
|
+
round = mx_specs['round']
|
247
|
+
|
248
|
+
if mx_specs['bfloat'] == 16 and round == 'even'\
|
249
|
+
and torch.cuda.is_bf16_supported() \
|
250
|
+
and mx_specs['bfloat_subnorms'] == True:
|
251
|
+
return A.to(torch.bfloat16)
|
252
|
+
|
253
|
+
if mx_specs['bfloat'] > 0 and mx_specs['fp'] > 0:
|
254
|
+
raise ValueError("Cannot set both [bfloat] and [fp] in mx_specs.")
|
255
|
+
elif mx_specs['bfloat'] > 9:
|
256
|
+
A = _quantize_bfloat(A, bfloat=mx_specs['bfloat'], round=round,
|
257
|
+
custom_cuda=mx_specs['custom_cuda'],
|
258
|
+
allow_denorm=mx_specs['bfloat_subnorms'])
|
259
|
+
elif mx_specs['bfloat'] > 0 and mx_specs['bfloat'] <= 9:
|
260
|
+
raise ValueError("Cannot set [bfloat] <= 9 in mx_specs.")
|
261
|
+
elif mx_specs['fp'] > 6:
|
262
|
+
A = _quantize_fp(A, exp_bits=5, mantissa_bits=mx_specs['fp'] - 6,
|
263
|
+
round=round, custom_cuda=mx_specs['custom_cuda'],
|
264
|
+
allow_denorm=mx_specs['bfloat_subnorms'])
|
265
|
+
elif mx_specs['fp'] > 0 and mx_specs['fp'] <= 6:
|
266
|
+
raise ValueError("Cannot set [fp] <= 6 in mx_specs.")
|
267
|
+
return A
|
tico/utils/mx/formats.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) Microsoft Corporation.
|
3
|
+
Licensed under the MIT License.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from enum import Enum, IntEnum
|
7
|
+
|
8
|
+
FP32_EXPONENT_BIAS = 127
|
9
|
+
FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1)
|
10
|
+
|
11
|
+
# Enum for rounding modes
|
12
|
+
class RoundingMode(IntEnum):
|
13
|
+
nearest = 0
|
14
|
+
floor = 1
|
15
|
+
even = 2
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def string_enums():
|
19
|
+
return [s.name for s in list(RoundingMode)]
|
20
|
+
|
21
|
+
# Enum for scalar data formats
|
22
|
+
class ElemFormat(Enum):
|
23
|
+
int8 = 1
|
24
|
+
int4 = 2
|
25
|
+
int2 = 3
|
26
|
+
fp8_e5m2 = 4
|
27
|
+
fp8_e4m3 = 5
|
28
|
+
fp6_e3m2 = 6
|
29
|
+
fp6_e2m3 = 7
|
30
|
+
fp4 = 8
|
31
|
+
fp4_e2m1 = 8
|
32
|
+
float16 = 9
|
33
|
+
fp16 = 9
|
34
|
+
bfloat16 = 10
|
35
|
+
bf16 = 10
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def from_str(s):
|
39
|
+
assert(s != None), "String elem_format == None"
|
40
|
+
s = s.lower()
|
41
|
+
if hasattr(ElemFormat, s):
|
42
|
+
return getattr(ElemFormat, s)
|
43
|
+
else:
|
44
|
+
raise Exception("Undefined elem format", s)
|
45
|
+
|
46
|
+
|
47
|
+
def _get_min_norm(ebits):
|
48
|
+
""" Valid for all float formats """
|
49
|
+
emin = 2 - (2 ** (ebits - 1))
|
50
|
+
return 0 if ebits == 0 else 2 ** emin
|
51
|
+
|
52
|
+
|
53
|
+
def _get_max_norm(ebits, mbits):
|
54
|
+
""" Valid only for floats that define NaN """
|
55
|
+
assert(ebits >= 5), "invalid for floats that don't define NaN"
|
56
|
+
emax = 0 if ebits==0 else 2**(ebits - 1) - 1
|
57
|
+
return 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
|
58
|
+
|
59
|
+
|
60
|
+
_FORMAT_CACHE = {}
|
61
|
+
def _get_format_params(fmt):
|
62
|
+
""" Allowed formats:
|
63
|
+
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
|
64
|
+
- floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
|
65
|
+
- bfloatX/bfX: 9 <= X <= 32
|
66
|
+
- fp4, no NaN/Inf
|
67
|
+
- fp6_e3m2/e2m3, no NaN/Inf
|
68
|
+
- fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
ebits: exponent bits
|
72
|
+
mbits: mantissa bits: includes sign and implicit bits
|
73
|
+
emax: max normal exponent
|
74
|
+
max_norm: max normal number
|
75
|
+
min_norm: min normal number
|
76
|
+
"""
|
77
|
+
if type(fmt) is str:
|
78
|
+
fmt = ElemFormat.from_str(fmt)
|
79
|
+
|
80
|
+
if fmt in _FORMAT_CACHE:
|
81
|
+
return _FORMAT_CACHE[fmt]
|
82
|
+
|
83
|
+
if fmt == ElemFormat.int8:
|
84
|
+
ebits, mbits = 0, 8
|
85
|
+
emax = 0
|
86
|
+
elif fmt == ElemFormat.int4:
|
87
|
+
ebits, mbits = 0, 4
|
88
|
+
emax = 0
|
89
|
+
elif fmt == ElemFormat.int2:
|
90
|
+
ebits, mbits = 0, 2
|
91
|
+
emax = 0
|
92
|
+
elif fmt == ElemFormat.fp8_e5m2:
|
93
|
+
ebits, mbits = 5, 4
|
94
|
+
emax = 2**(ebits - 1) - 1
|
95
|
+
elif fmt == ElemFormat.fp8_e4m3:
|
96
|
+
ebits, mbits = 4, 5
|
97
|
+
emax = 2**(ebits - 1)
|
98
|
+
elif fmt == ElemFormat.fp6_e3m2:
|
99
|
+
ebits, mbits = 3, 4
|
100
|
+
emax = 2**(ebits - 1)
|
101
|
+
elif fmt == ElemFormat.fp6_e2m3:
|
102
|
+
ebits, mbits = 2, 5
|
103
|
+
emax = 2**(ebits - 1)
|
104
|
+
elif fmt == ElemFormat.fp4:
|
105
|
+
ebits, mbits = 2, 3
|
106
|
+
emax = 2**(ebits - 1)
|
107
|
+
elif fmt == ElemFormat.float16:
|
108
|
+
ebits, mbits = 5, 12
|
109
|
+
emax = 2**(ebits - 1) - 1
|
110
|
+
elif fmt == ElemFormat.bfloat16:
|
111
|
+
ebits, mbits = 8, 9
|
112
|
+
emax = 2**(ebits - 1) - 1
|
113
|
+
else:
|
114
|
+
raise Exception("Unknown element format %s" % fmt)
|
115
|
+
|
116
|
+
if fmt != ElemFormat.fp8_e4m3:
|
117
|
+
max_norm = 2**emax * float(2**(mbits-1) - 1) / 2**(mbits-2)
|
118
|
+
else:
|
119
|
+
max_norm = 2**emax * 1.75 # FP8 has custom max_norm
|
120
|
+
|
121
|
+
min_norm = _get_min_norm(ebits)
|
122
|
+
|
123
|
+
_FORMAT_CACHE[fmt] = (ebits, mbits, emax, max_norm, min_norm)
|
124
|
+
|
125
|
+
return ebits, mbits, emax, max_norm, min_norm
|
tico/utils/mx/mx_ops.py
ADDED
@@ -0,0 +1,270 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# This file was copied from https://github.com/microsoft/microxcaling/tree/v1.1.0
|
16
|
+
# and modified for our purpose.
|
17
|
+
"""
|
18
|
+
Copyright (c) Microsoft Corporation.
|
19
|
+
Licensed under the MIT License.
|
20
|
+
|
21
|
+
Name: mx_ops.py
|
22
|
+
|
23
|
+
Pytorch methods for MX quantization.
|
24
|
+
|
25
|
+
Usage Notes:
|
26
|
+
- Use the "Exposed Methods" below to implement autograd functions
|
27
|
+
- Use autograd functions to then implement torch.nn.Module(s)
|
28
|
+
- Do *not* use methods in this file in Modules, they have no defined
|
29
|
+
backwards pass and will block gradient computation.
|
30
|
+
- Avoid importing internal function if at all possible.
|
31
|
+
|
32
|
+
Exposed Methods:
|
33
|
+
quantize_mx_op - quantizes a tensor to MX format.
|
34
|
+
|
35
|
+
Internal Methods:
|
36
|
+
_safe_lshift, _safe_rshift - fp16 compatible shifts
|
37
|
+
_shared_exponents - Returns MX shared exponent for the passed tensor
|
38
|
+
_reshape_to_blocks - tiles a tensor by splitting one dim into two
|
39
|
+
_undo_reshape_to_blocks - undos the above reshaping
|
40
|
+
_quantize_mx - quantizes a tensor to MX format
|
41
|
+
"""
|
42
|
+
|
43
|
+
import torch
|
44
|
+
|
45
|
+
from .elemwise_ops import _quantize_elemwise_core
|
46
|
+
|
47
|
+
from .formats import (
|
48
|
+
_get_format_params,
|
49
|
+
FP32_EXPONENT_BIAS,
|
50
|
+
FP32_MIN_NORMAL,
|
51
|
+
RoundingMode,
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
# -------------------------------------------------------------------------
|
56
|
+
# Helper funcs
|
57
|
+
# -------------------------------------------------------------------------
|
58
|
+
def _shared_exponents(A, method="max", axes=None, ebits=0):
|
59
|
+
"""
|
60
|
+
Get shared exponents for the passed matrix A.
|
61
|
+
Args:
|
62
|
+
A {PyTorch tensor} -- Input tensor
|
63
|
+
method {str} -- Exponent selection method.
|
64
|
+
"max" uses the max absolute value
|
65
|
+
"none" uses an exponent for each value (i.e., no sharing)
|
66
|
+
axes {list(int)} -- List of integers which specifies the axes across which
|
67
|
+
shared exponents are calculated.
|
68
|
+
Returns:
|
69
|
+
shared_exp {PyTorch tensor} -- Tensor of shared exponents
|
70
|
+
"""
|
71
|
+
|
72
|
+
if method == "max":
|
73
|
+
if axes is None:
|
74
|
+
shared_exp = torch.max(torch.abs(A))
|
75
|
+
else:
|
76
|
+
shared_exp = A
|
77
|
+
for axis in axes:
|
78
|
+
shared_exp, _ = torch.max(torch.abs(shared_exp), dim=axis, keepdim=True)
|
79
|
+
elif method == "none":
|
80
|
+
shared_exp = torch.abs(A)
|
81
|
+
else:
|
82
|
+
raise Exception("Unrecognized shared exponent selection method %s" % (method))
|
83
|
+
|
84
|
+
# log2(shared_exp) and truncate to integer
|
85
|
+
shared_exp = torch.floor(
|
86
|
+
torch.log2(
|
87
|
+
shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)
|
88
|
+
)
|
89
|
+
)
|
90
|
+
|
91
|
+
# Restrict to [-emax, emax] range
|
92
|
+
if ebits > 0:
|
93
|
+
emax = 2 ** (ebits - 1) - 1
|
94
|
+
# shared_exp = torch.clamp(shared_exp, -emax, emax)
|
95
|
+
# Overflow to Inf
|
96
|
+
shared_exp[shared_exp > emax] = float("NaN")
|
97
|
+
# Underflows are set to -127 which causes them to be
|
98
|
+
# flushed to 0 later
|
99
|
+
shared_exp[shared_exp < -emax] = -emax
|
100
|
+
|
101
|
+
return shared_exp
|
102
|
+
|
103
|
+
|
104
|
+
def _reshape_to_blocks(A, axes, block_size):
|
105
|
+
if axes is None:
|
106
|
+
raise Exception(
|
107
|
+
"axes required in order to determine which "
|
108
|
+
"dimension toapply block size to"
|
109
|
+
)
|
110
|
+
if block_size == 0:
|
111
|
+
raise Exception("block_size == 0 in _reshape_to_blocks")
|
112
|
+
|
113
|
+
# Fix axes to be positive and sort them
|
114
|
+
axes = [(x + len(A.shape) if x < 0 else x) for x in axes]
|
115
|
+
assert all(x >= 0 for x in axes)
|
116
|
+
axes = sorted(axes)
|
117
|
+
|
118
|
+
# Add extra dimension for tiles
|
119
|
+
for i in range(len(axes)):
|
120
|
+
axes[i] += i # Shift axes due to added dimensions
|
121
|
+
A = torch.unsqueeze(A, dim=axes[i] + 1)
|
122
|
+
|
123
|
+
# Pad to block_size
|
124
|
+
orig_shape = A.size()
|
125
|
+
pad = []
|
126
|
+
for i in range(len(orig_shape)):
|
127
|
+
pad += [0, 0]
|
128
|
+
|
129
|
+
do_padding = False
|
130
|
+
for axis in axes:
|
131
|
+
pre_pad_size = orig_shape[axis]
|
132
|
+
if isinstance(pre_pad_size, torch.Tensor):
|
133
|
+
pre_pad_size = int(pre_pad_size.value)
|
134
|
+
# Don't pad if the axis is short enough to fit inside one tile
|
135
|
+
if pre_pad_size % block_size == 0:
|
136
|
+
pad[2 * axis] = 0
|
137
|
+
else:
|
138
|
+
pad[2 * axis] = block_size - pre_pad_size % block_size
|
139
|
+
do_padding = True
|
140
|
+
|
141
|
+
if do_padding:
|
142
|
+
pad = list(reversed(pad))
|
143
|
+
A = torch.nn.functional.pad(A, pad, mode="constant")
|
144
|
+
|
145
|
+
def _reshape(shape, reshape_block_size):
|
146
|
+
for axis in axes:
|
147
|
+
# Reshape to tiles if axis length > reshape_block_size
|
148
|
+
if shape[axis] >= reshape_block_size:
|
149
|
+
assert shape[axis] % reshape_block_size == 0
|
150
|
+
shape[axis + 1] = reshape_block_size
|
151
|
+
shape[axis] = shape[axis] // reshape_block_size
|
152
|
+
# Otherwise preserve length and insert a 1 into the shape
|
153
|
+
else:
|
154
|
+
shape[axis + 1] = shape[axis]
|
155
|
+
shape[axis] = 1
|
156
|
+
return shape
|
157
|
+
|
158
|
+
# Reshape to tiles
|
159
|
+
padded_shape = A.size()
|
160
|
+
reshape = _reshape(list(padded_shape), block_size)
|
161
|
+
|
162
|
+
A = A.view(reshape)
|
163
|
+
return A, axes, orig_shape, padded_shape
|
164
|
+
|
165
|
+
|
166
|
+
def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
|
167
|
+
# Undo tile reshaping
|
168
|
+
A = A.view(padded_shape)
|
169
|
+
# Undo padding
|
170
|
+
if not list(padded_shape) == list(orig_shape):
|
171
|
+
slices = [slice(0, x) for x in orig_shape]
|
172
|
+
A = A[slices]
|
173
|
+
for axis in reversed(axes):
|
174
|
+
# Remove extra dimension
|
175
|
+
A = torch.squeeze(A, dim=axis + 1)
|
176
|
+
return A
|
177
|
+
|
178
|
+
|
179
|
+
# -------------------------------------------------------------------------
|
180
|
+
# Main funcs
|
181
|
+
# -------------------------------------------------------------------------
|
182
|
+
def _quantize_mx(
|
183
|
+
A,
|
184
|
+
scale_bits,
|
185
|
+
elem_format, # can be None for no quantization
|
186
|
+
shared_exp_method="max",
|
187
|
+
axes=None,
|
188
|
+
block_size=0,
|
189
|
+
round="nearest",
|
190
|
+
flush_fp32_subnorms=False,
|
191
|
+
custom_cuda=False,
|
192
|
+
):
|
193
|
+
"""Function used for MX* quantization"""
|
194
|
+
# Shortcut for no quantization
|
195
|
+
if elem_format == None:
|
196
|
+
return A
|
197
|
+
|
198
|
+
assert scale_bits > 0
|
199
|
+
|
200
|
+
# Make sure axes is a list of non-negative numbers
|
201
|
+
axes = [axes] if type(axes) == int else axes
|
202
|
+
axes = [x + A.ndim if x < 0 else x for x in axes]
|
203
|
+
|
204
|
+
# Custom CUDA only supports limited rounding modes
|
205
|
+
custom_cuda = custom_cuda and round in RoundingMode.string_enums()
|
206
|
+
|
207
|
+
ebits, mbits, emax, max_norm, _ = _get_format_params(elem_format)
|
208
|
+
|
209
|
+
# Perform tiling to the hardware vector size
|
210
|
+
if block_size > 0:
|
211
|
+
A, axes, orig_shape, padded_shape = _reshape_to_blocks(A, axes, block_size)
|
212
|
+
|
213
|
+
####################
|
214
|
+
# Quantize
|
215
|
+
####################
|
216
|
+
shared_exp_axes = [x + 1 for x in axes] if block_size > 0 else axes
|
217
|
+
|
218
|
+
# Get shared exponents
|
219
|
+
shared_exp = _shared_exponents(
|
220
|
+
A,
|
221
|
+
method=shared_exp_method,
|
222
|
+
axes=shared_exp_axes,
|
223
|
+
ebits=0,
|
224
|
+
)
|
225
|
+
|
226
|
+
# Flush subnormal FP32 inputs to zero
|
227
|
+
if flush_fp32_subnorms:
|
228
|
+
A = A * (shared_exp > -FP32_EXPONENT_BIAS).type(A.dtype)
|
229
|
+
|
230
|
+
# Offset the max exponent by the largest representable exponent
|
231
|
+
# in the element data format
|
232
|
+
shared_exp = shared_exp - emax
|
233
|
+
|
234
|
+
scale_emax = 2 ** (scale_bits - 1) - 1
|
235
|
+
shared_exp[shared_exp > scale_emax] = float("NaN")
|
236
|
+
shared_exp[shared_exp < -scale_emax] = -scale_emax
|
237
|
+
|
238
|
+
A = A / (2**shared_exp)
|
239
|
+
|
240
|
+
A = _quantize_elemwise_core(
|
241
|
+
A,
|
242
|
+
mbits,
|
243
|
+
ebits,
|
244
|
+
max_norm,
|
245
|
+
round=round,
|
246
|
+
allow_denorm=True,
|
247
|
+
saturate_normals=True,
|
248
|
+
custom_cuda=custom_cuda,
|
249
|
+
)
|
250
|
+
|
251
|
+
A = A * (2**shared_exp)
|
252
|
+
|
253
|
+
# Undo tile reshaping
|
254
|
+
if block_size:
|
255
|
+
A = _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes)
|
256
|
+
|
257
|
+
return A
|
258
|
+
|
259
|
+
|
260
|
+
# Wrapper function of circle_custom::quantize_mx
|
261
|
+
def quantize_mx(
|
262
|
+
input_: torch.Tensor,
|
263
|
+
elem_format: str,
|
264
|
+
axis: int,
|
265
|
+
shared_exp_method: str = "max",
|
266
|
+
round: str = "nearest",
|
267
|
+
) -> torch.Tensor:
|
268
|
+
return torch.ops.circle_custom.quantize_mx(
|
269
|
+
input_, elem_format, axis, shared_exp_method=shared_exp_method, round=round
|
270
|
+
)
|
tico/utils/register_custom_op.py
CHANGED
@@ -18,6 +18,7 @@ import torch
|
|
18
18
|
from torch._subclasses.fake_tensor import FakeTensor
|
19
19
|
from torch.library import custom_op, register_fake
|
20
20
|
|
21
|
+
from tico.utils.mx.mx_ops import _quantize_mx
|
21
22
|
|
22
23
|
# Note that an operator assumes input tensor has NHWC format.
|
23
24
|
def CircleResizeNearestNeighbor():
|
@@ -550,6 +551,47 @@ def CircleInstanceNorm():
|
|
550
551
|
return input.new_empty(input.size())
|
551
552
|
|
552
553
|
|
554
|
+
def CircleQuantizeMX():
|
555
|
+
# This operator conducts fake-quantization of microscaling
|
556
|
+
# NOTE Why using "quantize"_mx not "fake_quantize"_mx?
|
557
|
+
# To align with function name of microxcaling repo.
|
558
|
+
# https://github.com/microsoft/microxcaling/blob/v1.1.0/mx/mx_ops.py#L173
|
559
|
+
@custom_op("circle_custom::quantize_mx", mutates_args=())
|
560
|
+
def quantize_mx(
|
561
|
+
input_: torch.Tensor,
|
562
|
+
elem_format: str,
|
563
|
+
axis: int,
|
564
|
+
shared_exp_method: str = "max",
|
565
|
+
round: str = "nearest",
|
566
|
+
) -> torch.Tensor:
|
567
|
+
if elem_format == "int8":
|
568
|
+
scale_bits = 8
|
569
|
+
block_size = 32
|
570
|
+
else:
|
571
|
+
raise RuntimeError(f"Unsupported elem_format in quantize_mx: {elem_format}")
|
572
|
+
|
573
|
+
result = _quantize_mx(
|
574
|
+
input_,
|
575
|
+
scale_bits=scale_bits,
|
576
|
+
elem_format=elem_format,
|
577
|
+
axes=[axis],
|
578
|
+
block_size=block_size,
|
579
|
+
shared_exp_method=shared_exp_method,
|
580
|
+
round=round,
|
581
|
+
)
|
582
|
+
return result
|
583
|
+
|
584
|
+
@register_fake("circle_custom::quantize_mx")
|
585
|
+
def _(
|
586
|
+
input_: torch.Tensor,
|
587
|
+
elem_format: str,
|
588
|
+
axis: int,
|
589
|
+
shared_exp_method: str = "max", # Fixed
|
590
|
+
round: str = "nearest", # Fixed
|
591
|
+
) -> torch.Tensor:
|
592
|
+
return input_
|
593
|
+
|
594
|
+
|
553
595
|
# Add custom ops to the torch namespace
|
554
596
|
def RegisterOps():
|
555
597
|
CircleResizeNearestNeighbor()
|
@@ -560,3 +602,4 @@ def RegisterOps():
|
|
560
602
|
CircleMaxPool2D()
|
561
603
|
CircleAvgPool2D()
|
562
604
|
CircleInstanceNorm()
|
605
|
+
CircleQuantizeMX()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=VqDTIcftfqxHmepF7PI4HaijESf5C9i2qtOFcdJ9DU8,1181
|
2
2
|
tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -184,13 +184,17 @@ tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
|
|
184
184
|
tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
|
185
185
|
tico/utils/padding.py,sha256=GGO27VbaOvtaMYLDrSaKv7uxjeet566aMJD0PyYeMvQ,1484
|
186
186
|
tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
187
|
-
tico/utils/register_custom_op.py,sha256=
|
187
|
+
tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
|
188
188
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
189
189
|
tico/utils/utils.py,sha256=pybDU1LoNhjEplANig11lboX9yzYRkvFCSmyYth_2Do,10359
|
190
190
|
tico/utils/validate_args_kwargs.py,sha256=krT68b5CfBI9rxBIOsgYSy0LfEJqLfKfRikkp8ep9oQ,24726
|
191
|
-
tico
|
192
|
-
tico
|
193
|
-
tico
|
194
|
-
tico
|
195
|
-
tico-0.1.0.
|
196
|
-
tico-0.1.0.
|
191
|
+
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
192
|
+
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
193
|
+
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
194
|
+
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
195
|
+
tico-0.1.0.dev250429.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
196
|
+
tico-0.1.0.dev250429.dist-info/METADATA,sha256=9CsjedscOI7c1_UCpwOmht40v6n6h-qTsGowGQ5mawo,7353
|
197
|
+
tico-0.1.0.dev250429.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
198
|
+
tico-0.1.0.dev250429.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
199
|
+
tico-0.1.0.dev250429.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
200
|
+
tico-0.1.0.dev250429.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|