multi_compress 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multi_compress might be problematic. Click here for more details.

@@ -1,5 +1,9 @@
1
1
  #include <ruby.h>
2
2
  #include <ruby/encoding.h>
3
+ #include <ruby/thread.h>
4
+ #ifdef HAVE_RUBY_FIBER_SCHEDULER_H
5
+ #include <ruby/fiber/scheduler.h>
6
+ #endif
3
7
  #include <zstd.h>
4
8
  #include <zdict.h>
5
9
  #include <lz4.h>
@@ -12,6 +16,9 @@
12
16
 
13
17
  #define MAX_DECOMPRESS_SIZE (256ULL * 1024 * 1024)
14
18
 
19
+ #define GVL_UNLOCK_THRESHOLD (64 * 1024)
20
+ #define FIBER_YIELD_CHUNK (64 * 1024)
21
+
15
22
  static VALUE mMultiCompress;
16
23
  static VALUE eError;
17
24
  static VALUE eDataError;
@@ -27,6 +34,7 @@ static VALUE cDictionary;
27
34
  static VALUE mZstd;
28
35
  static VALUE mLZ4;
29
36
  static VALUE mBrotli;
37
+ static rb_encoding *binary_encoding;
30
38
 
31
39
  typedef enum { ALGO_ZSTD = 0, ALGO_LZ4 = 1, ALGO_BROTLI = 2 } compress_algo_t;
32
40
 
@@ -130,14 +138,75 @@ static compress_algo_t detect_algo(const uint8_t *data, size_t len) {
130
138
  return ALGO_ZSTD;
131
139
  }
132
140
 
141
+ static inline VALUE rb_binary_str_new(const char *ptr, long len) {
142
+ VALUE str = rb_str_new(ptr, len);
143
+ rb_enc_associate(str, binary_encoding);
144
+ return str;
145
+ }
146
+
147
+ static inline VALUE rb_binary_str_buf_new(long capa) {
148
+ VALUE str = rb_str_buf_new(capa);
149
+ rb_enc_associate(str, binary_encoding);
150
+ return str;
151
+ }
152
+
153
+ static int has_fiber_scheduler(void) {
154
+ #ifdef HAVE_RB_FIBER_SCHEDULER_CURRENT
155
+ VALUE scheduler = rb_fiber_scheduler_current();
156
+ return scheduler != Qnil && scheduler != Qfalse;
157
+ #else
158
+ return 0;
159
+ #endif
160
+ }
161
+
162
+ static void unblock_noop(void *arg) {
163
+ (void)arg;
164
+ }
165
+
166
+ static inline void run_without_gvl(void *(*func)(void *), void *arg) {
167
+ rb_thread_call_without_gvl(func, arg, unblock_noop, NULL);
168
+ }
169
+
170
+ static inline size_t fiber_maybe_yield(size_t bytes_since_yield, size_t just_processed,
171
+ int *did_yield) {
172
+ *did_yield = 0;
173
+ bytes_since_yield += just_processed;
174
+ if (bytes_since_yield >= FIBER_YIELD_CHUNK) {
175
+ if (has_fiber_scheduler()) {
176
+ rb_thread_schedule();
177
+ *did_yield = 1;
178
+ }
179
+ return 0;
180
+ }
181
+ return bytes_since_yield;
182
+ }
183
+
184
+ #define DICT_CDICT_CACHE_SIZE 4
185
+
186
+ typedef struct {
187
+ int level;
188
+ ZSTD_CDict *cdict;
189
+ } cdict_cache_entry_t;
190
+
133
191
  typedef struct {
134
192
  compress_algo_t algo;
135
193
  uint8_t *data;
136
194
  size_t size;
195
+
196
+ cdict_cache_entry_t cdict_cache[DICT_CDICT_CACHE_SIZE];
197
+ int cdict_cache_count;
198
+
199
+ ZSTD_DDict *ddict;
137
200
  } dictionary_t;
138
201
 
139
202
  static void dict_free(void *ptr) {
140
203
  dictionary_t *dict = (dictionary_t *)ptr;
204
+ for (int i = 0; i < dict->cdict_cache_count; i++) {
205
+ if (dict->cdict_cache[i].cdict)
206
+ ZSTD_freeCDict(dict->cdict_cache[i].cdict);
207
+ }
208
+ if (dict->ddict)
209
+ ZSTD_freeDDict(dict->ddict);
141
210
  if (dict->data)
142
211
  xfree(dict->data);
143
212
  xfree(dict);
@@ -145,7 +214,15 @@ static void dict_free(void *ptr) {
145
214
 
146
215
  static size_t dict_memsize(const void *ptr) {
147
216
  const dictionary_t *d = (const dictionary_t *)ptr;
148
- return sizeof(dictionary_t) + d->size;
217
+ size_t total = sizeof(dictionary_t) + d->size;
218
+
219
+ for (int i = 0; i < d->cdict_cache_count; i++) {
220
+ if (d->cdict_cache[i].cdict)
221
+ total += d->size + 4096;
222
+ }
223
+ if (d->ddict)
224
+ total += d->size + 4096;
225
+ return total;
149
226
  }
150
227
 
151
228
  static const rb_data_type_t dictionary_type = {
@@ -157,6 +234,138 @@ static VALUE dict_alloc(VALUE klass) {
157
234
  return TypedData_Wrap_Struct(klass, &dictionary_type, d);
158
235
  }
159
236
 
237
+ static ZSTD_CDict *dict_get_cdict(dictionary_t *dict, int level) {
238
+ for (int i = 0; i < dict->cdict_cache_count; i++) {
239
+ if (dict->cdict_cache[i].level == level)
240
+ return dict->cdict_cache[i].cdict;
241
+ }
242
+
243
+ ZSTD_CDict *cdict = ZSTD_createCDict(dict->data, dict->size, level);
244
+ if (!cdict)
245
+ return NULL;
246
+
247
+ for (int i = 0; i < dict->cdict_cache_count; i++) {
248
+ if (dict->cdict_cache[i].level == level) {
249
+ ZSTD_freeCDict(cdict);
250
+ return dict->cdict_cache[i].cdict;
251
+ }
252
+ }
253
+
254
+ if (dict->cdict_cache_count < DICT_CDICT_CACHE_SIZE) {
255
+ dict->cdict_cache[dict->cdict_cache_count].level = level;
256
+ dict->cdict_cache[dict->cdict_cache_count].cdict = cdict;
257
+ dict->cdict_cache_count++;
258
+ } else {
259
+ ZSTD_CDict *old_cdict = dict->cdict_cache[0].cdict;
260
+ memmove(&dict->cdict_cache[0], &dict->cdict_cache[1],
261
+ sizeof(cdict_cache_entry_t) * (DICT_CDICT_CACHE_SIZE - 1));
262
+ dict->cdict_cache[DICT_CDICT_CACHE_SIZE - 1].level = level;
263
+ dict->cdict_cache[DICT_CDICT_CACHE_SIZE - 1].cdict = cdict;
264
+ if (old_cdict)
265
+ ZSTD_freeCDict(old_cdict);
266
+ }
267
+ return cdict;
268
+ }
269
+
270
+ static ZSTD_DDict *dict_get_ddict(dictionary_t *dict) {
271
+ if (!dict->ddict) {
272
+ dict->ddict = ZSTD_createDDict(dict->data, dict->size);
273
+ }
274
+ return dict->ddict;
275
+ }
276
+
277
+ typedef struct {
278
+ const char *src;
279
+ size_t src_len;
280
+ char *dst;
281
+ size_t dst_cap;
282
+ int level;
283
+ ZSTD_CDict *cdict;
284
+ size_t result;
285
+ int error;
286
+ } zstd_compress_args_t;
287
+
288
+ static void *zstd_compress_nogvl(void *arg) {
289
+ zstd_compress_args_t *a = (zstd_compress_args_t *)arg;
290
+ if (a->cdict) {
291
+ ZSTD_CCtx *cctx = ZSTD_createCCtx();
292
+ if (!cctx) {
293
+ a->error = 1;
294
+ return NULL;
295
+ }
296
+ a->result =
297
+ ZSTD_compress_usingCDict(cctx, a->dst, a->dst_cap, a->src, a->src_len, a->cdict);
298
+ ZSTD_freeCCtx(cctx);
299
+ } else {
300
+ a->result = ZSTD_compress(a->dst, a->dst_cap, a->src, a->src_len, a->level);
301
+ }
302
+ a->error = 0;
303
+ return NULL;
304
+ }
305
+
306
+ typedef struct {
307
+ const void *src;
308
+ size_t src_len;
309
+ void *dst;
310
+ size_t dst_cap;
311
+ ZSTD_DDict *ddict;
312
+ size_t result;
313
+ int error;
314
+ } zstd_decompress_args_t;
315
+
316
+ static void *zstd_decompress_nogvl(void *arg) {
317
+ zstd_decompress_args_t *a = (zstd_decompress_args_t *)arg;
318
+ if (a->ddict) {
319
+ ZSTD_DCtx *dctx = ZSTD_createDCtx();
320
+ if (!dctx) {
321
+ a->error = 1;
322
+ return NULL;
323
+ }
324
+ a->result =
325
+ ZSTD_decompress_usingDDict(dctx, a->dst, a->dst_cap, a->src, a->src_len, a->ddict);
326
+ ZSTD_freeDCtx(dctx);
327
+ } else {
328
+ a->result = ZSTD_decompress(a->dst, a->dst_cap, a->src, a->src_len);
329
+ }
330
+ a->error = 0;
331
+ return NULL;
332
+ }
333
+
334
+ typedef struct {
335
+ const char *src;
336
+ int src_len;
337
+ char *dst;
338
+ int dst_cap;
339
+ int level;
340
+ int result;
341
+ } lz4_compress_args_t;
342
+
343
+ static void *lz4_compress_nogvl(void *arg) {
344
+ lz4_compress_args_t *a = (lz4_compress_args_t *)arg;
345
+ if (a->level > 1) {
346
+ a->result = LZ4_compress_HC(a->src, a->dst, a->src_len, a->dst_cap, a->level);
347
+ } else {
348
+ a->result = LZ4_compress_default(a->src, a->dst, a->src_len, a->dst_cap);
349
+ }
350
+ return NULL;
351
+ }
352
+
353
+ typedef struct {
354
+ int level;
355
+ size_t src_len;
356
+ const uint8_t *src;
357
+ size_t *out_len;
358
+ uint8_t *dst;
359
+ BROTLI_BOOL result;
360
+ } brotli_compress_args_t;
361
+
362
+ static void *brotli_compress_nogvl(void *arg) {
363
+ brotli_compress_args_t *a = (brotli_compress_args_t *)arg;
364
+ a->result = BrotliEncoderCompress(a->level, BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE,
365
+ a->src_len, a->src, a->out_len, a->dst);
366
+ return NULL;
367
+ }
368
+
160
369
  static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
161
370
  VALUE data, opts;
162
371
  rb_scan_args(argc, argv, "1:", &data, &opts);
@@ -186,74 +395,166 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
186
395
  switch (algo) {
187
396
  case ALGO_ZSTD: {
188
397
  size_t bound = ZSTD_compressBound(slen);
189
- VALUE dst = rb_str_buf_new(bound);
190
398
 
191
399
  size_t csize;
192
- if (dict) {
193
- ZSTD_CCtx *cctx = ZSTD_createCCtx();
194
- if (!cctx)
400
+ if (slen >= GVL_UNLOCK_THRESHOLD) {
401
+ char *src_buf = xmalloc(slen);
402
+ memcpy(src_buf, src, slen);
403
+ char *dst_buf = xmalloc(bound);
404
+
405
+ ZSTD_CDict *cdict = NULL;
406
+ if (dict) {
407
+ cdict = dict_get_cdict(dict, level);
408
+ if (!cdict) {
409
+ xfree(src_buf);
410
+ xfree(dst_buf);
411
+ rb_raise(eMemError, "zstd: failed to create/get cdict");
412
+ }
413
+ }
414
+
415
+ zstd_compress_args_t args = {
416
+ .src = src_buf,
417
+ .src_len = slen,
418
+ .dst = dst_buf,
419
+ .dst_cap = bound,
420
+ .level = level,
421
+ .cdict = cdict,
422
+ };
423
+ run_without_gvl(zstd_compress_nogvl, &args);
424
+
425
+ if (args.error) {
426
+ xfree(src_buf);
427
+ xfree(dst_buf);
195
428
  rb_raise(eMemError, "zstd: failed to create context");
196
- ZSTD_CDict *cdict = ZSTD_createCDict(dict->data, dict->size, level);
197
- if (!cdict) {
198
- ZSTD_freeCCtx(cctx);
199
- rb_raise(eMemError, "zstd: failed to create cdict");
200
429
  }
201
- csize = ZSTD_compress_usingCDict(cctx, RSTRING_PTR(dst), bound, src, slen, cdict);
202
- ZSTD_freeCDict(cdict);
203
- ZSTD_freeCCtx(cctx);
430
+ csize = args.result;
431
+
432
+ if (ZSTD_isError(csize)) {
433
+ const char *err = ZSTD_getErrorName(csize);
434
+ xfree(src_buf);
435
+ xfree(dst_buf);
436
+ rb_raise(eError, "zstd compress: %s", err);
437
+ }
438
+
439
+ VALUE dst = rb_binary_str_new(dst_buf, (long)csize);
440
+ xfree(src_buf);
441
+ xfree(dst_buf);
442
+ RB_GC_GUARD(data);
443
+ return dst;
204
444
  } else {
205
- csize = ZSTD_compress(RSTRING_PTR(dst), bound, src, slen, level);
206
- }
445
+ VALUE dst = rb_binary_str_buf_new(bound);
446
+
447
+ if (dict) {
448
+ ZSTD_CDict *cdict = dict_get_cdict(dict, level);
449
+ if (!cdict)
450
+ rb_raise(eMemError, "zstd: failed to create/get cdict");
451
+ ZSTD_CCtx *cctx = ZSTD_createCCtx();
452
+ if (!cctx)
453
+ rb_raise(eMemError, "zstd: failed to create context");
454
+ csize = ZSTD_compress_usingCDict(cctx, RSTRING_PTR(dst), bound, src, slen, cdict);
455
+ ZSTD_freeCCtx(cctx);
456
+ } else {
457
+ csize = ZSTD_compress(RSTRING_PTR(dst), bound, src, slen, level);
458
+ }
207
459
 
208
- if (ZSTD_isError(csize)) {
209
- rb_raise(eError, "zstd compress: %s", ZSTD_getErrorName(csize));
460
+ if (ZSTD_isError(csize)) {
461
+ rb_raise(eError, "zstd compress: %s", ZSTD_getErrorName(csize));
462
+ }
463
+ rb_str_set_len(dst, csize);
464
+ RB_GC_GUARD(data);
465
+ return dst;
210
466
  }
211
- rb_str_set_len(dst, csize);
212
- return dst;
213
467
  }
214
468
  case ALGO_LZ4: {
215
469
  if (slen > (size_t)INT_MAX)
216
470
  rb_raise(eError, "lz4: input too large (max 2GB)");
217
471
  int bound = LZ4_compressBound((int)slen);
218
- VALUE dst = rb_str_buf_new(8 + bound + 4);
219
- char *out = RSTRING_PTR(dst);
220
-
221
- out[0] = (slen >> 0) & 0xFF;
222
- out[1] = (slen >> 8) & 0xFF;
223
- out[2] = (slen >> 16) & 0xFF;
224
- out[3] = (slen >> 24) & 0xFF;
225
472
 
226
473
  int csize;
227
- if (level > 1) {
228
- csize = LZ4_compress_HC(src, out + 8, (int)slen, bound, level);
474
+ if (slen >= GVL_UNLOCK_THRESHOLD) {
475
+ char *src_buf = xmalloc(slen);
476
+ memcpy(src_buf, src, slen);
477
+ char *dst_buf = xmalloc((size_t)bound);
478
+
479
+ lz4_compress_args_t args = {
480
+ .src = src_buf,
481
+ .src_len = (int)slen,
482
+ .dst = dst_buf,
483
+ .dst_cap = bound,
484
+ .level = level,
485
+ };
486
+ run_without_gvl(lz4_compress_nogvl, &args);
487
+ csize = args.result;
488
+
489
+ if (csize <= 0) {
490
+ xfree(src_buf);
491
+ xfree(dst_buf);
492
+ rb_raise(eError, "lz4 compress failed");
493
+ }
494
+
495
+ size_t total = 8 + (size_t)csize + 4;
496
+ VALUE dst = rb_binary_str_buf_new(total);
497
+ char *out = RSTRING_PTR(dst);
498
+
499
+ out[0] = (slen >> 0) & 0xFF;
500
+ out[1] = (slen >> 8) & 0xFF;
501
+ out[2] = (slen >> 16) & 0xFF;
502
+ out[3] = (slen >> 24) & 0xFF;
503
+ out[4] = (csize >> 0) & 0xFF;
504
+ out[5] = (csize >> 8) & 0xFF;
505
+ out[6] = (csize >> 16) & 0xFF;
506
+ out[7] = (csize >> 24) & 0xFF;
507
+ memcpy(out + 8, dst_buf, (size_t)csize);
508
+ out[8 + csize] = 0;
509
+ out[8 + csize + 1] = 0;
510
+ out[8 + csize + 2] = 0;
511
+ out[8 + csize + 3] = 0;
512
+
513
+ rb_str_set_len(dst, total);
514
+ xfree(src_buf);
515
+ xfree(dst_buf);
516
+ RB_GC_GUARD(data);
517
+ return dst;
229
518
  } else {
230
- csize = LZ4_compress_default(src, out + 8, (int)slen, bound);
231
- }
232
- if (csize <= 0) {
233
- rb_raise(eError, "lz4 compress failed");
519
+ VALUE dst = rb_binary_str_buf_new(8 + bound + 4);
520
+ char *out = RSTRING_PTR(dst);
521
+
522
+ out[0] = (slen >> 0) & 0xFF;
523
+ out[1] = (slen >> 8) & 0xFF;
524
+ out[2] = (slen >> 16) & 0xFF;
525
+ out[3] = (slen >> 24) & 0xFF;
526
+
527
+ if (level > 1) {
528
+ csize = LZ4_compress_HC(src, out + 8, (int)slen, bound, level);
529
+ } else {
530
+ csize = LZ4_compress_default(src, out + 8, (int)slen, bound);
531
+ }
532
+ if (csize <= 0)
533
+ rb_raise(eError, "lz4 compress failed");
534
+
535
+ out[4] = (csize >> 0) & 0xFF;
536
+ out[5] = (csize >> 8) & 0xFF;
537
+ out[6] = (csize >> 16) & 0xFF;
538
+ out[7] = (csize >> 24) & 0xFF;
539
+
540
+ size_t total = 8 + csize;
541
+ out[total] = 0;
542
+ out[total + 1] = 0;
543
+ out[total + 2] = 0;
544
+ out[total + 3] = 0;
545
+
546
+ rb_str_set_len(dst, total + 4);
547
+ RB_GC_GUARD(data);
548
+ return dst;
234
549
  }
235
-
236
- out[4] = (csize >> 0) & 0xFF;
237
- out[5] = (csize >> 8) & 0xFF;
238
- out[6] = (csize >> 16) & 0xFF;
239
- out[7] = (csize >> 24) & 0xFF;
240
-
241
- size_t total = 8 + csize;
242
- out[total] = 0;
243
- out[total + 1] = 0;
244
- out[total + 2] = 0;
245
- out[total + 3] = 0;
246
-
247
- rb_str_set_len(dst, total + 4);
248
- return dst;
249
550
  }
250
551
  case ALGO_BROTLI: {
251
552
  size_t out_len = BrotliEncoderMaxCompressedSize(slen);
252
553
  if (out_len == 0)
253
- out_len = slen + 1024;
254
- VALUE dst = rb_str_buf_new(out_len);
554
+ out_len = slen + (slen >> 2) + 1024;
255
555
 
256
556
  if (dict) {
557
+ VALUE dst = rb_binary_str_buf_new(out_len);
257
558
  BrotliEncoderState *enc = BrotliEncoderCreateInstance(NULL, NULL, NULL);
258
559
  if (!enc)
259
560
  rb_raise(eMemError, "brotli: failed to create encoder");
@@ -279,16 +580,45 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
279
580
  rb_raise(eError, "brotli compress with dict failed");
280
581
 
281
582
  rb_str_set_len(dst, initial_out - available_out);
583
+ RB_GC_GUARD(data);
584
+ return dst;
585
+ } else if (slen >= GVL_UNLOCK_THRESHOLD) {
586
+ uint8_t *src_buf = xmalloc(slen);
587
+ memcpy(src_buf, src, slen);
588
+ uint8_t *dst_buf = xmalloc(out_len);
589
+ size_t actual_out_len = out_len;
590
+
591
+ brotli_compress_args_t args = {
592
+ .level = level,
593
+ .src_len = slen,
594
+ .src = src_buf,
595
+ .out_len = &actual_out_len,
596
+ .dst = dst_buf,
597
+ };
598
+ run_without_gvl(brotli_compress_nogvl, &args);
599
+
600
+ if (!args.result) {
601
+ xfree(src_buf);
602
+ xfree(dst_buf);
603
+ rb_raise(eError, "brotli compress failed");
604
+ }
605
+
606
+ VALUE dst = rb_binary_str_new((const char *)dst_buf, (long)actual_out_len);
607
+ xfree(src_buf);
608
+ xfree(dst_buf);
609
+ RB_GC_GUARD(data);
610
+ return dst;
282
611
  } else {
612
+ VALUE dst = rb_binary_str_buf_new(out_len);
283
613
  BROTLI_BOOL ok =
284
614
  BrotliEncoderCompress(level, BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE, slen,
285
615
  (const uint8_t *)src, &out_len, (uint8_t *)RSTRING_PTR(dst));
286
-
287
616
  if (!ok)
288
617
  rb_raise(eError, "brotli compress failed");
289
618
  rb_str_set_len(dst, out_len);
619
+ RB_GC_GUARD(data);
620
+ return dst;
290
621
  }
291
- return dst;
292
622
  }
293
623
  }
294
624
 
@@ -331,28 +661,97 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
331
661
  rb_raise(eDataError, "zstd: not valid compressed data");
332
662
  }
333
663
 
664
+ if (frame_size != ZSTD_CONTENTSIZE_UNKNOWN && frame_size <= MAX_DECOMPRESS_SIZE) {
665
+ size_t dsize;
666
+
667
+ if (frame_size >= GVL_UNLOCK_THRESHOLD) {
668
+ char *src_buf = xmalloc(slen);
669
+ memcpy(src_buf, src, slen);
670
+ char *dst_buf = xmalloc((size_t)frame_size);
671
+
672
+ ZSTD_DDict *ddict = NULL;
673
+ if (dict) {
674
+ ddict = dict_get_ddict(dict);
675
+ if (!ddict) {
676
+ xfree(src_buf);
677
+ xfree(dst_buf);
678
+ rb_raise(eMemError, "zstd: failed to create ddict");
679
+ }
680
+ }
681
+
682
+ zstd_decompress_args_t args = {
683
+ .src = src_buf,
684
+ .src_len = slen,
685
+ .dst = dst_buf,
686
+ .dst_cap = (size_t)frame_size,
687
+ .ddict = ddict,
688
+ };
689
+ run_without_gvl(zstd_decompress_nogvl, &args);
690
+
691
+ if (args.error) {
692
+ xfree(src_buf);
693
+ xfree(dst_buf);
694
+ rb_raise(eMemError, "zstd: failed to create dctx");
695
+ }
696
+ dsize = args.result;
697
+
698
+ if (ZSTD_isError(dsize)) {
699
+ const char *err = ZSTD_getErrorName(dsize);
700
+ xfree(src_buf);
701
+ xfree(dst_buf);
702
+ rb_raise(eDataError, "zstd decompress: %s", err);
703
+ }
704
+
705
+ VALUE dst = rb_binary_str_new(dst_buf, (long)dsize);
706
+ xfree(src_buf);
707
+ xfree(dst_buf);
708
+ RB_GC_GUARD(data);
709
+ return dst;
710
+ } else {
711
+ VALUE dst = rb_binary_str_buf_new((size_t)frame_size);
712
+
713
+ if (dict) {
714
+ ZSTD_DDict *ddict = dict_get_ddict(dict);
715
+ if (!ddict)
716
+ rb_raise(eMemError, "zstd: failed to create ddict");
717
+ ZSTD_DCtx *dctx = ZSTD_createDCtx();
718
+ if (!dctx)
719
+ rb_raise(eMemError, "zstd: failed to create dctx");
720
+ dsize = ZSTD_decompress_usingDDict(dctx, RSTRING_PTR(dst), (size_t)frame_size,
721
+ src, slen, ddict);
722
+ ZSTD_freeDCtx(dctx);
723
+ } else {
724
+ dsize = ZSTD_decompress(RSTRING_PTR(dst), (size_t)frame_size, src, slen);
725
+ }
726
+
727
+ if (ZSTD_isError(dsize))
728
+ rb_raise(eDataError, "zstd decompress: %s", ZSTD_getErrorName(dsize));
729
+ rb_str_set_len(dst, dsize);
730
+ RB_GC_GUARD(data);
731
+ return dst;
732
+ }
733
+ }
734
+
334
735
  ZSTD_DCtx *dctx = ZSTD_createDCtx();
335
736
  if (!dctx)
336
737
  rb_raise(eMemError, "zstd: failed to create dctx");
337
738
 
338
739
  if (dict) {
339
- size_t r = ZSTD_DCtx_loadDictionary(dctx, dict->data, dict->size);
340
- if (ZSTD_isError(r)) {
341
- ZSTD_freeDCtx(dctx);
342
- rb_raise(eError, "zstd dict load: %s", ZSTD_getErrorName(r));
740
+ ZSTD_DDict *ddict = dict_get_ddict(dict);
741
+ if (ddict) {
742
+ size_t r = ZSTD_DCtx_refDDict(dctx, ddict);
743
+ if (ZSTD_isError(r)) {
744
+ ZSTD_freeDCtx(dctx);
745
+ rb_raise(eError, "zstd dict ref: %s", ZSTD_getErrorName(r));
746
+ }
343
747
  }
344
748
  }
345
749
 
346
- size_t alloc_size;
347
- if (frame_size != ZSTD_CONTENTSIZE_UNKNOWN && frame_size <= 256ULL * 1024 * 1024) {
348
- alloc_size = (size_t)frame_size;
349
- } else {
350
- alloc_size = (slen > MAX_DECOMPRESS_SIZE / 8) ? MAX_DECOMPRESS_SIZE : slen * 8;
351
- if (alloc_size < 4096)
352
- alloc_size = 4096;
353
- }
750
+ size_t alloc_size = (slen > MAX_DECOMPRESS_SIZE / 8) ? MAX_DECOMPRESS_SIZE : slen * 8;
751
+ if (alloc_size < 4096)
752
+ alloc_size = 4096;
354
753
 
355
- VALUE dst = rb_str_buf_new(alloc_size);
754
+ VALUE dst = rb_binary_str_buf_new(alloc_size);
356
755
  size_t total_out = 0;
357
756
 
358
757
  ZSTD_inBuffer input = {src, slen, 0};
@@ -382,13 +781,39 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
382
781
 
383
782
  ZSTD_freeDCtx(dctx);
384
783
  rb_str_set_len(dst, total_out);
784
+ RB_GC_GUARD(data);
385
785
  return dst;
386
786
  }
387
787
  case ALGO_LZ4: {
388
788
  if (slen < 4)
389
789
  rb_raise(eDataError, "lz4: data too short");
390
790
 
391
- VALUE result = rb_str_buf_new(0);
791
+ size_t total_orig = 0;
792
+ size_t scan_pos = 0;
793
+ while (scan_pos + 4 <= slen) {
794
+ uint32_t orig_size = (uint32_t)src[scan_pos] | ((uint32_t)src[scan_pos + 1] << 8) |
795
+ ((uint32_t)src[scan_pos + 2] << 16) |
796
+ ((uint32_t)src[scan_pos + 3] << 24);
797
+ if (orig_size == 0)
798
+ break;
799
+ if (scan_pos + 8 > slen)
800
+ rb_raise(eDataError, "lz4: truncated block header");
801
+ uint32_t comp_size = (uint32_t)src[scan_pos + 4] | ((uint32_t)src[scan_pos + 5] << 8) |
802
+ ((uint32_t)src[scan_pos + 6] << 16) |
803
+ ((uint32_t)src[scan_pos + 7] << 24);
804
+ if (scan_pos + 8 + comp_size > slen)
805
+ rb_raise(eDataError, "lz4: truncated block data");
806
+ if (orig_size > 256 * 1024 * 1024)
807
+ rb_raise(eDataError, "lz4: block too large (%u)", orig_size);
808
+ total_orig += orig_size;
809
+ if (total_orig > MAX_DECOMPRESS_SIZE)
810
+ rb_raise(eDataError, "lz4: total decompressed size exceeds limit");
811
+ scan_pos += 8 + comp_size;
812
+ }
813
+
814
+ VALUE result = rb_binary_str_buf_new(total_orig);
815
+ char *out_ptr = RSTRING_PTR(result);
816
+ size_t out_offset = 0;
392
817
  size_t pos = 0;
393
818
 
394
819
  while (pos + 4 <= slen) {
@@ -396,29 +821,20 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
396
821
  ((uint32_t)src[pos + 2] << 16) | ((uint32_t)src[pos + 3] << 24);
397
822
  if (orig_size == 0)
398
823
  break;
399
-
400
- if (pos + 8 > slen)
401
- rb_raise(eDataError, "lz4: truncated block header");
402
-
403
824
  uint32_t comp_size = (uint32_t)src[pos + 4] | ((uint32_t)src[pos + 5] << 8) |
404
825
  ((uint32_t)src[pos + 6] << 16) | ((uint32_t)src[pos + 7] << 24);
405
826
 
406
- if (pos + 8 + comp_size > slen)
407
- rb_raise(eDataError, "lz4: truncated block data");
408
- if (orig_size > 256 * 1024 * 1024)
409
- rb_raise(eDataError, "lz4: block too large (%u)", orig_size);
410
-
411
- VALUE block = rb_str_buf_new(orig_size);
412
- int dsize = LZ4_decompress_safe((const char *)(src + pos + 8), RSTRING_PTR(block),
827
+ int dsize = LZ4_decompress_safe((const char *)(src + pos + 8), out_ptr + out_offset,
413
828
  (int)comp_size, (int)orig_size);
414
829
  if (dsize < 0)
415
830
  rb_raise(eDataError, "lz4 decompress failed");
416
831
 
417
- rb_str_set_len(block, dsize);
418
- rb_str_cat(result, RSTRING_PTR(block), dsize);
832
+ out_offset += dsize;
419
833
  pos += 8 + comp_size;
420
834
  }
421
835
 
836
+ rb_str_set_len(result, out_offset);
837
+ RB_GC_GUARD(data);
422
838
  return result;
423
839
  }
424
840
  case ALGO_BROTLI: {
@@ -435,7 +851,7 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
435
851
  dict->data);
436
852
  }
437
853
 
438
- VALUE dst = rb_str_buf_new(alloc_size);
854
+ VALUE dst = rb_binary_str_buf_new(alloc_size);
439
855
  size_t total_out = 0;
440
856
 
441
857
  size_t available_in = slen;
@@ -470,6 +886,7 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
470
886
  rb_raise(eDataError, "brotli decompress failed");
471
887
  }
472
888
  rb_str_set_len(dst, total_out);
889
+ RB_GC_GUARD(data);
473
890
  return dst;
474
891
  }
475
892
  }
@@ -477,39 +894,56 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
477
894
  return Qnil;
478
895
  }
479
896
 
480
- static const uint32_t crc32_table[256] = {
481
- 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3,
482
- 0x0EDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91,
483
- 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
484
- 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5,
485
- 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B,
486
- 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
487
- 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F,
488
- 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D,
489
- 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
490
- 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01,
491
- 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457,
492
- 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
493
- 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB,
494
- 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9,
495
- 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
496
- 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD,
497
- 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683,
498
- 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
499
- 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7,
500
- 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5,
501
- 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
502
- 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79,
503
- 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F,
504
- 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
505
- 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713,
506
- 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21,
507
- 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
508
- 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45,
509
- 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB,
510
- 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
511
- 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF,
512
- 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D};
897
+ static uint32_t crc32_tables[8][256];
898
+ static int crc32_tables_initialized = 0;
899
+
900
+ static void crc32_init_tables(void) {
901
+ if (crc32_tables_initialized)
902
+ return;
903
+
904
+ for (uint32_t i = 0; i < 256; i++) {
905
+ uint32_t crc = i;
906
+ for (int j = 0; j < 8; j++) {
907
+ crc = (crc >> 1) ^ (0xEDB88320 & (-(int32_t)(crc & 1)));
908
+ }
909
+ crc32_tables[0][i] = crc;
910
+ }
911
+
912
+ for (uint32_t i = 0; i < 256; i++) {
913
+ uint32_t crc = crc32_tables[0][i];
914
+ for (int t = 1; t < 8; t++) {
915
+ crc = crc32_tables[0][crc & 0xFF] ^ (crc >> 8);
916
+ crc32_tables[t][i] = crc;
917
+ }
918
+ }
919
+
920
+ crc32_tables_initialized = 1;
921
+ }
922
+
923
+ static uint32_t crc32_compute(const uint8_t *data, size_t len, uint32_t crc) {
924
+ crc = ~crc;
925
+
926
+ while (len >= 8) {
927
+ uint32_t val0 = crc ^ ((uint32_t)data[0] | ((uint32_t)data[1] << 8) |
928
+ ((uint32_t)data[2] << 16) | ((uint32_t)data[3] << 24));
929
+ uint32_t val1 = (uint32_t)data[4] | ((uint32_t)data[5] << 8) | ((uint32_t)data[6] << 16) |
930
+ ((uint32_t)data[7] << 24);
931
+
932
+ crc = crc32_tables[7][(val0) & 0xFF] ^ crc32_tables[6][(val0 >> 8) & 0xFF] ^
933
+ crc32_tables[5][(val0 >> 16) & 0xFF] ^ crc32_tables[4][(val0 >> 24) & 0xFF] ^
934
+ crc32_tables[3][(val1) & 0xFF] ^ crc32_tables[2][(val1 >> 8) & 0xFF] ^
935
+ crc32_tables[1][(val1 >> 16) & 0xFF] ^ crc32_tables[0][(val1 >> 24) & 0xFF];
936
+
937
+ data += 8;
938
+ len -= 8;
939
+ }
940
+
941
+ while (len--) {
942
+ crc = crc32_tables[0][(crc ^ *data++) & 0xFF] ^ (crc >> 8);
943
+ }
944
+
945
+ return ~crc;
946
+ }
513
947
 
514
948
  static VALUE compress_crc32(int argc, VALUE *argv, VALUE self) {
515
949
  VALUE data, prev;
@@ -520,13 +954,7 @@ static VALUE compress_crc32(int argc, VALUE *argv, VALUE self) {
520
954
  size_t len = RSTRING_LEN(data);
521
955
  uint32_t crc = NIL_P(prev) ? 0 : NUM2UINT(prev);
522
956
 
523
- crc = ~crc;
524
- for (size_t i = 0; i < len; i++) {
525
- crc = crc32_table[(crc ^ src[i]) & 0xFF] ^ (crc >> 8);
526
- }
527
- crc = ~crc;
528
-
529
- return UINT2NUM(crc);
957
+ return UINT2NUM(crc32_compute(src, len, crc));
530
958
  }
531
959
 
532
960
  static VALUE compress_adler32(int argc, VALUE *argv, VALUE self) {
@@ -584,6 +1012,9 @@ static VALUE compress_version(VALUE self, VALUE algo_sym) {
584
1012
  return Qnil;
585
1013
  }
586
1014
 
1015
+ #define LZ4_RING_BUFFER_SIZE (64 * 1024)
1016
+ #define LZ4_RING_BUFFER_TOTAL (LZ4_RING_BUFFER_SIZE * 2)
1017
+
587
1018
  typedef struct {
588
1019
  compress_algo_t algo;
589
1020
  int level;
@@ -598,9 +1029,9 @@ typedef struct {
598
1029
 
599
1030
  struct {
600
1031
  char *buf;
601
- size_t len;
602
- size_t cap;
603
- } lz4_buf;
1032
+ size_t ring_offset;
1033
+ size_t pending;
1034
+ } lz4_ring;
604
1035
  } deflater_t;
605
1036
 
606
1037
  static void deflater_free(void *ptr) {
@@ -621,14 +1052,17 @@ static void deflater_free(void *ptr) {
621
1052
  break;
622
1053
  }
623
1054
  }
624
- if (d->lz4_buf.buf)
625
- xfree(d->lz4_buf.buf);
1055
+ if (d->lz4_ring.buf)
1056
+ xfree(d->lz4_ring.buf);
626
1057
  xfree(d);
627
1058
  }
628
1059
 
629
1060
  static size_t deflater_memsize(const void *ptr) {
630
1061
  const deflater_t *d = (const deflater_t *)ptr;
631
- return sizeof(deflater_t) + d->lz4_buf.cap;
1062
+ size_t s = sizeof(deflater_t);
1063
+ if (d->lz4_ring.buf)
1064
+ s += LZ4_RING_BUFFER_TOTAL;
1065
+ return s;
632
1066
  }
633
1067
 
634
1068
  static const rb_data_type_t deflater_type = {"Compress::Deflater",
@@ -708,9 +1142,9 @@ static VALUE deflater_initialize(int argc, VALUE *argv, VALUE self) {
708
1142
  if (!d->ctx.lz4)
709
1143
  rb_raise(eMemError, "lz4: failed to create stream");
710
1144
  LZ4_resetStream(d->ctx.lz4);
711
- d->lz4_buf.cap = 64 * 1024;
712
- d->lz4_buf.buf = ALLOC_N(char, d->lz4_buf.cap);
713
- d->lz4_buf.len = 0;
1145
+ d->lz4_ring.buf = ALLOC_N(char, LZ4_RING_BUFFER_TOTAL);
1146
+ d->lz4_ring.ring_offset = 0;
1147
+ d->lz4_ring.pending = 0;
714
1148
  break;
715
1149
  }
716
1150
  }
@@ -718,16 +1152,15 @@ static VALUE deflater_initialize(int argc, VALUE *argv, VALUE self) {
718
1152
  return self;
719
1153
  }
720
1154
 
721
- static VALUE lz4_compress_block(deflater_t *d) {
722
- if (d->lz4_buf.len == 0)
723
- return rb_str_new("", 0);
1155
+ static VALUE lz4_compress_ring_block(deflater_t *d) {
1156
+ if (d->lz4_ring.pending == 0)
1157
+ return rb_binary_str_new("", 0);
724
1158
 
725
- if (d->lz4_buf.len > (size_t)INT_MAX)
726
- rb_raise(eError, "lz4: block too large (max 2GB)");
727
- int src_size = (int)d->lz4_buf.len;
1159
+ char *block_start = d->lz4_ring.buf + d->lz4_ring.ring_offset - d->lz4_ring.pending;
1160
+ int src_size = (int)d->lz4_ring.pending;
728
1161
  int bound = LZ4_compressBound(src_size);
729
1162
 
730
- VALUE output = rb_str_buf_new(8 + bound);
1163
+ VALUE output = rb_binary_str_buf_new(8 + bound);
731
1164
  char *out = RSTRING_PTR(output);
732
1165
 
733
1166
  out[0] = (src_size >> 0) & 0xFF;
@@ -735,8 +1168,7 @@ static VALUE lz4_compress_block(deflater_t *d) {
735
1168
  out[2] = (src_size >> 16) & 0xFF;
736
1169
  out[3] = (src_size >> 24) & 0xFF;
737
1170
 
738
- int csize = LZ4_compress_fast_continue(d->ctx.lz4, d->lz4_buf.buf, out + 8, src_size, bound, 1);
739
-
1171
+ int csize = LZ4_compress_fast_continue(d->ctx.lz4, block_start, out + 8, src_size, bound, 1);
740
1172
  if (csize <= 0)
741
1173
  rb_raise(eError, "lz4 stream compress block failed");
742
1174
 
@@ -746,7 +1178,12 @@ static VALUE lz4_compress_block(deflater_t *d) {
746
1178
  out[7] = (csize >> 24) & 0xFF;
747
1179
 
748
1180
  rb_str_set_len(output, 8 + csize);
749
- d->lz4_buf.len = 0;
1181
+ d->lz4_ring.pending = 0;
1182
+
1183
+ if (d->lz4_ring.ring_offset >= LZ4_RING_BUFFER_SIZE) {
1184
+ d->lz4_ring.ring_offset = 0;
1185
+ }
1186
+
750
1187
  return output;
751
1188
  }
752
1189
 
@@ -762,29 +1199,49 @@ static VALUE deflater_write(VALUE self, VALUE chunk) {
762
1199
  const char *src = RSTRING_PTR(chunk);
763
1200
  size_t slen = RSTRING_LEN(chunk);
764
1201
  if (slen == 0)
765
- return rb_str_new("", 0);
1202
+ return rb_binary_str_new("", 0);
766
1203
 
767
1204
  switch (d->algo) {
768
1205
  case ALGO_ZSTD: {
769
1206
  ZSTD_inBuffer input = {src, slen, 0};
770
1207
  size_t out_cap = ZSTD_CStreamOutSize();
771
- VALUE result = rb_str_buf_new(0);
1208
+ size_t result_cap = out_cap > slen ? out_cap : slen;
1209
+ VALUE result = rb_binary_str_buf_new(result_cap);
1210
+ size_t result_len = 0;
1211
+ int use_fiber = has_fiber_scheduler();
1212
+ size_t fiber_counter = 0;
772
1213
 
773
1214
  while (input.pos < input.size) {
774
- VALUE buf = rb_str_buf_new(out_cap);
775
- ZSTD_outBuffer output = {RSTRING_PTR(buf), out_cap, 0};
1215
+ if (result_len + out_cap > result_cap) {
1216
+ result_cap = result_cap * 2;
1217
+ rb_str_resize(result, result_cap);
1218
+ }
1219
+
1220
+ ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, out_cap, 0};
776
1221
  size_t ret = ZSTD_compressStream(d->ctx.zstd, &output, &input);
777
1222
  if (ZSTD_isError(ret))
778
1223
  rb_raise(eError, "zstd compress stream: %s", ZSTD_getErrorName(ret));
779
- if (output.pos > 0)
780
- rb_str_cat(result, RSTRING_PTR(buf), output.pos);
1224
+ result_len += output.pos;
1225
+ if (use_fiber) {
1226
+ int did_yield = 0;
1227
+ fiber_counter = fiber_maybe_yield(fiber_counter, output.pos, &did_yield);
1228
+ (void)did_yield;
1229
+ }
781
1230
  }
1231
+ rb_str_set_len(result, result_len);
1232
+ RB_GC_GUARD(chunk);
782
1233
  return result;
783
1234
  }
784
1235
  case ALGO_BROTLI: {
785
1236
  size_t available_in = slen;
786
1237
  const uint8_t *next_in = (const uint8_t *)src;
787
- VALUE result = rb_str_buf_new(0);
1238
+ size_t result_cap = slen;
1239
+ if (result_cap < 1024)
1240
+ result_cap = 1024;
1241
+ VALUE result = rb_binary_str_buf_new(result_cap);
1242
+ size_t result_len = 0;
1243
+ int use_fiber = has_fiber_scheduler();
1244
+ size_t fiber_counter = 0;
788
1245
 
789
1246
  while (available_in > 0 || BrotliEncoderHasMoreOutput(d->ctx.brotli)) {
790
1247
  size_t available_out = 0;
@@ -798,29 +1255,65 @@ static VALUE deflater_write(VALUE self, VALUE chunk) {
798
1255
  const uint8_t *out_data;
799
1256
  size_t out_size = 0;
800
1257
  out_data = BrotliEncoderTakeOutput(d->ctx.brotli, &out_size);
801
- if (out_size > 0)
802
- rb_str_cat(result, (const char *)out_data, out_size);
1258
+ if (out_size > 0) {
1259
+ if (result_len + out_size > result_cap) {
1260
+ result_cap = (result_len + out_size) * 2;
1261
+ rb_str_resize(result, result_cap);
1262
+ }
1263
+
1264
+ memcpy(RSTRING_PTR(result) + result_len, out_data, out_size);
1265
+ result_len += out_size;
1266
+ if (use_fiber) {
1267
+ int did_yield = 0;
1268
+ fiber_counter = fiber_maybe_yield(fiber_counter, out_size, &did_yield);
1269
+ (void)did_yield;
1270
+ }
1271
+ }
803
1272
  }
1273
+ rb_str_set_len(result, result_len);
1274
+ RB_GC_GUARD(chunk);
804
1275
  return result;
805
1276
  }
806
1277
  case ALGO_LZ4: {
807
- VALUE result = rb_str_buf_new(0);
1278
+ VALUE result = rb_binary_str_buf_new(0);
1279
+ size_t result_len = 0;
1280
+ size_t result_cap = 0;
1281
+
808
1282
  while (slen > 0) {
809
- size_t space = d->lz4_buf.cap - d->lz4_buf.len;
1283
+ size_t space = LZ4_RING_BUFFER_SIZE - d->lz4_ring.pending;
810
1284
  size_t copy = slen < space ? slen : space;
811
- memcpy(d->lz4_buf.buf + d->lz4_buf.len, src, copy);
812
- d->lz4_buf.len += copy;
1285
+
1286
+ if (d->lz4_ring.ring_offset + copy > LZ4_RING_BUFFER_TOTAL) {
1287
+ rb_raise(eError, "lz4: ring buffer overflow");
1288
+ }
1289
+
1290
+ memcpy(d->lz4_ring.buf + d->lz4_ring.ring_offset, src, copy);
1291
+ d->lz4_ring.ring_offset += copy;
1292
+ d->lz4_ring.pending += copy;
813
1293
  src += copy;
814
1294
  slen -= copy;
815
- if (d->lz4_buf.len >= d->lz4_buf.cap) {
816
- VALUE block = lz4_compress_block(d);
817
- rb_str_cat(result, RSTRING_PTR(block), RSTRING_LEN(block));
1295
+
1296
+ if (d->lz4_ring.pending >= (size_t)LZ4_RING_BUFFER_SIZE) {
1297
+ VALUE block = lz4_compress_ring_block(d);
1298
+ size_t blen = RSTRING_LEN(block);
1299
+ if (blen > 0) {
1300
+ if (result_len + blen > result_cap) {
1301
+ result_cap = (result_len + blen) * 2;
1302
+ if (result_cap < 256)
1303
+ result_cap = 256;
1304
+ rb_str_resize(result, result_cap);
1305
+ }
1306
+ memcpy(RSTRING_PTR(result) + result_len, RSTRING_PTR(block), blen);
1307
+ result_len += blen;
1308
+ }
818
1309
  }
819
1310
  }
1311
+ rb_str_set_len(result, result_len);
1312
+ RB_GC_GUARD(chunk);
820
1313
  return result;
821
1314
  }
822
1315
  }
823
- return rb_str_new("", 0);
1316
+ return rb_binary_str_new("", 0);
824
1317
  }
825
1318
 
826
1319
  static VALUE deflater_flush(VALUE self) {
@@ -834,23 +1327,34 @@ static VALUE deflater_flush(VALUE self) {
834
1327
  switch (d->algo) {
835
1328
  case ALGO_ZSTD: {
836
1329
  size_t out_cap = ZSTD_CStreamOutSize();
837
- VALUE result = rb_str_buf_new(0);
1330
+ size_t result_cap = out_cap;
1331
+ VALUE result = rb_binary_str_buf_new(result_cap);
1332
+ size_t result_len = 0;
838
1333
  size_t ret;
1334
+
839
1335
  do {
840
- VALUE buf = rb_str_buf_new(out_cap);
841
- ZSTD_outBuffer output = {RSTRING_PTR(buf), out_cap, 0};
1336
+ if (result_len + out_cap > result_cap) {
1337
+ result_cap *= 2;
1338
+ rb_str_resize(result, result_cap);
1339
+ }
1340
+
1341
+ ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, out_cap, 0};
842
1342
  ret = ZSTD_flushStream(d->ctx.zstd, &output);
843
1343
  if (ZSTD_isError(ret))
844
1344
  rb_raise(eError, "zstd flush: %s", ZSTD_getErrorName(ret));
845
- if (output.pos > 0)
846
- rb_str_cat(result, RSTRING_PTR(buf), output.pos);
1345
+ result_len += output.pos;
847
1346
  } while (ret > 0);
1347
+
1348
+ rb_str_set_len(result, result_len);
848
1349
  return result;
849
1350
  }
850
1351
  case ALGO_BROTLI: {
851
1352
  size_t available_in = 0;
852
1353
  const uint8_t *next_in = NULL;
853
- VALUE result = rb_str_buf_new(0);
1354
+ size_t result_cap = 1024;
1355
+ VALUE result = rb_binary_str_buf_new(result_cap);
1356
+ size_t result_len = 0;
1357
+
854
1358
  do {
855
1359
  size_t available_out = 0;
856
1360
  uint8_t *next_out = NULL;
@@ -862,15 +1366,24 @@ static VALUE deflater_flush(VALUE self) {
862
1366
  const uint8_t *out_data;
863
1367
  size_t out_size = 0;
864
1368
  out_data = BrotliEncoderTakeOutput(d->ctx.brotli, &out_size);
865
- if (out_size > 0)
866
- rb_str_cat(result, (const char *)out_data, out_size);
1369
+ if (out_size > 0) {
1370
+ if (result_len + out_size > result_cap) {
1371
+ result_cap = (result_len + out_size) * 2;
1372
+ rb_str_resize(result, result_cap);
1373
+ }
1374
+
1375
+ memcpy(RSTRING_PTR(result) + result_len, out_data, out_size);
1376
+ result_len += out_size;
1377
+ }
867
1378
  } while (BrotliEncoderHasMoreOutput(d->ctx.brotli));
1379
+
1380
+ rb_str_set_len(result, result_len);
868
1381
  return result;
869
1382
  }
870
1383
  case ALGO_LZ4:
871
- return lz4_compress_block(d);
1384
+ return lz4_compress_ring_block(d);
872
1385
  }
873
- return rb_str_new("", 0);
1386
+ return rb_binary_str_new("", 0);
874
1387
  }
875
1388
 
876
1389
  static VALUE deflater_finish(VALUE self) {
@@ -879,29 +1392,40 @@ static VALUE deflater_finish(VALUE self) {
879
1392
  if (d->closed)
880
1393
  rb_raise(eStreamError, "stream is closed");
881
1394
  if (d->finished)
882
- return rb_str_new("", 0);
1395
+ return rb_binary_str_new("", 0);
883
1396
  d->finished = 1;
884
1397
 
885
1398
  switch (d->algo) {
886
1399
  case ALGO_ZSTD: {
887
1400
  size_t out_cap = ZSTD_CStreamOutSize();
888
- VALUE result = rb_str_buf_new(0);
1401
+ size_t result_cap = out_cap;
1402
+ VALUE result = rb_binary_str_buf_new(result_cap);
1403
+ size_t result_len = 0;
889
1404
  size_t ret;
1405
+
890
1406
  do {
891
- VALUE buf = rb_str_buf_new(out_cap);
892
- ZSTD_outBuffer output = {RSTRING_PTR(buf), out_cap, 0};
1407
+ if (result_len + out_cap > result_cap) {
1408
+ result_cap *= 2;
1409
+ rb_str_resize(result, result_cap);
1410
+ }
1411
+
1412
+ ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, out_cap, 0};
893
1413
  ret = ZSTD_endStream(d->ctx.zstd, &output);
894
1414
  if (ZSTD_isError(ret))
895
1415
  rb_raise(eError, "zstd end stream: %s", ZSTD_getErrorName(ret));
896
- if (output.pos > 0)
897
- rb_str_cat(result, RSTRING_PTR(buf), output.pos);
1416
+ result_len += output.pos;
898
1417
  } while (ret > 0);
1418
+
1419
+ rb_str_set_len(result, result_len);
899
1420
  return result;
900
1421
  }
901
1422
  case ALGO_BROTLI: {
902
1423
  size_t available_in = 0;
903
1424
  const uint8_t *next_in = NULL;
904
- VALUE result = rb_str_buf_new(0);
1425
+ size_t result_cap = 1024;
1426
+ VALUE result = rb_binary_str_buf_new(result_cap);
1427
+ size_t result_len = 0;
1428
+
905
1429
  do {
906
1430
  size_t available_out = 0;
907
1431
  uint8_t *next_out = NULL;
@@ -913,25 +1437,54 @@ static VALUE deflater_finish(VALUE self) {
913
1437
  const uint8_t *out_data;
914
1438
  size_t out_size = 0;
915
1439
  out_data = BrotliEncoderTakeOutput(d->ctx.brotli, &out_size);
916
- if (out_size > 0)
917
- rb_str_cat(result, (const char *)out_data, out_size);
1440
+ if (out_size > 0) {
1441
+ if (result_len + out_size > result_cap) {
1442
+ result_cap = (result_len + out_size) * 2;
1443
+ rb_str_resize(result, result_cap);
1444
+ }
1445
+
1446
+ memcpy(RSTRING_PTR(result) + result_len, out_data, out_size);
1447
+ result_len += out_size;
1448
+ }
918
1449
  } while (BrotliEncoderHasMoreOutput(d->ctx.brotli) ||
919
1450
  !BrotliEncoderIsFinished(d->ctx.brotli));
1451
+
1452
+ rb_str_set_len(result, result_len);
920
1453
  return result;
921
1454
  }
922
1455
  case ALGO_LZ4: {
923
- VALUE result = rb_str_buf_new(0);
924
- if (d->lz4_buf.len > 0) {
925
- VALUE block = lz4_compress_block(d);
926
- rb_str_cat(result, RSTRING_PTR(block), RSTRING_LEN(block));
1456
+ size_t result_cap = 256;
1457
+ VALUE result = rb_binary_str_buf_new(result_cap);
1458
+ size_t result_len = 0;
1459
+
1460
+ if (d->lz4_ring.pending > 0) {
1461
+ VALUE block = lz4_compress_ring_block(d);
1462
+ size_t blen = RSTRING_LEN(block);
1463
+ if (blen > 0) {
1464
+ if (blen + 4 > result_cap) {
1465
+ result_cap = blen + 4;
1466
+ rb_str_resize(result, result_cap);
1467
+ }
1468
+
1469
+ memcpy(RSTRING_PTR(result), RSTRING_PTR(block), blen);
1470
+ result_len = blen;
1471
+ }
927
1472
  }
928
1473
 
929
- char eos[4] = {0, 0, 0, 0};
930
- rb_str_cat(result, eos, 4);
1474
+ if (result_len + 4 > result_cap) {
1475
+ result_cap = result_len + 4;
1476
+ rb_str_resize(result, result_cap);
1477
+ }
1478
+
1479
+ char *out = RSTRING_PTR(result) + result_len;
1480
+ out[0] = out[1] = out[2] = out[3] = 0;
1481
+ result_len += 4;
1482
+
1483
+ rb_str_set_len(result, result_len);
931
1484
  return result;
932
1485
  }
933
1486
  }
934
- return rb_str_new("", 0);
1487
+ return rb_binary_str_new("", 0);
935
1488
  }
936
1489
 
937
1490
  static VALUE deflater_reset(VALUE self) {
@@ -957,7 +1510,8 @@ static VALUE deflater_reset(VALUE self) {
957
1510
  case ALGO_LZ4:
958
1511
  if (d->ctx.lz4)
959
1512
  LZ4_resetStream(d->ctx.lz4);
960
- d->lz4_buf.len = 0;
1513
+ d->lz4_ring.ring_offset = 0;
1514
+ d->lz4_ring.pending = 0;
961
1515
  break;
962
1516
  }
963
1517
  d->closed = 0;
@@ -1015,6 +1569,7 @@ typedef struct {
1015
1569
  char *buf;
1016
1570
  size_t len;
1017
1571
  size_t cap;
1572
+ size_t offset;
1018
1573
  } lz4_buf;
1019
1574
  } inflater_t;
1020
1575
 
@@ -1108,6 +1663,7 @@ static VALUE inflater_initialize(int argc, VALUE *argv, VALUE self) {
1108
1663
  inf->lz4_buf.cap = 16 * 1024;
1109
1664
  inf->lz4_buf.buf = ALLOC_N(char, inf->lz4_buf.cap);
1110
1665
  inf->lz4_buf.len = 0;
1666
+ inf->lz4_buf.offset = 0;
1111
1667
  break;
1112
1668
  }
1113
1669
 
@@ -1124,30 +1680,52 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
1124
1680
  const char *src = RSTRING_PTR(chunk);
1125
1681
  size_t slen = RSTRING_LEN(chunk);
1126
1682
  if (slen == 0)
1127
- return rb_str_new("", 0);
1683
+ return rb_binary_str_new("", 0);
1128
1684
 
1129
1685
  switch (inf->algo) {
1130
1686
  case ALGO_ZSTD: {
1131
1687
  ZSTD_inBuffer input = {src, slen, 0};
1132
1688
  size_t out_cap = ZSTD_DStreamOutSize();
1133
- VALUE result = rb_str_buf_new(0);
1689
+ size_t result_cap = out_cap > slen * 2 ? out_cap : slen * 2;
1690
+ VALUE result = rb_binary_str_buf_new(result_cap);
1691
+ size_t result_len = 0;
1692
+ int use_fiber = has_fiber_scheduler();
1693
+ size_t fiber_counter = 0;
1694
+
1134
1695
  while (input.pos < input.size) {
1135
- VALUE buf = rb_str_buf_new(out_cap);
1136
- ZSTD_outBuffer output = {RSTRING_PTR(buf), out_cap, 0};
1696
+ if (result_len + out_cap > result_cap) {
1697
+ result_cap = result_cap * 2;
1698
+ rb_str_resize(result, result_cap);
1699
+ }
1700
+
1701
+ ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, out_cap, 0};
1137
1702
  size_t ret = ZSTD_decompressStream(inf->ctx.zstd, &output, &input);
1138
1703
  if (ZSTD_isError(ret))
1139
1704
  rb_raise(eDataError, "zstd decompress stream: %s", ZSTD_getErrorName(ret));
1140
- if (output.pos > 0)
1141
- rb_str_cat(result, RSTRING_PTR(buf), output.pos);
1705
+ result_len += output.pos;
1706
+ if (use_fiber) {
1707
+ int did_yield = 0;
1708
+ fiber_counter = fiber_maybe_yield(fiber_counter, output.pos, &did_yield);
1709
+ (void)did_yield;
1710
+ }
1142
1711
  if (ret == 0)
1143
1712
  break;
1144
1713
  }
1714
+ rb_str_set_len(result, result_len);
1715
+ RB_GC_GUARD(chunk);
1145
1716
  return result;
1146
1717
  }
1147
1718
  case ALGO_BROTLI: {
1148
1719
  size_t available_in = slen;
1149
1720
  const uint8_t *next_in = (const uint8_t *)src;
1150
- VALUE result = rb_str_buf_new(0);
1721
+ size_t result_cap = slen * 2;
1722
+ if (result_cap < 1024)
1723
+ result_cap = 1024;
1724
+ VALUE result = rb_binary_str_buf_new(result_cap);
1725
+ size_t result_len = 0;
1726
+ int use_fiber = has_fiber_scheduler();
1727
+ size_t fiber_counter = 0;
1728
+
1151
1729
  while (available_in > 0 || BrotliDecoderHasMoreOutput(inf->ctx.brotli)) {
1152
1730
  size_t available_out = 0;
1153
1731
  uint8_t *next_out = NULL;
@@ -1159,19 +1737,46 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
1159
1737
  const uint8_t *out_data;
1160
1738
  size_t out_size = 0;
1161
1739
  out_data = BrotliDecoderTakeOutput(inf->ctx.brotli, &out_size);
1162
- if (out_size > 0)
1163
- rb_str_cat(result, (const char *)out_data, out_size);
1740
+ if (out_size > 0) {
1741
+ if (result_len + out_size > result_cap) {
1742
+ result_cap = (result_len + out_size) * 2;
1743
+ rb_str_resize(result, result_cap);
1744
+ }
1745
+
1746
+ memcpy(RSTRING_PTR(result) + result_len, out_data, out_size);
1747
+ result_len += out_size;
1748
+ if (use_fiber) {
1749
+ int did_yield = 0;
1750
+ fiber_counter = fiber_maybe_yield(fiber_counter, out_size, &did_yield);
1751
+ (void)did_yield;
1752
+ }
1753
+ }
1164
1754
  if (res == BROTLI_DECODER_RESULT_SUCCESS)
1165
1755
  break;
1166
1756
  if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in == 0)
1167
1757
  break;
1168
1758
  }
1759
+ rb_str_set_len(result, result_len);
1760
+ RB_GC_GUARD(chunk);
1169
1761
  return result;
1170
1762
  }
1171
1763
  case ALGO_LZ4: {
1172
- VALUE result = rb_str_buf_new(0);
1764
+ size_t data_len = inf->lz4_buf.len - inf->lz4_buf.offset;
1765
+ size_t needed = data_len + slen;
1766
+
1767
+ if (inf->lz4_buf.offset > 0 && needed > inf->lz4_buf.cap) {
1768
+ if (data_len > 0)
1769
+ memmove(inf->lz4_buf.buf, inf->lz4_buf.buf + inf->lz4_buf.offset, data_len);
1770
+ inf->lz4_buf.offset = 0;
1771
+ inf->lz4_buf.len = data_len;
1772
+ } else if (inf->lz4_buf.offset > inf->lz4_buf.cap / 2) {
1773
+ if (data_len > 0)
1774
+ memmove(inf->lz4_buf.buf, inf->lz4_buf.buf + inf->lz4_buf.offset, data_len);
1775
+ inf->lz4_buf.offset = 0;
1776
+ inf->lz4_buf.len = data_len;
1777
+ }
1173
1778
 
1174
- size_t needed = inf->lz4_buf.len + slen;
1779
+ needed = inf->lz4_buf.len + slen;
1175
1780
  if (needed > inf->lz4_buf.cap) {
1176
1781
  inf->lz4_buf.cap = needed * 2;
1177
1782
  REALLOC_N(inf->lz4_buf.buf, char, inf->lz4_buf.cap);
@@ -1179,7 +1784,15 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
1179
1784
  memcpy(inf->lz4_buf.buf + inf->lz4_buf.len, src, slen);
1180
1785
  inf->lz4_buf.len += slen;
1181
1786
 
1182
- size_t pos = 0;
1787
+ size_t result_cap = slen * 4;
1788
+ if (result_cap < 256)
1789
+ result_cap = 256;
1790
+ VALUE result = rb_binary_str_buf_new(result_cap);
1791
+ size_t result_len = 0;
1792
+ int use_fiber = has_fiber_scheduler();
1793
+ size_t fiber_counter = 0;
1794
+
1795
+ size_t pos = inf->lz4_buf.offset;
1183
1796
  while (pos + 4 <= inf->lz4_buf.len) {
1184
1797
  const uint8_t *p = (const uint8_t *)(inf->lz4_buf.buf + pos);
1185
1798
  uint32_t orig_size = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16) |
@@ -1198,25 +1811,32 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
1198
1811
  if (orig_size > 64 * 1024 * 1024)
1199
1812
  rb_raise(eDataError, "lz4 stream: block too large (%u)", orig_size);
1200
1813
 
1201
- VALUE block = rb_str_buf_new(orig_size);
1202
- int dsize = LZ4_decompress_safe(inf->lz4_buf.buf + pos + 8, RSTRING_PTR(block),
1203
- (int)comp_size, (int)orig_size);
1814
+ if (result_len + orig_size > result_cap) {
1815
+ result_cap = (result_len + orig_size) * 2;
1816
+ rb_str_resize(result, result_cap);
1817
+ }
1818
+
1819
+ int dsize =
1820
+ LZ4_decompress_safe(inf->lz4_buf.buf + pos + 8, RSTRING_PTR(result) + result_len,
1821
+ (int)comp_size, (int)orig_size);
1204
1822
  if (dsize < 0)
1205
1823
  rb_raise(eDataError, "lz4 stream decompress block failed");
1206
- rb_str_set_len(block, dsize);
1207
- rb_str_cat(result, RSTRING_PTR(block), dsize);
1824
+ result_len += dsize;
1208
1825
  pos += 8 + comp_size;
1826
+ if (use_fiber) {
1827
+ int did_yield = 0;
1828
+ fiber_counter = fiber_maybe_yield(fiber_counter, (size_t)dsize, &did_yield);
1829
+ (void)did_yield;
1830
+ }
1209
1831
  }
1210
1832
 
1211
- if (pos > 0) {
1212
- inf->lz4_buf.len -= pos;
1213
- if (inf->lz4_buf.len > 0)
1214
- memmove(inf->lz4_buf.buf, inf->lz4_buf.buf + pos, inf->lz4_buf.len);
1215
- }
1833
+ inf->lz4_buf.offset = pos;
1834
+ rb_str_set_len(result, result_len);
1835
+ RB_GC_GUARD(chunk);
1216
1836
  return result;
1217
1837
  }
1218
1838
  }
1219
- return rb_str_new("", 0);
1839
+ return rb_binary_str_new("", 0);
1220
1840
  }
1221
1841
 
1222
1842
  static VALUE inflater_finish(VALUE self) {
@@ -1225,7 +1845,7 @@ static VALUE inflater_finish(VALUE self) {
1225
1845
  if (inf->closed)
1226
1846
  rb_raise(eStreamError, "stream is closed");
1227
1847
  inf->finished = 1;
1228
- return rb_str_new("", 0);
1848
+ return rb_binary_str_new("", 0);
1229
1849
  }
1230
1850
 
1231
1851
  static VALUE inflater_reset(VALUE self) {
@@ -1247,6 +1867,7 @@ static VALUE inflater_reset(VALUE self) {
1247
1867
  break;
1248
1868
  case ALGO_LZ4:
1249
1869
  inf->lz4_buf.len = 0;
1870
+ inf->lz4_buf.offset = 0;
1250
1871
  break;
1251
1872
  }
1252
1873
  inf->closed = 0;
@@ -1310,23 +1931,18 @@ static VALUE dict_initialize(int argc, VALUE *argv, VALUE self) {
1310
1931
  return self;
1311
1932
  }
1312
1933
 
1313
- static VALUE dict_train(int argc, VALUE *argv, VALUE self) {
1934
+ static VALUE brotli_train_dictionary(int argc, VALUE *argv, VALUE self) {
1314
1935
  VALUE samples, opts;
1315
1936
  rb_scan_args(argc, argv, "1:", &samples, &opts);
1316
1937
  Check_Type(samples, T_ARRAY);
1317
1938
 
1318
- VALUE algo_sym = Qnil, size_val = Qnil;
1939
+ VALUE size_val = Qnil;
1319
1940
  if (!NIL_P(opts)) {
1320
- algo_sym = rb_hash_aref(opts, ID2SYM(rb_intern("algo")));
1321
1941
  size_val = rb_hash_aref(opts, ID2SYM(rb_intern("size")));
1322
1942
  }
1323
1943
 
1324
- compress_algo_t algo = NIL_P(algo_sym) ? ALGO_ZSTD : sym_to_algo(algo_sym);
1325
1944
  size_t dict_capacity = NIL_P(size_val) ? 32768 : NUM2SIZET(size_val);
1326
1945
 
1327
- if (algo == ALGO_LZ4)
1328
- rb_raise(eUnsupportedError, "LZ4 does not support dictionary training");
1329
-
1330
1946
  long num_samples = RARRAY_LEN(samples);
1331
1947
  if (num_samples < 1)
1332
1948
  rb_raise(rb_eArgError, "need at least 1 sample for training");
@@ -1335,55 +1951,57 @@ static VALUE dict_train(int argc, VALUE *argv, VALUE self) {
1335
1951
  for (long i = 0; i < num_samples; i++) {
1336
1952
  VALUE s = rb_ary_entry(samples, i);
1337
1953
  StringValue(s);
1338
- total_size += RSTRING_LEN(s);
1954
+ size_t slen = RSTRING_LEN(s);
1955
+ if (slen < 8) {
1956
+ rb_raise(rb_eArgError, "sample %ld is too small (%zu bytes), minimum is 8 bytes", i,
1957
+ slen);
1958
+ }
1959
+ total_size += slen;
1960
+ }
1961
+
1962
+ uint8_t *dict_buf = ALLOC_N(uint8_t, dict_capacity);
1963
+ if (!dict_buf) {
1964
+ rb_raise(eMemError, "failed to allocate memory for brotli dictionary training");
1339
1965
  }
1340
1966
 
1341
1967
  char *concat = ALLOC_N(char, total_size);
1342
- size_t *sizes = ALLOC_N(size_t, num_samples);
1343
- size_t offset = 0;
1968
+ if (!concat) {
1969
+ xfree(dict_buf);
1970
+ rb_raise(eMemError, "failed to allocate memory for brotli dictionary training");
1971
+ }
1344
1972
 
1973
+ size_t offset = 0;
1345
1974
  for (long i = 0; i < num_samples; i++) {
1346
1975
  VALUE s = rb_ary_entry(samples, i);
1347
- size_t slen = RSTRING_LEN(s);
1348
- memcpy(concat + offset, RSTRING_PTR(s), slen);
1349
- sizes[i] = slen;
1350
- offset += slen;
1351
- }
1352
-
1353
- uint8_t *dict_buf = ALLOC_N(uint8_t, dict_capacity);
1976
+ StringValue(s);
1354
1977
 
1355
- if (algo == ALGO_ZSTD) {
1356
- size_t result =
1357
- ZDICT_trainFromBuffer(dict_buf, dict_capacity, concat, sizes, (unsigned)num_samples);
1358
- xfree(concat);
1359
- xfree(sizes);
1978
+ const char *str_ptr = RSTRING_PTR(s);
1979
+ size_t slen = RSTRING_LEN(s);
1360
1980
 
1361
- if (ZDICT_isError(result)) {
1981
+ if (offset + slen > total_size) {
1982
+ xfree(concat);
1362
1983
  xfree(dict_buf);
1363
- rb_raise(eError, "zstd dict training: %s", ZDICT_getErrorName(result));
1984
+ rb_raise(eError, "buffer overflow during concatenation");
1364
1985
  }
1365
1986
 
1366
- VALUE dict_obj = rb_obj_alloc(cDictionary);
1367
- dictionary_t *d;
1368
- TypedData_Get_Struct(dict_obj, dictionary_t, &dictionary_type, d);
1369
- d->algo = ALGO_ZSTD;
1370
- d->data = dict_buf;
1371
- d->size = result;
1372
- return dict_obj;
1373
- } else {
1374
- xfree(sizes);
1375
- size_t actual_size = total_size < dict_capacity ? total_size : dict_capacity;
1376
- memcpy(dict_buf, concat, actual_size);
1377
- xfree(concat);
1378
-
1379
- VALUE dict_obj = rb_obj_alloc(cDictionary);
1380
- dictionary_t *d;
1381
- TypedData_Get_Struct(dict_obj, dictionary_t, &dictionary_type, d);
1382
- d->algo = ALGO_BROTLI;
1383
- d->data = dict_buf;
1384
- d->size = actual_size;
1385
- return dict_obj;
1987
+ memcpy(concat + offset, str_ptr, slen);
1988
+ offset += slen;
1989
+
1990
+ RB_GC_GUARD(s);
1386
1991
  }
1992
+
1993
+ size_t dict_size = total_size < dict_capacity ? total_size : dict_capacity;
1994
+ memcpy(dict_buf, concat, dict_size);
1995
+ xfree(concat);
1996
+
1997
+ VALUE dict_obj = rb_obj_alloc(cDictionary);
1998
+ dictionary_t *d;
1999
+ TypedData_Get_Struct(dict_obj, dictionary_t, &dictionary_type, d);
2000
+ memset(d, 0, sizeof(*d));
2001
+ d->algo = ALGO_BROTLI;
2002
+ d->data = dict_buf;
2003
+ d->size = dict_size;
2004
+ return dict_obj;
1387
2005
  }
1388
2006
 
1389
2007
  static VALUE dict_load(int argc, VALUE *argv, VALUE self) {
@@ -1470,6 +2088,9 @@ static VALUE dict_size(VALUE self) {
1470
2088
  }
1471
2089
 
1472
2090
  void Init_multi_compress(void) {
2091
+ binary_encoding = rb_ascii8bit_encoding();
2092
+ crc32_init_tables();
2093
+
1473
2094
  mMultiCompress = rb_define_module("MultiCompress");
1474
2095
 
1475
2096
  eError = rb_define_class_under(mMultiCompress, "Error", rb_eStandardError);
@@ -1526,9 +2147,9 @@ void Init_multi_compress(void) {
1526
2147
  cDictionary = rb_define_class_under(mMultiCompress, "Dictionary", rb_cObject);
1527
2148
  rb_define_alloc_func(cDictionary, dict_alloc);
1528
2149
  rb_define_method(cDictionary, "initialize", dict_initialize, -1);
1529
- rb_define_singleton_method(cDictionary, "train", dict_train, -1);
1530
2150
  rb_define_singleton_method(cDictionary, "load", dict_load, -1);
1531
2151
  rb_define_method(cDictionary, "save", dict_save, 1);
1532
2152
  rb_define_method(cDictionary, "algo", dict_algo, 0);
1533
2153
  rb_define_method(cDictionary, "size", dict_size, 0);
2154
+ rb_define_singleton_method(mBrotli, "train_dictionary", brotli_train_dictionary, -1);
1534
2155
  }