mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 9a305e7fa3e70c61d29f8e997a98d4f7da6ff344d2df698778f56980dae27a34
4
+ data.tar.gz: f81553b3050391f1585f5ba6822c1cd28936f5ddaccc0b852bcfae30ea41bdf4
5
+ SHA512:
6
+ metadata.gz: d5e6be331e95c5323b7fdb93b8defb2242a5ad969269e8e5442cc8f499d1fb2beabe03f81d14e7ee8abc3c61bf5d144ea5755bda7462cfdb9bcddb0afe9695f0
7
+ data.tar.gz: b333ba48c0e390b217ec4320e149f03b1e507caefe856701bf85245fa301b55d0fec2d75ecc2bdd92e367b49e3c82454b1b0baffe1d9c161223fb9580635a20e
data/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2025 Alex Skryl
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,83 @@
1
+ # mlx-ruby-lm
2
+
3
+ Ruby LLM inference toolkit built on the `mlx` gem.
4
+
5
+ ## Index
6
+
7
+ - [Documentation Index](docs/index.md)
8
+ - [Installation](docs/installation.md)
9
+ - [CLI Usage](docs/cli.md)
10
+ - [Ruby APIs](docs/ruby-apis.md)
11
+ - [Models](docs/models.md)
12
+
13
+ For full reference pages and deep dives, start at [docs/index.md](docs/index.md).
14
+
15
+ ## Installation
16
+
17
+ ```bash
18
+ gem install mlx-ruby-lm
19
+ ```
20
+
21
+ Or add it to a project:
22
+
23
+ ```bash
24
+ bundle add mlx-ruby-lm
25
+ ```
26
+
27
+ See [docs/installation.md](docs/installation.md) for requirements and source installs.
28
+
29
+ ## CLI Usage
30
+
31
+ Executable: `mlx_lm`
32
+
33
+ Commands:
34
+
35
+ - `mlx_lm generate`
36
+ - `mlx_lm chat`
37
+ - `mlx_lm server`
38
+
39
+ Quick examples:
40
+
41
+ ```bash
42
+ mlx_lm generate --model /path/to/model --prompt "Hello"
43
+ mlx_lm chat --model /path/to/model --system-prompt "You are concise."
44
+ mlx_lm server --model /path/to/model --host 127.0.0.1 --port 8080
45
+ ```
46
+
47
+ See [docs/cli.md](docs/cli.md) for options, defaults, and current parser/behavior caveats.
48
+
49
+ ## High-Level Ruby API Usage
50
+
51
+ ```ruby
52
+ require "mlx"
53
+ require "mlx_lm"
54
+
55
+ model, tokenizer = MlxLm::LoadUtils.load("/path/to/model")
56
+ text = MlxLm::Generate.generate(model, tokenizer, "Hello", max_tokens: 64)
57
+ puts text
58
+ ```
59
+
60
+ Streaming:
61
+
62
+ ```ruby
63
+ MlxLm::Generate.stream_generate(model, tokenizer, "Hello", max_tokens: 64).each do |resp|
64
+ print resp.text
65
+ end
66
+ puts
67
+ ```
68
+
69
+ See [docs/ruby-apis.md](docs/ruby-apis.md) for the full API inventory.
70
+
71
+ ## High-Level Model Usage
72
+
73
+ `LoadUtils.load` expects a local model directory with files such as `config.json`,
74
+ `tokenizer.json`, and `model*.safetensors`.
75
+
76
+ To inspect supported model keys at runtime:
77
+
78
+ ```ruby
79
+ require "mlx_lm"
80
+ puts MlxLm::Models::REGISTRY.keys.sort
81
+ ```
82
+
83
+ See [docs/models.md](docs/models.md) for full registry keys and remapping behavior.
data/exe/mlx_lm ADDED
@@ -0,0 +1,7 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ require "mlx"
5
+ require "mlx_lm"
6
+
7
+ MlxLm::CLI.run(ARGV)
@@ -0,0 +1,67 @@
1
+ module MlxLm
2
+ module Benchmark
3
+ module_function
4
+
5
+ # Measure generation performance (tokens/sec).
6
+ def measure_generation(model, prompt_tokens: 32, gen_tokens: 64, vocab_size: 32000)
7
+ mx = MLX::Core
8
+
9
+ # Create random prompt tokens
10
+ prompt = mx.random_uniform([prompt_tokens], 0.0, (vocab_size - 1).to_f, mx.float32).astype(mx.int32)
11
+ mx.eval(prompt)
12
+
13
+ # Create cache
14
+ cache = Cache.make_prompt_cache(model)
15
+
16
+ # Measure prompt processing
17
+ prompt_input = prompt.reshape([1, prompt_tokens])
18
+ prompt_start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
19
+ logits = model.call(prompt_input, cache: cache)
20
+ mx.eval(logits)
21
+ mx.eval(*cache.map(&:state).flatten.compact)
22
+ prompt_elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - prompt_start
23
+ prompt_tps = prompt_tokens.to_f / [prompt_elapsed, 1e-9].max
24
+
25
+ # Get first generated token
26
+ last_logits = logits.reshape([prompt_tokens, logits.shape[-1]])
27
+ # Take last position
28
+ last_pos = mx.split(last_logits, [prompt_tokens - 1], 0)[1]
29
+ y = mx.argmax(last_pos, -1)
30
+ mx.eval(y)
31
+
32
+ # Measure generation
33
+ gen_start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
34
+ gen_tokens.times do
35
+ y_input = y.reshape([1, 1])
36
+ logits = model.call(y_input, cache: cache)
37
+ mx.eval(logits)
38
+ mx.eval(*cache.map(&:state).flatten.compact)
39
+ y = mx.argmax(logits.reshape([1, logits.shape[-1]]), -1)
40
+ mx.eval(y)
41
+ end
42
+ gen_elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - gen_start
43
+ gen_tps = gen_tokens.to_f / [gen_elapsed, 1e-9].max
44
+
45
+ {
46
+ prompt_tokens: prompt_tokens,
47
+ prompt_time: prompt_elapsed,
48
+ prompt_tps: prompt_tps,
49
+ generation_tokens: gen_tokens,
50
+ generation_time: gen_elapsed,
51
+ generation_tps: gen_tps,
52
+ }
53
+ end
54
+
55
+ # Get model statistics (parameter count, etc.)
56
+ def model_stats(model)
57
+ params = MLX::Utils.tree_flatten(model.parameters)
58
+ total = 0
59
+ params.each { |_, v| total += v.size }
60
+
61
+ {
62
+ total_params: total,
63
+ num_layers: model.respond_to?(:layers) ? model.layers.length : 0,
64
+ }
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,41 @@
1
+ module MlxLm
2
+ module ChatTemplate
3
+ module_function
4
+
5
+ # Apply a simple chat template to format messages into a prompt string.
6
+ # This is a default/fallback template. Model-specific templates (like
7
+ # Jinja-based ones from tokenizer_config.json) can override this.
8
+ def apply(messages, template: :default)
9
+ case template
10
+ when :default
11
+ apply_default(messages)
12
+ when :chatml
13
+ apply_chatml(messages)
14
+ else
15
+ apply_default(messages)
16
+ end
17
+ end
18
+
19
+ # Default template: ChatML-like format
20
+ # <|im_start|>system
21
+ # content<|im_end|>
22
+ # <|im_start|>user
23
+ # content<|im_end|>
24
+ # <|im_start|>assistant
25
+ def apply_default(messages)
26
+ parts = []
27
+ messages.each do |msg|
28
+ role = msg["role"] || msg[:role]
29
+ content = msg["content"] || msg[:content]
30
+ parts << "<|im_start|>#{role}\n#{content}<|im_end|>"
31
+ end
32
+ parts << "<|im_start|>assistant"
33
+ parts.join("\n")
34
+ end
35
+
36
+ # ChatML template (same as default, widely used)
37
+ def apply_chatml(messages)
38
+ apply_default(messages)
39
+ end
40
+ end
41
+ end
data/lib/mlx_lm/cli.rb ADDED
@@ -0,0 +1,113 @@
1
+ require "optparse"
2
+
3
+ module MlxLm
4
+ module CLI
5
+ COMMANDS = %w[generate chat server].freeze
6
+
7
+ module_function
8
+
9
+ def parse_args(argv)
10
+ command = argv.shift
11
+ unless COMMANDS.include?(command)
12
+ raise ArgumentError, "Unknown command '#{command}'. Valid commands: #{COMMANDS.join(', ')}"
13
+ end
14
+
15
+ args = default_args.merge(command: command)
16
+
17
+ parser = OptionParser.new do |opts|
18
+ opts.banner = "Usage: mlx_lm #{command} [options]"
19
+
20
+ opts.on("--model MODEL", "Model path or HuggingFace ID") { |v| args[:model] = v }
21
+ opts.on("--prompt PROMPT", "Input prompt") { |v| args[:prompt] = v }
22
+ opts.on("--max-tokens N", Integer, "Maximum tokens to generate") { |v| args[:max_tokens] = v }
23
+ opts.on("--temp TEMP", Float, "Sampling temperature") { |v| args[:temp] = v }
24
+ opts.on("--top-p P", Float, "Top-p (nucleus) sampling") { |v| args[:top_p] = v }
25
+ opts.on("--seed N", Integer, "Random seed") { |v| args[:seed] = v }
26
+ opts.on("--repetition-penalty F", Float, "Repetition penalty") { |v| args[:repetition_penalty] = v }
27
+ opts.on("--repetition-context-size N", Integer, "Repetition context size") { |v| args[:repetition_context_size] = v }
28
+ opts.on("--host HOST", "Server host") { |v| args[:host] = v }
29
+ opts.on("--port PORT", Integer, "Server port") { |v| args[:port] = v }
30
+ opts.on("--system-prompt PROMPT", "System prompt for chat") { |v| args[:system_prompt] = v }
31
+ opts.on("--verbose", "Verbose output") { args[:verbose] = true }
32
+ end
33
+
34
+ parser.parse!(argv)
35
+ args
36
+ end
37
+
38
+ def default_args
39
+ {
40
+ command: nil,
41
+ model: nil,
42
+ prompt: "",
43
+ max_tokens: 256,
44
+ temp: 0.0,
45
+ top_p: 1.0,
46
+ seed: nil,
47
+ repetition_penalty: nil,
48
+ repetition_context_size: 20,
49
+ host: "127.0.0.1",
50
+ port: 8080,
51
+ system_prompt: nil,
52
+ verbose: false,
53
+ }
54
+ end
55
+
56
+ def run(argv = ARGV)
57
+ args = parse_args(argv.dup)
58
+
59
+ case args[:command]
60
+ when "generate"
61
+ run_generate(args)
62
+ when "chat"
63
+ run_chat(args)
64
+ when "server"
65
+ run_server(args)
66
+ end
67
+ end
68
+
69
+ def run_generate(args)
70
+ model, tokenizer = LoadUtils.load(args[:model])
71
+ sampler = SampleUtils.make_sampler(temp: args[:temp], top_p: args[:top_p])
72
+ text = Generate.generate(model, tokenizer, args[:prompt],
73
+ max_tokens: args[:max_tokens], sampler: sampler, verbose: args[:verbose])
74
+ puts text unless args[:verbose]
75
+ end
76
+
77
+ def run_chat(args)
78
+ model, tokenizer = LoadUtils.load(args[:model])
79
+ messages = []
80
+ if args[:system_prompt]
81
+ messages << { "role" => "system", "content" => args[:system_prompt] }
82
+ end
83
+
84
+ loop do
85
+ print "> "
86
+ $stdout.flush
87
+ input = $stdin.gets
88
+ break if input.nil?
89
+ input = input.strip
90
+ break if input.empty? || input == "exit" || input == "quit"
91
+
92
+ messages << { "role" => "user", "content" => input }
93
+ prompt = ChatTemplate.apply(messages)
94
+
95
+ sampler = SampleUtils.make_sampler(temp: args[:temp])
96
+ text = ""
97
+ Generate.stream_generate(model, tokenizer, prompt,
98
+ max_tokens: args[:max_tokens], sampler: sampler).each do |resp|
99
+ print resp.text
100
+ $stdout.flush
101
+ text += resp.text
102
+ end
103
+ puts
104
+
105
+ messages << { "role" => "assistant", "content" => text }
106
+ end
107
+ end
108
+
109
+ def run_server(args)
110
+ Server.start(model_path: args[:model], host: args[:host], port: args[:port])
111
+ end
112
+ end
113
+ end
@@ -0,0 +1,30 @@
1
+ require "json"
2
+
3
+ module MlxLm
4
+ module Config
5
+ module_function
6
+
7
+ # Load model config from a directory containing config.json
8
+ # and optionally generation_config.json.
9
+ # Mirrors Python mlx_lm.utils.load_config
10
+ def load(model_path)
11
+ config_path = File.join(model_path, "config.json")
12
+ config = JSON.parse(File.read(config_path))
13
+
14
+ gen_config_path = File.join(model_path, "generation_config.json")
15
+ if File.exist?(gen_config_path)
16
+ begin
17
+ gen_config = JSON.parse(File.read(gen_config_path))
18
+ rescue JSON::ParserError
19
+ gen_config = {}
20
+ end
21
+
22
+ if (eos = gen_config["eos_token_id"])
23
+ config["eos_token_id"] = eos
24
+ end
25
+ end
26
+
27
+ config
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,51 @@
1
+ module MlxLm
2
+ module ConvertUtils
3
+ DTYPE_MAP = {
4
+ "float32" => :float32,
5
+ "float16" => :float16,
6
+ "bfloat16" => :bfloat16,
7
+ "int8" => :int8,
8
+ "int32" => :int32,
9
+ }.freeze
10
+
11
+ module_function
12
+
13
+ # Convert an MLX array to a different dtype.
14
+ def convert_dtype(array, target_dtype)
15
+ if target_dtype.is_a?(MLX::Core::Dtype)
16
+ return array.astype(target_dtype)
17
+ end
18
+ dtype_sym = target_dtype.is_a?(Symbol) ? target_dtype : DTYPE_MAP[target_dtype.to_s]
19
+ raise ArgumentError, "Unknown dtype: #{target_dtype}" unless dtype_sym
20
+ array.astype(MLX::Core::Dtype.new(dtype_sym))
21
+ end
22
+
23
+ # Count total number of parameters in a model.
24
+ def count_parameters(model)
25
+ params = MLX::Utils.tree_flatten(model.parameters)
26
+ total = 0
27
+ params.each { |_, v| total += v.size }
28
+ total
29
+ end
30
+
31
+ # Estimate total model size in bytes.
32
+ def model_size_bytes(model)
33
+ mx = MLX::Core
34
+ params = MLX::Utils.tree_flatten(model.parameters)
35
+ total = 0
36
+ params.each do |_, v|
37
+ bytes_per_elem = case v.dtype
38
+ when mx.float32 then 4
39
+ when mx.float16, mx.bfloat16 then 2
40
+ when mx.int32 then 4
41
+ when mx.int8, mx.uint8 then 1
42
+ when mx.int16, mx.uint16 then 2
43
+ when mx.int64 then 8
44
+ else 4
45
+ end
46
+ total += v.size * bytes_per_elem
47
+ end
48
+ total
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,204 @@
1
+ module MlxLm
2
+ # Response object yielded during streaming generation
3
+ GenerationResponse = Struct.new(
4
+ :text,
5
+ :token,
6
+ :logprobs,
7
+ :prompt_tokens,
8
+ :prompt_tps,
9
+ :generation_tokens,
10
+ :generation_tps,
11
+ :peak_memory,
12
+ :finish_reason,
13
+ keyword_init: true
14
+ )
15
+
16
+ module Generate
17
+ module_function
18
+
19
+ # A generator producing token ids based on the given prompt from the model.
20
+ # Yields [token_id, logprobs] for each generated token.
21
+ def generate_step(
22
+ prompt,
23
+ model,
24
+ max_tokens: 256,
25
+ sampler: nil,
26
+ logits_processors: nil,
27
+ max_kv_size: nil,
28
+ prompt_cache: nil,
29
+ prefill_step_size: 2048
30
+ )
31
+ mx = MLX::Core
32
+
33
+ raise ArgumentError, "prompt must not be empty" if prompt.size == 0
34
+
35
+ tokens = nil
36
+
37
+ # Create the KV cache for generation
38
+ prompt_cache ||= Cache.make_prompt_cache(model, max_kv_size: max_kv_size)
39
+
40
+ sampler ||= ->(x) { mx.argmax(x, -1) }
41
+
42
+ model_call = ->(input_tokens_2d) {
43
+ model.call(input_tokens_2d, cache: prompt_cache)
44
+ }
45
+
46
+ step = ->(input_tokens_1d) {
47
+ seq_len = input_tokens_1d.size
48
+ input_2d = input_tokens_1d.reshape([1, seq_len])
49
+ logits = model_call.call(input_2d)
50
+
51
+ # Take the last token's logits
52
+ last_dim = logits.shape[1]
53
+ if last_dim > 1
54
+ logits = mx.split(logits, [last_dim - 1], 1)[1]
55
+ end
56
+ vocab_size = logits.shape[-1]
57
+ logits = logits.reshape([1, vocab_size])
58
+
59
+ if logits_processors && input_tokens_1d.size > 0
60
+ tokens = if tokens.nil?
61
+ input_tokens_1d
62
+ else
63
+ mx.concatenate([tokens, input_tokens_1d], 0)
64
+ end
65
+ logits_processors.each { |processor| logits = processor.call(tokens, logits) }
66
+ end
67
+
68
+ logprobs = logits - mx.logsumexp(logits, -1, true)
69
+ sampled = sampler.call(logprobs)
70
+ [sampled, logprobs.reshape([vocab_size])]
71
+ }
72
+
73
+ # Prompt prefilling - process prompt in chunks
74
+ prompt_arr = prompt.is_a?(::Array) ? mx.array(prompt, dtype: mx.uint32) : prompt
75
+ total_prompt_tokens = prompt_arr.size
76
+
77
+ # Process prompt chunks (all but last token)
78
+ while total_prompt_tokens > 1
79
+ remaining = total_prompt_tokens - 1
80
+ n_to_process = [prefill_step_size, remaining].min
81
+ chunk = mx.split(prompt_arr, [n_to_process], 0)[0]
82
+ chunk_len = chunk.size
83
+ model_call.call(chunk.reshape([1, chunk_len]))
84
+ mx.eval(*prompt_cache.map(&:state).flatten.compact)
85
+ prompt_arr = mx.split(prompt_arr, [n_to_process], 0)[1]
86
+ total_prompt_tokens -= n_to_process
87
+ end
88
+
89
+ # Process last token and get first generated token
90
+ y, logprobs = step.call(prompt_arr)
91
+ mx.eval(y, logprobs)
92
+
93
+ Enumerator.new do |yielder|
94
+ n = 0
95
+ loop do
96
+ break if n == max_tokens
97
+
98
+ y_1d = y.ndim > 1 ? y.reshape([y.size]) : y
99
+ next_y, next_logprobs = step.call(y_1d)
100
+ mx.eval(next_y, next_logprobs)
101
+
102
+ yielder.yield [y.item, logprobs]
103
+ y, logprobs = next_y, next_logprobs
104
+ n += 1
105
+ end
106
+ end
107
+ end
108
+
109
+ # Stream text generation from the model.
110
+ # Yields GenerationResponse objects with text segments.
111
+ def stream_generate(model, tokenizer, prompt, max_tokens: 256, **kwargs)
112
+ tokenizer = TokenizerWrapper.new(tokenizer) unless tokenizer.is_a?(TokenizerWrapper)
113
+
114
+ unless prompt.is_a?(MLX::Core::Array)
115
+ if prompt.is_a?(String)
116
+ prompt = tokenizer.encode(prompt)
117
+ end
118
+ prompt = MLX::Core.array(prompt, dtype: MLX::Core.uint32)
119
+ end
120
+
121
+ detokenizer = tokenizer.detokenizer
122
+
123
+ token_generator = generate_step(prompt, model, max_tokens: max_tokens, **kwargs)
124
+
125
+ tic = Process.clock_gettime(Process::CLOCK_MONOTONIC)
126
+ prompt_tps = 0.0
127
+
128
+ Enumerator.new do |yielder|
129
+ n = 0
130
+ last_token = nil
131
+ token_generator.each do |token, logprobs|
132
+ if n == 0
133
+ prompt_time = Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic
134
+ prompt_tps = prompt.size.to_f / [prompt_time, 1e-9].max
135
+ tic = Process.clock_gettime(Process::CLOCK_MONOTONIC)
136
+ end
137
+
138
+ last_token = token
139
+
140
+ if tokenizer.eos_token_ids.include?(token)
141
+ detokenizer.finalize
142
+ elapsed = [Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic, 1e-9].max
143
+ yielder.yield GenerationResponse.new(
144
+ text: detokenizer.last_segment,
145
+ token: token,
146
+ logprobs: logprobs,
147
+ prompt_tokens: prompt.size,
148
+ prompt_tps: prompt_tps,
149
+ generation_tokens: n + 1,
150
+ generation_tps: (n + 1).to_f / elapsed,
151
+ peak_memory: 0.0,
152
+ finish_reason: "stop"
153
+ )
154
+ break
155
+ end
156
+
157
+ detokenizer.add_token(token)
158
+ elapsed = [Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic, 1e-9].max
159
+
160
+ yielder.yield GenerationResponse.new(
161
+ text: detokenizer.last_segment,
162
+ token: token,
163
+ logprobs: logprobs,
164
+ prompt_tokens: prompt.size,
165
+ prompt_tps: prompt_tps,
166
+ generation_tokens: n + 1,
167
+ generation_tps: (n + 1).to_f / elapsed,
168
+ peak_memory: 0.0,
169
+ finish_reason: ((n + 1) == max_tokens ? "length" : nil)
170
+ )
171
+
172
+ n += 1
173
+ break if (n + 1) == max_tokens
174
+ end
175
+ end
176
+ end
177
+
178
+ # Non-streaming generation, returns complete text.
179
+ def generate(model, tokenizer, prompt, verbose: false, **kwargs)
180
+ text = ""
181
+ response = nil
182
+ stream_generate(model, tokenizer, prompt, **kwargs).each do |resp|
183
+ text += resp.text
184
+ response = resp
185
+ if verbose
186
+ print resp.text
187
+ $stdout.flush
188
+ end
189
+ end
190
+
191
+ if verbose
192
+ puts
193
+ puts "=" * 10
194
+ if text.empty?
195
+ puts "No text generated for this prompt"
196
+ return text
197
+ end
198
+ puts "Prompt: #{response.prompt_tokens} tokens, #{'%.3f' % response.prompt_tps} tokens-per-sec"
199
+ puts "Generation: #{response.generation_tokens} tokens, #{'%.3f' % response.generation_tps} tokens-per-sec"
200
+ end
201
+ text
202
+ end
203
+ end
204
+ end