fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,238 @@
1
+ from typing import Any, List, Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from mlx_lm.models.cache import (
6
+ ArraysCache,
7
+ BatchKVCache,
8
+ BatchRotatingKVCache,
9
+ ChunkedKVCache,
10
+ KVCache,
11
+ RotatingKVCache,
12
+ _BaseCache,
13
+ )
14
+
15
+
16
+ def make_prompt_cache(
17
+ model: nn.Module,
18
+ max_kv_size: Optional[int] = None,
19
+ ) -> List[Any]:
20
+ """
21
+ Construct the model's cache for use in generation.
22
+
23
+ This function will defer the cache construction to the model if it has a
24
+ ``make_cache`` method, otherwise it will make a default KV cache.
25
+
26
+ Args:
27
+ model (nn.Module): The language model.
28
+ max_kv_size (Optional[int]): If provided and the model does not have a
29
+ ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
30
+ size of ``max_kv_size``
31
+ """
32
+ if hasattr(model, "make_cache"):
33
+ return model.make_cache()
34
+
35
+ num_layers = len(model.layers)
36
+
37
+ if max_kv_size is not None:
38
+ return [
39
+ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
40
+ ]
41
+ else:
42
+ return [KVCache() for _ in range(num_layers)]
43
+
44
+
45
+ class SimpleKVCache:
46
+ """A simple key-value cache for transformer attention layers.
47
+
48
+ Stores and concatenates key/value tensors along sequence dimension.
49
+ """
50
+
51
+ def __init__(self):
52
+ self.keys = None
53
+ self.values = None
54
+ self.cache_length = 0
55
+
56
+ def update_and_fetch(self, keys, values):
57
+ """Update cache with new key/value tensors and return full cache.
58
+
59
+ Args:
60
+ keys: New key tensor to add [batch, heads, seq_len, head_dim]
61
+ values: New value tensor to add [batch, heads, seq_len, head_dim]
62
+
63
+ Returns:
64
+ Tuple of (cached_keys, cached_values) containing full cache history
65
+ """
66
+ if self.cache_length == 0:
67
+ # First update - just store tensors
68
+ self.keys = keys
69
+ self.values = values
70
+ else:
71
+ # Concatenate with existing cache along sequence dimension
72
+ self.keys = mx.concatenate([self.keys, keys], axis=2)
73
+ self.values = mx.concatenate([self.values, values], axis=2)
74
+
75
+ self.cache_length += keys.shape[2]
76
+ return self.keys, self.values
77
+
78
+ def fetch(self):
79
+ return self.keys, self.values
80
+
81
+ def update(self, keys, values):
82
+ """Update cache with new key/value tensors without returning.
83
+
84
+ Args:
85
+ keys: New key tensor to store
86
+ values: New value tensor to store
87
+ """
88
+ self.keys = keys
89
+ self.values = values
90
+ self.cache_length += keys.shape[2]
91
+
92
+
93
+ class SlidingWindowCache(_BaseCache):
94
+ """A sliding window cache for local attention layers."""
95
+
96
+ def __init__(self, max_size: int, step: int = 256):
97
+ self.max_size = max_size
98
+ self.step = step
99
+ self.keys = None
100
+ self.values = None
101
+ self.offset = 0
102
+
103
+ def update_and_fetch(
104
+ self, keys: mx.array, values: mx.array
105
+ ) -> Tuple[mx.array, mx.array]:
106
+ B, n_kv_heads, seq_len, k_head_dim = keys.shape
107
+ v_head_dim = values.shape[-1]
108
+
109
+ if self.keys is None:
110
+ # Initialize cache
111
+ k_shape = (B, n_kv_heads, self.max_size, k_head_dim)
112
+ v_shape = (B, n_kv_heads, self.max_size, v_head_dim)
113
+ self.keys = mx.zeros(k_shape, dtype=keys.dtype)
114
+ self.values = mx.zeros(v_shape, dtype=values.dtype)
115
+
116
+ # Simple sliding window: keep only the last max_size tokens
117
+ if self.offset + seq_len <= self.max_size:
118
+ # Fits within current window
119
+ start_idx = self.offset
120
+ end_idx = self.offset + seq_len
121
+ self.keys[:, :, start_idx:end_idx, :] = keys
122
+ self.values[:, :, start_idx:end_idx, :] = values
123
+ self.offset += seq_len
124
+ else:
125
+ # Need to slide the window
126
+ if seq_len < self.max_size:
127
+ # Shift existing content left
128
+ shift_amount = min(seq_len, self.max_size - 1)
129
+ self.keys[:, :, :-shift_amount, :] = self.keys[:, :, shift_amount:, :]
130
+ self.values[:, :, :-shift_amount, :] = self.values[
131
+ :, :, shift_amount:, :
132
+ ]
133
+ # Add new tokens at the end
134
+ self.keys[:, :, -shift_amount:, :] = keys[:, :, -shift_amount:, :]
135
+ self.values[:, :, -shift_amount:, :] = values[:, :, -shift_amount:, :]
136
+ else:
137
+ # New sequence is larger than cache, just keep the last max_size tokens
138
+ self.keys = keys[:, :, -self.max_size :, :]
139
+ self.values = values[:, :, -self.max_size :, :]
140
+ self.offset = self.max_size
141
+
142
+ return self.keys, self.values
143
+
144
+ @property
145
+ def state(self):
146
+ if self.keys is None:
147
+ return None, None
148
+ return self.keys, self.values
149
+
150
+ @state.setter
151
+ def state(self, v):
152
+ if v is not None and len(v) == 2:
153
+ self.keys, self.values = v
154
+ if self.keys is not None:
155
+ self.offset = self.max_size
156
+
157
+ def get_max_cache_shape(self):
158
+ return self.max_size
159
+
160
+ @property
161
+ def meta_state(self):
162
+ return tuple(map(str, (self.max_size, self.step, self.offset)))
163
+
164
+ @meta_state.setter
165
+ def meta_state(self, v):
166
+ self.max_size, self.step, self.offset = map(int, v)
167
+
168
+ def is_trimmable(self):
169
+ return False
170
+
171
+ def trim(self, n):
172
+ return 0
173
+
174
+
175
+ class StaticKVCache(_BaseCache):
176
+ """A static cache that grows to accommodate all tokens."""
177
+
178
+ def __init__(self, max_size: int, step: int = 256):
179
+ self.max_size = max_size
180
+ self.step = step
181
+ self.keys = None
182
+ self.values = None
183
+ self.offset = 0
184
+
185
+ def update_and_fetch(
186
+ self, keys: mx.array, values: mx.array
187
+ ) -> Tuple[mx.array, mx.array]:
188
+ B, n_kv_heads, seq_len, k_head_dim = keys.shape
189
+ v_head_dim = values.shape[-1]
190
+
191
+ # Initialize cache if needed
192
+ if self.keys is None:
193
+ k_shape = (B, n_kv_heads, self.max_size, k_head_dim)
194
+ v_shape = (B, n_kv_heads, self.max_size, v_head_dim)
195
+ self.keys = mx.zeros(k_shape, dtype=keys.dtype)
196
+ self.values = mx.zeros(v_shape, dtype=values.dtype)
197
+
198
+ # Update cache
199
+ end_pos = min(self.offset + seq_len, self.max_size)
200
+ actual_seq_len = end_pos - self.offset
201
+
202
+ if actual_seq_len > 0:
203
+ self.keys[:, :, self.offset : end_pos, :] = keys[:, :, :actual_seq_len, :]
204
+ self.values[:, :, self.offset : end_pos, :] = values[
205
+ :, :, :actual_seq_len, :
206
+ ]
207
+ self.offset = end_pos
208
+
209
+ return self.keys, self.values
210
+
211
+ @property
212
+ def state(self):
213
+ if self.keys is None:
214
+ return None, None
215
+ return self.keys, self.values
216
+
217
+ @state.setter
218
+ def state(self, v):
219
+ if v is not None and len(v) == 2:
220
+ self.keys, self.values = v
221
+ if self.keys is not None:
222
+ self.offset = self.max_size
223
+
224
+ @property
225
+ def meta_state(self):
226
+ return tuple(map(str, (self.max_size, self.step, self.offset)))
227
+
228
+ @meta_state.setter
229
+ def meta_state(self, v):
230
+ self.max_size, self.step, self.offset = map(int, v)
231
+
232
+ def is_trimmable(self):
233
+ return True
234
+
235
+ def trim(self, n):
236
+ n = min(self.offset, n)
237
+ self.offset -= n
238
+ return n
@@ -0,0 +1,2 @@
1
+ from .config import MLPConfig, ModelConfig, ProjectorConfig, TextConfig, VisionConfig
2
+ from .deepseek_vl_v2 import DeepseekVLV2Processor, LanguageModel, Model, VisionModel
@@ -0,0 +1,159 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class TextConfig(BaseModelConfig):
10
+ model_type: str = "deepseek_v2"
11
+ vocab_size: int = 102400
12
+ hidden_size: int = 1280
13
+ intermediate_size: int = 6848
14
+ moe_intermediate_size: int = 896
15
+ num_hidden_layers: int = 30
16
+ num_attention_heads: int = 32
17
+ num_key_value_heads: int = 32
18
+ n_shared_experts: Optional[int] = 2
19
+ n_routed_experts: Optional[int] = 64
20
+ routed_scaling_factor: float = 1.0
21
+ kv_lora_rank: int = 512
22
+ q_lora_rank: int = 1536
23
+ qk_rope_head_dim: int = 64
24
+ v_head_dim: int = 128
25
+ qk_nope_head_dim: int = 128
26
+ topk_method: str = "greedy"
27
+ n_group: Optional[int] = 1
28
+ topk_group: Optional[int] = 1
29
+ num_experts_per_tok: Optional[int] = 6
30
+ moe_layer_freq: int = 1
31
+ first_k_dense_replace: int = 0
32
+ max_position_embeddings: int = 2048
33
+ rms_norm_eps: float = 1e-6
34
+ rope_theta: float = 10000.0
35
+ rope_traditional: bool = True
36
+ rope_scaling: Dict = None
37
+ attention_bias: bool = False
38
+ scoring_func: str = "softmax"
39
+ attn_type: str = "DeepseekV2Attention"
40
+
41
+ def __post_init__(self):
42
+ if self.qk_nope_head_dim == 0:
43
+ self.attn_type = "LlamaAttention"
44
+
45
+ if self.num_key_value_heads is None:
46
+ self.num_key_value_heads = self.num_attention_heads
47
+
48
+
49
+ @dataclass
50
+ class VisionConfig(BaseModelConfig):
51
+ model_type: str
52
+ layers: int = 27
53
+ width: int = 1152
54
+ intermediate_size: int = 4304
55
+ num_attention_heads: int = 16
56
+ image_size: int = 384
57
+ patch_size: int = 16
58
+ num_channels: int = 3
59
+ layer_norm_eps: float = 1e-6
60
+ mlp_ratio: float = 3.7362
61
+ cls: str = None
62
+ params: dict = None
63
+
64
+
65
+ @dataclass
66
+ class MLPConfig(BaseModelConfig):
67
+ width: int
68
+ intermediate_size: int
69
+ hidden_act: str = "gelu"
70
+
71
+
72
+ @dataclass
73
+ class ProjectorConfig(BaseModelConfig):
74
+ projector_type: str = "downsample_mlp_gelu"
75
+ input_dim: int = 1152
76
+ n_embed: int = 2048
77
+ depth: int = 2
78
+ mlp_ratio: int = 1
79
+ downsample_ratio: int = 2
80
+ token_pooling: bool = False
81
+
82
+
83
+ @dataclass
84
+ class ModelConfig(BaseModelConfig):
85
+ text_config: TextConfig
86
+ vision_config: VisionConfig
87
+ projector_config: ProjectorConfig
88
+ model_type: str
89
+ ignore_index: int = -100
90
+ image_token_index: int = 100015
91
+ vision_feature_select_strategy: str = "default"
92
+ select_layer: int = -1
93
+ pad_id: int = 100001
94
+ num_image_tokens: int = 576
95
+ vocab_size: int = 32000
96
+ tile_tag: str = "2D"
97
+ global_view_pos: str = "head"
98
+ eos_token_id: Optional[List[int]] = None
99
+ quantization: Optional[Dict] = None
100
+
101
+ @classmethod
102
+ def from_dict(cls, params):
103
+ if "language_config" in params:
104
+ params["text_config"] = params["language_config"]
105
+ del params["language_config"]
106
+
107
+ return cls(
108
+ text_config=TextConfig.from_dict(params["text_config"]),
109
+ vision_config=VisionConfig.from_dict(params["vision_config"]),
110
+ projector_config=ProjectorConfig.from_dict(params["projector_config"]),
111
+ **{
112
+ k: v
113
+ for k, v in params.items()
114
+ if k in inspect.signature(cls).parameters
115
+ and k not in ["text_config", "vision_config", "projector_config"]
116
+ },
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class Conversation:
122
+ """A class that represents a conversation."""
123
+
124
+ system: str
125
+ roles: List[str]
126
+ messages: List[List[str]]
127
+ offset: int
128
+ sep_style: int
129
+ sep: str
130
+ sep2: str
131
+ version: str = "Unknown"
132
+
133
+
134
+ @dataclass
135
+ class VLChatProcessorOutput:
136
+ """
137
+ Output of the VL chat processor.
138
+ """
139
+
140
+ sft_format: str
141
+ input_ids: List[int]
142
+ pixel_values: List
143
+ num_image_tokens: List[int]
144
+ image_grid_thw: List[List[int]]
145
+ image_sizes: Optional[List[List[int]]] = None
146
+ videos: Optional[List] = None
147
+ aspect_ratio_ids: Optional[List[int]] = None
148
+ aspect_ratio_mask: Optional[List[List[int]]] = None
149
+ cross_attention_mask: Optional[List[List[List[int]]]] = None
150
+ attention_mask: Optional[List[int]] = None
151
+ labels: Optional[List[int]] = None
152
+
153
+
154
+ @dataclass
155
+ class BatchCollateOutput:
156
+ input_ids: List
157
+ labels: List
158
+ attention_mask: List
159
+ pixel_values: List
@@ -0,0 +1,264 @@
1
+ """
2
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
+ """
4
+
5
+ import dataclasses
6
+ from enum import IntEnum, auto
7
+ from typing import Dict, List
8
+
9
+
10
+ class SeparatorStyle(IntEnum):
11
+ """Separator styles."""
12
+
13
+ DeepSeek = auto()
14
+ DeepSeekV2 = auto()
15
+ PLAIN = auto()
16
+ ALIGNMENT = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that manages prompt templates and keeps all conversation history."""
22
+
23
+ # The name of this template
24
+ name: str
25
+ # The template of the system prompt
26
+ system_template: str = "{system_message}"
27
+ # The system message
28
+ system_message: str = ""
29
+ # The names of two roles
30
+ roles: List[str] = (("USER", "ASSISTANT"),)
31
+ # All messages. Each item is (role, message).
32
+ messages: List[List[str]] = ()
33
+ # The number of few shot examples
34
+ offset: int = 0
35
+ # The separator style and configurations
36
+ sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37
+ sep: str = "\n"
38
+ sep2: str = None
39
+ # Stop criteria (the default one is EOS token)
40
+ stop_str: str = None
41
+ # Stops generation if meeting any token in this list
42
+ stop_token_ids: List[int] = None
43
+
44
+ def get_prompt(self) -> str:
45
+ """Get the prompt for generation."""
46
+ system_prompt = self.system_template.format(system_message=self.system_message)
47
+ if self.sep_style == SeparatorStyle.DeepSeek:
48
+ seps = [self.sep, self.sep2]
49
+ if system_prompt == "" or system_prompt is None:
50
+ ret = ""
51
+ else:
52
+ ret = system_prompt + seps[0]
53
+ for i, (role, message) in enumerate(self.messages):
54
+ if message:
55
+ ret += role + ": " + message + seps[i % 2]
56
+ else:
57
+ ret += role + ":"
58
+ return ret
59
+ elif self.sep_style == SeparatorStyle.DeepSeekV2:
60
+ seps = [self.sep, self.sep2]
61
+ if system_prompt == "" or system_prompt is None:
62
+ ret = ""
63
+ else:
64
+ ret = system_prompt + seps[0]
65
+ for i, (role, message) in enumerate(self.messages):
66
+ if message:
67
+ if role == "User":
68
+ ret += "<|sft▁begin|>\n" + message + self.sep
69
+ else:
70
+ ret += message + self.sep2
71
+ else:
72
+ ret = ret
73
+ return ret
74
+
75
+ elif self.sep_style == SeparatorStyle.PLAIN:
76
+ seps = [self.sep, self.sep2]
77
+ ret = ""
78
+ for i, (role, message) in enumerate(self.messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i % 2 == 0:
83
+ ret += message + seps[i % 2]
84
+ else:
85
+ ret += message + seps[i % 2]
86
+ else:
87
+ ret += ""
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
90
+ seps = [self.sep, self.sep2]
91
+ ret = ""
92
+ for i, (role, message) in enumerate(self.messages):
93
+ if message:
94
+ if type(message) is tuple:
95
+ message, _, _ = message
96
+ if i % 2 == 0:
97
+ ret += "<image>\n" + seps[i % 2]
98
+ else:
99
+ ret += message + seps[i % 2]
100
+ else:
101
+ ret += ""
102
+ return ret
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ def set_system_message(self, system_message: str):
107
+ """Set the system message."""
108
+ self.system_message = system_message
109
+
110
+ def append_message(self, role: str, message: str):
111
+ """Append a new message."""
112
+ self.messages.append([role, message])
113
+
114
+ def update_last_message(self, message: str):
115
+ """Update the last output.
116
+
117
+ The last message is typically set to be None when constructing the prompt,
118
+ so we need to update it in-place after getting the response from a model.
119
+ """
120
+ self.messages[-1][1] = message
121
+
122
+ def reset_message(self):
123
+ """Reset a new message."""
124
+ self.messages = []
125
+
126
+ def to_gradio_chatbot(self):
127
+ """Convert the conversation to gradio chatbot format."""
128
+ ret = []
129
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
130
+ if i % 2 == 0:
131
+ ret.append([msg, None])
132
+ else:
133
+ ret[-1][-1] = msg
134
+ return ret
135
+
136
+ def to_openai_api_messages(self):
137
+ """Convert the conversation to OpenAI chat completion format."""
138
+ system_prompt = self.system_template.format(system_message=self.system_message)
139
+ ret = [{"role": "system", "content": system_prompt}]
140
+
141
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
142
+ if i % 2 == 0:
143
+ ret.append({"role": "user", "content": msg})
144
+ else:
145
+ if msg is not None:
146
+ ret.append({"role": "assistant", "content": msg})
147
+ return ret
148
+
149
+ def copy(self):
150
+ return Conversation(
151
+ name=self.name,
152
+ system_template=self.system_template,
153
+ system_message=self.system_message,
154
+ roles=self.roles,
155
+ messages=[[x, y] for x, y in self.messages],
156
+ offset=self.offset,
157
+ sep_style=self.sep_style,
158
+ sep=self.sep,
159
+ sep2=self.sep2,
160
+ stop_str=self.stop_str,
161
+ stop_token_ids=self.stop_token_ids,
162
+ )
163
+
164
+ def dict(self):
165
+ return {
166
+ "template_name": self.name,
167
+ "system_message": self.system_message,
168
+ "roles": self.roles,
169
+ "messages": self.messages,
170
+ "offset": self.offset,
171
+ }
172
+
173
+
174
+ # A global registry for all conversation templates
175
+ conv_templates: Dict[str, Conversation] = {}
176
+
177
+
178
+ def register_conv_template(template: Conversation, override: bool = False):
179
+ """Register a new conversation template."""
180
+ if not override:
181
+ assert (
182
+ template.name not in conv_templates
183
+ ), f"{template.name} has been registered."
184
+
185
+ conv_templates[template.name] = template
186
+
187
+
188
+ def get_conv_template(name: str) -> Conversation:
189
+ """Get a conversation template."""
190
+ return conv_templates[name].copy()
191
+
192
+
193
+ register_conv_template(
194
+ Conversation(
195
+ name="deepseek",
196
+ system_template="{system_message}",
197
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
198
+ # "thinking step by step to be sure you get the right answer.",
199
+ system_message="",
200
+ roles=("<|User|>", "<|Assistant|>"),
201
+ messages=(),
202
+ offset=0,
203
+ sep_style=SeparatorStyle.DeepSeek,
204
+ sep="\n\n",
205
+ sep2="<|end▁of▁sentence|>",
206
+ stop_token_ids=[100001],
207
+ stop_str=["User:", "<|end▁of▁sentence|>"],
208
+ )
209
+ )
210
+
211
+ register_conv_template(
212
+ Conversation(
213
+ name="deepseekv2",
214
+ system_template="{system_message}",
215
+ system_message="",
216
+ roles=("|<User>|", "|<Assistant>|"),
217
+ messages=(),
218
+ offset=0,
219
+ sep_style=SeparatorStyle.DeepSeekV2,
220
+ sep="\n<|sft▁end|>",
221
+ sep2="<|end▁of▁sentence|>",
222
+ stop_token_ids=[100001],
223
+ stop_str=["User:", "<|end▁of▁sentence|>"],
224
+ )
225
+ )
226
+
227
+
228
+ register_conv_template(
229
+ Conversation(
230
+ name="plain",
231
+ system_template="",
232
+ system_message="",
233
+ roles=("", ""),
234
+ messages=(),
235
+ offset=0,
236
+ sep_style=SeparatorStyle.PLAIN,
237
+ sep="",
238
+ sep2="",
239
+ stop_token_ids=[100001],
240
+ stop_str=["</s>"],
241
+ )
242
+ )
243
+
244
+
245
+ register_conv_template(
246
+ Conversation(
247
+ name="alignment",
248
+ system_template="",
249
+ system_message="",
250
+ roles=("", ""),
251
+ messages=(),
252
+ offset=0,
253
+ sep_style=SeparatorStyle.ALIGNMENT,
254
+ sep="",
255
+ sep2="",
256
+ stop_token_ids=[100001],
257
+ stop_str=["</s>"],
258
+ )
259
+ )
260
+
261
+
262
+ if __name__ == "__main__":
263
+ print("deepseek template:")
264
+ conv = get_conv_template("deepseek")