fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,264 @@
1
+ """
2
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
+ """
4
+
5
+ import dataclasses
6
+ from enum import IntEnum, auto
7
+ from typing import Dict, List
8
+
9
+
10
+ class SeparatorStyle(IntEnum):
11
+ """Separator styles."""
12
+
13
+ DeepSeek = auto()
14
+ DeepSeekV2 = auto()
15
+ PLAIN = auto()
16
+ ALIGNMENT = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that manages prompt templates and keeps all conversation history."""
22
+
23
+ # The name of this template
24
+ name: str
25
+ # The template of the system prompt
26
+ system_template: str = "{system_message}"
27
+ # The system message
28
+ system_message: str = ""
29
+ # The names of two roles
30
+ roles: List[str] = (("USER", "ASSISTANT"),)
31
+ # All messages. Each item is (role, message).
32
+ messages: List[List[str]] = ()
33
+ # The number of few shot examples
34
+ offset: int = 0
35
+ # The separator style and configurations
36
+ sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
+ sep: str = "\n"
38
+ sep2: str = None
39
+ # Stop criteria (the default one is EOS token)
40
+ stop_str: str = None
41
+ # Stops generation if meeting any token in this list
42
+ stop_token_ids: List[int] = None
43
+
44
+ def get_prompt(self) -> str:
45
+ """Get the prompt for generation."""
46
+ system_prompt = self.system_template.format(system_message=self.system_message)
47
+ if self.sep_style == SeparatorStyle.DeepSeek:
48
+ seps = [self.sep, self.sep2]
49
+ if system_prompt == "" or system_prompt is None:
50
+ ret = ""
51
+ else:
52
+ ret = system_prompt + seps[0]
53
+ for i, (role, message) in enumerate(self.messages):
54
+ if message:
55
+ ret += role + ": " + message + seps[i % 2]
56
+ else:
57
+ ret += role + ":"
58
+ return ret
59
+ elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
+ seps = [self.sep, self.sep2]
61
+ if system_prompt == "" or system_prompt is None:
62
+ ret = ""
63
+ else:
64
+ ret = system_prompt + seps[0]
65
+ for i, (role, message) in enumerate(self.messages):
66
+ if message:
67
+ if role == "User":
68
+ ret += "<|sft▁begin|>\n" + message + self.sep
69
+ else:
70
+ ret += message + self.sep2
71
+ else:
72
+ ret = ret
73
+ return ret
74
+
75
+ elif self.sep_style == SeparatorStyle.PLAIN:
76
+ seps = [self.sep, self.sep2]
77
+ ret = ""
78
+ for i, (role, message) in enumerate(self.messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i % 2 == 0:
83
+ ret += message + seps[i % 2]
84
+ else:
85
+ ret += message + seps[i % 2]
86
+ else:
87
+ ret += ""
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
+ seps = [self.sep, self.sep2]
91
+ ret = ""
92
+ for i, (role, message) in enumerate(self.messages):
93
+ if message:
94
+ if type(message) is tuple:
95
+ message, _, _ = message
96
+ if i % 2 == 0:
97
+ ret += "<image>\n" + seps[i % 2]
98
+ else:
99
+ ret += message + seps[i % 2]
100
+ else:
101
+ ret += ""
102
+ return ret
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ def set_system_message(self, system_message: str):
107
+ """Set the system message."""
108
+ self.system_message = system_message
109
+
110
+ def append_message(self, role: str, message: str):
111
+ """Append a new message."""
112
+ self.messages.append([role, message])
113
+
114
+ def update_last_message(self, message: str):
115
+ """Update the last output.
116
+
117
+ The last message is typically set to be None when constructing the prompt,
118
+ so we need to update it in-place after getting the response from a model.
119
+ """
120
+ self.messages[-1][1] = message
121
+
122
+ def reset_message(self):
123
+ """Reset a new message."""
124
+ self.messages = []
125
+
126
+ def to_gradio_chatbot(self):
127
+ """Convert the conversation to gradio chatbot format."""
128
+ ret = []
129
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
+ if i % 2 == 0:
131
+ ret.append([msg, None])
132
+ else:
133
+ ret[-1][-1] = msg
134
+ return ret
135
+
136
+ def to_openai_api_messages(self):
137
+ """Convert the conversation to OpenAI chat completion format."""
138
+ system_prompt = self.system_template.format(system_message=self.system_message)
139
+ ret = [{"role": "system", "content": system_prompt}]
140
+
141
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
+ if i % 2 == 0:
143
+ ret.append({"role": "user", "content": msg})
144
+ else:
145
+ if msg is not None:
146
+ ret.append({"role": "assistant", "content": msg})
147
+ return ret
148
+
149
+ def copy(self):
150
+ return Conversation(
151
+ name=self.name,
152
+ system_template=self.system_template,
153
+ system_message=self.system_message,
154
+ roles=self.roles,
155
+ messages=[[x, y] for x, y in self.messages],
156
+ offset=self.offset,
157
+ sep_style=self.sep_style,
158
+ sep=self.sep,
159
+ sep2=self.sep2,
160
+ stop_str=self.stop_str,
161
+ stop_token_ids=self.stop_token_ids,
162
+ )
163
+
164
+ def dict(self):
165
+ return {
166
+ "template_name": self.name,
167
+ "system_message": self.system_message,
168
+ "roles": self.roles,
169
+ "messages": self.messages,
170
+ "offset": self.offset,
171
+ }
172
+
173
+
174
+ # A global registry for all conversation templates
175
+ conv_templates: Dict[str, Conversation] = {}
176
+
177
+
178
+ def register_conv_template(template: Conversation, override: bool = False):
179
+ """Register a new conversation template."""
180
+ if not override:
181
+ assert (
182
+ template.name not in conv_templates
183
+ ), f"{template.name} has been registered."
184
+
185
+ conv_templates[template.name] = template
186
+
187
+
188
+ def get_conv_template(name: str) -> Conversation:
189
+ """Get a conversation template."""
190
+ return conv_templates[name].copy()
191
+
192
+
193
+ register_conv_template(
194
+ Conversation(
195
+ name="deepseek",
196
+ system_template="{system_message}",
197
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
198
+ # "thinking step by step to be sure you get the right answer.",
199
+ system_message="",
200
+ roles=("<|User|>", "<|Assistant|>"),
201
+ messages=(),
202
+ offset=0,
203
+ sep_style=SeparatorStyle.DeepSeek,
204
+ sep="\n\n",
205
+ sep2="<|end▁of▁sentence|>",
206
+ stop_token_ids=[100001],
207
+ stop_str=["User:", "<|end▁of▁sentence|>"],
208
+ )
209
+ )
210
+
211
+ register_conv_template(
212
+ Conversation(
213
+ name="deepseekv2",
214
+ system_template="{system_message}",
215
+ system_message="",
216
+ roles=("|<User>|", "|<Assistant>|"),
217
+ messages=(),
218
+ offset=0,
219
+ sep_style=SeparatorStyle.DeepSeekV2,
220
+ sep="\n<|sft▁end|>",
221
+ sep2="<|end▁of▁sentence|>",
222
+ stop_token_ids=[100001],
223
+ stop_str=["User:", "<|end▁of▁sentence|>"],
224
+ )
225
+ )
226
+
227
+
228
+ register_conv_template(
229
+ Conversation(
230
+ name="plain",
231
+ system_template="",
232
+ system_message="",
233
+ roles=("", ""),
234
+ messages=(),
235
+ offset=0,
236
+ sep_style=SeparatorStyle.PLAIN,
237
+ sep="",
238
+ sep2="",
239
+ stop_token_ids=[100001],
240
+ stop_str=["</s>"],
241
+ )
242
+ )
243
+
244
+
245
+ register_conv_template(
246
+ Conversation(
247
+ name="alignment",
248
+ system_template="",
249
+ system_message="",
250
+ roles=("", ""),
251
+ messages=(),
252
+ offset=0,
253
+ sep_style=SeparatorStyle.ALIGNMENT,
254
+ sep="",
255
+ sep2="",
256
+ stop_token_ids=[100001],
257
+ stop_str=["</s>"],
258
+ )
259
+ )
260
+
261
+
262
+ if __name__ == "__main__":
263
+ print("deepseek template:")
264
+ conv = get_conv_template("deepseek")
@@ -0,0 +1,371 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+ from transformers import AutoProcessor
8
+
9
+ from mlx_vlm.models.base import InputEmbeddingsFeatures
10
+
11
+ from .config import ModelConfig, ProjectorConfig, SAMViTConfig
12
+ from .language import LanguageModel
13
+ from .processing_deepseekocr import DeepseekOCRProcessor
14
+ from .sam import SAMEncoder
15
+ from .vision import VisionModel
16
+
17
+ AutoProcessor.register("deepseekocr", DeepseekOCRProcessor)
18
+
19
+
20
+ class MlpProjector(nn.Module):
21
+ def __init__(self, config: ProjectorConfig):
22
+ super().__init__()
23
+ self.config = config
24
+
25
+ if config.projector_config.projector_type == "linear":
26
+ modules = nn.Linear(
27
+ config.projector_config.input_dim, config.projector_config.n_embed
28
+ )
29
+
30
+ elif config.projector_config.projector_type == "downsample_mlp_gelu":
31
+ mlp_depth = config.projector_config.depth
32
+ mlp_ratio = config.projector_config.mlp_ratio
33
+ modules = [
34
+ nn.Linear(
35
+ config.projector_config.input_dim
36
+ * config.projector_config.downsample_ratio
37
+ * config.projector_config.downsample_ratio,
38
+ config.projector_config.n_embed * mlp_ratio,
39
+ )
40
+ ]
41
+ for _ in range(1, mlp_depth - 1):
42
+ modules.append(nn.GELU())
43
+ modules.append(
44
+ nn.Linear(
45
+ config.projector_config.n_embed * mlp_ratio,
46
+ config.projector_config.n_embed * mlp_ratio,
47
+ )
48
+ )
49
+ modules.append(nn.GELU())
50
+ modules.append(
51
+ nn.Linear(
52
+ config.projector_config.n_embed * mlp_ratio,
53
+ config.projector_config.n_embed,
54
+ )
55
+ )
56
+ else:
57
+ raise ValueError(
58
+ f"Unknown projector type: {config.projector_config.projector_type}"
59
+ )
60
+
61
+ self.layers = modules
62
+
63
+ def __call__(self, x):
64
+ if self.config.projector_config.projector_type == "downsample_mlp_gelu":
65
+ bs, hw, input_dim = x.shape
66
+ h = w = int(math.sqrt(hw))
67
+
68
+ # Compute padding
69
+ pad = (
70
+ 0
71
+ if h % self.config.projector_config.downsample_ratio == 0
72
+ else self.config.projector_config.downsample_ratio
73
+ - h % self.config.projector_config.downsample_ratio
74
+ )
75
+
76
+ x = mx.reshape(x, (bs, h, w, input_dim))
77
+ if pad > 0:
78
+ x = mx.pad(x, [(0, 0), (0, pad), (0, pad), (0, 0)], constant_values=0)
79
+
80
+ x = mx.transpose(x, (0, 3, 1, 2)) # B, C, H, W
81
+
82
+ # Manual implementation of unfold for downsampling
83
+ h_pad, w_pad = x.shape[2], x.shape[3]
84
+ ds = self.config.projector_config.downsample_ratio
85
+ patches = []
86
+
87
+ for i in range(0, h_pad - ds + 1, ds):
88
+ for j in range(0, w_pad - ds + 1, ds):
89
+ patch = x[:, :, i : i + ds, j : j + ds]
90
+ patches.append(mx.reshape(patch, (bs, -1)))
91
+
92
+ x = mx.stack(patches, axis=1) # B, N_patches, C*ds*ds
93
+
94
+ if self.config.projector_config.projector_type == "linear":
95
+ x = self.layers(x)
96
+ else:
97
+ for layer in self.layers:
98
+ x = layer(x)
99
+ return x
100
+
101
+
102
+ class Model(nn.Module):
103
+ def __init__(self, config: ModelConfig):
104
+ super().__init__()
105
+ self.config = config
106
+ self.vision_model = VisionModel(config.vision_config)
107
+ sam_config = SAMViTConfig()
108
+ self.sam_model = SAMEncoder(
109
+ img_size=sam_config.image_size,
110
+ patch_size=sam_config.patch_size,
111
+ embed_dim=sam_config.width,
112
+ depth=sam_config.layers,
113
+ num_heads=sam_config.heads,
114
+ window_size=sam_config.window_size,
115
+ global_attn_indexes=sam_config.global_attn_indexes,
116
+ )
117
+ self.language_model = LanguageModel(config.text_config)
118
+ self.projector = MlpProjector(config)
119
+
120
+ self.tile_tag = config.tile_tag
121
+ self.global_view_pos = config.global_view_pos
122
+ # 用于format image token sequence的特殊token
123
+ embed_std = 1 / mx.sqrt(
124
+ mx.array(config.projector_config.n_embed, dtype=mx.float32)
125
+ )
126
+
127
+ if self.tile_tag == "2D":
128
+
129
+ # <|view_separator|>, <|\n|>
130
+ self.image_newline = mx.array(
131
+ mx.random.normal((config.projector_config.n_embed,)) * embed_std
132
+ )
133
+ # fix the typo: view_seperater
134
+ self.view_separator = mx.array(
135
+ mx.random.normal((config.projector_config.n_embed,)) * embed_std
136
+ )
137
+ else:
138
+ raise ValueError(
139
+ f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
140
+ )
141
+
142
+ def get_input_embeddings(
143
+ self,
144
+ input_ids: Optional[mx.array] = None,
145
+ pixel_values: Optional[mx.array] = None,
146
+ images_spatial_crop: Optional[mx.array] = None,
147
+ images_seq_mask: Optional[mx.array] = None,
148
+ **kwargs,
149
+ ):
150
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
151
+
152
+ if pixel_values is None:
153
+ return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
154
+
155
+ # Only process images on prefill (input_ids.shape[1] != 1), not during autoregressive decoding
156
+ if (
157
+ self.sam_model is not None
158
+ and input_ids.shape[1] != 1
159
+ and mx.sum(pixel_values[1]).item() != 0
160
+ ):
161
+
162
+ idx = 0
163
+ patch_idx = 0 # Track patch offset for batch processing
164
+ all_patches = pixel_values[0]
165
+ all_image_ori = pixel_values[1]
166
+
167
+ for crop_shape in images_spatial_crop.tolist():
168
+ images_in_this_batch = []
169
+ width_crop_num, height_crop_num = int(crop_shape[0]), int(crop_shape[1])
170
+
171
+ # Calculate number of patches for this image
172
+ has_crops = width_crop_num > 1 or height_crop_num > 1
173
+ num_patches = width_crop_num * height_crop_num if has_crops else 0
174
+
175
+ # Extract patches for current image
176
+ if has_crops and num_patches > 0:
177
+ patches = all_patches[patch_idx : patch_idx + num_patches]
178
+ patch_idx += num_patches
179
+ else:
180
+ patches = None
181
+
182
+ # Extract global image for current image (one per batch item)
183
+ image_ori = all_image_ori[idx : idx + 1]
184
+
185
+ if patches is not None and mx.sum(patches).item() != 0:
186
+ local_features_1 = self.sam_model(patches.transpose(0, 2, 3, 1))
187
+
188
+ local_features_2 = self.vision_model(
189
+ patches.transpose(0, 2, 3, 1), patch_embeds=local_features_1
190
+ )
191
+
192
+ local_features = mx.concatenate(
193
+ (
194
+ local_features_2[:, 1:],
195
+ local_features_1.flatten(start_axis=1, end_axis=2),
196
+ ),
197
+ axis=-1,
198
+ )
199
+
200
+ local_features = self.projector(local_features)
201
+
202
+ global_features_1 = self.sam_model(image_ori.transpose(0, 2, 3, 1))
203
+ global_features_2 = self.vision_model(
204
+ image_ori.transpose(0, 2, 3, 1), global_features_1
205
+ )
206
+
207
+ global_features = mx.concatenate(
208
+ (
209
+ global_features_2[:, 1:],
210
+ global_features_1.flatten(start_axis=1, end_axis=2),
211
+ ),
212
+ axis=-1,
213
+ )
214
+ global_features = self.projector(global_features)
215
+
216
+ # Remove batch dimension for single image processing
217
+ global_features = global_features[0] # (hw, n_dim)
218
+ hw, n_dim = global_features.shape
219
+ h = w = int(hw**0.5)
220
+
221
+ _, hw2, n_dim2 = local_features.shape
222
+ h2 = w2 = int(hw2**0.5)
223
+
224
+ global_features = global_features.reshape(h, w, n_dim)
225
+
226
+ global_features = mx.concatenate(
227
+ [
228
+ global_features,
229
+ mx.broadcast_to(
230
+ self.image_newline[None, None, :], (h, 1, n_dim)
231
+ ),
232
+ ],
233
+ axis=1,
234
+ )
235
+
236
+ global_features = global_features.reshape(-1, n_dim)
237
+
238
+ local_features = (
239
+ local_features.reshape(
240
+ height_crop_num, width_crop_num, h2, w2, n_dim2
241
+ )
242
+ .transpose(0, 2, 1, 3, 4)
243
+ .reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
244
+ )
245
+ local_features = mx.concatenate(
246
+ [
247
+ local_features,
248
+ mx.broadcast_to(
249
+ self.image_newline[None, None, :],
250
+ (height_crop_num * h2, 1, n_dim2),
251
+ ),
252
+ ],
253
+ axis=1,
254
+ )
255
+ local_features = local_features.reshape(-1, n_dim2)
256
+
257
+ global_local_features = mx.concatenate(
258
+ [local_features, global_features, self.view_separator[None, :]],
259
+ axis=0,
260
+ )
261
+
262
+ else:
263
+ global_features_1 = self.sam_model(image_ori.transpose(0, 2, 3, 1))
264
+ global_features_2 = self.vision_model(
265
+ image_ori.transpose(0, 2, 3, 1), global_features_1
266
+ )
267
+ global_features = mx.concatenate(
268
+ (
269
+ global_features_2[:, 1:],
270
+ global_features_1.flatten(start_axis=1, end_axis=2),
271
+ ),
272
+ axis=-1,
273
+ )
274
+ global_features = self.projector(global_features)
275
+
276
+ # Remove batch dimension for single image processing
277
+ global_features = global_features[0] # (hw, n_dim)
278
+ hw, n_dim = global_features.shape
279
+ h = w = int(hw**0.5)
280
+
281
+ global_features = global_features.reshape(h, w, n_dim)
282
+
283
+ global_features = mx.concatenate(
284
+ [
285
+ global_features,
286
+ mx.broadcast_to(
287
+ self.image_newline[None, None, :], (h, 1, n_dim)
288
+ ),
289
+ ],
290
+ axis=1,
291
+ )
292
+
293
+ global_features = global_features.reshape(-1, n_dim)
294
+
295
+ global_local_features = mx.concatenate(
296
+ [global_features, self.view_separator[None, :]], axis=0
297
+ )
298
+
299
+ images_in_this_batch.append(global_local_features)
300
+
301
+ if images_in_this_batch:
302
+ images_in_this_batch = mx.concatenate(images_in_this_batch, axis=0)
303
+ # Find positions where images should be placed
304
+ image_indices = np.where(images_seq_mask[idx])[0].tolist()
305
+ # Directly assign the image features to those positions
306
+ input_embeds[idx, image_indices] = images_in_this_batch
307
+
308
+ idx += 1
309
+
310
+ return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
311
+
312
+ @property
313
+ def layers(self):
314
+ return self.language_model.model.layers
315
+
316
+ def __call__(
317
+ self,
318
+ input_ids: mx.array,
319
+ pixel_values: Optional[mx.array] = None,
320
+ mask: Optional[mx.array] = None,
321
+ cache=None,
322
+ **kwargs,
323
+ ):
324
+
325
+ images_spatial_crop = kwargs.get("images_spatial_crop", None)
326
+ images_seq_mask = kwargs.get("images_seq_mask", None)
327
+
328
+ input_embeddings = self.get_input_embeddings(
329
+ input_ids, pixel_values, images_spatial_crop, images_seq_mask
330
+ )
331
+
332
+ logits = self.language_model(
333
+ input_ids, cache=cache, inputs_embeds=input_embeddings.inputs_embeds
334
+ )
335
+ return logits
336
+
337
+ @staticmethod
338
+ def sanitize(weights):
339
+ def transform_key(key):
340
+ if "model.layers" in key and "language_model" not in key:
341
+ key = key.replace("model.layers", "language_model.model.layers")
342
+
343
+ if "model.embed_tokens" in key and "language_model" not in key:
344
+ key = key.replace(
345
+ "model.embed_tokens", "language_model.model.embed_tokens"
346
+ )
347
+
348
+ if "model.norm" in key and "language_model" not in key:
349
+ key = key.replace("model.norm", "language_model.model.norm")
350
+
351
+ if "model.vision_model" in key:
352
+ key = key.replace("model.vision_model", "vision_model")
353
+
354
+ if "model.sam_model" in key:
355
+ key = key.replace("model.sam_model", "sam_model")
356
+
357
+ if "model.projector" in key:
358
+ key = key.replace("model.projector", "projector")
359
+
360
+ if "model.view_seperator" in key:
361
+ key = key.replace("model.view_seperator", "view_separator")
362
+
363
+ if "model.image_newline" in key:
364
+ key = key.replace("model.image_newline", "image_newline")
365
+
366
+ if "lm_head.weight" in key and "language_model" not in key:
367
+ key = key.replace("lm_head.weight", "language_model.lm_head.weight")
368
+
369
+ return key
370
+
371
+ return {transform_key(k): v for k, v in weights.items()}