fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,202 @@
1
+ """Vision encoder for Jina VLM in MLX."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from .config import VisionConfig
9
+
10
+
11
+ class PatchEmbedding(nn.Module):
12
+ """Patch embedding using linear projection."""
13
+
14
+ def __init__(self, config: VisionConfig):
15
+ super().__init__()
16
+ self.patch_size = config.patch_size
17
+ self.num_channels = config.num_channels
18
+ self.hidden_size = config.hidden_size
19
+
20
+ # Linear projection for patches - named to match weights
21
+ patch_dim = config.num_channels * config.patch_size * config.patch_size
22
+ self.proj = nn.Linear(patch_dim, config.hidden_size, bias=config.use_bias)
23
+
24
+ def __call__(self, x: mx.array) -> Tuple[mx.array, Tuple[int, int]]:
25
+ if x.ndim == 3:
26
+ # Already patchified: (B, n_patches, patch_dim)
27
+ B, n_patches, _ = x.shape
28
+ nH = nW = int(n_patches**0.5)
29
+ x = self.proj(x)
30
+ else:
31
+ # Image format: (B, C, H, W)
32
+ B, C, H, W = x.shape
33
+ pH, pW = self.patch_size, self.patch_size
34
+ nH, nW = H // pH, W // pW
35
+ x = x.reshape(B, C, nH, pH, nW, pW)
36
+ x = x.transpose(0, 2, 4, 1, 3, 5)
37
+ x = x.reshape(B, nH * nW, C * pH * pW)
38
+ x = self.proj(x)
39
+ return x, (nH, nW)
40
+
41
+
42
+ class VisionMLP(nn.Module):
43
+ """MLP for vision transformer - matches weight naming: ffn.up, ffn.down"""
44
+
45
+ def __init__(self, config: VisionConfig):
46
+ super().__init__()
47
+ # Named to match weights: ffn.up, ffn.down
48
+ self.up = nn.Linear(
49
+ config.hidden_size, config.intermediate_size, bias=config.use_bias
50
+ )
51
+ self.down = nn.Linear(
52
+ config.intermediate_size, config.hidden_size, bias=config.use_bias
53
+ )
54
+ # Use built-in GELU with tanh approximation
55
+ if config.activation == "gelu_pytorch_tanh":
56
+ self.gelu = nn.GELU(approx="tanh")
57
+ else:
58
+ self.gelu = nn.GELU()
59
+
60
+ def __call__(self, x: mx.array) -> mx.array:
61
+ x = self.up(x)
62
+ x = self.gelu(x)
63
+ x = self.down(x)
64
+ return x
65
+
66
+
67
+ class VisionAttention(nn.Module):
68
+ """Multi-head self-attention - matches weight naming: attn.qkv, attn.out"""
69
+
70
+ def __init__(self, config: VisionConfig):
71
+ super().__init__()
72
+ self.num_heads = config.num_attention_heads
73
+ self.head_dim = config.head_dim
74
+ self.scale = self.head_dim**-0.5
75
+
76
+ # Fused QKV projection - named to match weights
77
+ self.qkv = nn.Linear(
78
+ config.hidden_size,
79
+ 3 * config.num_attention_heads * config.head_dim,
80
+ bias=config.use_bias,
81
+ )
82
+ self.out = nn.Linear(
83
+ config.num_attention_heads * config.head_dim,
84
+ config.hidden_size,
85
+ bias=config.use_bias,
86
+ )
87
+
88
+ def __call__(self, x: mx.array) -> mx.array:
89
+ B, L, _ = x.shape
90
+ qkv = self.qkv(x)
91
+ qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim)
92
+ qkv = qkv.transpose(2, 0, 3, 1, 4)
93
+ q, k, v = qkv[0], qkv[1], qkv[2]
94
+
95
+ attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
96
+ attn = mx.softmax(attn, axis=-1)
97
+ x = attn @ v
98
+
99
+ x = x.transpose(0, 2, 1, 3).reshape(B, L, -1)
100
+ x = self.out(x)
101
+ return x
102
+
103
+
104
+ class VisionEncoderLayer(nn.Module):
105
+ """Transformer block - matches weight naming: attn_norm, ffn_norm"""
106
+
107
+ def __init__(self, config: VisionConfig):
108
+ super().__init__()
109
+ # Named to match weights: attn_norm, ffn_norm
110
+ self.attn_norm = nn.LayerNorm(
111
+ config.hidden_size, eps=config.layer_norm_eps, bias=config.use_bias
112
+ )
113
+ self.attn = VisionAttention(config)
114
+ self.ffn_norm = nn.LayerNorm(
115
+ config.hidden_size, eps=config.layer_norm_eps, bias=config.use_bias
116
+ )
117
+ self.ffn = VisionMLP(config)
118
+
119
+ def __call__(self, x: mx.array) -> mx.array:
120
+ x = x + self.attn(self.attn_norm(x))
121
+ x = x + self.ffn(self.ffn_norm(x))
122
+ return x
123
+
124
+
125
+ class VisionModel(nn.Module):
126
+ """Vision encoder (SigLIP-style ViT)."""
127
+
128
+ def __init__(self, config: VisionConfig):
129
+ super().__init__()
130
+ self.config = config
131
+ self.model_type = config.model_type
132
+ self.hidden_size = config.hidden_size
133
+ self.vit_layers = config.vit_layers
134
+
135
+ # Named to match weights: patch_embed.proj
136
+ self.patch_embed = PatchEmbedding(config)
137
+
138
+ # Named to match weights: pos_embed (saved as 2D, not 3D)
139
+ num_patches = (config.image_size // config.patch_size) ** 2
140
+ if config.use_cls_token:
141
+ num_patches += 1
142
+ self.cls_token = mx.zeros((1, 1, config.hidden_size))
143
+ else:
144
+ self.cls_token = None
145
+ self.pos_embed = mx.zeros((num_patches, config.hidden_size))
146
+
147
+ # Transformer blocks
148
+ self.layers = [
149
+ VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
150
+ ]
151
+
152
+ # Named to match weights: post_norm
153
+ if config.post_layer_norm:
154
+ self.post_norm = nn.LayerNorm(
155
+ config.hidden_size, eps=config.layer_norm_eps, bias=config.use_bias
156
+ )
157
+ else:
158
+ self.post_norm = None
159
+
160
+ def __call__(self, x: mx.array) -> Tuple[mx.array, List[mx.array]]:
161
+ x, shape = self.patch_embed(x)
162
+
163
+ if self.cls_token is not None:
164
+ B = x.shape[0]
165
+ cls = mx.broadcast_to(self.cls_token, (B, 1, self.hidden_size))
166
+ x = mx.concatenate([cls, x], axis=1)
167
+
168
+ # pos_embed is (num_patches, hidden_size), add batch dim for broadcast
169
+ x = x + self.pos_embed[None, :, :]
170
+
171
+ hidden_states = []
172
+ for layer in self.layers:
173
+ x = layer(x)
174
+ hidden_states.append(x)
175
+
176
+ if self.post_norm is not None:
177
+ x = self.post_norm(x)
178
+ hidden_states.append(x)
179
+
180
+ return x, hidden_states
181
+
182
+ def get_features(self, images: mx.array) -> mx.array:
183
+ """Extract features from specific ViT layers.
184
+
185
+ Note: hidden_states includes all layer outputs plus the post_norm output.
186
+ vit_layers indices (e.g., [-4, -10]) are applied to this full list.
187
+ For 27 layers with post_norm, hidden_states has 28 elements:
188
+ - indices 0-26: layer 0-26 outputs
189
+ - index 27: post_norm output
190
+ So vit_layers=[-4, -10] extracts layers 24 and 18 (not 23 and 17).
191
+ """
192
+ _, hidden_states = self(images)
193
+ # Use full hidden_states including post_norm output for correct indexing
194
+
195
+ features = []
196
+ for layer_idx in self.vit_layers:
197
+ feats = hidden_states[layer_idx]
198
+ if self.cls_token is not None:
199
+ feats = feats[:, 1:]
200
+ features.append(feats)
201
+
202
+ return mx.concatenate(features, axis=-1)
@@ -0,0 +1,447 @@
1
+ import mlx.core as mx
2
+
3
+
4
+ def nearest_interpolate(x, size=None, scale_factor=None):
5
+ """
6
+ Nearest neighbor interpolation that exactly matches PyTorch's behavior.
7
+ """
8
+ # Get input dimensions
9
+ batch_size, channels, in_h, in_w = x.shape
10
+
11
+ # Calculate output dimensions
12
+ if size is not None:
13
+ out_h, out_w = size
14
+ elif scale_factor is not None:
15
+ if isinstance(scale_factor, (int, float)):
16
+ scale_h = scale_w = scale_factor
17
+ else:
18
+ scale_h, scale_w = scale_factor
19
+ out_h, out_w = int(in_h * scale_h), int(in_w * scale_w)
20
+ else:
21
+ raise ValueError("Either size or scale_factor must be specified")
22
+
23
+ # Create dimensions tensor
24
+ dims = mx.array([batch_size, channels, in_h, in_w, out_h, out_w], dtype=mx.int32)
25
+
26
+ # Reshape input tensor to 1D for kernel processing
27
+ x_flat = x.reshape(-1)
28
+ input_dtype = x.dtype
29
+ if input_dtype != mx.float32:
30
+ x_flat = x_flat.astype(mx.float32)
31
+
32
+ # Metal kernel source that matches PyTorch's coordinate calculation
33
+ source = """
34
+ uint x_out = thread_position_in_grid.x;
35
+ uint y_out = thread_position_in_grid.y;
36
+ uint bc_idx = thread_position_in_grid.z;
37
+
38
+ int batch_size = dims[0];
39
+ int channels = dims[1];
40
+ int in_h = dims[2];
41
+ int in_w = dims[3];
42
+ int out_h = dims[4];
43
+ int out_w = dims[5];
44
+
45
+ if (x_out >= (uint)out_w || y_out >= (uint)out_h || bc_idx >= (uint)(batch_size * channels))
46
+ return;
47
+
48
+ int c = bc_idx % channels;
49
+ int b = bc_idx / channels;
50
+
51
+ // PyTorch's coordinate calculation for nearest neighbor
52
+ // This matches: torch.nn.functional.interpolate(..., mode='nearest')
53
+ float scale_h = float(in_h) / float(out_h);
54
+ float scale_w = float(in_w) / float(out_w);
55
+
56
+ // PyTorch uses floor for nearest neighbor coordinate mapping
57
+ int y_in = int(floor(float(y_out) * scale_h));
58
+ int x_in = int(floor(float(x_out) * scale_w));
59
+
60
+ // Clamp to bounds
61
+ y_in = max(0, min(y_in, in_h - 1));
62
+ x_in = max(0, min(x_in, in_w - 1));
63
+
64
+ int input_offset = ((b * channels + c) * in_h + y_in) * in_w + x_in;
65
+ int output_offset = ((b * channels + c) * out_h + y_out) * out_w + x_out;
66
+
67
+ output[output_offset] = input[input_offset];
68
+ """
69
+
70
+ # Create and run kernel
71
+ kernel = mx.fast.metal_kernel(
72
+ name="nearest_interpolation",
73
+ input_names=["input", "dims"],
74
+ output_names=["output"],
75
+ source=source,
76
+ )
77
+
78
+ threadgroup = get_optimal_threadgroup(out_w, out_h)
79
+ outputs = kernel(
80
+ inputs=[x_flat, dims],
81
+ grid=(out_w, out_h, batch_size * channels),
82
+ threadgroup=threadgroup,
83
+ output_shapes=[(batch_size * channels * out_h * out_w,)],
84
+ output_dtypes=[mx.float32],
85
+ )
86
+
87
+ result = outputs[0].reshape(batch_size, channels, out_h, out_w)
88
+ if input_dtype != mx.float32:
89
+ result = result.astype(input_dtype)
90
+
91
+ return result
92
+
93
+
94
+ def bicubic_interpolate(
95
+ x, size=None, scale_factor=None, align_corners=False, antialias=False
96
+ ):
97
+ """
98
+ Bicubic interpolation using MLX's built-in interpolate function.
99
+
100
+ Args:
101
+ x: MLX tensor of shape [B, C, H, W]
102
+ size: Tuple of (out_h, out_w) or None
103
+ scale_factor: Float or tuple of (scale_h, scale_w) or None
104
+ align_corners: Whether to align corners
105
+ antialias: Whether to apply antialiasing
106
+
107
+ Returns:
108
+ Interpolated MLX tensor
109
+ """
110
+ # Get input dimensions
111
+ batch_size, channels, in_h, in_w = x.shape
112
+
113
+ # Calculate output dimensions
114
+ if size is not None:
115
+ out_h, out_w = size
116
+ scale_h, scale_w = out_h / in_h, out_w / in_w
117
+ elif scale_factor is not None:
118
+ if isinstance(scale_factor, (int, float)):
119
+ scale_h = scale_w = scale_factor
120
+ else:
121
+ scale_h, scale_w = scale_factor
122
+ out_h, out_w = int(in_h * scale_h), int(in_w * scale_w)
123
+ else:
124
+ raise ValueError("Either size or scale_factor must be specified")
125
+
126
+ # Calculate antialiasing parameters
127
+ # PyTorch uses support = 2.0 for bicubic when antialiasing
128
+ support = 2.0
129
+ antialias_flag = 1.0 if (antialias and (scale_h < 1.0 or scale_w < 1.0)) else 0.0
130
+
131
+ # When downsampling with antialias, PyTorch expands the filter support
132
+ if antialias and scale_h < 1.0:
133
+ filter_scale_h = 1.0 / scale_h
134
+ else:
135
+ filter_scale_h = 1.0
136
+
137
+ if antialias and scale_w < 1.0:
138
+ filter_scale_w = 1.0 / scale_w
139
+ else:
140
+ filter_scale_w = 1.0
141
+
142
+ # Create parameters tensor
143
+ params = mx.array(
144
+ [
145
+ scale_h,
146
+ scale_w,
147
+ 1.0 if align_corners else 0.0,
148
+ antialias_flag,
149
+ filter_scale_h,
150
+ filter_scale_w,
151
+ support,
152
+ ],
153
+ dtype=mx.float32,
154
+ )
155
+
156
+ # Create dimensions tensor
157
+ dims = mx.array([batch_size, channels, in_h, in_w, out_h, out_w], dtype=mx.int32)
158
+
159
+ # Reshape input tensor to 1D for kernel processing
160
+ x_flat = x.reshape(-1)
161
+
162
+ # Convert to float32 for processing if needed
163
+ input_dtype = x.dtype
164
+ if input_dtype != mx.float32:
165
+ x_flat = x_flat.astype(mx.float32)
166
+
167
+ header = """
168
+ // Bicubic kernel function
169
+ float cubic_kernel(float x) {
170
+ float absx = fabs(x);
171
+ float absx2 = absx * absx;
172
+ float absx3 = absx2 * absx;
173
+
174
+ const float a = -0.5f;
175
+
176
+ if (absx <= 1.0f) {
177
+ return (a + 2.0f) * absx3 - (a + 3.0f) * absx2 + 1.0f;
178
+ } else if (absx < 2.0f) {
179
+ return a * absx3 - 5.0f * a * absx2 + 8.0f * a * absx - 4.0f * a;
180
+ }
181
+ return 0.0f;
182
+ }
183
+
184
+ // Antialiased bicubic kernel - scales the support region for downsampling
185
+ float cubic_kernel_antialias(float x, float scale) {
186
+ // When downsampling, we need to integrate over a wider region
187
+ // This matches PyTorch's antialiasing behavior
188
+ return cubic_kernel(x / scale);
189
+ }
190
+ """
191
+
192
+ # Metal kernel source code with antialiasing support
193
+ source = """
194
+ // Get thread position
195
+ uint x_out = thread_position_in_grid.x;
196
+ uint y_out = thread_position_in_grid.y;
197
+ uint bc_idx = thread_position_in_grid.z;
198
+
199
+ // Extract dimensions
200
+ int batch_size = dims[0];
201
+ int channels = dims[1];
202
+ int in_h = dims[2];
203
+ int in_w = dims[3];
204
+ int out_h = dims[4];
205
+ int out_w = dims[5];
206
+
207
+ // Extract parameters
208
+ float scale_h = params[0];
209
+ float scale_w = params[1];
210
+ bool align_corners = params[2] > 0.5f;
211
+ bool use_antialias = params[3] > 0.5f;
212
+ float filter_scale_h = params[4];
213
+ float filter_scale_w = params[5];
214
+ float support = params[6];
215
+
216
+ // Check bounds
217
+ if (x_out >= (uint)out_w || y_out >= (uint)out_h || bc_idx >= (uint)(batch_size * channels))
218
+ return;
219
+
220
+ // Calculate batch and channel indices
221
+ int c = bc_idx % channels;
222
+ int b = bc_idx / channels;
223
+
224
+ // Calculate input coordinates
225
+ float x_in, y_in;
226
+
227
+ if (align_corners && out_w > 1 && out_h > 1) {
228
+ x_in = float(x_out) * (in_w - 1) / (out_w - 1);
229
+ y_in = float(y_out) * (in_h - 1) / (out_h - 1);
230
+ } else {
231
+ // PyTorch's default coordinate mapping
232
+ x_in = ((float(x_out) + 0.5f) / float(out_w)) * float(in_w) - 0.5f;
233
+ y_in = ((float(y_out) + 0.5f) / float(out_h)) * float(in_h) - 0.5f;
234
+ }
235
+
236
+ // Calculate the support region based on antialiasing
237
+ float support_h = use_antialias ? support * filter_scale_h : support;
238
+ float support_w = use_antialias ? support * filter_scale_w : support;
239
+
240
+ // Calculate the range of input pixels to sample
241
+ int y_start = int(floor(y_in - support_h)) + 1;
242
+ int y_end = int(floor(y_in + support_h)) + 1;
243
+ int x_start = int(floor(x_in - support_w)) + 1;
244
+ int x_end = int(floor(x_in + support_w)) + 1;
245
+
246
+ // Clamp to valid range
247
+ y_start = max(0, y_start);
248
+ y_end = min(in_h, y_end);
249
+ x_start = max(0, x_start);
250
+ x_end = min(in_w, x_end);
251
+
252
+ // Perform bicubic interpolation with antialiasing
253
+ float result = 0.0f;
254
+ float weight_sum = 0.0f;
255
+
256
+ for (int y_pos = y_start; y_pos < y_end; y_pos++) {
257
+ float dy = float(y_pos) - y_in;
258
+ float wy = use_antialias ?
259
+ cubic_kernel_antialias(dy, filter_scale_h) :
260
+ cubic_kernel(dy);
261
+
262
+ for (int x_pos = x_start; x_pos < x_end; x_pos++) {
263
+ float dx = float(x_pos) - x_in;
264
+ float wx = use_antialias ?
265
+ cubic_kernel_antialias(dx, filter_scale_w) :
266
+ cubic_kernel(dx);
267
+
268
+ float weight = wy * wx;
269
+
270
+ // Calculate input tensor offset
271
+ int input_offset = ((b * channels + c) * in_h + y_pos) * in_w + x_pos;
272
+
273
+ // Add weighted contribution
274
+ result += input[input_offset] * weight;
275
+ weight_sum += weight;
276
+ }
277
+ }
278
+
279
+ // Normalize by weight sum
280
+ if (weight_sum > 1e-8f) {
281
+ result /= weight_sum;
282
+ }
283
+
284
+ // Calculate output tensor offset
285
+ int output_offset = ((b * channels + c) * out_h + y_out) * out_w + x_out;
286
+
287
+ // Assign the result to output
288
+ output[output_offset] = result;
289
+ """
290
+
291
+ # Create the kernel
292
+ kernel = mx.fast.metal_kernel(
293
+ name="bicubic_interpolation_antialias",
294
+ input_names=["input", "dims", "params"],
295
+ output_names=["output"],
296
+ source=source,
297
+ header=header,
298
+ )
299
+
300
+ # Run the kernel
301
+ threadgroup = get_optimal_threadgroup(out_w, out_h)
302
+ outputs = kernel(
303
+ inputs=[x_flat, dims, params],
304
+ grid=(out_w, out_h, batch_size * channels),
305
+ threadgroup=threadgroup,
306
+ output_shapes=[(batch_size * channels * out_h * out_w,)],
307
+ output_dtypes=[mx.float32],
308
+ )
309
+
310
+ # Reshape output back to 4D tensor and convert back to original dtype
311
+ result = outputs[0].reshape(batch_size, channels, out_h, out_w)
312
+ if input_dtype != mx.float32:
313
+ result = result.astype(input_dtype)
314
+
315
+ return result
316
+
317
+
318
+ def grid_sample(x, grid):
319
+ """
320
+ Grid sample using MLX's built-in interpolate function.
321
+ Args:
322
+ x: MLX tensor of shape [B, C, H, W]
323
+ grid: MLX tensor of shape [B, gN, gM, 2]
324
+
325
+ Returns:
326
+ Interpolated MLX tensor
327
+ """
328
+
329
+ assert x.ndim == 4, "`x` must be 4D."
330
+ assert grid.ndim == 4, "`grid` must be 4D."
331
+
332
+ B, _, _, C = x.shape
333
+ _, gN, gM, D = grid.shape
334
+ out_shape = (B, gN, gM, C)
335
+
336
+ assert D == 2, "Last dim of `grid` must be size 2."
337
+
338
+ source = """
339
+ uint elem = thread_position_in_grid.x;
340
+ int H = x_shape[1];
341
+ int W = x_shape[2];
342
+ int C = x_shape[3];
343
+ int gH = grid_shape[1];
344
+ int gW = grid_shape[2];
345
+
346
+ int w_stride = C;
347
+ int h_stride = W * w_stride;
348
+ int b_stride = H * h_stride;
349
+
350
+ uint grid_idx = elem / C * 2;
351
+ float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
352
+ float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
353
+
354
+ int ix_nw = floor(ix);
355
+ int iy_nw = floor(iy);
356
+
357
+ int ix_ne = ix_nw + 1;
358
+ int iy_ne = iy_nw;
359
+
360
+ int ix_sw = ix_nw;
361
+ int iy_sw = iy_nw + 1;
362
+
363
+ int ix_se = ix_nw + 1;
364
+ int iy_se = iy_nw + 1;
365
+
366
+ T nw = (ix_se - ix) * (iy_se - iy);
367
+ T ne = (ix - ix_sw) * (iy_sw - iy);
368
+ T sw = (ix_ne - ix) * (iy - iy_ne);
369
+ T se = (ix - ix_nw) * (iy - iy_nw);
370
+
371
+ int batch_idx = elem / C / gH / gW * b_stride;
372
+ int channel_idx = elem % C;
373
+ int base_idx = batch_idx + channel_idx;
374
+
375
+ T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
376
+ T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
377
+ T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
378
+ T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
379
+
380
+ I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
381
+ I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
382
+ I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
383
+ I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
384
+
385
+ out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
386
+ """
387
+
388
+ kernel = mx.fast.metal_kernel(
389
+ name="grid_sample",
390
+ input_names=["x", "grid"],
391
+ output_names=["out"],
392
+ source=source,
393
+ )
394
+
395
+ outputs = kernel(
396
+ inputs=[x, grid],
397
+ template=[("T", x.dtype)],
398
+ output_shapes=[out_shape],
399
+ output_dtypes=[x.dtype],
400
+ grid=(mx.prod(mx.array(out_shape)), 1, 1),
401
+ threadgroup=(256, 1, 1),
402
+ )
403
+ return outputs[0]
404
+
405
+
406
+ def get_optimal_threadgroup(out_w, out_h):
407
+ # Calculate optimal threadgroup dimensions based on output dimensions
408
+
409
+ # Maximum threadgroup size for most Metal GPUs
410
+ # This could be made more dynamic with Metal API queries if needed
411
+ MAX_THREADS_PER_GROUP = 1024
412
+ MAX_THREADS_PER_DIM = 1024
413
+
414
+ # Start with a reasonable default size for 2D workloads
415
+ default_threadgroup = (32, 32, 1)
416
+
417
+ try:
418
+ # Don't create threadgroups larger than the work dimensions
419
+ max_width = min(MAX_THREADS_PER_DIM, out_w)
420
+ max_height = min(MAX_THREADS_PER_DIM, out_h)
421
+
422
+ # Find largest power of 2 that fits within our dimensions
423
+ width = 2 ** (max_width.bit_length() - 1)
424
+ if width > max_width:
425
+ width = width // 2
426
+
427
+ height = 2 ** (max_height.bit_length() - 1)
428
+ if height > max_height:
429
+ height = height // 2
430
+
431
+ # Ensure we don't exceed maximum threads per threadgroup
432
+ while width * height > MAX_THREADS_PER_GROUP:
433
+ # Reduce the larger dimension first
434
+ if width >= height:
435
+ width = width // 2
436
+ else:
437
+ height = height // 2
438
+
439
+ # Ensure minimum size for efficiency
440
+ width = max(8, width)
441
+ height = max(8, height)
442
+
443
+ return (width, height, 1)
444
+
445
+ except Exception:
446
+ # Return safe defaults if calculation fails
447
+ return default_threadgroup
@@ -0,0 +1,4 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .kimi_vl import LanguageModel, Model, VisionModel
3
+ from .processing_kimi_vl import KimiVLImageProcessor as ImageProcessor
4
+ from .processing_kimi_vl import KimiVLProcessor as Processor