fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,200 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from .config import VisionConfig
9
+
10
+
11
+ def create_patches(pixel_values: mx.array, patch_size: int) -> mx.array:
12
+ """
13
+ Convert [B, C, H, W] images to [B, num_patches, patch_dim] patch sequences.
14
+
15
+ This matches the PyTorch reference implementation exactly.
16
+ """
17
+ B, C, H, W = pixel_values.shape
18
+ P = patch_size
19
+
20
+ # Reshape to [B, C, H/P, P, W/P, P]
21
+ x = pixel_values.reshape(B, C, H // P, P, W // P, P)
22
+
23
+ # Permute to [B, H/P, W/P, C, P, P]
24
+ x = x.transpose(0, 2, 4, 1, 3, 5)
25
+
26
+ # Flatten to [B, (H/P)*(W/P), C*P*P]
27
+ num_patches = (H // P) * (W // P)
28
+ patch_dim = C * P * P
29
+ x = x.reshape(B, num_patches, patch_dim)
30
+
31
+ return x
32
+
33
+
34
+ class Attention(nn.Module):
35
+ """Multi-head attention with combined QKV projection."""
36
+
37
+ def __init__(self, config: VisionConfig):
38
+ super().__init__()
39
+ self.hidden_size = config.hidden_size
40
+ self.num_heads = config.num_attention_heads
41
+ self.head_dim = self.hidden_size // self.num_heads
42
+ self.scale = self.head_dim**-0.5
43
+
44
+ # Combined QKV projection (like original moondream)
45
+ self.qkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
46
+ self.proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
47
+
48
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
49
+ B, L, _ = x.shape
50
+
51
+ # Combined QKV projection then split
52
+ qkv = self.qkv(x)
53
+ q, k, v = mx.split(qkv, 3, axis=-1)
54
+
55
+ # Reshape for multi-head attention
56
+ q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
57
+ k = k.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
58
+ v = v.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
59
+
60
+ # Scaled dot-product attention
61
+ output = mx.fast.scaled_dot_product_attention(
62
+ q, k, v, scale=self.scale, mask=mask
63
+ )
64
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
65
+
66
+ return self.proj(output)
67
+
68
+
69
+ class MLP(nn.Module):
70
+ """Feed-forward network with GELU activation."""
71
+
72
+ def __init__(self, config: VisionConfig):
73
+ super().__init__()
74
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
75
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
76
+ self.activation = nn.GELU(approx="precise")
77
+
78
+ def __call__(self, x: mx.array) -> mx.array:
79
+ x = self.activation(self.fc1(x))
80
+ x = self.fc2(x)
81
+ return x
82
+
83
+
84
+ class EncoderLayer(nn.Module):
85
+ """
86
+ Single transformer encoder layer with POST-NORM architecture.
87
+
88
+ Key difference from standard: residual addition happens BEFORE normalization.
89
+ Pattern: x = x + attn(ln(x))
90
+ """
91
+
92
+ def __init__(self, config: VisionConfig):
93
+ super().__init__()
94
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
95
+ self.attn = Attention(config)
96
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
97
+ self.mlp = MLP(config)
98
+
99
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
100
+ # Post-norm: x = x + sublayer(norm(x))
101
+ x = x + self.attn(self.ln1(x), mask)
102
+ x = x + self.mlp(self.ln2(x))
103
+ return x
104
+
105
+
106
+ class VisionEncoder(nn.Module):
107
+ """Vision encoder with transformer layers."""
108
+
109
+ def __init__(self, config: VisionConfig):
110
+ super().__init__()
111
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
112
+
113
+ def __call__(
114
+ self,
115
+ x: mx.array,
116
+ mask: Optional[mx.array] = None,
117
+ output_hidden_states: bool = False,
118
+ ):
119
+ encoder_states = (x,) if output_hidden_states else None
120
+
121
+ for layer in self.layers:
122
+ x = layer(x, mask)
123
+ if output_hidden_states:
124
+ encoder_states = encoder_states + (x,)
125
+
126
+ return x, encoder_states
127
+
128
+
129
+ class VisionModel(nn.Module):
130
+ """
131
+ Moondream2 vision encoder.
132
+
133
+ Architecture:
134
+ 1. Linear patch embedding (not Conv2d): 588 (14×14×3) -> 1152
135
+ 2. Add learnable positional embeddings
136
+ 3. 27 transformer layers with post-norm
137
+ 4. Final layer norm
138
+
139
+ Reference: moondream2/vision.py (PyTorch)
140
+ """
141
+
142
+ def __init__(self, config: VisionConfig):
143
+ super().__init__()
144
+ self.config = config
145
+
146
+ # Patch embedding: linear projection of flattened patches
147
+ patch_dim = config.patch_size * config.patch_size * config.num_channels # 588
148
+ self.patch_emb = nn.Linear(patch_dim, config.hidden_size, bias=True)
149
+
150
+ # Transformer encoder
151
+ self.encoder = VisionEncoder(config)
152
+
153
+ # Post layer norm
154
+ self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
155
+
156
+ # Positional embedding: [1, num_patches, hidden_size]
157
+ # This will be loaded from weights
158
+ num_patches = (config.image_size // config.patch_size) ** 2 # 729
159
+ self.position_embedding = mx.zeros((1, num_patches, config.hidden_size))
160
+
161
+ def __call__(
162
+ self,
163
+ pixel_values: mx.array,
164
+ output_hidden_states: bool = False,
165
+ ) -> mx.array:
166
+ """
167
+ Args:
168
+ pixel_values: [B, C, H, W] input images (normalized to [-1, 1])
169
+ Returns:
170
+ [B, num_patches, hidden_size] vision features
171
+ """
172
+ # Create patches: [B, C, H, W] -> [B, num_patches, patch_dim]
173
+ x = create_patches(pixel_values, self.config.patch_size)
174
+
175
+ # Linear projection: [B, num_patches, patch_dim] -> [B, num_patches, hidden_size]
176
+ x = self.patch_emb(x)
177
+
178
+ # Add positional embedding
179
+ x = x + self.position_embedding
180
+
181
+ # Encode through transformer layers
182
+ x, encoder_states = self.encoder(x, output_hidden_states=output_hidden_states)
183
+
184
+ # Final layer norm
185
+ x = self.post_layernorm(x)
186
+
187
+ if output_hidden_states:
188
+ return x, encoder_states
189
+ return x
190
+
191
+ def sanitize(self, weights):
192
+ """Sanitize vision encoder weights."""
193
+ sanitized_weights = {}
194
+ for k, v in weights.items():
195
+ if "position_ids" in k:
196
+ # Skip position_ids
197
+ continue
198
+ else:
199
+ sanitized_weights[k] = v
200
+ return sanitized_weights
@@ -0,0 +1,4 @@
1
+ from .config import ModelConfig, ProjectorConfig, TextConfig, VisionConfig
2
+ from .language import LanguageModel
3
+ from .multi_modality import ImageProcessor, Model
4
+ from .vision import VisionModel
@@ -0,0 +1,108 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class ProjectorConfig(BaseModelConfig):
10
+ cls: str
11
+ model_type: str
12
+ params: dict
13
+
14
+
15
+ @dataclass
16
+ class TextConfig(BaseModelConfig):
17
+ model_type: str
18
+ hidden_size: int = 4096
19
+ num_hidden_layers: int = 32
20
+ intermediate_size: int = 11008
21
+ num_attention_heads: int = 32
22
+ rms_norm_eps: float = 1e-6
23
+ vocab_size: int = 102400
24
+ num_key_value_heads: int = None
25
+ rope_theta: float = 10000
26
+ rope_traditional: bool = False
27
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
28
+ max_position_embeddings: int = 4096
29
+
30
+ def __post_init__(self):
31
+ if self.num_key_value_heads is None:
32
+ self.num_key_value_heads = self.num_attention_heads
33
+
34
+ if self.rope_scaling:
35
+ required_keys = {"factor", "type"}
36
+ if not all(key in self.rope_scaling for key in required_keys):
37
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
38
+
39
+ if self.rope_scaling["type"] != "linear":
40
+ raise ValueError("rope_scaling 'type' currently only supports 'linear'")
41
+
42
+
43
+ @dataclass
44
+ class VisionConfig(BaseModelConfig):
45
+ model_type: str
46
+ num_hidden_layers: int = 24
47
+ hidden_size: int = 1024
48
+ intermediate_size: int = 4096
49
+ num_attention_heads: int = 16
50
+ image_size: int = 384
51
+ patch_size: int = 16
52
+ num_channels: int = 3
53
+ layer_norm_eps: float = 1e-5
54
+ cls: str = None
55
+ params: dict = None
56
+
57
+ def __post_init__(self):
58
+ if "high_res_cfg" in self.params:
59
+ self.image_size = self.params["high_res_cfg"]["image_size"]
60
+
61
+
62
+ @dataclass
63
+ class MLPConfig(BaseModelConfig):
64
+ hidden_size: int
65
+ intermediate_size: int
66
+
67
+
68
+ @dataclass
69
+ class SAMViTCfg:
70
+ image_size: Union[Tuple[int, int], int] = 1024
71
+ width: int = 768
72
+ layers: int = 12
73
+ heads: int = 12
74
+ patch_size: int = 16
75
+ window_size: int = 14
76
+ prompt_embed_dim: int = 256
77
+ global_attn_indexes: Union[List[int], Tuple[int]] = (2, 5, 8, 11)
78
+ downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
79
+
80
+
81
+ @dataclass
82
+ class ModelConfig(BaseModelConfig):
83
+ text_config: TextConfig
84
+ vision_config: VisionConfig
85
+ projector_config: ProjectorConfig
86
+ model_type: str
87
+ ignore_index: int = -100
88
+ image_token_index: int = 100015
89
+ vision_feature_select_strategy: str = "default"
90
+ select_layer: int = -1
91
+ pad_id: int = 100001
92
+ num_image_tokens: int = 576
93
+ vocab_size: int = 32000
94
+ eos_token_id: Optional[List[int]] = None
95
+
96
+ @classmethod
97
+ def from_dict(cls, params):
98
+ if "aligner_config" in params:
99
+ params["projector_config"] = params["aligner_config"]
100
+ del params["aligner_config"]
101
+
102
+ return cls(
103
+ **{
104
+ k: v
105
+ for k, v in params.items()
106
+ if k in inspect.signature(cls).parameters
107
+ }
108
+ )
@@ -0,0 +1,191 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, config: TextConfig):
17
+ super().__init__()
18
+
19
+ dim = config.hidden_size
20
+ self.n_heads = n_heads = config.num_attention_heads
21
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
22
+
23
+ self.repeats = n_heads // n_kv_heads
24
+
25
+ head_dim = config.hidden_size // n_heads
26
+ self.scale = head_dim**-0.5
27
+
28
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
29
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
30
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
31
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
32
+
33
+ rope_scale = (
34
+ 1 / config.rope_scaling["factor"]
35
+ if config.rope_scaling is not None
36
+ and config.rope_scaling["type"] == "linear"
37
+ else 1
38
+ )
39
+ self.rope = nn.RoPE(
40
+ head_dim,
41
+ traditional=config.rope_traditional,
42
+ base=config.rope_theta,
43
+ scale=rope_scale,
44
+ )
45
+
46
+ def __call__(
47
+ self,
48
+ x: mx.array,
49
+ mask: Optional[mx.array] = None,
50
+ cache: Optional[KVCache] = None,
51
+ ) -> mx.array:
52
+ B, L, D = x.shape
53
+
54
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
55
+
56
+ # Prepare the queries, keys and values for the attention computation
57
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
58
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
59
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
60
+
61
+ if cache is not None:
62
+ queries = self.rope(queries, offset=cache.offset)
63
+ keys = self.rope(keys, offset=cache.offset)
64
+ keys, values = cache.update_and_fetch(keys, values)
65
+ else:
66
+ queries = self.rope(queries)
67
+ keys = self.rope(keys)
68
+
69
+ output = scaled_dot_product_attention(
70
+ queries, keys, values, cache, scale=self.scale, mask=mask
71
+ )
72
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
73
+ return self.o_proj(output)
74
+
75
+
76
+ class MLP(nn.Module):
77
+ def __init__(self, dim, hidden_dim):
78
+ super().__init__()
79
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
80
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
81
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
82
+
83
+ def __call__(self, x) -> mx.array:
84
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
85
+
86
+
87
+ class TransformerBlock(nn.Module):
88
+ def __init__(self, config: TextConfig):
89
+ super().__init__()
90
+ self.num_attention_heads = config.num_attention_heads
91
+ self.hidden_size = config.hidden_size
92
+ self.self_attn = Attention(config)
93
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
94
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
95
+ self.post_attention_layernorm = nn.RMSNorm(
96
+ config.hidden_size, eps=config.rms_norm_eps
97
+ )
98
+ self.config = config
99
+
100
+ def __call__(
101
+ self,
102
+ x: mx.array,
103
+ mask: Optional[mx.array] = None,
104
+ cache: Optional[KVCache] = None,
105
+ ) -> mx.array:
106
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
107
+ h = x + r
108
+ r = self.mlp(self.post_attention_layernorm(h))
109
+ out = h + r
110
+ return out
111
+
112
+
113
+ class Llama(nn.Module):
114
+ def __init__(self, config: TextConfig):
115
+ super().__init__()
116
+ self.config = config
117
+ self.vocab_size = config.vocab_size
118
+ self.num_hidden_layers = config.num_hidden_layers
119
+ assert self.vocab_size > 0
120
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
121
+ self.layers = [
122
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
123
+ ]
124
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
125
+
126
+ def __call__(
127
+ self,
128
+ inputs: mx.array,
129
+ inputs_embeds: Optional[mx.array] = None,
130
+ mask: Optional[mx.array] = None,
131
+ cache=None,
132
+ ):
133
+ # for passing merged input embeddings
134
+ if inputs_embeds is None:
135
+ h = self.embed_tokens(inputs)
136
+ else:
137
+ h = inputs_embeds
138
+
139
+ if cache is None:
140
+ cache = [None] * len(self.layers)
141
+
142
+ if mask is None:
143
+ mask = create_attention_mask(h, cache)
144
+
145
+ for layer, c in zip(self.layers, cache):
146
+ h = layer(h, mask, c)
147
+
148
+ return self.norm(h)
149
+
150
+
151
+ class LanguageModel(nn.Module):
152
+ def __init__(self, config: TextConfig):
153
+ super().__init__()
154
+ self.config = config
155
+ self.model_type = config.model_type
156
+ if self.model_type != "llama":
157
+ raise ValueError(
158
+ f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
159
+ )
160
+ self.model = Llama(config)
161
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
162
+
163
+ def __call__(
164
+ self,
165
+ inputs: mx.array,
166
+ inputs_embeds: Optional[mx.array] = None,
167
+ mask: Optional[mx.array] = None,
168
+ cache=None,
169
+ ):
170
+ out = self.model(inputs, mask=mask, cache=cache, inputs_embeds=inputs_embeds)
171
+ logits = self.lm_head(out)
172
+ return LanguageModelOutput(logits=logits)
173
+
174
+ @staticmethod
175
+ def sanitize(weights):
176
+ # Remove unused precomputed rotary freqs
177
+ return {
178
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
179
+ }
180
+
181
+ @property
182
+ def layers(self):
183
+ return self.model.layers
184
+
185
+ @property
186
+ def head_dim(self):
187
+ return self.config.hidden_size // self.config.num_attention_heads
188
+
189
+ @property
190
+ def n_kv_heads(self):
191
+ return self.config.num_key_value_heads