fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,272 @@
1
+ """Language model decoder for Jina VLM in MLX."""
2
+
3
+ from typing import List, Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
8
+ from mlx_lm.models.cache import KVCache
9
+
10
+ from ..base import LanguageModelOutput
11
+ from .config import TextConfig
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ """RMS Layer Normalization."""
16
+
17
+ def __init__(self, dims: int, eps: float = 1e-6):
18
+ super().__init__()
19
+ self.eps = eps
20
+ self.weight = mx.ones((dims,))
21
+
22
+ def __call__(self, x: mx.array) -> mx.array:
23
+ rms = mx.sqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
24
+ return self.weight * (x / rms)
25
+
26
+
27
+ class RoPE(nn.Module):
28
+ """Rotary Positional Embeddings."""
29
+
30
+ def __init__(self, dims: int, theta: float = 1000000.0):
31
+ super().__init__()
32
+ self.dims = dims
33
+ self.theta = theta
34
+ inv_freq = 1.0 / (theta ** (mx.arange(0, dims, 2).astype(mx.float32) / dims))
35
+ self._inv_freq = inv_freq
36
+
37
+ def __call__(self, x: mx.array, offset: int = 0) -> mx.array:
38
+ seq_len = x.shape[2]
39
+ positions = mx.arange(offset, offset + seq_len).astype(mx.float32)
40
+ freqs = positions[:, None] * self._inv_freq[None, :]
41
+ emb = mx.concatenate([freqs, freqs], axis=-1)
42
+ cos = mx.cos(emb)[None, None, :, :]
43
+ sin = mx.sin(emb)[None, None, :, :]
44
+ x1 = x[..., : self.dims // 2]
45
+ x2 = x[..., self.dims // 2 :]
46
+ rotated = mx.concatenate([-x2, x1], axis=-1)
47
+ return (x * cos + rotated * sin).astype(x.dtype)
48
+
49
+
50
+ class Attention(nn.Module):
51
+ """Multi-head attention with GQA and RoPE - matches weight naming: attn.qkv, attn.out"""
52
+
53
+ def __init__(self, config: TextConfig):
54
+ super().__init__()
55
+ self.config = config
56
+ self.num_heads = config.num_attention_heads
57
+ self.num_kv_heads = config.num_key_value_heads
58
+ self.head_dim = config.head_dim
59
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
60
+ self.scale = self.head_dim**-0.5
61
+
62
+ # Fused QKV projection - named to match weights
63
+ qkv_size = (
64
+ config.num_attention_heads + 2 * config.num_key_value_heads
65
+ ) * config.head_dim
66
+ self.qkv = nn.Linear(config.hidden_size, qkv_size, bias=False)
67
+ self.out = nn.Linear(
68
+ config.num_attention_heads * config.head_dim, config.hidden_size, bias=False
69
+ )
70
+
71
+ # QK normalization - named to match weights
72
+ if config.use_qk_norm:
73
+ self.q_norm = RMSNorm(config.head_dim, eps=config.rms_norm_eps)
74
+ self.k_norm = RMSNorm(config.head_dim, eps=config.rms_norm_eps)
75
+ else:
76
+ self.q_norm = None
77
+ self.k_norm = None
78
+
79
+ self.rope = RoPE(config.head_dim, theta=config.rope_theta)
80
+
81
+ def __call__(
82
+ self,
83
+ x: mx.array,
84
+ mask: Optional[mx.array] = None,
85
+ cache: Optional[KVCache] = None,
86
+ ) -> mx.array:
87
+ B, L, _ = x.shape
88
+
89
+ # Compute fused QKV
90
+ qkv = self.qkv(x)
91
+ q_size = self.num_heads * self.head_dim
92
+ kv_size = self.num_kv_heads * self.head_dim
93
+
94
+ q = qkv[..., :q_size]
95
+ k = qkv[..., q_size : q_size + kv_size]
96
+ v = qkv[..., q_size + kv_size :]
97
+
98
+ q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
99
+ k = k.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
100
+ v = v.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
101
+
102
+ if self.q_norm is not None:
103
+ q = self.q_norm(q)
104
+ k = self.k_norm(k)
105
+
106
+ if cache is not None:
107
+ q = self.rope(q, offset=cache.offset)
108
+ k = self.rope(k, offset=cache.offset)
109
+ k, v = cache.update_and_fetch(k, v)
110
+ else:
111
+ q = self.rope(q)
112
+ k = self.rope(k)
113
+
114
+ output = scaled_dot_product_attention(
115
+ q, k, v, cache=cache, scale=self.scale, mask=mask
116
+ )
117
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
118
+ return self.out(output)
119
+
120
+
121
+ class MLP(nn.Module):
122
+ """MLP with SwiGLU - matches weight naming: ffn.gate_up, ffn.down"""
123
+
124
+ def __init__(self, config: TextConfig):
125
+ super().__init__()
126
+ # Fused gate and up projection - named to match weights
127
+ self.gate_up = nn.Linear(
128
+ config.hidden_size, 2 * config.intermediate_size, bias=False
129
+ )
130
+ self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
131
+
132
+ def __call__(self, x: mx.array) -> mx.array:
133
+ gate_up = self.gate_up(x)
134
+ # Jina VLM convention: first half is value, second half is gate (activated)
135
+ up, gate = mx.split(gate_up, 2, axis=-1)
136
+ return self.down(nn.silu(gate) * up)
137
+
138
+
139
+ class TransformerBlock(nn.Module):
140
+ """Transformer block - matches weight naming: attn_norm, ffn_norm"""
141
+
142
+ def __init__(self, config: TextConfig, layer_idx: int = 0):
143
+ super().__init__()
144
+ # Named to match weights
145
+ self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
146
+ self.attn = Attention(config)
147
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
148
+ self.ffn = MLP(config)
149
+
150
+ def __call__(
151
+ self,
152
+ x: mx.array,
153
+ mask: Optional[mx.array] = None,
154
+ cache: Optional[KVCache] = None,
155
+ ) -> mx.array:
156
+ h = self.attn(self.attn_norm(x), mask=mask, cache=cache)
157
+ x = x + h
158
+ x = x + self.ffn(self.ffn_norm(x))
159
+ return x
160
+
161
+
162
+ class ExtendedEmbedding(nn.Module):
163
+ """Embedding with additional tokens - matches weight naming: embedding, new_embedding"""
164
+
165
+ def __init__(self, vocab_size: int, additional_size: int, dims: int):
166
+ super().__init__()
167
+ self.vocab_size = vocab_size
168
+ self.additional_size = additional_size
169
+ # Named to match weights
170
+ self.embedding = mx.zeros((vocab_size, dims))
171
+ self.new_embedding = mx.zeros((additional_size, dims))
172
+
173
+ def __call__(self, x: mx.array) -> mx.array:
174
+ full_embedding = mx.concatenate([self.embedding, self.new_embedding], axis=0)
175
+ return full_embedding[x]
176
+
177
+
178
+ class TextModel(nn.Module):
179
+ """Text decoder model - matches weight naming structure"""
180
+
181
+ def __init__(self, config: TextConfig):
182
+ super().__init__()
183
+ self.config = config
184
+
185
+ # Named to match weights: language_model.embedding
186
+ if config.additional_vocab_size > 0:
187
+ self.embedding = ExtendedEmbedding(
188
+ config.vocab_size, config.additional_vocab_size, config.hidden_size
189
+ )
190
+ else:
191
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
192
+
193
+ self.layers = [
194
+ TransformerBlock(config, layer_idx=i)
195
+ for i in range(config.num_hidden_layers)
196
+ ]
197
+
198
+ # Named to match weights: language_model.ln_f
199
+ self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
200
+
201
+ def __call__(
202
+ self,
203
+ input_ids: mx.array,
204
+ inputs_embeds: Optional[mx.array] = None,
205
+ mask: Optional[mx.array] = None,
206
+ cache: Optional[List[KVCache]] = None,
207
+ ) -> mx.array:
208
+ if inputs_embeds is None:
209
+ x = self.embedding(input_ids)
210
+ else:
211
+ x = inputs_embeds
212
+
213
+ for i, layer in enumerate(self.layers):
214
+ layer_cache = cache[i] if cache is not None else None
215
+ x = layer(x, mask=mask, cache=layer_cache)
216
+
217
+ return self.ln_f(x)
218
+
219
+
220
+ class LanguageModel(nn.Module):
221
+ """Language model wrapper - the TextModel is accessed as language_model in weights"""
222
+
223
+ def __init__(self, config: TextConfig):
224
+ super().__init__()
225
+ self.config = config
226
+ self.model_type = config.model_type
227
+ # This will be loaded under "language_model" prefix
228
+ self.embedding = None # Handled by sanitize
229
+ self.layers = None # Handled by sanitize
230
+ self.ln_f = None # Handled by sanitize
231
+
232
+ # Build the actual model components directly here
233
+ # They'll be found via language_model.embedding, language_model.layers, etc.
234
+ if config.additional_vocab_size > 0:
235
+ self.embedding = ExtendedEmbedding(
236
+ config.vocab_size, config.additional_vocab_size, config.hidden_size
237
+ )
238
+ else:
239
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
240
+
241
+ self.layers = [
242
+ TransformerBlock(config, layer_idx=i)
243
+ for i in range(config.num_hidden_layers)
244
+ ]
245
+ self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+
247
+ def __call__(
248
+ self,
249
+ inputs: mx.array,
250
+ inputs_embeds: Optional[mx.array] = None,
251
+ mask: Optional[mx.array] = None,
252
+ cache: Optional[List[KVCache]] = None,
253
+ **kwargs,
254
+ ) -> LanguageModelOutput:
255
+ if inputs_embeds is None:
256
+ x = self.embedding(inputs)
257
+ else:
258
+ x = inputs_embeds
259
+
260
+ # Initialize cache if needed
261
+ if cache is None:
262
+ cache = [None] * len(self.layers)
263
+
264
+ # Create causal attention mask
265
+ mask = create_attention_mask(x, cache)
266
+
267
+ for i, layer in enumerate(self.layers):
268
+ x = layer(x, mask=mask, cache=cache[i])
269
+
270
+ hidden_states = self.ln_f(x)
271
+ logits = self.lm_head(hidden_states)
272
+ return LanguageModelOutput(logits=logits)
@@ -0,0 +1,266 @@
1
+ """Processor for Jina VLM in MLX-VLM."""
2
+
3
+ from typing import Dict, List, Literal, Optional, Union
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+ from PIL import Image
8
+ from transformers.processing_utils import ProcessorMixin
9
+
10
+ from .image_processor import ImageProcessor
11
+
12
+
13
+ class JinaVLMProcessor(ProcessorMixin):
14
+ """Processor for Jina VLM that combines tokenizer and image processor."""
15
+
16
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
17
+ attributes = ["tokenizer"]
18
+
19
+ def __init__(
20
+ self,
21
+ tokenizer,
22
+ image_token: str = "<|image|>",
23
+ chat_template: Optional[str] = None,
24
+ **kwargs,
25
+ ):
26
+ self.tokenizer = tokenizer
27
+ self.image_token = image_token
28
+ self._image_proc = ImageProcessor() # Internal, not exposed as image_processor
29
+
30
+ # Get image token ID
31
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(image_token)
32
+
33
+ super().__init__(tokenizer, **kwargs)
34
+
35
+ # Set chat template AFTER super().__init__ - always set the default if not already set
36
+ default_chat_template = (
37
+ "{% for message in messages %}"
38
+ "{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}"
39
+ "{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}"
40
+ "{% elif message['role'] == 'assistant' %}{{ '<|assistant|>\n' + message['content'] + '\n' }}"
41
+ "{% endif %}{% endfor %}"
42
+ "{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}"
43
+ )
44
+ if chat_template is not None:
45
+ self.tokenizer.chat_template = chat_template
46
+ elif not self.tokenizer.chat_template:
47
+ self.tokenizer.chat_template = default_chat_template
48
+
49
+ @property
50
+ def chat_template(self):
51
+ return self.tokenizer.chat_template
52
+
53
+ @chat_template.setter
54
+ def chat_template(self, value):
55
+ self.tokenizer.chat_template = value
56
+
57
+ @property
58
+ def pad_token(self):
59
+ return self.tokenizer.pad_token
60
+
61
+ @property
62
+ def pad_token_id(self):
63
+ return self.tokenizer.pad_token_id
64
+
65
+ @property
66
+ def eos_token(self):
67
+ return self.tokenizer.eos_token
68
+
69
+ @property
70
+ def eos_token_id(self):
71
+ return self.tokenizer.eos_token_id
72
+
73
+ @property
74
+ def bos_token(self):
75
+ return self.tokenizer.bos_token
76
+
77
+ @property
78
+ def bos_token_id(self):
79
+ return self.tokenizer.bos_token_id
80
+
81
+ def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
82
+ return self.tokenizer.encode(text, add_special_tokens=add_special_tokens)
83
+
84
+ def decode(self, token_ids: List[int], **kwargs) -> str:
85
+ return self.tokenizer.decode(token_ids, **kwargs)
86
+
87
+ def batch_decode(self, token_ids, **kwargs) -> List[str]:
88
+ return self.tokenizer.batch_decode(token_ids, **kwargs)
89
+
90
+ def process_one(
91
+ self,
92
+ prompt: str,
93
+ images: Optional[List[Image.Image]] = None,
94
+ inference_mode: bool = True,
95
+ ) -> Dict:
96
+ """Process a single prompt with images."""
97
+ if images is None:
98
+ images = []
99
+
100
+ # Process images
101
+ if images:
102
+ image_outputs = self._image_proc.preprocess(images)
103
+ pixel_values_list = image_outputs["pixel_values"]
104
+ image_tokens = image_outputs["image_tokens"]
105
+ image_input_idx_list = image_outputs["image_input_idx"]
106
+ image_masks_list = image_outputs["image_masks"]
107
+ else:
108
+ pixel_values_list = None
109
+ image_tokens = []
110
+ image_input_idx_list = None
111
+ image_masks_list = None
112
+
113
+ # Split prompt by image token
114
+ text_splits = prompt.split(self.image_token)
115
+
116
+ # Build input_ids with image tokens interleaved
117
+ input_ids = []
118
+ current_image_idx = 0
119
+ updated_image_input_idx = []
120
+
121
+ for i, text_part in enumerate(text_splits):
122
+ # Encode text part
123
+ if text_part:
124
+ text_tokens = self.encode(text_part, add_special_tokens=False)
125
+ input_ids.extend(text_tokens)
126
+
127
+ # Add image tokens if not the last split and we have images
128
+ if i < len(text_splits) - 1 and current_image_idx < len(image_tokens):
129
+ # Get image tokens for this image
130
+ img_tokens = image_tokens[current_image_idx]
131
+ # Offset image_input_idx by current position
132
+ if image_input_idx_list is not None and current_image_idx < len(
133
+ image_input_idx_list
134
+ ):
135
+ offset_idx = image_input_idx_list[current_image_idx] + len(
136
+ input_ids
137
+ )
138
+ updated_image_input_idx.append(offset_idx)
139
+ input_ids.extend(img_tokens.tolist())
140
+ current_image_idx += 1
141
+
142
+ input_ids = mx.array(input_ids)
143
+
144
+ result = {
145
+ "input_ids": input_ids[None, :], # Add batch dimension
146
+ "attention_mask": mx.ones_like(input_ids)[None, :],
147
+ }
148
+
149
+ if pixel_values_list is not None and len(pixel_values_list) > 0:
150
+ # Stack pixel values: (n_crops, n_patches, patch_dim)
151
+ result["pixel_values"] = mx.array(np.stack(pixel_values_list))
152
+ # Stack image_input_idx: (n_images, tokens_per_image)
153
+ result["image_input_idx"] = mx.array(np.stack(updated_image_input_idx))
154
+ # Stack image_masks: (n_crops, n_patches)
155
+ result["image_masks"] = mx.array(np.stack(image_masks_list))
156
+
157
+ return result
158
+
159
+ def __call__(
160
+ self,
161
+ text: Optional[Union[str, List[str]]] = None,
162
+ images: Optional[Union[Image.Image, List[Image.Image]]] = None,
163
+ inference_mode: bool = True,
164
+ return_tensors: Literal["np", "mx", "pt"] = "mx",
165
+ **kwargs,
166
+ ) -> Dict:
167
+ """Process text and images for Jina VLM.
168
+
169
+ When called with just text (like a tokenizer), returns tokenizer output.
170
+ When called with text and images, returns full processed inputs.
171
+
172
+ Args:
173
+ text: Input text or list of texts
174
+ images: Input image or list of images
175
+ inference_mode: Whether in inference mode
176
+ return_tensors: Type of tensors to return
177
+
178
+ Returns:
179
+ Dictionary containing processed inputs
180
+ """
181
+ # If called with just text (like a tokenizer), delegate to tokenizer
182
+ if text is not None and images is None:
183
+ return self.tokenizer(text, **kwargs)
184
+
185
+ if text is None:
186
+ raise ValueError("Text must be provided")
187
+
188
+ # Normalize inputs
189
+ if isinstance(text, str):
190
+ texts = [text]
191
+ else:
192
+ texts = text
193
+
194
+ if images is None:
195
+ images_list = [None] * len(texts)
196
+ elif isinstance(images, Image.Image):
197
+ images_list = [[images]]
198
+ elif isinstance(images, list) and len(images) > 0:
199
+ if isinstance(images[0], Image.Image):
200
+ # Single list of images for single prompt
201
+ images_list = [images]
202
+ else:
203
+ images_list = images
204
+ else:
205
+ images_list = [None] * len(texts)
206
+
207
+ # Process each text-image pair
208
+ batch_results = []
209
+ for prompt, imgs in zip(texts, images_list):
210
+ result = self.process_one(prompt, imgs, inference_mode)
211
+ batch_results.append(result)
212
+
213
+ # Collate results
214
+ if len(batch_results) == 1:
215
+ return batch_results[0]
216
+ else:
217
+ return self._collate_batch(batch_results)
218
+
219
+ def _collate_batch(self, batch_results: List[Dict]) -> Dict:
220
+ """Collate multiple results into a batch."""
221
+ # Get max sequence length
222
+ max_len = max(r["input_ids"].shape[1] for r in batch_results)
223
+
224
+ padded_input_ids = []
225
+ padded_attention_mask = []
226
+
227
+ for r in batch_results:
228
+ seq_len = r["input_ids"].shape[1]
229
+ pad_len = max_len - seq_len
230
+
231
+ if pad_len > 0:
232
+ input_ids = mx.concatenate(
233
+ [mx.full((1, pad_len), self.pad_token_id), r["input_ids"]], axis=1
234
+ )
235
+ attention_mask = mx.concatenate(
236
+ [mx.zeros((1, pad_len)), r["attention_mask"]], axis=1
237
+ )
238
+ else:
239
+ input_ids = r["input_ids"]
240
+ attention_mask = r["attention_mask"]
241
+
242
+ padded_input_ids.append(input_ids)
243
+ padded_attention_mask.append(attention_mask)
244
+
245
+ result = {
246
+ "input_ids": mx.concatenate(padded_input_ids, axis=0),
247
+ "attention_mask": mx.concatenate(padded_attention_mask, axis=0),
248
+ }
249
+
250
+ # Combine pixel values if present
251
+ all_pixel_values = []
252
+ all_image_input_idx = []
253
+ all_image_masks = []
254
+
255
+ for r in batch_results:
256
+ if "pixel_values" in r:
257
+ all_pixel_values.append(r["pixel_values"])
258
+ all_image_input_idx.append(r["image_input_idx"])
259
+ all_image_masks.append(r["image_masks"])
260
+
261
+ if all_pixel_values:
262
+ result["pixel_values"] = mx.concatenate(all_pixel_values, axis=0)
263
+ result["image_input_idx"] = mx.concatenate(all_image_input_idx, axis=0)
264
+ result["image_masks"] = mx.concatenate(all_image_masks, axis=0)
265
+
266
+ return result