mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,159 @@
1
+ module MlxLm
2
+ module SampleUtils
3
+ module_function
4
+
5
+ # Build a sampler callable (proc) from the given parameters.
6
+ # Returns a proc that takes logprobs (mx.array) and returns a token (mx.array).
7
+ def make_sampler(
8
+ temp: 0.0,
9
+ top_p: 0.0,
10
+ min_p: 0.0,
11
+ min_tokens_to_keep: 1,
12
+ top_k: 0
13
+ )
14
+ mx = MLX::Core
15
+
16
+ if temp == 0
17
+ return ->(x) { mx.argmax(x, -1) }
18
+ end
19
+
20
+ sampling_methods = []
21
+ if top_p > 0 && top_p < 1.0
22
+ sampling_methods << ->(x) { apply_top_p(x, top_p) }
23
+ end
24
+ if min_p != 0.0
25
+ sampling_methods << ->(x) { apply_min_p(x, min_p, min_tokens_to_keep) }
26
+ end
27
+ if top_k > 0
28
+ sampling_methods << ->(x) { apply_top_k(x, top_k) }
29
+ end
30
+
31
+ ->(logprobs) {
32
+ sampling_methods.each { |method| logprobs = method.call(logprobs) }
33
+ categorical_sampling(logprobs, temp)
34
+ }
35
+ end
36
+
37
+ def make_logits_processors(repetition_penalty: nil, repetition_context_size: 20)
38
+ processors = []
39
+ if repetition_penalty && repetition_penalty != 0.0
40
+ processors << make_repetition_penalty(repetition_penalty, repetition_context_size)
41
+ end
42
+ processors
43
+ end
44
+
45
+ def apply_top_k(logprobs, top_k)
46
+ mx = MLX::Core
47
+ vocab_size = logprobs.shape[-1]
48
+ raise ArgumentError, "top_k must be in (0, #{vocab_size}]" unless top_k.is_a?(Integer) && top_k > 0 && top_k < vocab_size
49
+
50
+ neg_logprobs = mx.negative(logprobs)
51
+ mask_idx = mx.argpartition(neg_logprobs, top_k - 1, -1)
52
+ # Get indices after top_k (the ones to mask)
53
+ rest = mx.split(mask_idx, [top_k], -1)[1]
54
+ neg_inf = mx.array([-Float::INFINITY], dtype: logprobs.dtype)
55
+ mx.put_along_axis(logprobs, rest, neg_inf, -1)
56
+ end
57
+
58
+ def apply_min_p(logprobs, min_p, min_tokens_to_keep = 1)
59
+ mx = MLX::Core
60
+ raise ArgumentError, "min_p must be in [0, 1]" unless min_p >= 0 && min_p <= 1.0
61
+
62
+ # Sort indices in decreasing order
63
+ neg_logprobs = mx.negative(logprobs)
64
+ sorted_indices = mx.argsort(neg_logprobs, -1)
65
+ sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, -1)
66
+
67
+ # Top probability
68
+ top_logprobs = mx.split(sorted_logprobs, [1], -1)[0]
69
+
70
+ # Calculate the min_p threshold
71
+ scaled_min_p = top_logprobs + Math.log(min_p)
72
+
73
+ # Mask tokens below threshold
74
+ tokens_to_remove = mx.less(sorted_logprobs, scaled_min_p)
75
+
76
+ neg_inf = mx.array(-Float::INFINITY, dtype: sorted_logprobs.dtype)
77
+ selected_logprobs = mx.where(tokens_to_remove, neg_inf, sorted_logprobs)
78
+
79
+ # Restore the top min_tokens_to_keep tokens regardless
80
+ if min_tokens_to_keep > 0
81
+ top_sorted = mx.split(sorted_logprobs, [min_tokens_to_keep], -1)[0]
82
+ rest_selected = mx.split(selected_logprobs, [min_tokens_to_keep], -1)[1]
83
+ selected_logprobs = mx.concatenate([top_sorted, rest_selected], -1)
84
+ end
85
+
86
+ # Create inverse mapping to restore original order
87
+ inverse_indices = mx.put_along_axis(
88
+ mx.zeros_like(sorted_indices),
89
+ sorted_indices,
90
+ mx.arange(sorted_indices.shape[-1]).astype(sorted_indices.dtype),
91
+ -1
92
+ )
93
+
94
+ mx.take_along_axis(selected_logprobs, inverse_indices, -1)
95
+ end
96
+
97
+ def apply_top_p(logprobs, top_p)
98
+ mx = MLX::Core
99
+ probs = mx.exp(logprobs)
100
+ # sort in ascending order
101
+ sorted_indices = mx.argsort(logprobs, -1)
102
+ sorted_probs = mx.take_along_axis(probs, sorted_indices, -1)
103
+
104
+ cumulative_probs = mx.cumsum(sorted_probs, -1)
105
+
106
+ # Rearrange cumulative probs back to original order
107
+ inverse_indices = mx.put_along_axis(
108
+ mx.zeros_like(sorted_indices),
109
+ sorted_indices,
110
+ mx.arange(sorted_indices.shape[-1]).astype(sorted_indices.dtype),
111
+ -1
112
+ )
113
+ cumulative_probs = mx.take_along_axis(cumulative_probs, inverse_indices, -1)
114
+
115
+ # select tokens with cumulative probs above threshold
116
+ threshold = mx.array(1.0 - top_p, dtype: cumulative_probs.dtype)
117
+ mask = mx.greater(cumulative_probs, threshold)
118
+ neg_inf = mx.array(-Float::INFINITY, dtype: logprobs.dtype)
119
+ mx.where(mask, logprobs, neg_inf)
120
+ end
121
+
122
+ def categorical_sampling(logits, temp)
123
+ mx = MLX::Core
124
+ mx.categorical(logits * (1.0 / temp))
125
+ end
126
+
127
+ def make_repetition_penalty(penalty, context_size = 20)
128
+ mx = MLX::Core
129
+ raise ArgumentError, "penalty must be a non-negative float" unless penalty.is_a?(Numeric) && penalty >= 0
130
+
131
+ ->(tokens, logits) {
132
+ if tokens && tokens.size > 0
133
+ recent = if tokens.is_a?(::Array)
134
+ tokens.last(context_size)
135
+ elsif tokens.respond_to?(:tolist)
136
+ tokens.tolist.last(context_size)
137
+ else
138
+ []
139
+ end
140
+ if recent.length > 0
141
+ token_indices = mx.array(recent, dtype: mx.int32)
142
+ n_tokens = recent.length
143
+ idx_2d = token_indices.reshape([1, n_tokens])
144
+ selected_logits = mx.take_along_axis(logits, idx_2d, -1)
145
+ zero = mx.array(0.0, dtype: selected_logits.dtype)
146
+ is_negative = mx.less(selected_logits, zero)
147
+ selected_logits = mx.where(
148
+ is_negative,
149
+ selected_logits * penalty,
150
+ selected_logits / penalty
151
+ )
152
+ logits = mx.put_along_axis(logits, idx_2d, selected_logits, -1)
153
+ end
154
+ end
155
+ logits
156
+ }
157
+ end
158
+ end
159
+ end
@@ -0,0 +1,190 @@
1
+ require "json"
2
+ require "securerandom"
3
+
4
+ module MlxLm
5
+ module Server
6
+ # Request schema for POST /v1/chat/completions
7
+ class ChatCompletionRequest
8
+ attr_reader :model, :messages, :max_tokens, :temperature, :top_p, :stream, :stop
9
+
10
+ def self.from_hash(h)
11
+ new(
12
+ model: h["model"],
13
+ messages: h["messages"] || [],
14
+ max_tokens: h["max_tokens"] || 256,
15
+ temperature: h["temperature"] || 0.0,
16
+ top_p: h["top_p"] || 1.0,
17
+ stream: h.fetch("stream", false),
18
+ stop: h["stop"]
19
+ )
20
+ end
21
+
22
+ def initialize(model:, messages:, max_tokens: 256, temperature: 0.0, top_p: 1.0, stream: false, stop: nil)
23
+ @model = model
24
+ @messages = messages
25
+ @max_tokens = max_tokens
26
+ @temperature = temperature
27
+ @top_p = top_p
28
+ @stream = stream
29
+ @stop = stop
30
+ end
31
+ end
32
+
33
+ # Response schema for non-streaming chat completion
34
+ class ChatCompletionResponse
35
+ def initialize(model:, content:, prompt_tokens:, completion_tokens:, finish_reason: "stop")
36
+ @model = model
37
+ @content = content
38
+ @prompt_tokens = prompt_tokens
39
+ @completion_tokens = completion_tokens
40
+ @finish_reason = finish_reason
41
+ @id = "chatcmpl-#{SecureRandom.hex(12)}"
42
+ @created = Time.now.to_i
43
+ end
44
+
45
+ def to_hash
46
+ {
47
+ "id" => @id,
48
+ "object" => "chat.completion",
49
+ "created" => @created,
50
+ "model" => @model,
51
+ "choices" => [
52
+ {
53
+ "index" => 0,
54
+ "message" => {
55
+ "role" => "assistant",
56
+ "content" => @content,
57
+ },
58
+ "finish_reason" => @finish_reason,
59
+ }
60
+ ],
61
+ "usage" => {
62
+ "prompt_tokens" => @prompt_tokens,
63
+ "completion_tokens" => @completion_tokens,
64
+ "total_tokens" => @prompt_tokens + @completion_tokens,
65
+ }
66
+ }
67
+ end
68
+
69
+ def to_json
70
+ JSON.generate(to_hash)
71
+ end
72
+ end
73
+
74
+ # Streaming chunk response
75
+ class ChatCompletionChunk
76
+ def initialize(model:, content:, finish_reason: nil)
77
+ @model = model
78
+ @content = content
79
+ @finish_reason = finish_reason
80
+ @id = "chatcmpl-#{SecureRandom.hex(12)}"
81
+ @created = Time.now.to_i
82
+ end
83
+
84
+ def to_hash
85
+ {
86
+ "id" => @id,
87
+ "object" => "chat.completion.chunk",
88
+ "created" => @created,
89
+ "model" => @model,
90
+ "choices" => [
91
+ {
92
+ "index" => 0,
93
+ "delta" => {
94
+ "content" => @content,
95
+ },
96
+ "finish_reason" => @finish_reason,
97
+ }
98
+ ]
99
+ }
100
+ end
101
+
102
+ def to_sse
103
+ "data: #{JSON.generate(to_hash)}\n\n"
104
+ end
105
+ end
106
+
107
+ # GET /v1/models response
108
+ class ModelsListResponse
109
+ def initialize(models:)
110
+ @models = models
111
+ end
112
+
113
+ def to_hash
114
+ {
115
+ "object" => "list",
116
+ "data" => @models.map { |m|
117
+ {
118
+ "id" => m,
119
+ "object" => "model",
120
+ "created" => Time.now.to_i,
121
+ "owned_by" => "mlx-lm",
122
+ }
123
+ }
124
+ }
125
+ end
126
+
127
+ def to_json
128
+ JSON.generate(to_hash)
129
+ end
130
+ end
131
+
132
+ module_function
133
+
134
+ def start(model_path:, host: "127.0.0.1", port: 8080)
135
+ require "webrick"
136
+
137
+ model, tokenizer = LoadUtils.load(model_path)
138
+
139
+ server = WEBrick::HTTPServer.new(Port: port, BindAddress: host)
140
+
141
+ server.mount_proc "/v1/models" do |req, res|
142
+ res["Content-Type"] = "application/json"
143
+ resp = ModelsListResponse.new(models: [model_path])
144
+ res.body = resp.to_json
145
+ end
146
+
147
+ server.mount_proc "/v1/chat/completions" do |req, res|
148
+ body = JSON.parse(req.body)
149
+ chat_req = ChatCompletionRequest.from_hash(body)
150
+
151
+ prompt = ChatTemplate.apply(chat_req.messages)
152
+ sampler = SampleUtils.make_sampler(temp: chat_req.temperature, top_p: chat_req.top_p)
153
+
154
+ if chat_req.stream
155
+ res["Content-Type"] = "text/event-stream"
156
+ res["Cache-Control"] = "no-cache"
157
+
158
+ res.body = Enumerator.new { |yielder|
159
+ Generate.stream_generate(model, tokenizer, prompt,
160
+ max_tokens: chat_req.max_tokens, sampler: sampler).each do |resp|
161
+ chunk = ChatCompletionChunk.new(
162
+ model: chat_req.model,
163
+ content: resp.text,
164
+ finish_reason: resp.finish_reason
165
+ )
166
+ yielder << chunk.to_sse
167
+ end
168
+ yielder << "data: [DONE]\n\n"
169
+ }
170
+ else
171
+ text = Generate.generate(model, tokenizer, prompt,
172
+ max_tokens: chat_req.max_tokens, sampler: sampler)
173
+
174
+ res["Content-Type"] = "application/json"
175
+ resp = ChatCompletionResponse.new(
176
+ model: chat_req.model,
177
+ content: text,
178
+ prompt_tokens: prompt.length,
179
+ completion_tokens: text.length,
180
+ finish_reason: "stop"
181
+ )
182
+ res.body = resp.to_json
183
+ end
184
+ end
185
+
186
+ trap("INT") { server.shutdown }
187
+ server.start
188
+ end
189
+ end
190
+ end
@@ -0,0 +1,158 @@
1
+ require "tokenizers"
2
+ require "json"
3
+
4
+ module MlxLm
5
+ # Wraps a HuggingFace tokenizer (loaded via the tokenizers gem)
6
+ # providing encode/decode and metadata access.
7
+ class TokenizerWrapper
8
+ attr_reader :tokenizer
9
+
10
+ # Can be initialized with:
11
+ # 1. A path string (directory containing tokenizer.json)
12
+ # 2. A Tokenizers::Tokenizer object (with optional eos_token/eos_token_id)
13
+ def initialize(path_or_tokenizer, eos_token: nil, eos_token_id: nil)
14
+ if path_or_tokenizer.is_a?(String)
15
+ tokenizer_json = File.join(path_or_tokenizer, "tokenizer.json")
16
+ @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
17
+
18
+ config_path = File.join(path_or_tokenizer, "tokenizer_config.json")
19
+ @config = File.exist?(config_path) ? JSON.parse(File.read(config_path)) : {}
20
+ else
21
+ @tokenizer = path_or_tokenizer
22
+ @config = {}
23
+ end
24
+
25
+ @eos_token_override = eos_token
26
+ @eos_token_id_override = eos_token_id
27
+
28
+ @_detokenizer = nil
29
+ end
30
+
31
+ def encode(text, add_special_tokens: true)
32
+ @tokenizer.encode(text, add_special_tokens: add_special_tokens).ids
33
+ end
34
+
35
+ def decode(ids, skip_special_tokens: false)
36
+ @tokenizer.decode(ids, skip_special_tokens: skip_special_tokens)
37
+ end
38
+
39
+ def eos_token
40
+ return @eos_token_override if @eos_token_override
41
+ token = @config["eos_token"]
42
+ token = token["content"] if token.is_a?(Hash)
43
+ token
44
+ end
45
+
46
+ def eos_token_id
47
+ # Try override ids first
48
+ if @eos_token_id_override && !@eos_token_id_override.empty?
49
+ return @eos_token_id_override.first
50
+ end
51
+
52
+ # Try config
53
+ if @config["eos_token"]
54
+ token = @config["eos_token"]
55
+ token = token["content"] if token.is_a?(Hash)
56
+ id = @tokenizer.token_to_id(token)
57
+ return id if id
58
+ end
59
+
60
+ # Try eos_token string override
61
+ if @eos_token_override
62
+ id = @tokenizer.token_to_id(@eos_token_override)
63
+ return id if id
64
+ end
65
+
66
+ nil
67
+ end
68
+
69
+ # Returns a Set of all EOS token IDs
70
+ def eos_token_ids
71
+ ids = Set.new
72
+ if @eos_token_id_override
73
+ @eos_token_id_override.each { |id| ids << id if id }
74
+ end
75
+ base_id = eos_token_id
76
+ ids << base_id if base_id
77
+ ids
78
+ end
79
+
80
+ def bos_token
81
+ token = @config["bos_token"]
82
+ token = token["content"] if token.is_a?(Hash)
83
+ token
84
+ end
85
+
86
+ def bos_token_id
87
+ if @config["bos_token"]
88
+ token = @config["bos_token"]
89
+ token = token["content"] if token.is_a?(Hash)
90
+ id = @tokenizer.token_to_id(token)
91
+ return id if id
92
+ end
93
+ nil
94
+ end
95
+
96
+ def vocab_size
97
+ @tokenizer.vocab_size
98
+ end
99
+
100
+ def id_to_token(id)
101
+ @tokenizer.id_to_token(id)
102
+ end
103
+
104
+ def token_to_id(token)
105
+ @tokenizer.token_to_id(token)
106
+ end
107
+
108
+ def detokenizer
109
+ @_detokenizer ||= StreamingDetokenizer.new(self)
110
+ end
111
+
112
+ def has_chat_template
113
+ !!@config["chat_template"]
114
+ end
115
+ end
116
+
117
+ # Streaming detokenizer that incrementally decodes tokens without O(T^2) cost.
118
+ # Uses a simple approach: maintain a buffer of token IDs, decode the full buffer,
119
+ # and emit only the new characters since the last decode.
120
+ class StreamingDetokenizer
121
+ attr_reader :last_segment
122
+
123
+ def initialize(tokenizer_wrapper)
124
+ @tokenizer = tokenizer_wrapper
125
+ @token_ids = []
126
+ @prev_text = ""
127
+ @last_segment = ""
128
+ end
129
+
130
+ # Add a token and record the new text segment
131
+ def add_token(token_id)
132
+ @token_ids << token_id
133
+ current_text = @tokenizer.decode(@token_ids)
134
+ @last_segment = current_text[@prev_text.length..] || ""
135
+ @prev_text = current_text
136
+ @last_segment
137
+ end
138
+
139
+ # Finalize and record any remaining text
140
+ def finalize
141
+ return "" if @token_ids.empty?
142
+ final = @tokenizer.decode(@token_ids)
143
+ @last_segment = final[@prev_text.length..] || ""
144
+ @prev_text = final
145
+ @last_segment
146
+ end
147
+
148
+ def text
149
+ @prev_text
150
+ end
151
+
152
+ def reset
153
+ @token_ids = []
154
+ @prev_text = ""
155
+ @last_segment = ""
156
+ end
157
+ end
158
+ end
@@ -0,0 +1,165 @@
1
+ module MlxLm
2
+ module Tuner
3
+ # LoRA adapter for Linear layers.
4
+ # Forward: y = linear(x) + scale * (dropout(x) @ lora_a @ lora_b)
5
+ class LoRALinear < MLX::NN::Module
6
+ def self.from_base(linear, r: 8, dropout: 0.0, scale: 20.0)
7
+ if linear.is_a?(MLX::NN::QuantizedLinear)
8
+ input_dims = linear.instance_variable_get(:@weight).shape[1] * 32 /
9
+ (linear.instance_variable_get(:@bits) || 4)
10
+ output_dims = linear.instance_variable_get(:@weight).shape[0]
11
+ bias = !linear.instance_variable_get(:@bias).nil?
12
+ else
13
+ weight = linear.weight
14
+ output_dims, input_dims = weight.shape
15
+ bias = !linear.respond_to?(:bias) || !linear.bias.nil? rescue false
16
+ end
17
+ lora = new(input_dims, output_dims, r: r, dropout: dropout, scale: scale, bias: bias)
18
+ lora.linear = linear
19
+ lora
20
+ end
21
+
22
+ def initialize(input_dims, output_dims, r: 8, dropout: 0.0, scale: 20.0, bias: false)
23
+ super()
24
+ mx = MLX::Core
25
+ @scale = scale
26
+ self.linear = MLX::NN::Linear.new(input_dims, output_dims, bias: bias)
27
+ self.dropout = MLX::NN::Dropout.new(dropout)
28
+
29
+ # Initialize LoRA matrices
30
+ lora_scale = 1.0 / Math.sqrt(input_dims)
31
+ self.lora_a = mx.random_uniform(
32
+ [input_dims, r], -lora_scale, lora_scale, mx.float32
33
+ )
34
+ self.lora_b = mx.zeros([r, output_dims])
35
+ end
36
+
37
+ def call(x)
38
+ mx = MLX::Core
39
+ y = linear.call(x)
40
+ z = dropout.call(x)
41
+ z = mx.matmul(mx.matmul(z, lora_a), lora_b)
42
+ y + z * @scale
43
+ end
44
+
45
+ def fuse(dequantize: false)
46
+ mx = MLX::Core
47
+ lin = linear
48
+
49
+ if dequantize && lin.is_a?(MLX::NN::QuantizedLinear)
50
+ lin = MlxLm::Quantize.linear_from_quantized(lin)
51
+ end
52
+
53
+ weight = lin.weight
54
+ bias_val = lin.respond_to?(:bias) ? lin.bias : nil
55
+
56
+ # Fuse: W' = W + scale * (lora_a @ lora_b)^T
57
+ lora_weight = mx.matmul(lora_a, lora_b)
58
+ fused_weight = weight + mx.transpose(lora_weight) * @scale
59
+
60
+ out_features, in_features = fused_weight.shape
61
+ result = MLX::NN::Linear.new(in_features, out_features, bias: !bias_val.nil?)
62
+ result.weight = fused_weight
63
+ result.bias = bias_val if bias_val
64
+ result
65
+ end
66
+ end
67
+
68
+ # LoRA adapter for Embedding layers.
69
+ class LoRAEmbedding < MLX::NN::Module
70
+ def self.from_base(embedding, r: 8, dropout: 0.0, scale: 20.0)
71
+ weight = embedding.weight
72
+ num_embeddings, dims = weight.shape
73
+ lora = new(num_embeddings, dims, r: r, dropout: dropout, scale: scale)
74
+ lora.embedding = embedding
75
+ lora
76
+ end
77
+
78
+ def initialize(num_embeddings, dims, r: 8, dropout: 0.0, scale: 20.0)
79
+ super()
80
+ mx = MLX::Core
81
+ @scale = scale
82
+ self.embedding = MLX::NN::Embedding.new(num_embeddings, dims)
83
+ self.dropout = MLX::NN::Dropout.new(dropout)
84
+
85
+ lora_scale = 1.0 / Math.sqrt(num_embeddings)
86
+ self.lora_a = mx.random_uniform(
87
+ [num_embeddings, r], -lora_scale, lora_scale, mx.float32
88
+ )
89
+ self.lora_b = mx.zeros([r, dims])
90
+ end
91
+
92
+ def call(x)
93
+ mx = MLX::Core
94
+ y = embedding.call(x)
95
+ # LoRA for embedding: look up lora_a rows, then multiply by lora_b
96
+ z = mx.matmul(mx.take(lora_a, x, 0), lora_b)
97
+ z = dropout.call(z)
98
+ y + z * @scale
99
+ end
100
+
101
+ def as_linear(x)
102
+ mx = MLX::Core
103
+ y = embedding.as_linear(x)
104
+ z = mx.matmul(mx.matmul(dropout.call(x), mx.transpose(lora_b)), mx.transpose(lora_a))
105
+ y + z * @scale
106
+ end
107
+
108
+ def fuse(dequantize: false)
109
+ mx = MLX::Core
110
+ embed = embedding
111
+
112
+ if dequantize && embed.is_a?(MLX::NN::QuantizedEmbedding)
113
+ embed = MlxLm::Quantize.embedding_from_quantized(embed)
114
+ end
115
+
116
+ weight = embed.weight
117
+ lora_weight = mx.matmul(lora_a, lora_b)
118
+ fused_weight = weight + lora_weight * @scale
119
+
120
+ num_embeddings, dims = fused_weight.shape
121
+ result = MLX::NN::Embedding.new(num_embeddings, dims)
122
+ result.weight = fused_weight
123
+ result
124
+ end
125
+ end
126
+
127
+ module_function
128
+
129
+ # Default LoRA target keys (layer names that get LoRA applied)
130
+ DEFAULT_LORA_KEYS = %w[self_attn.q_proj self_attn.k_proj self_attn.v_proj].freeze
131
+
132
+ # Apply LoRA layers to a model's last N layers.
133
+ def apply_lora_layers(model, num_layers: nil, config: {})
134
+ rank = config["rank"] || config[:rank] || 8
135
+ scale = config["scale"] || config[:scale] || 20.0
136
+ dropout = config["dropout"] || config[:dropout] || 0.0
137
+ keys = config["keys"] || config[:keys] || DEFAULT_LORA_KEYS
138
+
139
+ layers = model.layers
140
+ num_layers ||= layers.length
141
+ target_layers = layers.last(num_layers)
142
+
143
+ target_layers.each do |layer|
144
+ _apply_lora_to_module(layer, "", keys, rank: rank, scale: scale, dropout: dropout)
145
+ end
146
+ end
147
+
148
+ def _apply_lora_to_module(mod, prefix, keys, rank:, scale:, dropout:)
149
+ mod.state.each do |key, value|
150
+ full_key = prefix.empty? ? key : "#{prefix}.#{key}"
151
+
152
+ if value.is_a?(MLX::NN::Linear) && keys.any? { |k| full_key.end_with?(k) || full_key.include?(k) }
153
+ lora = LoRALinear.from_base(value, r: rank, scale: scale, dropout: dropout)
154
+ mod.state[key] = lora
155
+ elsif value.is_a?(MLX::NN::Embedding) && keys.any? { |k| full_key.end_with?(k) || full_key.include?(k) }
156
+ lora = LoRAEmbedding.from_base(value, r: rank, scale: scale, dropout: dropout)
157
+ mod.state[key] = lora
158
+ elsif value.is_a?(MLX::NN::Module)
159
+ _apply_lora_to_module(value, full_key, keys, rank: rank, scale: scale, dropout: dropout)
160
+ end
161
+ end
162
+ end
163
+ module_function :_apply_lora_to_module
164
+ end
165
+ end
@@ -0,0 +1,3 @@
1
+ module MlxLm
2
+ VERSION = "0.30.7.1"
3
+ end