fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,85 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class TextConfig(BaseModelConfig):
10
+ model_type: str
11
+ hidden_size: int
12
+ num_hidden_layers: int
13
+ intermediate_size: int
14
+ num_attention_heads: int
15
+ rms_norm_eps: float
16
+ vocab_size: int
17
+ attention_bias: bool = True
18
+ num_key_value_heads: int = None
19
+ rope_theta: float = 1000000
20
+ rope_traditional: bool = False
21
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
22
+ max_position_embeddings: int = 4096
23
+ tie_word_embeddings: bool = True
24
+
25
+ def __post_init__(self):
26
+ if self.num_key_value_heads is None:
27
+ self.num_key_value_heads = self.num_attention_heads
28
+
29
+ if self.rope_scaling:
30
+ required_keys = {"factor", "type"}
31
+ if not all(key in self.rope_scaling for key in required_keys):
32
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
33
+
34
+ if self.rope_scaling["type"] != "linear":
35
+ raise ValueError("rope_scaling 'type' currently only supports 'linear'")
36
+
37
+
38
+ @dataclass
39
+ class VisionConfig(BaseModelConfig):
40
+ model_type: str
41
+ num_hidden_layers: int = 27
42
+ hidden_size: int = 1152
43
+ intermediate_size: int = 4304
44
+ num_attention_heads: int = 16
45
+ image_size: int = 384
46
+ patch_size: int = 14
47
+ projection_dim: int = 768
48
+ vocab_size: int = 32000
49
+ num_channels: int = 3
50
+ layer_norm_eps: float = 1e-6
51
+
52
+
53
+ @dataclass
54
+ class ModelConfig(BaseModelConfig):
55
+ text_config: TextConfig
56
+ vision_config: VisionConfig
57
+ model_type: str
58
+ auto_map: dict
59
+ hidden_size: int
60
+ mm_hidden_size: int
61
+ mm_projector_type: str = "mlp2x_gelu"
62
+ ignore_index: int = -100
63
+ image_token_index: int = -200
64
+ vocab_size: int = 151936
65
+ eos_token_id: Optional[List[int]] = None
66
+
67
+ @classmethod
68
+ def from_dict(cls, params):
69
+ if not params.get("text_config", {}):
70
+ # Copy text config parameters from root level
71
+ excluded_keys = {"vision_config"}
72
+ params["text_config"] = dict(
73
+ filter(lambda x: x[0] not in excluded_keys, params.items())
74
+ )
75
+ if not params.get("vision_config", {}).get("model_type", {}):
76
+ # Set default model type
77
+ params["vision_config"]["model_type"] = "siglip_vision_model"
78
+
79
+ return cls(
80
+ **{
81
+ k: v
82
+ for k, v in params.items()
83
+ if k in inspect.signature(cls).parameters
84
+ }
85
+ )
@@ -0,0 +1,194 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from .config import TextConfig
12
+
13
+
14
+ class Attention(nn.Module):
15
+ def __init__(self, config: TextConfig):
16
+ super().__init__()
17
+
18
+ dim = config.hidden_size
19
+ self.n_heads = n_heads = config.num_attention_heads
20
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
21
+
22
+ head_dim = config.hidden_size // n_heads
23
+ self.scale = head_dim**-0.5
24
+
25
+ if hasattr(config, "attention_bias"):
26
+ attention_bias = config.attention_bias
27
+ else:
28
+ attention_bias = False
29
+
30
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
31
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
32
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
33
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
34
+
35
+ rope_scale = (
36
+ 1 / config.rope_scaling["factor"]
37
+ if config.rope_scaling is not None
38
+ and config.rope_scaling["type"] == "linear"
39
+ else 1
40
+ )
41
+ self.rope = nn.RoPE(
42
+ head_dim,
43
+ traditional=config.rope_traditional,
44
+ base=config.rope_theta,
45
+ scale=rope_scale,
46
+ )
47
+
48
+ def __call__(
49
+ self,
50
+ x: mx.array,
51
+ mask: Optional[mx.array] = None,
52
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
53
+ ) -> mx.array:
54
+ B, L, D = x.shape
55
+
56
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
57
+
58
+ # Prepare the queries, keys and values for the attention computation
59
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
60
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
61
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
62
+
63
+ if cache is not None:
64
+ queries = self.rope(queries, offset=cache.offset)
65
+ keys = self.rope(keys, offset=cache.offset)
66
+ keys, values = cache.update_and_fetch(keys, values)
67
+ else:
68
+ queries = self.rope(queries)
69
+ keys = self.rope(keys)
70
+
71
+ output = scaled_dot_product_attention(
72
+ queries, keys, values, cache, scale=self.scale, mask=mask
73
+ )
74
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
75
+ return self.o_proj(output)
76
+
77
+
78
+ class MLP(nn.Module):
79
+ def __init__(self, dim, hidden_dim):
80
+ super().__init__()
81
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
82
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
83
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
84
+
85
+ def __call__(self, x) -> mx.array:
86
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
87
+
88
+
89
+ class TransformerBlock(nn.Module):
90
+ def __init__(self, config: TextConfig):
91
+ super().__init__()
92
+ self.num_attention_heads = config.num_attention_heads
93
+ self.hidden_size = config.hidden_size
94
+ self.self_attn = Attention(config)
95
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
96
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
97
+ self.post_attention_layernorm = nn.RMSNorm(
98
+ config.hidden_size, eps=config.rms_norm_eps
99
+ )
100
+ self.config = config
101
+
102
+ def __call__(
103
+ self,
104
+ x: mx.array,
105
+ mask: Optional[mx.array] = None,
106
+ cache=None,
107
+ ) -> mx.array:
108
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
109
+ h = x + r
110
+ r = self.mlp(self.post_attention_layernorm(h))
111
+ out = h + r
112
+ return out
113
+
114
+
115
+ class Qwen2Model(nn.Module):
116
+ def __init__(self, config: TextConfig):
117
+ super().__init__()
118
+ self.config = config
119
+ self.vocab_size = config.vocab_size
120
+ self.num_hidden_layers = config.num_hidden_layers
121
+ assert self.vocab_size > 0
122
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
123
+ self.layers = [
124
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
125
+ ]
126
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
127
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
128
+
129
+ def __call__(
130
+ self,
131
+ inputs: mx.array,
132
+ inputs_embeds: Optional[mx.array] = None,
133
+ mask: Optional[mx.array] = None,
134
+ cache=None,
135
+ ):
136
+ # for passing merged input embeddings
137
+ if inputs_embeds is None:
138
+ h = self.embed_tokens(inputs)
139
+ else:
140
+ h = inputs_embeds
141
+
142
+ if cache is None:
143
+ cache = [None] * len(self.layers)
144
+
145
+ if mask is None:
146
+ mask = create_attention_mask(h, cache)
147
+
148
+ for layer, c in zip(self.layers, cache):
149
+ h = layer(h, mask, c)
150
+
151
+ return self.lm_head(self.norm(h))
152
+
153
+
154
+ class LanguageModel(nn.Module):
155
+ def __init__(self, config: TextConfig):
156
+ super().__init__()
157
+ self.config = config
158
+ self.model_type = config.model_type
159
+ self.model = Qwen2Model(config)
160
+
161
+ def __call__(
162
+ self,
163
+ inputs: mx.array,
164
+ inputs_embeds: Optional[mx.array] = None,
165
+ mask: Optional[mx.array] = None,
166
+ cache=None,
167
+ ):
168
+ out = self.model(inputs, mask=mask, cache=cache, inputs_embeds=inputs_embeds)
169
+ return LanguageModelOutput(logits=out)
170
+
171
+ def sanitize(self, weights):
172
+ if (
173
+ self.config.tie_word_embeddings
174
+ and "language_model.model.lm_head.weight" not in weights
175
+ ):
176
+ weights["language_model.model.lm_head.weight"] = weights[
177
+ "language_model.model.embed_tokens.weight"
178
+ ]
179
+
180
+ return {
181
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
182
+ }
183
+
184
+ @property
185
+ def layers(self):
186
+ return self.model.layers
187
+
188
+ @property
189
+ def head_dim(self):
190
+ return self.config.hidden_size // self.config.num_attention_heads
191
+
192
+ @property
193
+ def n_kv_heads(self):
194
+ return self.config.num_key_value_heads
@@ -0,0 +1,217 @@
1
+ import re
2
+ from functools import partial, reduce
3
+ from typing import Optional, Tuple
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from PIL import Image
8
+ from transformers.image_transforms import (
9
+ convert_to_rgb,
10
+ normalize,
11
+ rescale,
12
+ resize,
13
+ to_channel_dimension_format,
14
+ )
15
+ from transformers.image_utils import to_numpy_array
16
+
17
+ from ..base import BaseImageProcessor, InputEmbeddingsFeatures
18
+ from .config import ModelConfig, VisionConfig
19
+ from .language import LanguageModel
20
+ from .vision import VisionModel
21
+
22
+
23
+ class ImageProcessor(BaseImageProcessor):
24
+ def preprocess(self, images):
25
+ if isinstance(images, Image.Image):
26
+ images = [images]
27
+ else:
28
+ assert isinstance(images, list)
29
+
30
+ transforms = [
31
+ convert_to_rgb,
32
+ to_numpy_array,
33
+ partial(
34
+ resize,
35
+ size=self.size,
36
+ resample=self.resample,
37
+ data_format=self.data_format,
38
+ ),
39
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
40
+ partial(
41
+ normalize,
42
+ mean=self.image_mean,
43
+ std=self.image_std,
44
+ data_format=self.data_format,
45
+ ),
46
+ partial(
47
+ to_channel_dimension_format,
48
+ channel_dim=self.data_format,
49
+ input_channel_dim=self.data_format,
50
+ ),
51
+ ]
52
+
53
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
54
+
55
+ return images
56
+
57
+
58
+ class LlavaMultiModalProjector(nn.Module):
59
+ def __init__(self, config: ModelConfig):
60
+ super().__init__()
61
+ self.linear_1 = nn.Linear(
62
+ config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
63
+ )
64
+ self.gelu = nn.GELU()
65
+ self.linear_2 = nn.Linear(
66
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=True
67
+ )
68
+
69
+ def __call__(self, x: mx.array) -> mx.array:
70
+ x = self.linear_1(x)
71
+ x = self.gelu(x)
72
+ x = self.linear_2(x)
73
+ return x
74
+
75
+
76
+ class SigLipVisionTower(nn.Module):
77
+ def __init__(self, config: VisionConfig):
78
+ super().__init__()
79
+ self.vision_tower = VisionModel(config)
80
+
81
+ def __call__(
82
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
83
+ ) -> mx.array:
84
+ return self.vision_tower(x, output_hidden_states)
85
+
86
+
87
+ class Model(nn.Module):
88
+ def __init__(self, config: ModelConfig):
89
+ super().__init__()
90
+ self.model_type = config.model_type
91
+ self.config = config
92
+
93
+ self.vision_tower = SigLipVisionTower(config.vision_config)
94
+ self.language_model = LanguageModel(config.text_config)
95
+ self.mm_projector = LlavaMultiModalProjector(config)
96
+
97
+ def get_input_embeddings(
98
+ self,
99
+ input_ids: Optional[mx.array] = None,
100
+ pixel_values: Optional[mx.array] = None,
101
+ **kwargs,
102
+ ):
103
+ if pixel_values is None:
104
+ return InputEmbeddingsFeatures(
105
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
106
+ )
107
+
108
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
109
+
110
+ *_, hidden_state = self.vision_tower(
111
+ pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
112
+ )
113
+
114
+ image_features = hidden_state[-1].astype(pixel_values.dtype)
115
+ assert image_features.shape[-2] == 729
116
+
117
+ image_features = self.mm_projector(image_features)
118
+
119
+ final_inputs_embeds = self._prepare_inputs_for_multimodal(
120
+ image_features, inputs_embeds, input_ids
121
+ )
122
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
123
+
124
+ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
125
+ image_token_index = self.config.image_token_index
126
+ num_images, num_image_patches, embed_dim = image_features.shape
127
+
128
+ batch_size, seq_length, embed_dim = inputs_embeds.shape
129
+ num_images, num_image_patches, _ = image_features.shape
130
+
131
+ # Positions of <image> tokens in input_ids for each batch
132
+ image_positions = mx.argmax(input_ids == image_token_index, axis=1)
133
+
134
+ final_embeddings = []
135
+ for b in range(batch_size):
136
+ text_segments = []
137
+ start_idx = 0
138
+ position = int(image_positions[b].item())
139
+
140
+ text_segments.append(inputs_embeds[b : b + 1, start_idx:position])
141
+ text_segments.append(image_features[b : b + 1])
142
+ text_segments.append(inputs_embeds[b : b + 1, position + 1 :])
143
+
144
+ batch_embeddings = mx.concatenate(text_segments, axis=1)
145
+ final_embeddings.append(batch_embeddings)
146
+
147
+ # Create a final embedding of shape
148
+ # (batch_size, num_image_patches + sequence_len, embed_dim)
149
+ return mx.concatenate(final_embeddings, axis=0)
150
+
151
+ @property
152
+ def layers(self):
153
+ return self.language_model.model.layers
154
+
155
+ def __call__(
156
+ self,
157
+ input_ids: mx.array,
158
+ pixel_values: mx.array,
159
+ mask: Optional[mx.array] = None,
160
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
161
+ **kwargs,
162
+ ):
163
+ input_embeddings_features = self.get_input_embeddings(input_ids, pixel_values)
164
+ logits = self.language_model(
165
+ inputs=input_ids,
166
+ cache=cache,
167
+ inputs_embeds=input_embeddings_features.inputs_embeds,
168
+ mask=None, # TODO: add mask
169
+ )
170
+ return logits
171
+
172
+ def sanitize(self, weights):
173
+ weights = {
174
+ (
175
+ f"{k.split('.', 1)[1]}"
176
+ if re.match(r"^model\.vision_tower", k)
177
+ else (
178
+ f"mm_projector.linear_1.{k.split('.')[-1]}"
179
+ if re.match(r"^model\.mm_projector\.0", k)
180
+ else (
181
+ f"mm_projector.linear_2.{k.split('.')[-1]}"
182
+ if re.match(r"^model\.mm_projector\.2", k)
183
+ else (
184
+ f"language_model.model.{k}"
185
+ if re.match(r"^lm_head", k)
186
+ else (
187
+ f"language_model.{k}"
188
+ if re.match(r"^model\.(embed_tokens|norm|layers)", k)
189
+ else k
190
+ )
191
+ )
192
+ )
193
+ )
194
+ ): v
195
+ for k, v in weights.items()
196
+ }
197
+
198
+ weights = {
199
+ (
200
+ f"vision_tower.vision_tower.vision_model.head.attention.in_proj.bias"
201
+ if re.match(
202
+ r"^vision_tower\.vision_tower\.vision_model\.head\.attention\.in_proj_bias",
203
+ k,
204
+ )
205
+ else (
206
+ f"vision_tower.vision_tower.vision_model.head.attention.in_proj.weight"
207
+ if re.match(
208
+ r"^vision_tower\.vision_tower\.vision_model\.head\.attention\.in_proj_weight",
209
+ k,
210
+ )
211
+ else k
212
+ )
213
+ ): v
214
+ for k, v in weights.items()
215
+ }
216
+
217
+ return weights