multi_compress 0.2.4 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@
6
6
  #include <brotli/encode.h>
7
7
  #include <lz4.h>
8
8
  #include <lz4hc.h>
9
+ #include <lz4frame.h>
9
10
  #include <pthread.h>
10
11
  #include <stdio.h>
11
12
  #include <stdint.h>
@@ -20,6 +21,20 @@
20
21
  #define RATIO_MIN_INPUT_BYTES 1024ULL
21
22
  #define DICT_FILE_MAX_SIZE (32ULL * 1024 * 1024)
22
23
 
24
+ #if defined(__GNUC__) || defined(__clang__)
25
+ #define MC_ALWAYS_INLINE static inline __attribute__((always_inline))
26
+ #else
27
+ #define MC_ALWAYS_INLINE static inline
28
+ #endif
29
+
30
+ #if defined(__GNUC__) || defined(__clang__)
31
+ #define MC_LIKELY(x) __builtin_expect(!!(x), 1)
32
+ #define MC_UNLIKELY(x) __builtin_expect(!!(x), 0)
33
+ #else
34
+ #define MC_LIKELY(x) (x)
35
+ #define MC_UNLIKELY(x) (x)
36
+ #endif
37
+
23
38
  typedef struct {
24
39
  size_t gvl_unlock_threshold;
25
40
  size_t fiber_yield_chunk;
@@ -62,7 +77,7 @@ static VALUE mBrotli;
62
77
  static rb_encoding *binary_encoding;
63
78
  static struct {
64
79
  ID zstd, lz4, brotli;
65
- ID algo, algorithm, level, dictionary, size;
80
+ ID algo, algorithm, level, dictionary, size, format, block, frame;
66
81
  ID max_output_size, max_ratio;
67
82
  ID fastest, default_, best;
68
83
  ID yield_, join;
@@ -71,11 +86,17 @@ static struct {
71
86
 
72
87
  static struct {
73
88
  VALUE zstd, lz4, brotli;
74
- VALUE algo, algorithm, level, dictionary, size;
89
+ VALUE algo, algorithm, level, dictionary, size, format, block, frame;
75
90
  VALUE max_output_size, max_ratio;
76
91
  } sym_cache;
77
92
 
78
93
  typedef enum { ALGO_ZSTD = 0, ALGO_LZ4 = 1, ALGO_BROTLI = 2 } compress_algo_t;
94
+ typedef enum { LZ4_FORMAT_BLOCK = 0, LZ4_FORMAT_FRAME = 1 } lz4_format_t;
95
+
96
+ #define MC_NUM_ALGOS 3
97
+
98
+ _Static_assert(ALGO_BROTLI == MC_NUM_ALGOS - 1,
99
+ "compress_algo_t must be contiguous [0..MC_NUM_ALGOS-1]");
79
100
 
80
101
  typedef struct dictionary_s dictionary_t;
81
102
  static const rb_data_type_t dictionary_type;
@@ -89,6 +110,9 @@ static void init_id_cache(void) {
89
110
  id_cache.level = rb_intern("level");
90
111
  id_cache.dictionary = rb_intern("dictionary");
91
112
  id_cache.size = rb_intern("size");
113
+ id_cache.format = rb_intern("format");
114
+ id_cache.block = rb_intern("block");
115
+ id_cache.frame = rb_intern("frame");
92
116
  id_cache.max_output_size = rb_intern("max_output_size");
93
117
  id_cache.max_ratio = rb_intern("max_ratio");
94
118
  id_cache.fastest = rb_intern("fastest");
@@ -106,6 +130,9 @@ static void init_id_cache(void) {
106
130
  sym_cache.level = ID2SYM(id_cache.level);
107
131
  sym_cache.dictionary = ID2SYM(id_cache.dictionary);
108
132
  sym_cache.size = ID2SYM(id_cache.size);
133
+ sym_cache.format = ID2SYM(id_cache.format);
134
+ sym_cache.block = ID2SYM(id_cache.block);
135
+ sym_cache.frame = ID2SYM(id_cache.frame);
109
136
  sym_cache.max_output_size = ID2SYM(id_cache.max_output_size);
110
137
  sym_cache.max_ratio = ID2SYM(id_cache.max_ratio);
111
138
  }
@@ -118,6 +145,30 @@ static inline VALUE opt_lookup2(VALUE opts, VALUE sym, VALUE default_value) {
118
145
  return NIL_P(opts) ? default_value : rb_hash_lookup2(opts, sym, default_value);
119
146
  }
120
147
 
148
+ enum { LZ4_FRAME_MAGIC_LEN = 4 };
149
+ static const uint8_t LZ4_FRAME_MAGIC[LZ4_FRAME_MAGIC_LEN] = {0x04, 0x22, 0x4D, 0x18};
150
+
151
+ static inline int is_lz4_frame_magic(const uint8_t *data, size_t len) {
152
+ return len >= LZ4_FRAME_MAGIC_LEN && memcmp(data, LZ4_FRAME_MAGIC, LZ4_FRAME_MAGIC_LEN) == 0;
153
+ }
154
+
155
+ static lz4_format_t parse_lz4_format(VALUE opts, compress_algo_t algo, int explicit_algo) {
156
+ VALUE format_val = opt_lookup2(opts, sym_cache.format, Qundef);
157
+ if (format_val == Qundef || NIL_P(format_val))
158
+ return LZ4_FORMAT_BLOCK;
159
+ if (explicit_algo && algo != ALGO_LZ4)
160
+ rb_raise(eUnsupportedError, "format is only supported for algo: :lz4");
161
+ if (!SYMBOL_P(format_val))
162
+ rb_raise(rb_eTypeError, "format must be a Symbol (:block or :frame)");
163
+ ID id = SYM2ID(format_val);
164
+ if (id == id_cache.block)
165
+ return LZ4_FORMAT_BLOCK;
166
+ if (id == id_cache.frame)
167
+ return LZ4_FORMAT_FRAME;
168
+ rb_raise(rb_eArgError, "Unknown LZ4 format: %s", rb_id2name(id));
169
+ return LZ4_FORMAT_BLOCK;
170
+ }
171
+
121
172
  static inline void reject_algorithm_keyword(VALUE opts) {
122
173
  if (NIL_P(opts))
123
174
  return;
@@ -161,11 +212,11 @@ static inline void dictionary_ivar_set(VALUE self, VALUE dictionary) {
161
212
  rb_ivar_set(self, id_cache.ivar_dictionary, dictionary);
162
213
  }
163
214
 
164
- static inline uint32_t read_le_u32(const uint8_t *p) {
215
+ MC_ALWAYS_INLINE uint32_t read_le_u32(const uint8_t *restrict p) {
165
216
  return (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16) | ((uint32_t)p[3] << 24);
166
217
  }
167
218
 
168
- static inline void write_le_u32(uint8_t *p, uint32_t v) {
219
+ MC_ALWAYS_INLINE void write_le_u32(uint8_t *restrict p, uint32_t v) {
169
220
  p[0] = (uint8_t)(v & 0xFF);
170
221
  p[1] = (uint8_t)((v >> 8) & 0xFF);
171
222
  p[2] = (uint8_t)((v >> 16) & 0xFF);
@@ -223,6 +274,8 @@ static const level_spec_t level_spec[] = {
223
274
  [ALGO_BROTLI] =
224
275
  {.min = 0, .max = 11, .fastest = 0, .default_ = 6, .best = 11, .name = "brotli"},
225
276
  };
277
+ _Static_assert(sizeof(level_spec) / sizeof(level_spec[0]) == MC_NUM_ALGOS,
278
+ "level_spec must cover every compress_algo_t value");
226
279
 
227
280
  static int resolve_level(compress_algo_t algo, VALUE level_val) {
228
281
  const level_spec_t *spec = &level_spec[algo];
@@ -241,7 +294,7 @@ static int resolve_level(compress_algo_t algo, VALUE level_val) {
241
294
  rb_raise(eLevelError, "Unknown named level: %s", rb_id2name(id));
242
295
  }
243
296
 
244
- int level = NUM2INT(level_val);
297
+ const int level = NUM2INT(level_val);
245
298
  if (level < spec->min || level > spec->max)
246
299
  rb_raise(eLevelError, "%s level must be %d..%d, got %d", spec->name, spec->min, spec->max,
247
300
  level);
@@ -249,23 +302,39 @@ static int resolve_level(compress_algo_t algo, VALUE level_val) {
249
302
  }
250
303
 
251
304
  static compress_algo_t detect_algo(const uint8_t *data, size_t len) {
252
- if (len >= 4) {
253
- if (data[0] == 0x28 && data[1] == 0xB5 && data[2] == 0x2F && data[3] == 0xFD) {
254
- return ALGO_ZSTD;
255
- }
305
+ enum { ZSTD_MAGIC_LEN = 4 };
306
+ static const uint8_t ZSTD_MAGIC[ZSTD_MAGIC_LEN] = {0x28, 0xB5, 0x2F, 0xFD};
307
+ enum { LZ4_BLOCK_SANITY_MAX = 256U * 1024 * 1024 };
308
+
309
+ if (is_lz4_frame_magic(data, len)) {
310
+ return ALGO_LZ4;
311
+ }
312
+
313
+ if (len >= ZSTD_MAGIC_LEN && memcmp(data, ZSTD_MAGIC, ZSTD_MAGIC_LEN) == 0) {
314
+ return ALGO_ZSTD;
256
315
  }
257
316
 
258
317
  if (len >= 12) {
259
- uint32_t orig = read_le_u32(data);
260
- uint32_t comp = read_le_u32(data + 4);
261
- if (orig > 0 && orig <= 256U * 1024 * 1024 && comp > 0 && comp <= 256U * 1024 * 1024 &&
262
- orig <= (uint32_t)INT_MAX && comp <= (uint32_t)LZ4_compressBound((int)orig) &&
263
- (size_t)8 + (size_t)comp + 4 == len) {
264
- size_t tail = 8 + (size_t)comp;
265
- if (data[tail] == 0 && data[tail + 1] == 0 && data[tail + 2] == 0 &&
266
- data[tail + 3] == 0) {
267
- return ALGO_LZ4;
318
+ size_t pos = 0;
319
+ int saw_block = 0;
320
+ while (pos + 4 <= len) {
321
+ uint32_t orig = read_le_u32(data + pos);
322
+ if (orig == 0) {
323
+ if (saw_block && pos + 4 == len)
324
+ return ALGO_LZ4;
325
+ break;
268
326
  }
327
+ if (pos + 8 > len)
328
+ break;
329
+ uint32_t comp = read_le_u32(data + pos + 4);
330
+ if (comp == 0 || orig > LZ4_BLOCK_SANITY_MAX || orig > (uint32_t)INT_MAX)
331
+ break;
332
+ if (comp > LZ4_BLOCK_SANITY_MAX || comp > (uint32_t)LZ4_compressBound((int)orig))
333
+ break;
334
+ if (pos + 8 + (size_t)comp > len)
335
+ break;
336
+ saw_block = 1;
337
+ pos += 8 + (size_t)comp;
269
338
  }
270
339
  }
271
340
 
@@ -347,8 +416,8 @@ static void parse_limits_from_opts(VALUE opts, limits_config_t *limits) {
347
416
  limits_config_apply_opts(opts, limits);
348
417
  }
349
418
 
350
- static size_t checked_add_size(size_t left, size_t right, const char *message) {
351
- if (SIZE_MAX - left < right)
419
+ static inline size_t checked_add_size(size_t left, size_t right, const char *message) {
420
+ if (MC_UNLIKELY(SIZE_MAX - left < right))
352
421
  rb_raise(eDataError, "%s", message);
353
422
  return left + right;
354
423
  }
@@ -361,10 +430,10 @@ static size_t ratio_limit_bytes(size_t total_input, unsigned long long max_ratio
361
430
  return total_input * (size_t)max_ratio;
362
431
  }
363
432
 
364
- static void enforce_output_and_ratio_limits(size_t total_output, size_t total_input,
365
- size_t max_output_size, int max_ratio_enabled,
366
- unsigned long long max_ratio) {
367
- if (total_output > max_output_size) {
433
+ static inline void enforce_output_and_ratio_limits(size_t total_output, size_t total_input,
434
+ size_t max_output_size, int max_ratio_enabled,
435
+ unsigned long long max_ratio) {
436
+ if (MC_UNLIKELY(total_output > max_output_size)) {
368
437
  rb_raise(eDataError, "decompressed output exceeds limit (%zu bytes)", max_output_size);
369
438
  }
370
439
 
@@ -372,7 +441,7 @@ static void enforce_output_and_ratio_limits(size_t total_output, size_t total_in
372
441
  return;
373
442
 
374
443
  size_t ratio_limit = ratio_limit_bytes(total_input, max_ratio);
375
- if (total_output > ratio_limit) {
444
+ if (MC_UNLIKELY(total_output > ratio_limit)) {
376
445
  size_t ratio = total_input == 0 ? 0 : (total_output / total_input);
377
446
  rb_raise(eDataError, "decompression ratio exceeds limit (ratio=%zu, max=%llu)", ratio,
378
447
  max_ratio);
@@ -401,12 +470,75 @@ static inline void run_without_gvl(void *(*func)(void *), void *arg) {
401
470
  typedef struct {
402
471
  void *(*func)(void *);
403
472
  void *arg;
404
-
405
- VALUE scheduler;
406
- VALUE blocker;
407
- VALUE fiber;
473
+ size_t arg_size;
474
+ VALUE thread;
408
475
  } fiber_worker_ctx_t;
409
476
 
477
+ typedef enum {
478
+ WORK_EXEC_DIRECT = 0,
479
+ WORK_EXEC_NOGVL = 1,
480
+ WORK_EXEC_FIBER = 2,
481
+ } work_exec_mode_t;
482
+
483
+ static void fiber_worker_mark(void *ptr) {
484
+ fiber_worker_ctx_t *c = (fiber_worker_ctx_t *)ptr;
485
+ if (!c)
486
+ return;
487
+ rb_gc_mark(c->thread);
488
+ }
489
+
490
+ static void fiber_worker_free(void *ptr) {
491
+ fiber_worker_ctx_t *c = (fiber_worker_ctx_t *)ptr;
492
+ if (!c)
493
+ return;
494
+ if (c->arg)
495
+ xfree(c->arg);
496
+ xfree(c);
497
+ }
498
+
499
+ static size_t fiber_worker_memsize(const void *ptr) {
500
+ const fiber_worker_ctx_t *c = (const fiber_worker_ctx_t *)ptr;
501
+ return sizeof(fiber_worker_ctx_t) + (c ? c->arg_size : 0);
502
+ }
503
+
504
+ static const rb_data_type_t fiber_worker_type = {
505
+ "MultiCompress/FiberWorker",
506
+ {fiber_worker_mark, fiber_worker_free, fiber_worker_memsize},
507
+ 0,
508
+ 0,
509
+ RUBY_TYPED_FREE_IMMEDIATELY};
510
+
511
+ static inline work_exec_mode_t select_fiber_or_direct_mode(VALUE scheduler, size_t work_size,
512
+ size_t fiber_threshold) {
513
+ if (scheduler != Qnil && work_size >= fiber_threshold)
514
+ return WORK_EXEC_FIBER;
515
+ return WORK_EXEC_DIRECT;
516
+ }
517
+
518
+ static inline work_exec_mode_t select_fiber_nogvl_or_direct_mode(VALUE scheduler, size_t work_size,
519
+ size_t fiber_threshold,
520
+ size_t nogvl_threshold) {
521
+ if (scheduler != Qnil && work_size >= fiber_threshold)
522
+ return WORK_EXEC_FIBER;
523
+ if (scheduler == Qnil && work_size >= nogvl_threshold)
524
+ return WORK_EXEC_NOGVL;
525
+ return WORK_EXEC_DIRECT;
526
+ }
527
+
528
+ static VALUE fiber_worker_new(void *(*func)(void *), const void *arg, size_t arg_size) {
529
+ fiber_worker_ctx_t *c;
530
+ VALUE worker = TypedData_Make_Struct(rb_cObject, fiber_worker_ctx_t, &fiber_worker_type, c);
531
+ memset(c, 0, sizeof(*c));
532
+ c->func = func;
533
+ c->arg_size = arg_size;
534
+ if (arg_size > 0) {
535
+ c->arg = xmalloc(arg_size);
536
+ memcpy(c->arg, arg, arg_size);
537
+ }
538
+ c->thread = Qnil;
539
+ return worker;
540
+ }
541
+
410
542
  static void *fiber_worker_nogvl(void *arg) {
411
543
  fiber_worker_ctx_t *c = (fiber_worker_ctx_t *)arg;
412
544
  c->func(c->arg);
@@ -415,24 +547,55 @@ static void *fiber_worker_nogvl(void *arg) {
415
547
 
416
548
  static VALUE fiber_worker_thread(void *arg) {
417
549
  fiber_worker_ctx_t *c = (fiber_worker_ctx_t *)arg;
418
- rb_thread_call_without_gvl(fiber_worker_nogvl, c, RUBY_UBF_PROCESS, NULL);
419
- rb_fiber_scheduler_unblock(c->scheduler, c->blocker, c->fiber);
550
+ rb_thread_call_without_gvl(fiber_worker_nogvl, c, unblock_noop, NULL);
551
+ return Qnil;
552
+ }
553
+
554
+ static VALUE fiber_worker_wait(VALUE worker) {
555
+ fiber_worker_ctx_t *c;
556
+ TypedData_Get_Struct(worker, fiber_worker_ctx_t, &fiber_worker_type, c);
557
+ c->thread = rb_thread_create(fiber_worker_thread, c);
558
+ join_thread(c->thread);
559
+ return Qnil;
560
+ }
561
+
562
+ static VALUE fiber_worker_join_ensure(VALUE worker) {
563
+ fiber_worker_ctx_t *c;
564
+ TypedData_Get_Struct(worker, fiber_worker_ctx_t, &fiber_worker_type, c);
565
+ if (!NIL_P(c->thread))
566
+ join_thread(c->thread);
420
567
  return Qnil;
421
568
  }
422
569
 
423
- static void run_via_fiber_worker(VALUE scheduler, void *(*func)(void *), void *arg) {
424
- fiber_worker_ctx_t ctx = {
425
- .func = func,
426
- .arg = arg,
427
- .scheduler = scheduler,
428
- .blocker = rb_obj_alloc(rb_cObject),
429
- .fiber = rb_fiber_current(),
430
- };
431
- VALUE th = rb_thread_create(fiber_worker_thread, &ctx);
432
- rb_fiber_scheduler_block(scheduler, ctx.blocker, Qnil);
433
- join_thread(th);
570
+ static void run_via_fiber_worker(void *(*func)(void *), void *arg, size_t arg_size) {
571
+ VALUE worker = fiber_worker_new(func, arg, arg_size);
572
+ rb_ensure(fiber_worker_wait, worker, fiber_worker_join_ensure, worker);
573
+ fiber_worker_ctx_t *c;
574
+ TypedData_Get_Struct(worker, fiber_worker_ctx_t, &fiber_worker_type, c);
575
+ if (arg_size > 0)
576
+ memcpy(arg, c->arg, arg_size);
577
+ RB_GC_GUARD(worker);
578
+ }
579
+
580
+ static inline void run_with_exec_mode(work_exec_mode_t mode, void *(*func)(void *), void *arg,
581
+ size_t arg_size) {
582
+ switch (mode) {
583
+ case WORK_EXEC_FIBER:
584
+ run_via_fiber_worker(func, arg, arg_size);
585
+ break;
586
+ case WORK_EXEC_NOGVL:
587
+ run_without_gvl(func, arg);
588
+ break;
589
+ case WORK_EXEC_DIRECT:
590
+ default:
591
+ func(arg);
592
+ break;
593
+ }
434
594
  }
435
595
 
596
+ #define RUN_VIA_FIBER_WORKER(func, arg) run_via_fiber_worker((func), &(arg), sizeof(arg))
597
+ #define RUN_WITH_EXEC_MODE(mode, func, arg) run_with_exec_mode((mode), (func), &(arg), sizeof(arg))
598
+
436
599
  static inline size_t fiber_maybe_yield(size_t bytes_since_yield, size_t just_processed,
437
600
  size_t yield_chunk, int *did_yield) {
438
601
  *did_yield = 0;
@@ -448,7 +611,8 @@ static inline size_t fiber_maybe_yield(size_t bytes_since_yield, size_t just_pro
448
611
  return bytes_since_yield;
449
612
  }
450
613
 
451
- #define DICT_CDICT_CACHE_SIZE 4
614
+ #define DICT_CDICT_CACHE_SIZE 22
615
+ _Static_assert(DICT_CDICT_CACHE_SIZE > 0, "CDict cache needs at least one slot");
452
616
 
453
617
  typedef struct {
454
618
  int level;
@@ -459,6 +623,7 @@ struct dictionary_s {
459
623
  compress_algo_t algo;
460
624
  uint8_t *data;
461
625
  size_t size;
626
+ pthread_mutex_t cache_mutex;
462
627
 
463
628
  cdict_cache_entry_t cdict_cache[DICT_CDICT_CACHE_SIZE];
464
629
  int cdict_cache_count;
@@ -468,12 +633,15 @@ struct dictionary_s {
468
633
 
469
634
  static void dict_free(void *ptr) {
470
635
  dictionary_t *dict = (dictionary_t *)ptr;
636
+ if (!dict)
637
+ return;
471
638
  for (int i = 0; i < dict->cdict_cache_count; i++) {
472
639
  if (dict->cdict_cache[i].cdict)
473
640
  ZSTD_freeCDict(dict->cdict_cache[i].cdict);
474
641
  }
475
642
  if (dict->ddict)
476
643
  ZSTD_freeDDict(dict->ddict);
644
+ pthread_mutex_destroy(&dict->cache_mutex);
477
645
  if (dict->data)
478
646
  xfree(dict->data);
479
647
  xfree(dict);
@@ -481,14 +649,18 @@ static void dict_free(void *ptr) {
481
649
 
482
650
  static size_t dict_memsize(const void *ptr) {
483
651
  const dictionary_t *d = (const dictionary_t *)ptr;
484
- size_t total = sizeof(dictionary_t) + d->size;
652
+ if (!d)
653
+ return 0;
485
654
 
486
- for (int i = 0; i < d->cdict_cache_count; i++) {
487
- if (d->cdict_cache[i].cdict)
488
- total += d->size + 4096;
655
+ size_t total = sizeof(dictionary_t) + d->size;
656
+ if (d->algo == ALGO_ZSTD) {
657
+ for (int i = 0; i < d->cdict_cache_count; i++) {
658
+ if (d->cdict_cache[i].cdict)
659
+ total += ZSTD_sizeof_CDict(d->cdict_cache[i].cdict);
660
+ }
661
+ if (d->ddict)
662
+ total += ZSTD_sizeof_DDict(d->ddict);
489
663
  }
490
- if (d->ddict)
491
- total += d->size + 4096;
492
664
  return total;
493
665
  }
494
666
 
@@ -498,47 +670,78 @@ static const rb_data_type_t dictionary_type = {
498
670
  static VALUE dict_alloc(VALUE klass) {
499
671
  dictionary_t *d = ALLOC(dictionary_t);
500
672
  memset(d, 0, sizeof(dictionary_t));
673
+ if (pthread_mutex_init(&d->cache_mutex, NULL) != 0) {
674
+ xfree(d);
675
+ rb_raise(eMemError, "failed to initialize dictionary cache mutex");
676
+ }
501
677
  return TypedData_Wrap_Struct(klass, &dictionary_type, d);
502
678
  }
503
679
 
504
680
  static ZSTD_CDict *dict_get_cdict(dictionary_t *dict, int level) {
681
+ ZSTD_CDict *existing = NULL;
682
+
683
+ pthread_mutex_lock(&dict->cache_mutex);
505
684
  for (int i = 0; i < dict->cdict_cache_count; i++) {
506
- if (dict->cdict_cache[i].level == level)
507
- return dict->cdict_cache[i].cdict;
685
+ if (dict->cdict_cache[i].level == level) {
686
+ existing = dict->cdict_cache[i].cdict;
687
+ break;
688
+ }
508
689
  }
690
+ pthread_mutex_unlock(&dict->cache_mutex);
691
+
692
+ if (existing)
693
+ return existing;
509
694
 
510
695
  ZSTD_CDict *cdict = ZSTD_createCDict(dict->data, dict->size, level);
511
696
  if (!cdict)
512
697
  return NULL;
513
698
 
699
+ pthread_mutex_lock(&dict->cache_mutex);
514
700
  for (int i = 0; i < dict->cdict_cache_count; i++) {
515
701
  if (dict->cdict_cache[i].level == level) {
702
+ existing = dict->cdict_cache[i].cdict;
703
+ pthread_mutex_unlock(&dict->cache_mutex);
516
704
  ZSTD_freeCDict(cdict);
517
- return dict->cdict_cache[i].cdict;
705
+ return existing;
518
706
  }
519
707
  }
520
708
 
521
- if (dict->cdict_cache_count < DICT_CDICT_CACHE_SIZE) {
522
- dict->cdict_cache[dict->cdict_cache_count].level = level;
523
- dict->cdict_cache[dict->cdict_cache_count].cdict = cdict;
524
- dict->cdict_cache_count++;
525
- } else {
526
- ZSTD_CDict *old_cdict = dict->cdict_cache[0].cdict;
527
- memmove(&dict->cdict_cache[0], &dict->cdict_cache[1],
528
- sizeof(cdict_cache_entry_t) * (DICT_CDICT_CACHE_SIZE - 1));
529
- dict->cdict_cache[DICT_CDICT_CACHE_SIZE - 1].level = level;
530
- dict->cdict_cache[DICT_CDICT_CACHE_SIZE - 1].cdict = cdict;
531
- if (old_cdict)
532
- ZSTD_freeCDict(old_cdict);
709
+ if (dict->cdict_cache_count >= DICT_CDICT_CACHE_SIZE) {
710
+ pthread_mutex_unlock(&dict->cache_mutex);
711
+ ZSTD_freeCDict(cdict);
712
+ rb_raise(eError, "zstd dictionary cdict cache exhausted");
533
713
  }
714
+
715
+ dict->cdict_cache[dict->cdict_cache_count].level = level;
716
+ dict->cdict_cache[dict->cdict_cache_count].cdict = cdict;
717
+ dict->cdict_cache_count++;
718
+ pthread_mutex_unlock(&dict->cache_mutex);
534
719
  return cdict;
535
720
  }
536
721
 
537
722
  static ZSTD_DDict *dict_get_ddict(dictionary_t *dict) {
723
+ ZSTD_DDict *existing;
724
+
725
+ pthread_mutex_lock(&dict->cache_mutex);
726
+ existing = dict->ddict;
727
+ pthread_mutex_unlock(&dict->cache_mutex);
728
+ if (existing)
729
+ return existing;
730
+
731
+ ZSTD_DDict *created = ZSTD_createDDict(dict->data, dict->size);
732
+ if (!created)
733
+ return NULL;
734
+
735
+ pthread_mutex_lock(&dict->cache_mutex);
538
736
  if (!dict->ddict) {
539
- dict->ddict = ZSTD_createDDict(dict->data, dict->size);
737
+ dict->ddict = created;
738
+ pthread_mutex_unlock(&dict->cache_mutex);
739
+ return created;
540
740
  }
541
- return dict->ddict;
741
+ existing = dict->ddict;
742
+ pthread_mutex_unlock(&dict->cache_mutex);
743
+ ZSTD_freeDDict(created);
744
+ return existing;
542
745
  }
543
746
 
544
747
  typedef struct {
@@ -609,9 +812,9 @@ typedef struct {
609
812
 
610
813
  static void *lz4_decompress_all_nogvl(void *arg) {
611
814
  lz4_decompress_all_args_t *a = (lz4_decompress_all_args_t *)arg;
612
- const uint8_t *src = a->src;
815
+ const uint8_t *restrict src = a->src;
613
816
  size_t slen = a->src_len;
614
- char *out_ptr = a->dst;
817
+ char *restrict out_ptr = a->dst;
615
818
  size_t out_offset = 0;
616
819
  size_t pos = 0;
617
820
 
@@ -623,13 +826,14 @@ static void *lz4_decompress_all_nogvl(void *arg) {
623
826
 
624
827
  int dsize = LZ4_decompress_safe((const char *)(src + pos + 8), out_ptr + out_offset,
625
828
  (int)comp_size, (int)orig_size);
626
- if (dsize < 0) {
829
+ if (MC_UNLIKELY(dsize < 0)) {
627
830
  a->error = 1;
628
- snprintf(a->err_msg, sizeof(a->err_msg), "lz4 decompress failed");
831
+ static const char kLz4FailMsg[] = "lz4 decompress failed";
832
+ memcpy(a->err_msg, kLz4FailMsg, sizeof(kLz4FailMsg));
629
833
  return NULL;
630
834
  }
631
835
 
632
- out_offset += dsize;
836
+ out_offset += (size_t)dsize;
633
837
  pos += 8 + comp_size;
634
838
  }
635
839
 
@@ -647,6 +851,24 @@ typedef struct {
647
851
  int result;
648
852
  } lz4_compress_args_t;
649
853
 
854
+ typedef struct {
855
+ const void *src;
856
+ size_t src_len;
857
+ void *dst;
858
+ size_t dst_cap;
859
+ size_t result;
860
+ size_t error_code;
861
+ } lz4frame_compress_args_t;
862
+
863
+ typedef struct {
864
+ const void *src;
865
+ size_t src_len;
866
+ void *dst;
867
+ size_t dst_cap;
868
+ size_t result;
869
+ size_t error_code;
870
+ } lz4frame_decompress_args_t;
871
+
650
872
  static void *lz4_compress_nogvl(void *arg) {
651
873
  lz4_compress_args_t *a = (lz4_compress_args_t *)arg;
652
874
  if (a->level > 1) {
@@ -666,6 +888,72 @@ typedef struct {
666
888
  BROTLI_BOOL result;
667
889
  } brotli_compress_args_t;
668
890
 
891
+ static void *lz4frame_compress_nogvl(void *arg) {
892
+ lz4frame_compress_args_t *a = (lz4frame_compress_args_t *)arg;
893
+ LZ4F_preferences_t prefs;
894
+ memset(&prefs, 0, sizeof(prefs));
895
+ prefs.frameInfo.blockChecksumFlag = LZ4F_blockChecksumEnabled;
896
+ prefs.frameInfo.contentChecksumFlag = LZ4F_contentChecksumEnabled;
897
+ a->result = LZ4F_compressFrame(a->dst, a->dst_cap, a->src, a->src_len, &prefs);
898
+ a->error_code = LZ4F_isError(a->result) ? a->result : 0;
899
+ return NULL;
900
+ }
901
+
902
+ static void *lz4frame_decompress_nogvl(void *arg) {
903
+ lz4frame_decompress_args_t *a = (lz4frame_decompress_args_t *)arg;
904
+ LZ4F_dctx *dctx = NULL;
905
+ size_t rc = LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
906
+ if (LZ4F_isError(rc)) {
907
+ a->result = 0;
908
+ a->error_code = rc;
909
+ return NULL;
910
+ }
911
+
912
+ const uint8_t *src = (const uint8_t *)a->src;
913
+ uint8_t *dst = (uint8_t *)a->dst;
914
+ size_t src_pos = 0;
915
+ size_t dst_pos = 0;
916
+ size_t hint = 1;
917
+
918
+ while (src_pos < a->src_len && hint != 0) {
919
+ size_t src_size = a->src_len - src_pos;
920
+ size_t dst_size = a->dst_cap - dst_pos;
921
+ rc = LZ4F_decompress(dctx, dst + dst_pos, &dst_size, src + src_pos, &src_size, NULL);
922
+ if (LZ4F_isError(rc)) {
923
+ a->result = 0;
924
+ a->error_code = rc;
925
+ LZ4F_freeDecompressionContext(dctx);
926
+ return NULL;
927
+ }
928
+ src_pos += src_size;
929
+ dst_pos += dst_size;
930
+ if (dst_pos > a->dst_cap) {
931
+ a->result = 0;
932
+ a->error_code = (size_t)-1;
933
+ LZ4F_freeDecompressionContext(dctx);
934
+ return NULL;
935
+ }
936
+ hint = rc;
937
+ if (src_size == 0 && dst_size == 0 && hint != 0)
938
+ break;
939
+ }
940
+
941
+ LZ4F_freeDecompressionContext(dctx);
942
+ if (hint != 0) {
943
+ if (dst_pos == a->dst_cap) {
944
+ a->result = 0;
945
+ a->error_code = (size_t)-1;
946
+ return NULL;
947
+ }
948
+ a->result = 0;
949
+ a->error_code = (size_t)-2;
950
+ return NULL;
951
+ }
952
+ a->result = dst_pos;
953
+ a->error_code = 0;
954
+ return NULL;
955
+ }
956
+
669
957
  static void *brotli_compress_nogvl(void *arg) {
670
958
  brotli_compress_args_t *a = (brotli_compress_args_t *)arg;
671
959
  a->result = BrotliEncoderCompress(a->level, BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE,
@@ -695,64 +983,38 @@ typedef struct {
695
983
  size_t dst_cap;
696
984
  size_t result;
697
985
  int error;
698
-
699
- VALUE scheduler;
700
- VALUE blocker;
701
- VALUE fiber;
702
986
  } zstd_fiber_compress_t;
703
987
 
704
988
  typedef struct {
705
989
  ZSTD_CStream *cstream;
706
- ZSTD_inBuffer *input;
707
- ZSTD_outBuffer *output;
990
+ ZSTD_inBuffer input;
991
+ ZSTD_outBuffer output;
708
992
  size_t result;
709
-
710
- VALUE scheduler;
711
- VALUE blocker;
712
- VALUE fiber;
713
993
  } zstd_stream_chunk_fiber_t;
714
994
 
715
995
  static void *zstd_stream_chunk_fiber_nogvl(void *arg) {
716
996
  zstd_stream_chunk_fiber_t *a = (zstd_stream_chunk_fiber_t *)arg;
717
- a->result = ZSTD_compressStream(a->cstream, a->output, a->input);
997
+ a->result = ZSTD_compressStream(a->cstream, &a->output, &a->input);
718
998
  return NULL;
719
999
  }
720
1000
 
721
- static VALUE zstd_stream_chunk_fiber_thread(void *arg) {
722
- zstd_stream_chunk_fiber_t *a = (zstd_stream_chunk_fiber_t *)arg;
723
- rb_thread_call_without_gvl(zstd_stream_chunk_fiber_nogvl, a, RUBY_UBF_PROCESS, NULL);
724
- rb_fiber_scheduler_unblock(a->scheduler, a->blocker, a->fiber);
725
- return Qnil;
726
- }
727
-
728
1001
  typedef struct {
729
1002
  BrotliEncoderState *enc;
730
1003
  BrotliEncoderOperation op;
731
- size_t *available_in;
732
- const uint8_t **next_in;
733
- size_t *available_out;
734
- uint8_t **next_out;
1004
+ size_t available_in;
1005
+ const uint8_t *next_in;
1006
+ size_t available_out;
1007
+ uint8_t *next_out;
735
1008
  BROTLI_BOOL result;
736
-
737
- VALUE scheduler;
738
- VALUE blocker;
739
- VALUE fiber;
740
1009
  } brotli_stream_chunk_fiber_t;
741
1010
 
742
1011
  static void *brotli_stream_chunk_fiber_nogvl(void *arg) {
743
1012
  brotli_stream_chunk_fiber_t *a = (brotli_stream_chunk_fiber_t *)arg;
744
- a->result = BrotliEncoderCompressStream(a->enc, a->op, a->available_in, a->next_in,
745
- a->available_out, a->next_out, NULL);
1013
+ a->result = BrotliEncoderCompressStream(a->enc, a->op, &a->available_in, &a->next_in,
1014
+ &a->available_out, &a->next_out, NULL);
746
1015
  return NULL;
747
1016
  }
748
1017
 
749
- static VALUE brotli_stream_chunk_fiber_thread(void *arg) {
750
- brotli_stream_chunk_fiber_t *a = (brotli_stream_chunk_fiber_t *)arg;
751
- rb_thread_call_without_gvl(brotli_stream_chunk_fiber_nogvl, a, RUBY_UBF_PROCESS, NULL);
752
- rb_fiber_scheduler_unblock(a->scheduler, a->blocker, a->fiber);
753
- return Qnil;
754
- }
755
-
756
1018
  typedef struct {
757
1019
  size_t encoded_size;
758
1020
  const uint8_t *encoded_buffer;
@@ -781,6 +1043,19 @@ static void *zstd_decompress_stream_chunk_nogvl(void *arg) {
781
1043
  return NULL;
782
1044
  }
783
1045
 
1046
+ typedef struct {
1047
+ ZSTD_DStream *dstream;
1048
+ ZSTD_outBuffer output;
1049
+ ZSTD_inBuffer input;
1050
+ size_t result;
1051
+ } zstd_decompress_stream_chunk_fiber_t;
1052
+
1053
+ static void *zstd_decompress_stream_chunk_fiber_nogvl(void *arg) {
1054
+ zstd_decompress_stream_chunk_fiber_t *a = (zstd_decompress_stream_chunk_fiber_t *)arg;
1055
+ a->result = ZSTD_decompressStream(a->dstream, &a->output, &a->input);
1056
+ return NULL;
1057
+ }
1058
+
784
1059
  typedef struct {
785
1060
  BrotliDecoderState *dec;
786
1061
  size_t *available_in;
@@ -797,6 +1072,22 @@ static void *brotli_decompress_stream_nogvl(void *arg) {
797
1072
  return NULL;
798
1073
  }
799
1074
 
1075
+ typedef struct {
1076
+ BrotliDecoderState *dec;
1077
+ size_t available_in;
1078
+ const uint8_t *next_in;
1079
+ size_t available_out;
1080
+ uint8_t *next_out;
1081
+ BrotliDecoderResult result;
1082
+ } brotli_decompress_stream_fiber_t;
1083
+
1084
+ static void *brotli_decompress_stream_fiber_nogvl(void *arg) {
1085
+ brotli_decompress_stream_fiber_t *a = (brotli_decompress_stream_fiber_t *)arg;
1086
+ a->result = BrotliDecoderDecompressStream(a->dec, &a->available_in, &a->next_in,
1087
+ &a->available_out, &a->next_out, NULL);
1088
+ return NULL;
1089
+ }
1090
+
800
1091
  static void *zstd_fiber_compress_nogvl(void *arg) {
801
1092
  zstd_fiber_compress_t *a = (zstd_fiber_compress_t *)arg;
802
1093
  if (a->cdict) {
@@ -814,13 +1105,6 @@ static void *zstd_fiber_compress_nogvl(void *arg) {
814
1105
  return NULL;
815
1106
  }
816
1107
 
817
- static VALUE zstd_fiber_compress_thread(void *arg) {
818
- zstd_fiber_compress_t *a = (zstd_fiber_compress_t *)arg;
819
- rb_thread_call_without_gvl(zstd_fiber_compress_nogvl, a, RUBY_UBF_PROCESS, NULL);
820
- rb_fiber_scheduler_unblock(a->scheduler, a->blocker, a->fiber);
821
- return Qnil;
822
- }
823
-
824
1108
  static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
825
1109
  VALUE data, opts;
826
1110
  rb_scan_args(argc, argv, "1:", &data, &opts);
@@ -834,7 +1118,9 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
834
1118
  dict_val = opt_get(opts, sym_cache.dictionary);
835
1119
  }
836
1120
 
837
- compress_algo_t algo = NIL_P(algo_sym) ? ALGO_ZSTD : sym_to_algo(algo_sym);
1121
+ int explicit_algo = !NIL_P(algo_sym);
1122
+ compress_algo_t algo = explicit_algo ? sym_to_algo(algo_sym) : ALGO_ZSTD;
1123
+ lz4_format_t lz4_format = parse_lz4_format(opts, algo, explicit_algo);
838
1124
  int level = resolve_level(algo, level_val);
839
1125
 
840
1126
  dictionary_t *dict = NULL;
@@ -876,18 +1162,19 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
876
1162
  rb_raise(eError, "zstd compress: %s", ZSTD_getErrorName(csize));
877
1163
  rb_str_set_len(dst, (long)csize);
878
1164
  RB_GC_GUARD(data);
1165
+ RB_GC_GUARD(dict_val);
879
1166
  return dst;
880
1167
  }
881
1168
 
882
1169
  {
883
1170
  VALUE scheduler = current_fiber_scheduler();
884
- if (scheduler != Qnil) {
1171
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1172
+ scheduler, slen, policy->gvl_unlock_threshold, policy->gvl_unlock_threshold);
1173
+
1174
+ if (mode == WORK_EXEC_FIBER) {
885
1175
  char *out_buf = (char *)malloc(bound);
886
1176
  if (!out_buf)
887
1177
  rb_raise(eMemError, "zstd: malloc failed");
888
-
889
- VALUE blocker = rb_obj_alloc(rb_cObject);
890
-
891
1178
  zstd_fiber_compress_t fargs = {
892
1179
  .src = src,
893
1180
  .src_len = slen,
@@ -897,14 +1184,9 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
897
1184
  .dst_cap = bound,
898
1185
  .result = 0,
899
1186
  .error = 0,
900
- .scheduler = scheduler,
901
- .blocker = blocker,
902
- .fiber = rb_fiber_current(),
903
1187
  };
904
1188
 
905
- VALUE rb_thread = rb_thread_create(zstd_fiber_compress_thread, &fargs);
906
- rb_fiber_scheduler_block(scheduler, blocker, Qnil);
907
- join_thread(rb_thread);
1189
+ RUN_WITH_EXEC_MODE(mode, zstd_fiber_compress_nogvl, fargs);
908
1190
 
909
1191
  if (fargs.error) {
910
1192
  free(out_buf);
@@ -918,6 +1200,7 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
918
1200
  VALUE result = rb_binary_str_new(out_buf, (long)fargs.result);
919
1201
  free(out_buf);
920
1202
  RB_GC_GUARD(data);
1203
+ RB_GC_GUARD(dict_val);
921
1204
  return result;
922
1205
  }
923
1206
  }
@@ -934,7 +1217,7 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
934
1217
  .result = 0,
935
1218
  .error = 0,
936
1219
  };
937
- run_without_gvl(zstd_compress_nogvl, &args);
1220
+ RUN_WITH_EXEC_MODE(WORK_EXEC_NOGVL, zstd_compress_nogvl, args);
938
1221
 
939
1222
  if (args.error)
940
1223
  rb_raise(eMemError, "zstd: failed to create context");
@@ -943,10 +1226,39 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
943
1226
 
944
1227
  rb_str_set_len(dst, (long)args.result);
945
1228
  RB_GC_GUARD(data);
1229
+ RB_GC_GUARD(dict_val);
946
1230
  return dst;
947
1231
  }
948
1232
  }
949
1233
  case ALGO_LZ4: {
1234
+ if (lz4_format == LZ4_FORMAT_FRAME) {
1235
+ LZ4F_preferences_t prefs;
1236
+ memset(&prefs, 0, sizeof(prefs));
1237
+ prefs.frameInfo.blockChecksumFlag = LZ4F_blockChecksumEnabled;
1238
+ prefs.frameInfo.contentChecksumFlag = LZ4F_contentChecksumEnabled;
1239
+ size_t bound = LZ4F_compressFrameBound(slen, &prefs);
1240
+ VALUE dst = rb_binary_str_buf_reserve((long)bound);
1241
+ lz4frame_compress_args_t args = {
1242
+ .src = src,
1243
+ .src_len = slen,
1244
+ .dst = RSTRING_PTR(dst),
1245
+ .dst_cap = bound,
1246
+ .result = 0,
1247
+ .error_code = 0,
1248
+ };
1249
+ {
1250
+ VALUE scheduler = current_fiber_scheduler();
1251
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1252
+ scheduler, slen, policy->gvl_unlock_threshold, policy->gvl_unlock_threshold);
1253
+ RUN_WITH_EXEC_MODE(mode, lz4frame_compress_nogvl, args);
1254
+ }
1255
+ if (args.error_code)
1256
+ rb_raise(eError, "lz4 frame compress failed: %s",
1257
+ LZ4F_getErrorName(args.error_code));
1258
+ rb_str_set_len(dst, (long)args.result);
1259
+ RB_GC_GUARD(data);
1260
+ return dst;
1261
+ }
950
1262
  if (slen > (size_t)INT_MAX)
951
1263
  rb_raise(eError, "lz4: input too large (max 2GB)");
952
1264
  int bound = LZ4_compressBound((int)slen);
@@ -967,11 +1279,9 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
967
1279
  };
968
1280
 
969
1281
  VALUE scheduler = current_fiber_scheduler();
970
- if (scheduler != Qnil) {
971
- run_via_fiber_worker(scheduler, lz4_compress_nogvl, &args);
972
- } else {
973
- run_without_gvl(lz4_compress_nogvl, &args);
974
- }
1282
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1283
+ scheduler, slen, policy->gvl_unlock_threshold, policy->gvl_unlock_threshold);
1284
+ RUN_WITH_EXEC_MODE(mode, lz4_compress_nogvl, args);
975
1285
  csize = args.result;
976
1286
 
977
1287
  if (csize <= 0)
@@ -1052,6 +1362,7 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
1052
1362
 
1053
1363
  rb_str_set_len(dst, initial_out - available_out);
1054
1364
  RB_GC_GUARD(data);
1365
+ RB_GC_GUARD(dict_val);
1055
1366
  return dst;
1056
1367
  } else if (slen >= policy->gvl_unlock_threshold) {
1057
1368
  VALUE dst = rb_binary_str_buf_reserve(out_len);
@@ -1066,11 +1377,9 @@ static VALUE compress_compress(int argc, VALUE *argv, VALUE self) {
1066
1377
  };
1067
1378
 
1068
1379
  VALUE scheduler = current_fiber_scheduler();
1069
- if (scheduler != Qnil) {
1070
- run_via_fiber_worker(scheduler, brotli_compress_nogvl, &args);
1071
- } else {
1072
- run_without_gvl(brotli_compress_nogvl, &args);
1073
- }
1380
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1381
+ scheduler, slen, policy->gvl_unlock_threshold, policy->gvl_unlock_threshold);
1382
+ RUN_WITH_EXEC_MODE(mode, brotli_compress_nogvl, args);
1074
1383
 
1075
1384
  if (!args.result)
1076
1385
  rb_raise(eError, "brotli compress failed");
@@ -1112,12 +1421,14 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1112
1421
  const uint8_t *src = (const uint8_t *)RSTRING_PTR(data);
1113
1422
  size_t slen = RSTRING_LEN(data);
1114
1423
 
1424
+ int explicit_algo = !NIL_P(algo_sym);
1115
1425
  compress_algo_t algo;
1116
- if (NIL_P(algo_sym)) {
1426
+ if (!explicit_algo) {
1117
1427
  algo = detect_algo(src, slen);
1118
1428
  } else {
1119
1429
  algo = sym_to_algo(algo_sym);
1120
1430
  }
1431
+ lz4_format_t lz4_format = parse_lz4_format(opts, algo, explicit_algo);
1121
1432
 
1122
1433
  const algo_policy_t *policy = algo_policy(algo);
1123
1434
 
@@ -1167,11 +1478,10 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1167
1478
  };
1168
1479
 
1169
1480
  VALUE scheduler = current_fiber_scheduler();
1170
- if (scheduler != Qnil) {
1171
- run_via_fiber_worker(scheduler, zstd_decompress_nogvl, &args);
1172
- } else {
1173
- run_without_gvl(zstd_decompress_nogvl, &args);
1174
- }
1481
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1482
+ scheduler, frame_size, policy->gvl_unlock_threshold,
1483
+ policy->gvl_unlock_threshold);
1484
+ RUN_WITH_EXEC_MODE(mode, zstd_decompress_nogvl, args);
1175
1485
 
1176
1486
  if (args.error)
1177
1487
  rb_raise(eMemError, "zstd: failed to create dctx");
@@ -1183,6 +1493,7 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1183
1493
  limits.max_ratio_enabled, limits.max_ratio);
1184
1494
  rb_str_set_len(dst, (long)dsize);
1185
1495
  RB_GC_GUARD(data);
1496
+ RB_GC_GUARD(dict_val);
1186
1497
  return dst;
1187
1498
  } else {
1188
1499
  VALUE dst = rb_binary_str_buf_reserve((size_t)frame_size);
@@ -1207,6 +1518,7 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1207
1518
  limits.max_ratio_enabled, limits.max_ratio);
1208
1519
  rb_str_set_len(dst, dsize);
1209
1520
  RB_GC_GUARD(data);
1521
+ RB_GC_GUARD(dict_val);
1210
1522
  return dst;
1211
1523
  }
1212
1524
  }
@@ -1278,9 +1590,57 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1278
1590
  ZSTD_freeDCtx(dctx);
1279
1591
  rb_str_set_len(dst, total_out);
1280
1592
  RB_GC_GUARD(data);
1593
+ RB_GC_GUARD(dict_val);
1281
1594
  return dst;
1282
1595
  }
1283
1596
  case ALGO_LZ4: {
1597
+ if (lz4_format == LZ4_FORMAT_FRAME || is_lz4_frame_magic(src, slen)) {
1598
+ size_t alloc_size =
1599
+ (slen > limits.max_output_size / 4) ? limits.max_output_size : slen * 4;
1600
+ if (alloc_size < 4096)
1601
+ alloc_size = limits.max_output_size < 4096 ? limits.max_output_size : 4096;
1602
+ if (alloc_size == 0)
1603
+ alloc_size = limits.max_output_size;
1604
+ VALUE dst = rb_binary_str_buf_reserve((long)alloc_size);
1605
+ while (1) {
1606
+ lz4frame_decompress_args_t args = {
1607
+ .src = src,
1608
+ .src_len = slen,
1609
+ .dst = RSTRING_PTR(dst),
1610
+ .dst_cap = alloc_size,
1611
+ .result = 0,
1612
+ .error_code = 0,
1613
+ };
1614
+ {
1615
+ VALUE scheduler = current_fiber_scheduler();
1616
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1617
+ scheduler, slen, policy->gvl_unlock_threshold,
1618
+ policy->gvl_unlock_threshold);
1619
+ RUN_WITH_EXEC_MODE(mode, lz4frame_decompress_nogvl, args);
1620
+ }
1621
+ if (args.error_code == (size_t)-1) {
1622
+ if (alloc_size >= limits.max_output_size)
1623
+ rb_raise(eDataError, "decompressed output exceeds limit (%zu bytes)",
1624
+ limits.max_output_size);
1625
+ size_t next_cap = alloc_size * 2;
1626
+ if (next_cap > limits.max_output_size)
1627
+ next_cap = limits.max_output_size;
1628
+ alloc_size = next_cap;
1629
+ grow_binary_str(dst, 0, alloc_size);
1630
+ continue;
1631
+ }
1632
+ if (args.error_code == (size_t)-2)
1633
+ rb_raise(eDataError, "lz4 frame decompress failed: truncated frame");
1634
+ if (args.error_code)
1635
+ rb_raise(eDataError, "lz4 frame decompress failed: %s",
1636
+ LZ4F_getErrorName(args.error_code));
1637
+ enforce_output_and_ratio_limits(args.result, slen, limits.max_output_size,
1638
+ limits.max_ratio_enabled, limits.max_ratio);
1639
+ rb_str_set_len(dst, (long)args.result);
1640
+ RB_GC_GUARD(data);
1641
+ return dst;
1642
+ }
1643
+ }
1284
1644
  if (slen < 4)
1285
1645
  rb_raise(eDataError, "lz4: data too short");
1286
1646
 
@@ -1318,15 +1678,12 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1318
1678
  .error = 0,
1319
1679
  };
1320
1680
 
1321
- if (total_orig >= algo_policy(ALGO_LZ4)->gvl_unlock_threshold) {
1681
+ {
1322
1682
  VALUE scheduler = current_fiber_scheduler();
1323
- if (scheduler != Qnil) {
1324
- run_via_fiber_worker(scheduler, lz4_decompress_all_nogvl, &args);
1325
- } else {
1326
- run_without_gvl(lz4_decompress_all_nogvl, &args);
1327
- }
1328
- } else {
1329
- lz4_decompress_all_nogvl(&args);
1683
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
1684
+ scheduler, total_orig, algo_policy(ALGO_LZ4)->gvl_unlock_threshold,
1685
+ algo_policy(ALGO_LZ4)->gvl_unlock_threshold);
1686
+ RUN_WITH_EXEC_MODE(mode, lz4_decompress_all_nogvl, args);
1330
1687
  }
1331
1688
 
1332
1689
  if (args.error)
@@ -1376,16 +1733,21 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1376
1733
  available_out = remaining_budget;
1377
1734
  uint8_t *next_out = (uint8_t *)RSTRING_PTR(dst) + total_out;
1378
1735
 
1379
- if (scheduler != Qnil && available_in >= policy->fiber_stream_threshold) {
1380
- brotli_decompress_stream_args_t sargs = {
1736
+ if (select_fiber_or_direct_mode(scheduler, available_in,
1737
+ policy->fiber_stream_threshold) == WORK_EXEC_FIBER) {
1738
+ brotli_decompress_stream_fiber_t sargs = {
1381
1739
  .dec = dec,
1382
- .available_in = &available_in,
1383
- .next_in = &next_in,
1384
- .available_out = &available_out,
1385
- .next_out = &next_out,
1740
+ .available_in = available_in,
1741
+ .next_in = next_in,
1742
+ .available_out = available_out,
1743
+ .next_out = next_out,
1386
1744
  .result = BROTLI_DECODER_RESULT_ERROR,
1387
1745
  };
1388
- run_via_fiber_worker(scheduler, brotli_decompress_stream_nogvl, &sargs);
1746
+ RUN_VIA_FIBER_WORKER(brotli_decompress_stream_fiber_nogvl, sargs);
1747
+ available_in = sargs.available_in;
1748
+ next_in = sargs.next_in;
1749
+ available_out = sargs.available_out;
1750
+ next_out = sargs.next_out;
1389
1751
  res = sargs.result;
1390
1752
  } else {
1391
1753
  res = BrotliDecoderDecompressStream(dec, &available_in, &next_in, &available_out,
@@ -1417,6 +1779,7 @@ static VALUE compress_decompress(int argc, VALUE *argv, VALUE self) {
1417
1779
  }
1418
1780
  rb_str_set_len(dst, total_out);
1419
1781
  RB_GC_GUARD(data);
1782
+ RB_GC_GUARD(dict_val);
1420
1783
  return dst;
1421
1784
  }
1422
1785
  }
@@ -1450,14 +1813,14 @@ static void crc32_init_tables(void) {
1450
1813
  crc32_tables_initialized = 1;
1451
1814
  }
1452
1815
 
1453
- static uint32_t crc32_compute(const uint8_t *data, size_t len, uint32_t crc) {
1816
+ static uint32_t crc32_compute(const uint8_t *restrict data, size_t len, uint32_t crc) {
1454
1817
  crc = ~crc;
1455
1818
 
1456
1819
  while (len >= 8) {
1457
- uint32_t val0 = crc ^ ((uint32_t)data[0] | ((uint32_t)data[1] << 8) |
1458
- ((uint32_t)data[2] << 16) | ((uint32_t)data[3] << 24));
1459
- uint32_t val1 = (uint32_t)data[4] | ((uint32_t)data[5] << 8) | ((uint32_t)data[6] << 16) |
1460
- ((uint32_t)data[7] << 24);
1820
+ const uint32_t val0 = crc ^ ((uint32_t)data[0] | ((uint32_t)data[1] << 8) |
1821
+ ((uint32_t)data[2] << 16) | ((uint32_t)data[3] << 24));
1822
+ const uint32_t val1 = (uint32_t)data[4] | ((uint32_t)data[5] << 8) |
1823
+ ((uint32_t)data[6] << 16) | ((uint32_t)data[7] << 24);
1461
1824
 
1462
1825
  crc = crc32_tables[7][(val0) & 0xFF] ^ crc32_tables[6][(val0 >> 8) & 0xFF] ^
1463
1826
  crc32_tables[5][(val0 >> 16) & 0xFF] ^ crc32_tables[4][(val0 >> 24) & 0xFF] ^
@@ -1488,28 +1851,49 @@ static VALUE compress_crc32(int argc, VALUE *argv, VALUE self) {
1488
1851
  }
1489
1852
 
1490
1853
  static VALUE compress_adler32(int argc, VALUE *argv, VALUE self) {
1854
+ (void)self;
1491
1855
  VALUE data, prev;
1492
1856
  rb_scan_args(argc, argv, "11", &data, &prev);
1493
1857
  StringValue(data);
1494
1858
 
1495
- const uint8_t *src = (const uint8_t *)RSTRING_PTR(data);
1496
- size_t len = RSTRING_LEN(data);
1497
- uint32_t adler = NIL_P(prev) ? 1 : NUM2UINT(prev);
1859
+ const uint8_t *restrict src = (const uint8_t *)RSTRING_PTR(data);
1860
+ size_t len = (size_t)RSTRING_LEN(data);
1861
+ const uint32_t adler = NIL_P(prev) ? 1u : NUM2UINT(prev);
1498
1862
 
1499
- uint32_t s1 = adler & 0xFFFF;
1500
- uint32_t s2 = (adler >> 16) & 0xFFFF;
1501
- const uint32_t BASE = 65521;
1863
+ uint32_t s1 = adler & 0xFFFFu;
1864
+ uint32_t s2 = (adler >> 16) & 0xFFFFu;
1865
+ enum { ADLER_BASE = 65521, ADLER_NMAX = 5552 };
1502
1866
 
1503
1867
  while (len > 0) {
1504
- size_t chunk = len > 5552 ? 5552 : len;
1868
+ size_t chunk = len > ADLER_NMAX ? (size_t)ADLER_NMAX : len;
1505
1869
  len -= chunk;
1506
- for (size_t i = 0; i < chunk; i++) {
1507
- s1 += src[i];
1870
+
1871
+ while (chunk >= 8) {
1872
+ s1 += src[0];
1873
+ s2 += s1;
1874
+ s1 += src[1];
1875
+ s2 += s1;
1876
+ s1 += src[2];
1877
+ s2 += s1;
1878
+ s1 += src[3];
1879
+ s2 += s1;
1880
+ s1 += src[4];
1881
+ s2 += s1;
1882
+ s1 += src[5];
1883
+ s2 += s1;
1884
+ s1 += src[6];
1885
+ s2 += s1;
1886
+ s1 += src[7];
1887
+ s2 += s1;
1888
+ src += 8;
1889
+ chunk -= 8;
1890
+ }
1891
+ while (chunk--) {
1892
+ s1 += *src++;
1508
1893
  s2 += s1;
1509
1894
  }
1510
- s1 %= BASE;
1511
- s2 %= BASE;
1512
- src += chunk;
1895
+ s1 %= ADLER_BASE;
1896
+ s2 %= ADLER_BASE;
1513
1897
  }
1514
1898
 
1515
1899
  return UINT2NUM((s2 << 16) | s1);
@@ -1545,6 +1929,9 @@ static VALUE compress_version(VALUE self, VALUE algo_sym) {
1545
1929
  #define LZ4_RING_BUFFER_SIZE (64 * 1024)
1546
1930
  #define LZ4_RING_BUFFER_TOTAL (LZ4_RING_BUFFER_SIZE * 2)
1547
1931
 
1932
+ _Static_assert(LZ4_RING_BUFFER_TOTAL == 2 * LZ4_RING_BUFFER_SIZE,
1933
+ "ring buffer total must be exactly twice the window size");
1934
+
1548
1935
  typedef struct {
1549
1936
  compress_algo_t algo;
1550
1937
  int level;
@@ -1589,9 +1976,22 @@ static void deflater_free(void *ptr) {
1589
1976
 
1590
1977
  static size_t deflater_memsize(const void *ptr) {
1591
1978
  const deflater_t *d = (const deflater_t *)ptr;
1979
+ if (!d)
1980
+ return 0;
1981
+
1592
1982
  size_t s = sizeof(deflater_t);
1593
- if (d->lz4_ring.buf)
1594
- s += LZ4_RING_BUFFER_TOTAL;
1983
+ switch (d->algo) {
1984
+ case ALGO_ZSTD:
1985
+ if (d->ctx.zstd)
1986
+ s += ZSTD_sizeof_CStream(d->ctx.zstd);
1987
+ break;
1988
+ case ALGO_BROTLI:
1989
+ break;
1990
+ case ALGO_LZ4:
1991
+ if (d->lz4_ring.buf)
1992
+ s += LZ4_RING_BUFFER_TOTAL;
1993
+ break;
1994
+ }
1595
1995
  return s;
1596
1996
  }
1597
1997
 
@@ -1711,7 +2111,12 @@ static VALUE lz4_compress_ring_block(deflater_t *d) {
1711
2111
 
1712
2112
  write_le_u32((uint8_t *)out, (uint32_t)src_size);
1713
2113
 
1714
- int csize = LZ4_compress_fast_continue(d->ctx.lz4, block_start, out + 8, src_size, bound, 1);
2114
+ int csize;
2115
+ if (d->level > 1) {
2116
+ csize = LZ4_compress_HC(block_start, out + 8, src_size, bound, d->level);
2117
+ } else {
2118
+ csize = LZ4_compress_default(block_start, out + 8, src_size, bound);
2119
+ }
1715
2120
  if (csize <= 0)
1716
2121
  rb_raise(eError, "lz4 stream compress block failed");
1717
2122
 
@@ -1759,37 +2164,41 @@ static VALUE deflater_write(VALUE self, VALUE chunk) {
1759
2164
 
1760
2165
  ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, out_cap, 0};
1761
2166
 
1762
- if (scheduler != Qnil && (input.size - input.pos) >= policy->fiber_stream_threshold) {
1763
- zstd_stream_chunk_fiber_t fargs = {
1764
- .cstream = d->ctx.zstd,
1765
- .input = &input,
1766
- .output = &output,
1767
- .result = 0,
1768
- .scheduler = scheduler,
1769
- .blocker = rb_obj_alloc(rb_cObject),
1770
- .fiber = rb_fiber_current(),
1771
- };
1772
- VALUE th = rb_thread_create(zstd_stream_chunk_fiber_thread, &fargs);
1773
- rb_fiber_scheduler_block(scheduler, fargs.blocker, Qnil);
1774
- join_thread(th);
1775
-
1776
- if (ZSTD_isError(fargs.result))
1777
- rb_raise(eError, "zstd compress stream: %s", ZSTD_getErrorName(fargs.result));
1778
- } else if (scheduler == Qnil &&
1779
- (input.size - input.pos) >= policy->gvl_unlock_threshold) {
1780
- zstd_stream_chunk_args_t args = {
1781
- .cstream = d->ctx.zstd,
1782
- .output = &output,
1783
- .input = &input,
1784
- .result = 0,
1785
- };
1786
- run_without_gvl(zstd_compress_stream_chunk_nogvl, &args);
1787
- if (ZSTD_isError(args.result))
1788
- rb_raise(eError, "zstd compress stream: %s", ZSTD_getErrorName(args.result));
1789
- } else {
1790
- size_t ret = ZSTD_compressStream(d->ctx.zstd, &output, &input);
1791
- if (ZSTD_isError(ret))
1792
- rb_raise(eError, "zstd compress stream: %s", ZSTD_getErrorName(ret));
2167
+ {
2168
+ work_exec_mode_t mode = select_fiber_nogvl_or_direct_mode(
2169
+ scheduler, input.size - input.pos, policy->fiber_stream_threshold,
2170
+ policy->gvl_unlock_threshold);
2171
+
2172
+ if (mode == WORK_EXEC_FIBER) {
2173
+ zstd_stream_chunk_fiber_t fargs = {
2174
+ .cstream = d->ctx.zstd,
2175
+ .input = input,
2176
+ .output = output,
2177
+ .result = 0,
2178
+ };
2179
+ RUN_WITH_EXEC_MODE(mode, zstd_stream_chunk_fiber_nogvl, fargs);
2180
+ input.pos = fargs.input.pos;
2181
+ output.pos = fargs.output.pos;
2182
+
2183
+ if (ZSTD_isError(fargs.result))
2184
+ rb_raise(eError, "zstd compress stream: %s",
2185
+ ZSTD_getErrorName(fargs.result));
2186
+ } else if (mode == WORK_EXEC_NOGVL) {
2187
+ zstd_stream_chunk_args_t args = {
2188
+ .cstream = d->ctx.zstd,
2189
+ .output = &output,
2190
+ .input = &input,
2191
+ .result = 0,
2192
+ };
2193
+ RUN_WITH_EXEC_MODE(mode, zstd_compress_stream_chunk_nogvl, args);
2194
+ if (ZSTD_isError(args.result))
2195
+ rb_raise(eError, "zstd compress stream: %s",
2196
+ ZSTD_getErrorName(args.result));
2197
+ } else {
2198
+ size_t ret = ZSTD_compressStream(d->ctx.zstd, &output, &input);
2199
+ if (ZSTD_isError(ret))
2200
+ rb_raise(eError, "zstd compress stream: %s", ZSTD_getErrorName(ret));
2201
+ }
1793
2202
  }
1794
2203
  result_len += output.pos;
1795
2204
  }
@@ -1814,22 +2223,23 @@ static VALUE deflater_write(VALUE self, VALUE chunk) {
1814
2223
  uint8_t *next_out = NULL;
1815
2224
  BROTLI_BOOL ok;
1816
2225
 
1817
- if (use_fiber && available_in >= policy->fiber_stream_threshold) {
2226
+ if (use_fiber &&
2227
+ select_fiber_or_direct_mode(scheduler, available_in,
2228
+ policy->fiber_stream_threshold) == WORK_EXEC_FIBER) {
1818
2229
  brotli_stream_chunk_fiber_t fargs = {
1819
2230
  .enc = d->ctx.brotli,
1820
2231
  .op = BROTLI_OPERATION_PROCESS,
1821
- .available_in = &available_in,
1822
- .next_in = &next_in,
1823
- .available_out = &available_out,
1824
- .next_out = &next_out,
2232
+ .available_in = available_in,
2233
+ .next_in = next_in,
2234
+ .available_out = available_out,
2235
+ .next_out = next_out,
1825
2236
  .result = BROTLI_FALSE,
1826
- .scheduler = scheduler,
1827
- .blocker = rb_obj_alloc(rb_cObject),
1828
- .fiber = rb_fiber_current(),
1829
2237
  };
1830
- VALUE th = rb_thread_create(brotli_stream_chunk_fiber_thread, &fargs);
1831
- rb_fiber_scheduler_block(scheduler, fargs.blocker, Qnil);
1832
- join_thread(th);
2238
+ RUN_VIA_FIBER_WORKER(brotli_stream_chunk_fiber_nogvl, fargs);
2239
+ available_in = fargs.available_in;
2240
+ next_in = fargs.next_in;
2241
+ available_out = fargs.available_out;
2242
+ next_out = fargs.next_out;
1833
2243
  ok = fargs.result;
1834
2244
  } else {
1835
2245
  ok = BrotliEncoderCompressStream(d->ctx.brotli, BROTLI_OPERATION_PROCESS,
@@ -2221,7 +2631,22 @@ static void inflater_free(void *ptr) {
2221
2631
 
2222
2632
  static size_t inflater_memsize(const void *ptr) {
2223
2633
  const inflater_t *inf = (const inflater_t *)ptr;
2224
- return sizeof(inflater_t) + inf->lz4_buf.cap;
2634
+ if (!inf)
2635
+ return 0;
2636
+
2637
+ size_t s = sizeof(inflater_t);
2638
+ switch (inf->algo) {
2639
+ case ALGO_ZSTD:
2640
+ if (inf->ctx.zstd)
2641
+ s += ZSTD_sizeof_DStream(inf->ctx.zstd);
2642
+ break;
2643
+ case ALGO_BROTLI:
2644
+ break;
2645
+ case ALGO_LZ4:
2646
+ s += inf->lz4_buf.cap;
2647
+ break;
2648
+ }
2649
+ return s;
2225
2650
  }
2226
2651
 
2227
2652
  static const rb_data_type_t inflater_type = {"Compress::Inflater",
@@ -2317,8 +2742,7 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2317
2742
  if (slen == 0)
2318
2743
  return rb_binary_str_new("", 0);
2319
2744
 
2320
- inf->total_input =
2321
- checked_add_size(inf->total_input, slen, "compressed input exceeds representable size");
2745
+ size_t input_accounted_before = inf->total_input;
2322
2746
 
2323
2747
  switch (inf->algo) {
2324
2748
  case ALGO_ZSTD: {
@@ -2357,14 +2781,17 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2357
2781
  ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, current_out_cap, 0};
2358
2782
  size_t ret;
2359
2783
 
2360
- if (scheduler != Qnil && (input.size - input.pos) >= policy->fiber_stream_threshold) {
2361
- zstd_decompress_stream_chunk_args_t args = {
2784
+ if (select_fiber_or_direct_mode(scheduler, input.size - input.pos,
2785
+ policy->fiber_stream_threshold) == WORK_EXEC_FIBER) {
2786
+ zstd_decompress_stream_chunk_fiber_t args = {
2362
2787
  .dstream = inf->ctx.zstd,
2363
- .output = &output,
2364
- .input = &input,
2788
+ .output = output,
2789
+ .input = input,
2365
2790
  .result = 0,
2366
2791
  };
2367
- run_via_fiber_worker(scheduler, zstd_decompress_stream_chunk_nogvl, &args);
2792
+ RUN_VIA_FIBER_WORKER(zstd_decompress_stream_chunk_fiber_nogvl, args);
2793
+ output.pos = args.output.pos;
2794
+ input.pos = args.input.pos;
2368
2795
  ret = args.result;
2369
2796
  } else {
2370
2797
  ret = ZSTD_decompressStream(inf->ctx.zstd, &output, &input);
@@ -2376,11 +2803,15 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2376
2803
  "decompressed output exceeds representable size");
2377
2804
  size_t total_output = checked_add_size(
2378
2805
  inf->total_output, result_len, "decompressed output exceeds representable size");
2379
- enforce_output_and_ratio_limits(total_output, inf->total_input, inf->max_output_size,
2806
+ size_t total_input = checked_add_size(input_accounted_before, input.pos,
2807
+ "compressed input exceeds representable size");
2808
+ enforce_output_and_ratio_limits(total_output, total_input, inf->max_output_size,
2380
2809
  inf->max_ratio_enabled, inf->max_ratio);
2381
2810
  if (ret == 0)
2382
2811
  break;
2383
2812
  }
2813
+ inf->total_input = checked_add_size(input_accounted_before, input.pos,
2814
+ "compressed input exceeds representable size");
2384
2815
  inf->total_output = checked_add_size(inf->total_output, result_len,
2385
2816
  "decompressed output exceeds representable size");
2386
2817
  rb_str_set_len(result, result_len);
@@ -2409,16 +2840,21 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2409
2840
  uint8_t *next_out = NULL;
2410
2841
  BrotliDecoderResult res;
2411
2842
 
2412
- if (scheduler != Qnil && available_in >= policy->fiber_stream_threshold) {
2413
- brotli_decompress_stream_args_t sargs = {
2843
+ if (select_fiber_or_direct_mode(scheduler, available_in,
2844
+ policy->fiber_stream_threshold) == WORK_EXEC_FIBER) {
2845
+ brotli_decompress_stream_fiber_t sargs = {
2414
2846
  .dec = inf->ctx.brotli,
2415
- .available_in = &available_in,
2416
- .next_in = &next_in,
2417
- .available_out = &available_out,
2418
- .next_out = &next_out,
2847
+ .available_in = available_in,
2848
+ .next_in = next_in,
2849
+ .available_out = available_out,
2850
+ .next_out = next_out,
2419
2851
  .result = BROTLI_DECODER_RESULT_ERROR,
2420
2852
  };
2421
- run_via_fiber_worker(scheduler, brotli_decompress_stream_nogvl, &sargs);
2853
+ RUN_VIA_FIBER_WORKER(brotli_decompress_stream_fiber_nogvl, sargs);
2854
+ available_in = sargs.available_in;
2855
+ next_in = sargs.next_in;
2856
+ available_out = sargs.available_out;
2857
+ next_out = sargs.next_out;
2422
2858
  res = sargs.result;
2423
2859
  } else {
2424
2860
  res = BrotliDecoderDecompressStream(inf->ctx.brotli, &available_in, &next_in,
@@ -2436,9 +2872,11 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2436
2872
  checked_add_size(result_len, out_size,
2437
2873
  "decompressed output exceeds representable size"),
2438
2874
  "decompressed output exceeds representable size");
2439
- enforce_output_and_ratio_limits(total_output, inf->total_input,
2440
- inf->max_output_size, inf->max_ratio_enabled,
2441
- inf->max_ratio);
2875
+ size_t total_input =
2876
+ checked_add_size(input_accounted_before, slen - available_in,
2877
+ "compressed input exceeds representable size");
2878
+ enforce_output_and_ratio_limits(total_output, total_input, inf->max_output_size,
2879
+ inf->max_ratio_enabled, inf->max_ratio);
2442
2880
 
2443
2881
  if (result_len + out_size > result_cap) {
2444
2882
  result_cap = result_len + out_size;
@@ -2453,6 +2891,8 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2453
2891
  if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in == 0)
2454
2892
  break;
2455
2893
  }
2894
+ inf->total_input = checked_add_size(input_accounted_before, slen - available_in,
2895
+ "compressed input exceeds representable size");
2456
2896
  inf->total_output = checked_add_size(inf->total_output, result_len,
2457
2897
  "decompressed output exceeds representable size");
2458
2898
  rb_str_set_len(result, result_len);
@@ -2462,7 +2902,6 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2462
2902
  case ALGO_LZ4: {
2463
2903
  size_t data_len = inf->lz4_buf.len - inf->lz4_buf.offset;
2464
2904
  size_t needed = data_len + slen;
2465
- // TODO(v0.4): optional standard LZ4 frame format support via lz4frame.h
2466
2905
 
2467
2906
  if (inf->lz4_buf.offset > 0 && needed > inf->lz4_buf.cap) {
2468
2907
  if (data_len > 0)
@@ -2521,7 +2960,10 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2521
2960
  checked_add_size(result_len, orig_size,
2522
2961
  "decompressed output exceeds representable size"),
2523
2962
  "decompressed output exceeds representable size");
2524
- enforce_output_and_ratio_limits(total_output, inf->total_input, inf->max_output_size,
2963
+ size_t total_input = checked_add_size(
2964
+ input_accounted_before, (pos + 8 + (size_t)comp_size) - inf->lz4_buf.offset,
2965
+ "compressed input exceeds representable size");
2966
+ enforce_output_and_ratio_limits(total_output, total_input, inf->max_output_size,
2525
2967
  inf->max_ratio_enabled, inf->max_ratio);
2526
2968
 
2527
2969
  if (result_len + orig_size > result_cap) {
@@ -2545,6 +2987,8 @@ static VALUE inflater_write(VALUE self, VALUE chunk) {
2545
2987
  }
2546
2988
  }
2547
2989
 
2990
+ inf->total_input = checked_add_size(input_accounted_before, pos - inf->lz4_buf.offset,
2991
+ "compressed input exceeds representable size");
2548
2992
  inf->lz4_buf.offset = pos;
2549
2993
  inf->total_output = checked_add_size(inf->total_output, result_len,
2550
2994
  "decompressed output exceeds representable size");
@@ -2738,12 +3182,6 @@ static VALUE train_dictionary_internal(VALUE samples, VALUE size_val, compress_a
2738
3182
  }
2739
3183
 
2740
3184
  static VALUE zstd_train_dictionary(int argc, VALUE *argv, VALUE self) {
2741
- // #if defined(__APPLE__) && (defined(__arm64__) || defined(__aarch64__))
2742
- // rb_raise(eUnsupportedError,
2743
- // "Zstd dictionary training is temporarily disabled on arm64-darwin "
2744
- // "because the current vendored trainer path crashes on this platform");
2745
- // #endif
2746
-
2747
3185
  VALUE samples, opts;
2748
3186
  rb_scan_args(argc, argv, "1:", &samples, &opts);
2749
3187
  reject_algorithm_keyword(opts);