pychop 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. pychop/__init__.py +87 -0
  2. pychop/bfp_formats.py +390 -0
  3. pychop/bitchop.py +79 -0
  4. pychop/blas.py +74 -0
  5. pychop/builtin/__init__.py +4 -0
  6. pychop/builtin/cparray.py +68 -0
  7. pychop/builtin/cparray_jax.py +202 -0
  8. pychop/builtin/cpfloat.py +87 -0
  9. pychop/builtin/cptensor.py +106 -0
  10. pychop/chop.py +151 -0
  11. pychop/demo_harmonic.py +56 -0
  12. pychop/faultchop.py +222 -0
  13. pychop/fixed_point.py +123 -0
  14. pychop/float_params.py +123 -0
  15. pychop/integer.py +170 -0
  16. pychop/jx/__init__.py +36 -0
  17. pychop/jx/bfp_formats.py +236 -0
  18. pychop/jx/bitchop.py +262 -0
  19. pychop/jx/blas_jx.py +1591 -0
  20. pychop/jx/fixed_point.py +726 -0
  21. pychop/jx/float_point.py +1160 -0
  22. pychop/jx/integer.py +134 -0
  23. pychop/jx/layers.py +4609 -0
  24. pychop/jx/lightchop.py +785 -0
  25. pychop/jx/mx_formats.py +262 -0
  26. pychop/jx/squeeze.py +202 -0
  27. pychop/layers.py +339 -0
  28. pychop/math_func.py +324 -0
  29. pychop/mx_formats.py +344 -0
  30. pychop/np/__init__.py +5 -0
  31. pychop/np/bfp_formats.py +205 -0
  32. pychop/np/bitchop.py +222 -0
  33. pychop/np/blas_np.py +1584 -0
  34. pychop/np/fixed_point.py +560 -0
  35. pychop/np/float_point.py +1180 -0
  36. pychop/np/integer.py +91 -0
  37. pychop/np/lightchop.py +624 -0
  38. pychop/np/mx_formats.py +361 -0
  39. pychop/np/roundit.py +1 -0
  40. pychop/np/squeeze.py +180 -0
  41. pychop/optimizers.py +453 -0
  42. pychop/set_backend.py +80 -0
  43. pychop/simulate.py +196 -0
  44. pychop/tch/__init__.py +22 -0
  45. pychop/tch/bfp_formats.py +278 -0
  46. pychop/tch/bitchop.py +240 -0
  47. pychop/tch/blas_th.py +1648 -0
  48. pychop/tch/fixed_point.py +541 -0
  49. pychop/tch/float_point.py +991 -0
  50. pychop/tch/integer.py +166 -0
  51. pychop/tch/layers.py +2366 -0
  52. pychop/tch/lightchop.py +816 -0
  53. pychop/tch/mx_formats.py +438 -0
  54. pychop/tch/squeeze.py +173 -0
  55. pychop/utils.py +415 -0
  56. pychop-0.5.2.dist-info/METADATA +425 -0
  57. pychop-0.5.2.dist-info/RECORD +60 -0
  58. pychop-0.5.2.dist-info/WHEEL +5 -0
  59. pychop-0.5.2.dist-info/licenses/LICENSE +21 -0
  60. pychop-0.5.2.dist-info/top_level.txt +1 -0
pychop/__init__.py ADDED
@@ -0,0 +1,87 @@
1
+ """
2
+ Pychop: Precision Simulation for Low-Precision Arithmetic
3
+
4
+ A comprehensive Python package for simulating low-precision arithmetic in
5
+ scientific computing and machine learning, with support for multiple backends
6
+ (NumPy, JAX, PyTorch).
7
+
8
+ Supported formats:
9
+ - Floating-point (Chop): IEEE 754 and custom formats
10
+ - Fixed-point (Chopf): Integer and fractional bits
11
+ - Integer quantization (Chopi): Symmetric and asymmetric
12
+ - Block Floating Point (BFP): Shared exponent per block
13
+ - Microscaling (MX): OCP standard with block-level scaling
14
+
15
+ Backends:
16
+ - NumPy: Pure numerical computation
17
+ - JAX: Custom VJP for differentiation
18
+ - PyTorch: Straight-Through Estimator (STE) for QAT
19
+
20
+ Author: Erin Carson, Xinye Chen
21
+ """
22
+
23
+ from .chop import Chop
24
+ from .integer import Chopi
25
+ from .fixed_point import Chopf
26
+
27
+ from .simulate import Simulate
28
+
29
+ from .float_params import float_params
30
+ from .bitchop import Bitchop
31
+ from .faultchop import FaultChop
32
+
33
+ from .layers import ChopSTE, ChopfSTE, ChopiSTE
34
+ from .math_func import *
35
+
36
+
37
+ __version__ = '0.5.2'
38
+
39
+ import os
40
+ if 'chop_backend' not in os.environ:
41
+ os.environ['chop_backend'] = 'auto'
42
+
43
+ from .set_backend import backend
44
+
45
+
46
+ from dataclasses import dataclass
47
+ from typing import Optional
48
+
49
+ from .bfp_formats import (
50
+ BFPSpec,
51
+ BFPTensor,
52
+ BFP_FORMATS,
53
+ create_bfp_spec,
54
+ bfp_quantize,
55
+ print_bfp_format_table,
56
+ )
57
+
58
+ # MX Formats
59
+ from .mx_formats import (
60
+ MXSpec,
61
+ MXTensor,
62
+ MX_FORMATS,
63
+ create_mx_spec,
64
+ mx_quantize,
65
+ compare_mx_formats,
66
+ print_mx_format_table,
67
+ )
68
+
69
+ @dataclass
70
+ class Customs:
71
+ emax: Optional[int] = None # the maximum value of the exponent.
72
+ t: Optional[int] = None # the number of bits in the significand (including the hidden bit)
73
+ exp_bits: Optional[int] = None # the exponent bits
74
+ sig_bits: Optional[int] = None # the significand bits (not including the hidden bit)
75
+
76
+
77
+ @dataclass
78
+ class Options:
79
+ t: int
80
+ emax: int
81
+ prec: int
82
+ subnormal: bool
83
+ rmode: bool
84
+ flip: bool
85
+ explim: bool
86
+ p: float
87
+
pychop/bfp_formats.py ADDED
@@ -0,0 +1,390 @@
1
+ """
2
+ Block Floating Point (BFP) Format - Backend Agnostic Entry Point
3
+
4
+ This module provides automatic backend detection and routing for BFP quantization.
5
+ Supports NumPy, JAX, and PyTorch backends with automatic detection.
6
+
7
+ Usage:
8
+ >>> import pychop
9
+ >>> pychop.backend('auto') # Auto-detect from input
10
+ >>>
11
+ >>> # NumPy
12
+ >>> import numpy as np
13
+ >>> X = np.random.randn(1024, 768)
14
+ >>> X_q = bfp_quantize(X, format='bfp8')
15
+ >>>
16
+ >>> # PyTorch (with STE for training)
17
+ >>> import torch
18
+ >>> X = torch.randn(128, 768, requires_grad=True)
19
+ >>> X_q = bfp_quantize(X, format='bfp8') # Automatic STE!
20
+ >>>
21
+ >>> # JAX
22
+ >>> import jax.numpy as jnp
23
+ >>> X = jnp.array(np.random.randn(512, 512))
24
+ >>> X_q = bfp_quantize(X, format='bfp8')
25
+
26
+ Author: Xinye Chen
27
+
28
+ """
29
+
30
+ import os
31
+ from typing import Union, Tuple, Optional, Any
32
+ from dataclasses import dataclass
33
+
34
+
35
+ # ============================================================================
36
+ # Backend Detection (inline to avoid import issues)
37
+ # ============================================================================
38
+
39
+ def _detect_array_type(x: Any) -> str:
40
+ """
41
+ Detect backend from input array type.
42
+
43
+ Parameters
44
+ ----------
45
+ x : Any
46
+ Input array or scalar
47
+
48
+ Returns
49
+ -------
50
+ str
51
+ Backend name: 'numpy', 'torch', or 'jax'
52
+ """
53
+ module = type(x).__module__
54
+
55
+ if "torch" in module:
56
+ return "torch"
57
+ if "jax" in module:
58
+ return "jax"
59
+ return "numpy"
60
+
61
+
62
+ def _get_backend_env() -> str:
63
+ """Get backend from environment variable."""
64
+ return os.environ.get('chop_backend', 'auto')
65
+
66
+
67
+ # ============================================================================
68
+ # BFP Format Specification (Backend-Independent)
69
+ # ============================================================================
70
+
71
+ @dataclass
72
+ class BFPSpec:
73
+ """
74
+ Block Floating Point format specification.
75
+
76
+ This is backend-independent and shared across all implementations.
77
+
78
+ Attributes
79
+ ----------
80
+ name : str
81
+ Format name
82
+ mantissa_bits : int
83
+ Number of mantissa bits per element (including sign)
84
+ block_size : int
85
+ Number of elements sharing same exponent
86
+ exponent_bits : int
87
+ Number of bits for shared exponent
88
+ has_sign : bool
89
+ Whether elements have sign bits
90
+ use_subnormals : bool
91
+ Whether to support subnormal numbers
92
+ """
93
+ name: str
94
+ mantissa_bits: int
95
+ block_size: int
96
+ exponent_bits: int = 8
97
+ has_sign: bool = True
98
+ use_subnormals: bool = False
99
+
100
+ @property
101
+ def total_bits_per_block(self) -> int:
102
+ """Total bits for entire block."""
103
+ return self.exponent_bits + (self.mantissa_bits * self.block_size)
104
+
105
+ @property
106
+ def compression_vs_fp32(self) -> float:
107
+ """Compression ratio vs FP32."""
108
+ fp32_bits = 32 * self.block_size
109
+ return fp32_bits / self.total_bits_per_block
110
+
111
+ @property
112
+ def compression_vs_fp16(self) -> float:
113
+ """Compression ratio vs FP16."""
114
+ fp16_bits = 16 * self.block_size
115
+ return fp16_bits / self.total_bits_per_block
116
+
117
+ def __repr__(self):
118
+ return (f"BFPSpec(name='{self.name}', mantissa={self.mantissa_bits}b, "
119
+ f"block_size={self.block_size}, exponent={self.exponent_bits}b)")
120
+
121
+
122
+ # Predefined BFP formats (shared across all backends)
123
+ BFP_FORMATS = {
124
+ 'bfp16': BFPSpec('bfp16', mantissa_bits=16, block_size=16, exponent_bits=8),
125
+ 'bfp12': BFPSpec('bfp12', mantissa_bits=12, block_size=16, exponent_bits=8),
126
+ 'bfp8': BFPSpec('bfp8', mantissa_bits=8, block_size=32, exponent_bits=8),
127
+ 'bfp6': BFPSpec('bfp6', mantissa_bits=6, block_size=32, exponent_bits=8),
128
+ 'bfp4': BFPSpec('bfp4', mantissa_bits=4, block_size=32, exponent_bits=8),
129
+ 'bfp3': BFPSpec('bfp3', mantissa_bits=3, block_size=64, exponent_bits=8),
130
+ 'bfp2': BFPSpec('bfp2', mantissa_bits=2, block_size=128, exponent_bits=8),
131
+ 'flexpoint16': BFPSpec('flexpoint16', mantissa_bits=16, block_size=16, exponent_bits=5),
132
+ 'flexpoint8': BFPSpec('flexpoint8', mantissa_bits=8, block_size=32, exponent_bits=5),
133
+ }
134
+
135
+
136
+ def create_bfp_spec(
137
+ mantissa_bits: int,
138
+ block_size: int,
139
+ exponent_bits: int = 8,
140
+ name: Optional[str] = None
141
+ ) -> BFPSpec:
142
+ """
143
+ Create custom BFP format specification.
144
+
145
+ Parameters
146
+ ----------
147
+ mantissa_bits : int
148
+ Number of mantissa bits (1-32)
149
+ block_size : int
150
+ Elements per block
151
+ exponent_bits : int
152
+ Bits for shared exponent
153
+ name : str, optional
154
+ Custom name
155
+
156
+ Returns
157
+ -------
158
+ BFPSpec
159
+ BFP format specification
160
+ """
161
+ if name is None:
162
+ name = f"custom_bfp{mantissa_bits}"
163
+
164
+ return BFPSpec(
165
+ name=name,
166
+ mantissa_bits=mantissa_bits,
167
+ block_size=block_size,
168
+ exponent_bits=exponent_bits
169
+ )
170
+
171
+
172
+ # ============================================================================
173
+ # Backend Detection and Routing
174
+ # ============================================================================
175
+
176
+ def _resolve_backend(X: Any = None) -> str:
177
+ """
178
+ Resolve which backend to use.
179
+
180
+ Parameters
181
+ ----------
182
+ X : Any, optional
183
+ Input array (if provided, used for auto-detection)
184
+
185
+ Returns
186
+ -------
187
+ str
188
+ Backend name: 'numpy', 'jax', or 'torch'
189
+ """
190
+ env_backend = _get_backend_env()
191
+
192
+ if env_backend == 'auto':
193
+ if X is not None:
194
+ return _detect_array_type(X)
195
+ else:
196
+ # Default to numpy if no input provided
197
+ return 'numpy'
198
+
199
+ if env_backend not in {'numpy', 'jax', 'torch'}:
200
+ raise ValueError(
201
+ f"Invalid backend: {env_backend}. "
202
+ "Must be 'numpy', 'jax', 'torch', or 'auto'."
203
+ )
204
+
205
+ return env_backend
206
+
207
+
208
+ def _get_backend_module(backend: str):
209
+ """
210
+ Get backend-specific BFP implementation.
211
+
212
+ Parameters
213
+ ----------
214
+ backend : str
215
+ Backend name
216
+
217
+ Returns
218
+ -------
219
+ module
220
+ Backend-specific BFP module
221
+ """
222
+ if backend == 'torch':
223
+ try:
224
+ from .tch import bfp_formats as backend_module
225
+ except ImportError:
226
+ raise ImportError(
227
+ "PyTorch backend not available. "
228
+ "Install with: pip install torch"
229
+ )
230
+ elif backend == 'jax':
231
+ try:
232
+ from .jx import bfp_formats as backend_module
233
+ except ImportError:
234
+ raise ImportError(
235
+ "JAX backend not available. "
236
+ "Install with: pip install jax jaxlib flax"
237
+ )
238
+ elif backend == 'numpy':
239
+ from .np import bfp_formats as backend_module
240
+ else:
241
+ raise ValueError(f"Unsupported backend: {backend}")
242
+
243
+ return backend_module
244
+
245
+
246
+ # ============================================================================
247
+ # User-Facing Functions
248
+ # ============================================================================
249
+
250
+ def bfp_quantize(
251
+ data: Any,
252
+ format: Union[str, BFPSpec, Tuple[int, int]] = 'bfp8',
253
+ backend: Optional[str] = None
254
+ ) -> Any:
255
+ """
256
+ Quantize array to BFP format.
257
+
258
+ Automatically detects backend from input type or uses specified backend.
259
+
260
+ Parameters
261
+ ----------
262
+ data : array-like
263
+ Input data (numpy.ndarray, torch.Tensor, or jax.Array)
264
+ format : str, BFPSpec, or tuple
265
+ BFP format specification
266
+ backend : str, optional
267
+ Force specific backend ('numpy', 'jax', or 'torch')
268
+ If None, auto-detects from input
269
+
270
+ Returns
271
+ -------
272
+ array-like
273
+ Quantized data (same type as input)
274
+
275
+ Examples
276
+ --------
277
+ >>> # NumPy
278
+ >>> import numpy as np
279
+ >>> X = np.random.randn(1024, 768)
280
+ >>> X_q = bfp_quantize(X, format='bfp8')
281
+ >>>
282
+ >>> # PyTorch (with automatic STE if requires_grad=True)
283
+ >>> import torch
284
+ >>> X = torch.randn(128, 768, requires_grad=True)
285
+ >>> X_q = bfp_quantize(X, format='bfp8')
286
+ >>> loss = X_q.sum()
287
+ >>> loss.backward() # Gradients flow through!
288
+ >>>
289
+ >>> # Custom format
290
+ >>> X_q = bfp_quantize(X, format=(4, 32)) # 4-bit mantissa, 32 elem/block
291
+ """
292
+ # Resolve backend
293
+ if backend is None:
294
+ backend = _resolve_backend(data)
295
+
296
+ # Get backend module
297
+ backend_module = _get_backend_module(backend)
298
+
299
+ # Call backend-specific quantization
300
+ return backend_module.bfp_quantize(data, format=format)
301
+
302
+
303
+ class BFPTensor:
304
+ """
305
+ Backend-agnostic BFP tensor wrapper.
306
+
307
+ Automatically routes to appropriate backend implementation.
308
+
309
+ Parameters
310
+ ----------
311
+ data : array-like
312
+ Input tensor
313
+ format : str, BFPSpec, or tuple
314
+ BFP format
315
+ backend : str, optional
316
+ Force specific backend
317
+
318
+ Examples
319
+ --------
320
+ >>> # NumPy backend
321
+ >>> import numpy as np
322
+ >>> X = np.random.randn(1024, 768)
323
+ >>> bfp = BFPTensor(X, format='bfp8')
324
+ >>> X_reconstructed = bfp.dequantize()
325
+ >>> stats = bfp.statistics()
326
+ """
327
+
328
+ def __init__(
329
+ self,
330
+ data: Any,
331
+ format: Union[str, BFPSpec, Tuple[int, int]] = 'bfp8',
332
+ backend: Optional[str] = None
333
+ ):
334
+ # Resolve backend
335
+ if backend is None:
336
+ self.backend = _resolve_backend(data)
337
+ else:
338
+ self.backend = backend
339
+
340
+ # Get backend module
341
+ backend_module = _get_backend_module(self.backend)
342
+
343
+ # Create backend-specific tensor
344
+ self._impl = backend_module.BFPTensor_(data, format=format)
345
+
346
+ def dequantize(self) -> Any:
347
+ """Dequantize to original data type."""
348
+ return self._impl.dequantize()
349
+
350
+ def statistics(self) -> dict:
351
+ """Get quantization statistics."""
352
+ return self._impl.statistics()
353
+
354
+ def __repr__(self):
355
+ return f"BFPTensor(backend={self.backend}, impl={self._impl})"
356
+
357
+
358
+ def print_bfp_format_table():
359
+ """Print table of predefined BFP formats."""
360
+ print("="*90)
361
+ print("Predefined BFP Formats")
362
+ print("="*90)
363
+
364
+ header = (f"{'Name':<15} {'Mantissa':<10} {'Block Size':<12} "
365
+ f"{'Exponent':<10} {'Compress FP16':<15} {'Total Bits':<12}")
366
+ print(header)
367
+ print("-"*90)
368
+
369
+ for name, spec in BFP_FORMATS.items():
370
+ row = (f"{spec.name:<15} "
371
+ f"{spec.mantissa_bits:<10} "
372
+ f"{spec.block_size:<12} "
373
+ f"{spec.exponent_bits:<10} "
374
+ f"{spec.compression_vs_fp16:.2f}x{'':>11} "
375
+ f"{spec.total_bits_per_block}")
376
+ print(row)
377
+
378
+ print("="*90)
379
+
380
+
381
+
382
+
383
+ __all__ = [
384
+ 'BFPSpec',
385
+ 'BFPTensor',
386
+ 'BFP_FORMATS',
387
+ 'create_bfp_spec',
388
+ 'bfp_quantize',
389
+ 'print_bfp_format_table',
390
+ ]
pychop/bitchop.py ADDED
@@ -0,0 +1,79 @@
1
+ import os
2
+ import numpy as np
3
+
4
+ def Bitchop(exp_bits, sig_bits, rmode="nearest_even", subnormal=True, random_state=42, device="cpu", verbose=0):
5
+ """
6
+ Parameters
7
+ ----------
8
+ exp_bits : int
9
+ Number of bits for the exponent in the target format. Determines the range
10
+ of representable values (e.g., 5 bits gives a bias of 15, range -14 to 15).
11
+
12
+ sig_bits : int
13
+ Number of bits for the significand (mantissa) in the target format, excluding
14
+ the implicit leading 1 for normalized numbers (e.g., 4 bits allows 0 to 15 plus implicit 1).
15
+
16
+ subnormal : boolean
17
+ Whether or not support subnormal numbers are supported.
18
+ If set `subnormal=False`, subnormals are flushed to zero.
19
+
20
+ rmode : int or str, default="nearest_even"
21
+ Rounding mode to use when quantizing the significand. Options are:
22
+ - 1 or "nearest_even": Round to nearest value, ties to even (IEEE 754 default).
23
+ - 0 or "nearest_odd": Round to nearest value, ties to odd.
24
+ - 2 or "plus_infinity": Round towards plus infinity (round up).
25
+ - 3 or "minus_infinity": Round towards minus infinity (round down).
26
+ - 4 or "toward_zero": Truncate toward zero (no rounding up).
27
+ - 5 or "stochastic_prop": Stochastic rounding proportional to the fractional part.
28
+ - 6 or "stochastic_equal": Stochastic rounding with 50% probability.
29
+
30
+ random_state : int, default=0
31
+ Random seed set for stochastic rounding settings.
32
+
33
+ device : str or torch.device, optional, default="cpu"
34
+ Device to perform computations on (e.g., "cpu", "cuda").
35
+
36
+ subnormal (bool, optional): If True, supports denormalized numbers (subnormals) when
37
+ the exponent underflows, shifting the significand. If False, underflows result in zero.
38
+ Defaults to True.
39
+
40
+ verbose : int | bool, defaul=0
41
+ Whether or not to print out the unit-roundoff.
42
+
43
+ Properties
44
+ ----------
45
+ u : float,
46
+ Unit roundoff corresponding to the floating point format
47
+
48
+ Methods
49
+ ----------
50
+ Bitchop(x)
51
+ Method that convert ``x`` to the user-specific arithmetic format.
52
+
53
+ Returns
54
+ ----------
55
+ Bitchop | object,
56
+ ``Chop`` instance.
57
+
58
+ """
59
+
60
+ if os.environ['chop_backend'] == 'torch':
61
+ from .tch.bitchop import Bitchop
62
+ obj = Bitchop(exp_bits=exp_bits, sig_bits=sig_bits, subnormal=subnormal, device=device,
63
+ random_state=random_state, rmode=rmode)
64
+
65
+ elif os.environ['chop_backend'] == 'jax':
66
+ from .jx.bitchop import Bitchop
67
+ obj = Bitchop(exp_bits=exp_bits, sig_bits=sig_bits, subnormal=subnormal, device=device,
68
+ random_state=random_state, rmode=rmode)
69
+ else:
70
+ from .np.bitchop import Bitchop
71
+ obj = Bitchop(exp_bits=exp_bits, sig_bits=sig_bits, subnormal=subnormal, random_state=random_state, rmode=rmode)
72
+
73
+ obj.u = 2**sig_bits / 2
74
+
75
+ if verbose:
76
+ print("The floating point format is with unit-roundoff of {:e}".format(
77
+ obj.u)+" (≈2^"+str(int(np.log2(obj.u)))+").")
78
+
79
+ return obj
pychop/blas.py ADDED
@@ -0,0 +1,74 @@
1
+ from pychop import LightChop
2
+ import torch
3
+ import pychop
4
+ pychop.backend('torch')
5
+
6
+
7
+ precision_configs = {
8
+ 'q52': {'exp_bits': 5, 'sig_bits': 2, 'rmode': 1},
9
+ 'q43': {'exp_bits': 4, 'sig_bits': 3, 'rmode': 1},
10
+ 'bf16': {'exp_bits': 8, 'sig_bits': 7, 'rmode': 1},
11
+ 'half': {'exp_bits': 5, 'sig_bits': 10, 'rmode': 1},
12
+ 'tf32': {'exp_bits': 8, 'sig_bits': 10, 'rmode': 1},
13
+ 'fp32': {'exp_bits': 8, 'sig_bits': 23, 'rmode': 1},
14
+ 'fp64': {'exp_bits': 11, 'sig_bits': 52, 'rmode': 1}
15
+ }
16
+
17
+ precision_fallback = ['q52', 'q43', 'bf16', 'half', 'tf32', 'fp32', 'fp64'] # Precision fallback order
18
+
19
+ def chop(x, precision_idx=0):
20
+ """Recursive chop function"""
21
+ if not torch.is_tensor(x):
22
+ x = torch.tensor(x, dtype=torch.float64, device=device)
23
+ if precision_idx >= len(precision_fallback):
24
+ return x
25
+ precision = precision_fallback[precision_idx]
26
+ if precision == 'fp64':
27
+ return x
28
+ ch = LightChop(**precision_configs[precision])
29
+ result = ch(x)
30
+ if not torch.any(torch.isnan(result)) and not torch.any(torch.isinf(result)):
31
+ return result.to(torch.float64).to(device)
32
+ logging.debug(f"Chop: Precision {precision} failed, escalating to {precision_fallback[precision_idx + 1]}")
33
+ return chop(x, precision_idx + 1)
34
+
35
+ def rounding(x, precision):
36
+ return chop(x, precision_idx=precision_fallback.index(precision))
37
+
38
+ def mixed_precision_op(op, x, precision, y=None):
39
+ """Mixed-precision operation"""
40
+ x = rounding(x, precision)
41
+ if y is None:
42
+ unrounded = op(x)
43
+ else:
44
+ y = rounding(y, precision)
45
+ unrounded = op(x, y)
46
+ if precision == 'fp64':
47
+ return unrounded.to(device)
48
+ result = chop(unrounded, precision_idx=precision_fallback.index(precision))
49
+ return result.to(device)
50
+
51
+
52
+ def round_sparse_matrix(A, precision):
53
+ """Round sparse matrix to specified precision"""
54
+ if precision == 'fp64':
55
+ return A
56
+ A_coo = A.tocoo()
57
+ data = torch.tensor(A_coo.data, dtype=torch.float64, device=device)
58
+ ch = LightChop(**precision_configs[precision])
59
+ rounded_data = ch(data)
60
+ if torch.any(torch.isnan(rounded_data)) or torch.any(torch.isinf(rounded_data)):
61
+ logging.warning(f"Rounding sparse matrix to {precision} failed; using fp64")
62
+ return A
63
+ return csc_matrix((rounded_data.cpu().numpy(), (A_coo.row, A_coo.col)), shape=A.shape)
64
+
65
+
66
+
67
+ if __name__ == "__main__":
68
+ import numpy as np
69
+
70
+ A = np.random.randn(100, 100)
71
+ B = np.random.randn(100, 100)
72
+ C = A + B
73
+ print("C:", C)
74
+ print("C (fp32):",mixed_precision_op(lambda x, y: x+y, A, 'fp32', B))
@@ -0,0 +1,4 @@
1
+ from .cpfloat import *
2
+ from .cparray import *
3
+ from .cparray_jax import *
4
+ from .cptensor import *
@@ -0,0 +1,68 @@
1
+ import numpy as np
2
+ from pychop import Chop # Or: from pychop import Chop
3
+
4
+ class CPArray(np.ndarray):
5
+ """
6
+ A NumPy array subclass that maintains chopped precision after arithmetic ops.
7
+ - Inherits from np.ndarray for full compatibility.
8
+ - Uses Chop for rounding arrays.
9
+ - Operations return CPArray instances (chopped post-op).
10
+ """
11
+ def __new__(cls, input_array, chopper=None):
12
+ if chopper is None:
13
+ raise ValueError("Must provide a chopper (Chop or Chop instance)")
14
+ # Chop the base array FIRST (pure ndarray) to avoid subclass recursion
15
+ base_input = np.asarray(input_array) # Strip any subclass
16
+ chopped_base = chopper(base_input) # Chop on pure -> pure chopped ndarray
17
+ # Now view the pre-chopped base as CPArray (no re-chop)
18
+ obj = chopped_base.view(cls)
19
+ obj.chopper = chopper
20
+ return obj
21
+
22
+ def __array_finalize__(self, obj):
23
+ if obj is None:
24
+ return
25
+ self.chopper = getattr(obj, 'chopper', None)
26
+
27
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
28
+ """
29
+ Override for ufuncs (+, -, *, /, etc.): Compute on pure arrays, chop pure result, view as CPArray.
30
+ """
31
+ # Validate same chopper for CPArray inputs
32
+ for inp in inputs:
33
+ if isinstance(inp, CPArray) and inp.chopper != self.chopper:
34
+ raise ValueError("All CPArray inputs must use the same chopper")
35
+
36
+ # Compute on pure ndarrays
37
+ full_inputs = [np.asarray(x) for x in inputs] # Strip subclasses
38
+ result = getattr(ufunc, method)(*full_inputs, **kwargs) # Pure computation
39
+
40
+ # Chop the pure result
41
+ chopped_result = self.chopper(result) # Chop on pure -> pure chopped
42
+
43
+ # Return as CPArray (views pre-chopped; no recursion)
44
+ if chopped_result.ndim == 0:
45
+ return chopped_result.item() # Scalar fallback
46
+ else:
47
+ return CPArray(chopped_result, self.chopper) # Safe view
48
+
49
+ # Matmul: Strip self to pure before computation
50
+ def __matmul__(self, other):
51
+ self_pure = self.view(np.ndarray) # Strip subclass
52
+ other_pure = np.asarray(other)
53
+ result = np.matmul(self_pure, other_pure)
54
+ return CPArray(result, self.chopper) # Views pre-chopped result
55
+
56
+ def __rmatmul__(self, other):
57
+ return CPArray(np.matmul(np.asarray(other), self.view(np.ndarray)), self.chopper)
58
+
59
+ # Utility: View as regular array
60
+ def to_regular(self):
61
+ return np.asarray(self)
62
+
63
+ def __str__(self):
64
+ prec_info = f"exp_bits={self.chopper.exp_bits}, sig_bits={self.chopper.sig_bits}" if hasattr(self.chopper, 'exp_bits') else "custom"
65
+ return f"CPArray({np.array2string(self)}, {prec_info})"
66
+
67
+ def __repr__(self):
68
+ return str(self)