fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,194 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import InputEmbeddingsFeatures
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel, RMSNorm
10
+ from .vision import VisionModel
11
+
12
+
13
+ class Gemma3MultiModalProjector(nn.Module):
14
+ def __init__(self, config: ModelConfig):
15
+ super().__init__()
16
+ self.mm_input_projection_weight = mx.ones(
17
+ (config.vision_config.hidden_size, config.text_config.hidden_size)
18
+ )
19
+
20
+ self.mm_soft_emb_norm = RMSNorm(
21
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
22
+ )
23
+ self.patches_per_image = int(
24
+ config.vision_config.image_size // config.vision_config.patch_size
25
+ )
26
+ self.tokens_per_side = int(config.text_config.mm_tokens_per_image**0.5)
27
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
28
+ self.avg_pool = nn.AvgPool2d(
29
+ kernel_size=self.kernel_size, stride=self.kernel_size
30
+ )
31
+
32
+ def __call__(self, x: mx.array) -> mx.array:
33
+ b, _, l = x.shape
34
+
35
+ reshaped_vision_outputs = x.transpose(0, 2, 1)
36
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
37
+ b, l, self.patches_per_image, self.patches_per_image
38
+ )
39
+
40
+ # Transpose to place h, w in indices 1, 2
41
+ reshaped_vision_outputs = reshaped_vision_outputs.transpose(0, 2, 3, 1)
42
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
43
+ pooled_vision_outputs = pooled_vision_outputs.transpose(0, 3, 1, 2).flatten(2)
44
+ pooled_vision_outputs = pooled_vision_outputs.transpose(0, 2, 1)
45
+
46
+ normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
47
+
48
+ projected_vision_outputs = mx.einsum(
49
+ "btm,md->btd", normed_vision_outputs, self.mm_input_projection_weight
50
+ )
51
+ return projected_vision_outputs.astype(x.dtype)
52
+
53
+
54
+ def masked_scatter(
55
+ final_embedding: mx.array,
56
+ image_mask_expanded: mx.array,
57
+ scaled_image_features: mx.array,
58
+ ):
59
+ # Reshape the tensors to 1D
60
+ final_embedding_shape = final_embedding.shape
61
+ scaled_image_features_flattened = mx.flatten(scaled_image_features)
62
+ final_embedding_flattened = mx.flatten(final_embedding)
63
+ image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
64
+
65
+ # Scatter the scaled image features into the special image token positions
66
+ image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
67
+ final_embedding_flattened[image_positions] = scaled_image_features_flattened
68
+
69
+ # Reshape back to the original shape
70
+ final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
71
+
72
+ return final_embedding
73
+
74
+
75
+ class Model(nn.Module):
76
+ def __init__(self, config: ModelConfig):
77
+ super().__init__()
78
+ self.model_type = config.model_type
79
+ self.config = config
80
+
81
+ self.vision_tower = VisionModel(config.vision_config)
82
+ self.language_model = LanguageModel(config.text_config)
83
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
84
+
85
+ def get_input_embeddings(
86
+ self,
87
+ input_ids: Optional[mx.array] = None,
88
+ pixel_values: Optional[mx.array] = None,
89
+ mask: Optional[mx.array] = None,
90
+ **kwargs,
91
+ ):
92
+ if pixel_values is None:
93
+ return InputEmbeddingsFeatures(
94
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
95
+ )
96
+
97
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
98
+
99
+ hidden_state, _, _ = self.vision_tower(
100
+ pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype),
101
+ output_hidden_states=True,
102
+ )
103
+
104
+ image_features = self.multi_modal_projector(hidden_state)
105
+
106
+ final_inputs_embeds, final_attention_mask_4d = (
107
+ self.prepare_inputs_for_multimodal(
108
+ self.config.hidden_size,
109
+ self.config.pad_token_id,
110
+ self.config.image_token_index,
111
+ image_features,
112
+ inputs_embeds,
113
+ input_ids,
114
+ mask,
115
+ )
116
+ )
117
+ return InputEmbeddingsFeatures(
118
+ inputs_embeds=final_inputs_embeds, attention_mask_4d=final_attention_mask_4d
119
+ )
120
+
121
+ @staticmethod
122
+ def prepare_inputs_for_multimodal(
123
+ hidden_size,
124
+ pad_token_id,
125
+ image_token_index,
126
+ image_features,
127
+ inputs_embeds,
128
+ input_ids,
129
+ attention_mask,
130
+ ):
131
+ _, _, embed_dim = image_features.shape
132
+
133
+ batch_size, sequence_length = input_ids.shape
134
+ scaled_image_features = image_features / (hidden_size**0.5)
135
+ final_embedding = mx.zeros((batch_size, sequence_length, embed_dim))
136
+
137
+ pad_token_id = pad_token_id
138
+ pad_token_id = pad_token_id if pad_token_id is not None else 0
139
+ text_mask = (input_ids != image_token_index) & (input_ids != pad_token_id)
140
+ image_mask = input_ids == image_token_index
141
+ pad_mask = input_ids == pad_token_id
142
+
143
+ # expand masks to match embedding dimension
144
+ text_mask_expanded = mx.expand_dims(text_mask, -1)
145
+ text_mask_expanded = mx.repeat(text_mask_expanded, embed_dim, axis=-1)
146
+ pad_mask_expanded = mx.expand_dims(pad_mask, -1)
147
+ pad_mask_expanded = mx.repeat(pad_mask_expanded, embed_dim, axis=-1)
148
+ image_mask_expanded = mx.expand_dims(image_mask, -1)
149
+ image_mask_expanded = mx.repeat(image_mask_expanded, embed_dim, axis=-1)
150
+
151
+ # insert padding and text token embeddings
152
+ final_embedding = mx.where(text_mask_expanded, inputs_embeds, final_embedding)
153
+ final_embedding = mx.where(
154
+ pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding
155
+ )
156
+
157
+ # insert image token embeddings
158
+ final_embedding = masked_scatter(
159
+ final_embedding, image_mask_expanded, scaled_image_features
160
+ )
161
+
162
+ attention_mask_expanded_1 = mx.expand_dims(attention_mask, 1)
163
+ attention_mask_expanded_2 = mx.expand_dims(attention_mask, 2)
164
+ final_attention_mask_4d = attention_mask_expanded_1 * attention_mask_expanded_2
165
+ final_attention_mask_4d = final_attention_mask_4d
166
+ final_attention_mask_4d = mx.expand_dims(final_attention_mask_4d, 1)
167
+ final_embedding = mx.array(final_embedding)
168
+ return final_embedding.astype(inputs_embeds.dtype), final_attention_mask_4d
169
+
170
+ @property
171
+ def layers(self):
172
+ return self.language_model.model.layers
173
+
174
+ def __call__(
175
+ self,
176
+ input_ids: mx.array,
177
+ pixel_values: mx.array,
178
+ mask: Optional[mx.array] = None,
179
+ cache: Optional[mx.array] = None,
180
+ **kwargs,
181
+ ):
182
+ input_embeddings_features = self.get_input_embeddings(
183
+ input_ids, pixel_values, mask
184
+ )
185
+ inputs_embeds = input_embeddings_features.inputs_embeds
186
+ attention_mask = input_embeddings_features.attention_mask_4d
187
+
188
+ logits = self.language_model(
189
+ inputs=input_ids,
190
+ cache=cache,
191
+ inputs_embeds=inputs_embeds,
192
+ mask=attention_mask,
193
+ )
194
+ return logits
@@ -0,0 +1,293 @@
1
+ from functools import partial
2
+ from typing import Any, Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from ..cache import KVCache, RotatingKVCache
13
+ from .config import TextConfig
14
+
15
+
16
+ class RMSNorm(nn.Module):
17
+ def __init__(self, dims: int, eps: float = 1e-5):
18
+ super().__init__()
19
+ self.weight = mx.ones((dims,))
20
+ self.eps = eps
21
+
22
+ def __call__(self, x):
23
+ return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
24
+
25
+
26
+ class Attention(nn.Module):
27
+ def __init__(self, config: TextConfig, layer_idx: int):
28
+ super().__init__()
29
+
30
+ dim = config.hidden_size
31
+ self.n_heads = n_heads = config.num_attention_heads
32
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
33
+ self.repeats = n_heads // n_kv_heads
34
+ self.head_dim = head_dim = config.head_dim
35
+ self.layer_idx = layer_idx
36
+
37
+ self.scale = config.query_pre_attn_scalar**-0.5
38
+
39
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
40
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
41
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
42
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
43
+
44
+ self.q_norm = RMSNorm(dims=head_dim, eps=config.rms_norm_eps)
45
+ self.k_norm = RMSNorm(dims=head_dim, eps=config.rms_norm_eps)
46
+ self.is_sliding = (layer_idx + 1) % config.sliding_window_pattern != 0
47
+
48
+ self.rope = nn.RoPE(
49
+ head_dim,
50
+ traditional=config.rope_traditional,
51
+ base=(
52
+ config.rope_local_base_freq
53
+ if self.is_sliding
54
+ else config.rope_global_base_freq
55
+ ),
56
+ )
57
+
58
+ def __call__(
59
+ self,
60
+ x: mx.array,
61
+ mask: Optional[mx.array] = None,
62
+ cache: Optional[Any] = None,
63
+ ) -> mx.array:
64
+ B, L, _ = x.shape
65
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
66
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
67
+
68
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
69
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
70
+
71
+ queries = self.q_norm(queries)
72
+ keys = self.k_norm(keys)
73
+
74
+ if cache is not None:
75
+ queries = self.rope(queries, offset=cache.offset)
76
+ keys = self.rope(keys, offset=cache.offset)
77
+ keys, values = cache.update_and_fetch(keys, values)
78
+ else:
79
+ queries = self.rope(queries)
80
+ keys = self.rope(keys)
81
+
82
+ # Sliding window
83
+ if mask is not None and isinstance(mask, mx.array):
84
+ if mask.shape[-1] != keys.shape[-2]:
85
+ mask = mask[..., -keys.shape[-2] :]
86
+
87
+ output = scaled_dot_product_attention(
88
+ queries, keys, values, cache, scale=self.scale, mask=mask
89
+ )
90
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
91
+ return self.o_proj(output)
92
+
93
+
94
+ class MLP(nn.Module):
95
+ def __init__(self, dim, hidden_dim):
96
+ super().__init__()
97
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
98
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
99
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
100
+
101
+ def __call__(self, x) -> mx.array:
102
+ # This should not be GELU approx, jax.nn.gelu
103
+ return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
104
+
105
+
106
+ @partial(mx.compile, shapeless=True)
107
+ def clip_residual(x, y=None):
108
+ bound = mx.finfo(mx.float16).max
109
+ if y is None:
110
+ if x.dtype == mx.float16:
111
+ return mx.clip(x.astype(mx.float32), -bound, bound).astype(mx.float16)
112
+ else:
113
+ return x
114
+
115
+ if x.dtype != mx.float16:
116
+ return x + y
117
+
118
+ return mx.clip(x.astype(mx.float32) + y.astype(mx.float32), -bound, bound).astype(
119
+ mx.float16
120
+ )
121
+
122
+
123
+ class TransformerBlock(nn.Module):
124
+ def __init__(self, config: TextConfig, layer_idx: int):
125
+ super().__init__()
126
+ self.num_attention_heads = config.num_attention_heads
127
+ self.hidden_size = config.hidden_size
128
+ self.self_attn = Attention(config, layer_idx)
129
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
130
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
131
+ self.post_attention_layernorm = RMSNorm(
132
+ config.hidden_size, eps=config.rms_norm_eps
133
+ )
134
+ self.pre_feedforward_layernorm = RMSNorm(
135
+ config.hidden_size, eps=config.rms_norm_eps
136
+ )
137
+ self.post_feedforward_layernorm = RMSNorm(
138
+ config.hidden_size, eps=config.rms_norm_eps
139
+ )
140
+
141
+ def __call__(
142
+ self,
143
+ x: mx.array,
144
+ mask: Optional[mx.array] = None,
145
+ cache: Optional[Any] = None,
146
+ ) -> mx.array:
147
+
148
+ # Clip the input to avoid overflow in float16
149
+ # Float16 has a max value of 65504. When values exceed this limit, they become inf.
150
+ # Example: If x contains 70000.0 in float16, it becomes inf, causing gradient issues.
151
+ # We upcast to float32 for operations that might exceed the limit, then clip and
152
+ # convert back to float16 to maintain numerical stability.
153
+
154
+ # Clip input to avoid overflow in float16
155
+ x = clip_residual(x)
156
+
157
+ # Self-attention block
158
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
159
+ h = self.post_attention_layernorm(r)
160
+
161
+ # Add residual connection with overflow protection for float16
162
+ h = clip_residual(x + h)
163
+
164
+ # MLP block
165
+ r = self.mlp(self.pre_feedforward_layernorm(h))
166
+ out = self.post_feedforward_layernorm(r)
167
+
168
+ # Add residual connection with overflow protection for float16
169
+ out = clip_residual(h + out)
170
+
171
+ return out
172
+
173
+
174
+ class Gemma3Model(nn.Module):
175
+ def __init__(self, config: TextConfig):
176
+ super().__init__()
177
+ self.config = config
178
+ self.vocab_size = config.vocab_size
179
+ self.window_size = config.sliding_window
180
+ self.sliding_window_pattern = config.sliding_window_pattern
181
+ self.num_hidden_layers = config.num_hidden_layers
182
+ assert self.vocab_size > 0
183
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
184
+ self.layers = [
185
+ TransformerBlock(config=config, layer_idx=layer_idx)
186
+ for layer_idx in range(config.num_hidden_layers)
187
+ ]
188
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
189
+
190
+ def __call__(
191
+ self,
192
+ inputs: mx.array,
193
+ inputs_embeds: mx.array = None,
194
+ mask: mx.array = None,
195
+ cache=None,
196
+ ):
197
+ if inputs_embeds is None:
198
+ h = self.embed_tokens(inputs)
199
+ else:
200
+ h = inputs_embeds
201
+
202
+ h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
203
+
204
+ if cache is None:
205
+ cache = [None] * len(self.layers)
206
+
207
+ if mask is None:
208
+ global_mask = create_attention_mask(
209
+ h, cache[self.sliding_window_pattern - 1]
210
+ )
211
+
212
+ if self.sliding_window_pattern > 1:
213
+ sliding_window_mask = create_attention_mask(
214
+ h,
215
+ cache[0],
216
+ window_size=self.window_size,
217
+ )
218
+ else:
219
+ sliding_window_mask = None
220
+
221
+ for i, (layer, c) in enumerate(zip(self.layers, cache)):
222
+ is_global = (
223
+ i % self.sliding_window_pattern == self.sliding_window_pattern - 1
224
+ )
225
+
226
+ local_mask = mask
227
+ if mask is None and is_global:
228
+ local_mask = global_mask
229
+ elif mask is None:
230
+ local_mask = sliding_window_mask
231
+
232
+ h = layer(h, local_mask, c)
233
+
234
+ return self.norm(h)
235
+
236
+
237
+ class LanguageModel(nn.Module):
238
+ def __init__(self, config: TextConfig):
239
+ super().__init__()
240
+ self.config = config
241
+ self.model_type = config.model_type
242
+ self.model = Gemma3Model(config)
243
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
244
+
245
+ def __call__(
246
+ self,
247
+ inputs: mx.array,
248
+ inputs_embeds: Optional[mx.array] = None,
249
+ mask: Optional[mx.array] = None,
250
+ cache=None,
251
+ **kwargs,
252
+ ):
253
+ out = self.model(inputs, inputs_embeds=inputs_embeds, mask=mask, cache=cache)
254
+ out = self.lm_head(out)
255
+ return LanguageModelOutput(logits=out)
256
+
257
+ def sanitize(self, weights):
258
+ if "lm_head.weight" not in weights:
259
+ weights["language_model.lm_head.weight"] = weights[
260
+ "language_model.model.embed_tokens.weight"
261
+ ]
262
+ return {
263
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
264
+ }
265
+
266
+ @property
267
+ def layers(self):
268
+ return self.model.layers
269
+
270
+ @property
271
+ def head_dim(self):
272
+ return self.config.head_dim
273
+
274
+ @property
275
+ def n_kv_heads(self):
276
+ return self.config.num_key_value_heads
277
+
278
+ def make_cache(self):
279
+ caches = []
280
+ for i in range(self.config.num_hidden_layers):
281
+ if (
282
+ i % self.config.sliding_window_pattern
283
+ == self.config.sliding_window_pattern - 1
284
+ ):
285
+ caches.append(KVCache())
286
+ else:
287
+ caches.append(
288
+ RotatingKVCache(
289
+ max_size=self.config.sliding_window,
290
+ keep=0,
291
+ )
292
+ )
293
+ return caches
@@ -0,0 +1,215 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 dimensions
14
+ if len(shape) != 4:
15
+ return False
16
+
17
+ out_channels, kH, KW, _ = shape
18
+
19
+ # Check if out_channels is the largest, and kH and KW are the same
20
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ class Attention(nn.Module):
27
+ def __init__(
28
+ self,
29
+ dims: int,
30
+ num_heads: int,
31
+ query_input_dims: Optional[int] = None,
32
+ key_input_dims: Optional[int] = None,
33
+ value_input_dims: Optional[int] = None,
34
+ value_dims: Optional[int] = None,
35
+ value_output_dims: Optional[int] = None,
36
+ bias: bool = True,
37
+ ):
38
+ super().__init__()
39
+
40
+ if (dims % num_heads) != 0:
41
+ raise ValueError(
42
+ "The input feature dimensions should be divisible by the "
43
+ f"number of heads ({dims} % {num_heads}) != 0"
44
+ )
45
+
46
+ query_input_dims = query_input_dims or dims
47
+ key_input_dims = key_input_dims or dims
48
+ value_input_dims = value_input_dims or key_input_dims
49
+ value_dims = value_dims or dims
50
+ value_output_dims = value_output_dims or dims
51
+
52
+ self.num_heads = num_heads
53
+ head_dim = dims // num_heads
54
+ self.scale = head_dim**-0.5
55
+
56
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
57
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
58
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
59
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
60
+
61
+ def __call__(self, x, mask=None):
62
+ queries = self.q_proj(x)
63
+ keys = self.k_proj(x)
64
+ values = self.v_proj(x)
65
+
66
+ num_heads = self.num_heads
67
+ B, L, D = queries.shape
68
+ _, S, _ = keys.shape
69
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
70
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
71
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
72
+
73
+ output = mx.fast.scaled_dot_product_attention(
74
+ queries, keys, values, scale=self.scale, mask=mask
75
+ )
76
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
77
+ return self.out_proj(output)
78
+
79
+
80
+ class MLP(nn.Module):
81
+ def __init__(self, config: VisionConfig):
82
+ super().__init__()
83
+ self.activation_fn = nn.GELU(approx="precise")
84
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
85
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
86
+
87
+ def __call__(self, x: mx.array) -> mx.array:
88
+ x = self.fc1(x)
89
+ x = self.activation_fn(x)
90
+ x = self.fc2(x)
91
+ return x
92
+
93
+
94
+ class EncoderLayer(nn.Module):
95
+ def __init__(self, config: VisionConfig):
96
+ super().__init__()
97
+ self.embed_dim = config.hidden_size
98
+ self.self_attn = Attention(
99
+ config.hidden_size, config.num_attention_heads, bias=True
100
+ )
101
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
102
+ self.mlp = MLP(config)
103
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
104
+
105
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
106
+ r = self.self_attn(self.layer_norm1(x), mask)
107
+ h = x + r
108
+ r = self.mlp(self.layer_norm2(h))
109
+ return h + r
110
+
111
+
112
+ class Encoder(nn.Module):
113
+ def __init__(self, config: VisionConfig):
114
+ super().__init__()
115
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
116
+
117
+ def __call__(
118
+ self,
119
+ x: mx.array,
120
+ output_hidden_states: Optional[bool] = None,
121
+ mask: Optional[mx.array] = None,
122
+ ) -> mx.array:
123
+ encoder_states = (x,) if output_hidden_states else None
124
+ h = x
125
+ for l in self.layers:
126
+ x = l(x, mask=mask)
127
+ if output_hidden_states:
128
+ encoder_states = encoder_states + (x,)
129
+
130
+ h = x
131
+
132
+ return (h, encoder_states)
133
+
134
+
135
+ class VisionEmbeddings(nn.Module):
136
+ def __init__(self, config: VisionConfig):
137
+ super().__init__()
138
+ self.config = config
139
+ self.embed_dim = config.hidden_size
140
+ self.image_size = config.image_size
141
+ self.patch_size = config.patch_size
142
+
143
+ self.patch_embedding = nn.Conv2d(
144
+ in_channels=config.num_channels,
145
+ out_channels=self.embed_dim,
146
+ kernel_size=self.patch_size,
147
+ stride=self.patch_size,
148
+ )
149
+
150
+ self.num_patches = (self.image_size // self.patch_size) ** 2
151
+ self.num_positions = self.num_patches
152
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
153
+
154
+ def __call__(self, x: mx.array) -> mx.array:
155
+ patch_embeddings = self.patch_embedding(x)
156
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
157
+ position_ids = mx.array(np.arange(self.num_positions)[None, :])
158
+ embeddings = patch_embeddings
159
+ embeddings += self.position_embedding(position_ids)
160
+ return embeddings
161
+
162
+
163
+ class SigLipVisionModel(nn.Module):
164
+ def __init__(self, config: VisionConfig):
165
+ super().__init__()
166
+ self.embeddings = VisionEmbeddings(config)
167
+ self.encoder = Encoder(config)
168
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
169
+
170
+ def __call__(
171
+ self,
172
+ x: mx.array,
173
+ output_hidden_states: Optional[bool] = None,
174
+ ) -> mx.array:
175
+ x = self.embeddings(x)
176
+
177
+ encoder_outputs = self.encoder(
178
+ x=x, output_hidden_states=output_hidden_states, mask=None
179
+ )
180
+
181
+ pooler_output = self.post_layernorm(encoder_outputs[0])
182
+
183
+ return pooler_output, x, encoder_outputs[-1]
184
+
185
+
186
+ class VisionModel(nn.Module):
187
+ def __init__(self, config: VisionConfig):
188
+ super().__init__()
189
+ self.model_type = config.model_type
190
+ if self.model_type not in ["siglip_vision_model", "gemma3", "gemma3_vision"]:
191
+ raise ValueError(f"Unsupported model type: {self.model_type}")
192
+
193
+ self.vision_model = SigLipVisionModel(config)
194
+
195
+ def __call__(
196
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
197
+ ) -> mx.array:
198
+ return self.vision_model(x, output_hidden_states)
199
+
200
+ def sanitize(self, weights):
201
+ sanitized_weights = {}
202
+ for k, v in weights.items():
203
+ if "patch_embedding.weight" in k:
204
+ # PyTorch conv2d weight tensors have shape:
205
+ # [out_channels, in_channels, kH, KW]
206
+ # MLX conv2d expects the weight be of shape:
207
+ # [out_channels, kH, KW, in_channels]
208
+ if check_array_shape(v):
209
+ sanitized_weights[k] = v
210
+ else:
211
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
212
+ else:
213
+ sanitized_weights[k] = v
214
+
215
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import AudioConfig, ModelConfig, TextConfig, VisionConfig
2
+ from .gemma3n import AudioModel, LanguageModel, Model, VisionModel