mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,169 @@
1
+ module MlxLm
2
+ module Models
3
+ module Exaone
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type
6
+ field :hidden_size
7
+ field :num_layers
8
+ field :intermediate_size
9
+ field :num_attention_heads
10
+ field :vocab_size
11
+ field :rope_theta
12
+ field :layer_norm_epsilon
13
+ field :num_key_value_heads
14
+ field :head_dim, default: nil
15
+ field :max_position_embeddings, default: nil
16
+ field :rope_traditional, default: false
17
+ field :rope_scaling, default: nil
18
+ field :tie_word_embeddings, default: true
19
+ field :attention_bias, default: false
20
+ field :mlp_bias, default: false
21
+
22
+ def initialize(**kwargs)
23
+ super
24
+ @head_dim ||= @hidden_size / @num_attention_heads
25
+ end
26
+ end
27
+
28
+ class AttentionModule < MLX::NN::Module
29
+ def initialize(args)
30
+ super()
31
+ dim = args.hidden_size
32
+ @n_heads = args.num_attention_heads
33
+ @n_kv_heads = args.num_key_value_heads
34
+ @head_dim = args.head_dim
35
+ @scale = @head_dim**(-0.5)
36
+
37
+ bias = args.attention_bias
38
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
39
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
40
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
41
+ self.out_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
42
+
43
+ self.rope = MlxLm::Models.initialize_rope(
44
+ @head_dim,
45
+ args.rope_theta,
46
+ args.rope_traditional,
47
+ args.rope_scaling,
48
+ max_position_embeddings: args.max_position_embeddings
49
+ )
50
+ end
51
+
52
+ def call(x, mask: nil, cache: nil)
53
+ mx = MLX::Core
54
+ b, l, d = x.shape
55
+
56
+ q = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
57
+ k = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
58
+ v = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+
60
+ if cache
61
+ q = rope.call(q, offset: cache.offset)
62
+ k = rope.call(k, offset: cache.offset)
63
+ k, v = cache.update_and_fetch(k, v)
64
+ else
65
+ q = rope.call(q)
66
+ k = rope.call(k)
67
+ end
68
+
69
+ out = mx.scaled_dot_product_attention(q, k, v, @scale, mask)
70
+ out = out.transpose([0, 2, 1, 3]).reshape([b, l, d])
71
+ out_proj.call(out)
72
+ end
73
+ end
74
+
75
+ class Attention < MLX::NN::Module
76
+ def initialize(args)
77
+ super()
78
+ self.attention = AttentionModule.new(args)
79
+ end
80
+ end
81
+
82
+ class MLP < MLX::NN::Module
83
+ def initialize(args)
84
+ super()
85
+ dim = args.hidden_size
86
+ hidden_dim = args.intermediate_size
87
+ bias = args.mlp_bias
88
+ self.c_fc_0 = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
89
+ self.c_fc_1 = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
90
+ self.c_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
91
+ end
92
+
93
+ def call(x)
94
+ c_proj.call(MlxLm::Models::Activations.swiglu(c_fc_0.call(x), c_fc_1.call(x)))
95
+ end
96
+ end
97
+
98
+ class TransformerBlock < MLX::NN::Module
99
+ def initialize(args)
100
+ super()
101
+ self.ln_1 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
102
+ self.attn = Attention.new(args)
103
+ self.ln_2 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
104
+ self.mlp = MLP.new(args)
105
+ end
106
+
107
+ def call(x, mask: nil, cache: nil)
108
+ h = x + attn.attention.call(ln_1.call(x), mask: mask, cache: cache)
109
+ h + mlp.call(ln_2.call(h))
110
+ end
111
+ end
112
+
113
+ class ExaoneModel < MLX::NN::Module
114
+ def initialize(args)
115
+ super()
116
+ self.wte = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
117
+ self.h = Array.new(args.num_layers) { TransformerBlock.new(args) }
118
+ self.ln_f = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
119
+ end
120
+
121
+ def call(inputs, cache: nil)
122
+ hidden = wte.call(inputs)
123
+ layer_cache = cache || [nil] * h.length
124
+
125
+ mask = nil
126
+ mask = "causal" if hidden.shape[1] > 1
127
+
128
+ h.each_with_index do |layer, i|
129
+ hidden = layer.call(hidden, mask: mask, cache: layer_cache[i])
130
+ end
131
+
132
+ ln_f.call(hidden)
133
+ end
134
+ end
135
+
136
+ class Model < MLX::NN::Module
137
+ def initialize(args)
138
+ super()
139
+ @args = args
140
+ self.transformer = ExaoneModel.new(args)
141
+ unless args.tie_word_embeddings
142
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
143
+ end
144
+ end
145
+
146
+ def call(inputs, cache: nil)
147
+ out = transformer.call(inputs, cache: cache)
148
+ if @args.tie_word_embeddings
149
+ transformer.wte.as_linear(out)
150
+ else
151
+ lm_head.call(out)
152
+ end
153
+ end
154
+
155
+ def sanitize(weights)
156
+ result = weights.reject { |k, _| k.include?("rotary_emb.inv_freq") }
157
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
158
+ result
159
+ end
160
+
161
+ def layers
162
+ transformer.h
163
+ end
164
+ end
165
+
166
+ Models.register("exaone", Model, ModelArgs)
167
+ end
168
+ end
169
+ end
@@ -0,0 +1,233 @@
1
+ module MlxLm
2
+ module Models
3
+ module Exaone4
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "exaone4"
6
+ field :hidden_size
7
+ field :num_hidden_layers
8
+ field :intermediate_size
9
+ field :num_attention_heads
10
+ field :rms_norm_eps
11
+ field :vocab_size
12
+ field :num_key_value_heads
13
+ field :max_position_embeddings
14
+ field :rope_theta
15
+ field :head_dim
16
+ field :tie_word_embeddings
17
+ field :rope_scaling, default: nil
18
+ field :sliding_window, default: nil
19
+ field :sliding_window_pattern, default: nil
20
+
21
+ def initialize(**kwargs)
22
+ super
23
+ @num_key_value_heads ||= @num_attention_heads
24
+ @head_dim ||= @hidden_size / @num_attention_heads
25
+ end
26
+ end
27
+
28
+ class Attention < MLX::NN::Module
29
+ attr_reader :is_local
30
+
31
+ def initialize(args, is_local)
32
+ super()
33
+
34
+ dim = args.hidden_size
35
+ @n_heads = args.num_attention_heads
36
+ @n_kv_heads = args.num_key_value_heads
37
+ @head_dim = args.head_dim
38
+ @scale = @head_dim**(-0.5)
39
+
40
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
41
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
42
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
43
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
44
+
45
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
46
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
47
+
48
+ @is_local = is_local || false
49
+ @use_rope = is_local.nil? || @is_local
50
+ if @use_rope
51
+ self.rope = MlxLm::Models.initialize_rope(
52
+ @head_dim,
53
+ args.rope_theta,
54
+ false,
55
+ args.rope_scaling,
56
+ max_position_embeddings: args.max_position_embeddings
57
+ )
58
+ end
59
+ end
60
+
61
+ def call(x, mask: nil, cache: nil)
62
+ mx = MLX::Core
63
+ b, l, _d = x.shape
64
+
65
+ queries = q_proj.call(x)
66
+ keys = k_proj.call(x)
67
+ values = v_proj.call(x)
68
+
69
+ queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3])
70
+ keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3])
71
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
72
+
73
+ if cache
74
+ if @use_rope
75
+ queries = rope.call(queries, offset: cache.offset)
76
+ keys = rope.call(keys, offset: cache.offset)
77
+ end
78
+ keys, values = cache.update_and_fetch(keys, values)
79
+ elsif @use_rope
80
+ queries = rope.call(queries)
81
+ keys = rope.call(keys)
82
+ end
83
+
84
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
85
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
86
+ o_proj.call(output)
87
+ end
88
+ end
89
+
90
+ class MLP < MLX::NN::Module
91
+ def initialize(dim, hidden_dim)
92
+ super()
93
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
94
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
95
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
96
+ end
97
+
98
+ def call(x)
99
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
100
+ end
101
+ end
102
+
103
+ class TransformerBlock < MLX::NN::Module
104
+ def initialize(args, is_local:)
105
+ super()
106
+ self.self_attn = Attention.new(args, is_local)
107
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
108
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
109
+ self.post_feedforward_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
110
+ end
111
+
112
+ def call(x, mask: nil, cache: nil)
113
+ r = self_attn.call(x, mask: mask, cache: cache)
114
+ h = x + post_attention_layernorm.call(r)
115
+ r = mlp.call(h)
116
+ h + post_feedforward_layernorm.call(r)
117
+ end
118
+ end
119
+
120
+ class ExaoneModel < MLX::NN::Module
121
+ def initialize(args)
122
+ super()
123
+ @args = args
124
+ self.vocab_size = args.vocab_size
125
+ self.num_hidden_layers = args.num_hidden_layers
126
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
127
+
128
+ pattern = args.sliding_window_pattern
129
+ self.layers = Array.new(args.num_hidden_layers) do |i|
130
+ is_local = pattern ? (pattern[i % pattern.length] == "L") : nil
131
+ TransformerBlock.new(args, is_local: is_local)
132
+ end
133
+
134
+ if pattern
135
+ self.swa_idx = pattern.index("L")
136
+ self.full_idx = pattern.index("G")
137
+ else
138
+ self.swa_idx = nil
139
+ self.full_idx = 0
140
+ end
141
+
142
+ self.window_size = args.sliding_window
143
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
144
+ end
145
+
146
+ def call(inputs, cache: nil)
147
+ h = embed_tokens.call(inputs)
148
+ layer_cache = cache || [nil] * layers.length
149
+
150
+ global_mask = _create_attention_mask(h, layer_cache[full_idx])
151
+ if !swa_idx.nil?
152
+ swa_mask = _create_attention_mask(
153
+ h,
154
+ layer_cache[swa_idx],
155
+ window_size: window_size
156
+ )
157
+ else
158
+ swa_mask = nil
159
+ end
160
+
161
+ layers.each_with_index do |layer, i|
162
+ mask = layer.self_attn.is_local ? swa_mask : global_mask
163
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
164
+ end
165
+
166
+ norm.call(h)
167
+ end
168
+
169
+ private
170
+
171
+ def _create_attention_mask(h, cache = nil, window_size: nil)
172
+ n = h.shape[1]
173
+ if cache && cache.respond_to?(:make_mask)
174
+ return cache.make_mask(n, window_size: window_size)
175
+ end
176
+ return nil if n == 1
177
+ return _create_causal_mask(n, window_size: window_size) if window_size && n > window_size
178
+
179
+ "causal"
180
+ end
181
+
182
+ def _create_causal_mask(n, offset: 0, window_size: nil)
183
+ mx = MLX::Core
184
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
185
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
186
+
187
+ mask = mx.greater_equal(linds, rinds)
188
+ if window_size
189
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
190
+ end
191
+ mask
192
+ end
193
+ end
194
+
195
+ class Model < MLX::NN::Module
196
+ def initialize(args)
197
+ super()
198
+ @args = args
199
+ self.model_type = args.model_type
200
+ self.model = ExaoneModel.new(args)
201
+ unless args.tie_word_embeddings
202
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
203
+ end
204
+ end
205
+
206
+ def call(inputs, cache: nil)
207
+ out = model.call(inputs, cache: cache)
208
+ if @args.tie_word_embeddings
209
+ model.embed_tokens.as_linear(out)
210
+ else
211
+ lm_head.call(out)
212
+ end
213
+ end
214
+
215
+ def make_cache
216
+ layers.map do |layer|
217
+ if layer.self_attn.is_local
218
+ RotatingKVCache.new(max_size: @args.sliding_window, keep: 0)
219
+ else
220
+ KVCache.new
221
+ end
222
+ end
223
+ end
224
+
225
+ def layers
226
+ model.layers
227
+ end
228
+ end
229
+
230
+ Models.register("exaone4", Model, ModelArgs)
231
+ end
232
+ end
233
+ end