bigsmall 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bigsmall/__init__.py +45 -0
- bigsmall/cli.py +124 -0
- bigsmall/codecs/__init__.py +19 -0
- bigsmall/codecs/bf16.py +170 -0
- bigsmall/codecs/fp16.py +135 -0
- bigsmall/codecs/fp32.py +119 -0
- bigsmall/codecs/fp4.py +76 -0
- bigsmall/codecs/fp8.py +56 -0
- bigsmall/codecs/generic.py +44 -0
- bigsmall/codecs/special.py +156 -0
- bigsmall/container.py +110 -0
- bigsmall/decoder.py +243 -0
- bigsmall/delta.py +32 -0
- bigsmall/encoder.py +413 -0
- bigsmall/formats.py +59 -0
- bigsmall/hub.py +286 -0
- bigsmall/hub_index.py +141 -0
- bigsmall/integrations/__init__.py +1 -0
- bigsmall/integrations/diffusion.py +99 -0
- bigsmall/integrations/huggingface.py +99 -0
- bigsmall/integrations/vllm.py +151 -0
- bigsmall/streaming.py +237 -0
- bigsmall/streaming_model.py +176 -0
- bigsmall/tensor_analysis.py +143 -0
- bigsmall/verify.py +50 -0
- bigsmall-1.0.0.dist-info/METADATA +234 -0
- bigsmall-1.0.0.dist-info/RECORD +31 -0
- bigsmall-1.0.0.dist-info/WHEEL +5 -0
- bigsmall-1.0.0.dist-info/entry_points.txt +2 -0
- bigsmall-1.0.0.dist-info/licenses/LICENSE +201 -0
- bigsmall-1.0.0.dist-info/top_level.txt +1 -0
bigsmall/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""BigSmall - Lossless neural network weight compression.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
bigsmall.compress(src, dst, mode="balanced") - compress safetensors -> .bs
|
|
5
|
+
bigsmall.decompress(src, dst=None) -> dict[str, ndarray]
|
|
6
|
+
bigsmall.load(src, device="cpu") -> dict[str, torch.Tensor]
|
|
7
|
+
bigsmall.info(src) -> dict
|
|
8
|
+
bigsmall.verify(src) -> bool
|
|
9
|
+
bigsmall.compress_delta(finetune, base, dst, mode="balanced")
|
|
10
|
+
bigsmall.decompress_delta(delta_src, base_src, dst=None) -> dict[str, ndarray]
|
|
11
|
+
|
|
12
|
+
HuggingFace Hub round-trip (Phase 4):
|
|
13
|
+
bigsmall.compress_for_hub(source, output_dir) - compress any HF model
|
|
14
|
+
bigsmall.upload_to_hub(output_dir, repo_id) - push to the Hub
|
|
15
|
+
bigsmall.from_pretrained(repo_or_path) - download + decompress -> state_dict
|
|
16
|
+
bigsmall.install_hook() - monkey-patch safetensors.load_file
|
|
17
|
+
|
|
18
|
+
Streaming loader (Phase 4 cont.):
|
|
19
|
+
bigsmall.StreamingLoader(path, device="cuda") - layer-by-layer decompression
|
|
20
|
+
"""
|
|
21
|
+
__version__ = "1.0.0"
|
|
22
|
+
|
|
23
|
+
from .encoder import compress, compress_delta
|
|
24
|
+
from .decoder import decompress, decompress_delta, load
|
|
25
|
+
from .verify import verify
|
|
26
|
+
from .container import info
|
|
27
|
+
from .hub import compress_for_hub, upload_to_hub, from_pretrained
|
|
28
|
+
from .integrations.huggingface import install_hook
|
|
29
|
+
from .streaming import StreamingLoader
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"compress",
|
|
33
|
+
"decompress",
|
|
34
|
+
"load",
|
|
35
|
+
"info",
|
|
36
|
+
"verify",
|
|
37
|
+
"compress_delta",
|
|
38
|
+
"decompress_delta",
|
|
39
|
+
"compress_for_hub",
|
|
40
|
+
"upload_to_hub",
|
|
41
|
+
"from_pretrained",
|
|
42
|
+
"install_hook",
|
|
43
|
+
"StreamingLoader",
|
|
44
|
+
"__version__",
|
|
45
|
+
]
|
bigsmall/cli.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""BigSmall command-line interface."""
|
|
2
|
+
import argparse
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _cmd_compress(args):
|
|
9
|
+
from . import encoder
|
|
10
|
+
src = Path(args.src)
|
|
11
|
+
if args.output:
|
|
12
|
+
dst = Path(args.output)
|
|
13
|
+
else:
|
|
14
|
+
dst = src.with_suffix(".bs")
|
|
15
|
+
mode = "balanced"
|
|
16
|
+
if args.storage:
|
|
17
|
+
mode = "storage"
|
|
18
|
+
elif args.inference:
|
|
19
|
+
mode = "inference"
|
|
20
|
+
|
|
21
|
+
t0 = time.perf_counter()
|
|
22
|
+
if args.base:
|
|
23
|
+
encoder.compress_delta(src, args.base, dst, mode=mode)
|
|
24
|
+
else:
|
|
25
|
+
encoder.compress(src, dst, mode=mode)
|
|
26
|
+
elapsed = time.perf_counter() - t0
|
|
27
|
+
src_size = src.stat().st_size
|
|
28
|
+
dst_size = dst.stat().st_size
|
|
29
|
+
pct = (dst_size / src_size * 100) if src_size > 0 else 0
|
|
30
|
+
print(f"compressed {src} -> {dst}", flush=True)
|
|
31
|
+
print(f" source: {src_size:,} bytes", flush=True)
|
|
32
|
+
print(f" compressed: {dst_size:,} bytes ({pct:.2f}%)", flush=True)
|
|
33
|
+
print(f" saved: {src_size - dst_size:,} bytes", flush=True)
|
|
34
|
+
print(f" elapsed: {elapsed:.1f}s", flush=True)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _cmd_decompress(args):
|
|
38
|
+
from . import decoder
|
|
39
|
+
src = Path(args.src)
|
|
40
|
+
if args.output:
|
|
41
|
+
dst = Path(args.output)
|
|
42
|
+
else:
|
|
43
|
+
dst = src.with_suffix(".safetensors")
|
|
44
|
+
|
|
45
|
+
t0 = time.perf_counter()
|
|
46
|
+
if args.base:
|
|
47
|
+
decoder.decompress_delta(src, args.base, dst)
|
|
48
|
+
else:
|
|
49
|
+
decoder.decompress(src, dst)
|
|
50
|
+
elapsed = time.perf_counter() - t0
|
|
51
|
+
print(f"decompressed {src} -> {dst} ({elapsed:.1f}s)", flush=True)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _cmd_info(args):
|
|
55
|
+
from .container import info
|
|
56
|
+
i = info(args.src)
|
|
57
|
+
for k, v in i.items():
|
|
58
|
+
print(f" {k:24s} {v}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _cmd_verify(args):
|
|
62
|
+
from .verify import verify
|
|
63
|
+
ok = verify(args.src, source_safetensors=args.source)
|
|
64
|
+
if ok:
|
|
65
|
+
print("OK", flush=True)
|
|
66
|
+
sys.exit(0)
|
|
67
|
+
print("FAIL", flush=True)
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _cmd_benchmark(args):
|
|
72
|
+
from . import encoder, decoder
|
|
73
|
+
src = Path(args.src)
|
|
74
|
+
dst = src.with_suffix(".bs")
|
|
75
|
+
print(f"Benchmarking {src.name}...")
|
|
76
|
+
t0 = time.perf_counter(); encoder.compress(src, dst); te = time.perf_counter() - t0
|
|
77
|
+
t0 = time.perf_counter(); _ = decoder.decompress(dst); td = time.perf_counter() - t0
|
|
78
|
+
src_size = src.stat().st_size
|
|
79
|
+
dst_size = dst.stat().st_size
|
|
80
|
+
pct = (dst_size / src_size * 100) if src_size > 0 else 0
|
|
81
|
+
print(f" encode: {te:.1f}s ({src_size / te / 1024 / 1024:.1f} MiB/s)")
|
|
82
|
+
print(f" decode: {td:.1f}s ({src_size / td / 1024 / 1024:.1f} MiB/s)")
|
|
83
|
+
print(f" ratio: {pct:.2f}% ({src_size:,} -> {dst_size:,})")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def main(argv=None):
|
|
87
|
+
p = argparse.ArgumentParser(prog="bigsmall", description="BigSmall lossless NN weight compression")
|
|
88
|
+
sub = p.add_subparsers(dest="cmd", required=True)
|
|
89
|
+
|
|
90
|
+
c = sub.add_parser("compress", help="Compress a .safetensors file to .bs")
|
|
91
|
+
c.add_argument("src")
|
|
92
|
+
c.add_argument("-o", "--output", default=None)
|
|
93
|
+
c.add_argument("--base", default=None, help="Base safetensors path - enables delta mode")
|
|
94
|
+
grp = c.add_mutually_exclusive_group()
|
|
95
|
+
grp.add_argument("--storage", action="store_true", help="Maximum compression mode")
|
|
96
|
+
grp.add_argument("--balanced", action="store_true", help="Balanced ratio+speed (default)")
|
|
97
|
+
grp.add_argument("--inference", action="store_true", help="Fastest decode mode")
|
|
98
|
+
c.set_defaults(func=_cmd_compress)
|
|
99
|
+
|
|
100
|
+
d = sub.add_parser("decompress", help="Decompress a .bs file to .safetensors")
|
|
101
|
+
d.add_argument("src")
|
|
102
|
+
d.add_argument("-o", "--output", default=None)
|
|
103
|
+
d.add_argument("--base", default=None, help="Base file path for delta decompression")
|
|
104
|
+
d.set_defaults(func=_cmd_decompress)
|
|
105
|
+
|
|
106
|
+
i = sub.add_parser("info", help="Show metadata for a .bs file")
|
|
107
|
+
i.add_argument("src")
|
|
108
|
+
i.set_defaults(func=_cmd_info)
|
|
109
|
+
|
|
110
|
+
v = sub.add_parser("verify", help="Verify md5 round-trip of a .bs file")
|
|
111
|
+
v.add_argument("src")
|
|
112
|
+
v.add_argument("--source", default=None, help="Compare against original .safetensors")
|
|
113
|
+
v.set_defaults(func=_cmd_verify)
|
|
114
|
+
|
|
115
|
+
b = sub.add_parser("benchmark", help="Encode/decode benchmark for a model")
|
|
116
|
+
b.add_argument("src")
|
|
117
|
+
b.set_defaults(func=_cmd_benchmark)
|
|
118
|
+
|
|
119
|
+
args = p.parse_args(argv)
|
|
120
|
+
args.func(args)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
if __name__ == "__main__":
|
|
124
|
+
main()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""BigSmall codecs - per-format encoder/decoder implementations."""
|
|
2
|
+
from . import bf16, fp32, fp16, fp8, fp4, special, generic
|
|
3
|
+
|
|
4
|
+
CODEC_REGISTRY = {
|
|
5
|
+
"bf16_se_ac": bf16,
|
|
6
|
+
"fp16_se_ac": fp16,
|
|
7
|
+
"fp32_se_ac": fp32,
|
|
8
|
+
"fp8_cat_ac": fp8,
|
|
9
|
+
"fp4_cat_ac": fp4,
|
|
10
|
+
"special": special,
|
|
11
|
+
"zstd": generic,
|
|
12
|
+
"blosc2_shuffle_zstd": generic,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_codec(name):
|
|
17
|
+
if name not in CODEC_REGISTRY:
|
|
18
|
+
raise KeyError(f"Unknown codec: {name}")
|
|
19
|
+
return CODEC_REGISTRY[name]
|
bigsmall/codecs/bf16.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""BF16 codec: per-tensor (sign,exp) joint AC + per-tensor (mantissa | exp) AC.
|
|
2
|
+
|
|
3
|
+
BF16 layout: 1 sign bit | 8 exp bits | 7 mantissa bits = 16 bits total.
|
|
4
|
+
|
|
5
|
+
Encoding strategy (mirrors cc10_v2.py but operates per-tensor so it works on
|
|
6
|
+
ANY safetensors model - no GPT-2-specific handling here):
|
|
7
|
+
|
|
8
|
+
1. Split each weight into (sign, exp, mantissa).
|
|
9
|
+
2. Encode (sign, exp) jointly as a 9-bit alphabet per tensor with Categorical AC.
|
|
10
|
+
3. Encode mantissa per-tensor, sorted by exp, with one Cat AC bucket per exp value.
|
|
11
|
+
|
|
12
|
+
This per-tensor approach generalises cleanly. The cc10_v2 script encoded the
|
|
13
|
+
mantissa stream globally across all non-special tensors; that gives a tiny
|
|
14
|
+
extra ratio improvement on GPT-2 but requires a fixed 'special tensor set' to
|
|
15
|
+
be known across the encode/decode path. Per-tensor encoding is simpler,
|
|
16
|
+
generalises, and the per-tensor overhead is negligible relative to the AC
|
|
17
|
+
codeword cost.
|
|
18
|
+
"""
|
|
19
|
+
import struct
|
|
20
|
+
import io
|
|
21
|
+
import numpy as np
|
|
22
|
+
import constriction as c
|
|
23
|
+
|
|
24
|
+
# Bit layout for BF16
|
|
25
|
+
SIGN_SHIFT = 15
|
|
26
|
+
EXP_SHIFT = 7
|
|
27
|
+
EXP_MASK = 0xFF
|
|
28
|
+
MANT_MASK = 0x7F
|
|
29
|
+
SE_ALPHABET = 512 # 2 sign * 256 exp
|
|
30
|
+
MANT_ALPHABET = 128 # 7 mantissa bits
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _encode_cat(values: np.ndarray, alphabet: int) -> tuple[bytes, np.ndarray, np.ndarray]:
|
|
34
|
+
"""Encode an int array with per-symbol Categorical AC.
|
|
35
|
+
|
|
36
|
+
Returns (codeword_bytes, nonzero_indices, nonzero_freqs).
|
|
37
|
+
"""
|
|
38
|
+
fp = np.bincount(values, minlength=alphabet).astype(np.int64)
|
|
39
|
+
nz_idx = np.nonzero(fp)[0]
|
|
40
|
+
probs = fp.astype(np.float64) + 0.01
|
|
41
|
+
probs /= probs.sum()
|
|
42
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
43
|
+
enc = c.stream.queue.RangeEncoder()
|
|
44
|
+
enc.encode(values.astype(np.int32), m)
|
|
45
|
+
cw = enc.get_compressed().tobytes()
|
|
46
|
+
return cw, nz_idx, fp[nz_idx].astype(np.int64)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _decode_cat(cw_bytes: bytes, nz_idx: np.ndarray, freqs: np.ndarray, alphabet: int, n: int) -> np.ndarray:
|
|
50
|
+
fp = np.zeros(alphabet, dtype=np.int64)
|
|
51
|
+
fp[nz_idx] = freqs
|
|
52
|
+
probs = fp.astype(np.float64) + 0.01
|
|
53
|
+
probs /= probs.sum()
|
|
54
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
55
|
+
cw = np.frombuffer(cw_bytes, dtype=np.uint32)
|
|
56
|
+
dec = c.stream.queue.RangeDecoder(cw)
|
|
57
|
+
return dec.decode(m, n)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def encode(raw: bytes) -> tuple[bytes, dict]:
|
|
61
|
+
"""Encode a single BF16 tensor's raw bytes.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
raw: little-endian bytes of the tensor (length must be even)
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
(compressed_blob, extras_dict)
|
|
68
|
+
|
|
69
|
+
The extras dict contains nothing - the codec is self-describing inside the blob.
|
|
70
|
+
"""
|
|
71
|
+
if len(raw) % 2 != 0:
|
|
72
|
+
raise ValueError(f"BF16 tensor byte length must be even, got {len(raw)}")
|
|
73
|
+
u16 = np.frombuffer(raw, dtype=np.uint16)
|
|
74
|
+
n = len(u16)
|
|
75
|
+
if n == 0:
|
|
76
|
+
return b"", {}
|
|
77
|
+
|
|
78
|
+
sign = ((u16 >> SIGN_SHIFT) & 1).astype(np.uint16)
|
|
79
|
+
exp = ((u16 >> EXP_SHIFT) & EXP_MASK).astype(np.uint16)
|
|
80
|
+
mant = (u16 & MANT_MASK).astype(np.uint16)
|
|
81
|
+
se = ((sign << 8) | exp).astype(np.int32)
|
|
82
|
+
|
|
83
|
+
# SE block
|
|
84
|
+
se_cw, se_nz_idx, se_freqs = _encode_cat(se, SE_ALPHABET)
|
|
85
|
+
|
|
86
|
+
# Sort mantissa by exp -> per-exp buckets
|
|
87
|
+
order_e = np.argsort(exp, kind="stable")
|
|
88
|
+
mant_sorted = mant[order_e]
|
|
89
|
+
exp_sorted = exp[order_e]
|
|
90
|
+
counts = np.bincount(exp_sorted, minlength=256)
|
|
91
|
+
nonzero_exps = np.nonzero(counts)[0].astype(np.int32)
|
|
92
|
+
bstart = np.zeros(257, dtype=np.int64)
|
|
93
|
+
bstart[1:] = np.cumsum(counts)
|
|
94
|
+
|
|
95
|
+
# Encode mantissa per nonzero-exp bucket
|
|
96
|
+
m_buf = io.BytesIO()
|
|
97
|
+
m_buf.write(struct.pack("<I", len(nonzero_exps)))
|
|
98
|
+
m_buf.write(nonzero_exps.astype(np.uint16).tobytes())
|
|
99
|
+
for ev in nonzero_exps:
|
|
100
|
+
bs = bstart[ev]
|
|
101
|
+
be = bstart[ev + 1]
|
|
102
|
+
bucket = mant_sorted[bs:be].astype(np.int32)
|
|
103
|
+
cw, nz_idx, freqs = _encode_cat(bucket, MANT_ALPHABET)
|
|
104
|
+
m_buf.write(struct.pack("<IB", be - bs, len(nz_idx)))
|
|
105
|
+
m_buf.write(nz_idx.astype(np.uint8).tobytes())
|
|
106
|
+
m_buf.write(freqs.astype(np.uint32).tobytes())
|
|
107
|
+
m_buf.write(struct.pack("<I", len(cw)))
|
|
108
|
+
m_buf.write(cw)
|
|
109
|
+
m_blob = m_buf.getvalue()
|
|
110
|
+
|
|
111
|
+
# Container blob
|
|
112
|
+
out = io.BytesIO()
|
|
113
|
+
out.write(struct.pack("<I", n)) # tensor weight count
|
|
114
|
+
out.write(struct.pack("<H", len(se_nz_idx))) # SE nonzero count
|
|
115
|
+
out.write(se_nz_idx.astype(np.uint16).tobytes()) # SE nonzero indices
|
|
116
|
+
out.write(se_freqs.astype(np.uint32).tobytes()) # SE freqs
|
|
117
|
+
out.write(struct.pack("<I", len(se_cw))) # SE codeword length
|
|
118
|
+
out.write(se_cw)
|
|
119
|
+
out.write(struct.pack("<I", len(m_blob))) # M block length
|
|
120
|
+
out.write(m_blob)
|
|
121
|
+
return out.getvalue(), {}
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
|
|
125
|
+
"""Decode a BF16 blob back to raw bytes."""
|
|
126
|
+
if n_weights == 0 or len(blob) == 0:
|
|
127
|
+
return b""
|
|
128
|
+
inp = io.BytesIO(blob)
|
|
129
|
+
n, = struct.unpack("<I", inp.read(4))
|
|
130
|
+
if n != n_weights:
|
|
131
|
+
raise ValueError(f"BF16 decode: weight count mismatch ({n} vs {n_weights})")
|
|
132
|
+
if n == 0:
|
|
133
|
+
return b""
|
|
134
|
+
|
|
135
|
+
se_n_nz, = struct.unpack("<H", inp.read(2))
|
|
136
|
+
se_nz_idx = np.frombuffer(inp.read(se_n_nz * 2), dtype=np.uint16).astype(np.int32)
|
|
137
|
+
se_freqs = np.frombuffer(inp.read(se_n_nz * 4), dtype=np.uint32)
|
|
138
|
+
se_cw_len, = struct.unpack("<I", inp.read(4))
|
|
139
|
+
se_cw = inp.read(se_cw_len)
|
|
140
|
+
se = _decode_cat(se_cw, se_nz_idx, se_freqs, SE_ALPHABET, n).astype(np.uint16)
|
|
141
|
+
sign = ((se >> 8) & 1).astype(np.uint16)
|
|
142
|
+
exp = (se & 0xFF).astype(np.uint16)
|
|
143
|
+
|
|
144
|
+
m_blob_len, = struct.unpack("<I", inp.read(4))
|
|
145
|
+
m_inp = io.BytesIO(inp.read(m_blob_len))
|
|
146
|
+
n_nz_exp, = struct.unpack("<I", m_inp.read(4))
|
|
147
|
+
nonzero_exps = np.frombuffer(m_inp.read(n_nz_exp * 2), dtype=np.uint16)
|
|
148
|
+
|
|
149
|
+
# Reconstruct sort order from decoded exp
|
|
150
|
+
order_e = np.argsort(exp, kind="stable")
|
|
151
|
+
counts = np.bincount(exp.astype(np.int64), minlength=256)
|
|
152
|
+
bstart = np.zeros(257, dtype=np.int64)
|
|
153
|
+
bstart[1:] = np.cumsum(counts)
|
|
154
|
+
|
|
155
|
+
mant_sorted = np.empty(n, dtype=np.uint16)
|
|
156
|
+
for ev in nonzero_exps:
|
|
157
|
+
nb, n_nz = struct.unpack("<IB", m_inp.read(5))
|
|
158
|
+
nz_idx = np.frombuffer(m_inp.read(n_nz), dtype=np.uint8).astype(np.int32)
|
|
159
|
+
freqs = np.frombuffer(m_inp.read(n_nz * 4), dtype=np.uint32)
|
|
160
|
+
cw_len, = struct.unpack("<I", m_inp.read(4))
|
|
161
|
+
cw = m_inp.read(cw_len)
|
|
162
|
+
mant_sorted[bstart[ev]:bstart[ev + 1]] = _decode_cat(
|
|
163
|
+
cw, nz_idx, freqs, MANT_ALPHABET, nb
|
|
164
|
+
).astype(np.uint16)
|
|
165
|
+
|
|
166
|
+
mant = np.empty(n, dtype=np.uint16)
|
|
167
|
+
mant[order_e] = mant_sorted
|
|
168
|
+
|
|
169
|
+
out = ((sign << SIGN_SHIFT) | (exp << EXP_SHIFT) | mant).astype(np.uint16)
|
|
170
|
+
return out.tobytes()
|
bigsmall/codecs/fp16.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""FP16 codec: per-tensor (sign,exp) joint AC + per-tensor (mantissa | exp) AC.
|
|
2
|
+
|
|
3
|
+
FP16 layout: 1 sign bit | 5 exp bits | 10 mantissa bits = 16 bits total.
|
|
4
|
+
|
|
5
|
+
Same structure as bf16.py with different alphabet sizes.
|
|
6
|
+
"""
|
|
7
|
+
import struct
|
|
8
|
+
import io
|
|
9
|
+
import numpy as np
|
|
10
|
+
import constriction as c
|
|
11
|
+
|
|
12
|
+
SIGN_SHIFT = 15
|
|
13
|
+
EXP_SHIFT = 10
|
|
14
|
+
EXP_MASK = 0x1F # 5 bits
|
|
15
|
+
MANT_MASK = 0x3FF # 10 bits
|
|
16
|
+
SE_ALPHABET = 64 # 2 sign * 32 exp
|
|
17
|
+
MANT_ALPHABET = 1024 # 10 bits
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _encode_cat(values: np.ndarray, alphabet: int):
|
|
21
|
+
fp = np.bincount(values, minlength=alphabet).astype(np.int64)
|
|
22
|
+
nz_idx = np.nonzero(fp)[0]
|
|
23
|
+
probs = fp.astype(np.float64) + 0.01
|
|
24
|
+
probs /= probs.sum()
|
|
25
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
26
|
+
enc = c.stream.queue.RangeEncoder()
|
|
27
|
+
enc.encode(values.astype(np.int32), m)
|
|
28
|
+
cw = enc.get_compressed().tobytes()
|
|
29
|
+
return cw, nz_idx, fp[nz_idx].astype(np.int64)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _decode_cat(cw_bytes, nz_idx, freqs, alphabet, n):
|
|
33
|
+
fp = np.zeros(alphabet, dtype=np.int64)
|
|
34
|
+
fp[nz_idx] = freqs
|
|
35
|
+
probs = fp.astype(np.float64) + 0.01
|
|
36
|
+
probs /= probs.sum()
|
|
37
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
38
|
+
cw = np.frombuffer(cw_bytes, dtype=np.uint32)
|
|
39
|
+
dec = c.stream.queue.RangeDecoder(cw)
|
|
40
|
+
return dec.decode(m, n)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def encode(raw: bytes) -> tuple[bytes, dict]:
|
|
44
|
+
if len(raw) % 2 != 0:
|
|
45
|
+
raise ValueError(f"FP16 tensor byte length must be even, got {len(raw)}")
|
|
46
|
+
u16 = np.frombuffer(raw, dtype=np.uint16)
|
|
47
|
+
n = len(u16)
|
|
48
|
+
if n == 0:
|
|
49
|
+
return b"", {}
|
|
50
|
+
|
|
51
|
+
sign = ((u16 >> SIGN_SHIFT) & 1).astype(np.uint16)
|
|
52
|
+
exp = ((u16 >> EXP_SHIFT) & EXP_MASK).astype(np.uint16)
|
|
53
|
+
mant = (u16 & MANT_MASK).astype(np.uint16)
|
|
54
|
+
se = ((sign << 5) | exp).astype(np.int32)
|
|
55
|
+
|
|
56
|
+
se_cw, se_nz_idx, se_freqs = _encode_cat(se, SE_ALPHABET)
|
|
57
|
+
|
|
58
|
+
order_e = np.argsort(exp, kind="stable")
|
|
59
|
+
mant_sorted = mant[order_e]
|
|
60
|
+
exp_sorted = exp[order_e]
|
|
61
|
+
counts = np.bincount(exp_sorted, minlength=32)
|
|
62
|
+
nonzero_exps = np.nonzero(counts)[0].astype(np.int32)
|
|
63
|
+
bstart = np.zeros(33, dtype=np.int64)
|
|
64
|
+
bstart[1:] = np.cumsum(counts)
|
|
65
|
+
|
|
66
|
+
m_buf = io.BytesIO()
|
|
67
|
+
m_buf.write(struct.pack("<I", len(nonzero_exps)))
|
|
68
|
+
m_buf.write(nonzero_exps.astype(np.uint8).tobytes())
|
|
69
|
+
for ev in nonzero_exps:
|
|
70
|
+
bs = bstart[ev]; be = bstart[ev + 1]
|
|
71
|
+
bucket = mant_sorted[bs:be].astype(np.int32)
|
|
72
|
+
cw, nz_idx, freqs = _encode_cat(bucket, MANT_ALPHABET)
|
|
73
|
+
m_buf.write(struct.pack("<IH", be - bs, len(nz_idx)))
|
|
74
|
+
m_buf.write(nz_idx.astype(np.uint16).tobytes())
|
|
75
|
+
m_buf.write(freqs.astype(np.uint32).tobytes())
|
|
76
|
+
m_buf.write(struct.pack("<I", len(cw)))
|
|
77
|
+
m_buf.write(cw)
|
|
78
|
+
m_blob = m_buf.getvalue()
|
|
79
|
+
|
|
80
|
+
out = io.BytesIO()
|
|
81
|
+
out.write(struct.pack("<I", n))
|
|
82
|
+
out.write(struct.pack("<B", len(se_nz_idx)))
|
|
83
|
+
out.write(se_nz_idx.astype(np.uint8).tobytes())
|
|
84
|
+
out.write(se_freqs.astype(np.uint32).tobytes())
|
|
85
|
+
out.write(struct.pack("<I", len(se_cw)))
|
|
86
|
+
out.write(se_cw)
|
|
87
|
+
out.write(struct.pack("<I", len(m_blob)))
|
|
88
|
+
out.write(m_blob)
|
|
89
|
+
return out.getvalue(), {}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
|
|
93
|
+
if n_weights == 0 or len(blob) == 0:
|
|
94
|
+
return b""
|
|
95
|
+
inp = io.BytesIO(blob)
|
|
96
|
+
n, = struct.unpack("<I", inp.read(4))
|
|
97
|
+
if n != n_weights:
|
|
98
|
+
raise ValueError(f"FP16 decode: weight count mismatch ({n} vs {n_weights})")
|
|
99
|
+
if n == 0:
|
|
100
|
+
return b""
|
|
101
|
+
|
|
102
|
+
se_n_nz, = struct.unpack("<B", inp.read(1))
|
|
103
|
+
se_nz_idx = np.frombuffer(inp.read(se_n_nz), dtype=np.uint8).astype(np.int32)
|
|
104
|
+
se_freqs = np.frombuffer(inp.read(se_n_nz * 4), dtype=np.uint32)
|
|
105
|
+
se_cw_len, = struct.unpack("<I", inp.read(4))
|
|
106
|
+
se_cw = inp.read(se_cw_len)
|
|
107
|
+
se = _decode_cat(se_cw, se_nz_idx, se_freqs, SE_ALPHABET, n).astype(np.uint16)
|
|
108
|
+
sign = ((se >> 5) & 1).astype(np.uint16)
|
|
109
|
+
exp = (se & 0x1F).astype(np.uint16)
|
|
110
|
+
|
|
111
|
+
m_blob_len, = struct.unpack("<I", inp.read(4))
|
|
112
|
+
m_inp = io.BytesIO(inp.read(m_blob_len))
|
|
113
|
+
n_nz_exp, = struct.unpack("<I", m_inp.read(4))
|
|
114
|
+
nonzero_exps = np.frombuffer(m_inp.read(n_nz_exp), dtype=np.uint8)
|
|
115
|
+
|
|
116
|
+
order_e = np.argsort(exp, kind="stable")
|
|
117
|
+
counts = np.bincount(exp.astype(np.int64), minlength=32)
|
|
118
|
+
bstart = np.zeros(33, dtype=np.int64)
|
|
119
|
+
bstart[1:] = np.cumsum(counts)
|
|
120
|
+
|
|
121
|
+
mant_sorted = np.empty(n, dtype=np.uint16)
|
|
122
|
+
for ev in nonzero_exps:
|
|
123
|
+
nb, n_nz = struct.unpack("<IH", m_inp.read(6))
|
|
124
|
+
nz_idx = np.frombuffer(m_inp.read(n_nz * 2), dtype=np.uint16).astype(np.int32)
|
|
125
|
+
freqs = np.frombuffer(m_inp.read(n_nz * 4), dtype=np.uint32)
|
|
126
|
+
cw_len, = struct.unpack("<I", m_inp.read(4))
|
|
127
|
+
cw = m_inp.read(cw_len)
|
|
128
|
+
mant_sorted[bstart[ev]:bstart[ev + 1]] = _decode_cat(
|
|
129
|
+
cw, nz_idx, freqs, MANT_ALPHABET, nb
|
|
130
|
+
).astype(np.uint16)
|
|
131
|
+
|
|
132
|
+
mant = np.empty(n, dtype=np.uint16)
|
|
133
|
+
mant[order_e] = mant_sorted
|
|
134
|
+
out = ((sign << SIGN_SHIFT) | (exp << EXP_SHIFT) | mant).astype(np.uint16)
|
|
135
|
+
return out.tobytes()
|
bigsmall/codecs/fp32.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""FP32 codec: per-tensor (sign,exp) joint AC + per-tensor (mantissa | exp) AC.
|
|
2
|
+
|
|
3
|
+
FP32 layout: 1 sign | 8 exp | 23 mantissa = 32 bits total.
|
|
4
|
+
|
|
5
|
+
The 23-bit mantissa cannot be efficiently AC-encoded as a single 8M-symbol
|
|
6
|
+
alphabet, so we encode mantissa per-(exp) bucket using 3 byte-streams
|
|
7
|
+
(low, mid, high) compressed with zstd. This mirrors cc10_v2 byte-transpose
|
|
8
|
+
intuition while staying simple.
|
|
9
|
+
|
|
10
|
+
Strategy per tensor:
|
|
11
|
+
1. Split into sign(1), exp(8), mant_lo8 + mant_mid8 + mant_hi7.
|
|
12
|
+
2. Encode (sign, exp) as 9-bit alphabet (512) per tensor with Cat AC.
|
|
13
|
+
3. Compress mant_lo / mant_mid / mant_hi as three byte streams with zstd L9
|
|
14
|
+
(the high 7 bits of mantissa have lower entropy than the low 16).
|
|
15
|
+
"""
|
|
16
|
+
import struct
|
|
17
|
+
import io
|
|
18
|
+
import numpy as np
|
|
19
|
+
import constriction as c
|
|
20
|
+
import zstandard as zstd
|
|
21
|
+
|
|
22
|
+
SE_ALPHABET = 512 # 2 sign * 256 exp
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _encode_cat(values: np.ndarray, alphabet: int):
|
|
26
|
+
fp = np.bincount(values, minlength=alphabet).astype(np.int64)
|
|
27
|
+
nz_idx = np.nonzero(fp)[0]
|
|
28
|
+
probs = fp.astype(np.float64) + 0.01
|
|
29
|
+
probs /= probs.sum()
|
|
30
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
31
|
+
enc = c.stream.queue.RangeEncoder()
|
|
32
|
+
enc.encode(values.astype(np.int32), m)
|
|
33
|
+
cw = enc.get_compressed().tobytes()
|
|
34
|
+
return cw, nz_idx, fp[nz_idx].astype(np.int64)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _decode_cat(cw_bytes, nz_idx, freqs, alphabet, n):
|
|
38
|
+
fp = np.zeros(alphabet, dtype=np.int64)
|
|
39
|
+
fp[nz_idx] = freqs
|
|
40
|
+
probs = fp.astype(np.float64) + 0.01
|
|
41
|
+
probs /= probs.sum()
|
|
42
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
43
|
+
cw = np.frombuffer(cw_bytes, dtype=np.uint32)
|
|
44
|
+
dec = c.stream.queue.RangeDecoder(cw)
|
|
45
|
+
return dec.decode(m, n)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def encode(raw: bytes) -> tuple[bytes, dict]:
|
|
49
|
+
if len(raw) % 4 != 0:
|
|
50
|
+
raise ValueError(f"FP32 tensor byte length must be % 4 == 0, got {len(raw)}")
|
|
51
|
+
u32 = np.frombuffer(raw, dtype=np.uint32)
|
|
52
|
+
n = len(u32)
|
|
53
|
+
if n == 0:
|
|
54
|
+
return b"", {}
|
|
55
|
+
|
|
56
|
+
sign = ((u32 >> 31) & 1).astype(np.uint16)
|
|
57
|
+
exp = ((u32 >> 23) & 0xFF).astype(np.uint16)
|
|
58
|
+
mant = (u32 & 0x7FFFFF).astype(np.uint32)
|
|
59
|
+
se = ((sign << 8) | exp).astype(np.int32)
|
|
60
|
+
|
|
61
|
+
se_cw, se_nz_idx, se_freqs = _encode_cat(se, SE_ALPHABET)
|
|
62
|
+
|
|
63
|
+
# Split mantissa into low 8, mid 8, high 7 bits and zstd-compress each
|
|
64
|
+
# (these 3 byte-streams approximate per-byte entropy of the mantissa)
|
|
65
|
+
mant_lo = (mant & 0xFF).astype(np.uint8)
|
|
66
|
+
mant_mid = ((mant >> 8) & 0xFF).astype(np.uint8)
|
|
67
|
+
mant_hi = ((mant >> 16) & 0x7F).astype(np.uint8)
|
|
68
|
+
|
|
69
|
+
cctx = zstd.ZstdCompressor(level=9)
|
|
70
|
+
blob_lo = cctx.compress(mant_lo.tobytes())
|
|
71
|
+
blob_mid = cctx.compress(mant_mid.tobytes())
|
|
72
|
+
blob_hi = cctx.compress(mant_hi.tobytes())
|
|
73
|
+
|
|
74
|
+
out = io.BytesIO()
|
|
75
|
+
out.write(struct.pack("<I", n))
|
|
76
|
+
out.write(struct.pack("<H", len(se_nz_idx)))
|
|
77
|
+
out.write(se_nz_idx.astype(np.uint16).tobytes())
|
|
78
|
+
out.write(se_freqs.astype(np.uint32).tobytes())
|
|
79
|
+
out.write(struct.pack("<I", len(se_cw)))
|
|
80
|
+
out.write(se_cw)
|
|
81
|
+
out.write(struct.pack("<I", len(blob_lo)))
|
|
82
|
+
out.write(blob_lo)
|
|
83
|
+
out.write(struct.pack("<I", len(blob_mid)))
|
|
84
|
+
out.write(blob_mid)
|
|
85
|
+
out.write(struct.pack("<I", len(blob_hi)))
|
|
86
|
+
out.write(blob_hi)
|
|
87
|
+
return out.getvalue(), {}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
|
|
91
|
+
if n_weights == 0 or len(blob) == 0:
|
|
92
|
+
return b""
|
|
93
|
+
inp = io.BytesIO(blob)
|
|
94
|
+
n, = struct.unpack("<I", inp.read(4))
|
|
95
|
+
if n != n_weights:
|
|
96
|
+
raise ValueError(f"FP32 decode: weight count mismatch ({n} vs {n_weights})")
|
|
97
|
+
if n == 0:
|
|
98
|
+
return b""
|
|
99
|
+
|
|
100
|
+
se_n_nz, = struct.unpack("<H", inp.read(2))
|
|
101
|
+
se_nz_idx = np.frombuffer(inp.read(se_n_nz * 2), dtype=np.uint16).astype(np.int32)
|
|
102
|
+
se_freqs = np.frombuffer(inp.read(se_n_nz * 4), dtype=np.uint32)
|
|
103
|
+
se_cw_len, = struct.unpack("<I", inp.read(4))
|
|
104
|
+
se_cw = inp.read(se_cw_len)
|
|
105
|
+
se = _decode_cat(se_cw, se_nz_idx, se_freqs, SE_ALPHABET, n).astype(np.uint32)
|
|
106
|
+
sign = ((se >> 8) & 1).astype(np.uint32)
|
|
107
|
+
exp = (se & 0xFF).astype(np.uint32)
|
|
108
|
+
|
|
109
|
+
blob_lo_len, = struct.unpack("<I", inp.read(4)); blob_lo = inp.read(blob_lo_len)
|
|
110
|
+
blob_mid_len, = struct.unpack("<I", inp.read(4)); blob_mid = inp.read(blob_mid_len)
|
|
111
|
+
blob_hi_len, = struct.unpack("<I", inp.read(4)); blob_hi = inp.read(blob_hi_len)
|
|
112
|
+
dctx = zstd.ZstdDecompressor()
|
|
113
|
+
mant_lo = np.frombuffer(dctx.decompress(blob_lo), dtype=np.uint8).astype(np.uint32)
|
|
114
|
+
mant_mid = np.frombuffer(dctx.decompress(blob_mid), dtype=np.uint8).astype(np.uint32)
|
|
115
|
+
mant_hi = np.frombuffer(dctx.decompress(blob_hi), dtype=np.uint8).astype(np.uint32)
|
|
116
|
+
mant = mant_lo | (mant_mid << 8) | (mant_hi << 16)
|
|
117
|
+
|
|
118
|
+
out = ((sign << 31) | (exp << 23) | mant).astype(np.uint32)
|
|
119
|
+
return out.tobytes()
|
bigsmall/codecs/fp4.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""FP4 codec: per-tensor Categorical AC on 4-bit indices (alphabet 16).
|
|
2
|
+
|
|
3
|
+
The codec operates on UNPACKED FP4 (one byte per 4-bit value). The encoder
|
|
4
|
+
takes raw bytes representing already-unpacked 4-bit indices in the low nibble.
|
|
5
|
+
"""
|
|
6
|
+
import struct
|
|
7
|
+
import io
|
|
8
|
+
import numpy as np
|
|
9
|
+
import constriction as c
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def encode(raw: bytes) -> tuple[bytes, dict]:
|
|
13
|
+
"""Encode FP4 unpacked stream (low nibble of each byte)."""
|
|
14
|
+
if len(raw) == 0:
|
|
15
|
+
return b"", {}
|
|
16
|
+
u8 = np.frombuffer(raw, dtype=np.uint8)
|
|
17
|
+
# Defensive: ensure values are in [0, 15]
|
|
18
|
+
if u8.max() > 15:
|
|
19
|
+
# Caller passed packed FP4 - unpack
|
|
20
|
+
n_total = len(u8) * 2
|
|
21
|
+
unp = np.empty(n_total, dtype=np.uint8)
|
|
22
|
+
unp[0::2] = u8 & 0x0F
|
|
23
|
+
unp[1::2] = (u8 >> 4) & 0x0F
|
|
24
|
+
u8 = unp
|
|
25
|
+
was_packed = True
|
|
26
|
+
else:
|
|
27
|
+
was_packed = False
|
|
28
|
+
n = len(u8)
|
|
29
|
+
vals = u8.astype(np.int32)
|
|
30
|
+
fp = np.bincount(vals, minlength=16).astype(np.int64)
|
|
31
|
+
nz_idx = np.nonzero(fp)[0]
|
|
32
|
+
probs = fp.astype(np.float64) + 0.01
|
|
33
|
+
probs /= probs.sum()
|
|
34
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
35
|
+
enc = c.stream.queue.RangeEncoder()
|
|
36
|
+
enc.encode(vals, m)
|
|
37
|
+
cw = enc.get_compressed().tobytes()
|
|
38
|
+
|
|
39
|
+
out = io.BytesIO()
|
|
40
|
+
out.write(struct.pack("<IBB", n, len(nz_idx), 1 if was_packed else 0))
|
|
41
|
+
out.write(nz_idx.astype(np.uint8).tobytes())
|
|
42
|
+
out.write(fp[nz_idx].astype(np.uint32).tobytes())
|
|
43
|
+
out.write(struct.pack("<I", len(cw)))
|
|
44
|
+
out.write(cw)
|
|
45
|
+
return out.getvalue(), {"was_packed": was_packed}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def decode(blob: bytes, extras: dict, n_weights: int) -> bytes:
|
|
49
|
+
"""Decode FP4 to unpacked stream (low nibble of each byte).
|
|
50
|
+
|
|
51
|
+
Note: n_weights is interpreted as the unpacked count.
|
|
52
|
+
If extras['was_packed'], the original was packed; we re-pack on output.
|
|
53
|
+
"""
|
|
54
|
+
if n_weights == 0 or len(blob) == 0:
|
|
55
|
+
return b""
|
|
56
|
+
inp = io.BytesIO(blob)
|
|
57
|
+
n, n_nz, was_packed_flag = struct.unpack("<IBB", inp.read(6))
|
|
58
|
+
nz_idx = np.frombuffer(inp.read(n_nz), dtype=np.uint8)
|
|
59
|
+
freqs = np.frombuffer(inp.read(n_nz * 4), dtype=np.uint32)
|
|
60
|
+
cw_len, = struct.unpack("<I", inp.read(4))
|
|
61
|
+
cw_bytes = inp.read(cw_len)
|
|
62
|
+
fp = np.zeros(16, dtype=np.int64)
|
|
63
|
+
fp[nz_idx] = freqs
|
|
64
|
+
probs = fp.astype(np.float64) + 0.01
|
|
65
|
+
probs /= probs.sum()
|
|
66
|
+
m = c.stream.model.Categorical(probs, perfect=True)
|
|
67
|
+
cw = np.frombuffer(cw_bytes, dtype=np.uint32)
|
|
68
|
+
dec = c.stream.queue.RangeDecoder(cw)
|
|
69
|
+
out = dec.decode(m, n).astype(np.uint8)
|
|
70
|
+
|
|
71
|
+
if was_packed_flag:
|
|
72
|
+
if n % 2 != 0:
|
|
73
|
+
raise ValueError("FP4 packed decode requires even unpacked count")
|
|
74
|
+
packed = (out[0::2] | (out[1::2] << 4)).astype(np.uint8)
|
|
75
|
+
return packed.tobytes()
|
|
76
|
+
return out.tobytes()
|