fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,574 @@
1
+ from typing import Any, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from .config import ModelConfig, TextConfig
13
+
14
+
15
+ def _compute_default_rope_parameters(
16
+ config: Optional[TextConfig] = None,
17
+ **rope_kwargs,
18
+ ) -> tuple[mx.array, float]:
19
+
20
+ if config is not None and len(rope_kwargs) > 0:
21
+ raise ValueError(
22
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
23
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
24
+ )
25
+ if len(rope_kwargs) > 0:
26
+ base = rope_kwargs["base"]
27
+ dim = rope_kwargs["dim"]
28
+ elif config is not None:
29
+ base = config.rope_theta
30
+ partial_rotary_factor = (
31
+ config.partial_rotary_factor
32
+ if hasattr(config, "partial_rotary_factor")
33
+ else 1.0
34
+ )
35
+ head_dim = (
36
+ getattr(config, "head_dim", None)
37
+ or config.hidden_size // config.num_attention_heads
38
+ )
39
+ dim = int(head_dim * partial_rotary_factor)
40
+
41
+ attention_factor = 1.0 # Unused in this type of RoPE
42
+
43
+ # Compute the inverse frequencies
44
+ inv_freq = 1.0 / (
45
+ base ** (mx.arange(0, dim, 2, dtype=mx.int64).astype(mx.float32) / dim)
46
+ )
47
+ return inv_freq, attention_factor
48
+
49
+
50
+ class GLM4VRotaryEmbedding(nn.Module):
51
+ def __init__(self, config: TextConfig):
52
+ super().__init__()
53
+
54
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
55
+ self.rope_type = config.rope_scaling.get(
56
+ "rope_type", config.rope_scaling.get("type")
57
+ )
58
+ else:
59
+ self.rope_type = "default"
60
+ self.max_seq_len_cached = config.max_position_embeddings
61
+ self.original_max_seq_len = config.max_position_embeddings
62
+
63
+ self.config = config
64
+
65
+ self.rope_init_fn = _compute_default_rope_parameters
66
+
67
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config)
68
+ self._inv_freq = mx.array(inv_freq, dtype=mx.float32)
69
+ self._original_inv_freq = mx.array(inv_freq, dtype=mx.float32)
70
+
71
+ def __call__(self, x, position_ids):
72
+ inv_freq_expanded = self._inv_freq[None, None, :, None].astype(mx.float32)
73
+ inv_freq_expanded = mx.broadcast_to(
74
+ inv_freq_expanded, (3, position_ids.shape[1], self._inv_freq.shape[0], 1)
75
+ )
76
+ position_ids_expanded = position_ids[:, :, None, :].astype(mx.float32)
77
+
78
+ freqs = (
79
+ inv_freq_expanded.astype(mx.float32)
80
+ @ position_ids_expanded.astype(mx.float32)
81
+ ).transpose(0, 1, 3, 2)
82
+ emb = mx.concatenate((freqs, freqs), axis=-1)
83
+ cos = mx.cos(emb) * self.attention_scaling
84
+ sin = mx.sin(emb) * self.attention_scaling
85
+
86
+ return cos.astype(x.dtype), sin.astype(x.dtype)
87
+
88
+
89
+ def rotate_half_llm(x):
90
+ """Rotates half the hidden dims of the input."""
91
+ x1 = x[..., 0::2]
92
+ x2 = x[..., 1::2]
93
+ return mx.flatten(mx.stack([-x2, x1], axis=-1), start_axis=-2, end_axis=-1)
94
+
95
+
96
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section):
97
+ """
98
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
99
+ Args:
100
+ q (mx.array): The query tensor.
101
+ k (mx.array): The key tensor.
102
+ cos (mx.array): The cosine part of the rotary embedding.
103
+ sin (mx.array): The sine part of the rotary embedding.
104
+ mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
105
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
106
+ Returns:
107
+ tuple(mx.array): The rotated query and key tensors.
108
+ """
109
+
110
+ mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
111
+
112
+ cos = mx.concatenate(
113
+ [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
114
+ )[
115
+ :, None, :, :
116
+ ] # unsqueeze dim 1
117
+ sin = mx.concatenate(
118
+ [m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
119
+ )[:, None, :, :]
120
+
121
+ cos = mx.repeat(cos[..., : cos.shape[-1] // 2], repeats=2, axis=-1)
122
+ sin = mx.repeat(sin[..., : sin.shape[-1] // 2], repeats=2, axis=-1)
123
+
124
+ rotary_dim = cos.shape[-1]
125
+ q_rot = q[..., :rotary_dim]
126
+ q_pass = q[..., rotary_dim:]
127
+
128
+ k_rot = k[..., :rotary_dim]
129
+ k_pass = k[..., rotary_dim:]
130
+
131
+ q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
132
+ k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
133
+
134
+ q_embed = mx.concatenate([q_embed, q_pass], axis=-1)
135
+ k_embed = mx.concatenate([k_embed, k_pass], axis=-1)
136
+
137
+ return q_embed, k_embed
138
+
139
+
140
+ class Glm4vAttention(nn.Module):
141
+ def __init__(self, args: TextConfig):
142
+ super().__init__()
143
+
144
+ dim = args.hidden_size
145
+ self.n_heads = n_heads = args.num_attention_heads
146
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
147
+
148
+ head_dim = getattr(
149
+ args, "head_dim", args.hidden_size // args.num_attention_heads
150
+ )
151
+ self.scale = head_dim**-0.5
152
+
153
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
154
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
155
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
156
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
157
+
158
+ self.rope_scaling = args.rope_scaling
159
+
160
+ def __call__(
161
+ self,
162
+ x: mx.array,
163
+ mask: Optional[mx.array] = None,
164
+ cache: Optional[Any] = None,
165
+ position_embeddings: Optional[mx.array] = None,
166
+ ) -> mx.array:
167
+ B, L, D = x.shape
168
+
169
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
170
+
171
+ queries = queries.reshape(B, L, self.n_heads, -1)
172
+ keys = keys.reshape(B, L, self.n_kv_heads, -1)
173
+
174
+ queries = queries.transpose(0, 2, 1, 3)
175
+ keys = keys.transpose(0, 2, 1, 3)
176
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
177
+
178
+ cos, sin = position_embeddings
179
+
180
+ queries, keys = apply_multimodal_rotary_pos_emb(
181
+ queries, keys, cos, sin, self.rope_scaling["mrope_section"]
182
+ )
183
+
184
+ if cache is not None:
185
+ keys, values = cache.update_and_fetch(keys, values)
186
+
187
+ if mask is not None and isinstance(mask, mx.array):
188
+ mask = mask[..., : keys.shape[-2]]
189
+
190
+ output = scaled_dot_product_attention(
191
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
192
+ )
193
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
194
+ return self.o_proj(output)
195
+
196
+
197
+ class Glm4vMLP(nn.Module):
198
+ def __init__(
199
+ self, config: TextConfig, hidden_size: int = None, intermediate_size: int = None
200
+ ):
201
+ super().__init__()
202
+ self.config = config
203
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
204
+ self.intermediate_size = (
205
+ config.intermediate_size if intermediate_size is None else intermediate_size
206
+ )
207
+
208
+ self.gate_up_proj = nn.Linear(
209
+ self.hidden_size, self.intermediate_size * 2, bias=False
210
+ )
211
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
212
+
213
+ def __call__(self, x):
214
+ x = self.gate_up_proj(x)
215
+ gate, x = mx.split(x, 2, axis=-1)
216
+ return self.down_proj(nn.silu(gate) * x)
217
+
218
+
219
+ class Glm4vDecoderLayer(nn.Module):
220
+ def __init__(self, config: TextConfig):
221
+ super().__init__()
222
+ self.self_attn = Glm4vAttention(config)
223
+ self.mlp = Glm4vMLP(config)
224
+
225
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
226
+ self.post_attention_layernorm = nn.RMSNorm(
227
+ config.hidden_size, eps=config.rms_norm_eps
228
+ )
229
+ self.post_self_attn_layernorm = nn.RMSNorm(
230
+ config.hidden_size, eps=config.rms_norm_eps
231
+ )
232
+ self.post_mlp_layernorm = nn.RMSNorm(
233
+ config.hidden_size, eps=config.rms_norm_eps
234
+ )
235
+
236
+ def __call__(
237
+ self,
238
+ x: mx.array,
239
+ mask: Optional[mx.array] = None,
240
+ cache: Optional[Any] = None,
241
+ position_embeddings: Optional[mx.array] = None,
242
+ ) -> mx.array:
243
+ r = x
244
+
245
+ # Self Attention
246
+ x = self.self_attn(self.input_layernorm(x), mask, cache, position_embeddings)
247
+
248
+ x = self.post_self_attn_layernorm(x)
249
+ x = r + x
250
+
251
+ # Fully Connected
252
+ r = x
253
+ x = self.post_attention_layernorm(x)
254
+ x = self.mlp(x)
255
+ x = self.post_mlp_layernorm(x)
256
+ x = r + x
257
+ return x
258
+
259
+
260
+ class GLM4VModel(nn.Module):
261
+ def __init__(self, config: TextConfig):
262
+ super().__init__()
263
+ self.vocab_size = config.vocab_size
264
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
265
+ self.layers = [
266
+ Glm4vDecoderLayer(config) for _ in range(config.num_hidden_layers)
267
+ ]
268
+ self.start_idx = 0
269
+ self.end_idx = len(self.layers)
270
+ self.num_layers = self.end_idx
271
+
272
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
+
274
+ self.rotary_emb = GLM4VRotaryEmbedding(config)
275
+
276
+ def __call__(
277
+ self,
278
+ inputs: mx.array,
279
+ inputs_embeds: Optional[mx.array] = None,
280
+ cache: Optional[Any] = None,
281
+ mask: Optional[mx.array] = None,
282
+ position_ids: Optional[mx.array] = None,
283
+ ) -> mx.array:
284
+
285
+ if inputs_embeds is None:
286
+ h = self.embed_tokens(inputs)
287
+ else:
288
+ h = inputs_embeds.astype(self.norm.weight.dtype)
289
+
290
+ if position_ids is None:
291
+ position_ids = mx.arange(cache[0].offset, cache[0].offset + h.shape[-2])
292
+ position_ids = mx.expand_dims(position_ids, axis=0)
293
+ position_ids = mx.tile(position_ids, (3, 1, 1))
294
+
295
+ position_embeddings = self.rotary_emb(h, position_ids)
296
+
297
+ if mask is None:
298
+ mask = create_attention_mask(h, cache)
299
+
300
+ if cache is None:
301
+ cache = [None] * self.num_layers
302
+
303
+ for i in range(self.num_layers):
304
+ h = self.layers[self.start_idx + i](h, mask, cache[i], position_embeddings)
305
+
306
+ return self.norm(h)
307
+
308
+
309
+ class LanguageModel(nn.Module):
310
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
311
+ super().__init__()
312
+ self.args = args
313
+ self.config = config
314
+ self.model_type = args.model_type
315
+ self.model = GLM4VModel(args)
316
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
317
+ self._rope_deltas = None
318
+
319
+ def get_rope_index(
320
+ self,
321
+ input_ids: mx.array,
322
+ image_grid_thw: Optional[mx.array] = None,
323
+ video_grid_thw: Optional[mx.array] = None,
324
+ attention_mask: Optional[mx.array] = None,
325
+ ):
326
+ # Calculate RoPE index for image/video tokens
327
+ batch_size, seq_length = input_ids.shape
328
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
329
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
330
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
331
+ image_token_id = self.config.image_token_id
332
+ video_token_id = self.config.video_token_id
333
+ vision_start_token_id = self.config.vision_start_token_id
334
+ mrope_position_deltas = []
335
+ if input_ids is not None and (
336
+ image_grid_thw is not None or video_grid_thw is not None
337
+ ):
338
+ total_input_ids = input_ids
339
+ if attention_mask is None:
340
+ attention_mask = mx.ones_like(input_ids)
341
+ position_ids = mx.ones(
342
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
343
+ )
344
+ image_index, video_index = 0, 0
345
+ for i, input_ids in enumerate(total_input_ids):
346
+ input_ids = mx.where(
347
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
348
+ )
349
+ image_nums, video_nums = 0, 0
350
+ vision_start_indices = mx.sum(
351
+ mx.where(
352
+ input_ids == vision_start_token_id,
353
+ mx.arange(input_ids.shape[0]),
354
+ mx.zeros_like(input_ids),
355
+ )
356
+ )
357
+ vision_tokens = input_ids[vision_start_indices + 1]
358
+ image_nums = (vision_tokens == image_token_id).sum().item()
359
+ video_nums = (vision_tokens == video_token_id).sum().item()
360
+ input_tokens = input_ids.tolist()
361
+ llm_pos_ids_list: list = []
362
+ st = 0
363
+ remain_images, remain_videos = image_nums, video_nums
364
+ for _ in range(image_nums + video_nums):
365
+ if image_token_id in input_tokens and remain_images > 0:
366
+ ed_image = input_tokens.index(image_token_id, st)
367
+ else:
368
+ ed_image = len(input_tokens) + 1
369
+ if video_token_id in input_tokens and remain_videos > 0:
370
+ ed_video = input_tokens.index(video_token_id, st)
371
+ else:
372
+ ed_video = len(input_tokens) + 1
373
+ if ed_image < ed_video:
374
+ t, h, w = (
375
+ image_grid_thw[image_index][0],
376
+ image_grid_thw[image_index][1],
377
+ image_grid_thw[image_index][2],
378
+ )
379
+ image_index += 1
380
+ remain_images -= 1
381
+ ed = ed_image
382
+ else:
383
+ t, h, w = (
384
+ video_grid_thw[video_index][0],
385
+ video_grid_thw[video_index][1],
386
+ video_grid_thw[video_index][2],
387
+ )
388
+ video_index += 1
389
+ remain_videos -= 1
390
+ ed = ed_video
391
+ llm_grid_t, llm_grid_h, llm_grid_w = (
392
+ t.item(),
393
+ h.item() // spatial_merge_size,
394
+ w.item() // spatial_merge_size,
395
+ )
396
+ text_len = ed - st
397
+ st_idx = (
398
+ llm_pos_ids_list[-1].max() + 1
399
+ if len(llm_pos_ids_list) > 0
400
+ else 0
401
+ )
402
+ index = mx.arange(text_len).reshape(1, text_len)
403
+ index = mx.broadcast_to(index, (3, text_len))
404
+ index = index + st_idx
405
+ llm_pos_ids_list.append(index)
406
+ t_index = mx.arange(llm_grid_t).reshape(
407
+ llm_grid_t, 1
408
+ ) # Equivalent to .view(-1, 1)
409
+ t_index = mx.broadcast_to(
410
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
411
+ ) # Equivalent to expand()
412
+ t_index = t_index.flatten() # Flattens to 1D
413
+
414
+ h_index = mx.arange(llm_grid_h).reshape(
415
+ 1, llm_grid_h, 1
416
+ ) # Equivalent to .view(1, -1)
417
+ h_index = mx.broadcast_to(
418
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
419
+ ) # Equivalent to expand()
420
+ h_index = h_index.flatten() # Flattens to 1D
421
+
422
+ w_index = mx.arange(llm_grid_w).reshape(
423
+ 1, 1, llm_grid_w
424
+ ) # Equivalent to .view(1, -1)
425
+ w_index = mx.broadcast_to(
426
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
427
+ ) # Equivalent to expand()
428
+ w_index = w_index.flatten() # Flattens to 1D
429
+
430
+ llm_pos_ids_list.append(
431
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
432
+ )
433
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
434
+ if st < len(input_tokens):
435
+ st_idx = (
436
+ llm_pos_ids_list[-1].max() + 1
437
+ if len(llm_pos_ids_list) > 0
438
+ else 0
439
+ )
440
+ text_len = len(input_tokens) - st
441
+
442
+ t_index = mx.arange(text_len).reshape(
443
+ 1, text_len
444
+ ) # Equivalent to .view(-1, 1)
445
+ t_index = mx.broadcast_to(
446
+ t_index, (3, text_len)
447
+ ) # Equivalent to expand(3, -1)
448
+
449
+ llm_pos_ids_list.append(t_index + st_idx)
450
+
451
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
452
+ mask = mx.array(attention_mask[i] == 1)
453
+ expanded_mask = mx.expand_dims(mask, axis=0)
454
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
455
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
456
+ new_positions = mx.where(
457
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
458
+ )
459
+ updated_position_ids = mx.concatenate(
460
+ [
461
+ position_ids[:, :i, :],
462
+ new_positions,
463
+ position_ids[:, i + 1 :, :],
464
+ ],
465
+ axis=1,
466
+ )
467
+ position_ids = updated_position_ids
468
+ mrope_position_deltas.append(
469
+ llm_positions.max() + 1 - len(total_input_ids[i])
470
+ )
471
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
472
+ return position_ids, mrope_position_deltas
473
+ else:
474
+ if attention_mask is not None:
475
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
476
+ position_ids = mx.where(
477
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
478
+ )
479
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
480
+ position_ids = mx.tile(position_ids, (3, 1, 1))
481
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
482
+ -1, keepdims=True
483
+ )[0]
484
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
485
+ else:
486
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
487
+ position_ids = mx.broadcast_to(
488
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
489
+ )
490
+ mrope_position_deltas = mx.zeros(
491
+ [input_ids.shape[0], 1],
492
+ dtype=input_ids.dtype,
493
+ )
494
+ return position_ids, mrope_position_deltas
495
+
496
+ def __call__(
497
+ self,
498
+ inputs: mx.array,
499
+ inputs_embeds: Optional[mx.array] = None,
500
+ mask: Optional[mx.array] = None,
501
+ cache=None,
502
+ **kwargs,
503
+ ):
504
+
505
+ position_ids = kwargs.pop("position_ids", None)
506
+ pixel_values = kwargs.pop("pixel_values", None)
507
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
508
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
509
+ # reset rope_deltas when processing a new image/video
510
+ if pixel_values is not None:
511
+ self._rope_deltas = None
512
+
513
+ cache_offset = 0
514
+ if cache and cache[0] is not None:
515
+ offset = cache[0].offset
516
+ if isinstance(offset, int):
517
+ cache_offset = offset
518
+ elif isinstance(offset, mx.array):
519
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
520
+ else:
521
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
522
+
523
+ if position_ids is None and (mask is None or mask.ndim == 2):
524
+ # Calculate RoPE index once per generation in the pre-fill stage only
525
+ if (
526
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
527
+ or self._rope_deltas is None
528
+ or cache is None
529
+ ):
530
+ position_ids, rope_deltas = self.get_rope_index(
531
+ inputs, image_grid_thw, video_grid_thw, mask
532
+ )
533
+ self._rope_deltas = rope_deltas
534
+ else:
535
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
536
+ batch_size, seq_length = inputs.shape
537
+ delta = mx.array(
538
+ cache_offset + self._rope_deltas if cache is not None else 0
539
+ )
540
+ position_ids = mx.arange(seq_length).reshape(1, -1)
541
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
542
+
543
+ if cache_offset is not None:
544
+ if delta.ndim == 0:
545
+ delta = mx.expand_dims(delta, axis=0)
546
+
547
+ if delta.shape[0] < batch_size:
548
+ delta = mx.tile(delta, (batch_size, 1))
549
+ else:
550
+ # Slice delta to match batch
551
+ delta = delta[:batch_size]
552
+
553
+ position_ids = mx.add(position_ids, delta)[None, ...]
554
+ position_ids = mx.broadcast_to(
555
+ position_ids, (3, batch_size, seq_length)
556
+ )
557
+
558
+ out = self.model(
559
+ inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
560
+ )
561
+
562
+ out = self.lm_head(out)
563
+ return LanguageModelOutput(logits=out)
564
+
565
+ def sanitize(self, weights):
566
+ return weights
567
+
568
+ @property
569
+ def layers(self):
570
+ return self.model.layers
571
+
572
+ @property
573
+ def n_kv_heads(self):
574
+ return self.args.num_key_value_heads