fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,542 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+ from mlx_lm.models.base import scaled_dot_product_attention
8
+
9
+ from mlx_vlm.models.qwen3_omni_moe.config import Code2WavConfig
10
+
11
+
12
+ class SnakeBeta(nn.Module):
13
+ def __init__(self, in_features, alpha=1.0):
14
+ super().__init__()
15
+ self.in_features = in_features
16
+ self.alpha = mx.zeros((in_features,)) * alpha
17
+ self.beta = mx.zeros((in_features,)) * alpha
18
+ self.no_div_by_zero = 0.000000001
19
+
20
+ def __call__(self, hidden_states):
21
+ alpha = mx.expand_dims(mx.expand_dims(self.alpha, axis=0), axis=-1)
22
+ beta = mx.expand_dims(mx.expand_dims(self.beta, axis=0), axis=-1)
23
+ alpha = mx.exp(alpha)
24
+ beta = mx.exp(beta)
25
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * mx.power(
26
+ mx.sin(hidden_states * alpha), 2
27
+ )
28
+ return hidden_states
29
+
30
+
31
+ class LayerScale(nn.Module):
32
+ def __init__(self, config: Code2WavConfig):
33
+ super().__init__()
34
+ channels = config.hidden_size
35
+ initial_scale = config.layer_scale_initial_scale
36
+ self.scale = mx.full((channels,), initial_scale)
37
+
38
+ def __call__(self, x: mx.array):
39
+ return self.scale * x
40
+
41
+
42
+ class RoPE(nn.Module):
43
+ def __init__(self, config: Code2WavConfig):
44
+ super().__init__()
45
+ self.config = config
46
+ head_dim = config.hidden_size // config.num_attention_heads
47
+ dim = head_dim
48
+ inv_freq = 1.0 / (
49
+ config.rope_theta ** (np.arange(0, dim, 2, dtype=np.float32) / dim)
50
+ )
51
+ self.inv_freq = inv_freq
52
+ self.attention_scaling = 1.0
53
+
54
+ def __call__(self, x: mx.array, position_ids: mx.array):
55
+ batch_size = position_ids.shape[0]
56
+ inv_freq_mx = mx.array(self.inv_freq)
57
+ inv_freq_expanded = mx.broadcast_to(
58
+ inv_freq_mx[None, :, None].astype(mx.float32),
59
+ (batch_size, inv_freq_mx.shape[0], 1),
60
+ )
61
+ position_ids_expanded = mx.expand_dims(position_ids.astype(mx.float32), axis=1)
62
+ freqs = inv_freq_expanded @ position_ids_expanded
63
+ freqs = mx.swapaxes(freqs, 1, 2)
64
+ emb = mx.concatenate([freqs, freqs], axis=-1)
65
+ cos = mx.cos(emb) * self.attention_scaling
66
+ sin = mx.sin(emb) * self.attention_scaling
67
+ return cos.astype(x.dtype), sin.astype(x.dtype)
68
+
69
+
70
+ class CausalConvNet(nn.Module):
71
+ def __init__(self, in_chn, out_chn, kernel_sz, dilation=1, stride=1, groups=1):
72
+ super().__init__()
73
+ self.conv = nn.Conv1d(
74
+ in_chn, out_chn, kernel_sz, stride=stride, dilation=dilation, groups=groups
75
+ )
76
+ self.stride = stride
77
+ self.kernel_size = (kernel_sz - 1) * dilation + 1
78
+ self.dilation = dilation
79
+ self.padding = self.kernel_size - self.stride
80
+
81
+ def _get_extra_padding_for_conv1d(self, length: int) -> int:
82
+ n_frames = (length - self.kernel_size + self.padding) / self.stride + 1
83
+ ideal_length = (math.ceil(n_frames) - 1) * self.stride + (
84
+ self.kernel_size - self.padding
85
+ )
86
+ return int(ideal_length - length)
87
+
88
+ def __call__(self, hidden_state: mx.array) -> mx.array:
89
+ length = hidden_state.shape[-1]
90
+ extra_padding = self._get_extra_padding_for_conv1d(length)
91
+ hidden_state = hidden_state.transpose(0, 2, 1)
92
+ pad_width = [(0, 0), (self.padding, extra_padding), (0, 0)]
93
+ hidden_state = mx.pad(
94
+ hidden_state, pad_width, mode="constant", constant_values=0
95
+ )
96
+ output = self.conv(hidden_state)
97
+ return output.transpose(0, 2, 1)
98
+
99
+
100
+ class CausalTransConvNet(nn.Module):
101
+ def __init__(self, in_chn, out_chn, kernel_sz, stride=1):
102
+ super().__init__()
103
+ self.conv = nn.ConvTranspose1d(in_chn, out_chn, kernel_sz, stride=stride)
104
+ pad = kernel_sz - stride
105
+ self.left_pad = 0
106
+ self.right_pad = pad
107
+
108
+ def __call__(self, hidden_state: mx.array) -> mx.array:
109
+ hidden_state = hidden_state.transpose(0, 2, 1)
110
+ hidden_state = self.conv(hidden_state)
111
+ length = hidden_state.shape[-2]
112
+ hidden_state = hidden_state[:, self.left_pad : length - self.right_pad, :]
113
+ return hidden_state.transpose(0, 2, 1)
114
+
115
+
116
+ class ConvNeXtBlock(nn.Module):
117
+ def __init__(self, dim: int):
118
+ super().__init__()
119
+
120
+ self.dwconv = CausalConvNet(dim, dim, kernel_sz=7, groups=dim, dilation=1)
121
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
122
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
123
+ self.pwconv2 = nn.Linear(4 * dim, dim)
124
+ self.gamma = mx.full((dim,), 1e-6)
125
+
126
+ def __call__(self, hidden_states: mx.array) -> mx.array:
127
+ input = hidden_states
128
+ hidden_states = self.dwconv(hidden_states)
129
+ hidden_states = hidden_states.transpose(0, 2, 1)
130
+ hidden_states = self.norm(hidden_states)
131
+ hidden_states = self.pwconv1(hidden_states)
132
+ hidden_states = nn.gelu(hidden_states)
133
+ hidden_states = self.pwconv2(hidden_states)
134
+ hidden_states = self.gamma * hidden_states
135
+ hidden_states = hidden_states.transpose(0, 2, 1)
136
+ hidden_states = input + hidden_states
137
+ return hidden_states
138
+
139
+
140
+ class Code2WavDecoderResUnit(nn.Module):
141
+ def __init__(self, dim: int, dilation: int = 1):
142
+ super().__init__()
143
+
144
+ self.act1 = SnakeBeta(dim)
145
+ self.conv1 = CausalConvNet(dim, dim, kernel_sz=7, dilation=dilation)
146
+ self.act2 = SnakeBeta(dim)
147
+ self.conv2 = CausalConvNet(dim, dim, kernel_sz=1)
148
+
149
+ def __call__(self, hidden_state: mx.array) -> mx.array:
150
+ residual = hidden_state
151
+ hidden_state = self.act1(hidden_state)
152
+ hidden_state = self.conv1(hidden_state)
153
+ hidden_state = self.act2(hidden_state)
154
+ hidden_state = self.conv2(hidden_state)
155
+ return hidden_state + residual
156
+
157
+
158
+ class Code2WavDecoderBlock(nn.Module):
159
+ def __init__(self, config: Code2WavConfig, idx: int):
160
+ super().__init__()
161
+
162
+ in_dim = config.decoder_dim // 2**idx
163
+ out_dim = config.decoder_dim // 2 ** (idx + 1)
164
+ upsample_rate = config.upsample_rates[idx]
165
+
166
+ self.block = [
167
+ SnakeBeta(in_dim),
168
+ CausalTransConvNet(in_dim, out_dim, 2 * upsample_rate, upsample_rate),
169
+ ]
170
+ self.block.extend(
171
+ [Code2WavDecoderResUnit(out_dim, dilation) for dilation in (1, 3, 9)]
172
+ )
173
+
174
+ def __call__(self, hidden: mx.array) -> mx.array:
175
+ for block in self.block:
176
+ hidden = block(hidden)
177
+ return hidden
178
+
179
+
180
+ def rotate_half(x):
181
+ x1 = x[..., : x.shape[-1] // 2]
182
+ x2 = x[..., x.shape[-1] // 2 :]
183
+ return mx.concatenate([-x2, x1], axis=-1)
184
+
185
+
186
+ def apply_rotary_pos_emb(q, k, cos, sin):
187
+ cos = mx.expand_dims(cos, axis=1)
188
+ sin = mx.expand_dims(sin, axis=1)
189
+ q_embed = (q * cos) + (rotate_half(q) * sin)
190
+ k_embed = (k * cos) + (rotate_half(k) * sin)
191
+ return q_embed, k_embed
192
+
193
+
194
+ class Code2WavAttention(nn.Module):
195
+ def __init__(self, config: Code2WavConfig, idx: int):
196
+ super().__init__()
197
+
198
+ self.config = config
199
+ self.layer_idx = idx
200
+ self.head_dim = getattr(
201
+ config, "head_dim", config.hidden_size // config.num_attention_heads
202
+ )
203
+ self.num_key_value_groups = (
204
+ config.num_attention_heads // config.num_key_value_heads
205
+ )
206
+ self.scaling = self.head_dim**-0.5
207
+ self.attention_dropout = config.attention_dropout
208
+ self.is_causal = True
209
+
210
+ self.q_proj = nn.Linear(
211
+ config.hidden_size,
212
+ config.num_attention_heads * self.head_dim,
213
+ bias=config.attention_bias,
214
+ )
215
+ self.k_proj = nn.Linear(
216
+ config.hidden_size,
217
+ config.num_key_value_heads * self.head_dim,
218
+ bias=config.attention_bias,
219
+ )
220
+ self.v_proj = nn.Linear(
221
+ config.hidden_size,
222
+ config.num_key_value_heads * self.head_dim,
223
+ bias=config.attention_bias,
224
+ )
225
+ self.o_proj = nn.Linear(
226
+ config.num_attention_heads * self.head_dim,
227
+ config.hidden_size,
228
+ bias=config.attention_bias,
229
+ )
230
+ self.q_norm = nn.Identity()
231
+ self.k_norm = nn.Identity()
232
+ self.sliding_window = config.sliding_window
233
+ self.rotary_emb = RoPE(config)
234
+
235
+ def __call__(
236
+ self,
237
+ hidden_states: mx.array,
238
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
239
+ attention_mask: Optional[mx.array] = None,
240
+ position_ids: Optional[mx.array] = None,
241
+ ) -> Tuple[mx.array, Optional[mx.array]]:
242
+ B, L, D = hidden_states.shape
243
+ hidden_shape = (B, L, -1, self.head_dim)
244
+
245
+ query_states = self.q_norm(
246
+ self.q_proj(hidden_states).reshape(*hidden_shape)
247
+ ).transpose(0, 2, 1, 3)
248
+ key_states = self.k_norm(
249
+ self.k_proj(hidden_states).reshape(*hidden_shape)
250
+ ).transpose(0, 2, 1, 3)
251
+ value_states = (
252
+ self.v_proj(hidden_states).reshape(*hidden_shape).transpose(0, 2, 1, 3)
253
+ )
254
+
255
+ if position_embeddings is None:
256
+ if position_ids is None:
257
+ position_ids = mx.arange(L)
258
+ position_ids = mx.expand_dims(position_ids, axis=0)
259
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
260
+ else:
261
+ cos, sin = position_embeddings
262
+
263
+ query_states, key_states = apply_rotary_pos_emb(
264
+ query_states, key_states, cos, sin
265
+ )
266
+
267
+ if attention_mask is not None and isinstance(attention_mask, mx.array):
268
+ kv_seq_len = key_states.shape[-2]
269
+ if attention_mask.shape[-1] != kv_seq_len:
270
+ attention_mask = attention_mask[..., :kv_seq_len]
271
+
272
+ if self.is_causal and attention_mask is None:
273
+ attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
274
+ attention_mask = attention_mask.astype(query_states.dtype)
275
+
276
+ attn_output = scaled_dot_product_attention(
277
+ query_states,
278
+ key_states,
279
+ value_states,
280
+ None,
281
+ scale=self.scaling,
282
+ mask=attention_mask,
283
+ )
284
+
285
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
286
+ attn_output = self.o_proj(attn_output)
287
+ return attn_output, None
288
+
289
+
290
+ class Code2WavMlp(nn.Module):
291
+ def __init__(self, config: Code2WavConfig):
292
+ super().__init__()
293
+
294
+ self.config = config
295
+ self.hidden_size = config.hidden_size
296
+ self.intermediate_size = config.intermediate_size
297
+
298
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
299
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
300
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
301
+
302
+ if config.hidden_act == "silu":
303
+ self.act_fn = nn.silu
304
+ elif config.hidden_act == "gelu":
305
+ self.act_fn = nn.gelu
306
+ elif config.hidden_act == "gelu_pytorch_tanh":
307
+ self.act_fn = nn.GELU(approx="precise")
308
+ else:
309
+ self.act_fn = nn.silu
310
+
311
+ def __call__(self, x: mx.array) -> mx.array:
312
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
313
+
314
+
315
+ class Code2WavTransformerLayer(nn.Module):
316
+ def __init__(self, config: Code2WavConfig, idx: int):
317
+ super().__init__()
318
+ self.self_attn = Code2WavAttention(config, idx)
319
+ self.mlp = Code2WavMlp(config)
320
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, config.rms_norm_eps)
321
+ self.post_attention_layernorm = nn.RMSNorm(
322
+ config.hidden_size, config.rms_norm_eps
323
+ )
324
+ self.self_attn_layer_scale = LayerScale(config)
325
+ self.mlp_layer_scale = LayerScale(config)
326
+
327
+ def __call__(
328
+ self,
329
+ hidden_states: mx.array,
330
+ attention_mask: Optional[mx.array] = None,
331
+ position_ids: Optional[mx.array] = None,
332
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
333
+ ) -> mx.array:
334
+ residual = hidden_states
335
+ hidden_states = self.input_layernorm(hidden_states)
336
+ hidden_states, _ = self.self_attn(
337
+ hidden_states=hidden_states,
338
+ attention_mask=attention_mask,
339
+ position_ids=position_ids,
340
+ position_embeddings=position_embeddings,
341
+ )
342
+ hidden_states = residual + self.self_attn_layer_scale(hidden_states)
343
+
344
+ residual = hidden_states
345
+ hidden_states = self.post_attention_layernorm(hidden_states)
346
+ hidden_states = self.mlp(hidden_states)
347
+ hidden_states = residual + self.mlp_layer_scale(hidden_states)
348
+
349
+ return hidden_states
350
+
351
+
352
+ class Code2WavTransformerModel(nn.Module):
353
+ def __init__(self, config: Code2WavConfig):
354
+ super().__init__()
355
+
356
+ self.layers = [
357
+ Code2WavTransformerLayer(config, idx)
358
+ for idx in range(config.num_hidden_layers)
359
+ ]
360
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
361
+ self.rotary_emb = RoPE(config)
362
+
363
+ def __call__(
364
+ self,
365
+ inputs_embeds: mx.array,
366
+ attention_mask: Optional[mx.array] = None,
367
+ position_ids: Optional[mx.array] = None,
368
+ ) -> mx.array:
369
+ hidden_states = inputs_embeds
370
+
371
+ if position_ids is None:
372
+ position_ids = mx.arange(hidden_states.shape[1])
373
+ position_ids = mx.expand_dims(position_ids, axis=0)
374
+
375
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
376
+
377
+ for layer in self.layers:
378
+ hidden_states = layer(
379
+ hidden_states,
380
+ attention_mask=attention_mask,
381
+ position_ids=position_ids,
382
+ position_embeddings=position_embeddings,
383
+ )
384
+
385
+ hidden_states = self.norm(hidden_states)
386
+ return hidden_states
387
+
388
+
389
+ class Code2WavModel(nn.Module):
390
+ def __init__(self, config: Code2WavConfig):
391
+ super().__init__()
392
+
393
+ self.pre_transformer = Code2WavTransformerModel(config)
394
+ self.code_embedding = nn.Embedding(
395
+ config.codebook_size * config.num_quantizers, config.hidden_size
396
+ )
397
+ self.upsample = [
398
+ [
399
+ CausalTransConvNet(
400
+ config.hidden_size, config.hidden_size, factor, factor
401
+ ),
402
+ ConvNeXtBlock(config.hidden_size),
403
+ ]
404
+ for factor in config.upsampling_ratios
405
+ ]
406
+ self.decoder = [CausalConvNet(config.hidden_size, config.decoder_dim, 7)]
407
+ self.decoder.extend(
408
+ [
409
+ Code2WavDecoderBlock(config, idx)
410
+ for idx in range(len(config.upsample_rates))
411
+ ]
412
+ )
413
+ output_dim = config.decoder_dim // 2 ** len(config.upsample_rates)
414
+ self.decoder.extend([SnakeBeta(output_dim), CausalConvNet(output_dim, 1, 7)])
415
+ self.config = config
416
+ self.code_offset = (
417
+ np.arange(config.num_quantizers).reshape(1, -1, 1) * config.codebook_size
418
+ )
419
+
420
+ def __call__(
421
+ self, codes: mx.array = None, input_embeds: mx.array = None
422
+ ) -> mx.array:
423
+ if input_embeds is not None:
424
+ hidden = input_embeds
425
+ elif codes is not None:
426
+ if codes.shape[1] != self.config.num_quantizers:
427
+ raise ValueError(
428
+ f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}"
429
+ )
430
+ hidden = self.code_embedding(codes + mx.array(self.code_offset)).mean(1)
431
+ else:
432
+ raise ValueError("Must provide codes or input_embeds")
433
+
434
+ hidden = self.pre_transformer(inputs_embeds=hidden)
435
+ hidden = hidden.transpose(0, 2, 1)
436
+ for blocks in self.upsample:
437
+ for block in blocks:
438
+ hidden = block(hidden)
439
+ wav = hidden
440
+ for block in self.decoder:
441
+ wav = block(wav)
442
+ return mx.clip(wav, -1, 1)
443
+
444
+ def chunked_decode(self, codes, chunk_size=300, left_context_size=25):
445
+ total_upsample_factor = 1
446
+ for r in self.config.upsampling_ratios:
447
+ total_upsample_factor *= r
448
+ for r in self.config.upsample_rates:
449
+ total_upsample_factor *= r
450
+
451
+ B, Q, L = codes.shape
452
+ final_wav_list = []
453
+
454
+ for start in range(0, L, chunk_size):
455
+ end = min(start + chunk_size, L)
456
+ context_start = max(0, start - left_context_size)
457
+ chunk_codes = codes[:, :, context_start:end]
458
+ wav_chunk = self(codes=chunk_codes)
459
+ context_len_tokens = start - context_start
460
+ valid_start_sample = context_len_tokens * total_upsample_factor
461
+ current_chunk_valid_len_tokens = end - start
462
+ valid_len_samples = current_chunk_valid_len_tokens * total_upsample_factor
463
+ chunk_valid_wav = wav_chunk[
464
+ :, :, valid_start_sample : valid_start_sample + valid_len_samples
465
+ ]
466
+ final_wav_list.append(chunk_valid_wav)
467
+
468
+ return mx.concatenate(final_wav_list, axis=-1)
469
+
470
+ def stream_decode(
471
+ self, codes_buffer, chunk_size=300, left_context_size=25, decoded_len=0
472
+ ):
473
+ total_upsample_factor = 1
474
+ for r in self.config.upsampling_ratios:
475
+ total_upsample_factor *= r
476
+ for r in self.config.upsample_rates:
477
+ total_upsample_factor *= r
478
+
479
+ L = codes_buffer.shape[2]
480
+ start = decoded_len
481
+ context_start = max(0, start - left_context_size)
482
+ context_len = start - context_start
483
+ new_tokens = chunk_size - context_len
484
+ if L - start < new_tokens:
485
+ return None, decoded_len
486
+
487
+ end = start + new_tokens
488
+ chunk_codes = codes_buffer[:, :, context_start:end]
489
+ wav_chunk = self(codes=chunk_codes)
490
+ context_len_tokens = start - context_start
491
+ valid_start_sample = context_len_tokens * total_upsample_factor
492
+ current_chunk_valid_len_tokens = end - start
493
+ valid_len_samples = current_chunk_valid_len_tokens * total_upsample_factor
494
+ chunk_valid_wav = wav_chunk[
495
+ :, :, valid_start_sample : valid_start_sample + valid_len_samples
496
+ ]
497
+ return chunk_valid_wav, end
498
+
499
+ def flush_decode(self, codes_buffer, left_context_size=25, decoded_len=0):
500
+ total_upsample_factor = 1
501
+ for r in self.config.upsampling_ratios:
502
+ total_upsample_factor *= r
503
+ for r in self.config.upsample_rates:
504
+ total_upsample_factor *= r
505
+
506
+ L = codes_buffer.shape[2]
507
+ if decoded_len >= L:
508
+ return None
509
+
510
+ start = decoded_len
511
+ context_start = max(0, start - left_context_size)
512
+ chunk_codes = codes_buffer[:, :, context_start:]
513
+ wav_chunk = self(codes=chunk_codes)
514
+ context_len_tokens = start - context_start
515
+ valid_start_sample = context_len_tokens * total_upsample_factor
516
+ return wav_chunk[:, :, valid_start_sample:]
517
+
518
+ def sanitize(self, weights):
519
+ sanitized_weights = {}
520
+ for k, v in weights.items():
521
+ if ("upsample" in k and "conv.weight" in k and "dwconv" not in k) or (
522
+ "decoder" in k
523
+ and "block" in k
524
+ and "conv.weight" in k
525
+ and "conv1" not in k
526
+ and "conv2" not in k
527
+ ):
528
+ sanitized_weights[k] = v.transpose(1, 2, 0)
529
+ elif (
530
+ ("dwconv.conv.weight" in k)
531
+ or ("decoder" in k and "conv.weight" in k and "block" not in k)
532
+ or (
533
+ "decoder" in k
534
+ and "block" in k
535
+ and ("conv1.conv.weight" in k or "conv2.conv.weight" in k)
536
+ )
537
+ ):
538
+ sanitized_weights[k] = v.transpose(0, 2, 1)
539
+ else:
540
+ sanitized_weights[k] = v
541
+
542
+ return sanitized_weights