pychop 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pychop/__init__.py +87 -0
- pychop/bfp_formats.py +390 -0
- pychop/bitchop.py +79 -0
- pychop/blas.py +74 -0
- pychop/builtin/__init__.py +4 -0
- pychop/builtin/cparray.py +68 -0
- pychop/builtin/cparray_jax.py +202 -0
- pychop/builtin/cpfloat.py +87 -0
- pychop/builtin/cptensor.py +106 -0
- pychop/chop.py +151 -0
- pychop/demo_harmonic.py +56 -0
- pychop/faultchop.py +222 -0
- pychop/fixed_point.py +123 -0
- pychop/float_params.py +123 -0
- pychop/integer.py +170 -0
- pychop/jx/__init__.py +36 -0
- pychop/jx/bfp_formats.py +236 -0
- pychop/jx/bitchop.py +262 -0
- pychop/jx/blas_jx.py +1591 -0
- pychop/jx/fixed_point.py +726 -0
- pychop/jx/float_point.py +1160 -0
- pychop/jx/integer.py +134 -0
- pychop/jx/layers.py +4609 -0
- pychop/jx/lightchop.py +785 -0
- pychop/jx/mx_formats.py +262 -0
- pychop/jx/squeeze.py +202 -0
- pychop/layers.py +339 -0
- pychop/math_func.py +324 -0
- pychop/mx_formats.py +344 -0
- pychop/np/__init__.py +5 -0
- pychop/np/bfp_formats.py +205 -0
- pychop/np/bitchop.py +222 -0
- pychop/np/blas_np.py +1584 -0
- pychop/np/fixed_point.py +560 -0
- pychop/np/float_point.py +1180 -0
- pychop/np/integer.py +91 -0
- pychop/np/lightchop.py +624 -0
- pychop/np/mx_formats.py +361 -0
- pychop/np/roundit.py +1 -0
- pychop/np/squeeze.py +180 -0
- pychop/optimizers.py +453 -0
- pychop/set_backend.py +80 -0
- pychop/simulate.py +196 -0
- pychop/tch/__init__.py +22 -0
- pychop/tch/bfp_formats.py +278 -0
- pychop/tch/bitchop.py +240 -0
- pychop/tch/blas_th.py +1648 -0
- pychop/tch/fixed_point.py +541 -0
- pychop/tch/float_point.py +991 -0
- pychop/tch/integer.py +166 -0
- pychop/tch/layers.py +2366 -0
- pychop/tch/lightchop.py +816 -0
- pychop/tch/mx_formats.py +438 -0
- pychop/tch/squeeze.py +173 -0
- pychop/utils.py +415 -0
- pychop-0.5.2.dist-info/METADATA +425 -0
- pychop-0.5.2.dist-info/RECORD +60 -0
- pychop-0.5.2.dist-info/WHEEL +5 -0
- pychop-0.5.2.dist-info/licenses/LICENSE +21 -0
- pychop-0.5.2.dist-info/top_level.txt +1 -0
pychop/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pychop: Precision Simulation for Low-Precision Arithmetic
|
|
3
|
+
|
|
4
|
+
A comprehensive Python package for simulating low-precision arithmetic in
|
|
5
|
+
scientific computing and machine learning, with support for multiple backends
|
|
6
|
+
(NumPy, JAX, PyTorch).
|
|
7
|
+
|
|
8
|
+
Supported formats:
|
|
9
|
+
- Floating-point (Chop): IEEE 754 and custom formats
|
|
10
|
+
- Fixed-point (Chopf): Integer and fractional bits
|
|
11
|
+
- Integer quantization (Chopi): Symmetric and asymmetric
|
|
12
|
+
- Block Floating Point (BFP): Shared exponent per block
|
|
13
|
+
- Microscaling (MX): OCP standard with block-level scaling
|
|
14
|
+
|
|
15
|
+
Backends:
|
|
16
|
+
- NumPy: Pure numerical computation
|
|
17
|
+
- JAX: Custom VJP for differentiation
|
|
18
|
+
- PyTorch: Straight-Through Estimator (STE) for QAT
|
|
19
|
+
|
|
20
|
+
Author: Erin Carson, Xinye Chen
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from .chop import Chop
|
|
24
|
+
from .integer import Chopi
|
|
25
|
+
from .fixed_point import Chopf
|
|
26
|
+
|
|
27
|
+
from .simulate import Simulate
|
|
28
|
+
|
|
29
|
+
from .float_params import float_params
|
|
30
|
+
from .bitchop import Bitchop
|
|
31
|
+
from .faultchop import FaultChop
|
|
32
|
+
|
|
33
|
+
from .layers import ChopSTE, ChopfSTE, ChopiSTE
|
|
34
|
+
from .math_func import *
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
__version__ = '0.5.2'
|
|
38
|
+
|
|
39
|
+
import os
|
|
40
|
+
if 'chop_backend' not in os.environ:
|
|
41
|
+
os.environ['chop_backend'] = 'auto'
|
|
42
|
+
|
|
43
|
+
from .set_backend import backend
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
from dataclasses import dataclass
|
|
47
|
+
from typing import Optional
|
|
48
|
+
|
|
49
|
+
from .bfp_formats import (
|
|
50
|
+
BFPSpec,
|
|
51
|
+
BFPTensor,
|
|
52
|
+
BFP_FORMATS,
|
|
53
|
+
create_bfp_spec,
|
|
54
|
+
bfp_quantize,
|
|
55
|
+
print_bfp_format_table,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# MX Formats
|
|
59
|
+
from .mx_formats import (
|
|
60
|
+
MXSpec,
|
|
61
|
+
MXTensor,
|
|
62
|
+
MX_FORMATS,
|
|
63
|
+
create_mx_spec,
|
|
64
|
+
mx_quantize,
|
|
65
|
+
compare_mx_formats,
|
|
66
|
+
print_mx_format_table,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class Customs:
|
|
71
|
+
emax: Optional[int] = None # the maximum value of the exponent.
|
|
72
|
+
t: Optional[int] = None # the number of bits in the significand (including the hidden bit)
|
|
73
|
+
exp_bits: Optional[int] = None # the exponent bits
|
|
74
|
+
sig_bits: Optional[int] = None # the significand bits (not including the hidden bit)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class Options:
|
|
79
|
+
t: int
|
|
80
|
+
emax: int
|
|
81
|
+
prec: int
|
|
82
|
+
subnormal: bool
|
|
83
|
+
rmode: bool
|
|
84
|
+
flip: bool
|
|
85
|
+
explim: bool
|
|
86
|
+
p: float
|
|
87
|
+
|
pychop/bfp_formats.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Block Floating Point (BFP) Format - Backend Agnostic Entry Point
|
|
3
|
+
|
|
4
|
+
This module provides automatic backend detection and routing for BFP quantization.
|
|
5
|
+
Supports NumPy, JAX, and PyTorch backends with automatic detection.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
>>> import pychop
|
|
9
|
+
>>> pychop.backend('auto') # Auto-detect from input
|
|
10
|
+
>>>
|
|
11
|
+
>>> # NumPy
|
|
12
|
+
>>> import numpy as np
|
|
13
|
+
>>> X = np.random.randn(1024, 768)
|
|
14
|
+
>>> X_q = bfp_quantize(X, format='bfp8')
|
|
15
|
+
>>>
|
|
16
|
+
>>> # PyTorch (with STE for training)
|
|
17
|
+
>>> import torch
|
|
18
|
+
>>> X = torch.randn(128, 768, requires_grad=True)
|
|
19
|
+
>>> X_q = bfp_quantize(X, format='bfp8') # Automatic STE!
|
|
20
|
+
>>>
|
|
21
|
+
>>> # JAX
|
|
22
|
+
>>> import jax.numpy as jnp
|
|
23
|
+
>>> X = jnp.array(np.random.randn(512, 512))
|
|
24
|
+
>>> X_q = bfp_quantize(X, format='bfp8')
|
|
25
|
+
|
|
26
|
+
Author: Xinye Chen
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import os
|
|
31
|
+
from typing import Union, Tuple, Optional, Any
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ============================================================================
|
|
36
|
+
# Backend Detection (inline to avoid import issues)
|
|
37
|
+
# ============================================================================
|
|
38
|
+
|
|
39
|
+
def _detect_array_type(x: Any) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Detect backend from input array type.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
x : Any
|
|
46
|
+
Input array or scalar
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
str
|
|
51
|
+
Backend name: 'numpy', 'torch', or 'jax'
|
|
52
|
+
"""
|
|
53
|
+
module = type(x).__module__
|
|
54
|
+
|
|
55
|
+
if "torch" in module:
|
|
56
|
+
return "torch"
|
|
57
|
+
if "jax" in module:
|
|
58
|
+
return "jax"
|
|
59
|
+
return "numpy"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_backend_env() -> str:
|
|
63
|
+
"""Get backend from environment variable."""
|
|
64
|
+
return os.environ.get('chop_backend', 'auto')
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# ============================================================================
|
|
68
|
+
# BFP Format Specification (Backend-Independent)
|
|
69
|
+
# ============================================================================
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class BFPSpec:
|
|
73
|
+
"""
|
|
74
|
+
Block Floating Point format specification.
|
|
75
|
+
|
|
76
|
+
This is backend-independent and shared across all implementations.
|
|
77
|
+
|
|
78
|
+
Attributes
|
|
79
|
+
----------
|
|
80
|
+
name : str
|
|
81
|
+
Format name
|
|
82
|
+
mantissa_bits : int
|
|
83
|
+
Number of mantissa bits per element (including sign)
|
|
84
|
+
block_size : int
|
|
85
|
+
Number of elements sharing same exponent
|
|
86
|
+
exponent_bits : int
|
|
87
|
+
Number of bits for shared exponent
|
|
88
|
+
has_sign : bool
|
|
89
|
+
Whether elements have sign bits
|
|
90
|
+
use_subnormals : bool
|
|
91
|
+
Whether to support subnormal numbers
|
|
92
|
+
"""
|
|
93
|
+
name: str
|
|
94
|
+
mantissa_bits: int
|
|
95
|
+
block_size: int
|
|
96
|
+
exponent_bits: int = 8
|
|
97
|
+
has_sign: bool = True
|
|
98
|
+
use_subnormals: bool = False
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def total_bits_per_block(self) -> int:
|
|
102
|
+
"""Total bits for entire block."""
|
|
103
|
+
return self.exponent_bits + (self.mantissa_bits * self.block_size)
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def compression_vs_fp32(self) -> float:
|
|
107
|
+
"""Compression ratio vs FP32."""
|
|
108
|
+
fp32_bits = 32 * self.block_size
|
|
109
|
+
return fp32_bits / self.total_bits_per_block
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def compression_vs_fp16(self) -> float:
|
|
113
|
+
"""Compression ratio vs FP16."""
|
|
114
|
+
fp16_bits = 16 * self.block_size
|
|
115
|
+
return fp16_bits / self.total_bits_per_block
|
|
116
|
+
|
|
117
|
+
def __repr__(self):
|
|
118
|
+
return (f"BFPSpec(name='{self.name}', mantissa={self.mantissa_bits}b, "
|
|
119
|
+
f"block_size={self.block_size}, exponent={self.exponent_bits}b)")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# Predefined BFP formats (shared across all backends)
|
|
123
|
+
BFP_FORMATS = {
|
|
124
|
+
'bfp16': BFPSpec('bfp16', mantissa_bits=16, block_size=16, exponent_bits=8),
|
|
125
|
+
'bfp12': BFPSpec('bfp12', mantissa_bits=12, block_size=16, exponent_bits=8),
|
|
126
|
+
'bfp8': BFPSpec('bfp8', mantissa_bits=8, block_size=32, exponent_bits=8),
|
|
127
|
+
'bfp6': BFPSpec('bfp6', mantissa_bits=6, block_size=32, exponent_bits=8),
|
|
128
|
+
'bfp4': BFPSpec('bfp4', mantissa_bits=4, block_size=32, exponent_bits=8),
|
|
129
|
+
'bfp3': BFPSpec('bfp3', mantissa_bits=3, block_size=64, exponent_bits=8),
|
|
130
|
+
'bfp2': BFPSpec('bfp2', mantissa_bits=2, block_size=128, exponent_bits=8),
|
|
131
|
+
'flexpoint16': BFPSpec('flexpoint16', mantissa_bits=16, block_size=16, exponent_bits=5),
|
|
132
|
+
'flexpoint8': BFPSpec('flexpoint8', mantissa_bits=8, block_size=32, exponent_bits=5),
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def create_bfp_spec(
|
|
137
|
+
mantissa_bits: int,
|
|
138
|
+
block_size: int,
|
|
139
|
+
exponent_bits: int = 8,
|
|
140
|
+
name: Optional[str] = None
|
|
141
|
+
) -> BFPSpec:
|
|
142
|
+
"""
|
|
143
|
+
Create custom BFP format specification.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
mantissa_bits : int
|
|
148
|
+
Number of mantissa bits (1-32)
|
|
149
|
+
block_size : int
|
|
150
|
+
Elements per block
|
|
151
|
+
exponent_bits : int
|
|
152
|
+
Bits for shared exponent
|
|
153
|
+
name : str, optional
|
|
154
|
+
Custom name
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
BFPSpec
|
|
159
|
+
BFP format specification
|
|
160
|
+
"""
|
|
161
|
+
if name is None:
|
|
162
|
+
name = f"custom_bfp{mantissa_bits}"
|
|
163
|
+
|
|
164
|
+
return BFPSpec(
|
|
165
|
+
name=name,
|
|
166
|
+
mantissa_bits=mantissa_bits,
|
|
167
|
+
block_size=block_size,
|
|
168
|
+
exponent_bits=exponent_bits
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# ============================================================================
|
|
173
|
+
# Backend Detection and Routing
|
|
174
|
+
# ============================================================================
|
|
175
|
+
|
|
176
|
+
def _resolve_backend(X: Any = None) -> str:
|
|
177
|
+
"""
|
|
178
|
+
Resolve which backend to use.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
X : Any, optional
|
|
183
|
+
Input array (if provided, used for auto-detection)
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
str
|
|
188
|
+
Backend name: 'numpy', 'jax', or 'torch'
|
|
189
|
+
"""
|
|
190
|
+
env_backend = _get_backend_env()
|
|
191
|
+
|
|
192
|
+
if env_backend == 'auto':
|
|
193
|
+
if X is not None:
|
|
194
|
+
return _detect_array_type(X)
|
|
195
|
+
else:
|
|
196
|
+
# Default to numpy if no input provided
|
|
197
|
+
return 'numpy'
|
|
198
|
+
|
|
199
|
+
if env_backend not in {'numpy', 'jax', 'torch'}:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"Invalid backend: {env_backend}. "
|
|
202
|
+
"Must be 'numpy', 'jax', 'torch', or 'auto'."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return env_backend
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _get_backend_module(backend: str):
|
|
209
|
+
"""
|
|
210
|
+
Get backend-specific BFP implementation.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
backend : str
|
|
215
|
+
Backend name
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
module
|
|
220
|
+
Backend-specific BFP module
|
|
221
|
+
"""
|
|
222
|
+
if backend == 'torch':
|
|
223
|
+
try:
|
|
224
|
+
from .tch import bfp_formats as backend_module
|
|
225
|
+
except ImportError:
|
|
226
|
+
raise ImportError(
|
|
227
|
+
"PyTorch backend not available. "
|
|
228
|
+
"Install with: pip install torch"
|
|
229
|
+
)
|
|
230
|
+
elif backend == 'jax':
|
|
231
|
+
try:
|
|
232
|
+
from .jx import bfp_formats as backend_module
|
|
233
|
+
except ImportError:
|
|
234
|
+
raise ImportError(
|
|
235
|
+
"JAX backend not available. "
|
|
236
|
+
"Install with: pip install jax jaxlib flax"
|
|
237
|
+
)
|
|
238
|
+
elif backend == 'numpy':
|
|
239
|
+
from .np import bfp_formats as backend_module
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError(f"Unsupported backend: {backend}")
|
|
242
|
+
|
|
243
|
+
return backend_module
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# ============================================================================
|
|
247
|
+
# User-Facing Functions
|
|
248
|
+
# ============================================================================
|
|
249
|
+
|
|
250
|
+
def bfp_quantize(
|
|
251
|
+
data: Any,
|
|
252
|
+
format: Union[str, BFPSpec, Tuple[int, int]] = 'bfp8',
|
|
253
|
+
backend: Optional[str] = None
|
|
254
|
+
) -> Any:
|
|
255
|
+
"""
|
|
256
|
+
Quantize array to BFP format.
|
|
257
|
+
|
|
258
|
+
Automatically detects backend from input type or uses specified backend.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
data : array-like
|
|
263
|
+
Input data (numpy.ndarray, torch.Tensor, or jax.Array)
|
|
264
|
+
format : str, BFPSpec, or tuple
|
|
265
|
+
BFP format specification
|
|
266
|
+
backend : str, optional
|
|
267
|
+
Force specific backend ('numpy', 'jax', or 'torch')
|
|
268
|
+
If None, auto-detects from input
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
array-like
|
|
273
|
+
Quantized data (same type as input)
|
|
274
|
+
|
|
275
|
+
Examples
|
|
276
|
+
--------
|
|
277
|
+
>>> # NumPy
|
|
278
|
+
>>> import numpy as np
|
|
279
|
+
>>> X = np.random.randn(1024, 768)
|
|
280
|
+
>>> X_q = bfp_quantize(X, format='bfp8')
|
|
281
|
+
>>>
|
|
282
|
+
>>> # PyTorch (with automatic STE if requires_grad=True)
|
|
283
|
+
>>> import torch
|
|
284
|
+
>>> X = torch.randn(128, 768, requires_grad=True)
|
|
285
|
+
>>> X_q = bfp_quantize(X, format='bfp8')
|
|
286
|
+
>>> loss = X_q.sum()
|
|
287
|
+
>>> loss.backward() # Gradients flow through!
|
|
288
|
+
>>>
|
|
289
|
+
>>> # Custom format
|
|
290
|
+
>>> X_q = bfp_quantize(X, format=(4, 32)) # 4-bit mantissa, 32 elem/block
|
|
291
|
+
"""
|
|
292
|
+
# Resolve backend
|
|
293
|
+
if backend is None:
|
|
294
|
+
backend = _resolve_backend(data)
|
|
295
|
+
|
|
296
|
+
# Get backend module
|
|
297
|
+
backend_module = _get_backend_module(backend)
|
|
298
|
+
|
|
299
|
+
# Call backend-specific quantization
|
|
300
|
+
return backend_module.bfp_quantize(data, format=format)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class BFPTensor:
|
|
304
|
+
"""
|
|
305
|
+
Backend-agnostic BFP tensor wrapper.
|
|
306
|
+
|
|
307
|
+
Automatically routes to appropriate backend implementation.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
data : array-like
|
|
312
|
+
Input tensor
|
|
313
|
+
format : str, BFPSpec, or tuple
|
|
314
|
+
BFP format
|
|
315
|
+
backend : str, optional
|
|
316
|
+
Force specific backend
|
|
317
|
+
|
|
318
|
+
Examples
|
|
319
|
+
--------
|
|
320
|
+
>>> # NumPy backend
|
|
321
|
+
>>> import numpy as np
|
|
322
|
+
>>> X = np.random.randn(1024, 768)
|
|
323
|
+
>>> bfp = BFPTensor(X, format='bfp8')
|
|
324
|
+
>>> X_reconstructed = bfp.dequantize()
|
|
325
|
+
>>> stats = bfp.statistics()
|
|
326
|
+
"""
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
data: Any,
|
|
331
|
+
format: Union[str, BFPSpec, Tuple[int, int]] = 'bfp8',
|
|
332
|
+
backend: Optional[str] = None
|
|
333
|
+
):
|
|
334
|
+
# Resolve backend
|
|
335
|
+
if backend is None:
|
|
336
|
+
self.backend = _resolve_backend(data)
|
|
337
|
+
else:
|
|
338
|
+
self.backend = backend
|
|
339
|
+
|
|
340
|
+
# Get backend module
|
|
341
|
+
backend_module = _get_backend_module(self.backend)
|
|
342
|
+
|
|
343
|
+
# Create backend-specific tensor
|
|
344
|
+
self._impl = backend_module.BFPTensor_(data, format=format)
|
|
345
|
+
|
|
346
|
+
def dequantize(self) -> Any:
|
|
347
|
+
"""Dequantize to original data type."""
|
|
348
|
+
return self._impl.dequantize()
|
|
349
|
+
|
|
350
|
+
def statistics(self) -> dict:
|
|
351
|
+
"""Get quantization statistics."""
|
|
352
|
+
return self._impl.statistics()
|
|
353
|
+
|
|
354
|
+
def __repr__(self):
|
|
355
|
+
return f"BFPTensor(backend={self.backend}, impl={self._impl})"
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def print_bfp_format_table():
|
|
359
|
+
"""Print table of predefined BFP formats."""
|
|
360
|
+
print("="*90)
|
|
361
|
+
print("Predefined BFP Formats")
|
|
362
|
+
print("="*90)
|
|
363
|
+
|
|
364
|
+
header = (f"{'Name':<15} {'Mantissa':<10} {'Block Size':<12} "
|
|
365
|
+
f"{'Exponent':<10} {'Compress FP16':<15} {'Total Bits':<12}")
|
|
366
|
+
print(header)
|
|
367
|
+
print("-"*90)
|
|
368
|
+
|
|
369
|
+
for name, spec in BFP_FORMATS.items():
|
|
370
|
+
row = (f"{spec.name:<15} "
|
|
371
|
+
f"{spec.mantissa_bits:<10} "
|
|
372
|
+
f"{spec.block_size:<12} "
|
|
373
|
+
f"{spec.exponent_bits:<10} "
|
|
374
|
+
f"{spec.compression_vs_fp16:.2f}x{'':>11} "
|
|
375
|
+
f"{spec.total_bits_per_block}")
|
|
376
|
+
print(row)
|
|
377
|
+
|
|
378
|
+
print("="*90)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
__all__ = [
|
|
384
|
+
'BFPSpec',
|
|
385
|
+
'BFPTensor',
|
|
386
|
+
'BFP_FORMATS',
|
|
387
|
+
'create_bfp_spec',
|
|
388
|
+
'bfp_quantize',
|
|
389
|
+
'print_bfp_format_table',
|
|
390
|
+
]
|
pychop/bitchop.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
def Bitchop(exp_bits, sig_bits, rmode="nearest_even", subnormal=True, random_state=42, device="cpu", verbose=0):
|
|
5
|
+
"""
|
|
6
|
+
Parameters
|
|
7
|
+
----------
|
|
8
|
+
exp_bits : int
|
|
9
|
+
Number of bits for the exponent in the target format. Determines the range
|
|
10
|
+
of representable values (e.g., 5 bits gives a bias of 15, range -14 to 15).
|
|
11
|
+
|
|
12
|
+
sig_bits : int
|
|
13
|
+
Number of bits for the significand (mantissa) in the target format, excluding
|
|
14
|
+
the implicit leading 1 for normalized numbers (e.g., 4 bits allows 0 to 15 plus implicit 1).
|
|
15
|
+
|
|
16
|
+
subnormal : boolean
|
|
17
|
+
Whether or not support subnormal numbers are supported.
|
|
18
|
+
If set `subnormal=False`, subnormals are flushed to zero.
|
|
19
|
+
|
|
20
|
+
rmode : int or str, default="nearest_even"
|
|
21
|
+
Rounding mode to use when quantizing the significand. Options are:
|
|
22
|
+
- 1 or "nearest_even": Round to nearest value, ties to even (IEEE 754 default).
|
|
23
|
+
- 0 or "nearest_odd": Round to nearest value, ties to odd.
|
|
24
|
+
- 2 or "plus_infinity": Round towards plus infinity (round up).
|
|
25
|
+
- 3 or "minus_infinity": Round towards minus infinity (round down).
|
|
26
|
+
- 4 or "toward_zero": Truncate toward zero (no rounding up).
|
|
27
|
+
- 5 or "stochastic_prop": Stochastic rounding proportional to the fractional part.
|
|
28
|
+
- 6 or "stochastic_equal": Stochastic rounding with 50% probability.
|
|
29
|
+
|
|
30
|
+
random_state : int, default=0
|
|
31
|
+
Random seed set for stochastic rounding settings.
|
|
32
|
+
|
|
33
|
+
device : str or torch.device, optional, default="cpu"
|
|
34
|
+
Device to perform computations on (e.g., "cpu", "cuda").
|
|
35
|
+
|
|
36
|
+
subnormal (bool, optional): If True, supports denormalized numbers (subnormals) when
|
|
37
|
+
the exponent underflows, shifting the significand. If False, underflows result in zero.
|
|
38
|
+
Defaults to True.
|
|
39
|
+
|
|
40
|
+
verbose : int | bool, defaul=0
|
|
41
|
+
Whether or not to print out the unit-roundoff.
|
|
42
|
+
|
|
43
|
+
Properties
|
|
44
|
+
----------
|
|
45
|
+
u : float,
|
|
46
|
+
Unit roundoff corresponding to the floating point format
|
|
47
|
+
|
|
48
|
+
Methods
|
|
49
|
+
----------
|
|
50
|
+
Bitchop(x)
|
|
51
|
+
Method that convert ``x`` to the user-specific arithmetic format.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
----------
|
|
55
|
+
Bitchop | object,
|
|
56
|
+
``Chop`` instance.
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
if os.environ['chop_backend'] == 'torch':
|
|
61
|
+
from .tch.bitchop import Bitchop
|
|
62
|
+
obj = Bitchop(exp_bits=exp_bits, sig_bits=sig_bits, subnormal=subnormal, device=device,
|
|
63
|
+
random_state=random_state, rmode=rmode)
|
|
64
|
+
|
|
65
|
+
elif os.environ['chop_backend'] == 'jax':
|
|
66
|
+
from .jx.bitchop import Bitchop
|
|
67
|
+
obj = Bitchop(exp_bits=exp_bits, sig_bits=sig_bits, subnormal=subnormal, device=device,
|
|
68
|
+
random_state=random_state, rmode=rmode)
|
|
69
|
+
else:
|
|
70
|
+
from .np.bitchop import Bitchop
|
|
71
|
+
obj = Bitchop(exp_bits=exp_bits, sig_bits=sig_bits, subnormal=subnormal, random_state=random_state, rmode=rmode)
|
|
72
|
+
|
|
73
|
+
obj.u = 2**sig_bits / 2
|
|
74
|
+
|
|
75
|
+
if verbose:
|
|
76
|
+
print("The floating point format is with unit-roundoff of {:e}".format(
|
|
77
|
+
obj.u)+" (≈2^"+str(int(np.log2(obj.u)))+").")
|
|
78
|
+
|
|
79
|
+
return obj
|
pychop/blas.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from pychop import LightChop
|
|
2
|
+
import torch
|
|
3
|
+
import pychop
|
|
4
|
+
pychop.backend('torch')
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
precision_configs = {
|
|
8
|
+
'q52': {'exp_bits': 5, 'sig_bits': 2, 'rmode': 1},
|
|
9
|
+
'q43': {'exp_bits': 4, 'sig_bits': 3, 'rmode': 1},
|
|
10
|
+
'bf16': {'exp_bits': 8, 'sig_bits': 7, 'rmode': 1},
|
|
11
|
+
'half': {'exp_bits': 5, 'sig_bits': 10, 'rmode': 1},
|
|
12
|
+
'tf32': {'exp_bits': 8, 'sig_bits': 10, 'rmode': 1},
|
|
13
|
+
'fp32': {'exp_bits': 8, 'sig_bits': 23, 'rmode': 1},
|
|
14
|
+
'fp64': {'exp_bits': 11, 'sig_bits': 52, 'rmode': 1}
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
precision_fallback = ['q52', 'q43', 'bf16', 'half', 'tf32', 'fp32', 'fp64'] # Precision fallback order
|
|
18
|
+
|
|
19
|
+
def chop(x, precision_idx=0):
|
|
20
|
+
"""Recursive chop function"""
|
|
21
|
+
if not torch.is_tensor(x):
|
|
22
|
+
x = torch.tensor(x, dtype=torch.float64, device=device)
|
|
23
|
+
if precision_idx >= len(precision_fallback):
|
|
24
|
+
return x
|
|
25
|
+
precision = precision_fallback[precision_idx]
|
|
26
|
+
if precision == 'fp64':
|
|
27
|
+
return x
|
|
28
|
+
ch = LightChop(**precision_configs[precision])
|
|
29
|
+
result = ch(x)
|
|
30
|
+
if not torch.any(torch.isnan(result)) and not torch.any(torch.isinf(result)):
|
|
31
|
+
return result.to(torch.float64).to(device)
|
|
32
|
+
logging.debug(f"Chop: Precision {precision} failed, escalating to {precision_fallback[precision_idx + 1]}")
|
|
33
|
+
return chop(x, precision_idx + 1)
|
|
34
|
+
|
|
35
|
+
def rounding(x, precision):
|
|
36
|
+
return chop(x, precision_idx=precision_fallback.index(precision))
|
|
37
|
+
|
|
38
|
+
def mixed_precision_op(op, x, precision, y=None):
|
|
39
|
+
"""Mixed-precision operation"""
|
|
40
|
+
x = rounding(x, precision)
|
|
41
|
+
if y is None:
|
|
42
|
+
unrounded = op(x)
|
|
43
|
+
else:
|
|
44
|
+
y = rounding(y, precision)
|
|
45
|
+
unrounded = op(x, y)
|
|
46
|
+
if precision == 'fp64':
|
|
47
|
+
return unrounded.to(device)
|
|
48
|
+
result = chop(unrounded, precision_idx=precision_fallback.index(precision))
|
|
49
|
+
return result.to(device)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def round_sparse_matrix(A, precision):
|
|
53
|
+
"""Round sparse matrix to specified precision"""
|
|
54
|
+
if precision == 'fp64':
|
|
55
|
+
return A
|
|
56
|
+
A_coo = A.tocoo()
|
|
57
|
+
data = torch.tensor(A_coo.data, dtype=torch.float64, device=device)
|
|
58
|
+
ch = LightChop(**precision_configs[precision])
|
|
59
|
+
rounded_data = ch(data)
|
|
60
|
+
if torch.any(torch.isnan(rounded_data)) or torch.any(torch.isinf(rounded_data)):
|
|
61
|
+
logging.warning(f"Rounding sparse matrix to {precision} failed; using fp64")
|
|
62
|
+
return A
|
|
63
|
+
return csc_matrix((rounded_data.cpu().numpy(), (A_coo.row, A_coo.col)), shape=A.shape)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == "__main__":
|
|
68
|
+
import numpy as np
|
|
69
|
+
|
|
70
|
+
A = np.random.randn(100, 100)
|
|
71
|
+
B = np.random.randn(100, 100)
|
|
72
|
+
C = A + B
|
|
73
|
+
print("C:", C)
|
|
74
|
+
print("C (fp32):",mixed_precision_op(lambda x, y: x+y, A, 'fp32', B))
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from pychop import Chop # Or: from pychop import Chop
|
|
3
|
+
|
|
4
|
+
class CPArray(np.ndarray):
|
|
5
|
+
"""
|
|
6
|
+
A NumPy array subclass that maintains chopped precision after arithmetic ops.
|
|
7
|
+
- Inherits from np.ndarray for full compatibility.
|
|
8
|
+
- Uses Chop for rounding arrays.
|
|
9
|
+
- Operations return CPArray instances (chopped post-op).
|
|
10
|
+
"""
|
|
11
|
+
def __new__(cls, input_array, chopper=None):
|
|
12
|
+
if chopper is None:
|
|
13
|
+
raise ValueError("Must provide a chopper (Chop or Chop instance)")
|
|
14
|
+
# Chop the base array FIRST (pure ndarray) to avoid subclass recursion
|
|
15
|
+
base_input = np.asarray(input_array) # Strip any subclass
|
|
16
|
+
chopped_base = chopper(base_input) # Chop on pure -> pure chopped ndarray
|
|
17
|
+
# Now view the pre-chopped base as CPArray (no re-chop)
|
|
18
|
+
obj = chopped_base.view(cls)
|
|
19
|
+
obj.chopper = chopper
|
|
20
|
+
return obj
|
|
21
|
+
|
|
22
|
+
def __array_finalize__(self, obj):
|
|
23
|
+
if obj is None:
|
|
24
|
+
return
|
|
25
|
+
self.chopper = getattr(obj, 'chopper', None)
|
|
26
|
+
|
|
27
|
+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
|
28
|
+
"""
|
|
29
|
+
Override for ufuncs (+, -, *, /, etc.): Compute on pure arrays, chop pure result, view as CPArray.
|
|
30
|
+
"""
|
|
31
|
+
# Validate same chopper for CPArray inputs
|
|
32
|
+
for inp in inputs:
|
|
33
|
+
if isinstance(inp, CPArray) and inp.chopper != self.chopper:
|
|
34
|
+
raise ValueError("All CPArray inputs must use the same chopper")
|
|
35
|
+
|
|
36
|
+
# Compute on pure ndarrays
|
|
37
|
+
full_inputs = [np.asarray(x) for x in inputs] # Strip subclasses
|
|
38
|
+
result = getattr(ufunc, method)(*full_inputs, **kwargs) # Pure computation
|
|
39
|
+
|
|
40
|
+
# Chop the pure result
|
|
41
|
+
chopped_result = self.chopper(result) # Chop on pure -> pure chopped
|
|
42
|
+
|
|
43
|
+
# Return as CPArray (views pre-chopped; no recursion)
|
|
44
|
+
if chopped_result.ndim == 0:
|
|
45
|
+
return chopped_result.item() # Scalar fallback
|
|
46
|
+
else:
|
|
47
|
+
return CPArray(chopped_result, self.chopper) # Safe view
|
|
48
|
+
|
|
49
|
+
# Matmul: Strip self to pure before computation
|
|
50
|
+
def __matmul__(self, other):
|
|
51
|
+
self_pure = self.view(np.ndarray) # Strip subclass
|
|
52
|
+
other_pure = np.asarray(other)
|
|
53
|
+
result = np.matmul(self_pure, other_pure)
|
|
54
|
+
return CPArray(result, self.chopper) # Views pre-chopped result
|
|
55
|
+
|
|
56
|
+
def __rmatmul__(self, other):
|
|
57
|
+
return CPArray(np.matmul(np.asarray(other), self.view(np.ndarray)), self.chopper)
|
|
58
|
+
|
|
59
|
+
# Utility: View as regular array
|
|
60
|
+
def to_regular(self):
|
|
61
|
+
return np.asarray(self)
|
|
62
|
+
|
|
63
|
+
def __str__(self):
|
|
64
|
+
prec_info = f"exp_bits={self.chopper.exp_bits}, sig_bits={self.chopper.sig_bits}" if hasattr(self.chopper, 'exp_bits') else "custom"
|
|
65
|
+
return f"CPArray({np.array2string(self)}, {prec_info})"
|
|
66
|
+
|
|
67
|
+
def __repr__(self):
|
|
68
|
+
return str(self)
|