pertype 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,996 @@
1
+ /* Native hot loops (compiled to audio.so by native.py, called via ctypes).
2
+ *
3
+ * Each function must produce bit-identical output to its pure-Python reference
4
+ * (in audiocodec.py / transform.py), or losslessness breaks. Compiled with
5
+ * -fwrapv (signed wrap like numpy int64) and -ffp-contract=off (no FMA, so the
6
+ * float Rice `run` update matches Python).
7
+ *
8
+ * Contents: the lossless audio codec's predictor + Rice coder, and the byte-
9
+ * stream `delta` transform used by the image/numeric path.
10
+ */
11
+ #include <stdint.h>
12
+ #include <stdlib.h>
13
+ #include <string.h>
14
+
15
+ static inline int64_t sgn(int64_t v) { return (v > 0) - (v < 0); }
16
+
17
+ void lms_fwd(const int64_t *x, int64_t *out, long n, int taps, int shift) {
18
+ int64_t *w = (int64_t *)calloc(taps, sizeof(int64_t));
19
+ int64_t *h = (int64_t *)calloc(taps, sizeof(int64_t));
20
+ for (long i = 0; i < n; i++) {
21
+ int64_t sum = 0;
22
+ for (int j = 0; j < taps; j++) sum += w[j] * h[j];
23
+ int64_t pred = sum >> shift; /* arithmetic shift = floor, matches Python */
24
+ int64_t err = x[i] - pred;
25
+ out[i] = err;
26
+ if (err > 0) { for (int j = 0; j < taps; j++) w[j] += sgn(h[j]); }
27
+ else if (err < 0) { for (int j = 0; j < taps; j++) w[j] -= sgn(h[j]); }
28
+ for (int j = taps - 1; j > 0; j--) h[j] = h[j - 1];
29
+ h[0] = x[i];
30
+ }
31
+ free(w); free(h);
32
+ }
33
+
34
+ void lms_inv(const int64_t *e, int64_t *x, long n, int taps, int shift) {
35
+ int64_t *w = (int64_t *)calloc(taps, sizeof(int64_t));
36
+ int64_t *h = (int64_t *)calloc(taps, sizeof(int64_t));
37
+ for (long i = 0; i < n; i++) {
38
+ int64_t sum = 0;
39
+ for (int j = 0; j < taps; j++) sum += w[j] * h[j];
40
+ int64_t pred = sum >> shift;
41
+ int64_t xi = e[i] + pred;
42
+ x[i] = xi;
43
+ if (e[i] > 0) { for (int j = 0; j < taps; j++) w[j] += sgn(h[j]); }
44
+ else if (e[i] < 0) { for (int j = 0; j < taps; j++) w[j] -= sgn(h[j]); }
45
+ for (int j = taps - 1; j > 0; j--) h[j] = h[j - 1];
46
+ h[0] = xi;
47
+ }
48
+ free(w); free(h);
49
+ }
50
+
51
+ /* order-2 fixed predictor (inverse is a sequential recurrence) */
52
+ void fixed2_fwd(const int64_t *x, int64_t *e, long n) {
53
+ e[0] = (n > 0) ? x[0] : 0;
54
+ if (n > 1) e[1] = x[1];
55
+ for (long i = 2; i < n; i++) e[i] = x[i] - (2 * x[i - 1] - x[i - 2]);
56
+ }
57
+
58
+ void fixed2_inv(const int64_t *e, int64_t *x, long n) {
59
+ x[0] = (n > 0) ? e[0] : 0;
60
+ if (n > 1) x[1] = e[1];
61
+ for (long i = 2; i < n; i++) x[i] = e[i] + 2 * x[i - 1] - x[i - 2];
62
+ }
63
+
64
+ /* --- adaptive Rice coding (MSB-first bits, matching bitio.py) --------------- */
65
+ #define RICE_ALPHA 0.02
66
+
67
+ static int _k_from_run(double run) {
68
+ long v = (long)run; /* trunc toward zero, like int(run) */
69
+ if (v < 1) return 0;
70
+ return 63 - __builtin_clzll((unsigned long long)v); /* bit_length(v) - 1 */
71
+ }
72
+
73
+ /* Returns bytes written, or -1 if the output buffer is too small. */
74
+ long rice_encode(const int64_t *res, long n, uint8_t *out, long cap) {
75
+ double run = 16.0;
76
+ long byte = 0;
77
+ int nbits = 0;
78
+ unsigned cur = 0;
79
+ for (long i = 0; i < n; i++) {
80
+ int64_t r = res[i];
81
+ uint64_t u = (uint64_t)((r << 1) ^ (r >> 63)); /* zigzag */
82
+ int k = _k_from_run(run);
83
+ uint64_t q = u >> k;
84
+ for (uint64_t t = 0; t < q + 1; t++) { /* q ones then a zero */
85
+ cur = (cur << 1) | (t < q ? 1u : 0u);
86
+ if (++nbits == 8) { if (byte >= cap) return -1; out[byte++] = cur; cur = 0; nbits = 0; }
87
+ }
88
+ for (int s = k - 1; s >= 0; s--) { /* k remainder bits, MSB first */
89
+ cur = (cur << 1) | (unsigned)((u >> s) & 1);
90
+ if (++nbits == 8) { if (byte >= cap) return -1; out[byte++] = cur; cur = 0; nbits = 0; }
91
+ }
92
+ run += (u - run) * RICE_ALPHA;
93
+ }
94
+ if (nbits > 0) { if (byte >= cap) return -1; out[byte++] = (uint8_t)(cur << (8 - nbits)); }
95
+ return byte;
96
+ }
97
+
98
+ void rice_decode(const uint8_t *in, long n, int64_t *out) {
99
+ double run = 16.0;
100
+ long pos = 0; /* bit position */
101
+ for (long i = 0; i < n; i++) {
102
+ int k = _k_from_run(run);
103
+ uint64_t q = 0;
104
+ while ((in[pos >> 3] >> (7 - (pos & 7))) & 1) { q++; pos++; }
105
+ pos++; /* the terminating zero */
106
+ uint64_t rem = 0;
107
+ for (int s = 0; s < k; s++) { rem = (rem << 1) | ((in[pos >> 3] >> (7 - (pos & 7))) & 1); pos++; }
108
+ uint64_t u = (q << k) | rem;
109
+ out[i] = (int64_t)(u >> 1) ^ -(int64_t)(u & 1); /* unzigzag */
110
+ run += (u - run) * RICE_ALPHA;
111
+ }
112
+ }
113
+
114
+ /* --- byte-stream delta transform (image/numeric path) ----------------------- */
115
+ void delta_fwd(const uint8_t *data, uint8_t *out, long n, int stride) {
116
+ for (long i = 0; i < stride && i < n; i++) out[i] = data[i];
117
+ for (long i = stride; i < n; i++) out[i] = (uint8_t)(data[i] - data[i - stride]);
118
+ }
119
+
120
+ void delta_inv(const uint8_t *data, uint8_t *out, long n, int stride) {
121
+ for (long i = 0; i < stride && i < n; i++) out[i] = data[i];
122
+ for (long i = stride; i < n; i++) out[i] = (uint8_t)(out[i - stride] + data[i]);
123
+ }
124
+
125
+ /* --- context-adaptive arithmetic residual coder (mirrors ctxcoder.py) ------
126
+ *
127
+ * Witten–Neal–Cleary 32-bit arithmetic coder driving a per-context adaptive
128
+ * model over magnitude buckets, with raw mantissa bits coded as uniform symbols
129
+ * through the same coder. Bit output is MSB-first with a zero-padded final byte,
130
+ * exactly matching bitio.BitWriter, so the C output is byte-identical to the
131
+ * pure-Python reference and files are interchangeable. All integer math.
132
+ */
133
+ #define CTX_NB 65 /* buckets 0..64 cover any int64 zigzag magnitude */
134
+ #define CTX_CLAMP 16 /* context bucket clamp (keeps the model dense) */
135
+ #define CTX_NCTX ((CTX_CLAMP + 1) * (CTX_CLAMP + 1)) /* order-2: (prev, prev-prev) */
136
+ #define CTX_INCR 32
137
+ #define CTX_RESCALE (1 << 14)
138
+ #define CTX_MINCR 24 /* adaptation of the modelled top-mantissa bit */
139
+ #define CTX_MRESCALE (1 << 13)
140
+ #define AC_HALF 0x80000000ULL
141
+ #define AC_QUARTER 0x40000000ULL
142
+ #define AC_3QUARTER 0xC0000000ULL
143
+ #define AC_MAX 0xFFFFFFFFULL
144
+
145
+ /* MSB-first bit writer */
146
+ typedef struct { uint8_t *out; long cap, byte; unsigned cur; int nbits, overflow; } bitw;
147
+
148
+ static inline void bw_bit(bitw *w, int bit) {
149
+ w->cur = (w->cur << 1) | (unsigned)(bit & 1);
150
+ if (++w->nbits == 8) {
151
+ if (w->byte >= w->cap) { w->overflow = 1; w->nbits = 0; w->cur = 0; return; }
152
+ w->out[w->byte++] = (uint8_t)w->cur; w->cur = 0; w->nbits = 0;
153
+ }
154
+ }
155
+
156
+ typedef struct { uint64_t low, high; long pending; bitw *w; } aenc;
157
+
158
+ static inline void ae_emit(aenc *e, int bit) {
159
+ bw_bit(e->w, bit);
160
+ while (e->pending) { bw_bit(e->w, bit ^ 1); e->pending--; }
161
+ }
162
+
163
+ static void ae_encode(aenc *e, uint64_t cum, uint64_t freq, uint64_t total) {
164
+ uint64_t span = e->high - e->low + 1;
165
+ e->high = e->low + span * (cum + freq) / total - 1;
166
+ e->low = e->low + span * cum / total;
167
+ for (;;) {
168
+ if (e->high < AC_HALF) ae_emit(e, 0);
169
+ else if (e->low >= AC_HALF) { ae_emit(e, 1); e->low -= AC_HALF; e->high -= AC_HALF; }
170
+ else if (e->low >= AC_QUARTER && e->high < AC_3QUARTER) {
171
+ e->pending++; e->low -= AC_QUARTER; e->high -= AC_QUARTER;
172
+ } else break;
173
+ e->low <<= 1; e->high = (e->high << 1) | 1;
174
+ }
175
+ }
176
+
177
+ static void ctx_init(int freq[CTX_NCTX][CTX_NB], long tot[CTX_NCTX]) {
178
+ for (int c = 0; c < CTX_NCTX; c++) {
179
+ for (int s = 0; s < CTX_NB; s++) freq[c][s] = 1;
180
+ tot[c] = CTX_NB;
181
+ }
182
+ }
183
+
184
+ static inline int ctx_index(int pk, int pk2) {
185
+ int a = pk < CTX_CLAMP ? pk : CTX_CLAMP;
186
+ int b = pk2 < CTX_CLAMP ? pk2 : CTX_CLAMP;
187
+ return a * (CTX_CLAMP + 1) + b;
188
+ }
189
+
190
+ static void ctx_bump(int *f, long *tot, int k) {
191
+ f[k] += CTX_INCR; *tot += CTX_INCR;
192
+ if (*tot >= CTX_RESCALE) {
193
+ long t = 0;
194
+ for (int s = 0; s < CTX_NB; s++) { f[s] = (f[s] + 1) >> 1; t += f[s]; }
195
+ *tot = t;
196
+ }
197
+ }
198
+
199
+ /* Returns bytes written, or -1 if the output buffer is too small. */
200
+ /* adaptive binary model for the top mantissa bit, indexed by (ctx, k) */
201
+ static void ctx_mant_init(int mf[CTX_NCTX][CTX_NB][2], long mt[CTX_NCTX][CTX_NB]) {
202
+ for (int c = 0; c < CTX_NCTX; c++)
203
+ for (int s = 0; s < CTX_NB; s++) { mf[c][s][0] = 1; mf[c][s][1] = 1; mt[c][s] = 2; }
204
+ }
205
+
206
+ long ctx_encode(const int64_t *res, long n, uint8_t *out, long cap) {
207
+ static int freq[CTX_NCTX][CTX_NB]; static long tot[CTX_NCTX];
208
+ static int mf[CTX_NCTX][CTX_NB][2]; static long mt[CTX_NCTX][CTX_NB];
209
+ ctx_init(freq, tot);
210
+ ctx_mant_init(mf, mt);
211
+ bitw w = { out, cap, 0, 0, 0, 0 };
212
+ aenc e = { 0, AC_MAX, 0, &w };
213
+ int pk = 0, pk2 = 0;
214
+ for (long i = 0; i < n; i++) {
215
+ int64_t r = res[i];
216
+ uint64_t u = (((uint64_t)r) << 1) ^ (uint64_t)(r >> 63); /* zigzag */
217
+ int k = u ? (64 - __builtin_clzll(u)) : 0;
218
+ int ctx = ctx_index(pk, pk2);
219
+ int *f = freq[ctx];
220
+ uint64_t cum = 0;
221
+ for (int s = 0; s < k; s++) cum += (uint64_t)f[s];
222
+ ae_encode(&e, cum, (uint64_t)f[k], (uint64_t)tot[ctx]);
223
+ if (k >= 2) {
224
+ uint64_t mant = u & ((1ULL << (k - 1)) - 1);
225
+ int b1 = (mant >> (k - 2)) & 1; /* top mantissa bit: modelled */
226
+ int *m = mf[ctx][k];
227
+ ae_encode(&e, b1 == 0 ? 0 : (uint64_t)m[0], (uint64_t)m[b1], (uint64_t)mt[ctx][k]);
228
+ m[b1] += CTX_MINCR; mt[ctx][k] += CTX_MINCR;
229
+ if (mt[ctx][k] >= CTX_MRESCALE) {
230
+ m[0] = (m[0] + 1) >> 1; m[1] = (m[1] + 1) >> 1; mt[ctx][k] = m[0] + m[1];
231
+ }
232
+ for (int shift = k - 3; shift >= 0; shift--) /* remaining low bits raw */
233
+ ae_encode(&e, (mant >> shift) & 1, 1, 2);
234
+ }
235
+ if (w.overflow) return -1;
236
+ ctx_bump(f, &tot[ctx], k);
237
+ pk2 = pk; pk = k;
238
+ }
239
+ e.pending++; /* finish() */
240
+ ae_emit(&e, e.low < AC_QUARTER ? 0 : 1);
241
+ if (w.overflow) return -1;
242
+ if (w.nbits > 0) { /* getvalue(): pad final byte */
243
+ if (w.byte >= w.cap) return -1;
244
+ w.out[w.byte++] = (uint8_t)(w.cur << (8 - w.nbits));
245
+ }
246
+ return w.byte;
247
+ }
248
+
249
+ typedef struct { uint64_t low, high, code; const uint8_t *in; long len, pos; } adec;
250
+
251
+ static inline int ad_bit(adec *d) {
252
+ long bi = d->pos >> 3;
253
+ int b = (bi >= d->len) ? 0 : ((d->in[bi] >> (7 - (d->pos & 7))) & 1);
254
+ d->pos++;
255
+ return b;
256
+ }
257
+
258
+ static uint64_t ad_target(adec *d, uint64_t total) {
259
+ uint64_t span = d->high - d->low + 1;
260
+ return ((d->code - d->low + 1) * total - 1) / span;
261
+ }
262
+
263
+ static void ad_update(adec *d, uint64_t cum, uint64_t freq, uint64_t total) {
264
+ uint64_t span = d->high - d->low + 1;
265
+ d->high = d->low + span * (cum + freq) / total - 1;
266
+ d->low = d->low + span * cum / total;
267
+ for (;;) {
268
+ if (d->high < AC_HALF) {}
269
+ else if (d->low >= AC_HALF) { d->low -= AC_HALF; d->high -= AC_HALF; d->code -= AC_HALF; }
270
+ else if (d->low >= AC_QUARTER && d->high < AC_3QUARTER) {
271
+ d->low -= AC_QUARTER; d->high -= AC_QUARTER; d->code -= AC_QUARTER;
272
+ } else break;
273
+ d->low <<= 1; d->high = (d->high << 1) | 1; d->code = (d->code << 1) | (uint64_t)ad_bit(d);
274
+ }
275
+ }
276
+
277
+ void ctx_decode(const uint8_t *in, long len, long n, int64_t *out) {
278
+ static int freq[CTX_NCTX][CTX_NB]; static long tot[CTX_NCTX];
279
+ static int mf[CTX_NCTX][CTX_NB][2]; static long mt[CTX_NCTX][CTX_NB];
280
+ ctx_init(freq, tot);
281
+ ctx_mant_init(mf, mt);
282
+ adec d = { 0, AC_MAX, 0, in, len, 0 };
283
+ for (int i = 0; i < 32; i++) d.code = (d.code << 1) | (uint64_t)ad_bit(&d);
284
+ int pk = 0, pk2 = 0;
285
+ for (long i = 0; i < n; i++) {
286
+ int ctx = ctx_index(pk, pk2);
287
+ int *f = freq[ctx];
288
+ uint64_t total = (uint64_t)tot[ctx];
289
+ uint64_t target = ad_target(&d, total);
290
+ uint64_t cum = 0; int k = 0;
291
+ while (cum + (uint64_t)f[k] <= target) { cum += (uint64_t)f[k]; k++; }
292
+ ad_update(&d, cum, (uint64_t)f[k], total);
293
+ uint64_t u;
294
+ if (k == 0) u = 0;
295
+ else if (k == 1) u = 1;
296
+ else {
297
+ int *m = mf[ctx][k]; long mtv = mt[ctx][k];
298
+ int b1 = (ad_target(&d, (uint64_t)mtv) >= (uint64_t)m[0]) ? 1 : 0;
299
+ ad_update(&d, b1 == 0 ? 0 : (uint64_t)m[0], (uint64_t)m[b1], (uint64_t)mtv);
300
+ m[b1] += CTX_MINCR; mt[ctx][k] += CTX_MINCR;
301
+ if (mt[ctx][k] >= CTX_MRESCALE) {
302
+ m[0] = (m[0] + 1) >> 1; m[1] = (m[1] + 1) >> 1; mt[ctx][k] = m[0] + m[1];
303
+ }
304
+ uint64_t low = 0;
305
+ for (int j = 0; j < k - 2; j++) { /* remaining low bits raw */
306
+ int bit = (ad_target(&d, 2) >= 1) ? 1 : 0;
307
+ ad_update(&d, (uint64_t)bit, 1, 2);
308
+ low = (low << 1) | (uint64_t)bit;
309
+ }
310
+ u = (1ULL << (k - 1)) | ((uint64_t)b1 << (k - 2)) | low;
311
+ }
312
+ out[i] = (int64_t)(u >> 1) ^ -(int64_t)(u & 1); /* unzigzag */
313
+ ctx_bump(f, &tot[ctx], k);
314
+ pk2 = pk; pk = k;
315
+ }
316
+ }
317
+
318
+ /* --- LZ token-stream coder for codec.py (mirrors codec._encode/_decode_tokens)
319
+ *
320
+ * Drives the same WNC arithmetic coder with three static frequency models
321
+ * (main / dist / mode), each passed as a prefix-sum array `cum` of length n+1
322
+ * (total == cum[n]); the models' symbol alphabets are contiguous 0..n-1, so the
323
+ * symbol value indexes `cum` directly. Length/distance slot "extra" bits are
324
+ * coded as uniform symbols through the coder, and the repeat-offset cache is
325
+ * maintained identically here, so the output is byte-identical to the Python
326
+ * reference and the streams are interchangeable.
327
+ */
328
+ #define LZ_REP_N 16 /* must match model.REP_N */
329
+
330
+ /* seed the move-to-front cache with 1..LZ_REP_N (matches model.REP_INIT) */
331
+ static void rep_init(int64_t *reps) {
332
+ for (int j = 0; j < LZ_REP_N; j++) reps[j] = j + 1;
333
+ }
334
+
335
+ static void model_encode(aenc *e, const int *cum, int n, int s) {
336
+ ae_encode(e, (uint64_t)cum[s], (uint64_t)(cum[s + 1] - cum[s]), (uint64_t)cum[n]);
337
+ }
338
+
339
+ static int model_decode(adec *d, const int *cum, int n) {
340
+ uint64_t total = (uint64_t)cum[n];
341
+ uint64_t target = ad_target(d, total);
342
+ int lo = 0, hi = n + 1; /* bisect_right(cum, target) */
343
+ while (lo < hi) {
344
+ int mid = (lo + hi) >> 1;
345
+ if ((uint64_t)cum[mid] <= target) lo = mid + 1; else hi = mid;
346
+ }
347
+ int s = lo - 1;
348
+ ad_update(d, (uint64_t)cum[s], (uint64_t)(cum[s + 1] - cum[s]), total);
349
+ return s;
350
+ }
351
+
352
+ static void enc_bits(aenc *e, uint64_t value, int nbits) {
353
+ for (int shift = nbits - 1; shift >= 0; shift--)
354
+ ae_encode(e, (value >> shift) & 1, 1, 2);
355
+ }
356
+
357
+ static uint64_t dec_bits(adec *d, int nbits) {
358
+ uint64_t v = 0;
359
+ for (int i = 0; i < nbits; i++) {
360
+ int bit = (ad_target(d, 2) >= 1) ? 1 : 0;
361
+ ad_update(d, (uint64_t)bit, 1, 2);
362
+ v = (v << 1) | (uint64_t)bit;
363
+ }
364
+ return v;
365
+ }
366
+
367
+ /* reps: pop element at index p, then insert `distance` at the front (len stays 3) */
368
+ static void rep_update(int64_t *reps, int p, int64_t distance) {
369
+ for (int j = p; j < LZ_REP_N - 1; j++) reps[j] = reps[j + 1];
370
+ for (int j = LZ_REP_N - 1; j > 0; j--) reps[j] = reps[j - 1];
371
+ reps[0] = distance;
372
+ }
373
+
374
+ long lz_encode(const int *kind, const int64_t *aval, const int64_t *bval, long n_tokens,
375
+ const int *mcum, int m_n, const int *dcum, int d_n, const int *ocum, int o_n,
376
+ int len_base, int min_match, uint8_t *out, long cap) {
377
+ bitw w = { out, cap, 0, 0, 0, 0 };
378
+ aenc e = { 0, AC_MAX, 0, &w };
379
+ int64_t reps[LZ_REP_N]; rep_init(reps);
380
+ for (long i = 0; i < n_tokens; i++) {
381
+ int k = kind[i];
382
+ if (k == 0) { /* literal */
383
+ model_encode(&e, mcum, m_n, (int)aval[i]);
384
+ } else if (k == 1) { /* dict ref */
385
+ model_encode(&e, mcum, m_n, 256 + (int)aval[i]);
386
+ } else { /* match */
387
+ int64_t length = aval[i], distance = bval[i];
388
+ int64_t v = length - min_match + 1;
389
+ int lslot = 63 - __builtin_clzll((uint64_t)v);
390
+ model_encode(&e, mcum, m_n, len_base + lslot);
391
+ enc_bits(&e, (uint64_t)(v - ((int64_t)1 << lslot)), lslot);
392
+ int ri = -1;
393
+ for (int j = 0; j < LZ_REP_N; j++) if (reps[j] == distance) { ri = j; break; }
394
+ if (ri >= 0) {
395
+ model_encode(&e, ocum, o_n, ri + 1);
396
+ } else {
397
+ model_encode(&e, ocum, o_n, 0); /* MODE_NORMAL */
398
+ int dslot = 63 - __builtin_clzll((uint64_t)distance);
399
+ model_encode(&e, dcum, d_n, dslot);
400
+ enc_bits(&e, (uint64_t)(distance - ((int64_t)1 << dslot)), dslot);
401
+ }
402
+ rep_update(reps, ri >= 0 ? ri : LZ_REP_N - 1, distance);
403
+ }
404
+ if (w.overflow) return -1;
405
+ }
406
+ e.pending++;
407
+ ae_emit(&e, e.low < AC_QUARTER ? 0 : 1);
408
+ if (w.overflow) return -1;
409
+ if (w.nbits > 0) {
410
+ if (w.byte >= w.cap) return -1;
411
+ w.out[w.byte++] = (uint8_t)(w.cur << (8 - w.nbits));
412
+ }
413
+ return w.byte;
414
+ }
415
+
416
+ void lz_decode(const uint8_t *in, long len, long n_tokens,
417
+ const int *mcum, int m_n, const int *dcum, int d_n, const int *ocum, int o_n,
418
+ int len_base, int n_patterns, int min_match,
419
+ int *kind, int64_t *aval, int64_t *bval) {
420
+ adec d = { 0, AC_MAX, 0, in, len, 0 };
421
+ for (int i = 0; i < 32; i++) d.code = (d.code << 1) | (uint64_t)ad_bit(&d);
422
+ int64_t reps[LZ_REP_N]; rep_init(reps);
423
+ for (long i = 0; i < n_tokens; i++) {
424
+ int sym = model_decode(&d, mcum, m_n);
425
+ if (sym < 256) {
426
+ kind[i] = 0; aval[i] = sym; bval[i] = 0;
427
+ } else if (sym < 256 + n_patterns) {
428
+ kind[i] = 1; aval[i] = sym - 256; bval[i] = 0;
429
+ } else {
430
+ int lslot = sym - len_base;
431
+ uint64_t lextra = dec_bits(&d, lslot);
432
+ int64_t length = ((int64_t)1 << lslot) + (int64_t)lextra + min_match - 1;
433
+ int m = model_decode(&d, ocum, o_n);
434
+ int64_t distance;
435
+ int p;
436
+ if (m == 0) {
437
+ int dslot = model_decode(&d, dcum, d_n);
438
+ uint64_t dextra = dec_bits(&d, dslot);
439
+ distance = ((int64_t)1 << dslot) + (int64_t)dextra;
440
+ p = LZ_REP_N - 1;
441
+ } else {
442
+ distance = reps[m - 1];
443
+ p = m - 1;
444
+ }
445
+ rep_update(reps, p, distance);
446
+ kind[i] = 2; aval[i] = length; bval[i] = distance;
447
+ }
448
+ }
449
+ }
450
+
451
+ /* --- causal MED reconstruction (mirrors videocodec._med_fill) --------------
452
+ *
453
+ * For each intra pixel in raster order, predict from already-reconstructed
454
+ * neighbours (left a, above b, above-left c) with the JPEG-LS / LOCO-I median,
455
+ * then add the residual. Non-intra pixels are left as the caller filled them
456
+ * (skip/inter), so neighbours read by an intra pixel are always final. rec is
457
+ * modified in place. Integer-exact, so byte-identical to the Python loop. */
458
+ void med_fill(int64_t *rec, const uint8_t *intra, const int64_t *residual,
459
+ long H, long W) {
460
+ for (long y = 0; y < H; y++) {
461
+ for (long x = 0; x < W; x++) {
462
+ long i = y * W + x;
463
+ if (!intra[i]) continue;
464
+ int64_t a = (x > 0) ? rec[i - 1] : ((y > 0) ? rec[i - W] : 128);
465
+ int64_t b = (y > 0) ? rec[i - W] : a;
466
+ int64_t c = (x > 0 && y > 0) ? rec[i - W - 1] : b;
467
+ int64_t mx = a > b ? a : b, mn = a < b ? a : b;
468
+ int64_t pred = (c >= mx) ? mn : ((c <= mn) ? mx : a + b - c);
469
+ rec[i] = pred + residual[i];
470
+ }
471
+ }
472
+ }
473
+
474
+ /* --- causal GAP reconstruction (CALIC gradient-adjusted predictor) ----------
475
+ *
476
+ * Mirrors predictors.gap_predict: the first row predicts from the left, the
477
+ * first column from above, the origin from 128, and the interior uses the GAP
478
+ * gradient logic with same-row/col 2-back neighbours zero when out of range.
479
+ * All divisions are arithmetic right shifts (matching numpy int floor-division),
480
+ * so this is byte-identical to the vectorised forward. Thresholds t1>t2>t3 scale
481
+ * with bit depth. All-intra (mask all 1) for the image codec. */
482
+ void gap_fill(int64_t *rec, const uint8_t *intra, const int64_t *residual,
483
+ long H, long W, long t1, long t2, long t3) {
484
+ for (long y = 0; y < H; y++) {
485
+ for (long x = 0; x < W; x++) {
486
+ long i = y * W + x;
487
+ if (!intra[i]) continue;
488
+ int64_t pred;
489
+ if (y == 0 && x == 0) {
490
+ pred = 128;
491
+ } else if (y == 0) {
492
+ pred = rec[i - 1]; /* left (W) */
493
+ } else if (x == 0) {
494
+ pred = rec[i - W]; /* above (N) */
495
+ } else {
496
+ int64_t a = rec[i - 1], b = rec[i - W], nw = rec[i - W - 1];
497
+ int64_t ne = (x < W - 1) ? rec[i - W + 1] : 0;
498
+ int64_t ww = (x > 1) ? rec[i - 2] : 0;
499
+ int64_t nn = (y > 1) ? rec[i - 2 * W] : 0;
500
+ int64_t dh = llabs(a - ww) + llabs(b - nw) + llabs(b - ne);
501
+ int64_t dv = llabs(a - nw) + llabs(b - nn) + llabs(ne - nn);
502
+ int64_t base = ((a + b) >> 1) + ((ne - nw) >> 2);
503
+ int64_t d = dv - dh;
504
+ if (d > t1) pred = a;
505
+ else if (d < -t1) pred = b;
506
+ else if (d > t2) pred = (base + a) >> 1;
507
+ else if (d < -t2) pred = (base + b) >> 1;
508
+ else if (d > t3) pred = (3 * base + a) >> 2;
509
+ else if (d < -t3) pred = (3 * base + b) >> 2;
510
+ else pred = base;
511
+ }
512
+ rec[i] = pred + residual[i];
513
+ }
514
+ }
515
+ }
516
+
517
+ /* --- CALIC GAP + context bias correction (sequential, encode & decode) ------
518
+ *
519
+ * On top of the GAP prediction, a per-context running mean prediction error
520
+ * (B[k]/C[k]) is subtracted, removing GAP's systematic bias in that context.
521
+ * The context k = energy_bin*64 + texture, where energy = dh+dv+2|e_west| (the
522
+ * causal west-pixel error) quantised into 11 bins, and texture is 6 neighbour
523
+ * sign bits vs the prediction. The bias state evolves identically on encode and
524
+ * decode, so one function serves both via ``mode`` (0 = encode: img -> res;
525
+ * 1 = decode: res -> img). All integer, byte-exact. NCTX = 11*64 = 704. */
526
+ #define CALIC_NCTX 704
527
+
528
+ static int64_t calic_round(int64_t B, int64_t C) {
529
+ if (C <= 0) return 0;
530
+ return (B >= 0) ? (B + C / 2) / C : -(((-B) + C / 2) / C);
531
+ }
532
+
533
+ void calic_code(int64_t *img, int64_t *res, int mode, long H, long W, long scale) {
534
+ static const long base_th[10] = {1, 3, 6, 11, 18, 30, 50, 90, 160, 300};
535
+ long th[10];
536
+ for (int k = 0; k < 10; k++) th[k] = base_th[k] * scale;
537
+ int64_t *B = (int64_t *)calloc(CALIC_NCTX, sizeof(int64_t));
538
+ int64_t *C = (int64_t *)calloc(CALIC_NCTX, sizeof(int64_t));
539
+ if (!B || !C) { free(B); free(C); return; }
540
+
541
+ for (long y = 0; y < H; y++) {
542
+ int64_t e_left = 0;
543
+ for (long x = 0; x < W; x++) {
544
+ long i = y * W + x;
545
+ int64_t a = (x > 0) ? img[i - 1] : 0;
546
+ int64_t b = (y > 0) ? img[i - W] : 0;
547
+ int64_t nw = (x > 0 && y > 0) ? img[i - W - 1] : 0;
548
+ int64_t ne = (y > 0 && x < W - 1) ? img[i - W + 1] : 0;
549
+ int64_t ww = (x > 1) ? img[i - 2] : 0;
550
+ int64_t nn = (y > 1) ? img[i - 2 * W] : 0;
551
+
552
+ int64_t pred;
553
+ if (y == 0 && x == 0) pred = 128;
554
+ else if (y == 0) pred = a;
555
+ else if (x == 0) pred = b;
556
+ else {
557
+ int64_t base = ((a + b) >> 1) + ((ne - nw) >> 2);
558
+ int64_t dhp = llabs(a - ww) + llabs(b - nw) + llabs(b - ne);
559
+ int64_t dvp = llabs(a - nw) + llabs(b - nn) + llabs(ne - nn);
560
+ int64_t d = dvp - dhp, t1 = 80 * scale, t2 = 32 * scale, t3 = 8 * scale;
561
+ if (d > t1) pred = a;
562
+ else if (d < -t1) pred = b;
563
+ else if (d > t2) pred = (base + a) >> 1;
564
+ else if (d < -t2) pred = (base + b) >> 1;
565
+ else if (d > t3) pred = (3 * base + a) >> 2;
566
+ else if (d < -t3) pred = (3 * base + b) >> 2;
567
+ else pred = base;
568
+ }
569
+
570
+ int64_t dh = llabs(a - ww) + llabs(b - nw) + llabs(b - ne);
571
+ int64_t dv = llabs(a - nw) + llabs(b - nn) + llabs(ne - nn);
572
+ int64_t energy = dh + dv + 2 * llabs(e_left);
573
+ int delta = 0;
574
+ while (delta < 10 && energy >= th[delta]) delta++;
575
+ int tex = (a >= pred) | ((b >= pred) << 1) | ((nw >= pred) << 2)
576
+ | ((ne >= pred) << 3) | ((ww >= pred) << 4) | ((nn >= pred) << 5);
577
+ long k = (long)delta * 64 + tex;
578
+ int64_t corr = calic_round(B[k], C[k]);
579
+
580
+ int64_t e;
581
+ if (mode == 0) { /* encode: img known, write residual */
582
+ e = img[i] - pred;
583
+ res[i] = e - corr;
584
+ } else { /* decode: residual known, write img */
585
+ e = res[i] + corr;
586
+ img[i] = e + pred;
587
+ }
588
+ B[k] += e;
589
+ C[k] += 1;
590
+ if (C[k] >= 256) { B[k] >>= 1; C[k] >>= 1; }
591
+ e_left = e;
592
+ }
593
+ }
594
+ free(B);
595
+ free(C);
596
+ }
597
+
598
+ /* --- full CALIC codec: predict + bias + energy-conditional entropy coding ---
599
+ *
600
+ * The decisive image step (vs feeding CALIC residuals to the order-2 ctxcoder):
601
+ * the magnitude-bucket arithmetic model is selected by the local error ENERGY
602
+ * (dh+dv quantised to 12 bins) rather than the scan-order previous buckets. The
603
+ * energy is computed from already-reconstructed neighbours, so the model must be
604
+ * chosen inside the prediction loop — prediction, bias correction (the 704-context
605
+ * running mean), and energy-conditional coding are one integrated pass. Measured
606
+ * ~+2.6% (Bayer) / +1.1% (RGB) over the order-2 coder. Encode and decode below
607
+ * share the exact context/bias arithmetic, so they round-trip byte-exact. */
608
+ #define CALIC_NEBIN 12
609
+
610
+ static int calic_ebin(int64_t energy, const long *tent) {
611
+ int e = 0;
612
+ while (e < 11 && energy >= tent[e]) e++;
613
+ return e; /* 0..11 */
614
+ }
615
+
616
+ long calic_codec_encode(const int64_t *img, long H, long W, long scale,
617
+ uint8_t *out, long cap) {
618
+ static const long bbase[10] = {1, 3, 6, 11, 18, 30, 50, 90, 160, 300};
619
+ static const long ebase[11] = {1, 3, 6, 11, 18, 30, 50, 90, 160, 300, 600};
620
+ long tbias[10], tent[11];
621
+ for (int i = 0; i < 10; i++) tbias[i] = bbase[i] * scale;
622
+ for (int i = 0; i < 11; i++) tent[i] = ebase[i] * scale;
623
+ long t1 = 80 * scale, t2 = 32 * scale, t3 = 8 * scale;
624
+ int64_t *B = (int64_t *)calloc(CALIC_NCTX, sizeof(int64_t));
625
+ int64_t *C = (int64_t *)calloc(CALIC_NCTX, sizeof(int64_t));
626
+ if (!B || !C) { free(B); free(C); return -1; }
627
+ int freq[CALIC_NEBIN][CTX_NB]; long tot[CALIC_NEBIN];
628
+ int mf[CALIC_NEBIN][CTX_NB][2]; long mt[CALIC_NEBIN][CTX_NB]; /* top mantissa bit | (ebin,k) */
629
+ for (int c = 0; c < CALIC_NEBIN; c++) {
630
+ for (int s = 0; s < CTX_NB; s++) {
631
+ freq[c][s] = 1; mf[c][s][0] = 1; mf[c][s][1] = 1; mt[c][s] = 2;
632
+ }
633
+ tot[c] = CTX_NB;
634
+ }
635
+ bitw w = { out, cap, 0, 0, 0, 0 };
636
+ aenc enc = { 0, AC_MAX, 0, &w };
637
+
638
+ for (long y = 0; y < H; y++) {
639
+ int64_t e_left = 0;
640
+ for (long x = 0; x < W; x++) {
641
+ long i = y * W + x;
642
+ int64_t a = (x > 0) ? img[i - 1] : 0;
643
+ int64_t b = (y > 0) ? img[i - W] : 0;
644
+ int64_t nw = (x > 0 && y > 0) ? img[i - W - 1] : 0;
645
+ int64_t ne = (y > 0 && x < W - 1) ? img[i - W + 1] : 0;
646
+ int64_t ww = (x > 1) ? img[i - 2] : 0;
647
+ int64_t nn = (y > 1) ? img[i - 2 * W] : 0;
648
+ int64_t dh = llabs(a - ww) + llabs(b - nw) + llabs(b - ne);
649
+ int64_t dv = llabs(a - nw) + llabs(b - nn) + llabs(ne - nn);
650
+ int64_t pred;
651
+ if (y == 0 && x == 0) pred = 128;
652
+ else if (y == 0) pred = a;
653
+ else if (x == 0) pred = b;
654
+ else {
655
+ int64_t base = ((a + b) >> 1) + ((ne - nw) >> 2), d = dv - dh;
656
+ if (d > t1) pred = a; else if (d < -t1) pred = b;
657
+ else if (d > t2) pred = (base + a) >> 1; else if (d < -t2) pred = (base + b) >> 1;
658
+ else if (d > t3) pred = (3 * base + a) >> 2; else if (d < -t3) pred = (3 * base + b) >> 2;
659
+ else pred = base;
660
+ }
661
+ int db = 0; int64_t ebias = dh + dv + 2 * llabs(e_left);
662
+ while (db < 10 && ebias >= tbias[db]) db++;
663
+ int tex = (a >= pred) | ((b >= pred) << 1) | ((nw >= pred) << 2)
664
+ | ((ne >= pred) << 3) | ((ww >= pred) << 4) | ((nn >= pred) << 5);
665
+ long kb = (long)db * 64 + tex;
666
+ int64_t corr = calic_round(B[kb], C[kb]);
667
+ int ebin = calic_ebin(dh + dv, tent);
668
+
669
+ int64_t e = img[i] - pred, r = e - corr;
670
+ uint64_t u = (((uint64_t)r) << 1) ^ (uint64_t)(r >> 63);
671
+ int k = u ? (64 - __builtin_clzll(u)) : 0;
672
+ int *f = freq[ebin];
673
+ uint64_t cum = 0;
674
+ for (int s = 0; s < k; s++) cum += (uint64_t)f[s];
675
+ ae_encode(&enc, cum, (uint64_t)f[k], (uint64_t)tot[ebin]);
676
+ if (k >= 2) {
677
+ uint64_t mant = u & ((1ULL << (k - 1)) - 1);
678
+ int b1 = (mant >> (k - 2)) & 1; /* top mantissa bit: modelled */
679
+ int *g = mf[ebin][k];
680
+ ae_encode(&enc, b1 == 0 ? 0 : (uint64_t)g[0], (uint64_t)g[b1], (uint64_t)mt[ebin][k]);
681
+ g[b1] += CTX_MINCR; mt[ebin][k] += CTX_MINCR;
682
+ if (mt[ebin][k] >= CTX_MRESCALE) {
683
+ g[0] = (g[0] + 1) >> 1; g[1] = (g[1] + 1) >> 1; mt[ebin][k] = g[0] + g[1];
684
+ }
685
+ for (int sh = k - 3; sh >= 0; sh--) ae_encode(&enc, (mant >> sh) & 1, 1, 2);
686
+ }
687
+ if (w.overflow) { free(B); free(C); return -1; }
688
+ ctx_bump(f, &tot[ebin], k);
689
+ B[kb] += e; C[kb] += 1;
690
+ if (C[kb] >= 256) { B[kb] >>= 1; C[kb] >>= 1; }
691
+ e_left = e;
692
+ }
693
+ }
694
+ enc.pending++;
695
+ ae_emit(&enc, enc.low < AC_QUARTER ? 0 : 1);
696
+ free(B); free(C);
697
+ if (w.overflow) return -1;
698
+ if (w.nbits > 0) { if (w.byte >= w.cap) return -1; w.out[w.byte++] = (uint8_t)(w.cur << (8 - w.nbits)); }
699
+ return w.byte;
700
+ }
701
+
702
+ void calic_codec_decode(const uint8_t *in, long len, int64_t *img,
703
+ long H, long W, long scale) {
704
+ static const long bbase[10] = {1, 3, 6, 11, 18, 30, 50, 90, 160, 300};
705
+ static const long ebase[11] = {1, 3, 6, 11, 18, 30, 50, 90, 160, 300, 600};
706
+ long tbias[10], tent[11];
707
+ for (int i = 0; i < 10; i++) tbias[i] = bbase[i] * scale;
708
+ for (int i = 0; i < 11; i++) tent[i] = ebase[i] * scale;
709
+ long t1 = 80 * scale, t2 = 32 * scale, t3 = 8 * scale;
710
+ int64_t *B = (int64_t *)calloc(CALIC_NCTX, sizeof(int64_t));
711
+ int64_t *C = (int64_t *)calloc(CALIC_NCTX, sizeof(int64_t));
712
+ if (!B || !C) { free(B); free(C); return; }
713
+ int freq[CALIC_NEBIN][CTX_NB]; long tot[CALIC_NEBIN];
714
+ int mf[CALIC_NEBIN][CTX_NB][2]; long mt[CALIC_NEBIN][CTX_NB]; /* top mantissa bit | (ebin,k) */
715
+ for (int c = 0; c < CALIC_NEBIN; c++) {
716
+ for (int s = 0; s < CTX_NB; s++) {
717
+ freq[c][s] = 1; mf[c][s][0] = 1; mf[c][s][1] = 1; mt[c][s] = 2;
718
+ }
719
+ tot[c] = CTX_NB;
720
+ }
721
+ adec d = { 0, AC_MAX, 0, in, len, 0 };
722
+ for (int i = 0; i < 32; i++) d.code = (d.code << 1) | (uint64_t)ad_bit(&d);
723
+
724
+ for (long y = 0; y < H; y++) {
725
+ int64_t e_left = 0;
726
+ for (long x = 0; x < W; x++) {
727
+ long i = y * W + x;
728
+ int64_t a = (x > 0) ? img[i - 1] : 0;
729
+ int64_t b = (y > 0) ? img[i - W] : 0;
730
+ int64_t nw = (x > 0 && y > 0) ? img[i - W - 1] : 0;
731
+ int64_t ne = (y > 0 && x < W - 1) ? img[i - W + 1] : 0;
732
+ int64_t ww = (x > 1) ? img[i - 2] : 0;
733
+ int64_t nn = (y > 1) ? img[i - 2 * W] : 0;
734
+ int64_t dh = llabs(a - ww) + llabs(b - nw) + llabs(b - ne);
735
+ int64_t dv = llabs(a - nw) + llabs(b - nn) + llabs(ne - nn);
736
+ int64_t pred;
737
+ if (y == 0 && x == 0) pred = 128;
738
+ else if (y == 0) pred = a;
739
+ else if (x == 0) pred = b;
740
+ else {
741
+ int64_t base = ((a + b) >> 1) + ((ne - nw) >> 2), dd = dv - dh;
742
+ if (dd > t1) pred = a; else if (dd < -t1) pred = b;
743
+ else if (dd > t2) pred = (base + a) >> 1; else if (dd < -t2) pred = (base + b) >> 1;
744
+ else if (dd > t3) pred = (3 * base + a) >> 2; else if (dd < -t3) pred = (3 * base + b) >> 2;
745
+ else pred = base;
746
+ }
747
+ int db = 0; int64_t ebias = dh + dv + 2 * llabs(e_left);
748
+ while (db < 10 && ebias >= tbias[db]) db++;
749
+ int tex = (a >= pred) | ((b >= pred) << 1) | ((nw >= pred) << 2)
750
+ | ((ne >= pred) << 3) | ((ww >= pred) << 4) | ((nn >= pred) << 5);
751
+ long kb = (long)db * 64 + tex;
752
+ int64_t corr = calic_round(B[kb], C[kb]);
753
+ int ebin = calic_ebin(dh + dv, tent);
754
+
755
+ int *f = freq[ebin];
756
+ uint64_t total = (uint64_t)tot[ebin];
757
+ uint64_t target = ad_target(&d, total);
758
+ uint64_t cum = 0; int k = 0;
759
+ while (cum + (uint64_t)f[k] <= target) { cum += (uint64_t)f[k]; k++; }
760
+ ad_update(&d, cum, (uint64_t)f[k], total);
761
+ uint64_t u;
762
+ if (k == 0) u = 0;
763
+ else if (k == 1) u = 1;
764
+ else {
765
+ int *g = mf[ebin][k]; long mtv = mt[ebin][k];
766
+ int b1 = (ad_target(&d, (uint64_t)mtv) >= (uint64_t)g[0]) ? 1 : 0;
767
+ ad_update(&d, b1 == 0 ? 0 : (uint64_t)g[0], (uint64_t)g[b1], (uint64_t)mtv);
768
+ g[b1] += CTX_MINCR; mt[ebin][k] += CTX_MINCR;
769
+ if (mt[ebin][k] >= CTX_MRESCALE) {
770
+ g[0] = (g[0] + 1) >> 1; g[1] = (g[1] + 1) >> 1; mt[ebin][k] = g[0] + g[1];
771
+ }
772
+ uint64_t low = 0;
773
+ for (int j = 0; j < k - 2; j++) {
774
+ int bit = (ad_target(&d, 2) >= 1) ? 1 : 0;
775
+ ad_update(&d, (uint64_t)bit, 1, 2);
776
+ low = (low << 1) | (uint64_t)bit;
777
+ }
778
+ u = (1ULL << (k - 1)) | ((uint64_t)b1 << (k - 2)) | low;
779
+ }
780
+ int64_t r = (int64_t)(u >> 1) ^ -(int64_t)(u & 1);
781
+ int64_t e = r + corr;
782
+ img[i] = e + pred;
783
+ ctx_bump(f, &tot[ebin], k);
784
+ B[kb] += e; C[kb] += 1;
785
+ if (C[kb] >= 256) { B[kb] >>= 1; C[kb] >>= 1; }
786
+ e_left = e;
787
+ }
788
+ }
789
+ free(B); free(C);
790
+ }
791
+
792
+ /* --- LZ match-finder forward pass (mirrors tokenizer.tokenize_optimal) ------
793
+ *
794
+ * Builds 3-byte hash chains over the combined buffer and, for each data position
795
+ * p in [base, N), finds the maximal in-file match per distinct length keeping the
796
+ * smallest distance (chains run newest-first, so the first time a length appears
797
+ * its distance is already minimal). Candidates are emitted in first-appearance
798
+ * order, exactly as the Python dict preserves them, so the downstream DP makes
799
+ * identical choices. This is the 60%+ hot loop (the per-position _match_len
800
+ * search); the cost-optimal DP itself stays in Python on these integer-exact
801
+ * candidates, so the produced tokens are byte-identical to the pure-Python parse.
802
+ *
803
+ * The 3-byte key (MIN_MATCH==3) indexes a direct 2^24 table — exact, no hash
804
+ * collisions, so chains contain only true 3-byte matches like the Python dict.
805
+ *
806
+ * CSR output: out_off[N-base+1] gives each position's slice into out_len/out_dist.
807
+ * Returns total candidates, -1 if the buffers are too small, -2 if min_match!=3.
808
+ */
809
+ #define HEAD_BITS 24
810
+ #define HEAD_SIZE (1 << HEAD_BITS)
811
+
812
+ long lz_forward(const uint8_t *c, long N, long base, long window,
813
+ int max_match, int max_chain, int min_match,
814
+ int *out_off, int *out_len, int *out_dist, long cap) {
815
+ if (min_match != 3) return -2;
816
+ int32_t *head = (int32_t *)malloc((size_t)HEAD_SIZE * sizeof(int32_t));
817
+ int32_t *prev = (int32_t *)malloc((size_t)N * sizeof(int32_t));
818
+ if (!head || !prev) { free(head); free(prev); return -1; }
819
+ memset(head, 0xFF, (size_t)HEAD_SIZE * sizeof(int32_t)); /* all -1 */
820
+
821
+ #define INSERT(i) do { \
822
+ if ((i) + 3 <= N) { \
823
+ uint32_t _k = ((uint32_t)c[i] << 16) | ((uint32_t)c[(i)+1] << 8) | c[(i)+2]; \
824
+ prev[i] = head[_k]; head[_k] = (int32_t)(i); \
825
+ } else prev[i] = -1; \
826
+ } while (0)
827
+
828
+ /* scratch for the per-position found set (distinct length count <= max_chain) */
829
+ int fcap = max_match + 1;
830
+ int *flen = (int *)malloc((size_t)fcap * sizeof(int));
831
+ int *fdist = (int *)malloc((size_t)fcap * sizeof(int));
832
+ if (!flen || !fdist) { free(head); free(prev); free(flen); free(fdist); return -1; }
833
+
834
+ for (long i = 0; i < base; i++) INSERT(i);
835
+
836
+ long total = 0;
837
+ for (long p = base; p < N; p++) {
838
+ out_off[p - base] = (int)total;
839
+ int fc = 0;
840
+ if (p + 3 <= N) {
841
+ uint32_t key = ((uint32_t)c[p] << 16) | ((uint32_t)c[p+1] << 8) | c[p+2];
842
+ long cand = head[key];
843
+ int chain = max_chain;
844
+ long limit = (max_match < N - p) ? max_match : (N - p);
845
+ while (cand != -1 && p - cand <= window && chain > 0) {
846
+ long n = 0;
847
+ while (n < limit && c[cand + n] == c[p + n]) n++;
848
+ if (n >= min_match) {
849
+ int seen = 0;
850
+ for (int j = 0; j < fc; j++) if (flen[j] == (int)n) { seen = 1; break; }
851
+ if (!seen && fc < fcap) { flen[fc] = (int)n; fdist[fc] = (int)(p - cand); fc++; }
852
+ }
853
+ cand = prev[cand];
854
+ chain--;
855
+ }
856
+ }
857
+ if (total + fc > cap) { free(head); free(prev); free(flen); free(fdist); return -1; }
858
+ for (int j = 0; j < fc; j++) { out_len[total] = flen[j]; out_dist[total] = fdist[j]; total++; }
859
+ INSERT(p);
860
+ }
861
+ out_off[N - base] = (int)total;
862
+ free(head); free(prev); free(flen); free(fdist);
863
+ #undef INSERT
864
+ return total;
865
+ }
866
+
867
+ /* Greedy single-best match per position (mirrors tokenizer._find_lz): longest
868
+ * match, smallest distance on ties (chains newest-first). best_len/best_dist are
869
+ * 0 where nothing qualifies. Used by the greedy/lazy parse (training). Integer-
870
+ * exact, so the produced tokens are identical to the Python parse. Returns -2 if
871
+ * min_match != 3, 0 on success, -1 on allocation failure. */
872
+ long lz_best(const uint8_t *c, long N, long base, long window,
873
+ int max_match, int max_chain, int min_match,
874
+ int *best_len, int *best_dist) {
875
+ if (min_match != 3) return -2;
876
+ int32_t *head = (int32_t *)malloc((size_t)HEAD_SIZE * sizeof(int32_t));
877
+ int32_t *prev = (int32_t *)malloc((size_t)N * sizeof(int32_t));
878
+ if (!head || !prev) { free(head); free(prev); return -1; }
879
+ memset(head, 0xFF, (size_t)HEAD_SIZE * sizeof(int32_t));
880
+
881
+ #define INSERT(i) do { \
882
+ if ((i) + 3 <= N) { \
883
+ uint32_t _k = ((uint32_t)c[i] << 16) | ((uint32_t)c[(i)+1] << 8) | c[(i)+2]; \
884
+ prev[i] = head[_k]; head[_k] = (int32_t)(i); \
885
+ } else prev[i] = -1; \
886
+ } while (0)
887
+
888
+ for (long i = 0; i < base; i++) INSERT(i);
889
+ for (long p = base; p < N; p++) {
890
+ int bl = 0, bd = 0;
891
+ if (p + 3 <= N) {
892
+ uint32_t key = ((uint32_t)c[p] << 16) | ((uint32_t)c[p+1] << 8) | c[p+2];
893
+ long cand = head[key];
894
+ int chain = max_chain;
895
+ long limit = (max_match < N - p) ? max_match : (N - p);
896
+ while (cand != -1 && p - cand <= window && chain > 0) {
897
+ long n = 0;
898
+ while (n < limit && c[cand + n] == c[p + n]) n++;
899
+ if ((int)n > bl) { bl = (int)n; bd = (int)(p - cand); if (n == limit) break; }
900
+ cand = prev[cand];
901
+ chain--;
902
+ }
903
+ }
904
+ best_len[p - base] = bl; best_dist[p - base] = bd;
905
+ INSERT(p);
906
+ }
907
+ free(head); free(prev);
908
+ #undef INSERT
909
+ return 0;
910
+ }
911
+
912
+ /* --- cost-optimal backward DP (mirrors tokenizer.tokenize_optimal's DP) ------
913
+ *
914
+ * Given the per-position LZ candidates (CSR off/clen/cdist), the per-position
915
+ * dict match (dpid/dlen), and cost tables built from the model, compute the
916
+ * minimum-cost parse and walk it into tokens. All arithmetic is on the exact
917
+ * double cost values supplied (lit_table[byte], dict_table[pid], and
918
+ * mc_table[lslot*ND+dslot] — match cost depends only on the two slots), in the
919
+ * same order and with the same strict-< tie-breaking as the Python DP, so the
920
+ * tokens are identical. Returns the token count. Token encoding matches
921
+ * lz_encode's: kind 0 lit (aval=byte), 1 dict (aval=pid), 2 match (aval=length,
922
+ * bval=distance). */
923
+ long lz_dp(const uint8_t *c, long N, long base,
924
+ const int *off, const int *clen, const int *cdist,
925
+ const int *dpid, const int *dlen,
926
+ const double *lit_table, const double *dict_table,
927
+ const double *mc_table, int ND, int min_match,
928
+ int *out_kind, int64_t *out_aval, int64_t *out_bval) {
929
+ double *cte = (double *)malloc((size_t)(N + 1) * sizeof(double));
930
+ int *ck = (int *)malloc((size_t)N * sizeof(int));
931
+ int *ca = (int *)malloc((size_t)N * sizeof(int));
932
+ int *cb = (int *)malloc((size_t)N * sizeof(int));
933
+ if (!cte || !ck || !ca || !cb) { free(cte); free(ck); free(ca); free(cb); return -1; }
934
+ cte[N] = 0.0;
935
+ for (long p = N - 1; p >= base; p--) {
936
+ long pi = p - base;
937
+ double best = lit_table[c[p]] + cte[p + 1];
938
+ int bk = 0, ba = c[p], bb = 0;
939
+ int dl = dlen[pi];
940
+ if (dl >= min_match) {
941
+ double cc = dict_table[dpid[pi]] + cte[p + dl];
942
+ if (cc < best) { best = cc; bk = 1; ba = dpid[pi]; bb = dl; }
943
+ }
944
+ for (int idx = off[pi]; idx < off[pi + 1]; idx++) {
945
+ int length = clen[idx], dist = cdist[idx];
946
+ int lslot = 63 - __builtin_clzll((uint64_t)(length - min_match + 1));
947
+ int dslot = 63 - __builtin_clzll((uint64_t)dist);
948
+ double cc = mc_table[lslot * ND + dslot] + cte[p + length];
949
+ if (cc < best) { best = cc; bk = 2; ba = length; bb = dist; }
950
+ }
951
+ cte[p] = best; ck[pi] = bk; ca[pi] = ba; cb[pi] = bb;
952
+ }
953
+ long nt = 0;
954
+ for (long p = base; p < N; ) {
955
+ long pi = p - base;
956
+ int k = ck[pi];
957
+ out_kind[nt] = k;
958
+ if (k == 0) { out_aval[nt] = ca[pi]; out_bval[nt] = 0; p += 1; }
959
+ else if (k == 1) { out_aval[nt] = ca[pi]; out_bval[nt] = 0; p += cb[pi]; }
960
+ else { out_aval[nt] = ca[pi]; out_bval[nt] = cb[pi]; p += ca[pi]; }
961
+ nt++;
962
+ }
963
+ free(cte); free(ck); free(ca); free(cb);
964
+ return nt;
965
+ }
966
+
967
+ /* --- trained-dictionary longest-match per position (mirrors Dictionary.match) -
968
+ *
969
+ * For each position p in [base, N), the longest dictionary pattern that is a
970
+ * prefix of c[p:]. Patterns are passed flat (pat_data + pat_off) with a 2-byte
971
+ * prefix index (bucket_off CSR over bucket_pids, pattern ids ordered longest-
972
+ * first within each key) — exactly the Python index. out_pid/out_len get the
973
+ * match (pid, length) or (-1, 0). Replaces the per-position dictionary.match
974
+ * Python call in every parse path. */
975
+ void dict_match_all(const uint8_t *c, long N, long base, int min_match,
976
+ const uint8_t *pat_data, const int *pat_off, int npat,
977
+ const int *bucket_off, const int *bucket_pids,
978
+ int *out_pid, int *out_len) {
979
+ (void)npat;
980
+ for (long p = base; p < N; p++) {
981
+ long pi = p - base;
982
+ out_pid[pi] = -1; out_len[pi] = 0;
983
+ if (p + 2 > N) continue;
984
+ int key = ((int)c[p] << 8) | c[p + 1];
985
+ for (int idx = bucket_off[key]; idx < bucket_off[key + 1]; idx++) {
986
+ int pid = bucket_pids[idx];
987
+ int plen = pat_off[pid + 1] - pat_off[pid];
988
+ if (plen < min_match) continue; /* bucket is longest-first */
989
+ if (p + plen > N) continue;
990
+ if (memcmp(c + p, pat_data + pat_off[pid], (size_t)plen) == 0) {
991
+ out_pid[pi] = pid; out_len[pi] = plen;
992
+ break;
993
+ }
994
+ }
995
+ }
996
+ }