fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,452 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from ..cache import SimpleKVCache
13
+ from .config import TextConfig
14
+
15
+
16
+ class Florence2Attention(nn.Module):
17
+ def __init__(
18
+ self, config: TextConfig, is_decoder: bool = False, is_causal: bool = False
19
+ ):
20
+ super().__init__()
21
+ self.embed_dim = config.d_model
22
+ self.num_heads = (
23
+ config.decoder_attention_heads
24
+ if is_decoder
25
+ else config.encoder_attention_heads
26
+ )
27
+ self.is_decoder = is_decoder
28
+ self.is_causal = is_causal
29
+ self.head_dim = self.embed_dim // self.num_heads
30
+ self.scaling = self.head_dim**-0.5
31
+
32
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
33
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
34
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
35
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
36
+
37
+ def __call__(
38
+ self,
39
+ hidden_states,
40
+ key_value_states=None,
41
+ cache: Optional[SimpleKVCache] = None,
42
+ attention_mask=None,
43
+ layer_head_mask=None,
44
+ ):
45
+ batch_size, tgt_len, _ = hidden_states.shape
46
+
47
+ q = (
48
+ self.q_proj(hidden_states)
49
+ .reshape(batch_size, tgt_len, self.num_heads, self.head_dim)
50
+ .transpose(0, 2, 1, 3)
51
+ )
52
+
53
+ is_cross_attention = key_value_states is not None
54
+
55
+ batch_size, tgt_len, _ = hidden_states.shape
56
+ src_len = (
57
+ key_value_states.shape[1]
58
+ if key_value_states is not None
59
+ else hidden_states.shape[1]
60
+ )
61
+
62
+ if (
63
+ is_cross_attention
64
+ and cache is not None
65
+ and cache.cache_length > 0
66
+ and cache.keys.shape[2] == key_value_states.shape[1]
67
+ ):
68
+ # Cross-attention with cached keys/values - reuse them
69
+ k = cache.keys
70
+ v = cache.values
71
+
72
+ elif is_cross_attention:
73
+ # Cross attention - compute and cache keys/values from encoder
74
+ k = (
75
+ self.k_proj(key_value_states)
76
+ .reshape(batch_size, src_len, self.num_heads, self.head_dim)
77
+ .transpose(0, 2, 1, 3)
78
+ )
79
+ v = (
80
+ self.v_proj(key_value_states)
81
+ .reshape(batch_size, src_len, self.num_heads, self.head_dim)
82
+ .transpose(0, 2, 1, 3)
83
+ )
84
+ # Cache the cross-attention keys/values
85
+ if cache is not None:
86
+ cache.update(k, v)
87
+
88
+ elif cache is not None:
89
+ # Self-attention with cache - compute new k,v and concatenate with cache
90
+ k = (
91
+ self.k_proj(hidden_states)
92
+ .reshape(batch_size, src_len, self.num_heads, -1)
93
+ .transpose(0, 2, 1, 3)
94
+ )
95
+ v = (
96
+ self.v_proj(hidden_states)
97
+ .reshape(batch_size, src_len, self.num_heads, -1)
98
+ .transpose(0, 2, 1, 3)
99
+ )
100
+ # update_and_fetch handles cache concatenation
101
+ k, v = cache.update_and_fetch(k, v)
102
+
103
+ else:
104
+ # Self attention without cache (encoder)
105
+ k = (
106
+ self.k_proj(hidden_states)
107
+ .reshape(batch_size, src_len, self.num_heads, self.head_dim)
108
+ .transpose(0, 2, 1, 3)
109
+ )
110
+ v = (
111
+ self.v_proj(hidden_states)
112
+ .reshape(batch_size, src_len, self.num_heads, self.head_dim)
113
+ .transpose(0, 2, 1, 3)
114
+ )
115
+
116
+ if self.is_causal and self.is_decoder:
117
+ causal_mask = create_attention_mask(hidden_states)
118
+ attention_mask = causal_mask
119
+
120
+ attn_output = (
121
+ scaled_dot_product_attention(
122
+ q, k, v, cache=cache, scale=self.scaling, mask=attention_mask
123
+ )
124
+ .transpose(0, 2, 1, 3)
125
+ .reshape(batch_size, tgt_len, -1)
126
+ )
127
+
128
+ attn_output = self.out_proj(attn_output)
129
+
130
+ return attn_output
131
+
132
+
133
+ class Florence2EncoderLayer(nn.Module):
134
+ def __init__(self, config: TextConfig):
135
+ super().__init__()
136
+ self.embed_dim = config.d_model
137
+ self.self_attn = Florence2Attention(config, is_decoder=False, is_causal=False)
138
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
139
+ self.activation_fn = nn.GELU()
140
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
141
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
142
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
143
+
144
+ def __call__(self, hidden_states, attention_mask=None):
145
+ residual = hidden_states
146
+ hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
147
+ hidden_states = residual + hidden_states
148
+ hidden_states = self.self_attn_layer_norm(hidden_states)
149
+
150
+ residual = hidden_states
151
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
152
+ hidden_states = self.fc2(hidden_states)
153
+ hidden_states = residual + hidden_states
154
+ hidden_states = self.final_layer_norm(hidden_states)
155
+
156
+ return hidden_states
157
+
158
+
159
+ class Florence2DecoderLayer(nn.Module):
160
+ def __init__(self, config: TextConfig):
161
+ super().__init__()
162
+ self.embed_dim = config.d_model
163
+ self.self_attn = Florence2Attention(config, is_decoder=True, is_causal=True)
164
+ self.dropout = config.dropout
165
+ self.activation_fn = nn.GELU()
166
+ self.activation_dropout = config.activation_dropout
167
+
168
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
169
+ self.encoder_attn = Florence2Attention(config, is_decoder=True)
170
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
171
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
172
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
173
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
174
+
175
+ def __call__(
176
+ self,
177
+ hidden_states,
178
+ encoder_hidden_states,
179
+ attention_mask=None,
180
+ encoder_attention_mask=None,
181
+ cache: Optional[Tuple[SimpleKVCache, SimpleKVCache]] = None,
182
+ ):
183
+ residual = hidden_states
184
+
185
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
186
+ self_attn_cache = cache[0] if cache[0] is not None else None
187
+
188
+ hidden_states = self.self_attn(
189
+ hidden_states, attention_mask=attention_mask, cache=self_attn_cache
190
+ )
191
+
192
+ hidden_states = residual + hidden_states
193
+ hidden_states = self.self_attn_layer_norm(hidden_states)
194
+
195
+ if encoder_hidden_states is not None:
196
+ residual = hidden_states
197
+
198
+ # cross_attn cached key/values tuple is at positions 3,4 of cache tuple
199
+ cross_attn_cache = cache[-1] if cache[-1] is not None else None
200
+
201
+ hidden_states = self.encoder_attn(
202
+ hidden_states,
203
+ key_value_states=encoder_hidden_states,
204
+ attention_mask=encoder_attention_mask,
205
+ cache=cross_attn_cache,
206
+ )
207
+ hidden_states = residual + hidden_states
208
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
209
+
210
+ # Fully Connected
211
+ residual = hidden_states
212
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
213
+ hidden_states = self.fc2(hidden_states)
214
+ hidden_states = residual + hidden_states
215
+ hidden_states = self.final_layer_norm(hidden_states)
216
+
217
+ return hidden_states
218
+
219
+
220
+ class Florence2Encoder(nn.Module):
221
+ def __init__(self, config: TextConfig):
222
+ super().__init__()
223
+ self.config = config
224
+ self.dropout = config.dropout
225
+ self.layerdrop = config.encoder_layerdrop
226
+
227
+ embed_dim = config.d_model
228
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
229
+ self.offset = 2
230
+ self.embed_positions = nn.Embedding(
231
+ config.max_position_embeddings + self.offset, embed_dim
232
+ )
233
+ self.layers = [
234
+ Florence2EncoderLayer(config) for _ in range(config.encoder_layers)
235
+ ]
236
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
237
+
238
+ def __call__(self, input_ids=None, inputs_embeds=None, attention_mask=None):
239
+
240
+ if inputs_embeds is None:
241
+ inputs_embeds = self.embed_tokens(input_ids)
242
+ input_shape = inputs_embeds.shape
243
+ else:
244
+ input_shape = inputs_embeds.shape
245
+
246
+ positions = mx.arange(input_shape[1])
247
+
248
+ if positions.ndim == 1:
249
+ positions = mx.expand_dims(positions, axis=0)
250
+
251
+ embed_pos = self.embed_positions(positions + self.offset)
252
+
253
+ hidden_states = inputs_embeds + embed_pos
254
+ hidden_states = self.layernorm_embedding(hidden_states)
255
+
256
+ for encoder_layer in self.layers:
257
+ # Add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
258
+ dropout_probability = mx.random.uniform()
259
+ if self.training and (dropout_probability < self.layerdrop):
260
+ continue
261
+ hidden_states = encoder_layer(hidden_states, attention_mask)
262
+
263
+ return hidden_states
264
+
265
+
266
+ class Florence2Decoder(nn.Module):
267
+ def __init__(self, config: TextConfig):
268
+ super().__init__()
269
+ self.config = config
270
+ self.dropout = config.dropout
271
+ self.layerdrop = config.decoder_layerdrop
272
+ self.padding_idx = config.pad_token_id
273
+ self.max_target_positions = config.max_position_embeddings
274
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
275
+ self.offset = 2
276
+ self.embed_positions = nn.Embedding(
277
+ config.max_position_embeddings + self.offset, config.d_model
278
+ )
279
+ self.layers = [
280
+ Florence2DecoderLayer(config) for _ in range(config.decoder_layers)
281
+ ]
282
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
283
+
284
+ def __call__(
285
+ self,
286
+ input_ids=None,
287
+ attention_mask=None,
288
+ encoder_hidden_states=None,
289
+ encoder_attention_mask=None,
290
+ head_mask=None,
291
+ cross_attn_head_mask=None,
292
+ inputs_embeds=None,
293
+ cache=None,
294
+ ):
295
+ if input_ids is not None and inputs_embeds is not None:
296
+ raise ValueError(
297
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
298
+ )
299
+ elif input_ids is not None:
300
+ inputs_embeds = self.embed_tokens(input_ids)
301
+ input_shape = inputs_embeds.shape # for 2d masks
302
+ positions = input_ids
303
+ elif inputs_embeds is not None:
304
+ input_shape = inputs_embeds.shape[:-1] # for 4d masks
305
+ positions = inputs_embeds[:, :, -1]
306
+ else:
307
+ raise ValueError(
308
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
309
+ )
310
+
311
+ if positions.ndim == 1:
312
+ positions = mx.expand_dims(positions, axis=0)
313
+
314
+ cache_length = cache[0][0].keys.shape[2] if cache[0][0].cache_length > 0 else 0
315
+
316
+ bsz, seq_len = inputs_embeds.shape[:2]
317
+ positions = mx.arange(
318
+ cache_length,
319
+ cache_length + seq_len,
320
+ dtype=mx.int64,
321
+ )
322
+ positions = mx.expand_dims(positions, axis=0)
323
+
324
+ embed_pos = self.embed_positions(positions + self.offset)
325
+
326
+ hidden_states = inputs_embeds + embed_pos
327
+ hidden_states = self.layernorm_embedding(hidden_states)
328
+
329
+ for decoder_layer, c in zip(self.layers, cache):
330
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
331
+ dropout_probability = mx.random.uniform()
332
+ if self.training and (dropout_probability < self.layerdrop):
333
+ continue
334
+ hidden_states = decoder_layer(
335
+ hidden_states=hidden_states,
336
+ encoder_hidden_states=encoder_hidden_states,
337
+ attention_mask=attention_mask,
338
+ encoder_attention_mask=encoder_attention_mask,
339
+ cache=c,
340
+ )
341
+
342
+ return hidden_states
343
+
344
+
345
+ class Florence2LanguageModel(nn.Module):
346
+ def __init__(self, config: TextConfig):
347
+ super().__init__()
348
+ self.config = config
349
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
350
+ self.encoder = Florence2Encoder(config)
351
+ self.decoder = Florence2Decoder(config)
352
+ if config.scale_embedding:
353
+ self.embed_scale = math.sqrt(config.d_model)
354
+ else:
355
+ self.embed_scale = 1.0
356
+
357
+ def __call__(
358
+ self,
359
+ input_ids=None,
360
+ inputs_embeds=None,
361
+ decoder_input_ids=None,
362
+ decoder_inputs_embeds=None,
363
+ attention_mask=None,
364
+ decoder_attention_mask=None,
365
+ encoder_outputs=None,
366
+ cache=None,
367
+ ):
368
+ self.encoder.embed_tokens = self.shared
369
+ self.decoder.embed_tokens = self.shared
370
+
371
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
372
+ if input_ids is None:
373
+ raise ValueError(
374
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
375
+ "passed, `input_ids` cannot be `None`. Please pass either "
376
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
377
+ )
378
+
379
+ decoder_input_ids = mx.zeros_like(input_ids)
380
+ decoder_input_ids[:, 1:] = input_ids[:, :-1]
381
+ decoder_input_ids[:, 0] = self.config.bos_token_id
382
+
383
+ if inputs_embeds is not None:
384
+ inputs_embeds = inputs_embeds * self.embed_scale
385
+
386
+ if cache is None:
387
+ cache = [(SimpleKVCache(), SimpleKVCache())] * len(self.decoder.layers)
388
+
389
+ if encoder_outputs is None:
390
+ encoder_outputs = self.encoder(
391
+ input_ids=input_ids,
392
+ inputs_embeds=inputs_embeds,
393
+ attention_mask=attention_mask,
394
+ )
395
+
396
+ decoder_outputs = self.decoder(
397
+ input_ids=decoder_input_ids,
398
+ attention_mask=decoder_attention_mask,
399
+ encoder_hidden_states=encoder_outputs,
400
+ encoder_attention_mask=attention_mask,
401
+ inputs_embeds=decoder_inputs_embeds,
402
+ cache=cache,
403
+ )
404
+ return decoder_outputs, encoder_outputs
405
+
406
+
407
+ class LanguageModel(nn.Module):
408
+ def __init__(self, config: TextConfig):
409
+ super().__init__()
410
+ self.config = config
411
+ self.model = Florence2LanguageModel(config)
412
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
413
+
414
+ def __call__(
415
+ self,
416
+ inputs=None,
417
+ inputs_embeds=None,
418
+ decoder_input_ids=None,
419
+ decoder_inputs_embeds=None,
420
+ attention_mask=None,
421
+ decoder_attention_mask=None,
422
+ encoder_outputs=None,
423
+ cache=None,
424
+ **kwargs,
425
+ ):
426
+ decoder_outputs, encoder_outputs = self.model(
427
+ inputs,
428
+ inputs_embeds,
429
+ decoder_input_ids,
430
+ decoder_inputs_embeds,
431
+ attention_mask,
432
+ decoder_attention_mask,
433
+ encoder_outputs,
434
+ cache,
435
+ )
436
+ out = self.lm_head(decoder_outputs)
437
+ return LanguageModelOutput(logits=out, encoder_outputs=encoder_outputs)
438
+
439
+ @property
440
+ def layers(self):
441
+ return range(self.model.config.decoder_layers)
442
+
443
+ @property
444
+ def head_dim(self):
445
+ return self.config.d_model // self.config.decoder_attention_heads
446
+
447
+ @property
448
+ def n_kv_heads(self):
449
+ return self.config.decoder_attention_heads
450
+
451
+ def make_cache(self):
452
+ return [(SimpleKVCache(), SimpleKVCache()) for n in self.layers]
@@ -0,0 +1,30 @@
1
+ from transformers.models.florence2.processing_florence2 import Florence2Processor
2
+
3
+ # Store the original __init__
4
+ _original_init = Florence2Processor.__init__
5
+
6
+
7
+ def _patched_init(self, image_processor=None, tokenizer=None, **kwargs):
8
+ """Patched __init__ that adds image_token attributes to tokenizer if missing."""
9
+ if tokenizer is not None:
10
+ # Ensure tokenizer has image_token attribute
11
+ if not hasattr(tokenizer, "image_token"):
12
+ tokenizer.image_token = "<image>"
13
+
14
+ # Ensure tokenizer has image_token_id attribute
15
+ if not hasattr(tokenizer, "image_token_id"):
16
+ vocab = tokenizer.get_vocab()
17
+ if tokenizer.image_token in vocab:
18
+ tokenizer.image_token_id = vocab[tokenizer.image_token]
19
+ else:
20
+ tokenizer.add_tokens([tokenizer.image_token], special_tokens=True)
21
+ tokenizer.image_token_id = tokenizer.convert_tokens_to_ids(
22
+ tokenizer.image_token
23
+ )
24
+
25
+ # Call original __init__
26
+ _original_init(self, image_processor=image_processor, tokenizer=tokenizer, **kwargs)
27
+
28
+
29
+ # Apply the patch
30
+ Florence2Processor.__init__ = _patched_init