mlx-ruby-lm 0.30.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +21 -0
- data/README.md +83 -0
- data/exe/mlx_lm +7 -0
- data/lib/mlx_lm/benchmark.rb +67 -0
- data/lib/mlx_lm/chat_template.rb +41 -0
- data/lib/mlx_lm/cli.rb +113 -0
- data/lib/mlx_lm/config.rb +30 -0
- data/lib/mlx_lm/convert_utils.rb +51 -0
- data/lib/mlx_lm/generate.rb +204 -0
- data/lib/mlx_lm/load_utils.rb +87 -0
- data/lib/mlx_lm/model_args.rb +54 -0
- data/lib/mlx_lm/models/activations.rb +46 -0
- data/lib/mlx_lm/models/afm7.rb +131 -0
- data/lib/mlx_lm/models/afmoe.rb +421 -0
- data/lib/mlx_lm/models/apertus.rb +179 -0
- data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
- data/lib/mlx_lm/models/bailing_moe.rb +399 -0
- data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
- data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
- data/lib/mlx_lm/models/bitnet.rb +176 -0
- data/lib/mlx_lm/models/cache.rb +792 -0
- data/lib/mlx_lm/models/cohere.rb +150 -0
- data/lib/mlx_lm/models/cohere2.rb +224 -0
- data/lib/mlx_lm/models/dbrx.rb +286 -0
- data/lib/mlx_lm/models/deepseek.rb +239 -0
- data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
- data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
- data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
- data/lib/mlx_lm/models/dots1.rb +292 -0
- data/lib/mlx_lm/models/ernie4_5.rb +165 -0
- data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
- data/lib/mlx_lm/models/exaone.rb +169 -0
- data/lib/mlx_lm/models/exaone4.rb +233 -0
- data/lib/mlx_lm/models/exaone_moe.rb +421 -0
- data/lib/mlx_lm/models/falcon_h1.rb +102 -0
- data/lib/mlx_lm/models/gated_delta.rb +136 -0
- data/lib/mlx_lm/models/gemma.rb +159 -0
- data/lib/mlx_lm/models/gemma2.rb +198 -0
- data/lib/mlx_lm/models/gemma3.rb +85 -0
- data/lib/mlx_lm/models/gemma3_text.rb +270 -0
- data/lib/mlx_lm/models/gemma3n.rb +79 -0
- data/lib/mlx_lm/models/glm.rb +164 -0
- data/lib/mlx_lm/models/glm4.rb +180 -0
- data/lib/mlx_lm/models/glm4_moe.rb +343 -0
- data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
- data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
- data/lib/mlx_lm/models/gpt2.rb +166 -0
- data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
- data/lib/mlx_lm/models/gpt_neox.rb +178 -0
- data/lib/mlx_lm/models/gpt_oss.rb +319 -0
- data/lib/mlx_lm/models/granite.rb +170 -0
- data/lib/mlx_lm/models/granitemoe.rb +58 -0
- data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
- data/lib/mlx_lm/models/helium.rb +158 -0
- data/lib/mlx_lm/models/hunyuan.rb +378 -0
- data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
- data/lib/mlx_lm/models/internlm2.rb +160 -0
- data/lib/mlx_lm/models/internlm3.rb +237 -0
- data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
- data/lib/mlx_lm/models/jamba.rb +158 -0
- data/lib/mlx_lm/models/kimi_k25.rb +98 -0
- data/lib/mlx_lm/models/kimi_linear.rb +124 -0
- data/lib/mlx_lm/models/kimi_vl.rb +93 -0
- data/lib/mlx_lm/models/klear.rb +283 -0
- data/lib/mlx_lm/models/lfm2.rb +120 -0
- data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
- data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
- data/lib/mlx_lm/models/lille_130m.rb +148 -0
- data/lib/mlx_lm/models/llama.rb +183 -0
- data/lib/mlx_lm/models/llama4.rb +357 -0
- data/lib/mlx_lm/models/llama4_text.rb +195 -0
- data/lib/mlx_lm/models/longcat_flash.rb +153 -0
- data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
- data/lib/mlx_lm/models/mamba.rb +301 -0
- data/lib/mlx_lm/models/mamba2.rb +292 -0
- data/lib/mlx_lm/models/mimo.rb +174 -0
- data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
- data/lib/mlx_lm/models/minicpm.rb +169 -0
- data/lib/mlx_lm/models/minicpm3.rb +237 -0
- data/lib/mlx_lm/models/minimax.rb +282 -0
- data/lib/mlx_lm/models/ministral3.rb +304 -0
- data/lib/mlx_lm/models/mistral3.rb +84 -0
- data/lib/mlx_lm/models/mixtral.rb +192 -0
- data/lib/mlx_lm/models/mla.rb +75 -0
- data/lib/mlx_lm/models/nanochat.rb +167 -0
- data/lib/mlx_lm/models/nemotron.rb +202 -0
- data/lib/mlx_lm/models/nemotron_h.rb +212 -0
- data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
- data/lib/mlx_lm/models/olmo.rb +165 -0
- data/lib/mlx_lm/models/olmo2.rb +169 -0
- data/lib/mlx_lm/models/olmo3.rb +254 -0
- data/lib/mlx_lm/models/olmoe.rb +64 -0
- data/lib/mlx_lm/models/openelm.rb +208 -0
- data/lib/mlx_lm/models/phi.rb +156 -0
- data/lib/mlx_lm/models/phi3.rb +171 -0
- data/lib/mlx_lm/models/phi3small.rb +196 -0
- data/lib/mlx_lm/models/phimoe.rb +206 -0
- data/lib/mlx_lm/models/phixtral.rb +208 -0
- data/lib/mlx_lm/models/pipeline.rb +37 -0
- data/lib/mlx_lm/models/pixtral.rb +47 -0
- data/lib/mlx_lm/models/plamo.rb +169 -0
- data/lib/mlx_lm/models/plamo2.rb +173 -0
- data/lib/mlx_lm/models/qwen.rb +175 -0
- data/lib/mlx_lm/models/qwen2.rb +162 -0
- data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
- data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
- data/lib/mlx_lm/models/qwen3.rb +167 -0
- data/lib/mlx_lm/models/qwen3_5.rb +69 -0
- data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
- data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
- data/lib/mlx_lm/models/qwen3_next.rb +147 -0
- data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
- data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
- data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
- data/lib/mlx_lm/models/rope_utils.rb +316 -0
- data/lib/mlx_lm/models/rwkv7.rb +101 -0
- data/lib/mlx_lm/models/seed_oss.rb +167 -0
- data/lib/mlx_lm/models/smollm3.rb +89 -0
- data/lib/mlx_lm/models/solar_open.rb +79 -0
- data/lib/mlx_lm/models/ssm.rb +162 -0
- data/lib/mlx_lm/models/stablelm.rb +160 -0
- data/lib/mlx_lm/models/starcoder2.rb +161 -0
- data/lib/mlx_lm/models/step3p5.rb +479 -0
- data/lib/mlx_lm/models/switch_layers.rb +221 -0
- data/lib/mlx_lm/models/telechat3.rb +192 -0
- data/lib/mlx_lm/models/youtu_llm.rb +230 -0
- data/lib/mlx_lm/models.rb +33 -0
- data/lib/mlx_lm/perplexity.rb +48 -0
- data/lib/mlx_lm/quantize.rb +131 -0
- data/lib/mlx_lm/sample_utils.rb +159 -0
- data/lib/mlx_lm/server.rb +190 -0
- data/lib/mlx_lm/tokenizer_utils.rb +158 -0
- data/lib/mlx_lm/tuner/lora.rb +165 -0
- data/lib/mlx_lm/version.rb +3 -0
- data/lib/mlx_lm/weight_utils.rb +170 -0
- data/lib/mlx_lm.rb +135 -0
- metadata +272 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
module SampleUtils
|
|
3
|
+
module_function
|
|
4
|
+
|
|
5
|
+
# Build a sampler callable (proc) from the given parameters.
|
|
6
|
+
# Returns a proc that takes logprobs (mx.array) and returns a token (mx.array).
|
|
7
|
+
def make_sampler(
|
|
8
|
+
temp: 0.0,
|
|
9
|
+
top_p: 0.0,
|
|
10
|
+
min_p: 0.0,
|
|
11
|
+
min_tokens_to_keep: 1,
|
|
12
|
+
top_k: 0
|
|
13
|
+
)
|
|
14
|
+
mx = MLX::Core
|
|
15
|
+
|
|
16
|
+
if temp == 0
|
|
17
|
+
return ->(x) { mx.argmax(x, -1) }
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
sampling_methods = []
|
|
21
|
+
if top_p > 0 && top_p < 1.0
|
|
22
|
+
sampling_methods << ->(x) { apply_top_p(x, top_p) }
|
|
23
|
+
end
|
|
24
|
+
if min_p != 0.0
|
|
25
|
+
sampling_methods << ->(x) { apply_min_p(x, min_p, min_tokens_to_keep) }
|
|
26
|
+
end
|
|
27
|
+
if top_k > 0
|
|
28
|
+
sampling_methods << ->(x) { apply_top_k(x, top_k) }
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
->(logprobs) {
|
|
32
|
+
sampling_methods.each { |method| logprobs = method.call(logprobs) }
|
|
33
|
+
categorical_sampling(logprobs, temp)
|
|
34
|
+
}
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def make_logits_processors(repetition_penalty: nil, repetition_context_size: 20)
|
|
38
|
+
processors = []
|
|
39
|
+
if repetition_penalty && repetition_penalty != 0.0
|
|
40
|
+
processors << make_repetition_penalty(repetition_penalty, repetition_context_size)
|
|
41
|
+
end
|
|
42
|
+
processors
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def apply_top_k(logprobs, top_k)
|
|
46
|
+
mx = MLX::Core
|
|
47
|
+
vocab_size = logprobs.shape[-1]
|
|
48
|
+
raise ArgumentError, "top_k must be in (0, #{vocab_size}]" unless top_k.is_a?(Integer) && top_k > 0 && top_k < vocab_size
|
|
49
|
+
|
|
50
|
+
neg_logprobs = mx.negative(logprobs)
|
|
51
|
+
mask_idx = mx.argpartition(neg_logprobs, top_k - 1, -1)
|
|
52
|
+
# Get indices after top_k (the ones to mask)
|
|
53
|
+
rest = mx.split(mask_idx, [top_k], -1)[1]
|
|
54
|
+
neg_inf = mx.array([-Float::INFINITY], dtype: logprobs.dtype)
|
|
55
|
+
mx.put_along_axis(logprobs, rest, neg_inf, -1)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def apply_min_p(logprobs, min_p, min_tokens_to_keep = 1)
|
|
59
|
+
mx = MLX::Core
|
|
60
|
+
raise ArgumentError, "min_p must be in [0, 1]" unless min_p >= 0 && min_p <= 1.0
|
|
61
|
+
|
|
62
|
+
# Sort indices in decreasing order
|
|
63
|
+
neg_logprobs = mx.negative(logprobs)
|
|
64
|
+
sorted_indices = mx.argsort(neg_logprobs, -1)
|
|
65
|
+
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, -1)
|
|
66
|
+
|
|
67
|
+
# Top probability
|
|
68
|
+
top_logprobs = mx.split(sorted_logprobs, [1], -1)[0]
|
|
69
|
+
|
|
70
|
+
# Calculate the min_p threshold
|
|
71
|
+
scaled_min_p = top_logprobs + Math.log(min_p)
|
|
72
|
+
|
|
73
|
+
# Mask tokens below threshold
|
|
74
|
+
tokens_to_remove = mx.less(sorted_logprobs, scaled_min_p)
|
|
75
|
+
|
|
76
|
+
neg_inf = mx.array(-Float::INFINITY, dtype: sorted_logprobs.dtype)
|
|
77
|
+
selected_logprobs = mx.where(tokens_to_remove, neg_inf, sorted_logprobs)
|
|
78
|
+
|
|
79
|
+
# Restore the top min_tokens_to_keep tokens regardless
|
|
80
|
+
if min_tokens_to_keep > 0
|
|
81
|
+
top_sorted = mx.split(sorted_logprobs, [min_tokens_to_keep], -1)[0]
|
|
82
|
+
rest_selected = mx.split(selected_logprobs, [min_tokens_to_keep], -1)[1]
|
|
83
|
+
selected_logprobs = mx.concatenate([top_sorted, rest_selected], -1)
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
# Create inverse mapping to restore original order
|
|
87
|
+
inverse_indices = mx.put_along_axis(
|
|
88
|
+
mx.zeros_like(sorted_indices),
|
|
89
|
+
sorted_indices,
|
|
90
|
+
mx.arange(sorted_indices.shape[-1]).astype(sorted_indices.dtype),
|
|
91
|
+
-1
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
mx.take_along_axis(selected_logprobs, inverse_indices, -1)
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def apply_top_p(logprobs, top_p)
|
|
98
|
+
mx = MLX::Core
|
|
99
|
+
probs = mx.exp(logprobs)
|
|
100
|
+
# sort in ascending order
|
|
101
|
+
sorted_indices = mx.argsort(logprobs, -1)
|
|
102
|
+
sorted_probs = mx.take_along_axis(probs, sorted_indices, -1)
|
|
103
|
+
|
|
104
|
+
cumulative_probs = mx.cumsum(sorted_probs, -1)
|
|
105
|
+
|
|
106
|
+
# Rearrange cumulative probs back to original order
|
|
107
|
+
inverse_indices = mx.put_along_axis(
|
|
108
|
+
mx.zeros_like(sorted_indices),
|
|
109
|
+
sorted_indices,
|
|
110
|
+
mx.arange(sorted_indices.shape[-1]).astype(sorted_indices.dtype),
|
|
111
|
+
-1
|
|
112
|
+
)
|
|
113
|
+
cumulative_probs = mx.take_along_axis(cumulative_probs, inverse_indices, -1)
|
|
114
|
+
|
|
115
|
+
# select tokens with cumulative probs above threshold
|
|
116
|
+
threshold = mx.array(1.0 - top_p, dtype: cumulative_probs.dtype)
|
|
117
|
+
mask = mx.greater(cumulative_probs, threshold)
|
|
118
|
+
neg_inf = mx.array(-Float::INFINITY, dtype: logprobs.dtype)
|
|
119
|
+
mx.where(mask, logprobs, neg_inf)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
def categorical_sampling(logits, temp)
|
|
123
|
+
mx = MLX::Core
|
|
124
|
+
mx.categorical(logits * (1.0 / temp))
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def make_repetition_penalty(penalty, context_size = 20)
|
|
128
|
+
mx = MLX::Core
|
|
129
|
+
raise ArgumentError, "penalty must be a non-negative float" unless penalty.is_a?(Numeric) && penalty >= 0
|
|
130
|
+
|
|
131
|
+
->(tokens, logits) {
|
|
132
|
+
if tokens && tokens.size > 0
|
|
133
|
+
recent = if tokens.is_a?(::Array)
|
|
134
|
+
tokens.last(context_size)
|
|
135
|
+
elsif tokens.respond_to?(:tolist)
|
|
136
|
+
tokens.tolist.last(context_size)
|
|
137
|
+
else
|
|
138
|
+
[]
|
|
139
|
+
end
|
|
140
|
+
if recent.length > 0
|
|
141
|
+
token_indices = mx.array(recent, dtype: mx.int32)
|
|
142
|
+
n_tokens = recent.length
|
|
143
|
+
idx_2d = token_indices.reshape([1, n_tokens])
|
|
144
|
+
selected_logits = mx.take_along_axis(logits, idx_2d, -1)
|
|
145
|
+
zero = mx.array(0.0, dtype: selected_logits.dtype)
|
|
146
|
+
is_negative = mx.less(selected_logits, zero)
|
|
147
|
+
selected_logits = mx.where(
|
|
148
|
+
is_negative,
|
|
149
|
+
selected_logits * penalty,
|
|
150
|
+
selected_logits / penalty
|
|
151
|
+
)
|
|
152
|
+
logits = mx.put_along_axis(logits, idx_2d, selected_logits, -1)
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
logits
|
|
156
|
+
}
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
end
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
require "json"
|
|
2
|
+
require "securerandom"
|
|
3
|
+
|
|
4
|
+
module MlxLm
|
|
5
|
+
module Server
|
|
6
|
+
# Request schema for POST /v1/chat/completions
|
|
7
|
+
class ChatCompletionRequest
|
|
8
|
+
attr_reader :model, :messages, :max_tokens, :temperature, :top_p, :stream, :stop
|
|
9
|
+
|
|
10
|
+
def self.from_hash(h)
|
|
11
|
+
new(
|
|
12
|
+
model: h["model"],
|
|
13
|
+
messages: h["messages"] || [],
|
|
14
|
+
max_tokens: h["max_tokens"] || 256,
|
|
15
|
+
temperature: h["temperature"] || 0.0,
|
|
16
|
+
top_p: h["top_p"] || 1.0,
|
|
17
|
+
stream: h.fetch("stream", false),
|
|
18
|
+
stop: h["stop"]
|
|
19
|
+
)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def initialize(model:, messages:, max_tokens: 256, temperature: 0.0, top_p: 1.0, stream: false, stop: nil)
|
|
23
|
+
@model = model
|
|
24
|
+
@messages = messages
|
|
25
|
+
@max_tokens = max_tokens
|
|
26
|
+
@temperature = temperature
|
|
27
|
+
@top_p = top_p
|
|
28
|
+
@stream = stream
|
|
29
|
+
@stop = stop
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Response schema for non-streaming chat completion
|
|
34
|
+
class ChatCompletionResponse
|
|
35
|
+
def initialize(model:, content:, prompt_tokens:, completion_tokens:, finish_reason: "stop")
|
|
36
|
+
@model = model
|
|
37
|
+
@content = content
|
|
38
|
+
@prompt_tokens = prompt_tokens
|
|
39
|
+
@completion_tokens = completion_tokens
|
|
40
|
+
@finish_reason = finish_reason
|
|
41
|
+
@id = "chatcmpl-#{SecureRandom.hex(12)}"
|
|
42
|
+
@created = Time.now.to_i
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def to_hash
|
|
46
|
+
{
|
|
47
|
+
"id" => @id,
|
|
48
|
+
"object" => "chat.completion",
|
|
49
|
+
"created" => @created,
|
|
50
|
+
"model" => @model,
|
|
51
|
+
"choices" => [
|
|
52
|
+
{
|
|
53
|
+
"index" => 0,
|
|
54
|
+
"message" => {
|
|
55
|
+
"role" => "assistant",
|
|
56
|
+
"content" => @content,
|
|
57
|
+
},
|
|
58
|
+
"finish_reason" => @finish_reason,
|
|
59
|
+
}
|
|
60
|
+
],
|
|
61
|
+
"usage" => {
|
|
62
|
+
"prompt_tokens" => @prompt_tokens,
|
|
63
|
+
"completion_tokens" => @completion_tokens,
|
|
64
|
+
"total_tokens" => @prompt_tokens + @completion_tokens,
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def to_json
|
|
70
|
+
JSON.generate(to_hash)
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Streaming chunk response
|
|
75
|
+
class ChatCompletionChunk
|
|
76
|
+
def initialize(model:, content:, finish_reason: nil)
|
|
77
|
+
@model = model
|
|
78
|
+
@content = content
|
|
79
|
+
@finish_reason = finish_reason
|
|
80
|
+
@id = "chatcmpl-#{SecureRandom.hex(12)}"
|
|
81
|
+
@created = Time.now.to_i
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def to_hash
|
|
85
|
+
{
|
|
86
|
+
"id" => @id,
|
|
87
|
+
"object" => "chat.completion.chunk",
|
|
88
|
+
"created" => @created,
|
|
89
|
+
"model" => @model,
|
|
90
|
+
"choices" => [
|
|
91
|
+
{
|
|
92
|
+
"index" => 0,
|
|
93
|
+
"delta" => {
|
|
94
|
+
"content" => @content,
|
|
95
|
+
},
|
|
96
|
+
"finish_reason" => @finish_reason,
|
|
97
|
+
}
|
|
98
|
+
]
|
|
99
|
+
}
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def to_sse
|
|
103
|
+
"data: #{JSON.generate(to_hash)}\n\n"
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
# GET /v1/models response
|
|
108
|
+
class ModelsListResponse
|
|
109
|
+
def initialize(models:)
|
|
110
|
+
@models = models
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def to_hash
|
|
114
|
+
{
|
|
115
|
+
"object" => "list",
|
|
116
|
+
"data" => @models.map { |m|
|
|
117
|
+
{
|
|
118
|
+
"id" => m,
|
|
119
|
+
"object" => "model",
|
|
120
|
+
"created" => Time.now.to_i,
|
|
121
|
+
"owned_by" => "mlx-lm",
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def to_json
|
|
128
|
+
JSON.generate(to_hash)
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
module_function
|
|
133
|
+
|
|
134
|
+
def start(model_path:, host: "127.0.0.1", port: 8080)
|
|
135
|
+
require "webrick"
|
|
136
|
+
|
|
137
|
+
model, tokenizer = LoadUtils.load(model_path)
|
|
138
|
+
|
|
139
|
+
server = WEBrick::HTTPServer.new(Port: port, BindAddress: host)
|
|
140
|
+
|
|
141
|
+
server.mount_proc "/v1/models" do |req, res|
|
|
142
|
+
res["Content-Type"] = "application/json"
|
|
143
|
+
resp = ModelsListResponse.new(models: [model_path])
|
|
144
|
+
res.body = resp.to_json
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
server.mount_proc "/v1/chat/completions" do |req, res|
|
|
148
|
+
body = JSON.parse(req.body)
|
|
149
|
+
chat_req = ChatCompletionRequest.from_hash(body)
|
|
150
|
+
|
|
151
|
+
prompt = ChatTemplate.apply(chat_req.messages)
|
|
152
|
+
sampler = SampleUtils.make_sampler(temp: chat_req.temperature, top_p: chat_req.top_p)
|
|
153
|
+
|
|
154
|
+
if chat_req.stream
|
|
155
|
+
res["Content-Type"] = "text/event-stream"
|
|
156
|
+
res["Cache-Control"] = "no-cache"
|
|
157
|
+
|
|
158
|
+
res.body = Enumerator.new { |yielder|
|
|
159
|
+
Generate.stream_generate(model, tokenizer, prompt,
|
|
160
|
+
max_tokens: chat_req.max_tokens, sampler: sampler).each do |resp|
|
|
161
|
+
chunk = ChatCompletionChunk.new(
|
|
162
|
+
model: chat_req.model,
|
|
163
|
+
content: resp.text,
|
|
164
|
+
finish_reason: resp.finish_reason
|
|
165
|
+
)
|
|
166
|
+
yielder << chunk.to_sse
|
|
167
|
+
end
|
|
168
|
+
yielder << "data: [DONE]\n\n"
|
|
169
|
+
}
|
|
170
|
+
else
|
|
171
|
+
text = Generate.generate(model, tokenizer, prompt,
|
|
172
|
+
max_tokens: chat_req.max_tokens, sampler: sampler)
|
|
173
|
+
|
|
174
|
+
res["Content-Type"] = "application/json"
|
|
175
|
+
resp = ChatCompletionResponse.new(
|
|
176
|
+
model: chat_req.model,
|
|
177
|
+
content: text,
|
|
178
|
+
prompt_tokens: prompt.length,
|
|
179
|
+
completion_tokens: text.length,
|
|
180
|
+
finish_reason: "stop"
|
|
181
|
+
)
|
|
182
|
+
res.body = resp.to_json
|
|
183
|
+
end
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
trap("INT") { server.shutdown }
|
|
187
|
+
server.start
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
end
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
require "tokenizers"
|
|
2
|
+
require "json"
|
|
3
|
+
|
|
4
|
+
module MlxLm
|
|
5
|
+
# Wraps a HuggingFace tokenizer (loaded via the tokenizers gem)
|
|
6
|
+
# providing encode/decode and metadata access.
|
|
7
|
+
class TokenizerWrapper
|
|
8
|
+
attr_reader :tokenizer
|
|
9
|
+
|
|
10
|
+
# Can be initialized with:
|
|
11
|
+
# 1. A path string (directory containing tokenizer.json)
|
|
12
|
+
# 2. A Tokenizers::Tokenizer object (with optional eos_token/eos_token_id)
|
|
13
|
+
def initialize(path_or_tokenizer, eos_token: nil, eos_token_id: nil)
|
|
14
|
+
if path_or_tokenizer.is_a?(String)
|
|
15
|
+
tokenizer_json = File.join(path_or_tokenizer, "tokenizer.json")
|
|
16
|
+
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
|
|
17
|
+
|
|
18
|
+
config_path = File.join(path_or_tokenizer, "tokenizer_config.json")
|
|
19
|
+
@config = File.exist?(config_path) ? JSON.parse(File.read(config_path)) : {}
|
|
20
|
+
else
|
|
21
|
+
@tokenizer = path_or_tokenizer
|
|
22
|
+
@config = {}
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
@eos_token_override = eos_token
|
|
26
|
+
@eos_token_id_override = eos_token_id
|
|
27
|
+
|
|
28
|
+
@_detokenizer = nil
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def encode(text, add_special_tokens: true)
|
|
32
|
+
@tokenizer.encode(text, add_special_tokens: add_special_tokens).ids
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def decode(ids, skip_special_tokens: false)
|
|
36
|
+
@tokenizer.decode(ids, skip_special_tokens: skip_special_tokens)
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def eos_token
|
|
40
|
+
return @eos_token_override if @eos_token_override
|
|
41
|
+
token = @config["eos_token"]
|
|
42
|
+
token = token["content"] if token.is_a?(Hash)
|
|
43
|
+
token
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def eos_token_id
|
|
47
|
+
# Try override ids first
|
|
48
|
+
if @eos_token_id_override && !@eos_token_id_override.empty?
|
|
49
|
+
return @eos_token_id_override.first
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Try config
|
|
53
|
+
if @config["eos_token"]
|
|
54
|
+
token = @config["eos_token"]
|
|
55
|
+
token = token["content"] if token.is_a?(Hash)
|
|
56
|
+
id = @tokenizer.token_to_id(token)
|
|
57
|
+
return id if id
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
# Try eos_token string override
|
|
61
|
+
if @eos_token_override
|
|
62
|
+
id = @tokenizer.token_to_id(@eos_token_override)
|
|
63
|
+
return id if id
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
nil
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# Returns a Set of all EOS token IDs
|
|
70
|
+
def eos_token_ids
|
|
71
|
+
ids = Set.new
|
|
72
|
+
if @eos_token_id_override
|
|
73
|
+
@eos_token_id_override.each { |id| ids << id if id }
|
|
74
|
+
end
|
|
75
|
+
base_id = eos_token_id
|
|
76
|
+
ids << base_id if base_id
|
|
77
|
+
ids
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
def bos_token
|
|
81
|
+
token = @config["bos_token"]
|
|
82
|
+
token = token["content"] if token.is_a?(Hash)
|
|
83
|
+
token
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def bos_token_id
|
|
87
|
+
if @config["bos_token"]
|
|
88
|
+
token = @config["bos_token"]
|
|
89
|
+
token = token["content"] if token.is_a?(Hash)
|
|
90
|
+
id = @tokenizer.token_to_id(token)
|
|
91
|
+
return id if id
|
|
92
|
+
end
|
|
93
|
+
nil
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def vocab_size
|
|
97
|
+
@tokenizer.vocab_size
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def id_to_token(id)
|
|
101
|
+
@tokenizer.id_to_token(id)
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
def token_to_id(token)
|
|
105
|
+
@tokenizer.token_to_id(token)
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def detokenizer
|
|
109
|
+
@_detokenizer ||= StreamingDetokenizer.new(self)
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
def has_chat_template
|
|
113
|
+
!!@config["chat_template"]
|
|
114
|
+
end
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
# Streaming detokenizer that incrementally decodes tokens without O(T^2) cost.
|
|
118
|
+
# Uses a simple approach: maintain a buffer of token IDs, decode the full buffer,
|
|
119
|
+
# and emit only the new characters since the last decode.
|
|
120
|
+
class StreamingDetokenizer
|
|
121
|
+
attr_reader :last_segment
|
|
122
|
+
|
|
123
|
+
def initialize(tokenizer_wrapper)
|
|
124
|
+
@tokenizer = tokenizer_wrapper
|
|
125
|
+
@token_ids = []
|
|
126
|
+
@prev_text = ""
|
|
127
|
+
@last_segment = ""
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Add a token and record the new text segment
|
|
131
|
+
def add_token(token_id)
|
|
132
|
+
@token_ids << token_id
|
|
133
|
+
current_text = @tokenizer.decode(@token_ids)
|
|
134
|
+
@last_segment = current_text[@prev_text.length..] || ""
|
|
135
|
+
@prev_text = current_text
|
|
136
|
+
@last_segment
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
# Finalize and record any remaining text
|
|
140
|
+
def finalize
|
|
141
|
+
return "" if @token_ids.empty?
|
|
142
|
+
final = @tokenizer.decode(@token_ids)
|
|
143
|
+
@last_segment = final[@prev_text.length..] || ""
|
|
144
|
+
@prev_text = final
|
|
145
|
+
@last_segment
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def text
|
|
149
|
+
@prev_text
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
def reset
|
|
153
|
+
@token_ids = []
|
|
154
|
+
@prev_text = ""
|
|
155
|
+
@last_segment = ""
|
|
156
|
+
end
|
|
157
|
+
end
|
|
158
|
+
end
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
module Tuner
|
|
3
|
+
# LoRA adapter for Linear layers.
|
|
4
|
+
# Forward: y = linear(x) + scale * (dropout(x) @ lora_a @ lora_b)
|
|
5
|
+
class LoRALinear < MLX::NN::Module
|
|
6
|
+
def self.from_base(linear, r: 8, dropout: 0.0, scale: 20.0)
|
|
7
|
+
if linear.is_a?(MLX::NN::QuantizedLinear)
|
|
8
|
+
input_dims = linear.instance_variable_get(:@weight).shape[1] * 32 /
|
|
9
|
+
(linear.instance_variable_get(:@bits) || 4)
|
|
10
|
+
output_dims = linear.instance_variable_get(:@weight).shape[0]
|
|
11
|
+
bias = !linear.instance_variable_get(:@bias).nil?
|
|
12
|
+
else
|
|
13
|
+
weight = linear.weight
|
|
14
|
+
output_dims, input_dims = weight.shape
|
|
15
|
+
bias = !linear.respond_to?(:bias) || !linear.bias.nil? rescue false
|
|
16
|
+
end
|
|
17
|
+
lora = new(input_dims, output_dims, r: r, dropout: dropout, scale: scale, bias: bias)
|
|
18
|
+
lora.linear = linear
|
|
19
|
+
lora
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def initialize(input_dims, output_dims, r: 8, dropout: 0.0, scale: 20.0, bias: false)
|
|
23
|
+
super()
|
|
24
|
+
mx = MLX::Core
|
|
25
|
+
@scale = scale
|
|
26
|
+
self.linear = MLX::NN::Linear.new(input_dims, output_dims, bias: bias)
|
|
27
|
+
self.dropout = MLX::NN::Dropout.new(dropout)
|
|
28
|
+
|
|
29
|
+
# Initialize LoRA matrices
|
|
30
|
+
lora_scale = 1.0 / Math.sqrt(input_dims)
|
|
31
|
+
self.lora_a = mx.random_uniform(
|
|
32
|
+
[input_dims, r], -lora_scale, lora_scale, mx.float32
|
|
33
|
+
)
|
|
34
|
+
self.lora_b = mx.zeros([r, output_dims])
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def call(x)
|
|
38
|
+
mx = MLX::Core
|
|
39
|
+
y = linear.call(x)
|
|
40
|
+
z = dropout.call(x)
|
|
41
|
+
z = mx.matmul(mx.matmul(z, lora_a), lora_b)
|
|
42
|
+
y + z * @scale
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def fuse(dequantize: false)
|
|
46
|
+
mx = MLX::Core
|
|
47
|
+
lin = linear
|
|
48
|
+
|
|
49
|
+
if dequantize && lin.is_a?(MLX::NN::QuantizedLinear)
|
|
50
|
+
lin = MlxLm::Quantize.linear_from_quantized(lin)
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
weight = lin.weight
|
|
54
|
+
bias_val = lin.respond_to?(:bias) ? lin.bias : nil
|
|
55
|
+
|
|
56
|
+
# Fuse: W' = W + scale * (lora_a @ lora_b)^T
|
|
57
|
+
lora_weight = mx.matmul(lora_a, lora_b)
|
|
58
|
+
fused_weight = weight + mx.transpose(lora_weight) * @scale
|
|
59
|
+
|
|
60
|
+
out_features, in_features = fused_weight.shape
|
|
61
|
+
result = MLX::NN::Linear.new(in_features, out_features, bias: !bias_val.nil?)
|
|
62
|
+
result.weight = fused_weight
|
|
63
|
+
result.bias = bias_val if bias_val
|
|
64
|
+
result
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# LoRA adapter for Embedding layers.
|
|
69
|
+
class LoRAEmbedding < MLX::NN::Module
|
|
70
|
+
def self.from_base(embedding, r: 8, dropout: 0.0, scale: 20.0)
|
|
71
|
+
weight = embedding.weight
|
|
72
|
+
num_embeddings, dims = weight.shape
|
|
73
|
+
lora = new(num_embeddings, dims, r: r, dropout: dropout, scale: scale)
|
|
74
|
+
lora.embedding = embedding
|
|
75
|
+
lora
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
def initialize(num_embeddings, dims, r: 8, dropout: 0.0, scale: 20.0)
|
|
79
|
+
super()
|
|
80
|
+
mx = MLX::Core
|
|
81
|
+
@scale = scale
|
|
82
|
+
self.embedding = MLX::NN::Embedding.new(num_embeddings, dims)
|
|
83
|
+
self.dropout = MLX::NN::Dropout.new(dropout)
|
|
84
|
+
|
|
85
|
+
lora_scale = 1.0 / Math.sqrt(num_embeddings)
|
|
86
|
+
self.lora_a = mx.random_uniform(
|
|
87
|
+
[num_embeddings, r], -lora_scale, lora_scale, mx.float32
|
|
88
|
+
)
|
|
89
|
+
self.lora_b = mx.zeros([r, dims])
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def call(x)
|
|
93
|
+
mx = MLX::Core
|
|
94
|
+
y = embedding.call(x)
|
|
95
|
+
# LoRA for embedding: look up lora_a rows, then multiply by lora_b
|
|
96
|
+
z = mx.matmul(mx.take(lora_a, x, 0), lora_b)
|
|
97
|
+
z = dropout.call(z)
|
|
98
|
+
y + z * @scale
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def as_linear(x)
|
|
102
|
+
mx = MLX::Core
|
|
103
|
+
y = embedding.as_linear(x)
|
|
104
|
+
z = mx.matmul(mx.matmul(dropout.call(x), mx.transpose(lora_b)), mx.transpose(lora_a))
|
|
105
|
+
y + z * @scale
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def fuse(dequantize: false)
|
|
109
|
+
mx = MLX::Core
|
|
110
|
+
embed = embedding
|
|
111
|
+
|
|
112
|
+
if dequantize && embed.is_a?(MLX::NN::QuantizedEmbedding)
|
|
113
|
+
embed = MlxLm::Quantize.embedding_from_quantized(embed)
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
weight = embed.weight
|
|
117
|
+
lora_weight = mx.matmul(lora_a, lora_b)
|
|
118
|
+
fused_weight = weight + lora_weight * @scale
|
|
119
|
+
|
|
120
|
+
num_embeddings, dims = fused_weight.shape
|
|
121
|
+
result = MLX::NN::Embedding.new(num_embeddings, dims)
|
|
122
|
+
result.weight = fused_weight
|
|
123
|
+
result
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
module_function
|
|
128
|
+
|
|
129
|
+
# Default LoRA target keys (layer names that get LoRA applied)
|
|
130
|
+
DEFAULT_LORA_KEYS = %w[self_attn.q_proj self_attn.k_proj self_attn.v_proj].freeze
|
|
131
|
+
|
|
132
|
+
# Apply LoRA layers to a model's last N layers.
|
|
133
|
+
def apply_lora_layers(model, num_layers: nil, config: {})
|
|
134
|
+
rank = config["rank"] || config[:rank] || 8
|
|
135
|
+
scale = config["scale"] || config[:scale] || 20.0
|
|
136
|
+
dropout = config["dropout"] || config[:dropout] || 0.0
|
|
137
|
+
keys = config["keys"] || config[:keys] || DEFAULT_LORA_KEYS
|
|
138
|
+
|
|
139
|
+
layers = model.layers
|
|
140
|
+
num_layers ||= layers.length
|
|
141
|
+
target_layers = layers.last(num_layers)
|
|
142
|
+
|
|
143
|
+
target_layers.each do |layer|
|
|
144
|
+
_apply_lora_to_module(layer, "", keys, rank: rank, scale: scale, dropout: dropout)
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def _apply_lora_to_module(mod, prefix, keys, rank:, scale:, dropout:)
|
|
149
|
+
mod.state.each do |key, value|
|
|
150
|
+
full_key = prefix.empty? ? key : "#{prefix}.#{key}"
|
|
151
|
+
|
|
152
|
+
if value.is_a?(MLX::NN::Linear) && keys.any? { |k| full_key.end_with?(k) || full_key.include?(k) }
|
|
153
|
+
lora = LoRALinear.from_base(value, r: rank, scale: scale, dropout: dropout)
|
|
154
|
+
mod.state[key] = lora
|
|
155
|
+
elsif value.is_a?(MLX::NN::Embedding) && keys.any? { |k| full_key.end_with?(k) || full_key.include?(k) }
|
|
156
|
+
lora = LoRAEmbedding.from_base(value, r: rank, scale: scale, dropout: dropout)
|
|
157
|
+
mod.state[key] = lora
|
|
158
|
+
elsif value.is_a?(MLX::NN::Module)
|
|
159
|
+
_apply_lora_to_module(value, full_key, keys, rank: rank, scale: scale, dropout: dropout)
|
|
160
|
+
end
|
|
161
|
+
end
|
|
162
|
+
end
|
|
163
|
+
module_function :_apply_lora_to_module
|
|
164
|
+
end
|
|
165
|
+
end
|