fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,547 @@
1
+ import math
2
+ from typing import Any, Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from .config import TextConfig
14
+
15
+
16
+ def yarn_find_correction_dim(
17
+ num_rotations, dim, base=10000, max_position_embeddings=2048
18
+ ):
19
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
20
+ 2 * math.log(base)
21
+ )
22
+
23
+
24
+ def yarn_find_correction_range(
25
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
26
+ ):
27
+ low = math.floor(
28
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
29
+ )
30
+ high = math.ceil(
31
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
32
+ )
33
+ return max(low, 0), min(high, dim - 1)
34
+
35
+
36
+ def yarn_get_mscale(scale=1, mscale=1):
37
+ if scale <= 1:
38
+ return 1.0
39
+ return 0.1 * mscale * math.log(scale) + 1.0
40
+
41
+
42
+ def yarn_linear_ramp_mask(min_val, max_val, dim):
43
+ if min_val == max_val:
44
+ max_val += 0.001 # Prevent singularity
45
+
46
+ linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
47
+ return mx.clip(linear_func, 0, 1)
48
+
49
+
50
+ class DeepseekV2YarnRotaryEmbedding(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim,
54
+ max_position_embeddings=2048,
55
+ base=10000,
56
+ scaling_factor=1.0,
57
+ original_max_position_embeddings=4096,
58
+ beta_fast=32,
59
+ beta_slow=1,
60
+ mscale=1,
61
+ mscale_all_dim=0,
62
+ ):
63
+ super().__init__()
64
+ self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
65
+ scaling_factor, mscale_all_dim
66
+ )
67
+ freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
68
+ freq_inter = scaling_factor * base ** (
69
+ mx.arange(0, dim, 2, dtype=mx.float32) / dim
70
+ )
71
+ low, high = yarn_find_correction_range(
72
+ beta_fast,
73
+ beta_slow,
74
+ dim,
75
+ base,
76
+ original_max_position_embeddings,
77
+ )
78
+ freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
79
+ self._freqs = (freq_inter * freq_extra) / (
80
+ freq_inter * freq_mask + freq_extra * (1 - freq_mask)
81
+ )
82
+
83
+ def __call__(self, x, offset=0):
84
+ if self.mscale != 1.0:
85
+ x = self.mscale * x
86
+ return mx.fast.rope(
87
+ x,
88
+ x.shape[-1],
89
+ traditional=True,
90
+ base=None,
91
+ scale=1.0,
92
+ offset=offset,
93
+ freqs=self._freqs,
94
+ )
95
+
96
+
97
+ class DeepseekV2Attention(nn.Module):
98
+ def __init__(self, config: TextConfig):
99
+ super().__init__()
100
+ self.config = config
101
+ self.hidden_size = config.hidden_size
102
+ self.num_heads = config.num_attention_heads
103
+ self.max_position_embeddings = config.max_position_embeddings
104
+ self.rope_theta = config.rope_theta
105
+ self.q_lora_rank = config.q_lora_rank
106
+ self.qk_rope_head_dim = config.qk_rope_head_dim
107
+ self.kv_lora_rank = config.kv_lora_rank
108
+ self.v_head_dim = config.v_head_dim
109
+ self.qk_nope_head_dim = config.qk_nope_head_dim
110
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
111
+
112
+ self.scale = self.q_head_dim**-0.5
113
+
114
+ if self.q_lora_rank is None:
115
+ self.q_proj = nn.Linear(
116
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
117
+ )
118
+ else:
119
+ self.q_a_proj = nn.Linear(
120
+ self.hidden_size, self.q_lora_rank, bias=config.attention_bias
121
+ )
122
+ self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
123
+ self.q_b_proj = nn.Linear(
124
+ self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
125
+ )
126
+
127
+ self.kv_a_proj_with_mqa = nn.Linear(
128
+ self.hidden_size,
129
+ self.kv_lora_rank + self.qk_rope_head_dim,
130
+ bias=config.attention_bias,
131
+ )
132
+ self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
133
+ self.kv_b_proj = nn.Linear(
134
+ self.kv_lora_rank,
135
+ self.num_heads
136
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
137
+ bias=False,
138
+ )
139
+
140
+ self.o_proj = nn.Linear(
141
+ self.num_heads * self.v_head_dim,
142
+ self.hidden_size,
143
+ bias=config.attention_bias,
144
+ )
145
+
146
+ if self.config.rope_scaling is None:
147
+ self.rope = nn.RoPE(
148
+ self.qk_rope_head_dim,
149
+ traditional=self.config.rope_traditional,
150
+ base=self.rope_theta,
151
+ )
152
+ else:
153
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
154
+ scaling_factor = self.config.rope_scaling.get("factor", 1)
155
+ if mscale_all_dim:
156
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
157
+ self.scale = self.scale * mscale * mscale
158
+
159
+ rope_kwargs = {
160
+ key: self.config.rope_scaling[key]
161
+ for key in [
162
+ "original_max_position_embeddings",
163
+ "beta_fast",
164
+ "beta_slow",
165
+ "mscale",
166
+ "mscale_all_dim",
167
+ ]
168
+ if key in self.config.rope_scaling
169
+ }
170
+ self.rope = DeepseekV2YarnRotaryEmbedding(
171
+ dim=self.qk_rope_head_dim,
172
+ max_position_embeddings=self.max_position_embeddings,
173
+ scaling_factor=scaling_factor,
174
+ base=self.rope_theta,
175
+ **rope_kwargs,
176
+ )
177
+
178
+ def __call__(
179
+ self,
180
+ x: mx.array,
181
+ mask: Optional[mx.array] = None,
182
+ cache: Optional[Any] = None,
183
+ ) -> mx.array:
184
+ B, L, D = x.shape
185
+
186
+ if self.q_lora_rank is None:
187
+ q = self.q_proj(x)
188
+ else:
189
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
190
+
191
+ q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
192
+ q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
193
+ compressed_kv = self.kv_a_proj_with_mqa(x)
194
+ compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
195
+ k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
196
+ kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
197
+ kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
198
+
199
+ k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
200
+
201
+ if cache is not None:
202
+ q_pe = self.rope(q_pe, cache.offset)
203
+ k_pe = self.rope(k_pe, cache.offset)
204
+ k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
205
+ keys, values = cache.update_and_fetch(
206
+ mx.concatenate([k_nope, k_pe], axis=-1), values
207
+ )
208
+ else:
209
+ q_pe = self.rope(q_pe)
210
+ k_pe = self.rope(k_pe)
211
+ k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
212
+ keys = mx.concatenate([k_nope, k_pe], axis=-1)
213
+
214
+ queries = mx.concatenate([q_nope, q_pe], axis=-1)
215
+
216
+ output = scaled_dot_product_attention(
217
+ queries, keys, values, cache, scale=self.scale, mask=mask
218
+ )
219
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
220
+ return self.o_proj(output)
221
+
222
+
223
+ class LlamaAttention(nn.Module):
224
+ def __init__(self, config: TextConfig):
225
+ super().__init__()
226
+
227
+ dim = config.hidden_size
228
+ self.n_heads = n_heads = config.num_attention_heads
229
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
230
+
231
+ self.head_dim = head_dim = config.hidden_size // n_heads
232
+
233
+ self.scale = head_dim**-0.5
234
+ if config.attention_bias:
235
+ attention_bias = config.attention_bias
236
+ else:
237
+ attention_bias = False
238
+
239
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
240
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
241
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
242
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
243
+
244
+ rope_scale = (
245
+ 1 / config.rope_scaling["factor"]
246
+ if config.rope_scaling is not None
247
+ and config.rope_scaling["type"] == "linear"
248
+ else 1
249
+ )
250
+ self.rope = nn.RoPE(
251
+ head_dim,
252
+ traditional=config.rope_traditional,
253
+ base=config.rope_theta,
254
+ scale=rope_scale,
255
+ )
256
+
257
+ def __call__(
258
+ self,
259
+ x: mx.array,
260
+ mask: Optional[mx.array] = None,
261
+ cache: Optional[Any] = None,
262
+ ) -> mx.array:
263
+ B, L, D = x.shape
264
+
265
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
266
+
267
+ # Prepare the queries, keys and values for the attention computation
268
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
269
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
270
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
271
+
272
+ if cache is not None:
273
+ queries = self.rope(queries, offset=cache.offset)
274
+ keys = self.rope(keys, offset=cache.offset)
275
+ keys, values = cache.update_and_fetch(keys, values)
276
+ else:
277
+ queries = self.rope(queries)
278
+ keys = self.rope(keys)
279
+
280
+ output = scaled_dot_product_attention(
281
+ queries, keys, values, cache, scale=self.scale, mask=mask
282
+ )
283
+
284
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
285
+ return self.o_proj(output)
286
+
287
+
288
+ class DeepseekV2MLP(nn.Module):
289
+ def __init__(
290
+ self, config: TextConfig, hidden_size: int = None, intermediate_size: int = None
291
+ ):
292
+ super().__init__()
293
+ self.config = config
294
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
295
+ self.intermediate_size = (
296
+ config.intermediate_size if intermediate_size is None else intermediate_size
297
+ )
298
+
299
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
300
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
301
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
302
+
303
+ def __call__(self, x):
304
+ down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
305
+ return down_proj
306
+
307
+
308
+ class MoEGate(nn.Module):
309
+ def __init__(self, config: TextConfig):
310
+ super().__init__()
311
+ self.config = config
312
+ self.scoring_func = config.scoring_func
313
+ self.top_k = config.num_experts_per_tok
314
+ self.n_routed_experts = config.n_routed_experts
315
+ self.routed_scaling_factor = config.routed_scaling_factor
316
+ self.topk_method = config.topk_method
317
+ self.n_group = config.n_group
318
+ self.topk_group = config.topk_group
319
+ if self.topk_method == "noaux_tc":
320
+ self.e_score_correction_bias = mx.zeros((self.n_routed_experts))
321
+ self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
322
+
323
+ def __call__(self, x):
324
+ bsz, seq_len = x.shape[:2]
325
+ # Layer-specific forced gate takes precedence
326
+ layer_idx = getattr(self, "_layer_index", None)
327
+
328
+ gates = x @ self.weight.T
329
+
330
+ if self.scoring_func == "softmax":
331
+ scores = mx.softmax(gates, axis=-1, precise=True) # type: ignore[call-arg]
332
+ elif self.scoring_func == "sigmoid":
333
+ scores = mx.sigmoid(gates)
334
+ else:
335
+ raise ValueError(f"Unknown scoring function: {self.scoring_func}")
336
+
337
+ if self.topk_method == "greedy":
338
+ flat_scores = scores
339
+ k = self.top_k
340
+ inds = mx.argpartition(flat_scores, kth=-k, axis=-1)[..., -k:]
341
+ scores_selected = mx.take_along_axis(flat_scores, inds, axis=-1)
342
+
343
+ elif self.topk_method == "noaux_tc":
344
+ experts_per_group = self.n_routed_experts // self.n_group
345
+ topk_group = self.topk_group if self.topk_group else self.n_group
346
+
347
+ scores_bt = scores.reshape(bsz, seq_len, self.n_routed_experts)
348
+ bias = self.e_score_correction_bias.astype(mx.float32)
349
+ bias = mx.reshape(bias, (1, 1, self.n_routed_experts))
350
+ corrected = scores_bt + bias
351
+ corrected_groups = corrected.reshape(
352
+ bsz, seq_len, self.n_group, experts_per_group
353
+ )
354
+
355
+ sorted_groups = mx.sort(corrected_groups, axis=-1)
356
+ if experts_per_group >= 2:
357
+ top_values = sorted_groups[..., -2:]
358
+ else:
359
+ top_values = sorted_groups
360
+ group_scores = mx.sum(top_values, axis=-1)
361
+
362
+ if topk_group < self.n_group:
363
+ kth_group = self.n_group - topk_group
364
+ group_idx = mx.argpartition(group_scores, kth=kth_group, axis=-1)[
365
+ ..., -topk_group:
366
+ ]
367
+ group_axis = mx.reshape(
368
+ mx.arange(self.n_group, dtype=mx.int32), (1, 1, self.n_group)
369
+ )
370
+ mask = mx.zeros(group_scores.shape, dtype=mx.bool_)
371
+ for slot in range(topk_group):
372
+ idx = mx.expand_dims(group_idx[..., slot], axis=-1)
373
+ mask = mx.logical_or(mask, mx.equal(idx, group_axis))
374
+ else:
375
+ mask = mx.ones(group_scores.shape, dtype=mx.bool_)
376
+
377
+ mask_expanded = mx.expand_dims(mask, axis=-1)
378
+ corrected_masked = mx.where(
379
+ mask_expanded,
380
+ corrected_groups,
381
+ mx.zeros_like(corrected_groups),
382
+ )
383
+
384
+ corrected_flat = corrected_masked.reshape(bsz, seq_len, -1)
385
+ raw_flat = scores_bt
386
+
387
+ total_experts = corrected_flat.shape[-1]
388
+ kth_value = max(total_experts - self.top_k, 0)
389
+ inds = mx.argpartition(corrected_flat, kth=kth_value, axis=-1)[
390
+ ..., -self.top_k :
391
+ ].astype(mx.int32)
392
+ scores_selected = mx.take_along_axis(raw_flat, inds, axis=-1)
393
+
394
+ else:
395
+ raise ValueError(f"Unknown topk method: {self.topk_method}")
396
+
397
+ scores_selected = scores_selected * self.routed_scaling_factor
398
+ return inds, scores_selected
399
+
400
+
401
+ class DeepseekV2MoE(nn.Module):
402
+ def __init__(self, config: TextConfig):
403
+ super().__init__()
404
+ self.config = config
405
+ self.num_experts_per_tok = config.num_experts_per_tok
406
+ self.switch_mlp = SwitchGLU(
407
+ config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
408
+ )
409
+
410
+ self.gate = MoEGate(config)
411
+ if config.n_shared_experts is not None:
412
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
413
+ self.shared_experts = DeepseekV2MLP(
414
+ config=config, intermediate_size=intermediate_size
415
+ )
416
+
417
+ def __call__(self, x):
418
+ inds, scores = self.gate(x)
419
+ y = self.switch_mlp(x, inds)
420
+ y = (y * scores[..., None]).sum(axis=-2)
421
+ if self.config.n_shared_experts is not None:
422
+ y = y + self.shared_experts(x)
423
+
424
+ return y
425
+
426
+
427
+ class DeepseekV2DecoderLayer(nn.Module):
428
+ def __init__(self, config: TextConfig, layer_idx: int):
429
+ super().__init__()
430
+ self.attn_type = config.attn_type
431
+ self.self_attn = (
432
+ DeepseekV2Attention(config)
433
+ if self.attn_type == "DeepseekV2Attention"
434
+ else LlamaAttention(config)
435
+ )
436
+ self.mlp = (
437
+ DeepseekV2MoE(config)
438
+ if (
439
+ config.n_routed_experts is not None
440
+ and layer_idx >= config.first_k_dense_replace
441
+ and layer_idx % config.moe_layer_freq == 0
442
+ )
443
+ else DeepseekV2MLP(config)
444
+ )
445
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
446
+ self.post_attention_layernorm = nn.RMSNorm(
447
+ config.hidden_size, eps=config.rms_norm_eps
448
+ )
449
+
450
+ def __call__(
451
+ self,
452
+ x: mx.array,
453
+ mask: Optional[mx.array] = None,
454
+ cache: Optional[Any] = None,
455
+ ) -> mx.array:
456
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
457
+ h = x + r
458
+ r = self.mlp(self.post_attention_layernorm(h))
459
+ out = h + r
460
+ return out
461
+
462
+
463
+ class DeepseekV2Model(nn.Module):
464
+ def __init__(self, config: TextConfig):
465
+ super().__init__()
466
+ self.vocab_size = config.vocab_size
467
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
468
+ self.layers = [
469
+ DeepseekV2DecoderLayer(config, idx)
470
+ for idx in range(config.num_hidden_layers)
471
+ ]
472
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
473
+
474
+ def __call__(
475
+ self,
476
+ x: mx.array,
477
+ mask: Optional[mx.array] = None,
478
+ inputs_embeds: Optional[mx.array] = None,
479
+ cache: Optional[Any] = None,
480
+ ) -> mx.array:
481
+
482
+ if inputs_embeds is None:
483
+ h = self.embed_tokens(x)
484
+ else:
485
+ h = inputs_embeds
486
+
487
+ if cache is None:
488
+ cache = [None] * len(self.layers)
489
+
490
+ mask = create_attention_mask(h, cache[0])
491
+
492
+ for layer, c in zip(self.layers, cache):
493
+ h = layer(h, mask, c)
494
+
495
+ return self.norm(h)
496
+
497
+
498
+ class LanguageModel(nn.Module):
499
+ def __init__(self, config: TextConfig):
500
+ super().__init__()
501
+ self.config = config
502
+ self.model_type = config.model_type
503
+ self.model = DeepseekV2Model(config)
504
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
505
+
506
+ def __call__(
507
+ self,
508
+ inputs: mx.array,
509
+ inputs_embeds: Optional[mx.array] = None,
510
+ mask: Optional[mx.array] = None,
511
+ cache: Optional[Any] = None,
512
+ **kwargs,
513
+ ):
514
+ out = self.model(inputs, mask=mask, inputs_embeds=inputs_embeds, cache=cache)
515
+ out = self.lm_head(out)
516
+ return LanguageModelOutput(logits=out)
517
+
518
+ def sanitize(self, weights):
519
+ for l in range(self.config.num_hidden_layers):
520
+ prefix = f"language_model.model.layers.{l}"
521
+ for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
522
+ for k in ["weight", "scales", "biases"]:
523
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
524
+ to_join = [
525
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
526
+ for e in range(self.config.n_routed_experts)
527
+ ]
528
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
529
+ return weights
530
+
531
+ @property
532
+ def layers(self):
533
+ return self.model.layers
534
+
535
+ @property
536
+ def head_dim(self):
537
+ if self.config.attn_type == "DeepseekV2Attention":
538
+ return (
539
+ self.config.qk_nope_head_dim + self.config.qk_rope_head_dim,
540
+ self.config.v_head_dim,
541
+ )
542
+ else:
543
+ return self.config.hidden_size // self.config.num_key_value_heads
544
+
545
+ @property
546
+ def n_kv_heads(self):
547
+ return self.config.num_key_value_heads