fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,770 @@
1
+ """Language model for ERNIE 4.5 VL MoE."""
2
+
3
+ from typing import Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from mlx_lm.models.switch_layers import SwitchGLU
8
+
9
+ from ..base import (
10
+ LanguageModelOutput,
11
+ create_attention_mask,
12
+ scaled_dot_product_attention,
13
+ )
14
+ from ..cache import KVCache
15
+ from .config import ModelConfig, TextConfig
16
+
17
+
18
+ class Ernie4_5RotaryEmbedding:
19
+ """Rotary Position Embedding for ERNIE 4.5 VL with MRoPE support.
20
+
21
+ Matches PyTorch's implementation with pre-rotated inverse frequencies.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ dim: int,
27
+ base: float = 10000,
28
+ mrope_section: tuple = (22, 22, 20),
29
+ ):
30
+ self.dim = dim # head_dim
31
+ self.base = base
32
+ self.mrope_section = mrope_section # (h_dim, w_dim, t_dim)
33
+
34
+ # Pre-compute inverse frequencies
35
+ indices = mx.arange(0, self.dim, 2, dtype=mx.float32)
36
+ inv_freq = 1.0 / (self.base ** (indices / self.dim))
37
+
38
+ # Pre-rotate frequencies to match PyTorch's approach
39
+ # This avoids rotation during forward pass
40
+ hw_dim = mrope_section[0] + mrope_section[1] # 44
41
+ t_dim = mrope_section[2] # 20
42
+
43
+ inv_freq_3d = mx.zeros_like(inv_freq)
44
+ # Pre-rotate HW dimensions: [even, odd] -> interleaved during recomposition
45
+ hw_freqs = inv_freq[:-t_dim] # First (dim/2 - t_dim) frequencies
46
+ inv_freq_3d = mx.concatenate(
47
+ [
48
+ mx.concatenate([hw_freqs[0::2], hw_freqs[1::2]]), # Pre-rotated HW
49
+ inv_freq[-t_dim:], # T frequencies unchanged
50
+ ]
51
+ )
52
+ self.inv_freq = inv_freq_3d
53
+
54
+ def _recomposition_to_3d(self, freq):
55
+ """Recompose frequencies for 3D positions matching PyTorch's approach.
56
+
57
+ Args:
58
+ freq: [3, batch, seq_len, dim//2] - frequencies for T, H, W dimensions
59
+
60
+ Returns:
61
+ Recomposed frequencies [batch, seq_len, dim]
62
+ """
63
+ # Split by mrope_section
64
+ h_dim, w_dim, t_dim = self.mrope_section
65
+
66
+ # freq shape: [3, batch, seq_len, half_dim]
67
+ # Split each dimension's frequencies
68
+ freq_parts = []
69
+ for i in range(3):
70
+ freq_parts.append(mx.split(freq[i], [h_dim, h_dim + w_dim], axis=-1))
71
+
72
+ # Recompose: freq_h from dim 1, freq_w from dim 2, freq_t from dim 0
73
+ # This matches PyTorch's (i + 1) % 3 indexing
74
+ freq_h = freq_parts[1][0] # H from position 1
75
+ freq_w = freq_parts[2][1] # W from position 2
76
+ freq_t = freq_parts[0][2] # T from position 0
77
+
78
+ # Interleave H and W: [h0, w0, h1, w1, ...]
79
+ freq_hw = mx.stack([freq_h, freq_w], axis=-1).reshape(
80
+ freq_h.shape[0], freq_h.shape[1], -1
81
+ )
82
+
83
+ # Concatenate HW and T
84
+ freq_hwt = mx.concatenate([freq_hw, freq_t], axis=-1)
85
+
86
+ # Repeat interleave by 2 for full head_dim
87
+ freq_full = mx.repeat(freq_hwt, 2, axis=-1)
88
+
89
+ return freq_full
90
+
91
+ def __call__(self, x, position_ids):
92
+ """
93
+ Compute 3D rotary embeddings matching PyTorch's implementation.
94
+
95
+ Args:
96
+ x: Input tensor for dtype reference
97
+ position_ids: Position IDs, shape (batch, seq_len, 3) for 3D positions [T, H, W]
98
+
99
+ Returns:
100
+ cos, sin: [batch, seq_len, head_dim] ready for rotation
101
+ """
102
+ if position_ids.ndim == 2:
103
+ # 1D positions - expand to 3D with same values
104
+ position_ids = mx.stack([position_ids, position_ids, position_ids], axis=-1)
105
+
106
+ batch_size, seq_len, _ = position_ids.shape
107
+
108
+ # position_ids: [batch, seq_len, 3] -> [3, batch, seq_len]
109
+ position_ids = position_ids.transpose(2, 0, 1).astype(mx.float32)
110
+
111
+ # inv_freq: [dim//2] -> [1, 1, dim//2, 1] for broadcasting
112
+ inv_freq_expanded = self.inv_freq[None, None, :, None] # [1, 1, dim//2, 1]
113
+ inv_freq_expanded = mx.broadcast_to(
114
+ inv_freq_expanded, (3, batch_size, self.dim // 2, 1)
115
+ )
116
+
117
+ # position_ids: [3, batch, seq_len] -> [3, batch, 1, seq_len]
118
+ position_ids_expanded = position_ids[:, :, None, :]
119
+
120
+ # freqs: [3, batch, dim//2, seq_len] -> [3, batch, seq_len, dim//2]
121
+ freqs = (inv_freq_expanded * position_ids_expanded).transpose(0, 1, 3, 2)
122
+
123
+ cos = mx.cos(freqs)
124
+ sin = mx.sin(freqs)
125
+
126
+ # Recompose to 3D
127
+ cos = self._recomposition_to_3d(cos)
128
+ sin = self._recomposition_to_3d(sin)
129
+
130
+ return cos.astype(x.dtype), sin.astype(x.dtype)
131
+
132
+
133
+ def rotate_half_interleaved(x):
134
+ """Rotates using interleaved pattern: [-x1, x0, -x3, x2, ...].
135
+
136
+ This matches PyTorch's rotation: stack([-x[1::2], x[0::2]], dim=-1).reshape()
137
+ """
138
+ x_even = x[..., 0::2] # [x0, x2, x4, ...]
139
+ x_odd = x[..., 1::2] # [x1, x3, x5, ...]
140
+ # Stack as [-odd, even] and reshape
141
+ rotated = mx.stack([-x_odd, x_even], axis=-1)
142
+ return rotated.reshape(x.shape)
143
+
144
+
145
+ def apply_rotary_pos_emb(q, k, cos_pos, sin_pos):
146
+ """Apply rotary position embeddings to queries and keys.
147
+
148
+ Uses interleaved rotation matching PyTorch's apply_rotary_3d.
149
+
150
+ Args:
151
+ q: [batch, n_heads, seq_len, head_dim]
152
+ k: [batch, n_kv_heads, seq_len, head_dim]
153
+ cos_pos: [batch, seq_len, head_dim]
154
+ sin_pos: [batch, seq_len, head_dim]
155
+ """
156
+ orig_dtype = q.dtype
157
+ # Expand for heads dimension
158
+
159
+ cos_pos = mx.expand_dims(cos_pos, axis=1) # [batch, 1, seq_len, head_dim]
160
+ sin_pos = mx.expand_dims(sin_pos, axis=1)
161
+
162
+ # Apply rotation: q_rotated = q * cos + rotate_half(q) * sin
163
+ q_rotated = rotate_half_interleaved(q)
164
+ k_rotated = rotate_half_interleaved(k)
165
+
166
+ q_embed = (q.astype(mx.float32) * cos_pos) + (
167
+ q_rotated.astype(mx.float32) * sin_pos
168
+ )
169
+ k_embed = (k.astype(mx.float32) * cos_pos) + (
170
+ k_rotated.astype(mx.float32) * sin_pos
171
+ )
172
+
173
+ return q_embed.astype(orig_dtype), k_embed.astype(orig_dtype)
174
+
175
+
176
+ class Attention(nn.Module):
177
+ """Multi-headed attention for ERNIE 4.5 with MRoPE support."""
178
+
179
+ def __init__(self, args: TextConfig):
180
+ super().__init__()
181
+
182
+ dim = args.hidden_size
183
+ self.n_heads = n_heads = args.num_attention_heads
184
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
185
+
186
+ self.head_dim = head_dim = args.hidden_size // n_heads
187
+ self.scale = head_dim**-0.5
188
+
189
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.use_bias)
190
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.use_bias)
191
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.use_bias)
192
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.use_bias)
193
+
194
+ # Get mrope_section for 3D RoPE (H, W, T dimension allocation)
195
+ # Default [22, 22, 20] for head_dim=128
196
+ self.mrope_section = tuple(getattr(args, "mrope_section", [22, 22, 20]))
197
+
198
+ self.rotary_emb = Ernie4_5RotaryEmbedding(
199
+ head_dim,
200
+ base=args.rope_theta,
201
+ mrope_section=self.mrope_section,
202
+ )
203
+
204
+ def __call__(
205
+ self,
206
+ x: mx.array,
207
+ mask: Optional[mx.array] = None,
208
+ cache: Optional[KVCache] = None,
209
+ position_ids: Optional[mx.array] = None,
210
+ ) -> mx.array:
211
+ B, L, D = x.shape
212
+
213
+ queries = self.q_proj(x)
214
+ keys = self.k_proj(x)
215
+ values = self.v_proj(x)
216
+
217
+ # Reshape and transpose: [B, L, n_heads, head_dim] -> [B, n_heads, L, head_dim]
218
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
219
+ 0, 2, 1, 3
220
+ )
221
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
222
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
223
+ 0, 2, 1, 3
224
+ )
225
+
226
+ # Handle position IDs
227
+ if position_ids is None:
228
+ offset = cache.offset if cache is not None else 0
229
+ position_ids = mx.arange(offset, offset + L)
230
+ position_ids = mx.expand_dims(position_ids, axis=0)
231
+
232
+ cos, sin = self.rotary_emb(values, position_ids)
233
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
234
+
235
+ if cache is not None:
236
+ keys, values = cache.update_and_fetch(keys, values)
237
+
238
+ if mask is not None and isinstance(mask, mx.array):
239
+ mask = mask[..., : keys.shape[-2]]
240
+
241
+ output = scaled_dot_product_attention(
242
+ queries, keys, values, cache, scale=self.scale, mask=mask
243
+ )
244
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
245
+ return self.o_proj(output)
246
+
247
+
248
+ class Ernie4_5_MLP(nn.Module):
249
+ def __init__(self, dim, hidden_dim, use_bias=False):
250
+ super().__init__()
251
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=use_bias)
252
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=use_bias)
253
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=use_bias)
254
+
255
+ def __call__(self, x) -> mx.array:
256
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
257
+
258
+
259
+ class Ernie4_5_MoeMLP(nn.Module):
260
+ """Mixture of Experts MLP for ERNIE with dual expert groups."""
261
+
262
+ def __init__(self, args: TextConfig):
263
+ super().__init__()
264
+ self.args = args
265
+ self.k = args.moe_k
266
+ self.norm_min = getattr(args, "moe_norm_min", 1e-12)
267
+
268
+ moe_num_experts = args.moe_num_experts
269
+ moe_intermediate_size = args.moe_intermediate_size
270
+
271
+ if isinstance(moe_num_experts, (list, tuple)) and len(moe_num_experts) == 2:
272
+ self.num_text_experts = moe_num_experts[0]
273
+ self.num_mm_experts = moe_num_experts[1]
274
+ self.has_dual_experts = True
275
+ else:
276
+ self.num_text_experts = (
277
+ moe_num_experts
278
+ if not isinstance(moe_num_experts, (list, tuple))
279
+ else moe_num_experts[0]
280
+ )
281
+ self.num_mm_experts = 0
282
+ self.has_dual_experts = False
283
+
284
+ if (
285
+ isinstance(moe_intermediate_size, (list, tuple))
286
+ and len(moe_intermediate_size) == 2
287
+ ):
288
+ self.text_intermediate_size = moe_intermediate_size[0]
289
+ self.mm_intermediate_size = moe_intermediate_size[1]
290
+ else:
291
+ self.text_intermediate_size = (
292
+ moe_intermediate_size
293
+ if not isinstance(moe_intermediate_size, (list, tuple))
294
+ else moe_intermediate_size[0]
295
+ )
296
+ self.mm_intermediate_size = self.text_intermediate_size
297
+
298
+ self.gate = nn.Linear(args.hidden_size, self.num_text_experts, bias=False)
299
+ self.e_score_correction_bias = mx.zeros((self.num_text_experts,))
300
+ self.switch_mlp = SwitchGLU(
301
+ args.hidden_size,
302
+ self.text_intermediate_size,
303
+ self.num_text_experts,
304
+ bias=args.use_bias,
305
+ )
306
+
307
+ if self.has_dual_experts and self.num_mm_experts > 0:
308
+ self.gate_1 = nn.Linear(args.hidden_size, self.num_mm_experts, bias=False)
309
+ self.e_score_correction_bias_1 = mx.zeros((self.num_mm_experts,))
310
+ self.switch_mlp_1 = SwitchGLU(
311
+ args.hidden_size,
312
+ self.mm_intermediate_size,
313
+ self.num_mm_experts,
314
+ bias=args.use_bias,
315
+ )
316
+
317
+ if getattr(args, "moe_num_shared_experts", 0) > 0:
318
+ shared_intermediate_size = (
319
+ self.text_intermediate_size * args.moe_num_shared_experts
320
+ )
321
+ self.shared_experts = Ernie4_5_MLP(
322
+ args.hidden_size, shared_intermediate_size, args.use_bias
323
+ )
324
+ else:
325
+ self.shared_experts = None
326
+
327
+ def _route_experts(
328
+ self, x: mx.array, gate: nn.Module, e_score_correction_bias: mx.array
329
+ ) -> tuple:
330
+ k = self.k
331
+ router_logits = gate(x).astype(mx.float32)
332
+ routing_weights = mx.softmax(router_logits, axis=-1)
333
+ routing_weights_with_bias = routing_weights + e_score_correction_bias
334
+
335
+ selected_experts = mx.stop_gradient(
336
+ mx.argpartition(-routing_weights_with_bias, kth=k - 1, axis=-1)[..., :k]
337
+ )
338
+ scores = mx.take_along_axis(routing_weights, selected_experts, axis=-1)
339
+ scores = scores / mx.maximum(scores.sum(axis=-1, keepdims=True), self.norm_min)
340
+
341
+ return selected_experts, scores
342
+
343
+ def __call__(
344
+ self, x: mx.array, token_type_ids: Optional[mx.array] = None
345
+ ) -> mx.array:
346
+ inds, scores = self._route_experts(x, self.gate, self.e_score_correction_bias)
347
+ y_text = self.switch_mlp(x, inds)
348
+ y_text = (y_text * scores[..., None]).sum(axis=-2).astype(y_text.dtype)
349
+
350
+ if (
351
+ not self.has_dual_experts
352
+ or self.num_mm_experts == 0
353
+ or token_type_ids is None
354
+ ):
355
+ y = y_text
356
+ else:
357
+ inds_mm, scores_mm = self._route_experts(
358
+ x, self.gate_1, self.e_score_correction_bias_1
359
+ )
360
+ y_mm = self.switch_mlp_1(x, inds_mm)
361
+ y_mm = (y_mm * scores_mm[..., None]).sum(axis=-2).astype(y_mm.dtype)
362
+
363
+ is_text = token_type_ids == 0
364
+ is_text_expanded = mx.expand_dims(is_text, axis=-1)
365
+ y = mx.where(is_text_expanded, y_text, y_mm)
366
+
367
+ if self.shared_experts is not None:
368
+ y = y + self.shared_experts(x)
369
+
370
+ return y
371
+
372
+
373
+ class Ernie4_5VLDecoderLayer(nn.Module):
374
+ def __init__(self, args: TextConfig, layer_idx: int):
375
+ super().__init__()
376
+ self.hidden_size = args.hidden_size
377
+ self.self_attn = Attention(args)
378
+
379
+ moe_layer_start_index = args.moe_layer_start_index
380
+ if isinstance(moe_layer_start_index, (tuple, list)):
381
+ moe_layer_start_index = min(moe_layer_start_index)
382
+
383
+ moe_layer_end_index = args.moe_layer_end_index
384
+ if moe_layer_end_index is None:
385
+ moe_layer_end_index = args.num_hidden_layers - 1
386
+ elif isinstance(moe_layer_end_index, (tuple, list)):
387
+ moe_layer_end_index = max(moe_layer_end_index)
388
+
389
+ use_moe = (
390
+ ((layer_idx + 1) % args.moe_layer_interval == 0)
391
+ and layer_idx >= moe_layer_start_index
392
+ and layer_idx <= moe_layer_end_index
393
+ )
394
+
395
+ if use_moe:
396
+ self.mlp = Ernie4_5_MoeMLP(args)
397
+ else:
398
+ self.mlp = Ernie4_5_MLP(
399
+ args.hidden_size, args.intermediate_size, args.use_bias
400
+ )
401
+
402
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
403
+ self.post_attention_layernorm = nn.RMSNorm(
404
+ args.hidden_size, eps=args.rms_norm_eps
405
+ )
406
+
407
+ def __call__(
408
+ self,
409
+ x: mx.array,
410
+ mask: Optional[mx.array] = None,
411
+ cache: Optional[KVCache] = None,
412
+ position_ids: Optional[mx.array] = None,
413
+ token_type_ids: Optional[mx.array] = None,
414
+ ) -> mx.array:
415
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
416
+ h = x + r
417
+ if isinstance(self.mlp, Ernie4_5_MoeMLP):
418
+ r = self.mlp(
419
+ self.post_attention_layernorm(h), token_type_ids=token_type_ids
420
+ )
421
+ else:
422
+ r = self.mlp(self.post_attention_layernorm(h))
423
+ return h + r
424
+
425
+
426
+ class Ernie4_5Model(nn.Module):
427
+ def __init__(self, args: TextConfig):
428
+ super().__init__()
429
+ self.args = args
430
+ self.vocab_size = args.vocab_size
431
+ self.num_hidden_layers = args.num_hidden_layers
432
+
433
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
434
+ self.layers = [
435
+ Ernie4_5VLDecoderLayer(args=args, layer_idx=i)
436
+ for i in range(args.num_hidden_layers)
437
+ ]
438
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
439
+
440
+ def __call__(
441
+ self,
442
+ inputs: mx.array,
443
+ inputs_embeds: Optional[mx.array] = None,
444
+ mask: Optional[mx.array] = None,
445
+ cache=None,
446
+ position_ids: Optional[mx.array] = None,
447
+ token_type_ids: Optional[mx.array] = None,
448
+ ):
449
+ if inputs_embeds is None:
450
+ h = self.embed_tokens(inputs)
451
+ else:
452
+ h = inputs_embeds
453
+
454
+ if cache is None:
455
+ cache = [None] * len(self.layers)
456
+
457
+ if mask is None:
458
+ mask = create_attention_mask(h, cache)
459
+
460
+ for layer, c in zip(self.layers, cache):
461
+ h = layer(h, mask, c, position_ids, token_type_ids=token_type_ids)
462
+
463
+ return self.norm(h)
464
+
465
+
466
+ class LanguageModel(nn.Module):
467
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
468
+ super().__init__()
469
+ self.args = args
470
+ self.config = config
471
+ self.model_type = args.model_type
472
+ self.model = Ernie4_5Model(args)
473
+ self._rope_deltas = None
474
+
475
+ if not args.tie_word_embeddings:
476
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
477
+
478
+ def get_rope_index(
479
+ self,
480
+ input_ids: mx.array,
481
+ image_grid_thw: Optional[mx.array] = None,
482
+ video_grid_thw: Optional[mx.array] = None,
483
+ attention_mask: Optional[mx.array] = None,
484
+ ):
485
+ batch_size, seq_length = input_ids.shape
486
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
487
+ image_token_id = self.config.image_token_id
488
+ video_token_id = self.config.video_token_id
489
+ vision_start_token_id = self.config.vision_start_token_id
490
+
491
+ if image_grid_thw is not None or video_grid_thw is not None:
492
+ batch_position_ids = []
493
+ mrope_position_deltas = []
494
+
495
+ image_index, video_index = 0, 0
496
+
497
+ for i in range(batch_size):
498
+ input_tokens = input_ids[i].tolist()
499
+ llm_pos_ids_list = []
500
+ st = 0
501
+
502
+ image_nums, video_nums = 0, 0
503
+ for idx, token in enumerate(input_tokens):
504
+ if token == vision_start_token_id and idx + 1 < len(input_tokens):
505
+ next_token = input_tokens[idx + 1]
506
+ if next_token == image_token_id:
507
+ image_nums += 1
508
+ elif next_token == video_token_id:
509
+ video_nums += 1
510
+
511
+ remain_images, remain_videos = image_nums, video_nums
512
+
513
+ for _ in range(image_nums + video_nums):
514
+ ed_image = (
515
+ input_tokens.index(image_token_id, st)
516
+ if image_token_id in input_tokens[st:] and remain_images > 0
517
+ else len(input_tokens) + 1
518
+ )
519
+ ed_video = (
520
+ input_tokens.index(video_token_id, st)
521
+ if video_token_id in input_tokens[st:] and remain_videos > 0
522
+ else len(input_tokens) + 1
523
+ )
524
+
525
+ if ed_image < ed_video:
526
+ t, h, w = image_grid_thw[image_index].tolist()
527
+ image_index += 1
528
+ remain_images -= 1
529
+ ed = ed_image
530
+ vision_token = image_token_id
531
+ else:
532
+ t, h, w = video_grid_thw[video_index].tolist()
533
+ video_index += 1
534
+ remain_videos -= 1
535
+ ed = ed_video
536
+ vision_token = video_token_id
537
+
538
+ llm_grid_t = t
539
+ llm_grid_h = h // spatial_merge_size
540
+ llm_grid_w = w // spatial_merge_size
541
+ expected_vision_len = llm_grid_t * llm_grid_h * llm_grid_w
542
+
543
+ actual_vision_len = 0
544
+ for j in range(
545
+ ed, min(ed + expected_vision_len, len(input_tokens))
546
+ ):
547
+ if input_tokens[j] == vision_token:
548
+ actual_vision_len += 1
549
+ else:
550
+ break
551
+
552
+ text_len = ed - st
553
+ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
554
+
555
+ text_pos = mx.arange(text_len) + st_idx
556
+ text_pos_3d = mx.stack([text_pos, text_pos, text_pos], axis=0)
557
+ llm_pos_ids_list.append(text_pos_3d)
558
+
559
+ if actual_vision_len > 0:
560
+ t_idx = mx.repeat(
561
+ mx.arange(llm_grid_t).reshape(-1, 1),
562
+ llm_grid_h * llm_grid_w,
563
+ axis=1,
564
+ ).flatten()[:actual_vision_len]
565
+ h_idx = mx.tile(
566
+ mx.arange(llm_grid_h).reshape(1, -1, 1),
567
+ (llm_grid_t, 1, llm_grid_w),
568
+ ).flatten()[:actual_vision_len]
569
+ w_idx = mx.tile(
570
+ mx.arange(llm_grid_w).reshape(1, 1, -1),
571
+ (llm_grid_t, llm_grid_h, 1),
572
+ ).flatten()[:actual_vision_len]
573
+
574
+ vision_pos = (
575
+ mx.stack([t_idx, h_idx, w_idx], axis=0) + text_len + st_idx
576
+ )
577
+ llm_pos_ids_list.append(vision_pos)
578
+
579
+ st = ed + actual_vision_len
580
+
581
+ # Handle remaining text
582
+ if st < len(input_tokens):
583
+ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
584
+ text_len = len(input_tokens) - st
585
+ text_pos = mx.arange(text_len) + st_idx
586
+ text_pos_3d = mx.stack([text_pos, text_pos, text_pos], axis=0)
587
+ llm_pos_ids_list.append(text_pos_3d)
588
+
589
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1) # [3, seq_len]
590
+ batch_position_ids.append(llm_positions.T) # [seq_len, 3]
591
+ mrope_position_deltas.append(llm_positions.max() + 1 - seq_length)
592
+
593
+ position_ids = mx.stack(batch_position_ids, axis=0)
594
+ mrope_position_deltas = mx.array(mrope_position_deltas)
595
+ return position_ids, mrope_position_deltas
596
+ else:
597
+ position_ids = mx.arange(seq_length)
598
+ position_ids = mx.broadcast_to(
599
+ position_ids[None, :], (batch_size, seq_length)
600
+ )
601
+ position_ids = mx.stack([position_ids, position_ids, position_ids], axis=-1)
602
+ return position_ids, mx.zeros((batch_size,), dtype=mx.int32)
603
+
604
+ def __call__(
605
+ self,
606
+ inputs: mx.array,
607
+ inputs_embeds: Optional[mx.array] = None,
608
+ mask: Optional[mx.array] = None,
609
+ cache=None,
610
+ **kwargs,
611
+ ):
612
+ position_ids = kwargs.pop("position_ids", None)
613
+ pixel_values = kwargs.pop("pixel_values", None)
614
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
615
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
616
+
617
+ if pixel_values is not None:
618
+ self._rope_deltas = None
619
+
620
+ cache_offset = 0
621
+ if cache and cache[0] is not None:
622
+ offset = cache[0].offset
623
+ cache_offset = offset.item() if isinstance(offset, mx.array) else offset
624
+
625
+ if position_ids is None and (mask is None or mask.ndim == 2):
626
+ if (
627
+ cache is None or cache[0] is None or cache_offset == 0
628
+ ) or self._rope_deltas is None:
629
+ position_ids, rope_deltas = self.get_rope_index(
630
+ inputs, image_grid_thw, video_grid_thw, mask
631
+ )
632
+ self._rope_deltas = rope_deltas
633
+ else:
634
+ batch_size, seq_length = inputs.shape
635
+ delta = cache_offset + self._rope_deltas if cache is not None else 0
636
+ position_ids = mx.arange(seq_length) + delta
637
+ position_ids = mx.broadcast_to(
638
+ position_ids[None, :], (batch_size, seq_length)
639
+ )
640
+ position_ids = mx.stack(
641
+ [position_ids, position_ids, position_ids], axis=-1
642
+ )
643
+
644
+ token_type_ids = kwargs.pop("token_type_ids", None)
645
+
646
+ out = self.model(
647
+ inputs,
648
+ cache=cache,
649
+ inputs_embeds=inputs_embeds,
650
+ mask=mask,
651
+ position_ids=position_ids,
652
+ token_type_ids=token_type_ids,
653
+ )
654
+
655
+ if self.args.tie_word_embeddings:
656
+ out = self.model.embed_tokens.as_linear(out)
657
+ else:
658
+ out = self.lm_head(out)
659
+
660
+ return LanguageModelOutput(logits=out)
661
+
662
+ @property
663
+ def layers(self):
664
+ return self.model.layers
665
+
666
+ @property
667
+ def head_dim(self):
668
+ return self.args.hidden_size // self.args.num_attention_heads
669
+
670
+ @property
671
+ def n_kv_heads(self):
672
+ return self.args.num_key_value_heads
673
+
674
+ def sanitize(self, weights):
675
+ """Sanitize weights for loading."""
676
+ remove_patterns = [
677
+ "mtp_block.",
678
+ "mtp_linear_proj.",
679
+ "mtp_hidden_norm.",
680
+ "mtp_emb_norm.",
681
+ ]
682
+
683
+ weights = {
684
+ key: value
685
+ for key, value in weights.items()
686
+ if not any(pattern in key for pattern in remove_patterns)
687
+ }
688
+
689
+ # Get expert configuration
690
+ moe_num_experts = self.args.moe_num_experts
691
+ if isinstance(moe_num_experts, (list, tuple)) and len(moe_num_experts) == 2:
692
+ num_text_experts = moe_num_experts[0]
693
+ num_mm_experts = moe_num_experts[1]
694
+ else:
695
+ num_text_experts = (
696
+ moe_num_experts
697
+ if not isinstance(moe_num_experts, (list, tuple))
698
+ else moe_num_experts[0]
699
+ )
700
+ num_mm_experts = 0
701
+
702
+ for l in range(self.args.num_hidden_layers):
703
+ prefix = f"language_model.model.layers.{l}"
704
+
705
+ # Stack text experts (0 to num_text_experts-1) into switch_mlp
706
+ for m in ["gate_proj", "down_proj", "up_proj"]:
707
+ for k in ["weight", "scales", "biases"]:
708
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
709
+ to_join = [
710
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
711
+ for e in range(num_text_experts)
712
+ ]
713
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
714
+
715
+ # Stack multimodal experts (num_text_experts to num_text_experts+num_mm_experts-1) into switch_mlp_1
716
+ if num_mm_experts > 0:
717
+ for m in ["gate_proj", "down_proj", "up_proj"]:
718
+ for k in ["weight", "scales", "biases"]:
719
+ first_mm_expert = num_text_experts
720
+ if f"{prefix}.mlp.experts.{first_mm_expert}.{m}.{k}" in weights:
721
+ to_join = [
722
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
723
+ for e in range(
724
+ num_text_experts, num_text_experts + num_mm_experts
725
+ )
726
+ ]
727
+ weights[f"{prefix}.mlp.switch_mlp_1.{m}.{k}"] = mx.stack(
728
+ to_join
729
+ )
730
+
731
+ # Transpose gate weights if needed (HuggingFace uses [in, out], MLX uses [out, in])
732
+ # MLX nn.Linear(in=2560, out=64) expects shape (64, 2560), HF provides (2560, 64)
733
+ gate_key = f"{prefix}.mlp.gate.weight"
734
+ if gate_key in weights:
735
+ w = weights[gate_key]
736
+ # Only transpose if shape is (hidden_size, num_experts) not (num_experts, hidden_size)
737
+ if w.shape[0] > w.shape[1]: # (2560, 64) needs transpose
738
+ weights[gate_key] = w.T
739
+
740
+ # Rename gate.weight_1 to gate_1.weight for multimodal gate and transpose
741
+ gate_1_key = f"{prefix}.mlp.gate.weight_1"
742
+ if gate_1_key in weights:
743
+ w = weights.pop(gate_1_key)
744
+ if w.shape[0] > w.shape[1]: # Only transpose if needed
745
+ w = w.T
746
+ weights[f"{prefix}.mlp.gate_1.weight"] = w
747
+
748
+ # Handle e_score_correction_bias
749
+ # HuggingFace stores as [2, num_experts] - row 0 for text, row 1 for multimodal
750
+ bias_key = f"{prefix}.mlp.moe_statics.e_score_correction_bias"
751
+ if bias_key in weights:
752
+ bias = weights.pop(bias_key)
753
+ if bias.ndim == 2 and bias.shape[0] == 2:
754
+ # Split into text and multimodal biases
755
+ weights[f"{prefix}.mlp.e_score_correction_bias"] = bias[0]
756
+ if num_mm_experts > 0:
757
+ weights[f"{prefix}.mlp.e_score_correction_bias_1"] = bias[1]
758
+ else:
759
+ # Single bias (squeeze if needed)
760
+ if bias.ndim > 1:
761
+ bias = bias.squeeze()
762
+ weights[f"{prefix}.mlp.e_score_correction_bias"] = bias
763
+
764
+ # Remove lm_head if tie_word_embeddings is True
765
+ if self.args.tie_word_embeddings:
766
+ lm_head_key = "language_model.lm_head.weight"
767
+ if lm_head_key in weights:
768
+ weights.pop(lm_head_key)
769
+
770
+ return weights