fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,873 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from mlx_vlm.models.qwen3_omni_moe.config import (
9
+ CodePredictorConfig,
10
+ TalkerConfig,
11
+ TextConfig,
12
+ )
13
+ from mlx_vlm.sample_utils import top_p_sampling
14
+
15
+ from ..base import create_attention_mask, scaled_dot_product_attention
16
+ from ..cache import KVCache
17
+ from .language import Attention, Qwen3OmniMoeThinkerTextRotaryEmbedding
18
+
19
+
20
+ class CodePredictorRotaryEmbedding:
21
+ def __init__(self, config: CodePredictorConfig):
22
+ self.config = config
23
+ head_dim = config.head_dim
24
+ inv_freq = 1.0 / (
25
+ config.rope_theta
26
+ ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim)
27
+ )
28
+ self.inv_freq = inv_freq
29
+ self.attention_scaling = 1.0
30
+
31
+ def __call__(
32
+ self, x: mx.array, position_ids: mx.array
33
+ ) -> Tuple[mx.array, mx.array]:
34
+ batch_size = position_ids.shape[0]
35
+ inv_freq_expanded = mx.broadcast_to(
36
+ self.inv_freq[None, :, None].astype(mx.float32),
37
+ (batch_size, self.inv_freq.shape[0], 1),
38
+ )
39
+ position_ids_expanded = mx.expand_dims(position_ids.astype(mx.float32), axis=1)
40
+ freqs = inv_freq_expanded @ position_ids_expanded
41
+ freqs = mx.swapaxes(freqs, 1, 2)
42
+ emb = mx.concatenate([freqs, freqs], axis=-1)
43
+ cos = mx.cos(emb) * self.attention_scaling
44
+ sin = mx.sin(emb) * self.attention_scaling
45
+ return cos.astype(x.dtype), sin.astype(x.dtype)
46
+
47
+
48
+ def rotate_half_code(x):
49
+ x1 = x[..., : x.shape[-1] // 2]
50
+ x2 = x[..., x.shape[-1] // 2 :]
51
+ return mx.concatenate([-x2, x1], axis=-1)
52
+
53
+
54
+ def apply_rotary_pos_emb_code(q, k, cos, sin):
55
+ cos = mx.expand_dims(cos, axis=1)
56
+ sin = mx.expand_dims(sin, axis=1)
57
+ q_embed = (q * cos) + (rotate_half_code(q) * sin)
58
+ k_embed = (k * cos) + (rotate_half_code(k) * sin)
59
+ return q_embed, k_embed
60
+
61
+
62
+ class CodePredictorMLP(nn.Module):
63
+ def __init__(self, config: CodePredictorConfig):
64
+ super().__init__()
65
+ self.config = config
66
+ self.hidden_size = config.hidden_size
67
+ self.intermediate_size = config.intermediate_size
68
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
69
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
70
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
71
+
72
+ if config.hidden_act == "silu":
73
+ self.act_fn = nn.silu
74
+ elif config.hidden_act == "gelu":
75
+ self.act_fn = nn.gelu
76
+ elif config.hidden_act == "gelu_pytorch_tanh":
77
+ self.act_fn = nn.GELU(approx="precise")
78
+ else:
79
+ self.act_fn = nn.silu
80
+
81
+ def __call__(self, x: mx.array) -> mx.array:
82
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
83
+
84
+
85
+ class CodePredictorAttention(nn.Module):
86
+ def __init__(self, config: CodePredictorConfig, idx: int):
87
+ super().__init__()
88
+ self.config = config
89
+ self.layer_idx = idx
90
+ self.head_dim = getattr(
91
+ config, "head_dim", config.hidden_size // config.num_attention_heads
92
+ )
93
+ self.num_key_value_groups = (
94
+ config.num_attention_heads // config.num_key_value_heads
95
+ )
96
+ self.scaling = self.head_dim**-0.5
97
+ self.attention_dropout = config.attention_dropout
98
+ self.is_causal = True
99
+
100
+ self.q_proj = nn.Linear(
101
+ config.hidden_size,
102
+ config.num_attention_heads * self.head_dim,
103
+ bias=config.attention_bias,
104
+ )
105
+ self.k_proj = nn.Linear(
106
+ config.hidden_size,
107
+ config.num_key_value_heads * self.head_dim,
108
+ bias=config.attention_bias,
109
+ )
110
+ self.v_proj = nn.Linear(
111
+ config.hidden_size,
112
+ config.num_key_value_heads * self.head_dim,
113
+ bias=config.attention_bias,
114
+ )
115
+ self.o_proj = nn.Linear(
116
+ config.num_attention_heads * self.head_dim,
117
+ config.hidden_size,
118
+ bias=config.attention_bias,
119
+ )
120
+ self.q_norm = nn.RMSNorm(dims=self.head_dim, eps=config.rms_norm_eps)
121
+ self.k_norm = nn.RMSNorm(dims=self.head_dim, eps=config.rms_norm_eps)
122
+ self.sliding_window = (
123
+ config.sliding_window
124
+ if (
125
+ hasattr(config, "layer_types")
126
+ and config.layer_types
127
+ and idx < len(config.layer_types)
128
+ and config.layer_types[idx] == "sliding_attention"
129
+ )
130
+ else None
131
+ )
132
+ self.rotary_emb = CodePredictorRotaryEmbedding(config)
133
+
134
+ def __call__(
135
+ self,
136
+ hidden_states: mx.array,
137
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
138
+ attention_mask: Optional[mx.array] = None,
139
+ position_ids: Optional[mx.array] = None,
140
+ past_key_values: Optional[KVCache] = None,
141
+ cache_position: Optional[mx.array] = None,
142
+ ) -> Tuple[mx.array, Optional[mx.array]]:
143
+ B, L, D = hidden_states.shape
144
+ hidden_shape = (B, L, -1, self.head_dim)
145
+
146
+ query_states = (
147
+ self.q_proj(hidden_states).reshape(*hidden_shape).transpose(0, 2, 1, 3)
148
+ )
149
+ key_states = (
150
+ self.k_proj(hidden_states).reshape(*hidden_shape).transpose(0, 2, 1, 3)
151
+ )
152
+ value_states = (
153
+ self.v_proj(hidden_states).reshape(*hidden_shape).transpose(0, 2, 1, 3)
154
+ )
155
+
156
+ query_states = self.q_norm(query_states)
157
+ key_states = self.k_norm(key_states)
158
+
159
+ if position_embeddings is None:
160
+ if position_ids is None:
161
+ if past_key_values is not None:
162
+ offset = (
163
+ past_key_values.offset
164
+ if hasattr(past_key_values, "offset")
165
+ else 0
166
+ )
167
+ position_ids = mx.arange(offset, offset + L)
168
+ else:
169
+ position_ids = mx.arange(L)
170
+ position_ids = mx.expand_dims(position_ids, axis=0)
171
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
172
+ else:
173
+ cos, sin = position_embeddings
174
+
175
+ query_states, key_states = apply_rotary_pos_emb_code(
176
+ query_states, key_states, cos, sin
177
+ )
178
+
179
+ if past_key_values is not None:
180
+ key_states, value_states = past_key_values.update_and_fetch(
181
+ key_states, value_states
182
+ )
183
+
184
+ if attention_mask is not None and isinstance(attention_mask, mx.array):
185
+ kv_seq_len = key_states.shape[-2]
186
+ if attention_mask.shape[-1] != kv_seq_len:
187
+ attention_mask = attention_mask[..., :kv_seq_len]
188
+
189
+ if self.is_causal and attention_mask is None:
190
+ attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
191
+ attention_mask = attention_mask.astype(query_states.dtype)
192
+
193
+ attn_output = scaled_dot_product_attention(
194
+ query_states,
195
+ key_states,
196
+ value_states,
197
+ past_key_values,
198
+ scale=self.scaling,
199
+ mask=attention_mask,
200
+ )
201
+
202
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
203
+ attn_output = self.o_proj(attn_output)
204
+ return attn_output, None
205
+
206
+
207
+ class CodePredictorDecoderLayer(nn.Module):
208
+ def __init__(self, config: CodePredictorConfig, idx: int):
209
+ super().__init__()
210
+ self.self_attn = CodePredictorAttention(config, idx)
211
+ self.mlp = CodePredictorMLP(config)
212
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
213
+ self.post_attention_layernorm = nn.RMSNorm(
214
+ config.hidden_size, eps=config.rms_norm_eps
215
+ )
216
+ self.attention_type = (
217
+ config.layer_types[idx]
218
+ if hasattr(config, "layer_types") and config.layer_types
219
+ else "full_attention"
220
+ )
221
+
222
+ def __call__(
223
+ self,
224
+ hidden_states: mx.array,
225
+ attention_mask: Optional[mx.array] = None,
226
+ position_ids: Optional[mx.array] = None,
227
+ past_key_values: Optional[KVCache] = None,
228
+ cache_position: Optional[mx.array] = None,
229
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
230
+ ) -> mx.array:
231
+ residual = hidden_states
232
+ hidden_states = self.input_layernorm(hidden_states)
233
+ hidden_states, _ = self.self_attn(
234
+ hidden_states=hidden_states,
235
+ attention_mask=attention_mask,
236
+ position_ids=position_ids,
237
+ past_key_values=past_key_values,
238
+ cache_position=cache_position,
239
+ position_embeddings=position_embeddings,
240
+ )
241
+ hidden_states = residual + hidden_states
242
+
243
+ residual = hidden_states
244
+ hidden_states = self.post_attention_layernorm(hidden_states)
245
+ hidden_states = self.mlp(hidden_states)
246
+ hidden_states = residual + hidden_states
247
+ return hidden_states
248
+
249
+
250
+ class CodePredictorModel(nn.Module):
251
+ def __init__(self, config: CodePredictorConfig):
252
+ super().__init__()
253
+ self.config = config
254
+ self.layers = [
255
+ CodePredictorDecoderLayer(config, idx)
256
+ for idx in range(config.num_hidden_layers)
257
+ ]
258
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
259
+ self.rotary_emb = CodePredictorRotaryEmbedding(config)
260
+ self.codec_embedding = [
261
+ nn.Embedding(config.vocab_size, config.hidden_size)
262
+ for _ in range(config.num_code_groups - 1)
263
+ ]
264
+
265
+ def __call__(
266
+ self,
267
+ input_ids: Optional[mx.array] = None,
268
+ attention_mask: Optional[mx.array] = None,
269
+ position_ids: Optional[mx.array] = None,
270
+ past_key_values: Optional[list] = None,
271
+ inputs_embeds: Optional[mx.array] = None,
272
+ use_cache: Optional[bool] = False,
273
+ cache_position: Optional[mx.array] = None,
274
+ generation_steps: Optional[int] = None,
275
+ ) -> mx.array:
276
+ if input_ids is not None:
277
+ raise ValueError("`input_ids` is expected to be `None`")
278
+
279
+ if use_cache and past_key_values is None:
280
+ past_key_values = [KVCache() for _ in range(len(self.layers))]
281
+
282
+ if cache_position is None:
283
+ if past_key_values is not None and len(past_key_values) > 0:
284
+ offset = (
285
+ past_key_values[0].offset
286
+ if hasattr(past_key_values[0], "offset")
287
+ else 0
288
+ )
289
+ else:
290
+ offset = 0
291
+ cache_position = mx.arange(offset, offset + inputs_embeds.shape[1])
292
+
293
+ if position_ids is None:
294
+ position_ids = mx.expand_dims(cache_position, axis=0)
295
+
296
+ if attention_mask is None:
297
+ attention_mask = create_attention_mask(
298
+ inputs_embeds,
299
+ past_key_values[0] if past_key_values else None,
300
+ )
301
+
302
+ if attention_mask is not None and not isinstance(attention_mask, dict):
303
+ causal_mask_mapping = {
304
+ "full_attention": attention_mask,
305
+ }
306
+ else:
307
+ causal_mask_mapping = (
308
+ attention_mask
309
+ if isinstance(attention_mask, dict)
310
+ else {"full_attention": None}
311
+ )
312
+
313
+ hidden_states = inputs_embeds
314
+
315
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
316
+
317
+ for i, decoder_layer in enumerate(self.layers):
318
+ hidden_states = decoder_layer(
319
+ hidden_states,
320
+ attention_mask=causal_mask_mapping.get(
321
+ decoder_layer.attention_type,
322
+ causal_mask_mapping.get("full_attention"),
323
+ ),
324
+ position_ids=position_ids,
325
+ past_key_values=past_key_values[i] if past_key_values else None,
326
+ cache_position=cache_position,
327
+ position_embeddings=position_embeddings,
328
+ )
329
+
330
+ hidden_states = self.norm(hidden_states)
331
+ return hidden_states
332
+
333
+
334
+ class CodePredictor(nn.Module):
335
+ def __init__(self, config: CodePredictorConfig):
336
+ super().__init__()
337
+ self.config = config
338
+ self.model = CodePredictorModel(config)
339
+ self.lm_head = [
340
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
341
+ for _ in range(config.num_code_groups - 1)
342
+ ]
343
+
344
+ def __call__(
345
+ self,
346
+ input_ids: Optional[mx.array] = None,
347
+ attention_mask: Optional[mx.array] = None,
348
+ position_ids: Optional[mx.array] = None,
349
+ past_key_values: Optional[list] = None,
350
+ inputs_embeds: Optional[mx.array] = None,
351
+ labels: Optional[mx.array] = None,
352
+ use_cache: Optional[bool] = None,
353
+ cache_position: Optional[mx.array] = None,
354
+ generation_steps: Optional[int] = None,
355
+ ):
356
+ if (
357
+ inputs_embeds is not None
358
+ and inputs_embeds.shape[1] > 1
359
+ and generation_steps is None
360
+ ):
361
+ generation_steps = inputs_embeds.shape[1] - 2
362
+ elif input_ids is not None and generation_steps is not None:
363
+ inputs_embeds = self.model.codec_embedding[generation_steps - 1](input_ids)
364
+
365
+ if generation_steps is None:
366
+ generation_steps = 0
367
+
368
+ outputs = self.model(
369
+ input_ids=None,
370
+ attention_mask=attention_mask,
371
+ position_ids=position_ids,
372
+ past_key_values=past_key_values,
373
+ inputs_embeds=inputs_embeds,
374
+ use_cache=use_cache,
375
+ cache_position=cache_position,
376
+ generation_steps=generation_steps,
377
+ )
378
+
379
+ hidden_states = outputs
380
+ logits = self.lm_head[generation_steps](hidden_states)
381
+
382
+ return logits, hidden_states, inputs_embeds
383
+
384
+
385
+ class TalkerResizeMlp(nn.Module):
386
+ def __init__(self, config: TalkerConfig):
387
+ super().__init__()
388
+ self.linear_fc1 = nn.Linear(
389
+ config.thinker_hidden_size, config.text_config.intermediate_size, bias=True
390
+ )
391
+ self.linear_fc2 = nn.Linear(
392
+ config.text_config.intermediate_size,
393
+ config.text_config.hidden_size,
394
+ bias=True,
395
+ )
396
+
397
+ def __call__(self, x: mx.array) -> mx.array:
398
+ return self.linear_fc2(nn.silu(self.linear_fc1(x)))
399
+
400
+
401
+ class TalkerTextMlp(nn.Module):
402
+ def __init__(self, config: TextConfig, intermediate_sz: int):
403
+ super().__init__()
404
+ if not intermediate_sz:
405
+ intermediate_sz = config.intermediate_size
406
+
407
+ self.gate_proj = nn.Linear(config.hidden_size, intermediate_sz, bias=False)
408
+ self.up_proj = nn.Linear(config.hidden_size, intermediate_sz, bias=False)
409
+ self.down_proj = nn.Linear(intermediate_sz, config.hidden_size, bias=False)
410
+
411
+ if config.hidden_act == "silu":
412
+ self.act_fn = nn.silu
413
+ elif config.hidden_act == "gelu":
414
+ self.act_fn = nn.gelu
415
+ else:
416
+ self.act_fn = nn.silu
417
+
418
+ def __call__(self, x: mx.array) -> mx.array:
419
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
420
+
421
+
422
+ class TalkerSparseMoeBlock(nn.Module):
423
+ def __init__(self, config: TextConfig):
424
+ super().__init__()
425
+ self.num_experts = config.num_experts
426
+ self.top_k = config.num_experts_per_tok
427
+ self.norm_topk_prob = config.norm_topk_prob
428
+
429
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
430
+ self.switch_mlp = SwitchGLU(
431
+ config.hidden_size, config.moe_intermediate_size, config.num_experts
432
+ )
433
+ self.shared_expert = TalkerTextMlp(
434
+ config, config.shared_expert_intermediate_size
435
+ )
436
+ self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)
437
+
438
+ def __call__(self, hidden_states: mx.array) -> Tuple[mx.array, mx.array]:
439
+ router_logits = self.gate(hidden_states)
440
+ routing_weights = mx.softmax(
441
+ router_logits.astype(mx.float32), axis=-1, precise=True
442
+ )
443
+
444
+ k = self.top_k
445
+ inds = mx.argpartition(routing_weights, kth=-k, axis=-1)[..., -k:]
446
+ scores = mx.take_along_axis(routing_weights, inds, axis=-1)
447
+
448
+ if self.norm_topk_prob:
449
+ scores /= mx.sum(scores, axis=-1, keepdims=True)
450
+
451
+ y = self.switch_mlp(hidden_states, inds)
452
+ final_hidden_states = (y * scores[..., None]).sum(axis=-2)
453
+
454
+ shared_expert_output = self.shared_expert(hidden_states)
455
+ shared_expert_gate_output = nn.sigmoid(self.shared_expert_gate(hidden_states))
456
+ shared_expert_output = shared_expert_gate_output * shared_expert_output
457
+
458
+ final_hidden_states = final_hidden_states + shared_expert_output
459
+ return final_hidden_states, router_logits
460
+
461
+
462
+ class TalkerModelDecoderLayer(nn.Module):
463
+ def __init__(self, config: TextConfig, idx: int):
464
+ super().__init__()
465
+ self.self_attn = Attention(config)
466
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
467
+ self.post_attention_layernorm = nn.RMSNorm(
468
+ config.hidden_size, eps=config.rms_norm_eps
469
+ )
470
+ self.mlp = TalkerSparseMoeBlock(config)
471
+
472
+ def __call__(
473
+ self,
474
+ hidden_states: mx.array,
475
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
476
+ attention_mask: Optional[mx.array] = None,
477
+ position_ids: Optional[mx.array] = None,
478
+ past_key_values: Optional[KVCache] = None,
479
+ cache_position: Optional[mx.array] = None,
480
+ ) -> mx.array:
481
+ residual = hidden_states
482
+ hidden_states = self.input_layernorm(hidden_states)
483
+
484
+ if position_ids is not None and position_ids.ndim == 2:
485
+ position_ids_3d = mx.tile(mx.expand_dims(position_ids, axis=0), (3, 1, 1))
486
+ else:
487
+ position_ids_3d = position_ids
488
+
489
+ hidden_states = self.self_attn(
490
+ hidden_states,
491
+ mask=attention_mask,
492
+ cache=past_key_values,
493
+ position_ids=position_ids_3d,
494
+ )
495
+ hidden_states = residual + hidden_states
496
+
497
+ residual = hidden_states
498
+ hidden_states = self.post_attention_layernorm(hidden_states)
499
+ hidden_states, _ = self.mlp(hidden_states)
500
+ hidden_states = residual + hidden_states
501
+ return hidden_states
502
+
503
+
504
+ class TalkerModel(nn.Module):
505
+ def __init__(self, config: TextConfig):
506
+ super().__init__()
507
+ self.config = config
508
+ self.layers = [
509
+ TalkerModelDecoderLayer(config, idx)
510
+ for idx in range(config.num_hidden_layers)
511
+ ]
512
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
513
+ self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(
514
+ config.head_dim,
515
+ max_position_embeddings=config.max_position_embeddings,
516
+ base=config.rope_theta,
517
+ rope_scaling=config.rope_scaling,
518
+ )
519
+ self.codec_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
520
+
521
+ def __call__(
522
+ self,
523
+ input_ids: Optional[mx.array] = None,
524
+ attention_mask: Optional[mx.array] = None,
525
+ position_ids: Optional[mx.array] = None,
526
+ past_key_values: Optional[list] = None,
527
+ inputs_embeds: Optional[mx.array] = None,
528
+ use_cache: Optional[bool] = None,
529
+ cache_position: Optional[mx.array] = None,
530
+ visual_pos_masks: Optional[mx.array] = None,
531
+ deepstack_visual_embeds: Optional[list] = None,
532
+ ) -> mx.array:
533
+ if inputs_embeds is None:
534
+ inputs_embeds = self.codec_embedding(input_ids)
535
+
536
+ if use_cache and past_key_values is None:
537
+ past_key_values = [KVCache() for _ in range(len(self.layers))]
538
+
539
+ if cache_position is None:
540
+ if past_key_values is not None and len(past_key_values) > 0:
541
+ offset = (
542
+ past_key_values[0].offset
543
+ if hasattr(past_key_values[0], "offset")
544
+ else 0
545
+ )
546
+ else:
547
+ offset = 0
548
+ cache_position = mx.arange(offset, offset + inputs_embeds.shape[1])
549
+
550
+ if position_ids is None:
551
+ position_ids = cache_position
552
+ position_ids = mx.expand_dims(position_ids, axis=0)
553
+ position_ids = mx.tile(position_ids, (3, 1, 1))
554
+
555
+ if position_ids.ndim == 2:
556
+ position_ids = mx.broadcast_to(
557
+ position_ids[None, ...],
558
+ (3, position_ids.shape[0], position_ids.shape[1]),
559
+ )
560
+
561
+ text_position_ids = (
562
+ position_ids[0] if position_ids.shape[0] >= 1 else position_ids
563
+ )
564
+
565
+ if attention_mask is None:
566
+ attention_mask = create_attention_mask(
567
+ inputs_embeds,
568
+ past_key_values if past_key_values else None,
569
+ )
570
+
571
+ hidden_states = inputs_embeds
572
+
573
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
574
+
575
+ for layer_idx, decoder_layer in enumerate(self.layers):
576
+ hidden_states = decoder_layer(
577
+ hidden_states,
578
+ attention_mask=attention_mask,
579
+ position_ids=position_ids,
580
+ past_key_values=past_key_values[layer_idx] if past_key_values else None,
581
+ cache_position=cache_position,
582
+ position_embeddings=position_embeddings,
583
+ )
584
+
585
+ if deepstack_visual_embeds is not None and layer_idx < len(
586
+ deepstack_visual_embeds
587
+ ):
588
+ hidden_states = self._deepstack_process(
589
+ hidden_states,
590
+ visual_pos_masks,
591
+ deepstack_visual_embeds[layer_idx],
592
+ )
593
+
594
+ if layer_idx % 4 == 0:
595
+ mx.eval(hidden_states)
596
+
597
+ hidden_states = self.norm(hidden_states)
598
+ return hidden_states
599
+
600
+ def _deepstack_process(
601
+ self,
602
+ hidden_states: mx.array,
603
+ visual_pos_masks: mx.array,
604
+ visual_embeds: mx.array,
605
+ ):
606
+ if visual_pos_masks.ndim == 3:
607
+ visual_pos_masks = visual_pos_masks[..., 0]
608
+ visual_embeds = visual_embeds.astype(hidden_states.dtype)
609
+ visual_indices = np.where(visual_pos_masks)[0].tolist()
610
+ local_this = hidden_states[:, visual_indices, :] + visual_embeds
611
+ hidden_states[:, visual_indices, :] = local_this
612
+ return hidden_states
613
+
614
+
615
+ class Talker(nn.Module):
616
+ def __init__(self, config: TalkerConfig):
617
+ super().__init__()
618
+ self.config = config
619
+ self.model = TalkerModel(config.text_config)
620
+ self.text_projection = TalkerResizeMlp(config)
621
+ self.hidden_projection = TalkerResizeMlp(config)
622
+ self.code_predictor = CodePredictor(config.code_predictor_config)
623
+ self.codec_head = nn.Linear(
624
+ config.text_config.hidden_size, config.text_config.vocab_size, bias=False
625
+ )
626
+
627
+ def __call__(
628
+ self,
629
+ input_ids: Optional[mx.array] = None,
630
+ attention_mask: Optional[mx.array] = None,
631
+ position_ids: Optional[mx.array] = None,
632
+ past_key_values: Optional[list] = None,
633
+ inputs_embeds: Optional[mx.array] = None,
634
+ use_cache: Optional[bool] = None,
635
+ cache_position: Optional[mx.array] = None,
636
+ visual_pos_masks: Optional[mx.array] = None,
637
+ deepstack_visual_embeds: Optional[list] = None,
638
+ generation_steps: Optional[int] = None,
639
+ residual_codes: Optional[mx.array] = None,
640
+ trailing_text_hidden: Optional[mx.array] = None,
641
+ ):
642
+ if inputs_embeds is None:
643
+ inputs_embeds = self.model.codec_embedding(input_ids)
644
+
645
+ outputs = self.model(
646
+ input_ids=None,
647
+ attention_mask=attention_mask,
648
+ position_ids=position_ids,
649
+ past_key_values=past_key_values,
650
+ inputs_embeds=inputs_embeds,
651
+ use_cache=use_cache,
652
+ cache_position=cache_position,
653
+ visual_pos_masks=visual_pos_masks,
654
+ deepstack_visual_embeds=deepstack_visual_embeds,
655
+ )
656
+
657
+ hidden_states = outputs
658
+ logits = self.codec_head(hidden_states)
659
+
660
+ return logits, hidden_states
661
+
662
+ def prepare_inputs_for_generation(
663
+ self,
664
+ input_ids: mx.array,
665
+ past_hidden: mx.array,
666
+ trailing_text_hidden: mx.array,
667
+ tts_pad_embed: mx.array,
668
+ generation_step: int,
669
+ temperature: float = 1.0,
670
+ top_p: float = 0.8,
671
+ ):
672
+ token = input_ids
673
+ last_id_hidden = self.model.codec_embedding(token)
674
+
675
+ cp_input_embeds = mx.concatenate([past_hidden, last_id_hidden], axis=1)
676
+ cp_past_key_values = [
677
+ KVCache() for _ in range(len(self.code_predictor.model.layers))
678
+ ]
679
+
680
+ cp_logits, cp_hidden, _ = self.code_predictor(
681
+ inputs_embeds=cp_input_embeds,
682
+ past_key_values=cp_past_key_values,
683
+ use_cache=True,
684
+ )
685
+
686
+ if temperature == 0:
687
+ cp_token = mx.argmax(cp_logits[:, -1, :], axis=-1)
688
+ else:
689
+ cp_token = top_p_sampling(cp_logits[:, -1, :], top_p, temperature)
690
+
691
+ current_step_codes = [token, cp_token[:, None]]
692
+
693
+ mid_residual_hiddens = []
694
+
695
+ for cp_step in range(1, self.config.num_code_groups - 1):
696
+ cp_logits, cp_hidden, cp_input_embeds_out = self.code_predictor(
697
+ input_ids=cp_token[:, None],
698
+ past_key_values=cp_past_key_values,
699
+ use_cache=True,
700
+ generation_steps=cp_step,
701
+ )
702
+ mid_residual_hiddens.append(cp_input_embeds_out)
703
+
704
+ if temperature == 0:
705
+ cp_token = mx.argmax(cp_logits[:, -1, :], axis=-1)
706
+ else:
707
+ cp_token = top_p_sampling(cp_logits[:, -1, :], top_p, temperature)
708
+
709
+ current_step_codes.append(cp_token[:, None])
710
+
711
+ last_residual_hidden = self.code_predictor.model.codec_embedding[-1](
712
+ cp_token[:, None]
713
+ )
714
+
715
+ codec_hiddens = [last_id_hidden] + mid_residual_hiddens + [last_residual_hidden]
716
+ codec_hiddens_stacked = mx.concatenate(codec_hiddens, axis=1)
717
+ inputs_embeds = mx.sum(codec_hiddens_stacked, axis=1, keepdims=True)
718
+
719
+ if generation_step < trailing_text_hidden.shape[1]:
720
+ trailing = trailing_text_hidden[:, generation_step].reshape(1, 1, -1)
721
+ inputs_embeds = inputs_embeds + trailing
722
+ else:
723
+ inputs_embeds = inputs_embeds + tts_pad_embed
724
+
725
+ residual_codes = mx.concatenate(current_step_codes, axis=1)
726
+
727
+ return inputs_embeds, residual_codes
728
+
729
+ def generate(
730
+ self,
731
+ inputs_embeds: mx.array,
732
+ trailing_text_hidden: mx.array,
733
+ tts_pad_embed: mx.array,
734
+ talker_input_ids: mx.array,
735
+ max_new_tokens: int = 2048,
736
+ temperature: float = 0.9,
737
+ top_p: float = 1.0,
738
+ **kwargs,
739
+ ):
740
+ past_key_values = [
741
+ KVCache() for _ in range(self.config.text_config.num_hidden_layers)
742
+ ]
743
+
744
+ logits, hidden_states = self(
745
+ input_ids=None,
746
+ inputs_embeds=inputs_embeds,
747
+ past_key_values=past_key_values,
748
+ use_cache=True,
749
+ )
750
+
751
+ hidden_states_list = [(hidden_states, None)]
752
+
753
+ if temperature == 0:
754
+ token = mx.argmax(logits[:, -1, :], axis=-1)
755
+ else:
756
+ token = top_p_sampling(logits[:, -1, :], top_p, temperature)
757
+
758
+ generation_step = 0
759
+
760
+ for _ in range(max_new_tokens):
761
+ token_scalar = token.item()
762
+ if token_scalar == self.config.codec_eos_token_id:
763
+ break
764
+
765
+ past_hidden = hidden_states_list[-1][0][:, -1:]
766
+ inputs_embeds, residual_codes = self.prepare_inputs_for_generation(
767
+ input_ids=token[:, None],
768
+ past_hidden=past_hidden,
769
+ trailing_text_hidden=trailing_text_hidden,
770
+ tts_pad_embed=tts_pad_embed,
771
+ generation_step=generation_step,
772
+ temperature=temperature,
773
+ top_p=0.8,
774
+ )
775
+
776
+ logits, hidden_states = self(
777
+ input_ids=None,
778
+ inputs_embeds=inputs_embeds,
779
+ past_key_values=past_key_values,
780
+ use_cache=True,
781
+ )
782
+
783
+ hidden_states_list.append((hidden_states, residual_codes))
784
+
785
+ if temperature == 0:
786
+ token = mx.argmax(logits[:, -1, :], axis=-1)
787
+ else:
788
+ token = top_p_sampling(logits[:, -1, :], top_p, temperature)
789
+
790
+ generation_step += 1
791
+
792
+ class TalkerGenerateResult:
793
+ def __init__(self, hidden_states):
794
+ self.hidden_states = hidden_states
795
+
796
+ return TalkerGenerateResult(hidden_states_list)
797
+
798
+ def generate_stream(
799
+ self,
800
+ inputs_embeds: mx.array,
801
+ trailing_text_hidden: mx.array,
802
+ tts_pad_embed: mx.array,
803
+ talker_input_ids: mx.array,
804
+ max_new_tokens: int = 2048,
805
+ temperature: float = 0.9,
806
+ top_p: float = 1.0,
807
+ **kwargs,
808
+ ):
809
+ past_key_values = [
810
+ KVCache() for _ in range(self.config.text_config.num_hidden_layers)
811
+ ]
812
+ logits, hidden_states = self(
813
+ input_ids=None,
814
+ inputs_embeds=inputs_embeds,
815
+ past_key_values=past_key_values,
816
+ use_cache=True,
817
+ )
818
+
819
+ if temperature == 0:
820
+ token = mx.argmax(logits[:, -1, :], axis=-1)
821
+ else:
822
+ token = top_p_sampling(logits[:, -1, :], top_p, temperature)
823
+
824
+ generation_step = 0
825
+ past_hidden = hidden_states[:, -1:]
826
+
827
+ for _ in range(max_new_tokens):
828
+ token_scalar = token.item()
829
+ if token_scalar == self.config.codec_eos_token_id:
830
+ break
831
+
832
+ inputs_embeds, residual_codes = self.prepare_inputs_for_generation(
833
+ input_ids=token[:, None],
834
+ past_hidden=past_hidden,
835
+ trailing_text_hidden=trailing_text_hidden,
836
+ tts_pad_embed=tts_pad_embed,
837
+ generation_step=generation_step,
838
+ temperature=temperature,
839
+ top_p=0.8,
840
+ )
841
+
842
+ logits, hidden_states = self(
843
+ input_ids=None,
844
+ inputs_embeds=inputs_embeds,
845
+ past_key_values=past_key_values,
846
+ use_cache=True,
847
+ )
848
+ past_hidden = hidden_states[:, -1:]
849
+
850
+ yield residual_codes
851
+
852
+ if temperature == 0:
853
+ token = mx.argmax(logits[:, -1, :], axis=-1)
854
+ else:
855
+ token = top_p_sampling(logits[:, -1, :], top_p, temperature)
856
+
857
+ generation_step += 1
858
+
859
+ def sanitize(self, weights):
860
+ for l in range(self.config.text_config.num_hidden_layers):
861
+ prefix = f"talker.model.layers.{l}.mlp"
862
+ for n in ["gate_proj", "down_proj", "up_proj"]:
863
+ experts_weights = []
864
+ for e in range(self.config.text_config.num_experts):
865
+ key = f"{prefix}.experts.{e}.{n}.weight"
866
+ if key in weights:
867
+ experts_weights.append(weights.pop(key))
868
+
869
+ if experts_weights:
870
+ weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(
871
+ experts_weights, axis=0
872
+ )
873
+ return weights