fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,706 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from mlx_vlm.models.qwen3_omni_moe.code2wav import Code2WavModel
8
+ from mlx_vlm.models.qwen3_omni_moe.talker import Talker
9
+ from mlx_vlm.models.qwen3_omni_moe.thinker import Thinker
10
+
11
+ from .config import ModelConfig
12
+
13
+
14
+ def masked_scatter(
15
+ final_embedding: mx.array,
16
+ image_mask_expanded: mx.array,
17
+ scaled_image_features: mx.array,
18
+ ):
19
+ final_embedding_shape = final_embedding.shape
20
+ scaled_image_features_flattened = mx.flatten(scaled_image_features)
21
+ final_embedding_flattened = mx.flatten(final_embedding)
22
+ image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
23
+
24
+ image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
25
+ final_embedding_flattened[image_positions] = scaled_image_features_flattened
26
+
27
+ final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
28
+
29
+ return final_embedding
30
+
31
+
32
+ class Model(nn.Module):
33
+ def __init__(self, config: ModelConfig):
34
+ super().__init__()
35
+ self.config = config
36
+
37
+ self.thinker = Thinker(config.thinker_config)
38
+ self.has_talker = config.enable_audio_output
39
+ if self.has_talker:
40
+ self.talker = Talker(config.talker_config)
41
+ self.code2wav = Code2WavModel(config.code2wav_config)
42
+ else:
43
+ self.talker = None
44
+ self.code2wav = None
45
+
46
+ def enable_talker(self):
47
+ if not self.has_talker:
48
+ self.talker = Talker(self.config.talker_config)
49
+ self.code2wav = Code2WavModel(self.config.code2wav_config)
50
+ self.has_talker = True
51
+
52
+ def disable_talker(self):
53
+ if self.has_talker:
54
+ self.talker = None
55
+ self.code2wav = None
56
+ self.has_talker = False
57
+
58
+ def get_input_embeddings(
59
+ self,
60
+ input_ids: Optional[mx.array] = None,
61
+ pixel_values: Optional[mx.array] = None,
62
+ pixel_values_videos: Optional[mx.array] = None,
63
+ input_features: Optional[mx.array] = None,
64
+ input_features_mask: Optional[mx.array] = None,
65
+ image_grid_thw: Optional[mx.array] = None,
66
+ video_grid_thw: Optional[mx.array] = None,
67
+ audio_feature_lengths: Optional[mx.array] = None,
68
+ **kwargs,
69
+ ):
70
+ return self.thinker.get_input_embeddings(
71
+ input_ids=input_ids,
72
+ pixel_values=pixel_values,
73
+ pixel_values_videos=pixel_values_videos,
74
+ image_grid_thw=image_grid_thw,
75
+ video_grid_thw=video_grid_thw,
76
+ input_features=input_features,
77
+ feature_attention_mask=input_features_mask,
78
+ audio_feature_lengths=audio_feature_lengths,
79
+ )
80
+
81
+ def get_audio_features(
82
+ self,
83
+ input_features: mx.array,
84
+ input_features_mask: Optional[mx.array] = None,
85
+ audio_feature_lengths: Optional[mx.array] = None,
86
+ ):
87
+ return self.thinker.get_audio_features(
88
+ input_features=input_features,
89
+ feature_attention_mask=input_features_mask,
90
+ audio_feature_lengths=audio_feature_lengths,
91
+ )
92
+
93
+ def get_image_features(
94
+ self,
95
+ pixel_values: mx.array,
96
+ image_grid_thw: Optional[mx.array] = None,
97
+ ):
98
+ dtype = self.thinker.vision_tower.patch_embed.proj.weight.dtype
99
+ pixel_values = pixel_values.astype(dtype)
100
+ vision_output = self.thinker.vision_tower(pixel_values, image_grid_thw)
101
+ if isinstance(vision_output, tuple):
102
+ return vision_output[0]
103
+ return vision_output
104
+
105
+ @property
106
+ def layers(self):
107
+ return self.thinker.language_model.layers
108
+
109
+ def extract_thinker_hidden_states(self, input_ids, target_layer_idx, **kwargs):
110
+ embed_kwargs = {
111
+ k: v
112
+ for k, v in kwargs.items()
113
+ if k
114
+ in [
115
+ "pixel_values",
116
+ "pixel_values_videos",
117
+ "image_grid_thw",
118
+ "video_grid_thw",
119
+ "input_features",
120
+ "feature_attention_mask",
121
+ "audio_feature_lengths",
122
+ ]
123
+ }
124
+ inputs_embeds, _, _ = self.thinker.get_input_embeddings(
125
+ input_ids, **embed_kwargs
126
+ )
127
+
128
+ lm_kwargs = {
129
+ k: v for k, v in kwargs.items() if k in ["image_grid_thw", "video_grid_thw"]
130
+ }
131
+
132
+ outputs = self.thinker.language_model(
133
+ input_ids,
134
+ inputs_embeds=inputs_embeds,
135
+ output_hidden_states=True,
136
+ **lm_kwargs,
137
+ )
138
+
139
+ hidden_states = outputs.hidden_states[target_layer_idx + 1]
140
+
141
+ return hidden_states, inputs_embeds
142
+
143
+ def _get_talker_user_parts(
144
+ self,
145
+ im_start_index: int,
146
+ segment_end_index: int,
147
+ multimodal_mask: mx.array,
148
+ thinker_hidden: mx.array,
149
+ thinker_embed: mx.array,
150
+ ):
151
+ seq_len = segment_end_index - im_start_index
152
+ user_talker_part = mx.zeros(
153
+ (1, seq_len, self.config.talker_config.text_config.hidden_size),
154
+ dtype=thinker_embed.dtype,
155
+ )
156
+ user_mm_mask = multimodal_mask[:, im_start_index:segment_end_index]
157
+ user_thinker_hidden_mm = thinker_hidden[:, im_start_index:segment_end_index]
158
+ user_thinker_embed_seg = thinker_embed[:, im_start_index:segment_end_index]
159
+
160
+ if mx.any(user_mm_mask):
161
+ mm_indices = mx.array(
162
+ np.where(np.array(mx.reshape(user_mm_mask, (-1,))))[0]
163
+ )
164
+ user_thinker_hidden_mm_flat = mx.reshape(
165
+ user_thinker_hidden_mm, (-1, user_thinker_hidden_mm.shape[-1])
166
+ )
167
+ mm_hidden_flat = mx.take(user_thinker_hidden_mm_flat, mm_indices, axis=0)
168
+ mm_hidden = self.talker.hidden_projection(mm_hidden_flat)
169
+ user_talker_part_flat = mx.reshape(
170
+ user_talker_part, (-1, user_talker_part.shape[-1])
171
+ )
172
+ user_talker_part_flat[mm_indices] = mm_hidden
173
+ user_talker_part = mx.reshape(user_talker_part_flat, user_talker_part.shape)
174
+
175
+ text_mask = ~user_mm_mask
176
+ if mx.any(text_mask):
177
+ text_indices = mx.array(np.where(np.array(mx.reshape(text_mask, (-1,))))[0])
178
+ user_thinker_embed_flat = mx.reshape(
179
+ user_thinker_embed_seg, (-1, user_thinker_embed_seg.shape[-1])
180
+ )
181
+ text_embed_flat = mx.take(user_thinker_embed_flat, text_indices, axis=0)
182
+ user_text_hidden = self.talker.text_projection(text_embed_flat)
183
+ user_talker_part_flat = mx.reshape(
184
+ user_talker_part, (-1, user_talker_part.shape[-1])
185
+ )
186
+ user_talker_part_flat[text_indices] = user_text_hidden
187
+ user_talker_part = mx.reshape(user_talker_part_flat, user_talker_part.shape)
188
+
189
+ return user_talker_part
190
+
191
+ def _get_talker_assistant_parts(
192
+ self,
193
+ im_start_index: int,
194
+ segment_end_index: int,
195
+ speaker_id: int,
196
+ thinker_embed: mx.array,
197
+ tts_pad_embed: mx.array,
198
+ tts_bos_embed: mx.array,
199
+ tts_eos_embed: mx.array,
200
+ ):
201
+ assistant_hidden = self.talker.text_projection(
202
+ thinker_embed[:, im_start_index:segment_end_index]
203
+ )
204
+ assistant_text_hidden = mx.concatenate(
205
+ (
206
+ assistant_hidden[:, :3],
207
+ mx.broadcast_to(tts_pad_embed, (1, 4, tts_pad_embed.shape[-1])),
208
+ tts_bos_embed,
209
+ assistant_hidden[:, 3:4],
210
+ ),
211
+ axis=1,
212
+ )
213
+ codec_special_tokens = mx.array(
214
+ [
215
+ [
216
+ self.config.talker_config.codec_nothink_id,
217
+ self.config.talker_config.codec_think_bos_id,
218
+ self.config.talker_config.codec_think_eos_id,
219
+ speaker_id,
220
+ self.config.talker_config.codec_pad_id,
221
+ self.config.talker_config.codec_bos_id,
222
+ ]
223
+ ],
224
+ dtype=mx.int32,
225
+ )
226
+ assistant_codec_hidden = mx.concatenate(
227
+ (
228
+ mx.zeros(
229
+ (1, 3, self.config.talker_config.text_config.hidden_size),
230
+ dtype=thinker_embed.dtype,
231
+ ),
232
+ self.talker.model.codec_embedding(codec_special_tokens),
233
+ ),
234
+ axis=1,
235
+ )
236
+ trailing_text_hidden = mx.concatenate(
237
+ (
238
+ assistant_hidden[:, 4:],
239
+ tts_eos_embed,
240
+ ),
241
+ axis=1,
242
+ )
243
+ input_embeds = assistant_text_hidden + assistant_codec_hidden
244
+ input_ids = mx.full(
245
+ (1, assistant_text_hidden.shape[1]),
246
+ self.config.tts_pad_token_id,
247
+ dtype=mx.int32,
248
+ )
249
+ return input_embeds, input_ids, trailing_text_hidden
250
+
251
+ def __call__(
252
+ self,
253
+ input_ids: mx.array,
254
+ pixel_values: Optional[mx.array] = None,
255
+ pixel_values_videos: Optional[mx.array] = None,
256
+ mask: Optional[mx.array] = None,
257
+ cache=None,
258
+ **kwargs,
259
+ ):
260
+ return self.thinker(
261
+ input_ids=input_ids,
262
+ pixel_values=pixel_values,
263
+ pixel_values_videos=pixel_values_videos,
264
+ mask=mask,
265
+ cache=cache,
266
+ **kwargs,
267
+ )
268
+
269
+ def generate(
270
+ self,
271
+ input_ids: mx.array,
272
+ speaker: str = "Ethan",
273
+ use_audio_in_video: bool = False,
274
+ return_audio: Optional[bool] = None,
275
+ thinker_max_new_tokens: int = 1024,
276
+ thinker_eos_token_id: int = 151645,
277
+ talker_max_new_tokens: int = 4096,
278
+ talker_do_sample: bool = True,
279
+ talker_top_k: int = 50,
280
+ talker_top_p: float = 1.0,
281
+ talker_temperature: float = 0.9,
282
+ talker_repetition_penalty: float = 1.05,
283
+ **kwargs,
284
+ ):
285
+ if return_audio and not self.has_talker:
286
+ raise ValueError(
287
+ "Cannot use talker when talker module not initialized. Use `enable_talker` method or set enable_audio_output in config to enable talker."
288
+ )
289
+ if return_audio is None:
290
+ return_audio = self.has_talker
291
+
292
+ if not return_audio:
293
+ from mlx_vlm.generate import generate_step
294
+
295
+ thinker_kwargs = {
296
+ "max_tokens": thinker_max_new_tokens,
297
+ "eos_tokens": [thinker_eos_token_id],
298
+ }
299
+ for key, value in kwargs.items():
300
+ if key.startswith("thinker_"):
301
+ thinker_kwargs[key[len("thinker_") :]] = value
302
+ elif key in (
303
+ "input_features",
304
+ "feature_attention_mask",
305
+ "audio_feature_lengths",
306
+ "pixel_values",
307
+ "pixel_values_videos",
308
+ "image_grid_thw",
309
+ "video_grid_thw",
310
+ ):
311
+ thinker_kwargs[key] = value
312
+
313
+ generator = generate_step(
314
+ input_ids,
315
+ self.thinker,
316
+ thinker_kwargs.get("pixel_values"),
317
+ kwargs.get("mask"),
318
+ **{
319
+ k: v
320
+ for k, v in thinker_kwargs.items()
321
+ if k not in ("pixel_values", "mask")
322
+ },
323
+ )
324
+ sequences = [input_ids]
325
+ for token, _ in generator:
326
+ sequences.append(mx.array([[token]]))
327
+ if token == thinker_eos_token_id:
328
+ break
329
+ thinker_result = type(
330
+ "obj",
331
+ (object,),
332
+ {
333
+ "sequences": mx.concatenate(sequences, axis=1),
334
+ "hidden_states": None,
335
+ },
336
+ )()
337
+ return thinker_result, None
338
+
339
+ if input_ids.shape[0] != 1:
340
+ raise NotImplementedError(
341
+ "Qwen3-Omni currently does not support batched inference with audio output"
342
+ )
343
+
344
+ speaker_id = self.config.talker_config.speaker_id.get(speaker.lower())
345
+ if speaker_id is None:
346
+ raise NotImplementedError(f"Speaker {speaker} not implemented")
347
+
348
+ from mlx_vlm.generate import generate_step
349
+
350
+ thinker_kwargs = {
351
+ "max_tokens": thinker_max_new_tokens,
352
+ "eos_tokens": [thinker_eos_token_id],
353
+ "output_hidden_states": True,
354
+ }
355
+ for key, value in kwargs.items():
356
+ if key.startswith("thinker_"):
357
+ thinker_kwargs[key[len("thinker_") :]] = value
358
+ elif key in (
359
+ "input_features",
360
+ "feature_attention_mask",
361
+ "audio_feature_lengths",
362
+ "pixel_values",
363
+ "pixel_values_videos",
364
+ "image_grid_thw",
365
+ "video_grid_thw",
366
+ ):
367
+ thinker_kwargs[key] = value
368
+
369
+ generator = generate_step(
370
+ input_ids,
371
+ self.thinker,
372
+ thinker_kwargs.get("pixel_values"),
373
+ kwargs.get("mask"),
374
+ **{
375
+ k: v
376
+ for k, v in thinker_kwargs.items()
377
+ if k not in ("pixel_values", "mask", "output_hidden_states")
378
+ },
379
+ )
380
+ sequences = [input_ids]
381
+ hidden_states_list = []
382
+ for token, _ in generator:
383
+ sequences.append(mx.array([[token]]))
384
+ if token == thinker_eos_token_id:
385
+ break
386
+
387
+ thinker_result_sequences = mx.concatenate(sequences, axis=1)
388
+
389
+ thinker_hidden_all, thinker_embed_all = self.extract_thinker_hidden_states(
390
+ thinker_result_sequences,
391
+ target_layer_idx=self.config.talker_config.accept_hidden_layer,
392
+ **kwargs,
393
+ )
394
+
395
+ im_start_indexes = mx.concatenate(
396
+ (
397
+ mx.array(
398
+ np.where(np.array(input_ids[0] == self.config.im_start_token_id))[0]
399
+ ),
400
+ mx.array([thinker_result_sequences.shape[-1]], dtype=mx.int32),
401
+ ),
402
+ axis=0,
403
+ )
404
+ multimodal_mask = (
405
+ (thinker_result_sequences == self.config.thinker_config.audio_token_id)
406
+ | (thinker_result_sequences == self.config.thinker_config.image_token_id)
407
+ | (thinker_result_sequences == self.config.thinker_config.video_token_id)
408
+ )
409
+
410
+ talker_special_tokens = mx.array(
411
+ [
412
+ [
413
+ self.config.tts_bos_token_id,
414
+ self.config.tts_eos_token_id,
415
+ self.config.tts_pad_token_id,
416
+ ]
417
+ ],
418
+ dtype=input_ids.dtype,
419
+ )
420
+ talker_special_embeds = self.thinker.language_model.model.embed_tokens(
421
+ talker_special_tokens
422
+ )
423
+ talker_special_embeds_proj = self.talker.text_projection(talker_special_embeds)
424
+ tts_bos_embed = talker_special_embeds_proj[:, 0:1]
425
+ tts_eos_embed = talker_special_embeds_proj[:, 1:2]
426
+ tts_pad_embed = talker_special_embeds_proj[:, 2:3]
427
+
428
+ talker_input_embeds = []
429
+ talker_input_ids = []
430
+
431
+ for i in range(len(im_start_indexes) - 1):
432
+ im_start_index = int(im_start_indexes[i])
433
+ segment_end_index = int(im_start_indexes[i + 1])
434
+ role_token = int(input_ids[0, im_start_index + 1])
435
+
436
+ if role_token == self.config.system_token_id:
437
+ continue
438
+ elif role_token == self.config.user_token_id:
439
+ talker_user_part = self._get_talker_user_parts(
440
+ im_start_index,
441
+ segment_end_index,
442
+ multimodal_mask,
443
+ thinker_hidden_all,
444
+ thinker_embed_all,
445
+ )
446
+ talker_input_embeds.append(talker_user_part)
447
+ talker_input_ids.append(
448
+ thinker_result_sequences[:, im_start_index:segment_end_index]
449
+ )
450
+ elif (
451
+ role_token == self.config.assistant_token_id
452
+ and i == len(im_start_indexes) - 2
453
+ ):
454
+ talker_assistant_embeds, talker_assistant_ids, trailing_text_hidden = (
455
+ self._get_talker_assistant_parts(
456
+ im_start_index,
457
+ segment_end_index,
458
+ speaker_id,
459
+ thinker_embed_all,
460
+ tts_pad_embed,
461
+ tts_bos_embed,
462
+ tts_eos_embed,
463
+ )
464
+ )
465
+ talker_input_embeds.append(talker_assistant_embeds)
466
+ talker_input_ids.append(talker_assistant_ids)
467
+ elif (
468
+ role_token == self.config.assistant_token_id
469
+ and i != len(im_start_indexes) - 2
470
+ ):
471
+ continue
472
+ else:
473
+ raise AssertionError(
474
+ "Expect role id after <|im_start|> (assistant, user, system)"
475
+ )
476
+
477
+ if len(talker_input_embeds) == 0:
478
+ return (
479
+ type(
480
+ "obj",
481
+ (object,),
482
+ {
483
+ "sequences": thinker_result_sequences,
484
+ "hidden_states": None,
485
+ },
486
+ )(),
487
+ None,
488
+ )
489
+
490
+ talker_input_embed = mx.concatenate(talker_input_embeds, axis=1)
491
+ talker_input_id = mx.concatenate(talker_input_ids, axis=1)
492
+
493
+ talker_result = self.talker.generate(
494
+ inputs_embeds=talker_input_embed,
495
+ trailing_text_hidden=trailing_text_hidden,
496
+ tts_pad_embed=tts_pad_embed,
497
+ talker_input_ids=talker_input_id,
498
+ max_new_tokens=talker_max_new_tokens,
499
+ temperature=talker_temperature,
500
+ top_p=talker_top_p,
501
+ )
502
+
503
+ valid_codes = [
504
+ hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None
505
+ ]
506
+ if not valid_codes:
507
+ talker_wavs = mx.zeros((1, 1, 1000))
508
+ else:
509
+ talker_codes = mx.stack(valid_codes, axis=1).transpose(0, 2, 1)
510
+ talker_wavs = self.code2wav.chunked_decode(
511
+ talker_codes, chunk_size=300, left_context_size=25
512
+ )
513
+
514
+ thinker_result = type(
515
+ "obj",
516
+ (object,),
517
+ {
518
+ "sequences": thinker_result_sequences,
519
+ "hidden_states": None,
520
+ },
521
+ )()
522
+
523
+ return thinker_result, talker_wavs.astype(mx.float32)
524
+
525
+ def generate_stream(
526
+ self,
527
+ input_ids: mx.array,
528
+ speaker: str = "Ethan",
529
+ thinker_max_new_tokens: int = 1024,
530
+ thinker_eos_token_id: int = 151645,
531
+ talker_max_new_tokens: int = 4096,
532
+ talker_top_p: float = 1.0,
533
+ talker_temperature: float = 0.9,
534
+ chunk_size: int = 300,
535
+ left_context_size: int = 25,
536
+ **kwargs,
537
+ ):
538
+ if not self.has_talker:
539
+ raise ValueError("Cannot stream audio without talker module")
540
+ if input_ids.shape[0] != 1:
541
+ raise NotImplementedError("Streaming does not support batched inference")
542
+
543
+ speaker_id = self.config.talker_config.speaker_id.get(speaker.lower())
544
+ if speaker_id is None:
545
+ raise NotImplementedError(f"Speaker {speaker} not implemented")
546
+
547
+ from mlx_vlm.generate import generate_step
548
+
549
+ thinker_kwargs = {
550
+ "max_tokens": thinker_max_new_tokens,
551
+ "eos_tokens": [thinker_eos_token_id],
552
+ }
553
+ for key, value in kwargs.items():
554
+ if key.startswith("thinker_"):
555
+ thinker_kwargs[key[len("thinker_") :]] = value
556
+ elif key in (
557
+ "input_features",
558
+ "feature_attention_mask",
559
+ "audio_feature_lengths",
560
+ "pixel_values",
561
+ "pixel_values_videos",
562
+ "image_grid_thw",
563
+ "video_grid_thw",
564
+ ):
565
+ thinker_kwargs[key] = value
566
+
567
+ generator = generate_step(
568
+ input_ids,
569
+ self.thinker,
570
+ thinker_kwargs.get("pixel_values"),
571
+ kwargs.get("mask"),
572
+ **{
573
+ k: v
574
+ for k, v in thinker_kwargs.items()
575
+ if k not in ("pixel_values", "mask")
576
+ },
577
+ )
578
+ sequences = [input_ids]
579
+ for token, _ in generator:
580
+ sequences.append(mx.array([[token]]))
581
+ if token == thinker_eos_token_id:
582
+ break
583
+
584
+ thinker_result_sequences = mx.concatenate(sequences, axis=1)
585
+ thinker_hidden_all, thinker_embed_all = self.extract_thinker_hidden_states(
586
+ thinker_result_sequences,
587
+ target_layer_idx=self.config.talker_config.accept_hidden_layer,
588
+ **kwargs,
589
+ )
590
+
591
+ im_start_indexes = mx.concatenate(
592
+ (
593
+ mx.array(
594
+ np.where(np.array(input_ids[0] == self.config.im_start_token_id))[0]
595
+ ),
596
+ mx.array([thinker_result_sequences.shape[-1]], dtype=mx.int32),
597
+ ),
598
+ axis=0,
599
+ )
600
+ multimodal_mask = (
601
+ (thinker_result_sequences == self.config.thinker_config.audio_token_id)
602
+ | (thinker_result_sequences == self.config.thinker_config.image_token_id)
603
+ | (thinker_result_sequences == self.config.thinker_config.video_token_id)
604
+ )
605
+
606
+ talker_special_tokens = mx.array(
607
+ [
608
+ [
609
+ self.config.tts_bos_token_id,
610
+ self.config.tts_eos_token_id,
611
+ self.config.tts_pad_token_id,
612
+ ]
613
+ ],
614
+ dtype=input_ids.dtype,
615
+ )
616
+ talker_special_embeds = self.thinker.language_model.model.embed_tokens(
617
+ talker_special_tokens
618
+ )
619
+ talker_special_embeds_proj = self.talker.text_projection(talker_special_embeds)
620
+ tts_bos_embed, tts_eos_embed, tts_pad_embed = (
621
+ talker_special_embeds_proj[:, 0:1],
622
+ talker_special_embeds_proj[:, 1:2],
623
+ talker_special_embeds_proj[:, 2:3],
624
+ )
625
+
626
+ talker_input_embeds, talker_input_ids = [], []
627
+ trailing_text_hidden = None
628
+
629
+ for i in range(len(im_start_indexes) - 1):
630
+ im_start_index, segment_end_index = int(im_start_indexes[i]), int(
631
+ im_start_indexes[i + 1]
632
+ )
633
+ role_token = int(input_ids[0, im_start_index + 1])
634
+
635
+ if role_token == self.config.system_token_id:
636
+ continue
637
+ elif role_token == self.config.user_token_id:
638
+ talker_input_embeds.append(
639
+ self._get_talker_user_parts(
640
+ im_start_index,
641
+ segment_end_index,
642
+ multimodal_mask,
643
+ thinker_hidden_all,
644
+ thinker_embed_all,
645
+ )
646
+ )
647
+ talker_input_ids.append(
648
+ thinker_result_sequences[:, im_start_index:segment_end_index]
649
+ )
650
+ elif (
651
+ role_token == self.config.assistant_token_id
652
+ and i == len(im_start_indexes) - 2
653
+ ):
654
+ talker_assistant_embeds, talker_assistant_ids, trailing_text_hidden = (
655
+ self._get_talker_assistant_parts(
656
+ im_start_index,
657
+ segment_end_index,
658
+ speaker_id,
659
+ thinker_embed_all,
660
+ tts_pad_embed,
661
+ tts_bos_embed,
662
+ tts_eos_embed,
663
+ )
664
+ )
665
+ talker_input_embeds.append(talker_assistant_embeds)
666
+ talker_input_ids.append(talker_assistant_ids)
667
+
668
+ if not talker_input_embeds:
669
+ return
670
+
671
+ talker_input_embed = mx.concatenate(talker_input_embeds, axis=1)
672
+ talker_input_id = mx.concatenate(talker_input_ids, axis=1)
673
+
674
+ generated_tokens = thinker_result_sequences[0, input_ids.shape[1] :].tolist()
675
+ yield ("text", generated_tokens)
676
+
677
+ codes_list = []
678
+ decoded_len = 0
679
+
680
+ for residual_codes in self.talker.generate_stream(
681
+ inputs_embeds=talker_input_embed,
682
+ trailing_text_hidden=trailing_text_hidden,
683
+ tts_pad_embed=tts_pad_embed,
684
+ talker_input_ids=talker_input_id,
685
+ max_new_tokens=talker_max_new_tokens,
686
+ temperature=talker_temperature,
687
+ top_p=talker_top_p,
688
+ ):
689
+ codes_list.append(residual_codes)
690
+ if len(codes_list) >= chunk_size:
691
+ codes_buffer = mx.stack(codes_list, axis=1).transpose(0, 2, 1)
692
+ wav_chunk, decoded_len = self.code2wav.stream_decode(
693
+ codes_buffer, chunk_size, left_context_size, decoded_len
694
+ )
695
+ if wav_chunk is not None:
696
+ mx.eval(wav_chunk)
697
+ yield ("audio", wav_chunk.astype(mx.float32))
698
+
699
+ if codes_list:
700
+ codes_buffer = mx.stack(codes_list, axis=1).transpose(0, 2, 1)
701
+ wav_chunk = self.code2wav.flush_decode(
702
+ codes_buffer, left_context_size, decoded_len
703
+ )
704
+ if wav_chunk is not None:
705
+ mx.eval(wav_chunk)
706
+ yield ("audio", wav_chunk.astype(mx.float32))