mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,158 @@
1
+ require_relative "falcon_h1"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Jamba
6
+ class ModelArgs < FalconH1::ModelArgs
7
+ field :model_type, default: "jamba"
8
+ field :attn_layer_offset, default: 1
9
+ field :attn_layer_period, default: 2
10
+ field :expert_layer_offset, default: 1
11
+ field :expert_layer_period, default: 2
12
+ field :mamba_d_state, default: nil
13
+ field :mamba_expand, default: nil
14
+ field :num_experts, default: 1
15
+ field :num_experts_per_tok, default: 1
16
+ field :mamba_dt_rank, default: "auto"
17
+ field :mamba_proj_bias, default: false
18
+ field :mamba_conv_bias, default: true
19
+ field :layers_block_type, default: nil
20
+
21
+ def initialize(**kwargs)
22
+ super
23
+ @mamba_d_conv ||= 4
24
+ @num_key_value_heads ||= @num_attention_heads
25
+ @layers_block_type ||= _default_layers_block_type
26
+ @num_hidden_layers ||= Array(@layers_block_type).length
27
+ @block_types ||= _to_block_types
28
+ end
29
+
30
+ def to_falcon_h1_dict
31
+ hidden_size = @hidden_size
32
+ attention_heads = @num_attention_heads
33
+ inferred_head_dim = if !@head_dim.nil?
34
+ @head_dim
35
+ elsif !hidden_size.nil? && attention_heads.to_i > 0
36
+ hidden_size / attention_heads
37
+ else
38
+ 64
39
+ end
40
+
41
+ {
42
+ "model_type" => @model_type,
43
+ "attention_bias" => @attention_bias,
44
+ "head_dim" => inferred_head_dim,
45
+ "hidden_size" => hidden_size,
46
+ "intermediate_size" => @intermediate_size,
47
+ "max_position_embeddings" => @max_position_embeddings,
48
+ "mamba_d_conv" => @mamba_d_conv,
49
+ "num_attention_heads" => attention_heads,
50
+ "num_hidden_layers" => @num_hidden_layers,
51
+ "num_key_value_heads" => @num_key_value_heads,
52
+ "rms_norm_eps" => @rms_norm_eps,
53
+ "rope_theta" => @rope_theta,
54
+ "vocab_size" => @vocab_size,
55
+ "tie_word_embeddings" => @tie_word_embeddings,
56
+ "attention_window_size" => @attention_window_size,
57
+ "block_types" => @block_types,
58
+ }
59
+ end
60
+
61
+ private
62
+
63
+ def _default_layers_block_type
64
+ count = @num_hidden_layers.to_i
65
+ return nil if count <= 0
66
+
67
+ period = @attn_layer_period.to_i
68
+ offset = @attn_layer_offset.to_i
69
+ period = 1 if period <= 0
70
+
71
+ Array.new(count) do |idx|
72
+ (idx % period == offset) ? "attention" : "mamba"
73
+ end
74
+ end
75
+
76
+ def _to_block_types
77
+ return @block_types if @block_types.is_a?(Array) && !@block_types.empty?
78
+ return nil unless @layers_block_type.is_a?(Array) && !@layers_block_type.empty?
79
+
80
+ @layers_block_type.map { |layer_type| layer_type.to_s == "mamba" ? "recurrent" : "attention" }
81
+ end
82
+ end
83
+
84
+ class Model < MLX::NN::Module
85
+ def initialize(args)
86
+ super()
87
+ @args = args
88
+ self.model_type = args.model_type
89
+ self.wrapped_model = FalconH1::Model.new(
90
+ FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict)
91
+ )
92
+ end
93
+
94
+ def call(inputs, cache: nil)
95
+ wrapped_model.call(inputs, cache: cache)
96
+ end
97
+
98
+ def sanitize(weights)
99
+ normalized = weights.dup
100
+ _stack_experts!(normalized)
101
+
102
+ remapped = {}
103
+ normalized.each do |key, value|
104
+ remapped[_remap_weight_key(key)] = value
105
+ end
106
+ wrapped_model.sanitize(remapped)
107
+ end
108
+
109
+ def layers
110
+ wrapped_model.layers
111
+ end
112
+
113
+ def make_cache
114
+ return nil unless wrapped_model.respond_to?(:make_cache)
115
+
116
+ wrapped_model.make_cache
117
+ end
118
+
119
+ private
120
+
121
+ def _stack_experts!(weights)
122
+ mx = MLX::Core
123
+
124
+ @args.num_hidden_layers.to_i.times do |layer_idx|
125
+ prefix = "model.layers.#{layer_idx}.feed_forward"
126
+ %w[gate_proj up_proj down_proj].each do |projection|
127
+ %w[weight bias scales biases].each do |param|
128
+ pattern = /\A#{Regexp.escape(prefix)}\.experts\.(\d+)\.#{projection}\.#{param}\z/
129
+ matches = weights.keys.filter_map do |key|
130
+ match = pattern.match(key)
131
+ next nil unless match
132
+
133
+ [match[1].to_i, key]
134
+ end
135
+ next if matches.empty?
136
+
137
+ stacked = matches.sort_by(&:first).map do |(_, key)|
138
+ weights.delete(key)
139
+ end
140
+ weights["#{prefix}.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked)
141
+ end
142
+ end
143
+ end
144
+ end
145
+
146
+ def _remap_weight_key(key)
147
+ mapped = key.dup
148
+ mapped = mapped.gsub("model.norm.", "model.final_layernorm.")
149
+ mapped = mapped.gsub(".mixer.", ".mamba.")
150
+ mapped = mapped.gsub(".feed_forward.router.", ".feed_forward.gate.")
151
+ mapped
152
+ end
153
+ end
154
+
155
+ Models.register("jamba", Model, ModelArgs)
156
+ end
157
+ end
158
+ end
@@ -0,0 +1,98 @@
1
+ require_relative "deepseek"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module KimiK25
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "kimi_k25"
8
+ field :text_config, default: nil
9
+
10
+ def self.from_dict(params)
11
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
12
+ return super if has_text_config
13
+
14
+ model_type = params["model_type"] || params[:model_type] || "kimi_k25"
15
+ new(model_type: model_type, text_config: params)
16
+ end
17
+
18
+ def initialize(**kwargs)
19
+ super
20
+ @text_config = _stringify_keys(@text_config || {})
21
+ @text_config["model_type"] ||= "deepseek"
22
+ end
23
+
24
+ private
25
+
26
+ def _stringify_keys(hash)
27
+ hash.each_with_object({}) do |(key, value), out|
28
+ out[key.to_s] = value
29
+ end
30
+ end
31
+ end
32
+
33
+ class Model < MLX::NN::Module
34
+ MULTIMODAL_PREFIXES = %w[
35
+ vision_tower
36
+ vision_model
37
+ multi_modal_projector
38
+ mm_projector
39
+ ].freeze
40
+
41
+ def initialize(args)
42
+ super()
43
+ @args = args
44
+ self.model_type = args.model_type
45
+ self.language_model = DeepSeek::Model.new(
46
+ DeepSeek::ModelArgs.from_dict(args.text_config)
47
+ )
48
+ end
49
+
50
+ def call(inputs, cache: nil, input_embeddings: nil)
51
+ language_model.call(inputs, cache: cache)
52
+ end
53
+
54
+ def sanitize(weights)
55
+ language_weights = {}
56
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
57
+
58
+ flat_weights.each do |key, value|
59
+ next if _multimodal_key?(key)
60
+
61
+ normalized_key = key.start_with?("language_model.") ? key.delete_prefix("language_model.") : key
62
+ language_weights[normalized_key] = value
63
+ end
64
+
65
+ sanitized_language = if language_model.respond_to?(:sanitize)
66
+ language_model.sanitize(language_weights)
67
+ else
68
+ language_weights
69
+ end
70
+
71
+ sanitized_language.each_with_object({}) do |(key, value), out|
72
+ out["language_model.#{key}"] = value
73
+ end
74
+ end
75
+
76
+ def model
77
+ language_model.model
78
+ end
79
+
80
+ def layers
81
+ model.layers
82
+ end
83
+
84
+ def cast_predicate
85
+ lambda { |key| !key.include?("e_score_correction_bias") }
86
+ end
87
+
88
+ private
89
+
90
+ def _multimodal_key?(key)
91
+ MULTIMODAL_PREFIXES.any? { |prefix| key == prefix || key.start_with?("#{prefix}.") }
92
+ end
93
+ end
94
+
95
+ Models.register("kimi_k25", Model, ModelArgs)
96
+ end
97
+ end
98
+ end
@@ -0,0 +1,124 @@
1
+ require_relative "bailing_moe_linear"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module KimiLinear
6
+ class ModelArgs < BailingMoeLinear::ModelArgs
7
+ field :model_type, default: "kimi_linear"
8
+ field :hidden_dim, default: nil
9
+ field :ffn_hidden_size, default: nil
10
+ field :num_layers, default: nil
11
+ field :num_heads, default: nil
12
+ field :num_kv_heads, default: nil
13
+ field :num_local_experts, default: nil
14
+ field :n_routed_experts, default: nil
15
+ field :n_shared_experts, default: nil
16
+ field :top_k, default: nil
17
+ field :score_func, default: nil
18
+
19
+ def self.from_dict(params)
20
+ normalized = params.each_with_object({}) do |(key, value), out|
21
+ out[key.to_s] = value
22
+ end
23
+
24
+ {
25
+ "hidden_dim" => "hidden_size",
26
+ "ffn_hidden_size" => "intermediate_size",
27
+ "num_layers" => "num_hidden_layers",
28
+ "num_heads" => "num_attention_heads",
29
+ "num_kv_heads" => "num_key_value_heads",
30
+ "num_local_experts" => "num_experts",
31
+ "n_routed_experts" => "num_experts",
32
+ "n_shared_experts" => "num_shared_experts",
33
+ "top_k" => "num_experts_per_tok",
34
+ "score_func" => "score_function",
35
+ }.each do |source_key, target_key|
36
+ next unless normalized.key?(source_key)
37
+
38
+ normalized[target_key] = normalized[source_key] unless normalized.key?(target_key)
39
+ end
40
+
41
+ normalized["model_type"] ||= "kimi_linear"
42
+ super(normalized)
43
+ end
44
+
45
+ def initialize(**kwargs)
46
+ super
47
+ @hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !kwargs.key?(:hidden_size) && !@hidden_dim.nil?
48
+ @intermediate_size = @ffn_hidden_size if kwargs.key?(:ffn_hidden_size) && !kwargs.key?(:intermediate_size) && !@ffn_hidden_size.nil?
49
+ @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !kwargs.key?(:num_hidden_layers) && !@num_layers.nil?
50
+ @num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !kwargs.key?(:num_attention_heads) && !@num_heads.nil?
51
+ @num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !kwargs.key?(:num_key_value_heads) && !@num_kv_heads.nil?
52
+ @num_experts = @num_local_experts if kwargs.key?(:num_local_experts) && !kwargs.key?(:num_experts) && !@num_local_experts.nil?
53
+ @num_experts = @n_routed_experts if kwargs.key?(:n_routed_experts) && !kwargs.key?(:num_experts) && !kwargs.key?(:num_local_experts) && !@n_routed_experts.nil?
54
+ @num_shared_experts = @n_shared_experts if kwargs.key?(:n_shared_experts) && !kwargs.key?(:num_shared_experts) && !@n_shared_experts.nil?
55
+ @num_experts_per_tok = @top_k if kwargs.key?(:top_k) && !kwargs.key?(:num_experts_per_tok) && !@top_k.nil?
56
+ @score_function = @score_func if kwargs.key?(:score_func) && !kwargs.key?(:score_function) && !@score_func.nil?
57
+ @num_key_value_heads ||= @num_attention_heads
58
+ end
59
+
60
+ def to_bailing_moe_linear_dict
61
+ to_bailing_moe_dict
62
+ end
63
+ end
64
+
65
+ class Model < MLX::NN::Module
66
+ def initialize(args)
67
+ super()
68
+ @args = args
69
+ self.model_type = args.model_type
70
+ self.wrapped_model = BailingMoeLinear::Model.new(
71
+ BailingMoeLinear::ModelArgs.from_dict(args.to_bailing_moe_linear_dict)
72
+ )
73
+ end
74
+
75
+ def call(inputs, cache: nil)
76
+ wrapped_model.call(inputs, cache: cache)
77
+ end
78
+
79
+ def sanitize(weights)
80
+ remapped = {}
81
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
82
+ flat_weights.each do |key, value|
83
+ remapped[_remap_weight_key(key)] = value
84
+ end
85
+ wrapped_model.sanitize(remapped)
86
+ end
87
+
88
+ def layers
89
+ wrapped_model.layers
90
+ end
91
+
92
+ def make_cache
93
+ return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache)
94
+
95
+ nil
96
+ end
97
+
98
+ def cast_predicate
99
+ return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate)
100
+
101
+ lambda { |_key| true }
102
+ end
103
+
104
+ def quant_predicate
105
+ return wrapped_model.quant_predicate if wrapped_model.respond_to?(:quant_predicate)
106
+
107
+ lambda { |_key, _value| true }
108
+ end
109
+
110
+ private
111
+
112
+ def _remap_weight_key(key)
113
+ mapped = key.dup
114
+ mapped = mapped.gsub(".mlp.router.", ".mlp.gate.")
115
+ mapped = mapped.gsub("model.embed_tokens.", "model.word_embeddings.")
116
+ mapped = mapped.gsub("model.tok_embeddings.", "model.word_embeddings.")
117
+ mapped
118
+ end
119
+ end
120
+
121
+ Models.register("kimi_linear", Model, ModelArgs)
122
+ end
123
+ end
124
+ end
@@ -0,0 +1,93 @@
1
+ require_relative "deepseek"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module KimiVL
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "kimi_vl"
8
+ field :text_config, default: nil
9
+
10
+ def self.from_dict(params)
11
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
12
+ return super if has_text_config
13
+
14
+ model_type = params["model_type"] || params[:model_type] || "kimi_vl"
15
+ new(model_type: model_type, text_config: params)
16
+ end
17
+
18
+ def initialize(**kwargs)
19
+ super
20
+ @text_config = _stringify_keys(@text_config || {})
21
+ @text_config["model_type"] ||= "deepseek"
22
+ end
23
+
24
+ private
25
+
26
+ def _stringify_keys(hash)
27
+ hash.each_with_object({}) do |(key, value), out|
28
+ out[key.to_s] = value
29
+ end
30
+ end
31
+ end
32
+
33
+ class Model < MLX::NN::Module
34
+ def initialize(args)
35
+ super()
36
+ @args = args
37
+ self.model_type = args.model_type
38
+ self.language_model = DeepSeek::Model.new(
39
+ DeepSeek::ModelArgs.from_dict(args.text_config)
40
+ )
41
+ end
42
+
43
+ def call(inputs, cache: nil, input_embeddings: nil)
44
+ language_model.call(inputs, cache: cache)
45
+ end
46
+
47
+ def sanitize(weights)
48
+ language_weights = {}
49
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
50
+
51
+ flat_weights.each do |key, value|
52
+ next if _drop_key?(key)
53
+
54
+ normalized_key = key.start_with?("language_model.") ? key.delete_prefix("language_model.") : key
55
+ language_weights[normalized_key] = value
56
+ end
57
+
58
+ sanitized_language = if language_model.respond_to?(:sanitize)
59
+ language_model.sanitize(language_weights)
60
+ else
61
+ language_weights
62
+ end
63
+
64
+ sanitized_language.each_with_object({}) do |(key, value), out|
65
+ out["language_model.#{key}"] = value
66
+ end
67
+ end
68
+
69
+ def model
70
+ language_model.model
71
+ end
72
+
73
+ def layers
74
+ model.layers
75
+ end
76
+
77
+ def cast_predicate
78
+ lambda { |key| !key.include?("e_score_correction_bias") }
79
+ end
80
+
81
+ private
82
+
83
+ def _drop_key?(key)
84
+ key.include?("vision_tower") ||
85
+ key.include?("multi_modal_projector") ||
86
+ key.include?("rotary_emb")
87
+ end
88
+ end
89
+
90
+ Models.register("kimi_vl", Model, ModelArgs)
91
+ end
92
+ end
93
+ end