fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,539 @@
1
+ import math
2
+ from typing import Any, Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from .config import TextConfig
14
+
15
+
16
+ def yarn_find_correction_dim(
17
+ num_rotations, dim, base=10000, max_position_embeddings=2048
18
+ ):
19
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
20
+ 2 * math.log(base)
21
+ )
22
+
23
+
24
+ def yarn_find_correction_range(
25
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
26
+ ):
27
+ low = math.floor(
28
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
29
+ )
30
+ high = math.ceil(
31
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
32
+ )
33
+ return max(low, 0), min(high, dim - 1)
34
+
35
+
36
+ def yarn_get_mscale(scale=1, mscale=1):
37
+ if scale <= 1:
38
+ return 1.0
39
+ return 0.1 * mscale * math.log(scale) + 1.0
40
+
41
+
42
+ def yarn_linear_ramp_mask(min_val, max_val, dim):
43
+ if min_val == max_val:
44
+ max_val += 0.001 # Prevent singularity
45
+
46
+ linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
47
+ return mx.clip(linear_func, 0, 1)
48
+
49
+
50
+ class DeepseekV2YarnRotaryEmbedding(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim,
54
+ max_position_embeddings=2048,
55
+ base=10000,
56
+ scaling_factor=1.0,
57
+ original_max_position_embeddings=4096,
58
+ beta_fast=32,
59
+ beta_slow=1,
60
+ mscale=1,
61
+ mscale_all_dim=0,
62
+ ):
63
+ super().__init__()
64
+ self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
65
+ scaling_factor, mscale_all_dim
66
+ )
67
+ freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
68
+ freq_inter = scaling_factor * base ** (
69
+ mx.arange(0, dim, 2, dtype=mx.float32) / dim
70
+ )
71
+ low, high = yarn_find_correction_range(
72
+ beta_fast,
73
+ beta_slow,
74
+ dim,
75
+ base,
76
+ original_max_position_embeddings,
77
+ )
78
+ freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
79
+ self._freqs = (freq_inter * freq_extra) / (
80
+ freq_inter * freq_mask + freq_extra * (1 - freq_mask)
81
+ )
82
+
83
+ def __call__(self, x, offset=0):
84
+ if self.mscale != 1.0:
85
+ x = self.mscale * x
86
+ return mx.fast.rope(
87
+ x,
88
+ x.shape[-1],
89
+ traditional=True,
90
+ base=None,
91
+ scale=1.0,
92
+ offset=offset,
93
+ freqs=self._freqs,
94
+ )
95
+
96
+
97
+ class DeepseekV2Attention(nn.Module):
98
+ def __init__(self, config: TextConfig):
99
+ super().__init__()
100
+ self.config = config
101
+ self.hidden_size = config.hidden_size
102
+ self.num_heads = config.num_attention_heads
103
+ self.max_position_embeddings = config.max_position_embeddings
104
+ self.rope_theta = config.rope_theta
105
+ self.q_lora_rank = config.q_lora_rank
106
+ self.qk_rope_head_dim = config.qk_rope_head_dim
107
+ self.kv_lora_rank = config.kv_lora_rank
108
+ self.v_head_dim = config.v_head_dim
109
+ self.qk_nope_head_dim = config.qk_nope_head_dim
110
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
111
+
112
+ self.scale = self.q_head_dim**-0.5
113
+
114
+ if self.q_lora_rank is None:
115
+ self.q_proj = nn.Linear(
116
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
117
+ )
118
+ else:
119
+ self.q_a_proj = nn.Linear(
120
+ self.hidden_size, self.q_lora_rank, bias=config.attention_bias
121
+ )
122
+ self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
123
+ self.q_b_proj = nn.Linear(
124
+ self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
125
+ )
126
+
127
+ self.kv_a_proj_with_mqa = nn.Linear(
128
+ self.hidden_size,
129
+ self.kv_lora_rank + self.qk_rope_head_dim,
130
+ bias=config.attention_bias,
131
+ )
132
+ self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
133
+ self.kv_b_proj = nn.Linear(
134
+ self.kv_lora_rank,
135
+ self.num_heads
136
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
137
+ bias=False,
138
+ )
139
+
140
+ self.o_proj = nn.Linear(
141
+ self.num_heads * self.v_head_dim,
142
+ self.hidden_size,
143
+ bias=config.attention_bias,
144
+ )
145
+
146
+ if self.config.rope_scaling is None:
147
+ self.rope = nn.RoPE(
148
+ self.qk_rope_head_dim,
149
+ traditional=self.config.rope_traditional,
150
+ base=self.rope_theta,
151
+ )
152
+ else:
153
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
154
+ scaling_factor = self.config.rope_scaling.get("factor", 1)
155
+ if mscale_all_dim:
156
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
157
+ self.scale = self.scale * mscale * mscale
158
+
159
+ rope_kwargs = {
160
+ key: self.config.rope_scaling[key]
161
+ for key in [
162
+ "original_max_position_embeddings",
163
+ "beta_fast",
164
+ "beta_slow",
165
+ "mscale",
166
+ "mscale_all_dim",
167
+ ]
168
+ if key in self.config.rope_scaling
169
+ }
170
+ self.rope = DeepseekV2YarnRotaryEmbedding(
171
+ dim=self.qk_rope_head_dim,
172
+ max_position_embeddings=self.max_position_embeddings,
173
+ scaling_factor=scaling_factor,
174
+ base=self.rope_theta,
175
+ **rope_kwargs,
176
+ )
177
+
178
+ def __call__(
179
+ self,
180
+ x: mx.array,
181
+ mask: Optional[mx.array] = None,
182
+ cache: Optional[Any] = None,
183
+ ) -> mx.array:
184
+ B, L, D = x.shape
185
+
186
+ if self.q_lora_rank is None:
187
+ q = self.q_proj(x)
188
+ else:
189
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
190
+
191
+ q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
192
+ q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
193
+ compressed_kv = self.kv_a_proj_with_mqa(x)
194
+ compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
195
+ k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
196
+ kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
197
+ kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
198
+
199
+ k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
200
+
201
+ if cache is not None:
202
+ q_pe = self.rope(q_pe, cache.offset)
203
+ k_pe = self.rope(k_pe, cache.offset)
204
+ k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
205
+ keys, values = cache.update_and_fetch(
206
+ mx.concatenate([k_nope, k_pe], axis=-1), values
207
+ )
208
+ else:
209
+ q_pe = self.rope(q_pe)
210
+ k_pe = self.rope(k_pe)
211
+ k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
212
+ keys = mx.concatenate([k_nope, k_pe], axis=-1)
213
+
214
+ queries = mx.concatenate([q_nope, q_pe], axis=-1)
215
+
216
+ output = scaled_dot_product_attention(
217
+ queries, keys, values, cache, scale=self.scale, mask=mask
218
+ )
219
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
220
+ return self.o_proj(output)
221
+
222
+
223
+ class LlamaAttention(nn.Module):
224
+ def __init__(self, config: TextConfig):
225
+ super().__init__()
226
+
227
+ dim = config.hidden_size
228
+ self.n_heads = n_heads = config.num_attention_heads
229
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
230
+
231
+ self.head_dim = head_dim = config.hidden_size // n_heads
232
+
233
+ self.scale = head_dim**-0.5
234
+ if config.attention_bias:
235
+ attention_bias = config.attention_bias
236
+ else:
237
+ attention_bias = False
238
+
239
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
240
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
241
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
242
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
243
+
244
+ rope_scale = (
245
+ 1 / config.rope_scaling["factor"]
246
+ if config.rope_scaling is not None
247
+ and config.rope_scaling["type"] == "linear"
248
+ else 1
249
+ )
250
+ self.rope = nn.RoPE(
251
+ head_dim,
252
+ traditional=config.rope_traditional,
253
+ base=config.rope_theta,
254
+ scale=rope_scale,
255
+ )
256
+
257
+ def __call__(
258
+ self,
259
+ x: mx.array,
260
+ mask: Optional[mx.array] = None,
261
+ cache: Optional[Any] = None,
262
+ ) -> mx.array:
263
+ B, L, D = x.shape
264
+
265
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
266
+
267
+ # Prepare the queries, keys and values for the attention computation
268
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
269
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
270
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
271
+
272
+ if cache is not None:
273
+ queries = self.rope(queries, offset=cache.offset)
274
+ keys = self.rope(keys, offset=cache.offset)
275
+ keys, values = cache.update_and_fetch(keys, values)
276
+ else:
277
+ queries = self.rope(queries)
278
+ keys = self.rope(keys)
279
+
280
+ output = scaled_dot_product_attention(
281
+ queries, keys, values, cache, scale=self.scale, mask=mask
282
+ )
283
+
284
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
285
+ return self.o_proj(output)
286
+
287
+
288
+ class DeepseekV2MLP(nn.Module):
289
+ def __init__(
290
+ self, config: TextConfig, hidden_size: int = None, intermediate_size: int = None
291
+ ):
292
+ super().__init__()
293
+ self.config = config
294
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
295
+ self.intermediate_size = (
296
+ config.intermediate_size if intermediate_size is None else intermediate_size
297
+ )
298
+
299
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
300
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
301
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
302
+
303
+ def __call__(self, x):
304
+ down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
305
+ return down_proj
306
+
307
+
308
+ class MoEGate(nn.Module):
309
+ def __init__(self, config: TextConfig):
310
+ super().__init__()
311
+ self.config = config
312
+ self.scoring_func = config.scoring_func
313
+ self.top_k = config.num_experts_per_tok
314
+ self.n_routed_experts = config.n_routed_experts
315
+ self.routed_scaling_factor = config.routed_scaling_factor
316
+ self.topk_method = config.topk_method
317
+ self.n_group = config.n_group
318
+ self.topk_group = config.topk_group
319
+ if self.topk_method == "noaux_tc":
320
+ self.e_score_correction_bias = mx.zeros((self.n_routed_experts))
321
+ self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
322
+
323
+ def __call__(self, x):
324
+ gates = x @ self.weight.T
325
+
326
+ if self.scoring_func == "softmax":
327
+ scores = mx.softmax(gates, axis=-1, precise=True)
328
+ elif self.scoring_func == "sigmoid":
329
+ scores = mx.sigmoid(gates)
330
+ else:
331
+ raise ValueError(f"Unknown scoring function: {self.scoring_func}")
332
+
333
+ if self.topk_method == "greedy":
334
+ bsz, seq_len = x.shape[:2]
335
+ scores = scores.reshape(bsz, seq_len, self.n_group, -1)
336
+ group_scores = scores.max(axis=-1)
337
+
338
+ # Get top-k groups
339
+ k = self.n_group - self.topk_group
340
+ group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
341
+ batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
342
+ seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
343
+
344
+ # Mask out top-k groups
345
+ scores[batch_idx, seq_idx, group_idx] = 0.0
346
+ scores = scores.reshape(bsz, seq_len, -1)
347
+
348
+ # Get top-k indices and weights
349
+ k = self.top_k
350
+ inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
351
+ scores = mx.take_along_axis(scores, inds, axis=-1)
352
+
353
+ elif self.topk_method == "noaux_tc":
354
+ bsz, seq_len = x.shape[:2]
355
+
356
+ # Add bias correction
357
+ scores_for_choice = scores.reshape(bsz * seq_len, -1) + mx.expand_dims(
358
+ self.e_score_correction_bias, 0
359
+ )
360
+
361
+ # Calculate group scores using top-2 sum per group
362
+ scores_reshaped = scores_for_choice.reshape(bsz * seq_len, self.n_group, -1)
363
+
364
+ # Get top 2 scores per group
365
+ group_scores = mx.topk(scores_reshaped, 2, axis=-1).sum(axis=-1)
366
+
367
+ # Get top groups
368
+ k = self.n_group - self.topk_group
369
+
370
+ # Create mask for selected groups
371
+ group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
372
+ batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
373
+
374
+ seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
375
+ scores[batch_idx, seq_idx, group_idx] = 0.0
376
+
377
+ # Get top-k indices and weights
378
+ k = self.top_k
379
+ inds = mx.argpartition(scores, kth=-k, axis=-1)[..., -k:]
380
+
381
+ # Gather original scores for the selected indices
382
+ scores_flat = scores.reshape(bsz * seq_len, -1)
383
+ batch_idx = mx.expand_dims(mx.arange(bsz * seq_len), 1)
384
+ scores = mx.take(scores_flat, inds + batch_idx * scores_flat.shape[1])
385
+ else:
386
+ raise ValueError(f"Unknown topk method: {self.topk_method}")
387
+
388
+ scores = scores * self.routed_scaling_factor
389
+ return inds, scores
390
+
391
+
392
+ class DeepseekV2MoE(nn.Module):
393
+ def __init__(self, config: TextConfig):
394
+ super().__init__()
395
+ self.config = config
396
+ self.num_experts_per_tok = config.num_experts_per_tok
397
+ self.switch_mlp = SwitchGLU(
398
+ config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
399
+ )
400
+
401
+ self.gate = MoEGate(config)
402
+ if config.n_shared_experts is not None:
403
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
404
+ self.shared_experts = DeepseekV2MLP(
405
+ config=config, intermediate_size=intermediate_size
406
+ )
407
+
408
+ def __call__(self, x):
409
+ inds, scores = self.gate(x)
410
+ y = self.switch_mlp(x, inds)
411
+ y = (y * scores[..., None]).sum(axis=-2)
412
+ if self.config.n_shared_experts is not None:
413
+ y = y + self.shared_experts(x)
414
+
415
+ return y
416
+
417
+
418
+ class DeepseekV2DecoderLayer(nn.Module):
419
+ def __init__(self, config: TextConfig, layer_idx: int):
420
+ super().__init__()
421
+ self.attn_type = config.attn_type
422
+ self.self_attn = (
423
+ DeepseekV2Attention(config)
424
+ if self.attn_type == "DeepseekV2Attention"
425
+ else LlamaAttention(config)
426
+ )
427
+ self.mlp = (
428
+ DeepseekV2MoE(config)
429
+ if (
430
+ config.n_routed_experts is not None
431
+ and layer_idx >= config.first_k_dense_replace
432
+ and layer_idx % config.moe_layer_freq == 0
433
+ )
434
+ else DeepseekV2MLP(config)
435
+ )
436
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
437
+ self.post_attention_layernorm = nn.RMSNorm(
438
+ config.hidden_size, eps=config.rms_norm_eps
439
+ )
440
+
441
+ def __call__(
442
+ self,
443
+ x: mx.array,
444
+ mask: Optional[mx.array] = None,
445
+ cache: Optional[Any] = None,
446
+ ) -> mx.array:
447
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
448
+ h = x + r
449
+ r = self.mlp(self.post_attention_layernorm(h))
450
+ out = h + r
451
+ return out
452
+
453
+
454
+ class DeepseekV2Model(nn.Module):
455
+ def __init__(self, config: TextConfig):
456
+ super().__init__()
457
+ self.vocab_size = config.vocab_size
458
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
459
+ self.layers = [
460
+ DeepseekV2DecoderLayer(config, idx)
461
+ for idx in range(config.num_hidden_layers)
462
+ ]
463
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
464
+
465
+ def __call__(
466
+ self,
467
+ x: mx.array,
468
+ mask: Optional[mx.array] = None,
469
+ inputs_embeds: Optional[mx.array] = None,
470
+ cache: Optional[Any] = None,
471
+ ) -> mx.array:
472
+
473
+ if inputs_embeds is None:
474
+ h = self.embed_tokens(x)
475
+ else:
476
+ h = inputs_embeds
477
+
478
+ if cache is None:
479
+ cache = [None] * len(self.layers)
480
+
481
+ if mask is None:
482
+ mask = create_attention_mask(h, cache)
483
+
484
+ for layer, c in zip(self.layers, cache):
485
+ h = layer(h, mask, c)
486
+
487
+ return self.norm(h)
488
+
489
+
490
+ class LanguageModel(nn.Module):
491
+ def __init__(self, config: TextConfig):
492
+ super().__init__()
493
+ self.config = config
494
+ self.model_type = config.model_type
495
+ self.model = DeepseekV2Model(config)
496
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
497
+
498
+ def __call__(
499
+ self,
500
+ inputs: mx.array,
501
+ inputs_embeds: Optional[mx.array] = None,
502
+ mask: Optional[mx.array] = None,
503
+ cache: Optional[Any] = None,
504
+ **kwargs,
505
+ ):
506
+ out = self.model(inputs, mask=mask, inputs_embeds=inputs_embeds, cache=cache)
507
+ out = self.lm_head(out)
508
+ return LanguageModelOutput(logits=out)
509
+
510
+ def sanitize(self, weights):
511
+ for l in range(self.config.num_hidden_layers):
512
+ prefix = f"language_model.model.layers.{l}"
513
+ for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
514
+ for k in ["weight", "scales", "biases"]:
515
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
516
+ to_join = [
517
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
518
+ for e in range(self.config.n_routed_experts)
519
+ ]
520
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
521
+ return weights
522
+
523
+ @property
524
+ def layers(self):
525
+ return self.model.layers
526
+
527
+ @property
528
+ def head_dim(self):
529
+ if self.config.attn_type == "DeepseekV2Attention":
530
+ return (
531
+ self.config.qk_nope_head_dim + self.config.qk_rope_head_dim,
532
+ self.config.v_head_dim,
533
+ )
534
+ else:
535
+ return self.config.hidden_size // self.config.num_key_value_heads
536
+
537
+ @property
538
+ def n_kv_heads(self):
539
+ return self.config.num_key_value_heads