mlx-ruby-lm 0.30.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +21 -0
- data/README.md +83 -0
- data/exe/mlx_lm +7 -0
- data/lib/mlx_lm/benchmark.rb +67 -0
- data/lib/mlx_lm/chat_template.rb +41 -0
- data/lib/mlx_lm/cli.rb +113 -0
- data/lib/mlx_lm/config.rb +30 -0
- data/lib/mlx_lm/convert_utils.rb +51 -0
- data/lib/mlx_lm/generate.rb +204 -0
- data/lib/mlx_lm/load_utils.rb +87 -0
- data/lib/mlx_lm/model_args.rb +54 -0
- data/lib/mlx_lm/models/activations.rb +46 -0
- data/lib/mlx_lm/models/afm7.rb +131 -0
- data/lib/mlx_lm/models/afmoe.rb +421 -0
- data/lib/mlx_lm/models/apertus.rb +179 -0
- data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
- data/lib/mlx_lm/models/bailing_moe.rb +399 -0
- data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
- data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
- data/lib/mlx_lm/models/bitnet.rb +176 -0
- data/lib/mlx_lm/models/cache.rb +792 -0
- data/lib/mlx_lm/models/cohere.rb +150 -0
- data/lib/mlx_lm/models/cohere2.rb +224 -0
- data/lib/mlx_lm/models/dbrx.rb +286 -0
- data/lib/mlx_lm/models/deepseek.rb +239 -0
- data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
- data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
- data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
- data/lib/mlx_lm/models/dots1.rb +292 -0
- data/lib/mlx_lm/models/ernie4_5.rb +165 -0
- data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
- data/lib/mlx_lm/models/exaone.rb +169 -0
- data/lib/mlx_lm/models/exaone4.rb +233 -0
- data/lib/mlx_lm/models/exaone_moe.rb +421 -0
- data/lib/mlx_lm/models/falcon_h1.rb +102 -0
- data/lib/mlx_lm/models/gated_delta.rb +136 -0
- data/lib/mlx_lm/models/gemma.rb +159 -0
- data/lib/mlx_lm/models/gemma2.rb +198 -0
- data/lib/mlx_lm/models/gemma3.rb +85 -0
- data/lib/mlx_lm/models/gemma3_text.rb +270 -0
- data/lib/mlx_lm/models/gemma3n.rb +79 -0
- data/lib/mlx_lm/models/glm.rb +164 -0
- data/lib/mlx_lm/models/glm4.rb +180 -0
- data/lib/mlx_lm/models/glm4_moe.rb +343 -0
- data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
- data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
- data/lib/mlx_lm/models/gpt2.rb +166 -0
- data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
- data/lib/mlx_lm/models/gpt_neox.rb +178 -0
- data/lib/mlx_lm/models/gpt_oss.rb +319 -0
- data/lib/mlx_lm/models/granite.rb +170 -0
- data/lib/mlx_lm/models/granitemoe.rb +58 -0
- data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
- data/lib/mlx_lm/models/helium.rb +158 -0
- data/lib/mlx_lm/models/hunyuan.rb +378 -0
- data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
- data/lib/mlx_lm/models/internlm2.rb +160 -0
- data/lib/mlx_lm/models/internlm3.rb +237 -0
- data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
- data/lib/mlx_lm/models/jamba.rb +158 -0
- data/lib/mlx_lm/models/kimi_k25.rb +98 -0
- data/lib/mlx_lm/models/kimi_linear.rb +124 -0
- data/lib/mlx_lm/models/kimi_vl.rb +93 -0
- data/lib/mlx_lm/models/klear.rb +283 -0
- data/lib/mlx_lm/models/lfm2.rb +120 -0
- data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
- data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
- data/lib/mlx_lm/models/lille_130m.rb +148 -0
- data/lib/mlx_lm/models/llama.rb +183 -0
- data/lib/mlx_lm/models/llama4.rb +357 -0
- data/lib/mlx_lm/models/llama4_text.rb +195 -0
- data/lib/mlx_lm/models/longcat_flash.rb +153 -0
- data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
- data/lib/mlx_lm/models/mamba.rb +301 -0
- data/lib/mlx_lm/models/mamba2.rb +292 -0
- data/lib/mlx_lm/models/mimo.rb +174 -0
- data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
- data/lib/mlx_lm/models/minicpm.rb +169 -0
- data/lib/mlx_lm/models/minicpm3.rb +237 -0
- data/lib/mlx_lm/models/minimax.rb +282 -0
- data/lib/mlx_lm/models/ministral3.rb +304 -0
- data/lib/mlx_lm/models/mistral3.rb +84 -0
- data/lib/mlx_lm/models/mixtral.rb +192 -0
- data/lib/mlx_lm/models/mla.rb +75 -0
- data/lib/mlx_lm/models/nanochat.rb +167 -0
- data/lib/mlx_lm/models/nemotron.rb +202 -0
- data/lib/mlx_lm/models/nemotron_h.rb +212 -0
- data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
- data/lib/mlx_lm/models/olmo.rb +165 -0
- data/lib/mlx_lm/models/olmo2.rb +169 -0
- data/lib/mlx_lm/models/olmo3.rb +254 -0
- data/lib/mlx_lm/models/olmoe.rb +64 -0
- data/lib/mlx_lm/models/openelm.rb +208 -0
- data/lib/mlx_lm/models/phi.rb +156 -0
- data/lib/mlx_lm/models/phi3.rb +171 -0
- data/lib/mlx_lm/models/phi3small.rb +196 -0
- data/lib/mlx_lm/models/phimoe.rb +206 -0
- data/lib/mlx_lm/models/phixtral.rb +208 -0
- data/lib/mlx_lm/models/pipeline.rb +37 -0
- data/lib/mlx_lm/models/pixtral.rb +47 -0
- data/lib/mlx_lm/models/plamo.rb +169 -0
- data/lib/mlx_lm/models/plamo2.rb +173 -0
- data/lib/mlx_lm/models/qwen.rb +175 -0
- data/lib/mlx_lm/models/qwen2.rb +162 -0
- data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
- data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
- data/lib/mlx_lm/models/qwen3.rb +167 -0
- data/lib/mlx_lm/models/qwen3_5.rb +69 -0
- data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
- data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
- data/lib/mlx_lm/models/qwen3_next.rb +147 -0
- data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
- data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
- data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
- data/lib/mlx_lm/models/rope_utils.rb +316 -0
- data/lib/mlx_lm/models/rwkv7.rb +101 -0
- data/lib/mlx_lm/models/seed_oss.rb +167 -0
- data/lib/mlx_lm/models/smollm3.rb +89 -0
- data/lib/mlx_lm/models/solar_open.rb +79 -0
- data/lib/mlx_lm/models/ssm.rb +162 -0
- data/lib/mlx_lm/models/stablelm.rb +160 -0
- data/lib/mlx_lm/models/starcoder2.rb +161 -0
- data/lib/mlx_lm/models/step3p5.rb +479 -0
- data/lib/mlx_lm/models/switch_layers.rb +221 -0
- data/lib/mlx_lm/models/telechat3.rb +192 -0
- data/lib/mlx_lm/models/youtu_llm.rb +230 -0
- data/lib/mlx_lm/models.rb +33 -0
- data/lib/mlx_lm/perplexity.rb +48 -0
- data/lib/mlx_lm/quantize.rb +131 -0
- data/lib/mlx_lm/sample_utils.rb +159 -0
- data/lib/mlx_lm/server.rb +190 -0
- data/lib/mlx_lm/tokenizer_utils.rb +158 -0
- data/lib/mlx_lm/tuner/lora.rb +165 -0
- data/lib/mlx_lm/version.rb +3 -0
- data/lib/mlx_lm/weight_utils.rb +170 -0
- data/lib/mlx_lm.rb +135 -0
- metadata +272 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: 9a305e7fa3e70c61d29f8e997a98d4f7da6ff344d2df698778f56980dae27a34
|
|
4
|
+
data.tar.gz: f81553b3050391f1585f5ba6822c1cd28936f5ddaccc0b852bcfae30ea41bdf4
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: d5e6be331e95c5323b7fdb93b8defb2242a5ad969269e8e5442cc8f499d1fb2beabe03f81d14e7ee8abc3c61bf5d144ea5755bda7462cfdb9bcddb0afe9695f0
|
|
7
|
+
data.tar.gz: b333ba48c0e390b217ec4320e149f03b1e507caefe856701bf85245fa301b55d0fec2d75ecc2bdd92e367b49e3c82454b1b0baffe1d9c161223fb9580635a20e
|
data/LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
The MIT License (MIT)
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Alex Skryl
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in
|
|
13
|
+
all copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
21
|
+
THE SOFTWARE.
|
data/README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# mlx-ruby-lm
|
|
2
|
+
|
|
3
|
+
Ruby LLM inference toolkit built on the `mlx` gem.
|
|
4
|
+
|
|
5
|
+
## Index
|
|
6
|
+
|
|
7
|
+
- [Documentation Index](docs/index.md)
|
|
8
|
+
- [Installation](docs/installation.md)
|
|
9
|
+
- [CLI Usage](docs/cli.md)
|
|
10
|
+
- [Ruby APIs](docs/ruby-apis.md)
|
|
11
|
+
- [Models](docs/models.md)
|
|
12
|
+
|
|
13
|
+
For full reference pages and deep dives, start at [docs/index.md](docs/index.md).
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
```bash
|
|
18
|
+
gem install mlx-ruby-lm
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
Or add it to a project:
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
bundle add mlx-ruby-lm
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
See [docs/installation.md](docs/installation.md) for requirements and source installs.
|
|
28
|
+
|
|
29
|
+
## CLI Usage
|
|
30
|
+
|
|
31
|
+
Executable: `mlx_lm`
|
|
32
|
+
|
|
33
|
+
Commands:
|
|
34
|
+
|
|
35
|
+
- `mlx_lm generate`
|
|
36
|
+
- `mlx_lm chat`
|
|
37
|
+
- `mlx_lm server`
|
|
38
|
+
|
|
39
|
+
Quick examples:
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
mlx_lm generate --model /path/to/model --prompt "Hello"
|
|
43
|
+
mlx_lm chat --model /path/to/model --system-prompt "You are concise."
|
|
44
|
+
mlx_lm server --model /path/to/model --host 127.0.0.1 --port 8080
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
See [docs/cli.md](docs/cli.md) for options, defaults, and current parser/behavior caveats.
|
|
48
|
+
|
|
49
|
+
## High-Level Ruby API Usage
|
|
50
|
+
|
|
51
|
+
```ruby
|
|
52
|
+
require "mlx"
|
|
53
|
+
require "mlx_lm"
|
|
54
|
+
|
|
55
|
+
model, tokenizer = MlxLm::LoadUtils.load("/path/to/model")
|
|
56
|
+
text = MlxLm::Generate.generate(model, tokenizer, "Hello", max_tokens: 64)
|
|
57
|
+
puts text
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Streaming:
|
|
61
|
+
|
|
62
|
+
```ruby
|
|
63
|
+
MlxLm::Generate.stream_generate(model, tokenizer, "Hello", max_tokens: 64).each do |resp|
|
|
64
|
+
print resp.text
|
|
65
|
+
end
|
|
66
|
+
puts
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
See [docs/ruby-apis.md](docs/ruby-apis.md) for the full API inventory.
|
|
70
|
+
|
|
71
|
+
## High-Level Model Usage
|
|
72
|
+
|
|
73
|
+
`LoadUtils.load` expects a local model directory with files such as `config.json`,
|
|
74
|
+
`tokenizer.json`, and `model*.safetensors`.
|
|
75
|
+
|
|
76
|
+
To inspect supported model keys at runtime:
|
|
77
|
+
|
|
78
|
+
```ruby
|
|
79
|
+
require "mlx_lm"
|
|
80
|
+
puts MlxLm::Models::REGISTRY.keys.sort
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
See [docs/models.md](docs/models.md) for full registry keys and remapping behavior.
|
data/exe/mlx_lm
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
module Benchmark
|
|
3
|
+
module_function
|
|
4
|
+
|
|
5
|
+
# Measure generation performance (tokens/sec).
|
|
6
|
+
def measure_generation(model, prompt_tokens: 32, gen_tokens: 64, vocab_size: 32000)
|
|
7
|
+
mx = MLX::Core
|
|
8
|
+
|
|
9
|
+
# Create random prompt tokens
|
|
10
|
+
prompt = mx.random_uniform([prompt_tokens], 0.0, (vocab_size - 1).to_f, mx.float32).astype(mx.int32)
|
|
11
|
+
mx.eval(prompt)
|
|
12
|
+
|
|
13
|
+
# Create cache
|
|
14
|
+
cache = Cache.make_prompt_cache(model)
|
|
15
|
+
|
|
16
|
+
# Measure prompt processing
|
|
17
|
+
prompt_input = prompt.reshape([1, prompt_tokens])
|
|
18
|
+
prompt_start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
|
|
19
|
+
logits = model.call(prompt_input, cache: cache)
|
|
20
|
+
mx.eval(logits)
|
|
21
|
+
mx.eval(*cache.map(&:state).flatten.compact)
|
|
22
|
+
prompt_elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - prompt_start
|
|
23
|
+
prompt_tps = prompt_tokens.to_f / [prompt_elapsed, 1e-9].max
|
|
24
|
+
|
|
25
|
+
# Get first generated token
|
|
26
|
+
last_logits = logits.reshape([prompt_tokens, logits.shape[-1]])
|
|
27
|
+
# Take last position
|
|
28
|
+
last_pos = mx.split(last_logits, [prompt_tokens - 1], 0)[1]
|
|
29
|
+
y = mx.argmax(last_pos, -1)
|
|
30
|
+
mx.eval(y)
|
|
31
|
+
|
|
32
|
+
# Measure generation
|
|
33
|
+
gen_start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
|
|
34
|
+
gen_tokens.times do
|
|
35
|
+
y_input = y.reshape([1, 1])
|
|
36
|
+
logits = model.call(y_input, cache: cache)
|
|
37
|
+
mx.eval(logits)
|
|
38
|
+
mx.eval(*cache.map(&:state).flatten.compact)
|
|
39
|
+
y = mx.argmax(logits.reshape([1, logits.shape[-1]]), -1)
|
|
40
|
+
mx.eval(y)
|
|
41
|
+
end
|
|
42
|
+
gen_elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - gen_start
|
|
43
|
+
gen_tps = gen_tokens.to_f / [gen_elapsed, 1e-9].max
|
|
44
|
+
|
|
45
|
+
{
|
|
46
|
+
prompt_tokens: prompt_tokens,
|
|
47
|
+
prompt_time: prompt_elapsed,
|
|
48
|
+
prompt_tps: prompt_tps,
|
|
49
|
+
generation_tokens: gen_tokens,
|
|
50
|
+
generation_time: gen_elapsed,
|
|
51
|
+
generation_tps: gen_tps,
|
|
52
|
+
}
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# Get model statistics (parameter count, etc.)
|
|
56
|
+
def model_stats(model)
|
|
57
|
+
params = MLX::Utils.tree_flatten(model.parameters)
|
|
58
|
+
total = 0
|
|
59
|
+
params.each { |_, v| total += v.size }
|
|
60
|
+
|
|
61
|
+
{
|
|
62
|
+
total_params: total,
|
|
63
|
+
num_layers: model.respond_to?(:layers) ? model.layers.length : 0,
|
|
64
|
+
}
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
end
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
module ChatTemplate
|
|
3
|
+
module_function
|
|
4
|
+
|
|
5
|
+
# Apply a simple chat template to format messages into a prompt string.
|
|
6
|
+
# This is a default/fallback template. Model-specific templates (like
|
|
7
|
+
# Jinja-based ones from tokenizer_config.json) can override this.
|
|
8
|
+
def apply(messages, template: :default)
|
|
9
|
+
case template
|
|
10
|
+
when :default
|
|
11
|
+
apply_default(messages)
|
|
12
|
+
when :chatml
|
|
13
|
+
apply_chatml(messages)
|
|
14
|
+
else
|
|
15
|
+
apply_default(messages)
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
# Default template: ChatML-like format
|
|
20
|
+
# <|im_start|>system
|
|
21
|
+
# content<|im_end|>
|
|
22
|
+
# <|im_start|>user
|
|
23
|
+
# content<|im_end|>
|
|
24
|
+
# <|im_start|>assistant
|
|
25
|
+
def apply_default(messages)
|
|
26
|
+
parts = []
|
|
27
|
+
messages.each do |msg|
|
|
28
|
+
role = msg["role"] || msg[:role]
|
|
29
|
+
content = msg["content"] || msg[:content]
|
|
30
|
+
parts << "<|im_start|>#{role}\n#{content}<|im_end|>"
|
|
31
|
+
end
|
|
32
|
+
parts << "<|im_start|>assistant"
|
|
33
|
+
parts.join("\n")
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# ChatML template (same as default, widely used)
|
|
37
|
+
def apply_chatml(messages)
|
|
38
|
+
apply_default(messages)
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
end
|
data/lib/mlx_lm/cli.rb
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
require "optparse"
|
|
2
|
+
|
|
3
|
+
module MlxLm
|
|
4
|
+
module CLI
|
|
5
|
+
COMMANDS = %w[generate chat server].freeze
|
|
6
|
+
|
|
7
|
+
module_function
|
|
8
|
+
|
|
9
|
+
def parse_args(argv)
|
|
10
|
+
command = argv.shift
|
|
11
|
+
unless COMMANDS.include?(command)
|
|
12
|
+
raise ArgumentError, "Unknown command '#{command}'. Valid commands: #{COMMANDS.join(', ')}"
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
args = default_args.merge(command: command)
|
|
16
|
+
|
|
17
|
+
parser = OptionParser.new do |opts|
|
|
18
|
+
opts.banner = "Usage: mlx_lm #{command} [options]"
|
|
19
|
+
|
|
20
|
+
opts.on("--model MODEL", "Model path or HuggingFace ID") { |v| args[:model] = v }
|
|
21
|
+
opts.on("--prompt PROMPT", "Input prompt") { |v| args[:prompt] = v }
|
|
22
|
+
opts.on("--max-tokens N", Integer, "Maximum tokens to generate") { |v| args[:max_tokens] = v }
|
|
23
|
+
opts.on("--temp TEMP", Float, "Sampling temperature") { |v| args[:temp] = v }
|
|
24
|
+
opts.on("--top-p P", Float, "Top-p (nucleus) sampling") { |v| args[:top_p] = v }
|
|
25
|
+
opts.on("--seed N", Integer, "Random seed") { |v| args[:seed] = v }
|
|
26
|
+
opts.on("--repetition-penalty F", Float, "Repetition penalty") { |v| args[:repetition_penalty] = v }
|
|
27
|
+
opts.on("--repetition-context-size N", Integer, "Repetition context size") { |v| args[:repetition_context_size] = v }
|
|
28
|
+
opts.on("--host HOST", "Server host") { |v| args[:host] = v }
|
|
29
|
+
opts.on("--port PORT", Integer, "Server port") { |v| args[:port] = v }
|
|
30
|
+
opts.on("--system-prompt PROMPT", "System prompt for chat") { |v| args[:system_prompt] = v }
|
|
31
|
+
opts.on("--verbose", "Verbose output") { args[:verbose] = true }
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
parser.parse!(argv)
|
|
35
|
+
args
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def default_args
|
|
39
|
+
{
|
|
40
|
+
command: nil,
|
|
41
|
+
model: nil,
|
|
42
|
+
prompt: "",
|
|
43
|
+
max_tokens: 256,
|
|
44
|
+
temp: 0.0,
|
|
45
|
+
top_p: 1.0,
|
|
46
|
+
seed: nil,
|
|
47
|
+
repetition_penalty: nil,
|
|
48
|
+
repetition_context_size: 20,
|
|
49
|
+
host: "127.0.0.1",
|
|
50
|
+
port: 8080,
|
|
51
|
+
system_prompt: nil,
|
|
52
|
+
verbose: false,
|
|
53
|
+
}
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def run(argv = ARGV)
|
|
57
|
+
args = parse_args(argv.dup)
|
|
58
|
+
|
|
59
|
+
case args[:command]
|
|
60
|
+
when "generate"
|
|
61
|
+
run_generate(args)
|
|
62
|
+
when "chat"
|
|
63
|
+
run_chat(args)
|
|
64
|
+
when "server"
|
|
65
|
+
run_server(args)
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def run_generate(args)
|
|
70
|
+
model, tokenizer = LoadUtils.load(args[:model])
|
|
71
|
+
sampler = SampleUtils.make_sampler(temp: args[:temp], top_p: args[:top_p])
|
|
72
|
+
text = Generate.generate(model, tokenizer, args[:prompt],
|
|
73
|
+
max_tokens: args[:max_tokens], sampler: sampler, verbose: args[:verbose])
|
|
74
|
+
puts text unless args[:verbose]
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def run_chat(args)
|
|
78
|
+
model, tokenizer = LoadUtils.load(args[:model])
|
|
79
|
+
messages = []
|
|
80
|
+
if args[:system_prompt]
|
|
81
|
+
messages << { "role" => "system", "content" => args[:system_prompt] }
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
loop do
|
|
85
|
+
print "> "
|
|
86
|
+
$stdout.flush
|
|
87
|
+
input = $stdin.gets
|
|
88
|
+
break if input.nil?
|
|
89
|
+
input = input.strip
|
|
90
|
+
break if input.empty? || input == "exit" || input == "quit"
|
|
91
|
+
|
|
92
|
+
messages << { "role" => "user", "content" => input }
|
|
93
|
+
prompt = ChatTemplate.apply(messages)
|
|
94
|
+
|
|
95
|
+
sampler = SampleUtils.make_sampler(temp: args[:temp])
|
|
96
|
+
text = ""
|
|
97
|
+
Generate.stream_generate(model, tokenizer, prompt,
|
|
98
|
+
max_tokens: args[:max_tokens], sampler: sampler).each do |resp|
|
|
99
|
+
print resp.text
|
|
100
|
+
$stdout.flush
|
|
101
|
+
text += resp.text
|
|
102
|
+
end
|
|
103
|
+
puts
|
|
104
|
+
|
|
105
|
+
messages << { "role" => "assistant", "content" => text }
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def run_server(args)
|
|
110
|
+
Server.start(model_path: args[:model], host: args[:host], port: args[:port])
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
end
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
require "json"
|
|
2
|
+
|
|
3
|
+
module MlxLm
|
|
4
|
+
module Config
|
|
5
|
+
module_function
|
|
6
|
+
|
|
7
|
+
# Load model config from a directory containing config.json
|
|
8
|
+
# and optionally generation_config.json.
|
|
9
|
+
# Mirrors Python mlx_lm.utils.load_config
|
|
10
|
+
def load(model_path)
|
|
11
|
+
config_path = File.join(model_path, "config.json")
|
|
12
|
+
config = JSON.parse(File.read(config_path))
|
|
13
|
+
|
|
14
|
+
gen_config_path = File.join(model_path, "generation_config.json")
|
|
15
|
+
if File.exist?(gen_config_path)
|
|
16
|
+
begin
|
|
17
|
+
gen_config = JSON.parse(File.read(gen_config_path))
|
|
18
|
+
rescue JSON::ParserError
|
|
19
|
+
gen_config = {}
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
if (eos = gen_config["eos_token_id"])
|
|
23
|
+
config["eos_token_id"] = eos
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
config
|
|
28
|
+
end
|
|
29
|
+
end
|
|
30
|
+
end
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
module ConvertUtils
|
|
3
|
+
DTYPE_MAP = {
|
|
4
|
+
"float32" => :float32,
|
|
5
|
+
"float16" => :float16,
|
|
6
|
+
"bfloat16" => :bfloat16,
|
|
7
|
+
"int8" => :int8,
|
|
8
|
+
"int32" => :int32,
|
|
9
|
+
}.freeze
|
|
10
|
+
|
|
11
|
+
module_function
|
|
12
|
+
|
|
13
|
+
# Convert an MLX array to a different dtype.
|
|
14
|
+
def convert_dtype(array, target_dtype)
|
|
15
|
+
if target_dtype.is_a?(MLX::Core::Dtype)
|
|
16
|
+
return array.astype(target_dtype)
|
|
17
|
+
end
|
|
18
|
+
dtype_sym = target_dtype.is_a?(Symbol) ? target_dtype : DTYPE_MAP[target_dtype.to_s]
|
|
19
|
+
raise ArgumentError, "Unknown dtype: #{target_dtype}" unless dtype_sym
|
|
20
|
+
array.astype(MLX::Core::Dtype.new(dtype_sym))
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Count total number of parameters in a model.
|
|
24
|
+
def count_parameters(model)
|
|
25
|
+
params = MLX::Utils.tree_flatten(model.parameters)
|
|
26
|
+
total = 0
|
|
27
|
+
params.each { |_, v| total += v.size }
|
|
28
|
+
total
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
# Estimate total model size in bytes.
|
|
32
|
+
def model_size_bytes(model)
|
|
33
|
+
mx = MLX::Core
|
|
34
|
+
params = MLX::Utils.tree_flatten(model.parameters)
|
|
35
|
+
total = 0
|
|
36
|
+
params.each do |_, v|
|
|
37
|
+
bytes_per_elem = case v.dtype
|
|
38
|
+
when mx.float32 then 4
|
|
39
|
+
when mx.float16, mx.bfloat16 then 2
|
|
40
|
+
when mx.int32 then 4
|
|
41
|
+
when mx.int8, mx.uint8 then 1
|
|
42
|
+
when mx.int16, mx.uint16 then 2
|
|
43
|
+
when mx.int64 then 8
|
|
44
|
+
else 4
|
|
45
|
+
end
|
|
46
|
+
total += v.size * bytes_per_elem
|
|
47
|
+
end
|
|
48
|
+
total
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
# Response object yielded during streaming generation
|
|
3
|
+
GenerationResponse = Struct.new(
|
|
4
|
+
:text,
|
|
5
|
+
:token,
|
|
6
|
+
:logprobs,
|
|
7
|
+
:prompt_tokens,
|
|
8
|
+
:prompt_tps,
|
|
9
|
+
:generation_tokens,
|
|
10
|
+
:generation_tps,
|
|
11
|
+
:peak_memory,
|
|
12
|
+
:finish_reason,
|
|
13
|
+
keyword_init: true
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
module Generate
|
|
17
|
+
module_function
|
|
18
|
+
|
|
19
|
+
# A generator producing token ids based on the given prompt from the model.
|
|
20
|
+
# Yields [token_id, logprobs] for each generated token.
|
|
21
|
+
def generate_step(
|
|
22
|
+
prompt,
|
|
23
|
+
model,
|
|
24
|
+
max_tokens: 256,
|
|
25
|
+
sampler: nil,
|
|
26
|
+
logits_processors: nil,
|
|
27
|
+
max_kv_size: nil,
|
|
28
|
+
prompt_cache: nil,
|
|
29
|
+
prefill_step_size: 2048
|
|
30
|
+
)
|
|
31
|
+
mx = MLX::Core
|
|
32
|
+
|
|
33
|
+
raise ArgumentError, "prompt must not be empty" if prompt.size == 0
|
|
34
|
+
|
|
35
|
+
tokens = nil
|
|
36
|
+
|
|
37
|
+
# Create the KV cache for generation
|
|
38
|
+
prompt_cache ||= Cache.make_prompt_cache(model, max_kv_size: max_kv_size)
|
|
39
|
+
|
|
40
|
+
sampler ||= ->(x) { mx.argmax(x, -1) }
|
|
41
|
+
|
|
42
|
+
model_call = ->(input_tokens_2d) {
|
|
43
|
+
model.call(input_tokens_2d, cache: prompt_cache)
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
step = ->(input_tokens_1d) {
|
|
47
|
+
seq_len = input_tokens_1d.size
|
|
48
|
+
input_2d = input_tokens_1d.reshape([1, seq_len])
|
|
49
|
+
logits = model_call.call(input_2d)
|
|
50
|
+
|
|
51
|
+
# Take the last token's logits
|
|
52
|
+
last_dim = logits.shape[1]
|
|
53
|
+
if last_dim > 1
|
|
54
|
+
logits = mx.split(logits, [last_dim - 1], 1)[1]
|
|
55
|
+
end
|
|
56
|
+
vocab_size = logits.shape[-1]
|
|
57
|
+
logits = logits.reshape([1, vocab_size])
|
|
58
|
+
|
|
59
|
+
if logits_processors && input_tokens_1d.size > 0
|
|
60
|
+
tokens = if tokens.nil?
|
|
61
|
+
input_tokens_1d
|
|
62
|
+
else
|
|
63
|
+
mx.concatenate([tokens, input_tokens_1d], 0)
|
|
64
|
+
end
|
|
65
|
+
logits_processors.each { |processor| logits = processor.call(tokens, logits) }
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
logprobs = logits - mx.logsumexp(logits, -1, true)
|
|
69
|
+
sampled = sampler.call(logprobs)
|
|
70
|
+
[sampled, logprobs.reshape([vocab_size])]
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
# Prompt prefilling - process prompt in chunks
|
|
74
|
+
prompt_arr = prompt.is_a?(::Array) ? mx.array(prompt, dtype: mx.uint32) : prompt
|
|
75
|
+
total_prompt_tokens = prompt_arr.size
|
|
76
|
+
|
|
77
|
+
# Process prompt chunks (all but last token)
|
|
78
|
+
while total_prompt_tokens > 1
|
|
79
|
+
remaining = total_prompt_tokens - 1
|
|
80
|
+
n_to_process = [prefill_step_size, remaining].min
|
|
81
|
+
chunk = mx.split(prompt_arr, [n_to_process], 0)[0]
|
|
82
|
+
chunk_len = chunk.size
|
|
83
|
+
model_call.call(chunk.reshape([1, chunk_len]))
|
|
84
|
+
mx.eval(*prompt_cache.map(&:state).flatten.compact)
|
|
85
|
+
prompt_arr = mx.split(prompt_arr, [n_to_process], 0)[1]
|
|
86
|
+
total_prompt_tokens -= n_to_process
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
# Process last token and get first generated token
|
|
90
|
+
y, logprobs = step.call(prompt_arr)
|
|
91
|
+
mx.eval(y, logprobs)
|
|
92
|
+
|
|
93
|
+
Enumerator.new do |yielder|
|
|
94
|
+
n = 0
|
|
95
|
+
loop do
|
|
96
|
+
break if n == max_tokens
|
|
97
|
+
|
|
98
|
+
y_1d = y.ndim > 1 ? y.reshape([y.size]) : y
|
|
99
|
+
next_y, next_logprobs = step.call(y_1d)
|
|
100
|
+
mx.eval(next_y, next_logprobs)
|
|
101
|
+
|
|
102
|
+
yielder.yield [y.item, logprobs]
|
|
103
|
+
y, logprobs = next_y, next_logprobs
|
|
104
|
+
n += 1
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Stream text generation from the model.
|
|
110
|
+
# Yields GenerationResponse objects with text segments.
|
|
111
|
+
def stream_generate(model, tokenizer, prompt, max_tokens: 256, **kwargs)
|
|
112
|
+
tokenizer = TokenizerWrapper.new(tokenizer) unless tokenizer.is_a?(TokenizerWrapper)
|
|
113
|
+
|
|
114
|
+
unless prompt.is_a?(MLX::Core::Array)
|
|
115
|
+
if prompt.is_a?(String)
|
|
116
|
+
prompt = tokenizer.encode(prompt)
|
|
117
|
+
end
|
|
118
|
+
prompt = MLX::Core.array(prompt, dtype: MLX::Core.uint32)
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
detokenizer = tokenizer.detokenizer
|
|
122
|
+
|
|
123
|
+
token_generator = generate_step(prompt, model, max_tokens: max_tokens, **kwargs)
|
|
124
|
+
|
|
125
|
+
tic = Process.clock_gettime(Process::CLOCK_MONOTONIC)
|
|
126
|
+
prompt_tps = 0.0
|
|
127
|
+
|
|
128
|
+
Enumerator.new do |yielder|
|
|
129
|
+
n = 0
|
|
130
|
+
last_token = nil
|
|
131
|
+
token_generator.each do |token, logprobs|
|
|
132
|
+
if n == 0
|
|
133
|
+
prompt_time = Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic
|
|
134
|
+
prompt_tps = prompt.size.to_f / [prompt_time, 1e-9].max
|
|
135
|
+
tic = Process.clock_gettime(Process::CLOCK_MONOTONIC)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
last_token = token
|
|
139
|
+
|
|
140
|
+
if tokenizer.eos_token_ids.include?(token)
|
|
141
|
+
detokenizer.finalize
|
|
142
|
+
elapsed = [Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic, 1e-9].max
|
|
143
|
+
yielder.yield GenerationResponse.new(
|
|
144
|
+
text: detokenizer.last_segment,
|
|
145
|
+
token: token,
|
|
146
|
+
logprobs: logprobs,
|
|
147
|
+
prompt_tokens: prompt.size,
|
|
148
|
+
prompt_tps: prompt_tps,
|
|
149
|
+
generation_tokens: n + 1,
|
|
150
|
+
generation_tps: (n + 1).to_f / elapsed,
|
|
151
|
+
peak_memory: 0.0,
|
|
152
|
+
finish_reason: "stop"
|
|
153
|
+
)
|
|
154
|
+
break
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
detokenizer.add_token(token)
|
|
158
|
+
elapsed = [Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic, 1e-9].max
|
|
159
|
+
|
|
160
|
+
yielder.yield GenerationResponse.new(
|
|
161
|
+
text: detokenizer.last_segment,
|
|
162
|
+
token: token,
|
|
163
|
+
logprobs: logprobs,
|
|
164
|
+
prompt_tokens: prompt.size,
|
|
165
|
+
prompt_tps: prompt_tps,
|
|
166
|
+
generation_tokens: n + 1,
|
|
167
|
+
generation_tps: (n + 1).to_f / elapsed,
|
|
168
|
+
peak_memory: 0.0,
|
|
169
|
+
finish_reason: ((n + 1) == max_tokens ? "length" : nil)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
n += 1
|
|
173
|
+
break if (n + 1) == max_tokens
|
|
174
|
+
end
|
|
175
|
+
end
|
|
176
|
+
end
|
|
177
|
+
|
|
178
|
+
# Non-streaming generation, returns complete text.
|
|
179
|
+
def generate(model, tokenizer, prompt, verbose: false, **kwargs)
|
|
180
|
+
text = ""
|
|
181
|
+
response = nil
|
|
182
|
+
stream_generate(model, tokenizer, prompt, **kwargs).each do |resp|
|
|
183
|
+
text += resp.text
|
|
184
|
+
response = resp
|
|
185
|
+
if verbose
|
|
186
|
+
print resp.text
|
|
187
|
+
$stdout.flush
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
if verbose
|
|
192
|
+
puts
|
|
193
|
+
puts "=" * 10
|
|
194
|
+
if text.empty?
|
|
195
|
+
puts "No text generated for this prompt"
|
|
196
|
+
return text
|
|
197
|
+
end
|
|
198
|
+
puts "Prompt: #{response.prompt_tokens} tokens, #{'%.3f' % response.prompt_tps} tokens-per-sec"
|
|
199
|
+
puts "Generation: #{response.generation_tokens} tokens, #{'%.3f' % response.generation_tps} tokens-per-sec"
|
|
200
|
+
end
|
|
201
|
+
text
|
|
202
|
+
end
|
|
203
|
+
end
|
|
204
|
+
end
|