fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,84 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig(BaseModelConfig):
9
+ text_config: "TextConfig"
10
+ vision_config: "VisionConfig"
11
+ model_type: str
12
+ ignore_index: int = -100
13
+ vocab_size: int = 128259
14
+ scale_factor: int = 2
15
+ media_placeholder_token_id: int = 163606
16
+ image_token_index: Optional[int] = None
17
+ eos_token_id: Optional[List[int]] = None
18
+
19
+ def __post_init__(self):
20
+ if self.image_token_index is None:
21
+ self.image_token_index = self.media_placeholder_token_id
22
+
23
+
24
+ @dataclass
25
+ class TextConfig(BaseModelConfig):
26
+ model_type: str = "deepseek_v3"
27
+ vocab_size: int = 102400
28
+ hidden_size: int = 4096
29
+ intermediate_size: int = 11008
30
+ moe_intermediate_size: int = 1407
31
+ num_hidden_layers: int = 30
32
+ num_attention_heads: int = 32
33
+ num_key_value_heads: int = 32
34
+ n_shared_experts: Optional[int] = None
35
+ n_routed_experts: Optional[int] = None
36
+ routed_scaling_factor: float = 1.0
37
+ kv_lora_rank: int = 512
38
+ q_lora_rank: int = 1536
39
+ qk_rope_head_dim: int = 64
40
+ v_head_dim: int = 128
41
+ qk_nope_head_dim: int = 128
42
+ topk_method: str = "noaux_tc"
43
+ scoring_func: str = "sigmoid"
44
+ norm_topk_prob: bool = True
45
+ n_group: Optional[int] = None
46
+ topk_group: Optional[int] = None
47
+ num_experts_per_tok: Optional[int] = None
48
+ moe_layer_freq: int = 1
49
+ first_k_dense_replace: int = 0
50
+ max_position_embeddings: int = 2048
51
+ rms_norm_eps: float = 1e-6
52
+ rope_theta: float = 10000.0
53
+ rope_scaling: Dict = None
54
+ attention_bias: bool = False
55
+
56
+ def __post_init__(self):
57
+ if self.num_key_value_heads is None:
58
+ self.num_key_value_heads = self.num_attention_heads
59
+
60
+
61
+ @dataclass
62
+ class VisionConfig(BaseModelConfig):
63
+ model_type: str = "moonvit"
64
+ depth: int = 27
65
+ embed_dim: int = 1152
66
+ hidden_size: int = 1152
67
+ num_heads: int = 16
68
+ image_size: int = 384
69
+ patch_size: int = 14
70
+ vocab_size: int = 32000
71
+ mlp_ratio: float = 4.0
72
+ num_channels: int = 3
73
+ layer_norm_eps: float = 1e-6
74
+ intermediate_size: int = 4304
75
+ init_pos_emb_height: int = 64
76
+ init_pos_emb_width: int = 64
77
+ spatial_patch_size: int = 14
78
+ spatial_merge_size: int = 2
79
+ temporal_patch_size: int = 2
80
+ merge_kernel_size: list[int, int] = None
81
+
82
+ def __post_init__(self):
83
+ if self.merge_kernel_size is None:
84
+ self.merge_kernel_size = (self.spatial_merge_size, self.spatial_merge_size)
@@ -0,0 +1,127 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from transformers import AutoImageProcessor, AutoProcessor
7
+
8
+ from ..base import InputEmbeddingsFeatures
9
+ from .config import ModelConfig
10
+ from .language import LanguageModel
11
+ from .processing_kimi_vl import KimiVLImageProcessor, KimiVLProcessor
12
+ from .vision import VisionModel
13
+
14
+ # Register custom processor classes for kimi_vl model type
15
+ try:
16
+ MODEL_TYPE = "kimi_vl"
17
+ AutoImageProcessor.register(
18
+ MODEL_TYPE, slow_image_processor_class=KimiVLImageProcessor
19
+ )
20
+ AutoProcessor.register(MODEL_TYPE, KimiVLProcessor)
21
+ except Exception:
22
+ raise Exception("Failed to register kimi_vl processor")
23
+
24
+
25
+ class KimiVLMultiModalProjector(nn.Module):
26
+
27
+ def __init__(self, config: ModelConfig):
28
+ super().__init__()
29
+
30
+ self.hidden_size = (
31
+ config.vision_config.hidden_size
32
+ * config.vision_config.merge_kernel_size[0]
33
+ * config.vision_config.merge_kernel_size[1]
34
+ )
35
+
36
+ self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05)
37
+ self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
38
+ self.act = nn.GELU()
39
+ self.linear_2 = nn.Linear(
40
+ self.hidden_size, config.text_config.hidden_size, bias=True
41
+ )
42
+
43
+ def __call__(self, image_features: list[mx.array]) -> mx.array:
44
+ image_features = mx.concatenate(image_features, axis=0)
45
+ h = self.pre_norm(image_features).reshape(-1, self.hidden_size)
46
+ h = self.linear_1(h)
47
+ h = self.act(h)
48
+ h = self.linear_2(h)
49
+ return h
50
+
51
+
52
+ class Model(nn.Module):
53
+ def __init__(self, config: ModelConfig):
54
+ super().__init__()
55
+ self.model_type = config.model_type
56
+ self.config = config
57
+
58
+ self.vision_tower = VisionModel(config.vision_config)
59
+ self.language_model = LanguageModel(config.text_config)
60
+ self.multi_modal_projector = KimiVLMultiModalProjector(config)
61
+
62
+ def get_input_embeddings(
63
+ self,
64
+ input_ids: Optional[mx.array] = None,
65
+ pixel_values: Optional[mx.array] = None,
66
+ **kwargs,
67
+ ):
68
+ image_grid_thw = kwargs.pop("image_grid_hws", None)
69
+ video_grid_thw = kwargs.pop("video_grid_hws", None)
70
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
71
+
72
+ if pixel_values is None:
73
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
74
+ return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
75
+
76
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
77
+
78
+ hidden_state = self.vision_tower(
79
+ pixel_values.transpose(0, 2, 3, 1),
80
+ output_hidden_states=True,
81
+ grid_thw=grid_thw,
82
+ )
83
+
84
+ image_features = self.multi_modal_projector(hidden_state)
85
+
86
+ final_inputs_embeds = self._prepare_inputs_for_multimodal(
87
+ image_features, inputs_embeds, input_ids
88
+ )
89
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
90
+
91
+ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
92
+ image_token_index = self.config.image_token_index
93
+
94
+ # Positions of <image> tokens in input_ids, assuming batch size is 1
95
+ image_positions = np.where(input_ids == image_token_index)[1].tolist()
96
+
97
+ inputs_embeds[:, image_positions, :] = image_features
98
+
99
+ return inputs_embeds
100
+
101
+ @property
102
+ def layers(self):
103
+ return self.language_model.model.layers
104
+
105
+ def __call__(
106
+ self,
107
+ input_ids: mx.array,
108
+ pixel_values: mx.array,
109
+ cache=None,
110
+ **kwargs,
111
+ ):
112
+
113
+ input_embeddings_features = self.get_input_embeddings(
114
+ input_ids, pixel_values, **kwargs
115
+ )
116
+ logits = self.language_model(
117
+ inputs=input_ids,
118
+ cache=cache,
119
+ inputs_embeds=input_embeddings_features.inputs_embeds,
120
+ )
121
+ return logits
122
+
123
+ def sanitize(self, weights):
124
+ return {
125
+ k.replace("encoder.", "") if "vision_tower" in k else k: v
126
+ for k, v in weights.items()
127
+ }
@@ -0,0 +1,460 @@
1
+ import math
2
+ from functools import partial
3
+ from typing import Any, Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from mlx_lm.models.switch_layers import SwitchGLU
8
+
9
+ from ..base import (
10
+ LanguageModelOutput,
11
+ create_attention_mask,
12
+ scaled_dot_product_attention,
13
+ )
14
+ from .config import TextConfig
15
+
16
+
17
+ def yarn_find_correction_dim(
18
+ num_rotations, dim, base=10000, max_position_embeddings=2048
19
+ ):
20
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
21
+ 2 * math.log(base)
22
+ )
23
+
24
+
25
+ def yarn_find_correction_range(
26
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
27
+ ):
28
+ low = math.floor(
29
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
30
+ )
31
+ high = math.ceil(
32
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
33
+ )
34
+ return max(low, 0), min(high, dim - 1)
35
+
36
+
37
+ def yarn_get_mscale(scale=1, mscale=1):
38
+ if scale <= 1:
39
+ return 1.0
40
+ return 0.1 * mscale * math.log(scale) + 1.0
41
+
42
+
43
+ def yarn_linear_ramp_mask(min_val, max_val, dim):
44
+ if min_val == max_val:
45
+ max_val += 0.001 # Prevent singularity
46
+
47
+ linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
48
+ return mx.clip(linear_func, 0, 1)
49
+
50
+
51
+ class DeepseekV3YarnRotaryEmbedding(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim,
55
+ max_position_embeddings=2048,
56
+ base=10000,
57
+ scaling_factor=1.0,
58
+ original_max_position_embeddings=4096,
59
+ beta_fast=32,
60
+ beta_slow=1,
61
+ mscale=1,
62
+ mscale_all_dim=0,
63
+ ):
64
+ super().__init__()
65
+ self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
66
+ scaling_factor, mscale_all_dim
67
+ )
68
+ freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
69
+ freq_inter = scaling_factor * freq_extra
70
+ low, high = yarn_find_correction_range(
71
+ beta_fast,
72
+ beta_slow,
73
+ dim,
74
+ base,
75
+ original_max_position_embeddings,
76
+ )
77
+ freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
78
+ self._freqs = (freq_inter * freq_extra) / (
79
+ freq_inter * freq_mask + freq_extra * (1 - freq_mask)
80
+ )
81
+
82
+ def __call__(self, x, offset=0):
83
+ if self.mscale != 1.0:
84
+ x = self.mscale * x
85
+ return mx.fast.rope(
86
+ x,
87
+ x.shape[-1],
88
+ traditional=True,
89
+ base=None,
90
+ scale=1.0,
91
+ offset=offset,
92
+ freqs=self._freqs,
93
+ )
94
+
95
+
96
+ # A clipped silu to prevent fp16 from overflowing
97
+ @partial(mx.compile, shapeless=True)
98
+ def clipped_silu(x, gate):
99
+ return mx.clip(gate * mx.sigmoid(gate), a_min=-100, a_max=100) * x
100
+
101
+
102
+ class DeepseekV3Attention(nn.Module):
103
+ def __init__(self, config: TextConfig):
104
+ super().__init__()
105
+ self.config = config
106
+ self.hidden_size = config.hidden_size
107
+ self.num_heads = config.num_attention_heads
108
+ self.max_position_embeddings = config.max_position_embeddings
109
+ self.rope_theta = config.rope_theta
110
+ self.q_lora_rank = config.q_lora_rank
111
+ self.qk_rope_head_dim = config.qk_rope_head_dim
112
+ self.kv_lora_rank = config.kv_lora_rank
113
+ self.v_head_dim = config.v_head_dim
114
+ self.qk_nope_head_dim = config.qk_nope_head_dim
115
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
116
+
117
+ self.scale = self.q_head_dim**-0.5
118
+
119
+ if self.q_lora_rank is None:
120
+ self.q_proj = nn.Linear(
121
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
122
+ )
123
+ else:
124
+ self.q_a_proj = nn.Linear(
125
+ self.hidden_size, self.q_lora_rank, bias=config.attention_bias
126
+ )
127
+ self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank, eps=1e-6)
128
+ self.q_b_proj = nn.Linear(
129
+ self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
130
+ )
131
+
132
+ self.kv_a_proj_with_mqa = nn.Linear(
133
+ self.hidden_size,
134
+ self.kv_lora_rank + self.qk_rope_head_dim,
135
+ bias=config.attention_bias,
136
+ )
137
+ self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank, eps=1e-6)
138
+ self.kv_b_proj = nn.Linear(
139
+ self.kv_lora_rank,
140
+ self.num_heads
141
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
142
+ bias=False,
143
+ )
144
+
145
+ self.o_proj = nn.Linear(
146
+ self.num_heads * self.v_head_dim,
147
+ self.hidden_size,
148
+ bias=config.attention_bias,
149
+ )
150
+
151
+ if self.config.rope_scaling is not None:
152
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
153
+ scaling_factor = self.config.rope_scaling["factor"]
154
+ if mscale_all_dim:
155
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
156
+ self.scale = self.scale * mscale * mscale
157
+
158
+ rope_kwargs = {
159
+ key: self.config.rope_scaling[key]
160
+ for key in [
161
+ "original_max_position_embeddings",
162
+ "beta_fast",
163
+ "beta_slow",
164
+ "mscale",
165
+ "mscale_all_dim",
166
+ ]
167
+ if key in self.config.rope_scaling
168
+ }
169
+ self.rope = DeepseekV3YarnRotaryEmbedding(
170
+ dim=self.qk_rope_head_dim,
171
+ max_position_embeddings=self.max_position_embeddings,
172
+ scaling_factor=scaling_factor,
173
+ base=self.rope_theta,
174
+ **rope_kwargs,
175
+ )
176
+ else:
177
+ self.rope = nn.RoPE(
178
+ dims=self.qk_rope_head_dim,
179
+ base=self.rope_theta,
180
+ traditional=True,
181
+ )
182
+
183
+ def __call__(
184
+ self,
185
+ x: mx.array,
186
+ mask: Optional[mx.array] = None,
187
+ cache: Optional[Any] = None,
188
+ ) -> mx.array:
189
+ B, L, D = x.shape
190
+
191
+ if self.q_lora_rank is None:
192
+ q = self.q_proj(x)
193
+ else:
194
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
195
+
196
+ q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
197
+ q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
198
+ compressed_kv = self.kv_a_proj_with_mqa(x)
199
+ compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
200
+ k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
201
+ kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
202
+ kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
203
+
204
+ k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
205
+
206
+ if cache is not None:
207
+ q_pe = self.rope(q_pe, cache.offset)
208
+ k_pe = self.rope(k_pe, cache.offset)
209
+ k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
210
+ keys, values = cache.update_and_fetch(
211
+ mx.concatenate([k_nope, k_pe], axis=-1), values
212
+ )
213
+ else:
214
+ q_pe = self.rope(q_pe)
215
+ k_pe = self.rope(k_pe)
216
+ k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
217
+ keys = mx.concatenate([k_nope, k_pe], axis=-1)
218
+
219
+ queries = mx.concatenate([q_nope, q_pe], axis=-1)
220
+
221
+ output = scaled_dot_product_attention(
222
+ queries, keys, values, cache, scale=self.scale, mask=mask
223
+ )
224
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
225
+ return self.o_proj(output)
226
+
227
+
228
+ class DeepseekV3MLP(nn.Module):
229
+ def __init__(
230
+ self, config: TextConfig, hidden_size: int = None, intermediate_size: int = None
231
+ ):
232
+ super().__init__()
233
+ self.config = config
234
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
235
+ self.intermediate_size = (
236
+ config.intermediate_size if intermediate_size is None else intermediate_size
237
+ )
238
+
239
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
240
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
241
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
242
+
243
+ def __call__(self, x):
244
+ down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
245
+ return down_proj
246
+
247
+
248
+ @mx.compile
249
+ def group_expert_select(
250
+ gates,
251
+ e_score_correction_bias,
252
+ top_k,
253
+ n_group,
254
+ topk_group,
255
+ routed_scaling_factor,
256
+ norm_topk_prob,
257
+ ):
258
+
259
+ k = top_k
260
+ scores = mx.sigmoid(gates.astype(mx.float32))
261
+ orig_scores = scores
262
+ scores = scores + e_score_correction_bias
263
+ scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
264
+ group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
265
+ k = n_group - topk_group
266
+ group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
267
+ scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
268
+ scores = mx.flatten(scores, -2, -1)
269
+
270
+ k = top_k
271
+ inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
272
+ scores = mx.take_along_axis(orig_scores, inds, axis=-1)
273
+ if top_k > 1 and norm_topk_prob:
274
+ denominator = scores.sum(axis=-1, keepdims=True)
275
+ scores = scores / denominator
276
+ scores = scores * routed_scaling_factor
277
+
278
+ return inds, scores
279
+
280
+
281
+ class MoEGate(nn.Module):
282
+ def __init__(self, config: TextConfig):
283
+ super().__init__()
284
+ self.config = config
285
+ self.top_k = config.num_experts_per_tok
286
+ self.norm_topk_prob = config.norm_topk_prob
287
+ self.n_routed_experts = config.n_routed_experts
288
+ self.routed_scaling_factor = config.routed_scaling_factor
289
+ self.n_group = config.n_group
290
+ self.topk_group = config.topk_group
291
+ self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
292
+ self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
293
+ assert config.topk_method == "noaux_tc", "Unsupported topk method."
294
+
295
+ def __call__(self, x):
296
+ return group_expert_select(
297
+ x @ self.weight.T,
298
+ self.e_score_correction_bias,
299
+ self.top_k,
300
+ self.n_group,
301
+ self.topk_group,
302
+ self.routed_scaling_factor,
303
+ self.norm_topk_prob,
304
+ )
305
+
306
+
307
+ class DeepseekV3MoE(nn.Module):
308
+ def __init__(self, config: TextConfig):
309
+ super().__init__()
310
+ self.config = config
311
+ self.num_experts_per_tok = config.num_experts_per_tok
312
+ self.switch_mlp = SwitchGLU(
313
+ config.hidden_size,
314
+ config.moe_intermediate_size,
315
+ config.n_routed_experts,
316
+ activation=clipped_silu,
317
+ )
318
+
319
+ self.gate = MoEGate(config)
320
+ if config.n_shared_experts is not None:
321
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
322
+ self.shared_experts = DeepseekV3MLP(
323
+ config=config, intermediate_size=intermediate_size
324
+ )
325
+
326
+ def __call__(self, x):
327
+ inds, scores = self.gate(x)
328
+ y = self.switch_mlp(x, inds)
329
+ y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
330
+ if self.config.n_shared_experts is not None:
331
+ y = y + self.shared_experts(x)
332
+
333
+ return y
334
+
335
+
336
+ class DeepseekV3DecoderLayer(nn.Module):
337
+ def __init__(self, config: TextConfig, layer_idx: int):
338
+ super().__init__()
339
+ self.self_attn = DeepseekV3Attention(config)
340
+ self.mlp = (
341
+ DeepseekV3MoE(config)
342
+ if (
343
+ config.n_routed_experts is not None
344
+ and layer_idx >= config.first_k_dense_replace
345
+ and layer_idx % config.moe_layer_freq == 0
346
+ )
347
+ else DeepseekV3MLP(config)
348
+ )
349
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
350
+ self.post_attention_layernorm = nn.RMSNorm(
351
+ config.hidden_size, eps=config.rms_norm_eps
352
+ )
353
+
354
+ def __call__(
355
+ self,
356
+ x: mx.array,
357
+ mask: Optional[mx.array] = None,
358
+ cache: Optional[Any] = None,
359
+ ) -> mx.array:
360
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
361
+ h = x + r
362
+ r = self.mlp(self.post_attention_layernorm(h))
363
+ return h + r
364
+
365
+
366
+ class DeepseekV3Model(nn.Module):
367
+ def __init__(self, config: TextConfig):
368
+ super().__init__()
369
+ self.vocab_size = config.vocab_size
370
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
371
+ self.layers = [
372
+ DeepseekV3DecoderLayer(config, idx)
373
+ for idx in range(config.num_hidden_layers)
374
+ ]
375
+ self.start_idx = 0
376
+ self.end_idx = len(self.layers)
377
+ self.num_layers = self.end_idx
378
+
379
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
380
+
381
+ def __call__(
382
+ self,
383
+ x: mx.array,
384
+ inputs_embeds: Optional[mx.array] = None,
385
+ cache: Optional[Any] = None,
386
+ mask: Optional[mx.array] = None,
387
+ ) -> mx.array:
388
+
389
+ if inputs_embeds is None:
390
+ h = self.embed_tokens(x)
391
+ else:
392
+ h = inputs_embeds
393
+
394
+ if mask is None:
395
+ mask = create_attention_mask(h, cache)
396
+
397
+ if cache is None:
398
+ cache = [None] * self.num_layers
399
+
400
+ for layer, c in zip(self.layers, cache):
401
+ h = layer(h, mask, c)
402
+
403
+ return self.norm(h)
404
+
405
+
406
+ class LanguageModel(nn.Module):
407
+ def __init__(self, config: TextConfig):
408
+ super().__init__()
409
+ self.config = config
410
+ self.model_type = config.model_type
411
+ self.model = DeepseekV3Model(config)
412
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
413
+
414
+ def __call__(
415
+ self,
416
+ inputs: mx.array,
417
+ inputs_embeds: Optional[mx.array] = None,
418
+ cache: Optional[Any] = None,
419
+ mask: Optional[mx.array] = None,
420
+ **kwargs, # Accept and ignore extra kwargs like image_grid_hws
421
+ ):
422
+ out = self.model(inputs, inputs_embeds=inputs_embeds, cache=cache, mask=mask)
423
+ out = self.lm_head(out)
424
+ return LanguageModelOutput(logits=out)
425
+
426
+ def sanitize(self, weights):
427
+ def keep(key):
428
+ return "rotary_emb" not in key
429
+
430
+ weights = {k: v for k, v in weights.items() if keep(k)}
431
+ # Stack experts
432
+ for l in range(self.config.num_hidden_layers):
433
+ prefix = f"language_model.model.layers.{l}"
434
+ for m in [("gate_proj"), ("down_proj"), ("up_proj")]:
435
+ for k in ["weight", "scales", "biases"]:
436
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
437
+ to_join = [
438
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
439
+ for e in range(self.config.n_routed_experts)
440
+ ]
441
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
442
+
443
+ return weights
444
+
445
+ def embed_tokens(self, x):
446
+ return self.model.embed_tokens(x)
447
+
448
+ @property
449
+ def layers(self):
450
+ return self.model.layers[self.model.start_idx : self.model.end_idx]
451
+
452
+ @property
453
+ def n_kv_heads(self):
454
+ return self.config.num_key_value_heads
455
+
456
+ def cast_predicate(self):
457
+ def predicate(k):
458
+ return "e_score_correction_bias" not in k
459
+
460
+ return predicate