mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,91 @@
1
+ require_relative "bailing_moe"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module BailingMoeLinear
6
+ class ModelArgs < BailingMoe::ModelArgs
7
+ field :model_type, default: "bailing_moe_linear"
8
+ field :layer_group_size, default: nil
9
+ field :group_norm_size, default: nil
10
+ field :use_rmsnorm, default: nil
11
+ field :head_dim, default: nil
12
+ field :rope_traditional, default: false
13
+
14
+ def to_bailing_moe_dict
15
+ {
16
+ "model_type" => @model_type,
17
+ "hidden_size" => @hidden_size,
18
+ "intermediate_size" => @intermediate_size,
19
+ "max_position_embeddings" => @max_position_embeddings,
20
+ "moe_intermediate_size" => @moe_intermediate_size,
21
+ "num_experts" => @num_experts,
22
+ "num_shared_experts" => @num_shared_experts,
23
+ "norm_topk_prob" => @norm_topk_prob,
24
+ "num_attention_heads" => @num_attention_heads,
25
+ "num_experts_per_tok" => @num_experts_per_tok,
26
+ "num_hidden_layers" => @num_hidden_layers,
27
+ "num_key_value_heads" => @num_key_value_heads,
28
+ "rms_norm_eps" => @rms_norm_eps,
29
+ "rope_theta" => @rope_theta,
30
+ "vocab_size" => @vocab_size,
31
+ "first_k_dense_replace" => @first_k_dense_replace,
32
+ "rope_scaling" => @rope_scaling,
33
+ "use_bias" => @use_bias,
34
+ "use_qkv_bias" => @use_qkv_bias,
35
+ "norm_head" => @norm_head,
36
+ "norm_softmax" => @norm_softmax,
37
+ "use_qk_norm" => @use_qk_norm,
38
+ "tie_word_embeddings" => @tie_word_embeddings,
39
+ "partial_rotary_factor" => @partial_rotary_factor,
40
+ "rotary_dim" => @rotary_dim,
41
+ "moe_router_enable_expert_bias" => @moe_router_enable_expert_bias,
42
+ "moe_router_enable_routed_scaling" => @moe_router_enable_routed_scaling,
43
+ "routed_scaling_factor" => @routed_scaling_factor,
44
+ "score_function" => @score_function,
45
+ "n_group" => @n_group,
46
+ "topk_group" => @topk_group,
47
+ "moe_shared_expert_intermediate_size" => @moe_shared_expert_intermediate_size,
48
+ "moe_router_enable_shared_expert" => @moe_router_enable_shared_expert,
49
+ }
50
+ end
51
+ end
52
+
53
+ class Model < MLX::NN::Module
54
+ def initialize(args)
55
+ super()
56
+ @args = args
57
+ self.model_type = args.model_type
58
+ self.wrapped_model = BailingMoe::Model.new(BailingMoe::ModelArgs.from_dict(args.to_bailing_moe_dict))
59
+ end
60
+
61
+ def call(inputs, cache: nil)
62
+ wrapped_model.call(inputs, cache: cache)
63
+ end
64
+
65
+ def sanitize(weights)
66
+ wrapped_model.sanitize(weights)
67
+ end
68
+
69
+ def layers
70
+ wrapped_model.layers
71
+ end
72
+
73
+ def make_cache
74
+ return nil unless wrapped_model.respond_to?(:make_cache)
75
+
76
+ wrapped_model.make_cache
77
+ end
78
+
79
+ def cast_predicate
80
+ wrapped_model.cast_predicate
81
+ end
82
+
83
+ def quant_predicate
84
+ wrapped_model.quant_predicate
85
+ end
86
+ end
87
+
88
+ Models.register("bailing_moe_linear", Model, ModelArgs)
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,108 @@
1
+ module MlxLm
2
+ module Models
3
+ class BitLinear < MLX::NN::Module
4
+ attr_reader :in_features, :out_features, :invert_weight_scales
5
+
6
+ def initialize(
7
+ in_features,
8
+ out_features,
9
+ bias: true,
10
+ invert_weight_scales: false
11
+ )
12
+ super()
13
+ mx = MLX::Core
14
+
15
+ @in_features = in_features
16
+ @out_features = out_features
17
+ @invert_weight_scales = invert_weight_scales
18
+
19
+ packed_out_features = (out_features + 3) / 4
20
+ self.weight = mx.zeros([packed_out_features, in_features], mx.uint8)
21
+ self.weight_scale = mx.array([1.0], dtype: mx.float32)
22
+ self.bias = mx.zeros([out_features], mx.float32) if bias
23
+ end
24
+
25
+ def call(x)
26
+ y = execute_matmul_kernel(x, weight)
27
+ state.key?("bias") ? MLX::Core.add(y, bias) : y
28
+ end
29
+
30
+ def execute_matmul_kernel(x, packed_weights)
31
+ # TODO(phase1e): switch to a custom Metal kernel once MLX Ruby exposes
32
+ # a stable fast-kernel API equivalent to Python's mx.fast.metal_kernel.
33
+ execute_matmul_fallback(x, packed_weights)
34
+ end
35
+
36
+ private
37
+
38
+ def execute_matmul_fallback(x, packed_weights)
39
+ input_dims = x.shape[-1]
40
+ unless input_dims == @in_features
41
+ raise ArgumentError, "Expected input features #{@in_features}, got #{input_dims}"
42
+ end
43
+
44
+ ternary_weight = unpack_packed_weights(packed_weights, x.dtype)
45
+ out = MLX::Core.matmul(x, ternary_weight.T)
46
+
47
+ scale = weight_scale.astype(x.dtype)
48
+ scale = MLX::Core.divide(1.0, scale) if invert_weight_scales
49
+ MLX::Core.multiply(out, scale)
50
+ end
51
+
52
+ def unpack_packed_weights(packed_weights, dtype)
53
+ mx = MLX::Core
54
+
55
+ w0 = (mx.bitwise_and(packed_weights, 0x03).astype(dtype) - 1.0)
56
+ w1 = (mx.bitwise_and(mx.right_shift(packed_weights, 2), 0x03).astype(dtype) - 1.0)
57
+ w2 = (mx.bitwise_and(mx.right_shift(packed_weights, 4), 0x03).astype(dtype) - 1.0)
58
+ w3 = (mx.bitwise_and(mx.right_shift(packed_weights, 6), 0x03).astype(dtype) - 1.0)
59
+
60
+ expanded = mx.concatenate([w0, w1, w2, w3], 0)
61
+ return expanded if expanded.shape[0] == @out_features
62
+
63
+ keep = mx.arange(0, @out_features, 1, mx.int32)
64
+ mx.take(expanded, keep, 0)
65
+ end
66
+ end
67
+
68
+ module_function
69
+
70
+ def bitnet_quantize(model, quantization_config = {})
71
+ modules_to_not_convert = Array(config_value(quantization_config, "modules_to_not_convert", []))
72
+ .map(&:to_s)
73
+ invert_weight_scales = config_value(quantization_config, "linear_class", "").to_s != "autobitlinear"
74
+
75
+ replacements = []
76
+ leaves = model.leaf_modules
77
+ flat = MLX::Utils.tree_flatten(leaves, is_leaf: lambda { |node| node.is_a?(MLX::NN::Module) })
78
+
79
+ flat.each do |path, layer|
80
+ path_s = path.to_s
81
+ next if modules_to_not_convert.include?(path_s)
82
+ next unless layer.is_a?(MLX::NN::Linear)
83
+
84
+ out_features, in_features = layer.weight.shape
85
+ replacements << [
86
+ path_s,
87
+ BitLinear.new(
88
+ in_features,
89
+ out_features,
90
+ bias: layer.state.key?("bias"),
91
+ invert_weight_scales: invert_weight_scales
92
+ ),
93
+ ]
94
+ end
95
+
96
+ model.update_modules(MLX::Utils.tree_unflatten(replacements)) unless replacements.empty?
97
+ model
98
+ end
99
+
100
+ def config_value(config, key, default = nil)
101
+ return default if config.nil?
102
+ return config[key] if config.key?(key)
103
+
104
+ config.fetch(key.to_sym, default)
105
+ end
106
+ private_class_method :config_value
107
+ end
108
+ end
@@ -0,0 +1,176 @@
1
+ module MlxLm
2
+ module Models
3
+ module Bitnet
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "bitnet"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :intermediate_size, default: 11_008
9
+ field :num_attention_heads, default: 32
10
+ field :num_key_value_heads, default: nil
11
+ field :rms_norm_eps, default: 1e-6
12
+ field :vocab_size, default: 32_000
13
+ field :head_dim, default: nil
14
+ field :max_position_embeddings, default: nil
15
+ field :attention_bias, default: false
16
+ field :mlp_bias, default: false
17
+ field :rope_theta, default: 10_000.0
18
+ field :rope_traditional, default: false
19
+ field :rope_scaling, default: nil
20
+ field :tie_word_embeddings, default: true
21
+
22
+ def initialize(**kwargs)
23
+ super
24
+ @num_key_value_heads ||= @num_attention_heads
25
+ @head_dim ||= @hidden_size / @num_attention_heads
26
+ end
27
+ end
28
+
29
+ class Attention < MLX::NN::Module
30
+ def initialize(args)
31
+ super()
32
+ dim = args.hidden_size
33
+ @n_heads = args.num_attention_heads
34
+ @n_kv_heads = args.num_key_value_heads
35
+ @head_dim = args.head_dim
36
+ @scale = @head_dim**(-0.5)
37
+
38
+ bias = args.attention_bias
39
+ self.q_proj = BitLinear.new(dim, @n_heads * @head_dim, bias: bias)
40
+ self.k_proj = BitLinear.new(dim, @n_kv_heads * @head_dim, bias: bias)
41
+ self.v_proj = BitLinear.new(dim, @n_kv_heads * @head_dim, bias: bias)
42
+ self.o_proj = BitLinear.new(@n_heads * @head_dim, dim, bias: bias)
43
+
44
+ self.rope = MlxLm::Models.initialize_rope(
45
+ @head_dim,
46
+ args.rope_theta,
47
+ args.rope_traditional,
48
+ args.rope_scaling,
49
+ max_position_embeddings: args.max_position_embeddings
50
+ )
51
+ self.attn_sub_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
52
+ end
53
+
54
+ def call(x, mask: nil, cache: nil)
55
+ mx = MLX::Core
56
+ b, l, _d = x.shape
57
+
58
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
60
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
61
+
62
+ if cache
63
+ queries = rope.call(queries, offset: cache.offset)
64
+ keys = rope.call(keys, offset: cache.offset)
65
+ keys, values = cache.update_and_fetch(keys, values)
66
+ else
67
+ queries = rope.call(queries)
68
+ keys = rope.call(keys)
69
+ end
70
+
71
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
72
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
73
+ o_proj.call(attn_sub_norm.call(output))
74
+ end
75
+ end
76
+
77
+ class MLP < MLX::NN::Module
78
+ def initialize(args)
79
+ super()
80
+ dim = args.hidden_size
81
+ hidden_dim = args.intermediate_size
82
+ bias = args.mlp_bias
83
+
84
+ self.gate_proj = BitLinear.new(dim, hidden_dim, bias: bias)
85
+ self.down_proj = BitLinear.new(hidden_dim, dim, bias: bias)
86
+ self.up_proj = BitLinear.new(dim, hidden_dim, bias: bias)
87
+ self.ffn_sub_norm = MLX::NN::RMSNorm.new(hidden_dim, eps: args.rms_norm_eps)
88
+ end
89
+
90
+ def call(x)
91
+ h = MLX::NN.relu2(gate_proj.call(x)) * up_proj.call(x)
92
+ down_proj.call(ffn_sub_norm.call(h))
93
+ end
94
+ end
95
+
96
+ class TransformerBlock < MLX::NN::Module
97
+ def initialize(args)
98
+ super()
99
+ self.self_attn = Attention.new(args)
100
+ self.mlp = MLP.new(args)
101
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
102
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
103
+ end
104
+
105
+ def call(x, mask: nil, cache: nil)
106
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
107
+ h = x + r
108
+ r = mlp.call(post_attention_layernorm.call(h))
109
+ h + r
110
+ end
111
+ end
112
+
113
+ class BitnetModel < MLX::NN::Module
114
+ def initialize(args)
115
+ super()
116
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
117
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
118
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
119
+ end
120
+
121
+ def call(inputs, cache: nil)
122
+ h = embed_tokens.call(inputs)
123
+ layer_cache = cache || [nil] * layers.length
124
+ mask = _create_attention_mask(h, layer_cache[0])
125
+
126
+ layers.each_with_index do |layer, i|
127
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
128
+ end
129
+ norm.call(h)
130
+ end
131
+
132
+ private
133
+
134
+ def _create_attention_mask(h, cache)
135
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
136
+ return nil if h.shape[1] == 1
137
+
138
+ "causal"
139
+ end
140
+ end
141
+
142
+ class Model < MLX::NN::Module
143
+ def initialize(args)
144
+ super()
145
+ @args = args
146
+ self.model_type = args.model_type
147
+ self.model = BitnetModel.new(args)
148
+ unless args.tie_word_embeddings
149
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
150
+ end
151
+ end
152
+
153
+ def call(inputs, cache: nil)
154
+ out = model.call(inputs, cache: cache)
155
+ if @args.tie_word_embeddings
156
+ model.embed_tokens.as_linear(out)
157
+ else
158
+ lm_head.call(out)
159
+ end
160
+ end
161
+
162
+ def sanitize(weights)
163
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
164
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
165
+ result
166
+ end
167
+
168
+ def layers
169
+ model.layers
170
+ end
171
+ end
172
+
173
+ Models.register("bitnet", Model, ModelArgs)
174
+ end
175
+ end
176
+ end