fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,130 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class AudioConfig(BaseModelConfig):
9
+ input_feat_size: int = 80
10
+ hidden_size: int = 1536
11
+ conf_attention_chunk_size: int = 12
12
+ conf_attention_context_left: int = 13
13
+ conf_attention_context_right: int = 0
14
+ conf_attention_invalid_logits_value: float = -1e9
15
+ conf_attention_logit_cap: float = 50.0
16
+ conf_num_attention_heads: int = 8
17
+ conf_num_hidden_layers: int = 12
18
+ conf_conv_kernel_size: int = 5
19
+ conf_positional_bias_size: int = 256
20
+ conf_reduction_factor: int = 4
21
+ conf_residual_weight: float = 0.5
22
+ sscp_conv_channel_size: tuple[int, int] = (128, 32)
23
+ sscp_conv_group_norm_eps: float = 1e-3
24
+ sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = ((3, 3), (3, 3))
25
+ sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = ((2, 2), (2, 2))
26
+ vocab_size: int = 128
27
+ sscp_conv_eps: float = 1e-3
28
+ rms_norm_eps: float = 1e-6
29
+ gradient_clipping: float = 10000000000.0
30
+ vocab_offset: int = 262_144 + 128 # text vocab size + vision vocab size
31
+
32
+
33
+ @dataclass
34
+ class VisionConfig(BaseModelConfig):
35
+ model_type: str = "gemma3n_vision"
36
+ num_hidden_layers: int = 12
37
+ hidden_size: int = 2048
38
+ intermediate_size: int = 8192
39
+ num_attention_heads: int = 16
40
+ patch_size: int = 16
41
+ image_size: int = 224
42
+ num_channels: int = 3
43
+ rms_norm_eps: float = 1e-6
44
+ vocab_size: int = 128
45
+ vocab_offset: int = 262_144
46
+
47
+
48
+ @dataclass
49
+ class TextConfig(BaseModelConfig):
50
+ model_type: str
51
+ hidden_size: int
52
+ num_hidden_layers: int
53
+ intermediate_size: int
54
+ num_attention_heads: int = 2
55
+ head_dim: int = 256
56
+ rms_norm_eps: float = 1.0e-6
57
+ vocab_size: int = 262400
58
+ vocab_size_per_layer_input: int = 262144
59
+ num_key_value_heads: int = 4
60
+ laurel_rank: int = 64
61
+ frac_shared_layers: float = 0.5
62
+ altup_active_idx: int = 0
63
+ pad_token_id: int = 0
64
+ altup_num_inputs: int = 4
65
+ altup_coef_clip: Optional[float] = None
66
+ altup_correct_scale: bool = True
67
+ hidden_size_per_layer_input: int = 1024
68
+ rope_local_base_freq: float = 10000.0
69
+ rope_traditional: bool = False
70
+ rope_theta: float = 1000000.0
71
+ query_pre_attn_scalar: float = 0.0625
72
+ sliding_window: int = 1024
73
+ rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
74
+ mm_tokens_per_image: int = 256
75
+ sliding_window_pattern: int = 5
76
+ activation_sparsity_pattern: Optional[List[float]] = None
77
+ final_logit_softcapping: float = 30.0
78
+ query_rescale_scalar: float = 1.0
79
+ num_kv_shared_layers: int = 0
80
+ max_position_embeddings: int = 32768
81
+ attn_logit_softcapping: float = 0.0
82
+ layer_types: List[str] = None
83
+
84
+
85
+ @dataclass
86
+ class ModelConfig(BaseModelConfig):
87
+ text_config: TextConfig
88
+ vision_config: VisionConfig
89
+ audio_config: AudioConfig
90
+ model_type: str
91
+ vocab_size: int = 257152
92
+ ignore_index: int = -100
93
+ image_token_index: int = 262145
94
+ audio_token_id: int = 262273
95
+ image_token_id: int = 262145
96
+ hidden_size: int = 2048
97
+ pad_token_id: int = 0
98
+ vision_soft_tokens_per_image: int = 256
99
+ audio_soft_tokens_per_image: int = 188
100
+ eos_token_id: Optional[List[int]] = None
101
+
102
+
103
+ @dataclass
104
+ class MultiQueryAttentionBlockConfig(BaseModelConfig):
105
+ num_heads: int = 8
106
+ kv_dim: int = 16
107
+ kv_strides: int = 1
108
+ mmqa_avg_pool_kv: bool = False
109
+ mmqa_dropout: float = 0.0
110
+ mmqa_dw_kernel_size: int = 3
111
+ is_multiscale: bool = False
112
+
113
+
114
+ @dataclass
115
+ class UniversalInvertedResidualConfig(BaseModelConfig):
116
+ start_dw_kernel_size: int = 0 # Zero size means no conv
117
+ mid_dw_kernel_size: int = 0 # Zero size means no conv
118
+ filters: int = 32
119
+ strides: int = 1
120
+ expand_ratio: float = 4.0
121
+ is_multiscale: bool = False
122
+
123
+
124
+ @dataclass
125
+ class EdgeResidualConfig(BaseModelConfig):
126
+ kernel_size: int = 3
127
+ filters: int = 32
128
+ strides: int = 1
129
+ expand_ratio: float = 4.0
130
+ is_multiscale: bool = False
@@ -0,0 +1,322 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .audio import AudioModel
8
+ from .config import ModelConfig, TextConfig
9
+ from .language import Gemma3nRMSNorm, LanguageModel
10
+ from .vision import VisionModel
11
+
12
+
13
+ def masked_scatter(input_tensor, mask, source):
14
+ """MLX implementation of PyTorch's masked_scatter"""
15
+
16
+ # Convert mask to boolean once
17
+ mask = mask.astype(mx.bool_)
18
+
19
+ # Early exit
20
+ if not mask.any():
21
+ return mx.broadcast_to(input_tensor, mask.shape)
22
+
23
+ # Flatten everything once
24
+ input_shape = mask.shape
25
+ result_flat = mx.broadcast_to(input_tensor, input_shape).flatten()
26
+ mask_flat = mask.flatten()
27
+ source_flat = source.flatten()
28
+
29
+ # Create selection indices using cumulative sum
30
+ selection_mask = mx.cumsum(mask_flat.astype(mx.int32)) - 1
31
+
32
+ # Bound check and create source selection
33
+ source_len = len(source_flat)
34
+ bounded_indices = selection_mask % source_len
35
+
36
+ # Vectorized selection from source
37
+ selected_values = source_flat[bounded_indices]
38
+
39
+ result_flat = mx.where(mask_flat, selected_values, result_flat)
40
+
41
+ return result_flat.reshape(input_shape)
42
+
43
+
44
+ class Gemma3nMultimodalEmbedder(nn.Module):
45
+ """Embeds token ids or soft tokens into language model space."""
46
+
47
+ def __init__(self, multimodal_config: ModelConfig, text_config: TextConfig):
48
+ super().__init__()
49
+
50
+ self.multimodal_hidden_size = multimodal_config.hidden_size
51
+ self.eps = multimodal_config.rms_norm_eps
52
+ self.vocab_offset = multimodal_config.vocab_offset
53
+ self.vocab_size = multimodal_config.vocab_size
54
+ self.text_hidden_size = text_config.hidden_size
55
+
56
+ self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
57
+ self.hard_embedding_norm = Gemma3nRMSNorm(
58
+ self.multimodal_hidden_size, eps=self.eps
59
+ )
60
+ self.soft_embedding_norm = Gemma3nRMSNorm(
61
+ self.multimodal_hidden_size, eps=self.eps
62
+ )
63
+ self.embedding_projection = nn.Linear(
64
+ self.multimodal_hidden_size, self.text_hidden_size, bias=False
65
+ )
66
+ self.embedding_post_projection_norm = Gemma3nRMSNorm(
67
+ self.text_hidden_size, eps=self.eps, with_scale=False
68
+ )
69
+
70
+ def __call__(
71
+ self, input_ids: mx.array = None, inputs_embeds: mx.array = None
72
+ ) -> mx.array:
73
+ if (input_ids is None) ^ (inputs_embeds is not None):
74
+ raise ValueError(
75
+ "You must specify exactly one of input_ids or inputs_embeds"
76
+ )
77
+
78
+ if inputs_embeds is not None:
79
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
80
+ else:
81
+
82
+ hard_emb = self.embedding(input_ids - self.vocab_offset)
83
+ emb_norm = self.hard_embedding_norm(hard_emb)
84
+
85
+ emb_norm_proj = self.embedding_projection(emb_norm)
86
+ projected = self.embedding_post_projection_norm(emb_norm_proj)
87
+ return projected
88
+
89
+
90
+ class Model(nn.Module):
91
+ def __init__(self, config: ModelConfig):
92
+ super().__init__()
93
+ self.model_type = config.model_type
94
+ self.config = config
95
+
96
+ # Text
97
+ self.language_model = LanguageModel(config.text_config)
98
+ self.vocab_size = config.text_config.vocab_size
99
+ self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
100
+
101
+ # Vision
102
+ self.vision_tower = VisionModel(config.vision_config)
103
+ self.embed_vision = Gemma3nMultimodalEmbedder(
104
+ config.vision_config, text_config=config.text_config
105
+ )
106
+
107
+ # Audio
108
+ self.audio_tower = AudioModel(config.audio_config)
109
+ self.embed_audio = Gemma3nMultimodalEmbedder(
110
+ config.audio_config, text_config=config.text_config
111
+ )
112
+
113
+ def get_input_embeddings(
114
+ self,
115
+ input_ids: Optional[mx.array] = None,
116
+ pixel_values: Optional[mx.array] = None,
117
+ **kwargs,
118
+ ):
119
+ input_features = kwargs.get("input_features", None)
120
+ input_features_mask = kwargs.get("input_features_mask", None)
121
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
122
+
123
+ per_layer_inputs_mask = mx.logical_and(
124
+ input_ids >= 0, input_ids < self.vocab_size_per_layer_input
125
+ )
126
+ per_layer_inputs_tokens = mx.where(
127
+ per_layer_inputs_mask, input_ids, mx.zeros_like(input_ids)
128
+ )
129
+ per_layer_inputs = self.language_model.model.get_per_layer_inputs(
130
+ per_layer_inputs_tokens
131
+ )
132
+ if pixel_values is None and input_features is None:
133
+ return InputEmbeddingsFeatures(
134
+ inputs_embeds=inputs_embeds, per_layer_inputs=per_layer_inputs
135
+ )
136
+
137
+ if input_ids is not None:
138
+
139
+ # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
140
+ vision_mask = mx.logical_and(
141
+ input_ids >= self.embed_vision.vocab_offset,
142
+ input_ids < self.embed_audio.vocab_offset,
143
+ )
144
+ dummy_vision_token_id = (
145
+ self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
146
+ )
147
+ vision_tokens = mx.where(vision_mask, input_ids, dummy_vision_token_id)
148
+ vision_embeds_flat = self.embed_vision(input_ids=vision_tokens)
149
+ inputs_embeds = mx.where(
150
+ vision_mask[..., None], vision_embeds_flat, inputs_embeds
151
+ )
152
+
153
+ # Handle audio tokens (>= embed_audio.vocab_offset)
154
+ audio_mask = input_ids >= self.embed_audio.vocab_offset
155
+ dummy_audio_token_id = (
156
+ self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
157
+ )
158
+
159
+ audio_tokens = mx.where(audio_mask, input_ids, dummy_audio_token_id)
160
+ audio_embeds_flat = self.embed_audio(input_ids=audio_tokens)
161
+ inputs_embeds = mx.where(
162
+ audio_mask[..., None], audio_embeds_flat, inputs_embeds
163
+ )
164
+ else:
165
+ per_layer_inputs = None
166
+
167
+ # Vision features
168
+ if pixel_values is not None:
169
+ image_features = self.get_image_features(
170
+ pixel_values, self.vision_tower, self.config, self.embed_vision
171
+ )
172
+
173
+ modality = "image"
174
+ inputs_embeds = self.merge_multimodal_and_text(
175
+ inputs_embeds,
176
+ image_features,
177
+ self.construct_special_modality_mask(
178
+ input_ids,
179
+ inputs_embeds,
180
+ self.config.image_token_id,
181
+ modality=modality,
182
+ ),
183
+ modality=modality,
184
+ )
185
+
186
+ # Audio features
187
+ if input_features is not None:
188
+ audio_features, audio_mask = self.get_audio_features(
189
+ input_features, ~input_features_mask
190
+ )
191
+ audio_padding_ids = mx.array([[self.vocab_size - 1]])
192
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_ids)
193
+ audio_features = mx.where(
194
+ audio_mask[..., None], audio_padding_embs, audio_features
195
+ )
196
+
197
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
198
+ extra_padding_tokens = (
199
+ self.config.audio_soft_tokens_per_image - audio_seq_len
200
+ )
201
+ extra_padding_features = mx.broadcast_to(
202
+ audio_padding_embs,
203
+ (audio_batch_size, extra_padding_tokens, audio_embed_dim),
204
+ )
205
+
206
+ audio_features = mx.concatenate(
207
+ (audio_features, extra_padding_features), axis=1
208
+ )
209
+ modality = "audio"
210
+ inputs_embeds = self.merge_multimodal_and_text(
211
+ inputs_embeds,
212
+ audio_features,
213
+ self.construct_special_modality_mask(
214
+ input_ids,
215
+ inputs_embeds,
216
+ self.config.audio_token_id,
217
+ modality=modality,
218
+ ),
219
+ modality=modality,
220
+ )
221
+
222
+ return InputEmbeddingsFeatures(
223
+ inputs_embeds=inputs_embeds, per_layer_inputs=per_layer_inputs
224
+ )
225
+
226
+ def get_audio_features(self, input_features, input_features_mask):
227
+ audio_outputs, audio_mask = self.audio_tower(
228
+ input_features, input_features_mask
229
+ )
230
+ return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
231
+
232
+ @staticmethod
233
+ def get_image_features(pixel_values, vision_tower, config, embed_vision):
234
+ vision_outputs = vision_tower(
235
+ pixel_values,
236
+ output_hidden_states=True,
237
+ )
238
+ vision_outputs = vision_outputs.transpose(0, 3, 1, 2)
239
+ vision_outputs = vision_outputs.reshape(
240
+ vision_outputs.shape[0],
241
+ config.vision_config.hidden_size,
242
+ config.vision_soft_tokens_per_image,
243
+ ).transpose(0, 2, 1)
244
+
245
+ # Normalize and embed the soft tokens into language model space.
246
+ vision_outputs *= config.vision_config.hidden_size**0.5
247
+ return embed_vision(inputs_embeds=vision_outputs)
248
+
249
+ def construct_special_modality_mask(
250
+ self, input_ids, inputs_embeds, token_id, modality="image"
251
+ ):
252
+ if input_ids is None:
253
+ embed_fn = (
254
+ self.embed_audio
255
+ if modality == "audio"
256
+ else self.language_model.model.embed_tokens
257
+ )
258
+ special_modality_mask = inputs_embeds == embed_fn(
259
+ input_ids=mx.array([token_id])
260
+ )
261
+ else:
262
+ special_modality_mask = mx.expand_dims(input_ids == token_id, -1)
263
+ special_modality_mask = mx.broadcast_to(
264
+ special_modality_mask, inputs_embeds.shape
265
+ )
266
+ return special_modality_mask
267
+
268
+ @staticmethod
269
+ def merge_multimodal_and_text(
270
+ inputs_embeds, features, special_modality_mask, modality="image"
271
+ ):
272
+ # Count special tokens by summing the mask
273
+ modality_tokens_in_text = special_modality_mask.sum()
274
+ feature_tokens = features.size
275
+
276
+ if modality_tokens_in_text != feature_tokens:
277
+ raise ValueError(
278
+ f"Number of {modality}s does not match number of special {modality} tokens in the input text. "
279
+ f"Got {modality_tokens_in_text} {modality} tokens in the text and "
280
+ f"{feature_tokens} tokens from {modality} embeddings."
281
+ )
282
+ features = features.astype(inputs_embeds.dtype)
283
+
284
+ inputs_embeds = masked_scatter(inputs_embeds, special_modality_mask, features)
285
+ return inputs_embeds
286
+
287
+ def __call__(
288
+ self,
289
+ input_ids: mx.array,
290
+ pixel_values: mx.array,
291
+ mask: Optional[mx.array] = None,
292
+ cache: Optional[mx.array] = None,
293
+ **kwargs,
294
+ ):
295
+ # Audio features
296
+ input_embeddings_features = self.get_input_embeddings(
297
+ input_ids=input_ids,
298
+ pixel_values=pixel_values,
299
+ **kwargs,
300
+ )
301
+
302
+ logits = self.language_model(
303
+ input_ids=None,
304
+ cache=cache,
305
+ inputs_embeds=input_embeddings_features.inputs_embeds,
306
+ per_layer_inputs=input_embeddings_features.per_layer_inputs,
307
+ )
308
+ return logits
309
+
310
+ def sanitize(self, weights):
311
+ sanitized_weights = {}
312
+ for k, v in weights.items():
313
+ # if "vision_tower" not in k and "embed_vision" not in k:
314
+ if k.startswith("model."):
315
+ sanitized_weights[".".join(k.split(".")[1:])] = v
316
+ else:
317
+ sanitized_weights[k] = v
318
+ return sanitized_weights
319
+
320
+ @property
321
+ def layers(self):
322
+ return self.language_model.model.layers