fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,223 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..kernels import bicubic_interpolate
7
+ from .config import VisionConfig
8
+
9
+
10
+ class Attention(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dims: int,
14
+ num_heads: int,
15
+ query_input_dims: Optional[int] = None,
16
+ key_input_dims: Optional[int] = None,
17
+ value_input_dims: Optional[int] = None,
18
+ value_dims: Optional[int] = None,
19
+ value_output_dims: Optional[int] = None,
20
+ bias: bool = True,
21
+ ):
22
+ super().__init__()
23
+
24
+ if (dims % num_heads) != 0:
25
+ raise ValueError(
26
+ "The input feature dimensions should be divisible by the "
27
+ f"number of heads ({dims} % {num_heads}) != 0"
28
+ )
29
+
30
+ query_input_dims = query_input_dims or dims
31
+ key_input_dims = key_input_dims or dims
32
+ value_input_dims = value_input_dims or key_input_dims
33
+ value_dims = value_dims or dims
34
+ value_output_dims = value_output_dims or dims
35
+
36
+ self.num_heads = num_heads
37
+ head_dim = dims // num_heads
38
+ self.scale = head_dim**-0.5
39
+
40
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
41
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
42
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
43
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
44
+
45
+ def __call__(self, x, mask=None):
46
+ queries = self.q_proj(x)
47
+ keys = self.k_proj(x)
48
+ values = self.v_proj(x)
49
+
50
+ num_heads = self.num_heads
51
+ B, L, D = queries.shape
52
+ _, S, _ = keys.shape
53
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
54
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
55
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
56
+
57
+ output = mx.fast.scaled_dot_product_attention(
58
+ queries, keys, values, scale=self.scale, mask=mask
59
+ )
60
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
61
+ return self.out_proj(output)
62
+
63
+
64
+ class MLP(nn.Module):
65
+ def __init__(self, config: VisionConfig):
66
+ super().__init__()
67
+ self.activation_fn = nn.GELU(approx="precise")
68
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
69
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
70
+
71
+ def __call__(self, x: mx.array) -> mx.array:
72
+ x = self.fc1(x)
73
+ x = self.activation_fn(x)
74
+ x = self.fc2(x)
75
+ return x
76
+
77
+
78
+ class EncoderLayer(nn.Module):
79
+ def __init__(self, config: VisionConfig):
80
+ super().__init__()
81
+ self.embed_dim = config.hidden_size
82
+ self.self_attn = Attention(
83
+ config.hidden_size, config.num_attention_heads, bias=True
84
+ )
85
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
86
+ self.mlp = MLP(config)
87
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
88
+
89
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
90
+ r = self.self_attn(self.layer_norm1(x), mask)
91
+ h = x + r
92
+ r = self.mlp(self.layer_norm2(h))
93
+ return h + r
94
+
95
+
96
+ class Encoder(nn.Module):
97
+ def __init__(self, config: VisionConfig):
98
+ super().__init__()
99
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
100
+
101
+ def __call__(
102
+ self,
103
+ x: mx.array,
104
+ output_hidden_states: Optional[bool] = None,
105
+ mask: Optional[mx.array] = None,
106
+ ) -> mx.array:
107
+ encoder_states = (x,) if output_hidden_states else None
108
+ h = x
109
+ for l in self.layers:
110
+ x = l(x, mask=mask)
111
+ if output_hidden_states:
112
+ encoder_states = encoder_states + (x,)
113
+
114
+ h = x
115
+
116
+ return encoder_states
117
+
118
+
119
+ class VisionEmbeddings(nn.Module):
120
+ def __init__(self, config: VisionConfig):
121
+ super().__init__()
122
+ self.config = config
123
+ self.embed_dim = config.hidden_size
124
+ self.image_size = config.image_size
125
+ self.patch_size = config.patch_size
126
+
127
+ self.patch_embedding = nn.Linear(
128
+ input_dims=config.num_channels * self.patch_size * self.patch_size,
129
+ output_dims=self.embed_dim,
130
+ )
131
+
132
+ self.num_patches = config.num_patches
133
+ self.position_embedding_size = int(self.num_patches**0.5)
134
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
135
+
136
+ @staticmethod
137
+ def resize_positional_embeddings(
138
+ positional_embeddings: mx.array,
139
+ spatial_shapes: mx.array,
140
+ max_length: int,
141
+ ) -> mx.array:
142
+ batch_size = spatial_shapes.shape[0]
143
+ embed_dim = positional_embeddings.shape[-1]
144
+ source_dtype = positional_embeddings.dtype
145
+
146
+ resulted_positional_embeddings = mx.zeros(
147
+ (batch_size, max_length, embed_dim),
148
+ dtype=source_dtype,
149
+ )
150
+
151
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
152
+ positional_embeddings = positional_embeddings.transpose(2, 0, 1)[None, :]
153
+ for i in range(batch_size):
154
+
155
+ height, width = spatial_shapes[i].tolist()
156
+
157
+ resized_embeddings = bicubic_interpolate(
158
+ positional_embeddings,
159
+ size=(height, width),
160
+ )
161
+
162
+ resized_embeddings = resized_embeddings.reshape(
163
+ embed_dim, height * width
164
+ ).transpose(1, 0)
165
+
166
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
167
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
168
+
169
+ return resulted_positional_embeddings
170
+
171
+ def __call__(
172
+ self, pixel_values: mx.array, spatial_shapes: mx.array = None
173
+ ) -> mx.array:
174
+
175
+ target_dtype = self.patch_embedding.weight.dtype
176
+ patch_embeds = self.patch_embedding(pixel_values.astype(target_dtype))
177
+
178
+ positional_embeddings = self.position_embedding.weight.reshape(
179
+ self.position_embedding_size, self.position_embedding_size, -1
180
+ )
181
+ resized_positional_embeddings = self.resize_positional_embeddings(
182
+ positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
183
+ )
184
+
185
+ embeddings = patch_embeds + resized_positional_embeddings
186
+ return embeddings
187
+
188
+
189
+ class VisionModel(nn.Module):
190
+ def __init__(self, config: VisionConfig):
191
+ super().__init__()
192
+ self.model_type = config.model_type
193
+ if self.model_type not in ["lfm2_vl", "siglip2_vision_model"]:
194
+ raise ValueError(f"Unsupported model type: {self.model_type}")
195
+
196
+ self.embeddings = VisionEmbeddings(config)
197
+ self.encoder = Encoder(config)
198
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
199
+
200
+ def __call__(
201
+ self,
202
+ x: mx.array,
203
+ output_hidden_states: Optional[bool] = None,
204
+ spatial_shapes: Optional[mx.array] = None,
205
+ ) -> mx.array:
206
+ x = self.embeddings(x, spatial_shapes=spatial_shapes)
207
+ x = x.astype(self.embeddings.patch_embedding.weight.dtype)
208
+ encoder_outputs = self.encoder(
209
+ x=x, output_hidden_states=output_hidden_states, mask=None
210
+ )
211
+ last_hidden_state = self.post_layernorm(encoder_outputs[-1])
212
+ return encoder_outputs, x, last_hidden_state
213
+
214
+ def sanitize(self, weights):
215
+ sanitized_weights = {}
216
+ for k, v in weights.items():
217
+ if "position_ids" in k:
218
+
219
+ continue
220
+ else:
221
+ sanitized_weights[k] = v
222
+
223
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .llama4 import LanguageModel, Model, VisionModel
@@ -0,0 +1,83 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str
10
+ hidden_size: int
11
+ intermediate_size: int
12
+ num_attention_heads: int
13
+ rms_norm_eps: float
14
+ vocab_size: int
15
+ num_key_value_heads: int
16
+ rope_theta: float = 500000.0
17
+ num_hidden_layers: int = 48
18
+ rope_traditional: bool = False
19
+ rope_scaling: Optional[dict] = None
20
+ tie_word_embeddings: bool = False
21
+ head_dim: int = 128
22
+ hidden_act: str = "silu"
23
+ intermediate_size_mlp: int = 16384
24
+ max_position_embeddings: int = 10485760
25
+ num_experts_per_tok: int = 1
26
+ num_local_experts: int = 16
27
+ attention_dropout: float = 0.0
28
+ use_qk_norm: bool = True
29
+ bos_token_id: int = 200000
30
+ eos_token_id: list = None
31
+ pad_token_id: int = 200018
32
+ attention_chunk_size: int = 8192
33
+ attention_bias: bool = False
34
+ interleave_moe_layer_step: int = 1
35
+ no_rope_layers: list = 4
36
+ output_router_logits: bool = False
37
+ router_aux_loss_coef: float = 0.001
38
+ router_jitter_noise: float = 0.0
39
+ attn_temperature_tuning: int = 4
40
+ floor_scale: float = 8192
41
+ attn_scale: float = 0.1
42
+ moe_layers: list = None
43
+
44
+ def __post_init__(self):
45
+ if self.num_key_value_heads is None:
46
+ self.num_key_value_heads = self.num_attention_heads
47
+
48
+
49
+ @dataclass
50
+ class VisionConfig(BaseModelConfig):
51
+ model_type: str
52
+ hidden_size: int
53
+ image_size: int
54
+ initializer_range: float
55
+ intermediate_size: int
56
+ norm_eps: float
57
+ num_attention_heads: int
58
+ num_channels: int
59
+ num_hidden_layers: int
60
+ patch_size: int
61
+ pixel_shuffle_ratio: float
62
+ projector_dropout: float
63
+ projector_input_dim: int
64
+ projector_output_dim: int
65
+ rope_theta: float
66
+ vision_feature_layer: int
67
+ vision_feature_select_strategy: str
68
+ vision_output_dim: int
69
+
70
+
71
+ @dataclass
72
+ class ModelConfig(BaseModelConfig):
73
+ text_config: TextConfig
74
+ vision_config: VisionConfig
75
+ model_type: str
76
+ ignore_index: int = -100
77
+ image_token_id: int = 200092
78
+ image_token_index: Optional[int] = None
79
+ eos_token_id: Optional[List[int]] = None
80
+
81
+ def __post_init__(self):
82
+ if self.image_token_index is None:
83
+ self.image_token_index = self.image_token_id
@@ -0,0 +1,334 @@
1
+ from typing import Any, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from mlx_lm.models.rope_utils import initialize_rope
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from ..cache import ChunkedKVCache, KVCache
14
+ from .config import TextConfig
15
+
16
+
17
+ class Attention(nn.Module):
18
+ def __init__(self, config: TextConfig, layer_idx: int):
19
+ super().__init__()
20
+
21
+ dim = config.hidden_size
22
+ self.n_heads = n_heads = config.num_attention_heads
23
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
24
+
25
+ self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
26
+ self.attn_temperature_tuning = config.attn_temperature_tuning
27
+ self.floor_scale = config.floor_scale
28
+ self.attn_scale = config.attn_scale
29
+
30
+ self.head_dim = head_dim = config.head_dim or config.hidden_size // n_heads
31
+
32
+ self.scale = head_dim**-0.5
33
+ if hasattr(config, "attention_bias"):
34
+ attention_bias = config.attention_bias
35
+ else:
36
+ attention_bias = False
37
+
38
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
39
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
40
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
41
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
42
+
43
+ self.use_qk_norm = config.use_qk_norm and self.use_rope
44
+
45
+ if self.use_rope:
46
+ self.rope = initialize_rope(
47
+ head_dim,
48
+ config.rope_theta,
49
+ traditional=True,
50
+ scaling_config=config.rope_scaling,
51
+ max_position_embeddings=config.max_position_embeddings,
52
+ )
53
+
54
+ def __call__(
55
+ self,
56
+ x: mx.array,
57
+ mask: Optional[mx.array] = None,
58
+ cache: Optional[Any] = None,
59
+ ) -> mx.array:
60
+ B, L, D = x.shape
61
+
62
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
63
+
64
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
65
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
66
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
67
+
68
+ if cache is not None:
69
+ offset = cache.offset
70
+ else:
71
+ offset = 0
72
+
73
+ if self.use_rope:
74
+ queries = self.rope(queries, offset=offset)
75
+ keys = self.rope(keys, offset=offset)
76
+
77
+ if self.use_qk_norm:
78
+ queries = mx.fast.rms_norm(queries, weight=None, eps=1e-6)
79
+ keys = mx.fast.rms_norm(keys, weight=None, eps=1e-6)
80
+
81
+ if self.attn_temperature_tuning and not self.use_rope:
82
+ attn_scales = (
83
+ mx.log(
84
+ mx.floor(mx.arange(offset + 1, offset + L + 1) / self.floor_scale)
85
+ + 1.0
86
+ )
87
+ * self.attn_scale
88
+ + 1.0
89
+ )
90
+ attn_scales = attn_scales[:, None]
91
+ queries = (queries * attn_scales).astype(queries.dtype)
92
+
93
+ if cache is not None:
94
+ keys, values = cache.update_and_fetch(keys, values)
95
+
96
+ if self.use_rope and isinstance(mask, mx.array):
97
+ key_len = keys.shape[-2]
98
+ if mask.shape[-1] != key_len:
99
+ mask = mask[..., -key_len:]
100
+
101
+ output = scaled_dot_product_attention(
102
+ queries, keys, values, cache, scale=self.scale, mask=mask
103
+ )
104
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
105
+ return self.o_proj(output)
106
+
107
+
108
+ class MLP(nn.Module):
109
+ def __init__(self, config: TextConfig, intermediate_size: int = None):
110
+ super().__init__()
111
+
112
+ dim = config.hidden_size
113
+ hidden_dim = intermediate_size or config.intermediate_size
114
+
115
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
116
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
117
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
118
+
119
+ def __call__(self, x) -> mx.array:
120
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
121
+
122
+
123
+ class MoE(nn.Module):
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ self.top_k = config.num_experts_per_tok
127
+ self.num_experts = config.num_local_experts
128
+ self.experts = SwitchGLU(
129
+ config.hidden_size, config.intermediate_size, self.num_experts
130
+ )
131
+ self.router = nn.Linear(
132
+ config.hidden_size, config.num_local_experts, bias=False
133
+ )
134
+ self.shared_expert = MLP(config)
135
+
136
+ def __call__(self, x) -> mx.array:
137
+ logits = self.router(x)
138
+ k = self.top_k
139
+ indices = mx.argpartition(-logits, kth=k - 1, axis=-1)[..., :k]
140
+ scores = mx.take_along_axis(logits, indices, axis=-1)
141
+ scores = mx.sigmoid(scores.astype(mx.float32)).astype(x.dtype)
142
+
143
+ out = self.experts(x * scores, indices).squeeze(2)
144
+ return out + self.shared_expert(x)
145
+
146
+
147
+ class TransformerBlock(nn.Module):
148
+ def __init__(self, config: TextConfig, layer_idx: int):
149
+ super().__init__()
150
+ self.num_attention_heads = config.num_attention_heads
151
+ self.hidden_size = config.hidden_size
152
+ self.self_attn = Attention(config, layer_idx)
153
+ self.use_chunked_attention = int((layer_idx + 1) % 4 != 0)
154
+ self.is_moe_layer = (layer_idx % config.interleave_moe_layer_step) == (
155
+ config.interleave_moe_layer_step - 1
156
+ )
157
+ if self.is_moe_layer:
158
+ self.feed_forward = MoE(config)
159
+ else:
160
+ self.feed_forward = MLP(config, config.intermediate_size_mlp)
161
+
162
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
163
+ self.post_attention_layernorm = nn.RMSNorm(
164
+ config.hidden_size, eps=config.rms_norm_eps
165
+ )
166
+ self.config = config
167
+
168
+ self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
169
+
170
+ def __call__(
171
+ self,
172
+ x: mx.array,
173
+ mask: Optional[mx.array] = None,
174
+ cache: Optional[Any] = None,
175
+ ) -> mx.array:
176
+
177
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
178
+ h = x + r
179
+ r = self.feed_forward(self.post_attention_layernorm(h))
180
+ out = h + r
181
+ return out
182
+
183
+
184
+ class LlamaModel(nn.Module):
185
+ def __init__(self, config: TextConfig):
186
+ super().__init__()
187
+ self.config = config
188
+ self.vocab_size = config.vocab_size
189
+ self.num_hidden_layers = config.num_hidden_layers
190
+ assert self.vocab_size > 0
191
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
192
+ self.layers = [
193
+ TransformerBlock(config, i) for i in range(config.num_hidden_layers)
194
+ ]
195
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
196
+
197
+ def create_chunked_attention_mask(
198
+ self, seq_len: int, attention_chunk_size: int, start: int = 0, offset: int = 0
199
+ ) -> mx.array:
200
+ """
201
+ Generate the following:
202
+
203
+ 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ |
204
+ '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ |
205
+ '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ |
206
+ 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ |
207
+ '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ |
208
+ '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ |
209
+
210
+ If the chunk size is 3.
211
+ This can just be appplied over the already created attention mask
212
+ """
213
+
214
+ end = offset + seq_len
215
+ linds = mx.arange(start, end)
216
+ rinds = mx.arange(offset, end)[:, None]
217
+ block_pos = mx.abs(
218
+ (linds // attention_chunk_size) - (rinds // attention_chunk_size)
219
+ )
220
+ token_pos = linds <= rinds
221
+ mask = (block_pos == 0) & (token_pos)
222
+ return mask
223
+
224
+ def __call__(
225
+ self,
226
+ input_ids: mx.array = None,
227
+ input_embeds: mx.array = None,
228
+ mask: mx.array = None,
229
+ cache=None,
230
+ ):
231
+ if input_embeds is None:
232
+ h = self.embed_tokens(input_ids)
233
+ else:
234
+ h = input_embeds
235
+
236
+ if mask is None:
237
+ mask = create_attention_mask(h, cache)
238
+
239
+ if cache is not None:
240
+ for idx, c in enumerate(cache):
241
+ if (idx + 1) % 4 != 0:
242
+ c.maybe_trim_front()
243
+ start = cache[0].start_position
244
+ offset = cache[0].offset
245
+ else:
246
+ start = 0
247
+ offset = 0
248
+
249
+ # Create a mask for the chunked attention
250
+ chunk_mask = self.create_chunked_attention_mask(
251
+ h.shape[1], self.config.attention_chunk_size, start, offset
252
+ )
253
+
254
+ if cache is None:
255
+ cache = [None] * len(self.layers)
256
+
257
+ for idx, (layer, c) in enumerate(zip(self.layers, cache)):
258
+ use_chunked_attention = (idx + 1) % 4 != 0
259
+ if use_chunked_attention:
260
+ local_mask = chunk_mask
261
+ else:
262
+ local_mask = mask
263
+ h = layer(h, local_mask, cache=c)
264
+
265
+ return self.norm(h)
266
+
267
+
268
+ class LanguageModel(nn.Module):
269
+ def __init__(self, config: TextConfig):
270
+ super().__init__()
271
+ self.config = config
272
+ self.model_type = config.model_type
273
+ self.model = LlamaModel(self.config)
274
+ self.lm_head = nn.Linear(
275
+ self.config.hidden_size, self.config.vocab_size, bias=False
276
+ )
277
+
278
+ def __call__(
279
+ self,
280
+ inputs: mx.array = None,
281
+ inputs_embeds: mx.array = None,
282
+ mask: mx.array = None,
283
+ cache=None,
284
+ **kwargs,
285
+ ):
286
+ out = self.model(
287
+ input_ids=inputs,
288
+ input_embeds=inputs_embeds,
289
+ mask=mask,
290
+ cache=cache,
291
+ )
292
+ out = self.lm_head(out)
293
+ return LanguageModelOutput(logits=out)
294
+
295
+ def sanitize(self, weights):
296
+ # Rename expert weights for SwitchGLU
297
+ for l in range(self.config.num_hidden_layers):
298
+ prefix = f"language_model.model.layers.{l}.feed_forward.experts"
299
+ if f"{prefix}.gate_up_proj" in weights:
300
+ v = weights.pop(f"{prefix}.gate_up_proj")
301
+ gate_k = f"{prefix}.gate_proj.weight"
302
+ up_k = f"{prefix}.up_proj.weight"
303
+ gate_proj, up_proj = mx.split(v, 2, axis=-1)
304
+ weights[gate_k] = mx.swapaxes(gate_proj, 1, 2)
305
+ weights[up_k] = mx.swapaxes(up_proj, 1, 2)
306
+ if f"{prefix}.down_proj" in weights:
307
+ down_proj = weights.pop(f"{prefix}.down_proj")
308
+ weights[f"{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2)
309
+ return weights
310
+
311
+ @property
312
+ def layers(self):
313
+ return self.model.layers
314
+
315
+ @property
316
+ def n_kv_heads(self):
317
+ return self.config.num_key_value_heads
318
+
319
+ @property
320
+ def head_dim(self):
321
+ return (
322
+ self.config.head_dim
323
+ if self.config.head_dim
324
+ else self.config.hidden_size // self.config.num_attention_heads
325
+ )
326
+
327
+ def make_cache(self):
328
+ caches = []
329
+ for i in range(self.config.num_hidden_layers):
330
+ if (i + 1) % 4 != 0:
331
+ caches.append(ChunkedKVCache(self.config.attention_chunk_size))
332
+ else:
333
+ caches.append(KVCache()) # no chunking for dense layers
334
+ return caches