fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,181 @@
1
+ from typing import Dict, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures, check_array_shape
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .vision import VisionModel
10
+
11
+ try:
12
+ from transformers import AutoImageProcessor, AutoProcessor
13
+
14
+ from .processing_hunyuan_vl import HunYuanVLImageProcessor, HunYuanVLProcessor
15
+
16
+ MODEL_TYPE = "hunyuan_vl"
17
+
18
+ AutoImageProcessor.register(
19
+ MODEL_TYPE, slow_image_processor_class=HunYuanVLImageProcessor
20
+ )
21
+ AutoProcessor.register(MODEL_TYPE, HunYuanVLProcessor)
22
+
23
+ except Exception as e:
24
+ raise e
25
+
26
+
27
+ class Model(nn.Module):
28
+
29
+ def __init__(self, config: ModelConfig):
30
+ super().__init__()
31
+ self.config = config
32
+ self.model_type = config.model_type
33
+ self.vision_tower = VisionModel(config.vision_config)
34
+ self.language_model = LanguageModel(config)
35
+
36
+ def get_input_embeddings(
37
+ self,
38
+ input_ids: Optional[mx.array] = None,
39
+ pixel_values: Optional[mx.array] = None,
40
+ **kwargs,
41
+ ) -> mx.array:
42
+
43
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
44
+
45
+ position_ids_from_processor = kwargs.pop("position_ids", None)
46
+
47
+ # Get text embeddings
48
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
49
+
50
+ # If no image, return text embeddings
51
+ if pixel_values is None:
52
+ # Reset stored position_ids when no image
53
+ self.language_model._position_ids = None
54
+ return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
55
+
56
+ # Get vision features
57
+ vision_features = self.vision_tower(pixel_values, image_grid_thw)
58
+
59
+ # Find image token positions and replace with vision features
60
+ image_token_id = self.config.image_token_id
61
+ image_mask = input_ids == image_token_id
62
+
63
+ # Get number of image tokens expected
64
+ num_image_tokens = image_mask.sum().item()
65
+ num_vision_tokens = vision_features.shape[1]
66
+
67
+ if num_image_tokens != num_vision_tokens:
68
+ raise ValueError(
69
+ f"Number of image placeholders ({num_image_tokens}) does not match "
70
+ f"number of vision tokens ({num_vision_tokens}). "
71
+ f"Expected token count based on grid: {num_vision_tokens}"
72
+ )
73
+
74
+ B, L, _ = inputs_embeds.shape
75
+
76
+ output_parts = []
77
+
78
+ for b in range(B):
79
+ mask_b = image_mask[b] # (L,) boolean mask
80
+ text_embeds_b = inputs_embeds[b] # (L, D)
81
+ vis_feats_b = vision_features[b] # (num_vis_tokens, D)
82
+
83
+ # Build sequence for this batch
84
+ vis_idx = 0
85
+ seq_parts = []
86
+ for pos in range(L):
87
+ if mask_b[pos].item():
88
+ # Use vision feature
89
+ seq_parts.append(vis_feats_b[vis_idx : vis_idx + 1])
90
+ vis_idx += 1
91
+ else:
92
+ # Use text embedding
93
+ seq_parts.append(text_embeds_b[pos : pos + 1])
94
+
95
+ # Concatenate all parts for this batch
96
+ batch_embeds = mx.concatenate(seq_parts, axis=0) # (L, D)
97
+ output_parts.append(batch_embeds[None, :, :]) # (1, L, D)
98
+
99
+ # Stack batches
100
+ inputs_embeds = mx.concatenate(output_parts, axis=0) # (B, L, D)
101
+
102
+ # Pre-calculate position_ids for chunked prefill
103
+ if position_ids_from_processor is not None:
104
+ self.language_model._position_ids = position_ids_from_processor
105
+ elif image_grid_thw is not None:
106
+ position_ids = self.language_model.get_xdrope_input_positions(
107
+ input_tokens=input_ids[0].tolist(),
108
+ image_grid_thw=image_grid_thw,
109
+ image_token_id=self.config.image_token_id,
110
+ spatial_merge_size=self.config.vision_config.spatial_merge_size,
111
+ )[None, ...]
112
+ self.language_model._position_ids = position_ids
113
+
114
+ return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
115
+
116
+ @property
117
+ def layers(self):
118
+ return self.language_model.model.layers
119
+
120
+ @property
121
+ def head_dim(self):
122
+ return self.config.text_config.head_dim
123
+
124
+ @property
125
+ def n_kv_heads(self):
126
+ return self.config.text_config.num_key_value_heads
127
+
128
+ def __call__(
129
+ self,
130
+ input_ids: mx.array,
131
+ pixel_values: Optional[mx.array] = None,
132
+ mask: Optional[mx.array] = None,
133
+ cache=None,
134
+ **kwargs,
135
+ ):
136
+
137
+ # Get embeddings (with vision features merged if image provided)
138
+ input_embeddings_features = self.get_input_embeddings(
139
+ input_ids=input_ids,
140
+ pixel_values=pixel_values,
141
+ **kwargs,
142
+ )
143
+
144
+ # Forward through language model
145
+ return self.language_model(
146
+ input_ids=input_ids,
147
+ inputs_embeds=input_embeddings_features.inputs_embeds,
148
+ mask=mask,
149
+ cache=cache,
150
+ image_grid_thw=image_grid_thw,
151
+ )
152
+
153
+ def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
154
+
155
+ sanitized = {}
156
+
157
+ for key, value in weights.items():
158
+ new_key = key
159
+
160
+ # Language model mappings
161
+ if key.startswith("model."):
162
+ new_key = "language_model." + key
163
+
164
+ # Vision tower mappings
165
+ elif key.startswith("vit."):
166
+ new_key = key.replace("vit.", "vision_tower.", 1)
167
+
168
+ # Handle Conv2d weight transposition for MLX
169
+ # PyTorch Conv2d: [out_channels, in_channels, kH, kW]
170
+ # MLX Conv2d: [out_channels, kH, kW, in_channels]
171
+ if (
172
+ "patch_embedding.weight" in new_key
173
+ or "proj.0.weight" in new_key
174
+ or "proj.2.weight" in new_key
175
+ ):
176
+ if not check_array_shape(value):
177
+ value = value.transpose(0, 2, 3, 1)
178
+
179
+ sanitized[new_key] = value
180
+
181
+ return sanitized
@@ -0,0 +1,509 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from ..cache import KVCache
13
+ from .config import ModelConfig, TextConfig
14
+
15
+
16
+ class HunyuanRotaryEmbedding:
17
+ def __init__(self, config: TextConfig):
18
+ super().__init__()
19
+ self.config = config
20
+ self.dim = config.head_dim
21
+ self.max_position_embeddings = config.max_position_embeddings
22
+ self.base = config.rope_theta
23
+
24
+ # Handle xdrope/dynamic scaling
25
+ self.xdrope_section = config.rope_scaling.get("xdrope_section")
26
+ self.rope_type = config.rope_scaling.get("type")
27
+ alpha = config.rope_scaling.get("alpha")
28
+
29
+ if config.rope_scaling is not None and self.rope_type in ["xdrope", "dynamic"]:
30
+ if alpha:
31
+ self.base = self.base * (alpha ** (self.dim / (self.dim - 2)))
32
+
33
+ inv_freq = 1.0 / (
34
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
35
+ )
36
+ self._inv_freq = inv_freq
37
+ self._cos_cached = None
38
+ self._sin_cached = None
39
+ self._cached_seq_len = 0
40
+
41
+ def _update_cache(self, seq_len: int, dtype: mx.Dtype):
42
+ if seq_len > self._cached_seq_len:
43
+ self._cached_seq_len = seq_len
44
+ t = mx.arange(seq_len, dtype=mx.float32)
45
+ freqs = mx.outer(t, self._inv_freq)
46
+ emb = mx.concatenate([freqs, freqs], axis=-1)
47
+ self._cos_cached = mx.cos(emb).astype(dtype)
48
+ self._sin_cached = mx.sin(emb).astype(dtype)
49
+
50
+ def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]:
51
+ self._update_cache(seq_len, x.dtype)
52
+ return self._cos_cached[:seq_len], self._sin_cached[:seq_len]
53
+
54
+
55
+ def rotate_half(x: mx.array) -> mx.array:
56
+ x1 = x[..., : x.shape[-1] // 2]
57
+ x2 = x[..., x.shape[-1] // 2 :]
58
+ return mx.concatenate([-x2, x1], axis=-1)
59
+
60
+
61
+ def apply_rotary_pos_emb_xdrope(
62
+ q: mx.array,
63
+ k: mx.array,
64
+ cos: mx.array,
65
+ sin: mx.array,
66
+ position_ids: mx.array,
67
+ xdrope_section: list,
68
+ output_size: tuple,
69
+ ) -> Tuple[mx.array, mx.array]:
70
+ """Applies XD Rotary Position Embedding."""
71
+
72
+ x_dim = len(xdrope_section)
73
+ cos = (
74
+ cos[position_ids, ...]
75
+ .transpose(0, 2, 1, 3)
76
+ .reshape(output_size[0], output_size[2], x_dim, -1)
77
+ )
78
+ sin = (
79
+ sin[position_ids, ...]
80
+ .transpose(0, 2, 1, 3)
81
+ .reshape(output_size[0], output_size[2], x_dim, -1)
82
+ )
83
+
84
+ xdrope_section = xdrope_section * 2
85
+
86
+ # for xd concat
87
+ assert sum(xdrope_section) == cos.shape[-1], "Illegal partition for xd rope"
88
+
89
+ # Convert split sizes to split indices for MLX
90
+ split_indices = [
91
+ sum(xdrope_section[: i + 1]) for i in range(len(xdrope_section) - 1)
92
+ ]
93
+ cos_splits = mx.split(cos, split_indices, axis=-1)
94
+ sin_splits = mx.split(sin, split_indices, axis=-1)
95
+
96
+ cos = mx.concatenate(
97
+ [m[:, :, i % x_dim, :] for i, m in enumerate(cos_splits)], axis=-1
98
+ )
99
+ sin = mx.concatenate(
100
+ [m[:, :, i % x_dim, :] for i, m in enumerate(sin_splits)], axis=-1
101
+ )
102
+
103
+ # for head repeat
104
+ cos = cos.reshape(output_size[0], 1, output_size[2], -1)
105
+ sin = sin.reshape(output_size[0], 1, output_size[2], -1)
106
+
107
+ origin_dtype = q.dtype
108
+ q, k = q.astype(mx.float32), k.astype(mx.float32)
109
+ cos, sin = cos.astype(mx.float32), sin.astype(mx.float32)
110
+
111
+ q_out = (q * cos) + (rotate_half(q) * sin)
112
+ k_out = (k * cos) + (rotate_half(k) * sin)
113
+
114
+ return q_out.astype(origin_dtype), k_out.astype(origin_dtype)
115
+
116
+
117
+ def apply_rotary_pos_emb(
118
+ q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, unsqueeze_dim: int = 1
119
+ ) -> Tuple[mx.array, mx.array]:
120
+ """Standard rotary position embedding.
121
+
122
+ Args:
123
+ q: Queries with shape (batch, n_heads, seq_len, head_dim)
124
+ k: Keys with shape (batch, n_heads, seq_len, head_dim)
125
+ cos: Cosine values with shape (seq_len, head_dim)
126
+ sin: Sine values with shape (seq_len, head_dim)
127
+ """
128
+ # Expand cos/sin to (1, 1, seq_len, head_dim) for broadcasting
129
+ cos = cos[None, None, :, :]
130
+ sin = sin[None, None, :, :]
131
+
132
+ q_embed = (q * cos) + (rotate_half(q) * sin)
133
+ k_embed = (k * cos) + (rotate_half(k) * sin)
134
+
135
+ return q_embed, k_embed
136
+
137
+
138
+ class Attention(nn.Module):
139
+ def __init__(self, config: TextConfig):
140
+ super().__init__()
141
+ self.config = config
142
+
143
+ self.hidden_size = config.hidden_size
144
+ self.n_heads = config.num_attention_heads
145
+ self.n_kv_heads = config.num_key_value_heads
146
+ self.head_dim = config.head_dim
147
+ self.scale = self.head_dim**-0.5
148
+
149
+ self.q_proj = nn.Linear(
150
+ self.hidden_size, self.n_heads * self.head_dim, bias=config.attention_bias
151
+ )
152
+ self.k_proj = nn.Linear(
153
+ self.hidden_size,
154
+ self.n_kv_heads * self.head_dim,
155
+ bias=config.attention_bias,
156
+ )
157
+ self.v_proj = nn.Linear(
158
+ self.hidden_size,
159
+ self.n_kv_heads * self.head_dim,
160
+ bias=config.attention_bias,
161
+ )
162
+ self.o_proj = nn.Linear(
163
+ self.n_heads * self.head_dim,
164
+ config.hidden_size,
165
+ bias=config.attention_bias,
166
+ )
167
+
168
+ if config.use_qk_norm:
169
+ self.query_layernorm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
170
+ self.key_layernorm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
171
+
172
+ self.rotary_emb = HunyuanRotaryEmbedding(config=config)
173
+
174
+ self.xdrope_section = None
175
+ if config.rope_scaling is not None:
176
+ self.xdrope_section = config.rope_scaling.get("xdrope_section")
177
+
178
+ def __call__(
179
+ self,
180
+ x: mx.array,
181
+ mask: Optional[mx.array] = None,
182
+ cache: Optional[KVCache] = None,
183
+ position_ids: Optional[mx.array] = None,
184
+ ) -> mx.array:
185
+ B, L, _ = x.shape
186
+
187
+ # Project Q, K, V
188
+ queries = self.q_proj(x)
189
+ keys = self.k_proj(x)
190
+ values = self.v_proj(x)
191
+
192
+ # Reshape to (B, n_heads, L, head_dim)
193
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
194
+ 0, 2, 1, 3
195
+ )
196
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
197
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
198
+ 0, 2, 1, 3
199
+ )
200
+
201
+ kv_seq_len = L
202
+ offset = 0
203
+ if cache is not None:
204
+ offset = cache.offset
205
+ kv_seq_len += offset
206
+
207
+ cos, sin = self.rotary_emb(values, seq_len=kv_seq_len)
208
+
209
+ # Apply rotary embeddings
210
+ if self.xdrope_section is not None and (cache is None or offset == 0):
211
+ # XD RoPE for prefill (first forward pass)
212
+ output_size = (B, self.n_heads, L, L)
213
+ queries, keys = apply_rotary_pos_emb_xdrope(
214
+ queries,
215
+ keys,
216
+ cos,
217
+ sin,
218
+ position_ids,
219
+ self.xdrope_section,
220
+ output_size,
221
+ )
222
+ else:
223
+ # Standard RoPE for decode (subsequent tokens)
224
+ if cache is not None and offset > 0:
225
+ cos = cos[-L:]
226
+ sin = sin[-L:]
227
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
228
+
229
+ # Apply QK normalization if configured
230
+ if self.config.use_qk_norm:
231
+ queries = self.query_layernorm(queries)
232
+ keys = self.key_layernorm(keys)
233
+
234
+ # Update cache
235
+ if cache is not None:
236
+ keys, values = cache.update_and_fetch(keys, values)
237
+
238
+ # Apply mask
239
+ if mask is not None and isinstance(mask, mx.array):
240
+ mask = mask[..., : keys.shape[-2]]
241
+
242
+ output = scaled_dot_product_attention(
243
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
244
+ )
245
+
246
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
247
+ return self.o_proj(output)
248
+
249
+
250
+ class MLP(nn.Module):
251
+ def __init__(self, config: TextConfig):
252
+ super().__init__()
253
+ self.hidden_size = config.hidden_size
254
+ self.intermediate_size = config.intermediate_size
255
+
256
+ self.gate_proj = nn.Linear(
257
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
258
+ )
259
+ self.up_proj = nn.Linear(
260
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
261
+ )
262
+ self.down_proj = nn.Linear(
263
+ self.intermediate_size, self.hidden_size, bias=config.mlp_bias
264
+ )
265
+
266
+ def __call__(self, x: mx.array) -> mx.array:
267
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
268
+
269
+
270
+ class DecoderLayer(nn.Module):
271
+ def __init__(self, config: TextConfig):
272
+ super().__init__()
273
+ self.hidden_size = config.hidden_size
274
+ self.self_attn = Attention(config)
275
+ self.mlp = MLP(config)
276
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
277
+ self.post_attention_layernorm = nn.RMSNorm(
278
+ config.hidden_size, eps=config.rms_norm_eps
279
+ )
280
+
281
+ def __call__(
282
+ self,
283
+ x: mx.array,
284
+ mask: Optional[mx.array] = None,
285
+ cache: Optional[KVCache] = None,
286
+ position_ids: Optional[mx.array] = None,
287
+ ) -> mx.array:
288
+ # Self-attention with residual
289
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
290
+ h = x + r
291
+
292
+ # MLP with residual
293
+ r = self.mlp(self.post_attention_layernorm(h))
294
+ out = h + r
295
+
296
+ return out
297
+
298
+
299
+ class HunyuanModel(nn.Module):
300
+ def __init__(self, config: TextConfig):
301
+ super().__init__()
302
+ self.config = config
303
+ self.vocab_size = config.vocab_size
304
+ self.num_hidden_layers = config.num_hidden_layers
305
+
306
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
307
+ self.layers = [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
308
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
309
+
310
+ def __call__(
311
+ self,
312
+ input_ids: Optional[mx.array] = None,
313
+ inputs_embeds: Optional[mx.array] = None,
314
+ mask: Optional[mx.array] = None,
315
+ cache=None,
316
+ position_ids: Optional[mx.array] = None,
317
+ ) -> mx.array:
318
+
319
+ if inputs_embeds is None:
320
+ h = self.embed_tokens(input_ids)
321
+ else:
322
+ h = inputs_embeds
323
+
324
+ if cache is None:
325
+ cache = [None] * len(self.layers)
326
+
327
+ if mask is None:
328
+ mask = create_attention_mask(h, cache)
329
+
330
+ for layer, c in zip(self.layers, cache):
331
+ h = layer(h, mask, c, position_ids)
332
+
333
+ return self.norm(h)
334
+
335
+
336
+ class LanguageModel(nn.Module):
337
+ def __init__(self, config: ModelConfig = None):
338
+ super().__init__()
339
+ self.args = config.text_config
340
+ self.config = config
341
+ self.model_type = self.args.model_type
342
+ self.model = HunyuanModel(self.args)
343
+ self._position_ids = None
344
+
345
+ if not self.args.tie_word_embeddings:
346
+ self.lm_head = nn.Linear(
347
+ self.args.hidden_size, self.args.vocab_size, bias=False
348
+ )
349
+
350
+ def get_xdrope_input_positions(
351
+ self,
352
+ input_tokens: List[int],
353
+ image_grid_thw: Optional[mx.array],
354
+ image_token_id: int,
355
+ spatial_merge_size: int,
356
+ ) -> mx.array:
357
+ """Compute XD-RoPE position IDs for image-text interleaved inputs."""
358
+
359
+ xd_num = len(self.args.rope_scaling["xdrope_section"])
360
+
361
+ input_tokens_arr = np.array(input_tokens)
362
+ image_start_indices = np.where(input_tokens_arr == image_token_id)[0].tolist()
363
+
364
+ seq_len = len(input_tokens)
365
+ p_index = np.arange(seq_len)
366
+ w_index = np.arange(seq_len)
367
+ h_index = np.arange(seq_len)
368
+ t_index = np.arange(seq_len)
369
+
370
+ # Process image positions if we have images
371
+ if image_grid_thw is not None and len(image_start_indices) > 0:
372
+ for image_index in range(len(image_start_indices)):
373
+ # +2: skip first image_token and account for xdrope positions
374
+ pos = int(image_start_indices[image_index]) + 1
375
+ _, h, w = image_grid_thw.flatten().tolist()
376
+
377
+ llm_grid_h = h // spatial_merge_size
378
+ llm_grid_w = w // spatial_merge_size
379
+
380
+ token_num = (llm_grid_w + 1) * llm_grid_h
381
+
382
+ # Ensure we don't go out of bounds
383
+ end_pos = min(pos + token_num, seq_len)
384
+ actual_token_num = end_pos - pos
385
+
386
+ if actual_token_num > 0:
387
+ # w_index: [0, 1, ..., grid_w, 0, 1, ..., grid_w, ...] repeated for each row
388
+ w_pattern = np.tile(np.arange(llm_grid_w + 1), llm_grid_h)[
389
+ :actual_token_num
390
+ ]
391
+ w_index[pos:end_pos] = w_pattern
392
+
393
+ # h_index: [0, 0, ..., 0, 1, 1, ..., 1, ...] each repeated (grid_w + 1) times
394
+ h_pattern = np.repeat(np.arange(llm_grid_h), llm_grid_w + 1)[
395
+ :actual_token_num
396
+ ]
397
+ h_index[pos:end_pos] = h_pattern
398
+
399
+ # t_index: image index for temporal dimension
400
+ t_index[pos:end_pos] = image_index
401
+
402
+ # Stack based on number of xdrope dimensions
403
+ if xd_num == 4:
404
+ llm_positions = mx.stack(
405
+ [
406
+ mx.array(p_index),
407
+ mx.array(t_index),
408
+ mx.array(h_index),
409
+ mx.array(w_index),
410
+ ]
411
+ )
412
+ elif xd_num == 3:
413
+ llm_positions = mx.stack(
414
+ [
415
+ mx.array(t_index),
416
+ mx.array(h_index),
417
+ mx.array(w_index),
418
+ ]
419
+ )
420
+ else:
421
+ # Fallback: just use sequential positions
422
+ llm_positions = mx.stack([mx.array(p_index)] * xd_num)
423
+
424
+ return llm_positions
425
+
426
+ def __call__(
427
+ self,
428
+ inputs: Optional[mx.array] = None,
429
+ inputs_embeds: Optional[mx.array] = None,
430
+ mask: Optional[mx.array] = None,
431
+ cache=None,
432
+ **kwargs,
433
+ ) -> LanguageModelOutput:
434
+
435
+ kwargs_position_ids = kwargs.pop("position_ids", None)
436
+
437
+ # Compute cache offset
438
+ cache_offset = 0
439
+ if cache is not None and cache[0] is not None:
440
+ offset = cache[0].offset
441
+ if isinstance(offset, int):
442
+ cache_offset = offset
443
+ elif isinstance(offset, mx.array):
444
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
445
+ else:
446
+ cache_offset = int(offset)
447
+
448
+ # Determine sequence length from inputs or inputs_embeds
449
+ if inputs_embeds is not None:
450
+ seq_length = inputs_embeds.shape[1]
451
+ elif inputs is not None:
452
+ seq_length = inputs.shape[1]
453
+ else:
454
+ seq_length = 0
455
+
456
+ position_ids = None
457
+ if cache is None or cache_offset == 0:
458
+ # Prefill phase - need xdrope position_ids
459
+ if self._position_ids is not None:
460
+ # Use stored position_ids (sliced for chunked prefill)
461
+ position_ids = self._position_ids[
462
+ :, :, cache_offset : cache_offset + seq_length
463
+ ]
464
+ elif kwargs_position_ids is not None:
465
+ # Use position_ids from kwargs (e.g., from processor)
466
+ if not isinstance(kwargs_position_ids, mx.array):
467
+ kwargs_position_ids = mx.array(kwargs_position_ids)
468
+ # Store for potential future chunks and slice for current chunk
469
+ self._position_ids = kwargs_position_ids
470
+ position_ids = self._position_ids[
471
+ :, :, cache_offset : cache_offset + seq_length
472
+ ]
473
+ elif inputs is not None:
474
+ # Compute position_ids on the fly (for non-chunked prefill)
475
+ position_ids = self.get_xdrope_input_positions(
476
+ input_tokens=inputs[0].tolist(),
477
+ image_grid_thw=kwargs.get("image_grid_thw", None),
478
+ image_token_id=self.config.image_token_id,
479
+ spatial_merge_size=self.config.vision_config.spatial_merge_size,
480
+ )[None, ...]
481
+ # Store for potential future chunks
482
+ self._position_ids = position_ids
483
+
484
+ out = self.model(
485
+ input_ids=inputs,
486
+ inputs_embeds=inputs_embeds,
487
+ mask=mask,
488
+ cache=cache,
489
+ position_ids=position_ids,
490
+ )
491
+
492
+ if self.args.tie_word_embeddings:
493
+ logits = self.model.embed_tokens.as_linear(out)
494
+ else:
495
+ logits = self.lm_head(out)
496
+
497
+ return LanguageModelOutput(logits=logits)
498
+
499
+ @property
500
+ def layers(self):
501
+ return self.model.layers
502
+
503
+ @property
504
+ def head_dim(self):
505
+ return self.args.head_dim
506
+
507
+ @property
508
+ def n_kv_heads(self):
509
+ return self.args.num_key_value_heads