fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,674 @@
1
+ from typing import Any, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from .config import ModelConfig, TextConfig
14
+
15
+
16
+ def _compute_default_rope_parameters(
17
+ config: Optional[TextConfig] = None,
18
+ **rope_kwargs,
19
+ ) -> tuple[mx.array, float]:
20
+
21
+ if config is not None and len(rope_kwargs) > 0:
22
+ raise ValueError(
23
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
24
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
25
+ )
26
+ if len(rope_kwargs) > 0:
27
+ base = rope_kwargs["base"]
28
+ dim = rope_kwargs["dim"]
29
+ elif config is not None:
30
+ base = config.rope_theta or config.rope_parameters["rope_theta"]
31
+ partial_rotary_factor = (
32
+ config.partial_rotary_factor
33
+ if hasattr(config, "partial_rotary_factor")
34
+ else 1.0
35
+ )
36
+ head_dim = (
37
+ getattr(config, "head_dim", None)
38
+ or config.hidden_size // config.num_attention_heads
39
+ )
40
+ dim = int(head_dim * partial_rotary_factor)
41
+
42
+ attention_factor = 1.0 # Unused in this type of RoPE
43
+
44
+ # Compute the inverse frequencies
45
+ inv_freq = 1.0 / (
46
+ base ** (mx.arange(0, dim, 2, dtype=mx.int64).astype(mx.float32) / dim)
47
+ )
48
+ return inv_freq, attention_factor
49
+
50
+
51
+ class GLM4VRotaryEmbedding(nn.Module):
52
+ def __init__(self, config: TextConfig):
53
+ super().__init__()
54
+
55
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56
+ self.rope_type = config.rope_scaling.get(
57
+ "rope_type", config.rope_scaling.get("type")
58
+ )
59
+ else:
60
+ self.rope_type = "default"
61
+ self.max_seq_len_cached = config.max_position_embeddings
62
+ self.original_max_seq_len = config.max_position_embeddings
63
+
64
+ self.config = config
65
+
66
+ self.rope_init_fn = _compute_default_rope_parameters
67
+
68
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config)
69
+ self._inv_freq = mx.array(inv_freq, dtype=mx.float32)
70
+ self._original_inv_freq = mx.array(inv_freq, dtype=mx.float32)
71
+
72
+ def __call__(self, x, position_ids):
73
+ inv_freq_expanded = self._inv_freq[None, None, :, None].astype(mx.float32)
74
+ inv_freq_expanded = mx.broadcast_to(
75
+ inv_freq_expanded, (3, position_ids.shape[1], self._inv_freq.shape[0], 1)
76
+ )
77
+ position_ids_expanded = position_ids[:, :, None, :].astype(mx.float32)
78
+
79
+ freqs = (
80
+ inv_freq_expanded.astype(mx.float32)
81
+ @ position_ids_expanded.astype(mx.float32)
82
+ ).transpose(0, 1, 3, 2)
83
+ emb = mx.concatenate((freqs, freqs), axis=-1)
84
+ cos = mx.cos(emb) * self.attention_scaling
85
+ sin = mx.sin(emb) * self.attention_scaling
86
+
87
+ return cos.astype(x.dtype), sin.astype(x.dtype)
88
+
89
+
90
+ def rotate_half_llm(x):
91
+ """Rotates half the hidden dims of the input."""
92
+ x1 = x[..., : x.shape[-1] // 2]
93
+ x2 = x[..., x.shape[-1] // 2 :]
94
+ return mx.concatenate([-x2, x1], axis=-1)
95
+
96
+
97
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section):
98
+ """
99
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
100
+ Args:
101
+ q (mx.array): The query tensor.
102
+ k (mx.array): The key tensor.
103
+ cos (mx.array): The cosine part of the rotary embedding.
104
+ sin (mx.array): The sine part of the rotary embedding.
105
+ mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
106
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
107
+ Returns:
108
+ tuple(mx.array): The rotated query and key tensors.
109
+ """
110
+
111
+ mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
112
+
113
+ cos = mx.concatenate(
114
+ [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
115
+ )[
116
+ :, None, :, :
117
+ ] # unsqueeze dim 1
118
+ sin = mx.concatenate(
119
+ [m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
120
+ )[:, None, :, :]
121
+
122
+ rotary_dim = cos.shape[-1]
123
+ q_rot = q[..., :rotary_dim]
124
+ q_pass = q[..., rotary_dim:]
125
+
126
+ k_rot = k[..., :rotary_dim]
127
+ k_pass = k[..., rotary_dim:]
128
+
129
+ q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
130
+ k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
131
+
132
+ q_embed = mx.concatenate([q_embed, q_pass], axis=-1)
133
+ k_embed = mx.concatenate([k_embed, k_pass], axis=-1)
134
+
135
+ return q_embed, k_embed
136
+
137
+
138
+ class Glm4vMoeAttention(nn.Module):
139
+ def __init__(self, args: TextConfig):
140
+ super().__init__()
141
+
142
+ dim = args.hidden_size
143
+ self.n_heads = n_heads = args.num_attention_heads
144
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
145
+
146
+ head_dim = args.head_dim
147
+ self.scale = head_dim**-0.5
148
+
149
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
150
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
151
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
152
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
153
+
154
+ self.rope_parameters = args.rope_parameters or args.rope_scaling
155
+
156
+ def __call__(
157
+ self,
158
+ x: mx.array,
159
+ mask: Optional[mx.array] = None,
160
+ cache: Optional[Any] = None,
161
+ position_embeddings: Optional[mx.array] = None,
162
+ ) -> mx.array:
163
+ B, L, D = x.shape
164
+
165
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
166
+
167
+ queries = queries.reshape(B, L, self.n_heads, -1)
168
+ keys = keys.reshape(B, L, self.n_kv_heads, -1)
169
+
170
+ queries = queries.transpose(0, 2, 1, 3)
171
+ keys = keys.transpose(0, 2, 1, 3)
172
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
173
+
174
+ cos, sin = position_embeddings
175
+
176
+ queries, keys = apply_multimodal_rotary_pos_emb(
177
+ queries, keys, cos, sin, self.rope_parameters["mrope_section"]
178
+ )
179
+
180
+ if cache is not None:
181
+ keys, values = cache.update_and_fetch(keys, values)
182
+
183
+ if mask is not None and isinstance(mask, mx.array):
184
+ mask = mask[..., : keys.shape[-2]]
185
+
186
+ output = scaled_dot_product_attention(
187
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
188
+ )
189
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
190
+ return self.o_proj(output)
191
+
192
+
193
+ class Glm4vMoeMLP(nn.Module):
194
+ def __init__(
195
+ self, config: TextConfig, hidden_size: int = None, intermediate_size: int = None
196
+ ):
197
+ super().__init__()
198
+ self.config = config
199
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
200
+ self.intermediate_size = (
201
+ config.intermediate_size if intermediate_size is None else intermediate_size
202
+ )
203
+
204
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
205
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
206
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
207
+
208
+ def __call__(self, x):
209
+ down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
210
+ return down_proj
211
+
212
+
213
+ @mx.compile
214
+ def group_expert_select(
215
+ gates,
216
+ e_score_correction_bias,
217
+ top_k,
218
+ n_group,
219
+ topk_group,
220
+ routed_scaling_factor,
221
+ norm_topk_prob,
222
+ ):
223
+
224
+ scores = mx.sigmoid(gates.astype(mx.float32))
225
+ orig_scores = scores
226
+ scores = scores + e_score_correction_bias
227
+ if n_group > 1:
228
+ k = top_k
229
+ scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
230
+ group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
231
+ k = n_group - topk_group
232
+ group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
233
+ scores = mx.put_along_axis(
234
+ scores, mx.stop_gradient(group_idx), mx.array(0.0), axis=-2
235
+ )
236
+ scores = mx.flatten(scores, -2, -1)
237
+
238
+ k = top_k
239
+ inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
240
+ scores = mx.take_along_axis(orig_scores, inds, axis=-1)
241
+ if top_k > 1 and norm_topk_prob:
242
+ denominator = scores.sum(axis=-1, keepdims=True)
243
+ scores = scores / denominator
244
+ scores = scores * routed_scaling_factor
245
+
246
+ return inds, scores
247
+
248
+
249
+ class MoEGate(nn.Module):
250
+ def __init__(self, config: TextConfig):
251
+ super().__init__()
252
+ self.config = config
253
+ self.top_k = config.num_experts_per_tok
254
+ self.norm_topk_prob = config.norm_topk_prob
255
+ self.n_routed_experts = config.n_routed_experts
256
+ self.routed_scaling_factor = config.routed_scaling_factor
257
+ self.n_group = config.n_group
258
+ self.topk_group = config.topk_group
259
+ self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
260
+ self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
261
+ assert config.topk_method == "noaux_tc", "Unsupported topk method."
262
+
263
+ def __call__(self, x):
264
+ return group_expert_select(
265
+ x @ self.weight.T,
266
+ self.e_score_correction_bias,
267
+ self.top_k,
268
+ self.n_group,
269
+ self.topk_group,
270
+ self.routed_scaling_factor,
271
+ self.norm_topk_prob,
272
+ )
273
+
274
+
275
+ class MoE(nn.Module):
276
+ def __init__(self, config: TextConfig):
277
+ super().__init__()
278
+ self.config = config
279
+ self.num_experts_per_tok = config.num_experts_per_tok
280
+ self.switch_mlp = SwitchGLU(
281
+ config.hidden_size,
282
+ config.moe_intermediate_size,
283
+ config.n_routed_experts,
284
+ )
285
+
286
+ self.gate = MoEGate(config)
287
+ if config.n_shared_experts is not None:
288
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
289
+ self.shared_experts = Glm4vMoeMLP(
290
+ config=config, intermediate_size=intermediate_size
291
+ )
292
+
293
+ def __call__(self, x):
294
+ inds, scores = self.gate(x)
295
+ y = self.switch_mlp(x, inds)
296
+ y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
297
+ if self.config.n_shared_experts is not None:
298
+ y = y + self.shared_experts(x)
299
+
300
+ return y
301
+
302
+
303
+ class Glm4vMoeDecoderLayer(nn.Module):
304
+ def __init__(self, config: TextConfig, layer_idx: int):
305
+ super().__init__()
306
+ self.self_attn = Glm4vMoeAttention(config)
307
+ self.mlp = (
308
+ MoE(config)
309
+ if (
310
+ config.n_routed_experts is not None
311
+ and layer_idx >= config.first_k_dense_replace
312
+ )
313
+ else Glm4vMoeMLP(config)
314
+ )
315
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
316
+ self.post_attention_layernorm = nn.RMSNorm(
317
+ config.hidden_size, eps=config.rms_norm_eps
318
+ )
319
+
320
+ def __call__(
321
+ self,
322
+ x: mx.array,
323
+ mask: Optional[mx.array] = None,
324
+ cache: Optional[Any] = None,
325
+ position_embeddings: Optional[mx.array] = None,
326
+ ) -> mx.array:
327
+
328
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_embeddings)
329
+ h = x + r
330
+ r = self.mlp(self.post_attention_layernorm(h))
331
+ return h + r
332
+
333
+
334
+ class GLM4VModel(nn.Module):
335
+ def __init__(self, config: TextConfig):
336
+ super().__init__()
337
+ self.vocab_size = config.vocab_size
338
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
339
+ self.layers = [
340
+ Glm4vMoeDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
341
+ ]
342
+ self.start_idx = 0
343
+ self.end_idx = len(self.layers)
344
+ self.num_layers = self.end_idx
345
+
346
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
347
+
348
+ self.rotary_emb = GLM4VRotaryEmbedding(config)
349
+
350
+ def __call__(
351
+ self,
352
+ inputs: mx.array,
353
+ inputs_embeds: Optional[mx.array] = None,
354
+ cache: Optional[Any] = None,
355
+ mask: Optional[mx.array] = None,
356
+ position_ids: Optional[mx.array] = None,
357
+ ) -> mx.array:
358
+
359
+ if inputs_embeds is None:
360
+ h = self.embed_tokens(inputs)
361
+ else:
362
+ h = inputs_embeds.astype(self.norm.weight.dtype)
363
+
364
+ if position_ids is None:
365
+ position_ids = mx.arange(cache[0].offset, cache[0].offset + h.shape[-2])
366
+ position_ids = mx.expand_dims(position_ids, axis=0)
367
+ position_ids = mx.tile(position_ids, (3, 1, 1))
368
+
369
+ position_embeddings = self.rotary_emb(h, position_ids)
370
+
371
+ if mask is None:
372
+ mask = create_attention_mask(h, cache)
373
+
374
+ if cache is None:
375
+ cache = [None] * self.num_layers
376
+
377
+ for i in range(self.num_layers):
378
+ h = self.layers[self.start_idx + i](h, mask, cache[i], position_embeddings)
379
+
380
+ return self.norm(h)
381
+
382
+
383
+ class LanguageModel(nn.Module):
384
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
385
+ super().__init__()
386
+ self.args = args
387
+ self.config = config
388
+ self.model_type = args.model_type
389
+ self.model = GLM4VModel(args)
390
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
391
+ self._rope_deltas = None
392
+
393
+ def get_rope_index(
394
+ self,
395
+ input_ids: mx.array,
396
+ image_grid_thw: Optional[mx.array] = None,
397
+ video_grid_thw: Optional[mx.array] = None,
398
+ attention_mask: Optional[mx.array] = None,
399
+ ):
400
+ # Calculate RoPE index for image/video tokens
401
+ batch_size, seq_length = input_ids.shape
402
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
403
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
404
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
405
+ image_token_id = self.config.image_token_id
406
+ video_token_id = self.config.video_token_id
407
+ vision_start_token_id = self.config.vision_start_token_id
408
+ mrope_position_deltas = []
409
+ if input_ids is not None and (
410
+ image_grid_thw is not None or video_grid_thw is not None
411
+ ):
412
+ total_input_ids = input_ids
413
+ if attention_mask is None:
414
+ attention_mask = mx.ones_like(input_ids)
415
+ position_ids = mx.ones(
416
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
417
+ )
418
+ image_index, video_index = 0, 0
419
+ for i, input_ids in enumerate(total_input_ids):
420
+ input_ids = mx.where(
421
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
422
+ )
423
+ image_nums, video_nums = 0, 0
424
+ vision_start_indices = mx.sum(
425
+ mx.where(
426
+ input_ids == vision_start_token_id,
427
+ mx.arange(input_ids.shape[0]),
428
+ mx.zeros_like(input_ids),
429
+ )
430
+ )
431
+ vision_tokens = input_ids[vision_start_indices + 1]
432
+ image_nums = (vision_tokens == image_token_id).sum().item()
433
+ video_nums = (vision_tokens == video_token_id).sum().item()
434
+ input_tokens = input_ids.tolist()
435
+ llm_pos_ids_list: list = []
436
+ st = 0
437
+ remain_images, remain_videos = image_nums, video_nums
438
+ for _ in range(image_nums + video_nums):
439
+ if image_token_id in input_tokens and remain_images > 0:
440
+ ed_image = input_tokens.index(image_token_id, st)
441
+ else:
442
+ ed_image = len(input_tokens) + 1
443
+ if video_token_id in input_tokens and remain_videos > 0:
444
+ ed_video = input_tokens.index(video_token_id, st)
445
+ else:
446
+ ed_video = len(input_tokens) + 1
447
+ if ed_image < ed_video:
448
+ t, h, w = (
449
+ image_grid_thw[image_index][0],
450
+ image_grid_thw[image_index][1],
451
+ image_grid_thw[image_index][2],
452
+ )
453
+ image_index += 1
454
+ remain_images -= 1
455
+ ed = ed_image
456
+ else:
457
+ t, h, w = (
458
+ video_grid_thw[video_index][0],
459
+ video_grid_thw[video_index][1],
460
+ video_grid_thw[video_index][2],
461
+ )
462
+ video_index += 1
463
+ remain_videos -= 1
464
+ ed = ed_video
465
+ llm_grid_t, llm_grid_h, llm_grid_w = (
466
+ t.item(),
467
+ h.item() // spatial_merge_size,
468
+ w.item() // spatial_merge_size,
469
+ )
470
+ text_len = ed - st
471
+ st_idx = (
472
+ llm_pos_ids_list[-1].max() + 1
473
+ if len(llm_pos_ids_list) > 0
474
+ else 0
475
+ )
476
+ index = mx.arange(text_len).reshape(1, text_len)
477
+ index = mx.broadcast_to(index, (3, text_len))
478
+ index = index + st_idx
479
+ llm_pos_ids_list.append(index)
480
+ t_index = mx.arange(llm_grid_t).reshape(
481
+ llm_grid_t, 1
482
+ ) # Equivalent to .view(-1, 1)
483
+ t_index = mx.broadcast_to(
484
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
485
+ ) # Equivalent to expand()
486
+ t_index = t_index.flatten() # Flattens to 1D
487
+
488
+ h_index = mx.arange(llm_grid_h).reshape(
489
+ 1, llm_grid_h, 1
490
+ ) # Equivalent to .view(1, -1)
491
+ h_index = mx.broadcast_to(
492
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
493
+ ) # Equivalent to expand()
494
+ h_index = h_index.flatten() # Flattens to 1D
495
+
496
+ w_index = mx.arange(llm_grid_w).reshape(
497
+ 1, 1, llm_grid_w
498
+ ) # Equivalent to .view(1, -1)
499
+ w_index = mx.broadcast_to(
500
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
501
+ ) # Equivalent to expand()
502
+ w_index = w_index.flatten() # Flattens to 1D
503
+
504
+ llm_pos_ids_list.append(
505
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
506
+ )
507
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
508
+ if st < len(input_tokens):
509
+ st_idx = (
510
+ llm_pos_ids_list[-1].max() + 1
511
+ if len(llm_pos_ids_list) > 0
512
+ else 0
513
+ )
514
+ text_len = len(input_tokens) - st
515
+
516
+ t_index = mx.arange(text_len).reshape(
517
+ 1, text_len
518
+ ) # Equivalent to .view(-1, 1)
519
+ t_index = mx.broadcast_to(
520
+ t_index, (3, text_len)
521
+ ) # Equivalent to expand(3, -1)
522
+
523
+ llm_pos_ids_list.append(t_index + st_idx)
524
+
525
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
526
+ mask = mx.array(attention_mask[i] == 1)
527
+ expanded_mask = mx.expand_dims(mask, axis=0)
528
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
529
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
530
+ new_positions = mx.where(
531
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
532
+ )
533
+ updated_position_ids = mx.concatenate(
534
+ [
535
+ position_ids[:, :i, :],
536
+ new_positions,
537
+ position_ids[:, i + 1 :, :],
538
+ ],
539
+ axis=1,
540
+ )
541
+ position_ids = updated_position_ids
542
+ mrope_position_deltas.append(
543
+ llm_positions.max() + 1 - len(total_input_ids[i])
544
+ )
545
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
546
+ return position_ids, mrope_position_deltas
547
+ else:
548
+ if attention_mask is not None:
549
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
550
+ position_ids = mx.where(
551
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
552
+ )
553
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
554
+ position_ids = mx.tile(position_ids, (3, 1, 1))
555
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
556
+ -1, keepdims=True
557
+ )[0]
558
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
559
+ else:
560
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
561
+ position_ids = mx.broadcast_to(
562
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
563
+ )
564
+ mrope_position_deltas = mx.zeros(
565
+ [input_ids.shape[0], 1],
566
+ dtype=input_ids.dtype,
567
+ )
568
+ return position_ids, mrope_position_deltas
569
+
570
+ def __call__(
571
+ self,
572
+ inputs: mx.array,
573
+ inputs_embeds: Optional[mx.array] = None,
574
+ mask: Optional[mx.array] = None,
575
+ cache=None,
576
+ **kwargs,
577
+ ):
578
+
579
+ position_ids = kwargs.pop("position_ids", None)
580
+ pixel_values = kwargs.pop("pixel_values", None)
581
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
582
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
583
+ # reset rope_deltas when processing a new image/video
584
+ if pixel_values is not None:
585
+ self._rope_deltas = None
586
+
587
+ cache_offset = 0
588
+ if cache and cache[0] is not None:
589
+ offset = cache[0].offset
590
+ if isinstance(offset, int):
591
+ cache_offset = offset
592
+ elif isinstance(offset, mx.array):
593
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
594
+ else:
595
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
596
+
597
+ if position_ids is None and (mask is None or mask.ndim == 2):
598
+ # Calculate RoPE index once per generation in the pre-fill stage only
599
+ if (
600
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
601
+ or self._rope_deltas is None
602
+ or cache is None
603
+ ):
604
+ position_ids, rope_deltas = self.get_rope_index(
605
+ inputs, image_grid_thw, video_grid_thw, mask
606
+ )
607
+ self._rope_deltas = rope_deltas
608
+ else:
609
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
610
+ batch_size, seq_length = inputs.shape
611
+ delta = mx.array(
612
+ cache_offset + self._rope_deltas if cache is not None else 0
613
+ )
614
+ position_ids = mx.arange(seq_length).reshape(1, -1)
615
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
616
+
617
+ if cache_offset is not None:
618
+ if delta.ndim == 0:
619
+ delta = mx.expand_dims(delta, axis=0)
620
+
621
+ if delta.shape[0] < batch_size:
622
+ delta = mx.tile(delta, (batch_size, 1))
623
+ else:
624
+ # Slice delta to match batch
625
+ delta = delta[:batch_size]
626
+
627
+ position_ids = mx.add(position_ids, delta)[None, ...]
628
+ position_ids = mx.broadcast_to(
629
+ position_ids, (3, batch_size, seq_length)
630
+ )
631
+
632
+ out = self.model(
633
+ inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
634
+ )
635
+
636
+ out = self.lm_head(out)
637
+ return LanguageModelOutput(logits=out)
638
+
639
+ def sanitize(self, weights):
640
+ mpt_layer = self.args.num_hidden_layers
641
+
642
+ # Stack experts
643
+ for l in range(self.args.num_hidden_layers):
644
+ prefix = f"language_model.model.layers.{l}"
645
+ for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
646
+ for k in ["weight", "scales", "biases"]:
647
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
648
+ to_join = [
649
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
650
+ for e in range(self.args.n_routed_experts)
651
+ ]
652
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
653
+
654
+ # Remove multi-token prediction layer
655
+ return {
656
+ k: v
657
+ for k, v in weights.items()
658
+ if not k.startswith(f"language_model.model.layers.{mpt_layer}")
659
+ }
660
+
661
+ @property
662
+ def layers(self):
663
+ return self.model.layers
664
+
665
+ @property
666
+ def cast_predicate(self):
667
+ def predicate(k):
668
+ return "e_score_correction_bias" not in k
669
+
670
+ return predicate
671
+
672
+ @property
673
+ def n_kv_heads(self):
674
+ return self.args.num_key_value_heads