fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,84 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional, Tuple
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class VisionConfig(BaseModelConfig):
9
+ """Configuration class for Florence2 Vision model (DaViT)."""
10
+
11
+ model_type: str = "davit"
12
+ in_chans: int = 3
13
+ num_classes: int = 1000
14
+ depths: List[int] = field(default_factory=lambda: [1, 1, 9, 1])
15
+ dim_embed: List[int] = field(default_factory=lambda: [128, 256, 512, 1024])
16
+ num_heads: List[int] = field(default_factory=lambda: [4, 8, 16, 32])
17
+ num_groups: List[int] = field(default_factory=lambda: [4, 8, 16, 32])
18
+ window_size: int = 12
19
+ mlp_ratio: float = 4.0
20
+ drop_path_rate: float = 0.1
21
+ patch_size: List[int] = field(default_factory=lambda: [7, 3, 3, 3])
22
+ patch_stride: List[int] = field(default_factory=lambda: [4, 2, 2, 2])
23
+ patch_padding: List[int] = field(default_factory=lambda: [3, 1, 1, 1])
24
+ patch_prenorm: List[bool] = field(
25
+ default_factory=lambda: [False, False, False, False]
26
+ )
27
+ qkv_bias: bool = True
28
+ conv_at_attn: bool = True
29
+ conv_at_ffn: bool = True
30
+ hidden_size: int = 768
31
+ image_size: Tuple[int, int] = (768, 768)
32
+
33
+
34
+ @dataclass
35
+ class TextConfig(BaseModelConfig):
36
+ d_model: int = 768
37
+ model_type: str = "florence2"
38
+ encoder_attention_heads: int = 8
39
+ decoder_attention_heads: int = 8
40
+ encoder_ffn_dim: int = 3072
41
+ decoder_ffn_dim: int = 3072
42
+ dropout: float = 0.1
43
+ attention_dropout: float = 0.0
44
+ activation_dropout: float = 0.0
45
+ activation_function: str = "gelu"
46
+ init_std: float = 0.02
47
+ encoder_layerdrop: float = 0.0
48
+ decoder_layerdrop: float = 0.0
49
+ scale_embedding: bool = False
50
+ use_cache: bool = True
51
+ max_position_embeddings: int = 1024
52
+ vocab_size: int = 51289
53
+ pad_token_id: int = 1
54
+ bos_token_id: int = 0
55
+ eos_token_id: int = 2
56
+ decoder_start_token_id: int = 2
57
+ encoder_layers: int = 6
58
+ decoder_layers: int = 6
59
+
60
+
61
+ @dataclass
62
+ class ModelConfig(BaseModelConfig):
63
+ """Configuration class for Florence2."""
64
+
65
+ vision_config: VisionConfig
66
+ text_config: TextConfig
67
+ model_type: str = "florence2"
68
+ vocab_size: int = 50265
69
+ max_position_embeddings: int = 1024
70
+ pad_token_id: int = 1
71
+ bos_token_id: int = 0
72
+ eos_token_id: int = 2
73
+ image_token_id: int = 51289
74
+ image_token_index: int = 51289
75
+ image_feature_source: List[str] = field(
76
+ default_factory=lambda: ["temporal_avg_pool", "spatial_avg_pool"]
77
+ )
78
+ visual_temporal_embedding: Optional[dict] = field(
79
+ default_factory=lambda: {"type": "COSINE", "max_temporal_embeddings": 100}
80
+ )
81
+ image_pos_embed: Optional[dict] = field(
82
+ default_factory=lambda: {"type": "learned_abs_2d", "max_pos_embeddings": 50}
83
+ )
84
+ eos_token_id: Optional[List[int]] = None
@@ -0,0 +1,383 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ from mlx.utils import tree_map
7
+
8
+ from ..base import InputEmbeddingsFeatures
9
+
10
+ # Import to apply Florence2Processor compatibility patch
11
+ from . import processing_florence2 # noqa: F401
12
+ from .config import ModelConfig
13
+ from .language import LanguageModel
14
+ from .vision import VisionModel
15
+
16
+
17
+ def shift_tokens_right(
18
+ input_ids: mx.array, pad_token_id: int, decoder_start_token_id: int
19
+ ) -> mx.array:
20
+ """Shift input tokens right, adding decoder start token at beginning."""
21
+ shifted = mx.roll(input_ids, 1, axis=-1)
22
+ shifted = tree_map(lambda x: x.at[:, 0].set(decoder_start_token_id), shifted)
23
+ shifted = mx.where(shifted == -100, pad_token_id, shifted)
24
+ return shifted
25
+
26
+
27
+ class LearnedPositionEmbedding2D(nn.Module):
28
+ """2D learned position embeddings."""
29
+
30
+ def __init__(self, embedding_dim: int = 256, num_pos: int = 50):
31
+ super().__init__()
32
+ self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
33
+ self.column_embeddings = nn.Embedding(
34
+ num_pos, embedding_dim - (embedding_dim // 2)
35
+ )
36
+
37
+ def __call__(self, x):
38
+ batch_size, height, width, channels = x.shape
39
+ width_pos = mx.arange(width)
40
+ height_pos = mx.arange(height)
41
+
42
+ x_emb = self.column_embeddings(width_pos)
43
+ y_emb = self.row_embeddings(height_pos)
44
+
45
+ pos = mx.concatenate(
46
+ [
47
+ mx.broadcast_to(x_emb[None, :, :], (height, width, x_emb.shape[-1])),
48
+ mx.broadcast_to(y_emb[:, None, :], (height, width, y_emb.shape[-1])),
49
+ ],
50
+ axis=-1,
51
+ )
52
+
53
+ return mx.broadcast_to(pos[None, ...], (batch_size, height, width, channels))
54
+
55
+
56
+ class PositionalEmbeddingCosine1D(nn.Module):
57
+ """
58
+ MLX implementation of 1D cosine positional embeddings.
59
+
60
+ Args:
61
+ embed_dim: The dimension of the embeddings
62
+ max_seq_len: The maximum length to precompute the positional encodings
63
+ """
64
+
65
+ def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
66
+ super().__init__()
67
+ self.embed_dim = embed_dim
68
+ self.max_seq_len = max_seq_len
69
+
70
+ # Generate position indices and dimension indices
71
+ position = mx.arange(max_seq_len)
72
+ dim_pos = mx.arange(0, embed_dim // 2) # Half the dimensions for sin/cos pairs
73
+
74
+ # Calculate frequency bands
75
+ factor = math.log(10000)
76
+ denominator = mx.exp(-factor * dim_pos / embed_dim)
77
+
78
+ # Create position-frequency product matrix [max_seq_len, embed_dim//2]
79
+ frequencies = mx.reshape(position, (-1, 1)) * denominator
80
+
81
+ # Calculate sin and cos values [max_seq_len, embed_dim//2]
82
+ sin_values = mx.sin(frequencies)
83
+ cos_values = mx.cos(frequencies)
84
+
85
+ # Interleave sin and cos values to create final embeddings
86
+ pos_idx_to_embed = mx.zeros((max_seq_len, embed_dim))
87
+ pos_idx_to_embed = mx.concatenate(
88
+ [mx.expand_dims(sin_values, -1), mx.expand_dims(cos_values, -1)], axis=-1
89
+ ).reshape(max_seq_len, embed_dim)
90
+
91
+ # Store the positional embeddings
92
+ self.pos_idx_to_embed = pos_idx_to_embed
93
+
94
+ def __call__(self, seq_embeds: mx.array) -> mx.array:
95
+ """
96
+ Apply positional embeddings to the input sequence.
97
+
98
+ Args:
99
+ seq_embeds: Input sequence embeddings with shape:
100
+ - [T, D] where T is sequence length and D is embedding dimension
101
+ - [B, T, D] where B is batch size
102
+
103
+ Returns:
104
+ Positional embeddings matching input shape
105
+ """
106
+ shape_len = len(seq_embeds.shape)
107
+ assert 2 <= shape_len <= 3, "Input must be 2D or 3D tensor"
108
+
109
+ len_seq = seq_embeds.shape[-2]
110
+ assert (
111
+ len_seq <= self.max_seq_len
112
+ ), f"Sequence length {len_seq} exceeds maximum length {self.max_seq_len}"
113
+
114
+ # Get relevant portion of pre-computed embeddings
115
+ pos_embeds = self.pos_idx_to_embed[:len_seq]
116
+
117
+ # Add batch dimension if input is 3D
118
+ if shape_len == 3:
119
+ pos_embeds = mx.expand_dims(pos_embeds, 0)
120
+
121
+ return pos_embeds
122
+
123
+
124
+ class Model(nn.Module):
125
+ """Florence-2 model for conditional generation."""
126
+
127
+ def __init__(self, config: ModelConfig):
128
+ super().__init__()
129
+ self.config = config
130
+
131
+ # Initialize vision model
132
+ self.vision_tower = VisionModel(config.vision_config)
133
+
134
+ # Initialize language model
135
+ self.language_model = LanguageModel(config.text_config)
136
+
137
+ # Image projection layers
138
+ image_dim = config.vision_config.dim_embed[-1]
139
+ text_dim = config.text_config.d_model
140
+ self.image_projection = mx.zeros((image_dim, text_dim))
141
+
142
+ self.image_proj_norm = nn.LayerNorm(text_dim)
143
+
144
+ # Position embeddings
145
+ if config.image_pos_embed["type"] == "learned_abs_2d":
146
+ self.image_pos_embed = LearnedPositionEmbedding2D(
147
+ embedding_dim=image_dim,
148
+ num_pos=config.image_pos_embed["max_pos_embeddings"],
149
+ )
150
+ else:
151
+ raise NotImplementedError(
152
+ f"Position embedding type {config.image_pos_embed['type']} not supported"
153
+ )
154
+
155
+ # Temporal embeddings
156
+ if config.visual_temporal_embedding["type"] == "COSINE":
157
+ self.visual_temporal_embed = PositionalEmbeddingCosine1D(
158
+ embed_dim=image_dim,
159
+ max_seq_len=config.visual_temporal_embedding["max_temporal_embeddings"],
160
+ )
161
+ else:
162
+ raise NotImplementedError(
163
+ f"Temporal embedding type {config.visual_temporal_embedding['type']} not supported"
164
+ )
165
+
166
+ self.image_feature_source = config.image_feature_source
167
+
168
+ def _encode_image(self, pixel_values, extract_features=True):
169
+ """Encode image using vision model and add position embeddings."""
170
+ T = 1 # Single frame for now
171
+
172
+ # Get vision features
173
+ if extract_features:
174
+ batch_size, C, H, W = pixel_values.shape
175
+ x = self.vision_tower(pixel_values)
176
+ else:
177
+ x = pixel_values
178
+ batch_size = pixel_values.shape[0]
179
+
180
+ # Assuming this is part of a class method, keeping the same structure
181
+ if self.image_pos_embed is not None:
182
+ # Reshape to (batch_size * T, -1, feature_dim)
183
+ x = mx.reshape(x, (batch_size * T, -1, x.shape[-1]))
184
+ num_tokens = x.shape[-2]
185
+ h, w = int(num_tokens**0.5), int(num_tokens**0.5)
186
+ assert h * w == num_tokens, "only support square feature maps for now"
187
+ # Reshape to (batch_size * T, h, w, feature_dim)
188
+ x = mx.reshape(x, (batch_size * T, h, w, x.shape[-1]))
189
+ pos_embed = self.image_pos_embed(x)
190
+ x = x + pos_embed
191
+ # Reshape to (batch_size, T * h * w, feature_dim)
192
+ x = mx.reshape(x, (batch_size, T * h * w, x.shape[-1]))
193
+
194
+ if self.visual_temporal_embed is not None:
195
+ # Reshape for temporal embedding
196
+ x_temp = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
197
+ temporal_input = x_temp[:, :, 0]
198
+ visual_temporal_embed = self.visual_temporal_embed(temporal_input)
199
+ # Expand dims for broadcasting
200
+ visual_temporal_embed = mx.expand_dims(visual_temporal_embed, axis=2)
201
+ x = mx.reshape(x, (batch_size, T, -1, x.shape[-1])) + visual_temporal_embed
202
+
203
+ x_feat_dict = {}
204
+
205
+ # Spatial average pooling
206
+ x_spatial = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
207
+ spatial_avg_pool_x = mx.mean(x_spatial, axis=2)
208
+ x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
209
+
210
+ # Temporal average pooling
211
+ x_temporal = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
212
+ temporal_avg_pool_x = mx.mean(x_temporal, axis=1)
213
+ x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
214
+
215
+ # Last frame features
216
+ x_last = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
217
+ x = x_last[:, -1]
218
+ x_feat_dict["last_frame"] = x
219
+
220
+ # Gather features based on source configuration
221
+ new_x = []
222
+ for _image_feature_source in self.image_feature_source:
223
+ if _image_feature_source not in x_feat_dict:
224
+ raise ValueError(
225
+ f"invalid image feature source: {_image_feature_source}"
226
+ )
227
+ new_x.append(x_feat_dict[_image_feature_source])
228
+
229
+ # Concatenate features
230
+ x = mx.concatenate(new_x, axis=1)
231
+
232
+ # Final projection and normalization
233
+ x = x @ self.image_projection
234
+ x = self.image_proj_norm(x)
235
+
236
+ return x
237
+
238
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds=None):
239
+ batch_size, image_token_length, _ = image_features.shape
240
+ image_attention_mask = mx.ones((batch_size, image_token_length))
241
+
242
+ if inputs_embeds is None:
243
+ return image_features, image_attention_mask
244
+
245
+ task_prefix_embeds = inputs_embeds
246
+ task_prefix_attention_mask = mx.ones((batch_size, task_prefix_embeds.shape[1]))
247
+
248
+ if len(task_prefix_attention_mask.shape) == 3:
249
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
250
+
251
+ # Concatenate image features and task prefix embeddings
252
+ inputs_embeds = mx.concatenate([image_features, task_prefix_embeds], axis=1)
253
+ attention_mask = mx.concatenate(
254
+ [image_attention_mask, task_prefix_attention_mask], axis=1
255
+ )
256
+ return inputs_embeds, attention_mask
257
+
258
+ @property
259
+ def layers(self):
260
+ return self.language_model.model.decoder.layers
261
+
262
+ def make_cache(self):
263
+ """Create cache for encoder-decoder model."""
264
+ return self.language_model.make_cache()
265
+
266
+ def get_input_embeddings(
267
+ self,
268
+ input_ids: Optional[mx.array] = None,
269
+ pixel_values: Optional[mx.array] = None,
270
+ **kwargs,
271
+ ):
272
+
273
+ if input_ids is not None:
274
+ # Filter out image placeholder tokens and only embed the task prompt
275
+ # Create mask for non-image tokens
276
+ non_image_mask = input_ids != self.config.image_token_id
277
+
278
+ # Use boolean indexing to filter - convert to list for processing
279
+ batch_size = input_ids.shape[0]
280
+
281
+ # For batch_size=1, filter directly
282
+ if batch_size == 1:
283
+ # Get non-image token indices using argwhere-like approach
284
+ mask_flat = non_image_mask[0]
285
+ # Sum up mask to count non-image tokens
286
+ num_non_image = int(mx.sum(mask_flat).item())
287
+
288
+ if num_non_image > 0:
289
+ # Extract non-image tokens by iterating (simple approach)
290
+ input_list = input_ids[0].tolist()
291
+ filtered_tokens = [
292
+ t for t in input_list if t != self.config.image_token_id
293
+ ]
294
+ task_input_ids = mx.array([filtered_tokens])
295
+ inputs_embeds = self.language_model.model.shared(task_input_ids)
296
+ else:
297
+ inputs_embeds = None
298
+ else:
299
+ # For batch processing, embed all and handle later
300
+ inputs_embeds = self.language_model.model.shared(input_ids)
301
+ else:
302
+ inputs_embeds = None
303
+
304
+ attention_mask = None
305
+
306
+ # Process image if provided
307
+ if pixel_values is not None:
308
+ image_features = self._encode_image(pixel_values)
309
+
310
+ # Merge image features with text embeddings (task prompt only)
311
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
312
+ image_features, inputs_embeds
313
+ )
314
+
315
+ # For encoder-decoder models, prepare initial decoder input
316
+ # Use decoder_start_token_id from text_config (default 2 for Florence2/BART)
317
+ decoder_start_token_id = getattr(
318
+ self.config.text_config, "decoder_start_token_id", 2
319
+ )
320
+ decoder_input_ids = mx.array([[decoder_start_token_id]])
321
+ decoder_inputs_embeds = self.language_model.model.shared(decoder_input_ids)
322
+
323
+ return InputEmbeddingsFeatures(
324
+ inputs_embeds=inputs_embeds,
325
+ attention_mask=attention_mask, # Use attention_mask for encoder-decoder
326
+ decoder_inputs_embeds=decoder_inputs_embeds,
327
+ )
328
+
329
+ def __call__(
330
+ self,
331
+ input_ids=None,
332
+ pixel_values=None,
333
+ cache=None,
334
+ decoder_input_ids=None,
335
+ decoder_attention_mask=None,
336
+ labels=None,
337
+ **kwargs,
338
+ ):
339
+ """Forward pass."""
340
+ attention_mask = None
341
+ decoder_inputs_embeds = None
342
+
343
+ input_embeddings_features = self.get_input_embeddings(
344
+ input_ids, pixel_values, **kwargs
345
+ )
346
+ inputs_embeds = input_embeddings_features.inputs_embeds
347
+ attention_mask = input_embeddings_features.attention_mask
348
+ # Handle decoder input IDs
349
+ if labels is not None and decoder_input_ids is None:
350
+ decoder_input_ids = shift_tokens_right(
351
+ labels, self.config.pad_token_id, self.config.bos_token_id
352
+ )
353
+
354
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
355
+ # Use decoder_start_token_id from text_config (default 2 for Florence2/BART)
356
+ decoder_start_token_id = getattr(
357
+ self.config.text_config, "decoder_start_token_id", 2
358
+ )
359
+ decoder_input_ids = mx.array([decoder_start_token_id])[None, :]
360
+ decoder_inputs_embeds = self.language_model.model.shared(decoder_input_ids)
361
+ decoder_input_ids = None
362
+
363
+ # Forward through language model
364
+ outputs = self.language_model(
365
+ inputs=input_ids,
366
+ inputs_embeds=inputs_embeds,
367
+ attention_mask=attention_mask,
368
+ decoder_input_ids=decoder_input_ids,
369
+ decoder_inputs_embeds=decoder_inputs_embeds,
370
+ decoder_attention_mask=decoder_attention_mask,
371
+ cache=cache,
372
+ )
373
+
374
+ return outputs
375
+
376
+ @staticmethod
377
+ def sanitize(weights):
378
+ sanitized_weights = {}
379
+ for k, v in weights.items():
380
+ if "final_logits_bias" in k:
381
+ continue
382
+ sanitized_weights[k] = v
383
+ return sanitized_weights