mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,87 @@
1
+ require "json"
2
+
3
+ module MlxLm
4
+ module LoadUtils
5
+ module_function
6
+
7
+ # Load a model and tokenizer from a local directory.
8
+ #
9
+ # @param model_path [String] Path to the model directory
10
+ # @param tokenizer_config [Hash] Additional tokenizer config overrides
11
+ # @return [Array(nn::Module, TokenizerWrapper)] The loaded model and tokenizer
12
+ def load(model_path, tokenizer_config: nil)
13
+ model, config = load_model(model_path)
14
+ tokenizer = load_tokenizer(model_path)
15
+ [model, tokenizer]
16
+ end
17
+
18
+ # Load model from a local directory containing config.json and safetensors.
19
+ #
20
+ # @param model_path [String] Path to the model directory
21
+ # @return [Array(nn::Module, Hash)] The loaded model and config
22
+ def load_model(model_path)
23
+ config = Config.load(model_path)
24
+
25
+ # Get model and args classes from registry
26
+ model_class, args_class = Models.get_classes(config)
27
+
28
+ # Instantiate model args from config
29
+ model_args = args_class.from_dict(config)
30
+
31
+ # Create model
32
+ model = model_class.new(model_args)
33
+
34
+ # Load weights
35
+ weights = WeightUtils.load_sharded_safetensors(model_path)
36
+
37
+ # Apply model-specific weight sanitization
38
+ if model.respond_to?(:sanitize)
39
+ weights = model.sanitize(weights)
40
+ end
41
+
42
+ # Apply quantization if config specifies it
43
+ quantization = config["quantization"]
44
+ if quantization
45
+ group_size = quantization["group_size"] || 64
46
+ bits = quantization["bits"] || 4
47
+ Quantize.quantize_model(model, group_size: group_size, bits: bits, weights: weights)
48
+ end
49
+
50
+ # Load weights into model
51
+ model.load_weights(weights, strict: false)
52
+
53
+ [model, config]
54
+ end
55
+
56
+ # Load tokenizer from a local directory.
57
+ #
58
+ # @param model_path [String] Path containing tokenizer files
59
+ # @return [TokenizerWrapper] The loaded tokenizer
60
+ def load_tokenizer(model_path)
61
+ tokenizer_path = File.join(model_path, "tokenizer.json")
62
+ raise "Tokenizer not found at #{tokenizer_path}" unless File.exist?(tokenizer_path)
63
+
64
+ tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_path)
65
+
66
+ # Try to load tokenizer config for EOS token
67
+ config_path = File.join(model_path, "tokenizer_config.json")
68
+ eos_token = nil
69
+ eos_token_id = nil
70
+ if File.exist?(config_path)
71
+ tc = JSON.parse(File.read(config_path))
72
+ eos_token = tc["eos_token"]
73
+ eos_token = eos_token["content"] if eos_token.is_a?(Hash)
74
+ end
75
+
76
+ # Try to get eos_token_id from config.json
77
+ model_config_path = File.join(model_path, "config.json")
78
+ if File.exist?(model_config_path)
79
+ mc = JSON.parse(File.read(model_config_path))
80
+ eos_token_id = mc["eos_token_id"]
81
+ eos_token_id = eos_token_id.is_a?(::Array) ? eos_token_id : [eos_token_id].compact
82
+ end
83
+
84
+ TokenizerWrapper.new(tokenizer, eos_token: eos_token, eos_token_id: eos_token_id)
85
+ end
86
+ end
87
+ end
@@ -0,0 +1,54 @@
1
+ module MlxLm
2
+ # Base class for model configuration arguments.
3
+ # Mirrors Python's BaseModelArgs dataclass with from_dict filtering.
4
+ class BaseModelArgs
5
+ def self.fields
6
+ @fields ||= {}
7
+ end
8
+
9
+ def self.field(name, default: :__required__)
10
+ fields[name] = default
11
+
12
+ attr_accessor name
13
+
14
+ if default == :__required__
15
+ # no default
16
+ else
17
+ define_method(:"default_#{name}") { default }
18
+ end
19
+ end
20
+
21
+ def self.inherited(subclass)
22
+ super
23
+ # Copy parent fields into subclass
24
+ subclass.instance_variable_set(:@fields, fields.dup)
25
+ end
26
+
27
+ def self.from_dict(params)
28
+ known = {}
29
+ fields.each do |name, default|
30
+ str_name = name.to_s
31
+ if params.key?(str_name)
32
+ known[name] = params[str_name]
33
+ elsif default != :__required__
34
+ known[name] = default
35
+ end
36
+ end
37
+ new(**known)
38
+ end
39
+
40
+ def initialize(**kwargs)
41
+ kwargs.each do |k, v|
42
+ if self.class.fields.key?(k)
43
+ instance_variable_set(:"@#{k}", v)
44
+ end
45
+ end
46
+ # Set defaults for any fields not provided
47
+ self.class.fields.each do |name, default|
48
+ next if kwargs.key?(name)
49
+ next if default == :__required__
50
+ instance_variable_set(:"@#{name}", default)
51
+ end
52
+ end
53
+ end
54
+ end
@@ -0,0 +1,46 @@
1
+ module MlxLm
2
+ module Models
3
+ module Activations
4
+ module_function
5
+
6
+ def swiglu(gate, x)
7
+ MLX::NN.silu(gate) * x
8
+ end
9
+
10
+ def xielu(x, alpha_p, alpha_n, beta, eps)
11
+ mx = MLX::Core
12
+ alpha_p = MLX::NN.softplus(alpha_p)
13
+ alpha_n = beta + MLX::NN.softplus(alpha_n)
14
+
15
+ mx.where(
16
+ mx.greater(x, 0.0),
17
+ alpha_p * mx.square(x) + beta * x,
18
+ (mx.expm1(mx.minimum(x, eps)) - x) * alpha_n + beta * x
19
+ )
20
+ end
21
+
22
+ class XieLU < MLX::NN::Module
23
+ def initialize(
24
+ alpha_p_init: 0.8,
25
+ alpha_n_init: 0.8,
26
+ beta: 0.5,
27
+ eps: -1e-6
28
+ )
29
+ super()
30
+ mx = MLX::Core
31
+ alpha_p_tensor = mx.array(alpha_p_init)
32
+ alpha_n_tensor = mx.array(alpha_n_init - beta)
33
+
34
+ self.alpha_p = mx.log(mx.exp(alpha_p_tensor) - 1.0)
35
+ self.alpha_n = mx.log(mx.exp(alpha_n_tensor) - 1.0)
36
+ self.beta = mx.array(beta)
37
+ self.eps = mx.array(eps)
38
+ end
39
+
40
+ def call(x)
41
+ Activations.xielu(x, alpha_p, alpha_n, beta, eps)
42
+ end
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,131 @@
1
+ require_relative "afmoe"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Afm7
6
+ class ModelArgs < Afmoe::ModelArgs
7
+ field :model_type, default: "afm7"
8
+ field :hidden_dim, default: nil
9
+ field :num_layers, default: nil
10
+ field :num_kv_reuse_layers, default: 0
11
+ field :num_heads, default: nil
12
+ field :num_kv_heads, default: nil
13
+ field :hidden_dim_scale_factor, default: nil
14
+
15
+ def initialize(**kwargs)
16
+ afm7_style = _afm7_style_kwargs?(kwargs)
17
+ super
18
+
19
+ @hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !@hidden_dim.nil?
20
+ @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !@num_layers.nil?
21
+ @num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !@num_heads.nil?
22
+ @num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !@num_kv_heads.nil?
23
+
24
+ if kwargs.key?(:hidden_dim_scale_factor) && !@hidden_dim_scale_factor.nil? && !@hidden_size.nil?
25
+ @intermediate_size = (@hidden_size * @hidden_dim_scale_factor.to_f).to_i
26
+ end
27
+
28
+ if !@hidden_size.nil? && !@num_attention_heads.nil? && @num_attention_heads.to_i > 0
29
+ @head_dim = @hidden_size / @num_attention_heads
30
+ end
31
+
32
+ if kwargs.key?(:num_kv_reuse_layers) && !@num_hidden_layers.nil?
33
+ @num_dense_layers = [@num_hidden_layers.to_i - @num_kv_reuse_layers.to_i, 0].max
34
+ elsif afm7_style && !@num_hidden_layers.nil?
35
+ @num_dense_layers = @num_hidden_layers
36
+ end
37
+
38
+ if afm7_style
39
+ @num_experts = 1 unless kwargs.key?(:num_experts)
40
+ @num_experts_per_tok = 1 unless kwargs.key?(:num_experts_per_tok)
41
+ @num_shared_experts = 0 unless kwargs.key?(:num_shared_experts)
42
+ @mup_enabled = false unless kwargs.key?(:mup_enabled)
43
+ @layer_types = Array.new(@num_hidden_layers) { "full_attention" } unless kwargs.key?(:layer_types)
44
+ end
45
+
46
+ @num_key_value_heads ||= @num_attention_heads
47
+ @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" } unless @num_hidden_layers.nil?
48
+ end
49
+
50
+ def to_afmoe_dict
51
+ {
52
+ "model_type" => @model_type,
53
+ "layer_types" => @layer_types,
54
+ "vocab_size" => @vocab_size,
55
+ "hidden_size" => @hidden_size,
56
+ "intermediate_size" => @intermediate_size,
57
+ "moe_intermediate_size" => @moe_intermediate_size,
58
+ "num_hidden_layers" => @num_hidden_layers,
59
+ "num_attention_heads" => @num_attention_heads,
60
+ "num_key_value_heads" => @num_key_value_heads,
61
+ "head_dim" => @head_dim,
62
+ "max_position_embeddings" => @max_position_embeddings,
63
+ "rms_norm_eps" => @rms_norm_eps,
64
+ "rope_theta" => @rope_theta,
65
+ "rope_scaling" => @rope_scaling,
66
+ "tie_word_embeddings" => @tie_word_embeddings,
67
+ "num_experts" => @num_experts,
68
+ "num_experts_per_tok" => @num_experts_per_tok,
69
+ "num_shared_experts" => @num_shared_experts,
70
+ "num_dense_layers" => @num_dense_layers,
71
+ "route_norm" => @route_norm,
72
+ "route_scale" => @route_scale,
73
+ "score_func" => @score_func,
74
+ "n_group" => @n_group,
75
+ "topk_group" => @topk_group,
76
+ "sliding_window" => @sliding_window,
77
+ "mup_enabled" => @mup_enabled,
78
+ }
79
+ end
80
+
81
+ private
82
+
83
+ def _afm7_style_kwargs?(kwargs)
84
+ kwargs.key?(:hidden_dim) ||
85
+ kwargs.key?(:num_layers) ||
86
+ kwargs.key?(:num_heads) ||
87
+ kwargs.key?(:num_kv_heads) ||
88
+ kwargs.key?(:num_kv_reuse_layers) ||
89
+ kwargs.key?(:hidden_dim_scale_factor)
90
+ end
91
+ end
92
+
93
+ class Model < MLX::NN::Module
94
+ def initialize(args)
95
+ super()
96
+ @args = args
97
+ self.model_type = args.model_type
98
+ self.wrapped_model = Afmoe::Model.new(Afmoe::ModelArgs.from_dict(args.to_afmoe_dict))
99
+ end
100
+
101
+ def call(inputs, cache: nil)
102
+ wrapped_model.call(inputs, cache: cache)
103
+ end
104
+
105
+ def sanitize(weights)
106
+ wrapped_model.sanitize(weights)
107
+ end
108
+
109
+ def layers
110
+ wrapped_model.layers
111
+ end
112
+
113
+ def make_cache
114
+ return nil unless wrapped_model.respond_to?(:make_cache)
115
+
116
+ wrapped_model.make_cache
117
+ end
118
+
119
+ def cast_predicate
120
+ wrapped_model.cast_predicate
121
+ end
122
+
123
+ def quant_predicate
124
+ wrapped_model.quant_predicate
125
+ end
126
+ end
127
+
128
+ Models.register("afm7", Model, ModelArgs)
129
+ end
130
+ end
131
+ end