fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,267 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from .config import TextConfig
12
+
13
+
14
+ class Attention(nn.Module):
15
+ """Multi-head attention with partial RoPE (32 dims)."""
16
+
17
+ def __init__(self, config: TextConfig):
18
+ super().__init__()
19
+ self.config = config
20
+ self.hidden_size = config.hidden_size
21
+ self.num_heads = config.num_attention_heads
22
+ self.num_kv_heads = config.num_key_value_heads
23
+ self.head_dim = self.hidden_size // self.num_heads
24
+ self.scale = self.head_dim**-0.5
25
+
26
+ # Combined QKV projection (like original moondream)
27
+ qkv_dim = self.hidden_size + 2 * (self.num_kv_heads * self.head_dim)
28
+ self.qkv_proj = nn.Linear(self.hidden_size, qkv_dim, bias=True)
29
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
30
+
31
+ # Partial RoPE: only apply to first half of head_dim (32 out of 64)
32
+ rope_dims = int(self.head_dim * config.partial_rotary_factor)
33
+ self.rope = nn.RoPE(
34
+ dims=rope_dims,
35
+ traditional=False,
36
+ base=config.rope_theta,
37
+ )
38
+
39
+ def __call__(
40
+ self,
41
+ x: mx.array,
42
+ mask: Optional[mx.array] = None,
43
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
44
+ ) -> mx.array:
45
+ B, L, _ = x.shape
46
+
47
+ # Combined QKV projection
48
+ qkv = self.qkv_proj(x)
49
+
50
+ # Split into Q, K, V
51
+ q_dim = self.num_heads * self.head_dim
52
+ kv_dim = self.num_kv_heads * self.head_dim
53
+ q, k, v = mx.split(qkv, [q_dim, q_dim + kv_dim], axis=-1)
54
+
55
+ # Reshape for attention
56
+ q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
57
+ k = k.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
58
+ v = v.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
59
+
60
+ # Apply partial RoPE
61
+ if cache is not None:
62
+ q = self.rope(q, offset=cache.offset)
63
+ k = self.rope(k, offset=cache.offset)
64
+ k, v = cache.update_and_fetch(k, v)
65
+ else:
66
+ q = self.rope(q)
67
+ k = self.rope(k)
68
+
69
+ # Attention
70
+ output = scaled_dot_product_attention(
71
+ q, k, v, cache=cache, scale=self.scale, mask=mask
72
+ )
73
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
74
+
75
+ return self.o_proj(output)
76
+
77
+
78
+ class MLP(nn.Module):
79
+ """Simple MLP with GELU activation (not gated like Phi3)."""
80
+
81
+ def __init__(self, config: TextConfig):
82
+ super().__init__()
83
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
84
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
85
+ self.activation = nn.GELU(approx="precise")
86
+
87
+ def __call__(self, x: mx.array) -> mx.array:
88
+ x = self.activation(self.fc1(x))
89
+ x = self.fc2(x)
90
+ return x
91
+
92
+
93
+ class TransformerBlock(nn.Module):
94
+ """Transformer block with pre-norm using LayerNorm."""
95
+
96
+ def __init__(self, config: TextConfig):
97
+ super().__init__()
98
+ self.self_attn = Attention(config)
99
+ self.mlp = MLP(config)
100
+ # Moondream uses a single LayerNorm before both attention and MLP
101
+ # The residual pattern is: x + attn(ln(x)) + mlp(ln(x))
102
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
103
+
104
+ def __call__(
105
+ self,
106
+ x: mx.array,
107
+ mask: Optional[mx.array] = None,
108
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
109
+ ) -> mx.array:
110
+ # Moondream uses parallel attention and MLP
111
+ # x = x + attn(ln(x)) + mlp(ln(x))
112
+ normalized = self.input_layernorm(x)
113
+ attn_out = self.self_attn(normalized, mask, cache)
114
+ mlp_out = self.mlp(normalized)
115
+ return x + attn_out + mlp_out
116
+
117
+
118
+ class PhiModel(nn.Module):
119
+ """Core transformer model."""
120
+
121
+ def __init__(self, config: TextConfig):
122
+ super().__init__()
123
+ self.config = config
124
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
125
+ self.layers = [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
126
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
127
+
128
+ def __call__(
129
+ self,
130
+ inputs: mx.array,
131
+ inputs_embeds: Optional[mx.array] = None,
132
+ mask: Optional[mx.array] = None,
133
+ cache=None,
134
+ ):
135
+ if inputs_embeds is None:
136
+ h = self.embed_tokens(inputs)
137
+ else:
138
+ h = inputs_embeds
139
+
140
+ if cache is None:
141
+ cache = [None] * len(self.layers)
142
+
143
+ if mask is None:
144
+ mask = create_attention_mask(h, cache)
145
+
146
+ # Moondream uses a special "prefix attention" mask where the
147
+ # BOS+image patch tokens attend fully within the prefix.
148
+ prefix_len = getattr(self.config, "prefix_attn_len", None)
149
+ try:
150
+ # Only apply during prefill (offset == 0) and only if prefix fits.
151
+ cache0 = cache[0] if isinstance(cache, list) and len(cache) > 0 else None
152
+ offset0 = getattr(cache0, "offset", None)
153
+
154
+ if (
155
+ prefix_len is not None
156
+ and offset0 == 0
157
+ and hasattr(mask, "ndim")
158
+ and mask.ndim >= 4
159
+ and h.shape[1] >= prefix_len
160
+ ):
161
+ if str(mask.dtype) == "bool":
162
+ mask[..., :prefix_len, :prefix_len] = True
163
+ else:
164
+ # For additive masks, 0 indicates "allowed"
165
+ mask[..., :prefix_len, :prefix_len] = 0
166
+ except Exception:
167
+ pass
168
+
169
+ for layer, c in zip(self.layers, cache):
170
+ h = layer(h, mask, c)
171
+
172
+ return self.norm(h)
173
+
174
+
175
+ class LanguageModel(nn.Module):
176
+ """Language model with LM head."""
177
+
178
+ def __init__(self, config: TextConfig):
179
+ super().__init__()
180
+ self.config = config
181
+ self.model = PhiModel(config)
182
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
183
+
184
+ def __call__(
185
+ self,
186
+ inputs: mx.array,
187
+ inputs_embeds: Optional[mx.array] = None,
188
+ mask: Optional[mx.array] = None,
189
+ cache=None,
190
+ **kwargs,
191
+ ):
192
+ # #region agent log
193
+ import json
194
+ log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
195
+ def log_lm(location, message, data, hypothesis_id):
196
+ try:
197
+ with open(log_file, "a") as f:
198
+ f.write(json.dumps({"sessionId": "debug-session", "runId": "inference", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
199
+ except: pass
200
+
201
+ if inputs_embeds is not None:
202
+ log_lm("language.py:lm_input_embeds", "Language model input embeddings", {
203
+ "shape": str(inputs_embeds.shape),
204
+ "dtype": str(inputs_embeds.dtype),
205
+ "mean": float(mx.mean(inputs_embeds)),
206
+ "std": float(mx.std(inputs_embeds)),
207
+ "min": float(mx.min(inputs_embeds)),
208
+ "max": float(mx.max(inputs_embeds)),
209
+ "has_nan": bool(mx.any(mx.isnan(inputs_embeds))),
210
+ "has_inf": bool(mx.any(mx.isinf(inputs_embeds)))
211
+ }, "H9")
212
+ # #endregion
213
+
214
+ h = self.model(inputs, inputs_embeds=inputs_embeds, mask=mask, cache=cache)
215
+
216
+ # #region agent log
217
+ log_lm("language.py:lm_hidden_states", "Language model hidden states after layers", {
218
+ "shape": str(h.shape),
219
+ "dtype": str(h.dtype),
220
+ "mean": float(mx.mean(h)),
221
+ "std": float(mx.std(h)),
222
+ "min": float(mx.min(h)),
223
+ "max": float(mx.max(h)),
224
+ "last_token_mean": float(mx.mean(h[:, -1, :])),
225
+ "last_token_std": float(mx.std(h[:, -1, :]))
226
+ }, "H9")
227
+ # #endregion
228
+
229
+ logits = self.lm_head(h)
230
+
231
+ # #region agent log
232
+ last_logits = logits[0, -1, :]
233
+ top5_idx = mx.argsort(last_logits)[-5:].tolist()
234
+ top5_val = mx.sort(last_logits)[-5:].tolist()
235
+ log_lm("language.py:lm_logits", "Language model logits output", {
236
+ "shape": str(logits.shape),
237
+ "dtype": str(logits.dtype),
238
+ "mean": float(mx.mean(logits)),
239
+ "std": float(mx.std(logits)),
240
+ "min": float(mx.min(logits)),
241
+ "max": float(mx.max(logits)),
242
+ "last_token_logits_mean": float(mx.mean(logits[:, -1, :])),
243
+ "last_token_logits_std": float(mx.std(logits[:, -1, :])),
244
+ "last_token_top5_indices": top5_idx,
245
+ "last_token_top5_values": top5_val
246
+ }, "H9,H10")
247
+ # #endregion
248
+
249
+ return LanguageModelOutput(logits=logits)
250
+
251
+ def sanitize(self, weights):
252
+ """Sanitize language model weights."""
253
+ return {
254
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
255
+ }
256
+
257
+ @property
258
+ def layers(self):
259
+ return self.model.layers
260
+
261
+ @property
262
+ def head_dim(self):
263
+ return self.config.hidden_size // self.config.num_attention_heads
264
+
265
+ @property
266
+ def n_kv_heads(self):
267
+ return self.config.num_key_value_heads