bigsmall 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bigsmall/__init__.py ADDED
@@ -0,0 +1,45 @@
1
+ """BigSmall - Lossless neural network weight compression.
2
+
3
+ Public API:
4
+ bigsmall.compress(src, dst, mode="balanced") - compress safetensors -> .bs
5
+ bigsmall.decompress(src, dst=None) -> dict[str, ndarray]
6
+ bigsmall.load(src, device="cpu") -> dict[str, torch.Tensor]
7
+ bigsmall.info(src) -> dict
8
+ bigsmall.verify(src) -> bool
9
+ bigsmall.compress_delta(finetune, base, dst, mode="balanced")
10
+ bigsmall.decompress_delta(delta_src, base_src, dst=None) -> dict[str, ndarray]
11
+
12
+ HuggingFace Hub round-trip (Phase 4):
13
+ bigsmall.compress_for_hub(source, output_dir) - compress any HF model
14
+ bigsmall.upload_to_hub(output_dir, repo_id) - push to the Hub
15
+ bigsmall.from_pretrained(repo_or_path) - download + decompress -> state_dict
16
+ bigsmall.install_hook() - monkey-patch safetensors.load_file
17
+
18
+ Streaming loader (Phase 4 cont.):
19
+ bigsmall.StreamingLoader(path, device="cuda") - layer-by-layer decompression
20
+ """
21
+ __version__ = "1.0.0"
22
+
23
+ from .encoder import compress, compress_delta
24
+ from .decoder import decompress, decompress_delta, load
25
+ from .verify import verify
26
+ from .container import info
27
+ from .hub import compress_for_hub, upload_to_hub, from_pretrained
28
+ from .integrations.huggingface import install_hook
29
+ from .streaming import StreamingLoader
30
+
31
+ __all__ = [
32
+ "compress",
33
+ "decompress",
34
+ "load",
35
+ "info",
36
+ "verify",
37
+ "compress_delta",
38
+ "decompress_delta",
39
+ "compress_for_hub",
40
+ "upload_to_hub",
41
+ "from_pretrained",
42
+ "install_hook",
43
+ "StreamingLoader",
44
+ "__version__",
45
+ ]
bigsmall/cli.py ADDED
@@ -0,0 +1,124 @@
1
+ """BigSmall command-line interface."""
2
+ import argparse
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+
7
+
8
+ def _cmd_compress(args):
9
+ from . import encoder
10
+ src = Path(args.src)
11
+ if args.output:
12
+ dst = Path(args.output)
13
+ else:
14
+ dst = src.with_suffix(".bs")
15
+ mode = "balanced"
16
+ if args.storage:
17
+ mode = "storage"
18
+ elif args.inference:
19
+ mode = "inference"
20
+
21
+ t0 = time.perf_counter()
22
+ if args.base:
23
+ encoder.compress_delta(src, args.base, dst, mode=mode)
24
+ else:
25
+ encoder.compress(src, dst, mode=mode)
26
+ elapsed = time.perf_counter() - t0
27
+ src_size = src.stat().st_size
28
+ dst_size = dst.stat().st_size
29
+ pct = (dst_size / src_size * 100) if src_size > 0 else 0
30
+ print(f"compressed {src} -> {dst}", flush=True)
31
+ print(f" source: {src_size:,} bytes", flush=True)
32
+ print(f" compressed: {dst_size:,} bytes ({pct:.2f}%)", flush=True)
33
+ print(f" saved: {src_size - dst_size:,} bytes", flush=True)
34
+ print(f" elapsed: {elapsed:.1f}s", flush=True)
35
+
36
+
37
+ def _cmd_decompress(args):
38
+ from . import decoder
39
+ src = Path(args.src)
40
+ if args.output:
41
+ dst = Path(args.output)
42
+ else:
43
+ dst = src.with_suffix(".safetensors")
44
+
45
+ t0 = time.perf_counter()
46
+ if args.base:
47
+ decoder.decompress_delta(src, args.base, dst)
48
+ else:
49
+ decoder.decompress(src, dst)
50
+ elapsed = time.perf_counter() - t0
51
+ print(f"decompressed {src} -> {dst} ({elapsed:.1f}s)", flush=True)
52
+
53
+
54
+ def _cmd_info(args):
55
+ from .container import info
56
+ i = info(args.src)
57
+ for k, v in i.items():
58
+ print(f" {k:24s} {v}")
59
+
60
+
61
+ def _cmd_verify(args):
62
+ from .verify import verify
63
+ ok = verify(args.src, source_safetensors=args.source)
64
+ if ok:
65
+ print("OK", flush=True)
66
+ sys.exit(0)
67
+ print("FAIL", flush=True)
68
+ sys.exit(1)
69
+
70
+
71
+ def _cmd_benchmark(args):
72
+ from . import encoder, decoder
73
+ src = Path(args.src)
74
+ dst = src.with_suffix(".bs")
75
+ print(f"Benchmarking {src.name}...")
76
+ t0 = time.perf_counter(); encoder.compress(src, dst); te = time.perf_counter() - t0
77
+ t0 = time.perf_counter(); _ = decoder.decompress(dst); td = time.perf_counter() - t0
78
+ src_size = src.stat().st_size
79
+ dst_size = dst.stat().st_size
80
+ pct = (dst_size / src_size * 100) if src_size > 0 else 0
81
+ print(f" encode: {te:.1f}s ({src_size / te / 1024 / 1024:.1f} MiB/s)")
82
+ print(f" decode: {td:.1f}s ({src_size / td / 1024 / 1024:.1f} MiB/s)")
83
+ print(f" ratio: {pct:.2f}% ({src_size:,} -> {dst_size:,})")
84
+
85
+
86
+ def main(argv=None):
87
+ p = argparse.ArgumentParser(prog="bigsmall", description="BigSmall lossless NN weight compression")
88
+ sub = p.add_subparsers(dest="cmd", required=True)
89
+
90
+ c = sub.add_parser("compress", help="Compress a .safetensors file to .bs")
91
+ c.add_argument("src")
92
+ c.add_argument("-o", "--output", default=None)
93
+ c.add_argument("--base", default=None, help="Base safetensors path - enables delta mode")
94
+ grp = c.add_mutually_exclusive_group()
95
+ grp.add_argument("--storage", action="store_true", help="Maximum compression mode")
96
+ grp.add_argument("--balanced", action="store_true", help="Balanced ratio+speed (default)")
97
+ grp.add_argument("--inference", action="store_true", help="Fastest decode mode")
98
+ c.set_defaults(func=_cmd_compress)
99
+
100
+ d = sub.add_parser("decompress", help="Decompress a .bs file to .safetensors")
101
+ d.add_argument("src")
102
+ d.add_argument("-o", "--output", default=None)
103
+ d.add_argument("--base", default=None, help="Base file path for delta decompression")
104
+ d.set_defaults(func=_cmd_decompress)
105
+
106
+ i = sub.add_parser("info", help="Show metadata for a .bs file")
107
+ i.add_argument("src")
108
+ i.set_defaults(func=_cmd_info)
109
+
110
+ v = sub.add_parser("verify", help="Verify md5 round-trip of a .bs file")
111
+ v.add_argument("src")
112
+ v.add_argument("--source", default=None, help="Compare against original .safetensors")
113
+ v.set_defaults(func=_cmd_verify)
114
+
115
+ b = sub.add_parser("benchmark", help="Encode/decode benchmark for a model")
116
+ b.add_argument("src")
117
+ b.set_defaults(func=_cmd_benchmark)
118
+
119
+ args = p.parse_args(argv)
120
+ args.func(args)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
@@ -0,0 +1,19 @@
1
+ """BigSmall codecs - per-format encoder/decoder implementations."""
2
+ from . import bf16, fp32, fp16, fp8, fp4, special, generic
3
+
4
+ CODEC_REGISTRY = {
5
+ "bf16_se_ac": bf16,
6
+ "fp16_se_ac": fp16,
7
+ "fp32_se_ac": fp32,
8
+ "fp8_cat_ac": fp8,
9
+ "fp4_cat_ac": fp4,
10
+ "special": special,
11
+ "zstd": generic,
12
+ "blosc2_shuffle_zstd": generic,
13
+ }
14
+
15
+
16
+ def get_codec(name):
17
+ if name not in CODEC_REGISTRY:
18
+ raise KeyError(f"Unknown codec: {name}")
19
+ return CODEC_REGISTRY[name]
@@ -0,0 +1,170 @@
1
+ """BF16 codec: per-tensor (sign,exp) joint AC + per-tensor (mantissa | exp) AC.
2
+
3
+ BF16 layout: 1 sign bit | 8 exp bits | 7 mantissa bits = 16 bits total.
4
+
5
+ Encoding strategy (mirrors cc10_v2.py but operates per-tensor so it works on
6
+ ANY safetensors model - no GPT-2-specific handling here):
7
+
8
+ 1. Split each weight into (sign, exp, mantissa).
9
+ 2. Encode (sign, exp) jointly as a 9-bit alphabet per tensor with Categorical AC.
10
+ 3. Encode mantissa per-tensor, sorted by exp, with one Cat AC bucket per exp value.
11
+
12
+ This per-tensor approach generalises cleanly. The cc10_v2 script encoded the
13
+ mantissa stream globally across all non-special tensors; that gives a tiny
14
+ extra ratio improvement on GPT-2 but requires a fixed 'special tensor set' to
15
+ be known across the encode/decode path. Per-tensor encoding is simpler,
16
+ generalises, and the per-tensor overhead is negligible relative to the AC
17
+ codeword cost.
18
+ """
19
+ import struct
20
+ import io
21
+ import numpy as np
22
+ import constriction as c
23
+
24
+ # Bit layout for BF16
25
+ SIGN_SHIFT = 15
26
+ EXP_SHIFT = 7
27
+ EXP_MASK = 0xFF
28
+ MANT_MASK = 0x7F
29
+ SE_ALPHABET = 512 # 2 sign * 256 exp
30
+ MANT_ALPHABET = 128 # 7 mantissa bits
31
+
32
+
33
+ def _encode_cat(values: np.ndarray, alphabet: int) -> tuple[bytes, np.ndarray, np.ndarray]:
34
+ """Encode an int array with per-symbol Categorical AC.
35
+
36
+ Returns (codeword_bytes, nonzero_indices, nonzero_freqs).
37
+ """
38
+ fp = np.bincount(values, minlength=alphabet).astype(np.int64)
39
+ nz_idx = np.nonzero(fp)[0]
40
+ probs = fp.astype(np.float64) + 0.01
41
+ probs /= probs.sum()
42
+ m = c.stream.model.Categorical(probs, perfect=True)
43
+ enc = c.stream.queue.RangeEncoder()
44
+ enc.encode(values.astype(np.int32), m)
45
+ cw = enc.get_compressed().tobytes()
46
+ return cw, nz_idx, fp[nz_idx].astype(np.int64)
47
+
48
+
49
+ def _decode_cat(cw_bytes: bytes, nz_idx: np.ndarray, freqs: np.ndarray, alphabet: int, n: int) -> np.ndarray:
50
+ fp = np.zeros(alphabet, dtype=np.int64)
51
+ fp[nz_idx] = freqs
52
+ probs = fp.astype(np.float64) + 0.01
53
+ probs /= probs.sum()
54
+ m = c.stream.model.Categorical(probs, perfect=True)
55
+ cw = np.frombuffer(cw_bytes, dtype=np.uint32)
56
+ dec = c.stream.queue.RangeDecoder(cw)
57
+ return dec.decode(m, n)
58
+
59
+
60
+ def encode(raw: bytes) -> tuple[bytes, dict]:
61
+ """Encode a single BF16 tensor's raw bytes.
62
+
63
+ Args:
64
+ raw: little-endian bytes of the tensor (length must be even)
65
+
66
+ Returns:
67
+ (compressed_blob, extras_dict)
68
+
69
+ The extras dict contains nothing - the codec is self-describing inside the blob.
70
+ """
71
+ if len(raw) % 2 != 0:
72
+ raise ValueError(f"BF16 tensor byte length must be even, got {len(raw)}")
73
+ u16 = np.frombuffer(raw, dtype=np.uint16)
74
+ n = len(u16)
75
+ if n == 0:
76
+ return b"", {}
77
+
78
+ sign = ((u16 >> SIGN_SHIFT) & 1).astype(np.uint16)
79
+ exp = ((u16 >> EXP_SHIFT) & EXP_MASK).astype(np.uint16)
80
+ mant = (u16 & MANT_MASK).astype(np.uint16)
81
+ se = ((sign << 8) | exp).astype(np.int32)
82
+
83
+ # SE block
84
+ se_cw, se_nz_idx, se_freqs = _encode_cat(se, SE_ALPHABET)
85
+
86
+ # Sort mantissa by exp -> per-exp buckets
87
+ order_e = np.argsort(exp, kind="stable")
88
+ mant_sorted = mant[order_e]
89
+ exp_sorted = exp[order_e]
90
+ counts = np.bincount(exp_sorted, minlength=256)
91
+ nonzero_exps = np.nonzero(counts)[0].astype(np.int32)
92
+ bstart = np.zeros(257, dtype=np.int64)
93
+ bstart[1:] = np.cumsum(counts)
94
+
95
+ # Encode mantissa per nonzero-exp bucket
96
+ m_buf = io.BytesIO()
97
+ m_buf.write(struct.pack("<I", len(nonzero_exps)))
98
+ m_buf.write(nonzero_exps.astype(np.uint16).tobytes())
99
+ for ev in nonzero_exps:
100
+ bs = bstart[ev]
101
+ be = bstart[ev + 1]
102
+ bucket = mant_sorted[bs:be].astype(np.int32)
103
+ cw, nz_idx, freqs = _encode_cat(bucket, MANT_ALPHABET)
104
+ m_buf.write(struct.pack("<IB", be - bs, len(nz_idx)))
105
+ m_buf.write(nz_idx.astype(np.uint8).tobytes())
106
+ m_buf.write(freqs.astype(np.uint32).tobytes())
107
+ m_buf.write(struct.pack("<I", len(cw)))
108
+ m_buf.write(cw)
109
+ m_blob = m_buf.getvalue()
110
+
111
+ # Container blob
112
+ out = io.BytesIO()
113
+ out.write(struct.pack("<I", n)) # tensor weight count
114
+ out.write(struct.pack("<H", len(se_nz_idx))) # SE nonzero count
115
+ out.write(se_nz_idx.astype(np.uint16).tobytes()) # SE nonzero indices
116
+ out.write(se_freqs.astype(np.uint32).tobytes()) # SE freqs
117
+ out.write(struct.pack("<I", len(se_cw))) # SE codeword length
118
+ out.write(se_cw)
119
+ out.write(struct.pack("<I", len(m_blob))) # M block length
120
+ out.write(m_blob)
121
+ return out.getvalue(), {}
122
+
123
+
124
+ def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
125
+ """Decode a BF16 blob back to raw bytes."""
126
+ if n_weights == 0 or len(blob) == 0:
127
+ return b""
128
+ inp = io.BytesIO(blob)
129
+ n, = struct.unpack("<I", inp.read(4))
130
+ if n != n_weights:
131
+ raise ValueError(f"BF16 decode: weight count mismatch ({n} vs {n_weights})")
132
+ if n == 0:
133
+ return b""
134
+
135
+ se_n_nz, = struct.unpack("<H", inp.read(2))
136
+ se_nz_idx = np.frombuffer(inp.read(se_n_nz * 2), dtype=np.uint16).astype(np.int32)
137
+ se_freqs = np.frombuffer(inp.read(se_n_nz * 4), dtype=np.uint32)
138
+ se_cw_len, = struct.unpack("<I", inp.read(4))
139
+ se_cw = inp.read(se_cw_len)
140
+ se = _decode_cat(se_cw, se_nz_idx, se_freqs, SE_ALPHABET, n).astype(np.uint16)
141
+ sign = ((se >> 8) & 1).astype(np.uint16)
142
+ exp = (se & 0xFF).astype(np.uint16)
143
+
144
+ m_blob_len, = struct.unpack("<I", inp.read(4))
145
+ m_inp = io.BytesIO(inp.read(m_blob_len))
146
+ n_nz_exp, = struct.unpack("<I", m_inp.read(4))
147
+ nonzero_exps = np.frombuffer(m_inp.read(n_nz_exp * 2), dtype=np.uint16)
148
+
149
+ # Reconstruct sort order from decoded exp
150
+ order_e = np.argsort(exp, kind="stable")
151
+ counts = np.bincount(exp.astype(np.int64), minlength=256)
152
+ bstart = np.zeros(257, dtype=np.int64)
153
+ bstart[1:] = np.cumsum(counts)
154
+
155
+ mant_sorted = np.empty(n, dtype=np.uint16)
156
+ for ev in nonzero_exps:
157
+ nb, n_nz = struct.unpack("<IB", m_inp.read(5))
158
+ nz_idx = np.frombuffer(m_inp.read(n_nz), dtype=np.uint8).astype(np.int32)
159
+ freqs = np.frombuffer(m_inp.read(n_nz * 4), dtype=np.uint32)
160
+ cw_len, = struct.unpack("<I", m_inp.read(4))
161
+ cw = m_inp.read(cw_len)
162
+ mant_sorted[bstart[ev]:bstart[ev + 1]] = _decode_cat(
163
+ cw, nz_idx, freqs, MANT_ALPHABET, nb
164
+ ).astype(np.uint16)
165
+
166
+ mant = np.empty(n, dtype=np.uint16)
167
+ mant[order_e] = mant_sorted
168
+
169
+ out = ((sign << SIGN_SHIFT) | (exp << EXP_SHIFT) | mant).astype(np.uint16)
170
+ return out.tobytes()
@@ -0,0 +1,135 @@
1
+ """FP16 codec: per-tensor (sign,exp) joint AC + per-tensor (mantissa | exp) AC.
2
+
3
+ FP16 layout: 1 sign bit | 5 exp bits | 10 mantissa bits = 16 bits total.
4
+
5
+ Same structure as bf16.py with different alphabet sizes.
6
+ """
7
+ import struct
8
+ import io
9
+ import numpy as np
10
+ import constriction as c
11
+
12
+ SIGN_SHIFT = 15
13
+ EXP_SHIFT = 10
14
+ EXP_MASK = 0x1F # 5 bits
15
+ MANT_MASK = 0x3FF # 10 bits
16
+ SE_ALPHABET = 64 # 2 sign * 32 exp
17
+ MANT_ALPHABET = 1024 # 10 bits
18
+
19
+
20
+ def _encode_cat(values: np.ndarray, alphabet: int):
21
+ fp = np.bincount(values, minlength=alphabet).astype(np.int64)
22
+ nz_idx = np.nonzero(fp)[0]
23
+ probs = fp.astype(np.float64) + 0.01
24
+ probs /= probs.sum()
25
+ m = c.stream.model.Categorical(probs, perfect=True)
26
+ enc = c.stream.queue.RangeEncoder()
27
+ enc.encode(values.astype(np.int32), m)
28
+ cw = enc.get_compressed().tobytes()
29
+ return cw, nz_idx, fp[nz_idx].astype(np.int64)
30
+
31
+
32
+ def _decode_cat(cw_bytes, nz_idx, freqs, alphabet, n):
33
+ fp = np.zeros(alphabet, dtype=np.int64)
34
+ fp[nz_idx] = freqs
35
+ probs = fp.astype(np.float64) + 0.01
36
+ probs /= probs.sum()
37
+ m = c.stream.model.Categorical(probs, perfect=True)
38
+ cw = np.frombuffer(cw_bytes, dtype=np.uint32)
39
+ dec = c.stream.queue.RangeDecoder(cw)
40
+ return dec.decode(m, n)
41
+
42
+
43
+ def encode(raw: bytes) -> tuple[bytes, dict]:
44
+ if len(raw) % 2 != 0:
45
+ raise ValueError(f"FP16 tensor byte length must be even, got {len(raw)}")
46
+ u16 = np.frombuffer(raw, dtype=np.uint16)
47
+ n = len(u16)
48
+ if n == 0:
49
+ return b"", {}
50
+
51
+ sign = ((u16 >> SIGN_SHIFT) & 1).astype(np.uint16)
52
+ exp = ((u16 >> EXP_SHIFT) & EXP_MASK).astype(np.uint16)
53
+ mant = (u16 & MANT_MASK).astype(np.uint16)
54
+ se = ((sign << 5) | exp).astype(np.int32)
55
+
56
+ se_cw, se_nz_idx, se_freqs = _encode_cat(se, SE_ALPHABET)
57
+
58
+ order_e = np.argsort(exp, kind="stable")
59
+ mant_sorted = mant[order_e]
60
+ exp_sorted = exp[order_e]
61
+ counts = np.bincount(exp_sorted, minlength=32)
62
+ nonzero_exps = np.nonzero(counts)[0].astype(np.int32)
63
+ bstart = np.zeros(33, dtype=np.int64)
64
+ bstart[1:] = np.cumsum(counts)
65
+
66
+ m_buf = io.BytesIO()
67
+ m_buf.write(struct.pack("<I", len(nonzero_exps)))
68
+ m_buf.write(nonzero_exps.astype(np.uint8).tobytes())
69
+ for ev in nonzero_exps:
70
+ bs = bstart[ev]; be = bstart[ev + 1]
71
+ bucket = mant_sorted[bs:be].astype(np.int32)
72
+ cw, nz_idx, freqs = _encode_cat(bucket, MANT_ALPHABET)
73
+ m_buf.write(struct.pack("<IH", be - bs, len(nz_idx)))
74
+ m_buf.write(nz_idx.astype(np.uint16).tobytes())
75
+ m_buf.write(freqs.astype(np.uint32).tobytes())
76
+ m_buf.write(struct.pack("<I", len(cw)))
77
+ m_buf.write(cw)
78
+ m_blob = m_buf.getvalue()
79
+
80
+ out = io.BytesIO()
81
+ out.write(struct.pack("<I", n))
82
+ out.write(struct.pack("<B", len(se_nz_idx)))
83
+ out.write(se_nz_idx.astype(np.uint8).tobytes())
84
+ out.write(se_freqs.astype(np.uint32).tobytes())
85
+ out.write(struct.pack("<I", len(se_cw)))
86
+ out.write(se_cw)
87
+ out.write(struct.pack("<I", len(m_blob)))
88
+ out.write(m_blob)
89
+ return out.getvalue(), {}
90
+
91
+
92
+ def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
93
+ if n_weights == 0 or len(blob) == 0:
94
+ return b""
95
+ inp = io.BytesIO(blob)
96
+ n, = struct.unpack("<I", inp.read(4))
97
+ if n != n_weights:
98
+ raise ValueError(f"FP16 decode: weight count mismatch ({n} vs {n_weights})")
99
+ if n == 0:
100
+ return b""
101
+
102
+ se_n_nz, = struct.unpack("<B", inp.read(1))
103
+ se_nz_idx = np.frombuffer(inp.read(se_n_nz), dtype=np.uint8).astype(np.int32)
104
+ se_freqs = np.frombuffer(inp.read(se_n_nz * 4), dtype=np.uint32)
105
+ se_cw_len, = struct.unpack("<I", inp.read(4))
106
+ se_cw = inp.read(se_cw_len)
107
+ se = _decode_cat(se_cw, se_nz_idx, se_freqs, SE_ALPHABET, n).astype(np.uint16)
108
+ sign = ((se >> 5) & 1).astype(np.uint16)
109
+ exp = (se & 0x1F).astype(np.uint16)
110
+
111
+ m_blob_len, = struct.unpack("<I", inp.read(4))
112
+ m_inp = io.BytesIO(inp.read(m_blob_len))
113
+ n_nz_exp, = struct.unpack("<I", m_inp.read(4))
114
+ nonzero_exps = np.frombuffer(m_inp.read(n_nz_exp), dtype=np.uint8)
115
+
116
+ order_e = np.argsort(exp, kind="stable")
117
+ counts = np.bincount(exp.astype(np.int64), minlength=32)
118
+ bstart = np.zeros(33, dtype=np.int64)
119
+ bstart[1:] = np.cumsum(counts)
120
+
121
+ mant_sorted = np.empty(n, dtype=np.uint16)
122
+ for ev in nonzero_exps:
123
+ nb, n_nz = struct.unpack("<IH", m_inp.read(6))
124
+ nz_idx = np.frombuffer(m_inp.read(n_nz * 2), dtype=np.uint16).astype(np.int32)
125
+ freqs = np.frombuffer(m_inp.read(n_nz * 4), dtype=np.uint32)
126
+ cw_len, = struct.unpack("<I", m_inp.read(4))
127
+ cw = m_inp.read(cw_len)
128
+ mant_sorted[bstart[ev]:bstart[ev + 1]] = _decode_cat(
129
+ cw, nz_idx, freqs, MANT_ALPHABET, nb
130
+ ).astype(np.uint16)
131
+
132
+ mant = np.empty(n, dtype=np.uint16)
133
+ mant[order_e] = mant_sorted
134
+ out = ((sign << SIGN_SHIFT) | (exp << EXP_SHIFT) | mant).astype(np.uint16)
135
+ return out.tobytes()
@@ -0,0 +1,119 @@
1
+ """FP32 codec: per-tensor (sign,exp) joint AC + per-tensor (mantissa | exp) AC.
2
+
3
+ FP32 layout: 1 sign | 8 exp | 23 mantissa = 32 bits total.
4
+
5
+ The 23-bit mantissa cannot be efficiently AC-encoded as a single 8M-symbol
6
+ alphabet, so we encode mantissa per-(exp) bucket using 3 byte-streams
7
+ (low, mid, high) compressed with zstd. This mirrors cc10_v2 byte-transpose
8
+ intuition while staying simple.
9
+
10
+ Strategy per tensor:
11
+ 1. Split into sign(1), exp(8), mant_lo8 + mant_mid8 + mant_hi7.
12
+ 2. Encode (sign, exp) as 9-bit alphabet (512) per tensor with Cat AC.
13
+ 3. Compress mant_lo / mant_mid / mant_hi as three byte streams with zstd L9
14
+ (the high 7 bits of mantissa have lower entropy than the low 16).
15
+ """
16
+ import struct
17
+ import io
18
+ import numpy as np
19
+ import constriction as c
20
+ import zstandard as zstd
21
+
22
+ SE_ALPHABET = 512 # 2 sign * 256 exp
23
+
24
+
25
+ def _encode_cat(values: np.ndarray, alphabet: int):
26
+ fp = np.bincount(values, minlength=alphabet).astype(np.int64)
27
+ nz_idx = np.nonzero(fp)[0]
28
+ probs = fp.astype(np.float64) + 0.01
29
+ probs /= probs.sum()
30
+ m = c.stream.model.Categorical(probs, perfect=True)
31
+ enc = c.stream.queue.RangeEncoder()
32
+ enc.encode(values.astype(np.int32), m)
33
+ cw = enc.get_compressed().tobytes()
34
+ return cw, nz_idx, fp[nz_idx].astype(np.int64)
35
+
36
+
37
+ def _decode_cat(cw_bytes, nz_idx, freqs, alphabet, n):
38
+ fp = np.zeros(alphabet, dtype=np.int64)
39
+ fp[nz_idx] = freqs
40
+ probs = fp.astype(np.float64) + 0.01
41
+ probs /= probs.sum()
42
+ m = c.stream.model.Categorical(probs, perfect=True)
43
+ cw = np.frombuffer(cw_bytes, dtype=np.uint32)
44
+ dec = c.stream.queue.RangeDecoder(cw)
45
+ return dec.decode(m, n)
46
+
47
+
48
+ def encode(raw: bytes) -> tuple[bytes, dict]:
49
+ if len(raw) % 4 != 0:
50
+ raise ValueError(f"FP32 tensor byte length must be % 4 == 0, got {len(raw)}")
51
+ u32 = np.frombuffer(raw, dtype=np.uint32)
52
+ n = len(u32)
53
+ if n == 0:
54
+ return b"", {}
55
+
56
+ sign = ((u32 >> 31) & 1).astype(np.uint16)
57
+ exp = ((u32 >> 23) & 0xFF).astype(np.uint16)
58
+ mant = (u32 & 0x7FFFFF).astype(np.uint32)
59
+ se = ((sign << 8) | exp).astype(np.int32)
60
+
61
+ se_cw, se_nz_idx, se_freqs = _encode_cat(se, SE_ALPHABET)
62
+
63
+ # Split mantissa into low 8, mid 8, high 7 bits and zstd-compress each
64
+ # (these 3 byte-streams approximate per-byte entropy of the mantissa)
65
+ mant_lo = (mant & 0xFF).astype(np.uint8)
66
+ mant_mid = ((mant >> 8) & 0xFF).astype(np.uint8)
67
+ mant_hi = ((mant >> 16) & 0x7F).astype(np.uint8)
68
+
69
+ cctx = zstd.ZstdCompressor(level=9)
70
+ blob_lo = cctx.compress(mant_lo.tobytes())
71
+ blob_mid = cctx.compress(mant_mid.tobytes())
72
+ blob_hi = cctx.compress(mant_hi.tobytes())
73
+
74
+ out = io.BytesIO()
75
+ out.write(struct.pack("<I", n))
76
+ out.write(struct.pack("<H", len(se_nz_idx)))
77
+ out.write(se_nz_idx.astype(np.uint16).tobytes())
78
+ out.write(se_freqs.astype(np.uint32).tobytes())
79
+ out.write(struct.pack("<I", len(se_cw)))
80
+ out.write(se_cw)
81
+ out.write(struct.pack("<I", len(blob_lo)))
82
+ out.write(blob_lo)
83
+ out.write(struct.pack("<I", len(blob_mid)))
84
+ out.write(blob_mid)
85
+ out.write(struct.pack("<I", len(blob_hi)))
86
+ out.write(blob_hi)
87
+ return out.getvalue(), {}
88
+
89
+
90
+ def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
91
+ if n_weights == 0 or len(blob) == 0:
92
+ return b""
93
+ inp = io.BytesIO(blob)
94
+ n, = struct.unpack("<I", inp.read(4))
95
+ if n != n_weights:
96
+ raise ValueError(f"FP32 decode: weight count mismatch ({n} vs {n_weights})")
97
+ if n == 0:
98
+ return b""
99
+
100
+ se_n_nz, = struct.unpack("<H", inp.read(2))
101
+ se_nz_idx = np.frombuffer(inp.read(se_n_nz * 2), dtype=np.uint16).astype(np.int32)
102
+ se_freqs = np.frombuffer(inp.read(se_n_nz * 4), dtype=np.uint32)
103
+ se_cw_len, = struct.unpack("<I", inp.read(4))
104
+ se_cw = inp.read(se_cw_len)
105
+ se = _decode_cat(se_cw, se_nz_idx, se_freqs, SE_ALPHABET, n).astype(np.uint32)
106
+ sign = ((se >> 8) & 1).astype(np.uint32)
107
+ exp = (se & 0xFF).astype(np.uint32)
108
+
109
+ blob_lo_len, = struct.unpack("<I", inp.read(4)); blob_lo = inp.read(blob_lo_len)
110
+ blob_mid_len, = struct.unpack("<I", inp.read(4)); blob_mid = inp.read(blob_mid_len)
111
+ blob_hi_len, = struct.unpack("<I", inp.read(4)); blob_hi = inp.read(blob_hi_len)
112
+ dctx = zstd.ZstdDecompressor()
113
+ mant_lo = np.frombuffer(dctx.decompress(blob_lo), dtype=np.uint8).astype(np.uint32)
114
+ mant_mid = np.frombuffer(dctx.decompress(blob_mid), dtype=np.uint8).astype(np.uint32)
115
+ mant_hi = np.frombuffer(dctx.decompress(blob_hi), dtype=np.uint8).astype(np.uint32)
116
+ mant = mant_lo | (mant_mid << 8) | (mant_hi << 16)
117
+
118
+ out = ((sign << 31) | (exp << 23) | mant).astype(np.uint32)
119
+ return out.tobytes()
bigsmall/codecs/fp4.py ADDED
@@ -0,0 +1,76 @@
1
+ """FP4 codec: per-tensor Categorical AC on 4-bit indices (alphabet 16).
2
+
3
+ The codec operates on UNPACKED FP4 (one byte per 4-bit value). The encoder
4
+ takes raw bytes representing already-unpacked 4-bit indices in the low nibble.
5
+ """
6
+ import struct
7
+ import io
8
+ import numpy as np
9
+ import constriction as c
10
+
11
+
12
+ def encode(raw: bytes) -> tuple[bytes, dict]:
13
+ """Encode FP4 unpacked stream (low nibble of each byte)."""
14
+ if len(raw) == 0:
15
+ return b"", {}
16
+ u8 = np.frombuffer(raw, dtype=np.uint8)
17
+ # Defensive: ensure values are in [0, 15]
18
+ if u8.max() > 15:
19
+ # Caller passed packed FP4 - unpack
20
+ n_total = len(u8) * 2
21
+ unp = np.empty(n_total, dtype=np.uint8)
22
+ unp[0::2] = u8 & 0x0F
23
+ unp[1::2] = (u8 >> 4) & 0x0F
24
+ u8 = unp
25
+ was_packed = True
26
+ else:
27
+ was_packed = False
28
+ n = len(u8)
29
+ vals = u8.astype(np.int32)
30
+ fp = np.bincount(vals, minlength=16).astype(np.int64)
31
+ nz_idx = np.nonzero(fp)[0]
32
+ probs = fp.astype(np.float64) + 0.01
33
+ probs /= probs.sum()
34
+ m = c.stream.model.Categorical(probs, perfect=True)
35
+ enc = c.stream.queue.RangeEncoder()
36
+ enc.encode(vals, m)
37
+ cw = enc.get_compressed().tobytes()
38
+
39
+ out = io.BytesIO()
40
+ out.write(struct.pack("<IBB", n, len(nz_idx), 1 if was_packed else 0))
41
+ out.write(nz_idx.astype(np.uint8).tobytes())
42
+ out.write(fp[nz_idx].astype(np.uint32).tobytes())
43
+ out.write(struct.pack("<I", len(cw)))
44
+ out.write(cw)
45
+ return out.getvalue(), {"was_packed": was_packed}
46
+
47
+
48
+ def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
49
+ """Decode FP4 to unpacked stream (low nibble of each byte).
50
+
51
+ Note: n_weights is interpreted as the unpacked count.
52
+ If extras['was_packed'], the original was packed; we re-pack on output.
53
+ """
54
+ if n_weights == 0 or len(blob) == 0:
55
+ return b""
56
+ inp = io.BytesIO(blob)
57
+ n, n_nz, was_packed_flag = struct.unpack("<IBB", inp.read(6))
58
+ nz_idx = np.frombuffer(inp.read(n_nz), dtype=np.uint8)
59
+ freqs = np.frombuffer(inp.read(n_nz * 4), dtype=np.uint32)
60
+ cw_len, = struct.unpack("<I", inp.read(4))
61
+ cw_bytes = inp.read(cw_len)
62
+ fp = np.zeros(16, dtype=np.int64)
63
+ fp[nz_idx] = freqs
64
+ probs = fp.astype(np.float64) + 0.01
65
+ probs /= probs.sum()
66
+ m = c.stream.model.Categorical(probs, perfect=True)
67
+ cw = np.frombuffer(cw_bytes, dtype=np.uint32)
68
+ dec = c.stream.queue.RangeDecoder(cw)
69
+ out = dec.decode(m, n).astype(np.uint8)
70
+
71
+ if was_packed_flag:
72
+ if n % 2 != 0:
73
+ raise ValueError("FP4 packed decode requires even unpacked count")
74
+ packed = (out[0::2] | (out[1::2] << 4)).astype(np.uint8)
75
+ return packed.tobytes()
76
+ return out.tobytes()