mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,792 @@
1
+ module MlxLm
2
+ # Ruby constant names cannot begin with "_", so this is the _BaseCache abstraction.
3
+ class BaseCache
4
+ def state
5
+ []
6
+ end
7
+
8
+ def state=(value)
9
+ return if value.nil? || (value.respond_to?(:empty?) && value.empty?)
10
+
11
+ raise ArgumentError, "This cache has no state but a state was set."
12
+ end
13
+
14
+ def meta_state
15
+ ""
16
+ end
17
+
18
+ def meta_state=(value)
19
+ return if value.nil? || (value.respond_to?(:empty?) && value.empty?)
20
+
21
+ raise ArgumentError, "This cache has no meta_state but a meta_state was set."
22
+ end
23
+
24
+ def is_trimmable
25
+ false
26
+ end
27
+
28
+ def size
29
+ 0
30
+ end
31
+
32
+ def nbytes
33
+ raise NotImplementedError, "Cache sub-class must implement nbytes"
34
+ end
35
+
36
+ def empty
37
+ raise NotImplementedError, "Cache sub-class must implement empty"
38
+ end
39
+
40
+ def self.from_state(state, meta_state)
41
+ obj = allocate
42
+ obj.state = state
43
+ obj.meta_state = meta_state
44
+ obj
45
+ end
46
+ end
47
+
48
+ # Simple KV Cache — concatenates new K,V to existing.
49
+ # Uses simple concatenation since MLX Ruby doesn't support in-place slice assignment.
50
+ class KVCache < BaseCache
51
+ attr_reader :offset
52
+
53
+ def initialize
54
+ @keys = nil
55
+ @values = nil
56
+ @offset = 0
57
+ end
58
+
59
+ def update_and_fetch(keys, values)
60
+ mx = MLX::Core
61
+ if @keys.nil?
62
+ @keys = keys
63
+ @values = values
64
+ else
65
+ @keys = mx.concatenate([@keys, keys], 2)
66
+ @values = mx.concatenate([@values, values], 2)
67
+ end
68
+ @offset += keys.shape[2]
69
+ return @keys, @values
70
+ end
71
+
72
+ def size
73
+ @offset
74
+ end
75
+
76
+ def state
77
+ [@keys, @values]
78
+ end
79
+
80
+ def state=(v)
81
+ @keys, @values = v
82
+ @offset = @keys ? @keys.shape[2] : 0
83
+ end
84
+
85
+ def is_trimmable
86
+ true
87
+ end
88
+
89
+ def trim(n)
90
+ return 0 if @keys.nil? || n <= 0
91
+
92
+ n = [@offset, n].min
93
+ @offset -= n
94
+ @keys = _slice_prefix(@keys, @offset)
95
+ @values = _slice_prefix(@values, @offset)
96
+ n
97
+ end
98
+
99
+ def to_quantized(group_size: 64, bits: 4)
100
+ quant_cache = QuantizedKVCache.new(group_size: group_size, bits: bits)
101
+ return quant_cache if @keys.nil?
102
+
103
+ mx = MLX::Core
104
+ qk = mx.quantize(@keys, group_size, bits)
105
+ qv = mx.quantize(@values, group_size, bits)
106
+ quant_cache.state = [qk, qv]
107
+ quant_cache
108
+ end
109
+
110
+ def empty
111
+ @keys.nil?
112
+ end
113
+
114
+ def nbytes
115
+ return 0 if @keys.nil?
116
+
117
+ @keys.nbytes + @values.nbytes
118
+ end
119
+
120
+ def self.merge(caches)
121
+ non_empty = caches.reject(&:empty)
122
+ return new if non_empty.empty?
123
+
124
+ mx = MLX::Core
125
+ template_k, template_v = non_empty.first.state
126
+ target_len = non_empty.map(&:size).max
127
+
128
+ rows_k = caches.map do |cache|
129
+ if cache.empty
130
+ shape = template_k.shape.dup
131
+ shape[0] = 1
132
+ shape[2] = target_len
133
+ mx.zeros(shape, template_k.dtype)
134
+ else
135
+ keys, _values = cache.state
136
+ _left_pad_seq(keys, target_len)
137
+ end
138
+ end
139
+
140
+ rows_v = caches.map do |cache|
141
+ if cache.empty
142
+ shape = template_v.shape.dup
143
+ shape[0] = 1
144
+ shape[2] = target_len
145
+ mx.zeros(shape, template_v.dtype)
146
+ else
147
+ _keys, values = cache.state
148
+ _left_pad_seq(values, target_len)
149
+ end
150
+ end
151
+
152
+ out = new
153
+ out.state = [mx.concatenate(rows_k, 0), mx.concatenate(rows_v, 0)]
154
+ out
155
+ end
156
+
157
+ private
158
+
159
+ def _slice_prefix(array, length)
160
+ return array if array.shape[2] == length
161
+
162
+ MLX::Core.split(array, [length], 2)[0]
163
+ end
164
+
165
+ def self._left_pad_seq(array, target_len)
166
+ return array if array.shape[2] == target_len
167
+
168
+ mx = MLX::Core
169
+ pad = target_len - array.shape[2]
170
+ pad_shape = array.shape.dup
171
+ pad_shape[2] = pad
172
+ padding = mx.zeros(pad_shape, array.dtype)
173
+ mx.concatenate([padding, array], 2)
174
+ end
175
+ end
176
+
177
+ # Rotating KV Cache — fixed maximum size, old entries rotate out.
178
+ class RotatingKVCache < BaseCache
179
+ attr_reader :offset
180
+
181
+ def initialize(max_size:, keep: 0)
182
+ @max_size = max_size
183
+ @keep = keep
184
+ @keys = nil
185
+ @values = nil
186
+ @offset = 0
187
+ end
188
+
189
+ def size
190
+ [@offset, @max_size].min
191
+ end
192
+
193
+ def update_and_fetch(keys, values)
194
+ mx = MLX::Core
195
+ if @keys.nil?
196
+ @keys = keys
197
+ @values = values
198
+ @offset += keys.shape[2]
199
+ else
200
+ @keys = mx.concatenate([@keys, keys], 2)
201
+ @values = mx.concatenate([@values, values], 2)
202
+ @offset += keys.shape[2]
203
+
204
+ # Trim if exceeding max_size
205
+ if @keys.shape[2] > @max_size
206
+ excess = @keys.shape[2] - @max_size
207
+ if @keep > 0
208
+ # Keep first @keep tokens + last (max_size - @keep) tokens
209
+ kept_k = mx.split(@keys, [@keep], 2)[0]
210
+ tail_k = mx.split(@keys, [excess + @keep], 2)[1]
211
+ @keys = mx.concatenate([kept_k, tail_k], 2)
212
+
213
+ kept_v = mx.split(@values, [@keep], 2)[0]
214
+ tail_v = mx.split(@values, [excess + @keep], 2)[1]
215
+ @values = mx.concatenate([kept_v, tail_v], 2)
216
+ else
217
+ @keys = mx.split(@keys, [excess], 2)[1]
218
+ @values = mx.split(@values, [excess], 2)[1]
219
+ end
220
+ end
221
+ end
222
+
223
+ return @keys, @values
224
+ end
225
+
226
+ def state
227
+ [@keys, @values]
228
+ end
229
+
230
+ def state=(v)
231
+ @keys, @values = v
232
+ @offset = @keys ? @keys.shape[2] : 0
233
+ end
234
+
235
+ def meta_state
236
+ [@keep, @max_size, @offset]
237
+ end
238
+
239
+ def meta_state=(v)
240
+ @keep, @max_size, @offset = v.map(&:to_i)
241
+ end
242
+
243
+ def is_trimmable
244
+ @offset < @max_size
245
+ end
246
+
247
+ def trim(n)
248
+ return 0 if @keys.nil? || n <= 0
249
+
250
+ n = [@offset, n].min
251
+ @offset -= n
252
+ keep_len = [@keys.shape[2], @offset].min
253
+ @keys = _slice_prefix(@keys, keep_len)
254
+ @values = _slice_prefix(@values, keep_len)
255
+ n
256
+ end
257
+
258
+ def empty
259
+ @keys.nil?
260
+ end
261
+
262
+ def nbytes
263
+ return 0 if @keys.nil?
264
+
265
+ @keys.nbytes + @values.nbytes
266
+ end
267
+
268
+ def self.merge(caches)
269
+ KVCache.merge(caches)
270
+ end
271
+
272
+ private
273
+
274
+ def _slice_prefix(array, length)
275
+ return array if array.shape[2] == length
276
+
277
+ MLX::Core.split(array, [length], 2)[0]
278
+ end
279
+ end
280
+
281
+ class QuantizedKVCache < BaseCache
282
+ attr_reader :offset, :group_size, :bits
283
+
284
+ def initialize(group_size: 64, bits: 8)
285
+ @keys = nil
286
+ @values = nil
287
+ @offset = 0
288
+ @group_size = group_size
289
+ @bits = bits
290
+ end
291
+
292
+ def update_and_fetch(keys, values)
293
+ mx = MLX::Core
294
+ qk = mx.quantize(keys, @group_size, @bits)
295
+ qv = mx.quantize(values, @group_size, @bits)
296
+
297
+ if @keys.nil?
298
+ @keys = qk
299
+ @values = qv
300
+ else
301
+ @keys = _concat_quantized(@keys, qk)
302
+ @values = _concat_quantized(@values, qv)
303
+ end
304
+
305
+ @offset += keys.shape[2]
306
+ [@keys, @values]
307
+ end
308
+
309
+ def size
310
+ @offset
311
+ end
312
+
313
+ def state
314
+ [@keys, @values]
315
+ end
316
+
317
+ def state=(v)
318
+ @keys, @values = v
319
+ @offset = @keys ? @keys[0].shape[2] : 0
320
+ end
321
+
322
+ def meta_state
323
+ [@offset, @group_size, @bits]
324
+ end
325
+
326
+ def meta_state=(v)
327
+ @offset, @group_size, @bits = v.map(&:to_i)
328
+ end
329
+
330
+ def is_trimmable
331
+ true
332
+ end
333
+
334
+ def trim(n)
335
+ return 0 if @keys.nil? || n <= 0
336
+
337
+ n = [@offset, n].min
338
+ @offset -= n
339
+ @keys = _slice_quantized(@keys, @offset)
340
+ @values = _slice_quantized(@values, @offset)
341
+ n
342
+ end
343
+
344
+ def empty
345
+ @keys.nil?
346
+ end
347
+
348
+ def nbytes
349
+ return 0 if @keys.nil?
350
+
351
+ _sum_nbytes(@keys) + _sum_nbytes(@values)
352
+ end
353
+
354
+ private
355
+
356
+ def _concat_quantized(lhs, rhs)
357
+ lhs.each_with_index.map do |item, i|
358
+ MLX::Core.concatenate([item, rhs[i]], 2)
359
+ end
360
+ end
361
+
362
+ def _slice_quantized(tensors, length)
363
+ tensors.map do |item|
364
+ item.shape[2] == length ? item : MLX::Core.split(item, [length], 2)[0]
365
+ end
366
+ end
367
+
368
+ def _sum_nbytes(tensors)
369
+ tensors.reduce(0) { |acc, t| acc + t.nbytes }
370
+ end
371
+ end
372
+
373
+ class ArraysCache < BaseCache
374
+ attr_reader :cache
375
+ attr_accessor :left_padding, :lengths
376
+
377
+ def initialize(size, left_padding: nil)
378
+ @cache = Array.new(size)
379
+ @left_padding = left_padding ? MLX::Core.array(left_padding) : nil
380
+ @lengths = nil
381
+ end
382
+
383
+ def []=(idx, value)
384
+ @cache[idx] = value
385
+ end
386
+
387
+ def [](idx)
388
+ @cache[idx]
389
+ end
390
+
391
+ def state
392
+ @cache
393
+ end
394
+
395
+ def state=(v)
396
+ @cache = v
397
+ end
398
+
399
+ def meta_state
400
+ [@left_padding, @lengths]
401
+ end
402
+
403
+ def meta_state=(v)
404
+ @left_padding, @lengths = v
405
+ end
406
+
407
+ def filter(batch_indices)
408
+ idx = _indices_array(batch_indices)
409
+ @cache = @cache.map { |c| c.nil? ? nil : MLX::Core.take(c, idx, 0) }
410
+ end
411
+
412
+ def extend(other)
413
+ @cache = @cache.zip(other.cache).map do |c, o|
414
+ if c.nil?
415
+ o
416
+ elsif o.nil?
417
+ c
418
+ else
419
+ MLX::Core.concatenate([c, o], 0)
420
+ end
421
+ end
422
+
423
+ if @left_padding && other.left_padding
424
+ @left_padding = MLX::Core.concatenate([@left_padding, other.left_padding], 0)
425
+ end
426
+ if @lengths && other.lengths
427
+ @lengths = MLX::Core.concatenate([@lengths, other.lengths], 0)
428
+ end
429
+ end
430
+
431
+ def extract(idx)
432
+ single = _indices_array([idx])
433
+ out = ArraysCache.new(@cache.length)
434
+ out.state = @cache.map { |c| c.nil? ? nil : MLX::Core.take(c, single, 0) }
435
+ if @left_padding
436
+ out.left_padding = MLX::Core.take(@left_padding, single, 0)
437
+ end
438
+ if @lengths
439
+ out.lengths = MLX::Core.take(@lengths, single, 0)
440
+ end
441
+ out
442
+ end
443
+
444
+ def prepare(lengths: nil, **_kwargs)
445
+ @lengths = lengths.nil? ? nil : MLX::Core.array(lengths)
446
+ end
447
+
448
+ def finalize
449
+ @lengths = nil
450
+ @left_padding = nil
451
+ end
452
+
453
+ def advance(n)
454
+ @lengths = MLX::Core.subtract(@lengths, n) if @lengths
455
+ @left_padding = MLX::Core.subtract(@left_padding, n) if @left_padding
456
+ end
457
+
458
+ def make_mask(n)
459
+ mx = MLX::Core
460
+ pos = mx.arange(n).reshape([1, n])
461
+ if @left_padding
462
+ mx.greater_equal(pos, @left_padding.reshape([@left_padding.shape[0], 1]))
463
+ elsif @lengths
464
+ mx.less(pos, @lengths.reshape([@lengths.shape[0], 1]))
465
+ else
466
+ nil
467
+ end
468
+ end
469
+
470
+ def self.merge(caches)
471
+ mx = MLX::Core
472
+ n_state = caches[0].cache.length
473
+ batch = caches.length
474
+ out = new(n_state)
475
+
476
+ n_state.times do |e|
477
+ init = caches.map { |c| c[e] }.find { |v| !v.nil? }
478
+ next if init.nil?
479
+
480
+ shape = init.shape.dup
481
+ shape[0] = 1
482
+ zero = mx.zeros(shape, init.dtype)
483
+ rows = caches.map { |c| c[e] || zero }
484
+ out[e] = mx.concatenate(rows, 0)
485
+ end
486
+
487
+ left_padding_values = caches.map(&:left_padding).compact
488
+ out.left_padding = mx.concatenate(left_padding_values, 0) if left_padding_values.length == batch
489
+
490
+ length_values = caches.map(&:lengths).compact
491
+ out.lengths = mx.concatenate(length_values, 0) if length_values.length == batch
492
+
493
+ out
494
+ end
495
+
496
+ def empty
497
+ @cache.empty? || @cache[0].nil?
498
+ end
499
+
500
+ def nbytes
501
+ @cache.compact.reduce(0) { |acc, c| acc + c.nbytes }
502
+ end
503
+
504
+ private
505
+
506
+ def _indices_array(indices)
507
+ return indices if indices.is_a?(MLX::Core::Array)
508
+
509
+ MLX::Core.array(indices, dtype: MLX::Core.int32)
510
+ end
511
+ end
512
+
513
+ class ChunkedKVCache < BaseCache
514
+ attr_reader :offset, :chunk_size, :start_position
515
+
516
+ def initialize(chunk_size)
517
+ @keys = nil
518
+ @values = nil
519
+ @offset = 0
520
+ @chunk_size = chunk_size
521
+ @start_position = 0
522
+ end
523
+
524
+ def maybe_trim_front
525
+ return if @keys.nil? || @keys.shape[2] < @chunk_size
526
+
527
+ excess = @keys.shape[2] - @chunk_size
528
+ return if excess <= 0
529
+
530
+ @start_position += excess
531
+ @keys = _slice_tail(@keys, @chunk_size)
532
+ @values = _slice_tail(@values, @chunk_size)
533
+ end
534
+
535
+ def update_and_fetch(keys, values)
536
+ mx = MLX::Core
537
+ if @keys.nil?
538
+ @keys = keys
539
+ @values = values
540
+ else
541
+ @keys = mx.concatenate([@keys, keys], 2)
542
+ @values = mx.concatenate([@values, values], 2)
543
+ end
544
+ @offset += keys.shape[2]
545
+ [@keys, @values]
546
+ end
547
+
548
+ def size
549
+ @offset - @start_position
550
+ end
551
+
552
+ def state
553
+ [@keys, @values]
554
+ end
555
+
556
+ def state=(v)
557
+ @keys, @values = v
558
+ @offset = @keys ? @keys.shape[2] : 0
559
+ end
560
+
561
+ def meta_state
562
+ [@chunk_size, @start_position]
563
+ end
564
+
565
+ def meta_state=(v)
566
+ @chunk_size, @start_position = v.map(&:to_i)
567
+ end
568
+
569
+ def is_trimmable
570
+ true
571
+ end
572
+
573
+ def trim(n)
574
+ return 0 if @keys.nil? || n <= 0
575
+
576
+ available = @offset - @start_position
577
+ n = [available, n].min
578
+ @offset -= n
579
+ keep_len = @offset - @start_position
580
+ @keys = _slice_prefix(@keys, keep_len)
581
+ @values = _slice_prefix(@values, keep_len)
582
+ n
583
+ end
584
+
585
+ def empty
586
+ @keys.nil?
587
+ end
588
+
589
+ def nbytes
590
+ return 0 if @keys.nil?
591
+
592
+ @keys.nbytes + @values.nbytes
593
+ end
594
+
595
+ private
596
+
597
+ def _slice_prefix(array, length)
598
+ return array if array.shape[2] == length
599
+
600
+ MLX::Core.split(array, [length], 2)[0]
601
+ end
602
+
603
+ def _slice_tail(array, length)
604
+ return array if array.shape[2] == length
605
+
606
+ split_idx = array.shape[2] - length
607
+ MLX::Core.split(array, [split_idx], 2)[1]
608
+ end
609
+ end
610
+
611
+ class CacheList < BaseCache
612
+ attr_reader :caches
613
+
614
+ def initialize(*caches)
615
+ @caches = caches
616
+ end
617
+
618
+ def [](idx)
619
+ @caches[idx]
620
+ end
621
+
622
+ def is_trimmable
623
+ @caches.all?(&:is_trimmable)
624
+ end
625
+
626
+ def trim(n)
627
+ trimmed = 0
628
+ @caches.each do |cache|
629
+ trimmed = cache.trim(n)
630
+ end
631
+ trimmed
632
+ end
633
+
634
+ def state
635
+ @caches.map(&:state)
636
+ end
637
+
638
+ def state=(v)
639
+ @caches.zip(v).each do |cache, cache_state|
640
+ cache.state = cache_state
641
+ end
642
+ end
643
+
644
+ def meta_state
645
+ [
646
+ @caches.map { |c| c.class.name.split("::").last },
647
+ @caches.map(&:meta_state),
648
+ ]
649
+ end
650
+
651
+ def meta_state=(v)
652
+ _classes, states = v
653
+ @caches.zip(states).each do |cache, cache_state|
654
+ cache.meta_state = cache_state
655
+ end
656
+ end
657
+
658
+ def filter(batch_indices)
659
+ @caches.each { |cache| cache.filter(batch_indices) if cache.respond_to?(:filter) }
660
+ end
661
+
662
+ def extend(other)
663
+ @caches.zip(other.caches).each do |cache, other_cache|
664
+ next unless cache.class.instance_method(:extend).owner != Object
665
+
666
+ cache.extend(other_cache)
667
+ end
668
+ end
669
+
670
+ def self.merge(caches)
671
+ merged = caches[0].caches.each_index.map do |i|
672
+ batch = caches.map { |c| c.caches[i] }
673
+ unless batch[0].class.respond_to?(:merge)
674
+ raise NotImplementedError, "#{batch[0].class} does not implement .merge"
675
+ end
676
+
677
+ batch[0].class.merge(batch)
678
+ end
679
+ new(*merged)
680
+ end
681
+
682
+ def extract(idx)
683
+ CacheList.new(*@caches.map { |cache| cache.extract(idx) })
684
+ end
685
+
686
+ def prepare(**kwargs)
687
+ @caches.each { |cache| cache.prepare(**kwargs) if cache.respond_to?(:prepare) }
688
+ end
689
+
690
+ def finalize
691
+ @caches.each { |cache| cache.finalize if cache.respond_to?(:finalize) }
692
+ end
693
+
694
+ def size
695
+ @caches.map(&:size).max || 0
696
+ end
697
+
698
+ def empty
699
+ @caches.empty? || @caches[0].empty
700
+ end
701
+
702
+ def nbytes
703
+ @caches.reduce(0) { |acc, cache| acc + cache.nbytes }
704
+ end
705
+
706
+ def self.from_state(state, meta_state)
707
+ classes, metas = meta_state
708
+ caches = state.each_with_index.map do |sub_state, i|
709
+ klass = MlxLm.const_get(classes[i])
710
+ klass.from_state(sub_state, metas[i])
711
+ end
712
+ new(*caches)
713
+ end
714
+ end
715
+
716
+ module Cache
717
+ module_function
718
+
719
+ def make_prompt_cache(model, max_kv_size: nil)
720
+ if model.respond_to?(:make_cache)
721
+ return model.make_cache
722
+ end
723
+
724
+ num_layers = model.layers.length
725
+ if max_kv_size
726
+ Array.new(num_layers) { RotatingKVCache.new(max_size: max_kv_size, keep: 4) }
727
+ else
728
+ Array.new(num_layers) { KVCache.new }
729
+ end
730
+ end
731
+
732
+ def save_prompt_cache(path, cache)
733
+ mx = MLX::Core
734
+ tensors = {}
735
+
736
+ cache.each_with_index do |layer_cache, i|
737
+ keys, values = layer_cache.state
738
+ next unless keys
739
+
740
+ mx.eval(keys, values)
741
+ tensors["layer.#{i}.keys"] = keys
742
+ tensors["layer.#{i}.values"] = values
743
+ end
744
+
745
+ # Also save metadata
746
+ tensors["_meta_offsets"] = mx.array(cache.map(&:offset), mx.int32)
747
+
748
+ # Serialize using safetensors gem
749
+ st_tensors = {}
750
+ tensors.each do |name, arr|
751
+ arr = arr.astype(mx.float32) unless [mx.float32, mx.int32].include?(arr.dtype)
752
+ mx.eval(arr)
753
+ data = arr.tolist
754
+ data = data.flatten if data.is_a?(::Array) && data.first.is_a?(::Array)
755
+ data = [data].flatten
756
+
757
+ if arr.dtype == mx.int32
758
+ binary = data.map(&:to_i).pack("l<*")
759
+ st_tensors[name] = { "dtype" => "int32", "shape" => arr.shape, "data" => binary }
760
+ else
761
+ binary = data.pack("e*")
762
+ st_tensors[name] = { "dtype" => "float32", "shape" => arr.shape, "data" => binary }
763
+ end
764
+ end
765
+
766
+ File.binwrite(path, Safetensors.serialize(st_tensors))
767
+ end
768
+
769
+ def load_prompt_cache(path, model)
770
+ loaded = WeightUtils.load_safetensors(path)
771
+ mx = MLX::Core
772
+
773
+ offsets = loaded["_meta_offsets"]
774
+ mx.eval(offsets)
775
+ offset_list = offsets.tolist
776
+ offset_list = [offset_list].flatten
777
+
778
+ num_layers = model.layers.length
779
+ cache = Array.new(num_layers) { KVCache.new }
780
+
781
+ num_layers.times do |i|
782
+ keys = loaded["layer.#{i}.keys"]
783
+ values = loaded["layer.#{i}.values"]
784
+ next unless keys
785
+
786
+ cache[i].state = [keys, values]
787
+ end
788
+
789
+ cache
790
+ end
791
+ end
792
+ end