nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1038 @@
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+
7
+ from ..base import check_array_shape
8
+ from .config import AudioConfig, ModelConfig
9
+ from .language import Gemma3nRMSNorm
10
+
11
+
12
+ def convert_torch_to_mlx_pad_width(padding, input_shape):
13
+ """Convert PyTorch padding to MLX pad_width format"""
14
+ ndim = len(input_shape)
15
+
16
+ # Initialize with no padding for all dimensions
17
+ pad_width = [(0, 0)] * ndim
18
+
19
+ # Set padding only for the dimensions that exist in the input
20
+ # PyTorch p2d format: (left, right, top, bottom, front, back, ...)
21
+ # For 2D tensor with padding (12, 11, 0, 0):
22
+ # - Last dim gets (left=12, right=11)
23
+ # - Second to last dim gets (top=0, bottom=0)
24
+
25
+ if ndim >= 1 and len(padding) >= 2:
26
+ # Last dimension
27
+ pad_width[-1] = (padding[0], padding[1])
28
+ if ndim >= 2 and len(padding) >= 4:
29
+ # Second to last dimension
30
+ pad_width[-2] = (padding[2], padding[3])
31
+ if ndim >= 3 and len(padding) >= 6:
32
+ # Third to last dimension
33
+ pad_width[-3] = (padding[4], padding[5])
34
+ if ndim >= 4 and len(padding) >= 8:
35
+ # Fourth to last dimension
36
+ pad_width[-4] = (padding[6], padding[7])
37
+
38
+ return pad_width
39
+
40
+
41
+ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
42
+
43
+ def __init__(self, config: AudioConfig, *args, **kwargs):
44
+ super().__init__()
45
+ self.config = config
46
+
47
+ self.num_heads = self.config.conf_num_attention_heads
48
+ self.channels = self.config.hidden_size
49
+ self.head_dim = self.channels // self.num_heads
50
+ self.max_backward = (
51
+ self.config.conf_attention_context_left - 1
52
+ if self.config.conf_attention_context_left > 0
53
+ else 0
54
+ )
55
+ self.max_forward = self.config.conf_attention_context_right
56
+
57
+ self.pos_proj = nn.Linear(
58
+ self.channels, self.num_heads * self.head_dim, bias=False
59
+ )
60
+
61
+ min_timescale = 1.0
62
+ max_timescale = 1.0e4
63
+ num_timescales = self.channels // 2
64
+ log_timescale_increment = math.log(
65
+ float(max_timescale) / float(min_timescale)
66
+ ) / max(num_timescales - 1, 1)
67
+ inv_timescales = min_timescale * mx.exp(
68
+ mx.arange(num_timescales) * -log_timescale_increment
69
+ )
70
+
71
+ self._inv_timescales = mx.array(inv_timescales)[None, None, ...]
72
+
73
+ def _get_timing_signal_1d_pos(self, position: mx.array, dtype) -> mx.array:
74
+ assert position.ndim == 2
75
+ position = mx.expand_dims(position.astype(mx.float32), axis=-1)
76
+
77
+ scaled_time = position * self._inv_timescales
78
+ timing_signal = mx.concatenate(
79
+ [mx.sin(scaled_time), mx.cos(scaled_time)], axis=-1
80
+ )
81
+ return timing_signal.astype(dtype)
82
+
83
+ def _relative_shift(
84
+ self,
85
+ term_bd_before_shift: mx.array,
86
+ batch_size: int,
87
+ num_heads: int,
88
+ num_query_blocks: int,
89
+ query_block_size: int,
90
+ key_context_size: int,
91
+ max_span_plus_1: int,
92
+ ) -> mx.array:
93
+ pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
94
+
95
+ # We only pad the last dimension on the right.
96
+ padding_tuple = (0, pad_amount_last_dim)
97
+
98
+ term_bd_padded = mx.pad(
99
+ term_bd_before_shift,
100
+ convert_torch_to_mlx_pad_width(padding_tuple, term_bd_before_shift.shape),
101
+ )
102
+ # Shape after pad: [B, N, U, W, C+1]
103
+ # Reshape for slicing (emulating JAX's behavior)
104
+ # [B, N, U, W * (C+1)]
105
+ term_bd_reshaped = term_bd_padded.reshape(
106
+ (
107
+ batch_size,
108
+ num_heads,
109
+ num_query_blocks,
110
+ query_block_size * (key_context_size + 1),
111
+ )
112
+ )
113
+
114
+ # Slice to effective [B, N, U, W * C]
115
+ term_bd_sliced = term_bd_reshaped[
116
+ :, :, :, : query_block_size * key_context_size
117
+ ]
118
+
119
+ # Reshape back to [B, N, U, W, C]
120
+ term_bd_shifted = term_bd_sliced.reshape(
121
+ (
122
+ batch_size,
123
+ num_heads,
124
+ num_query_blocks,
125
+ query_block_size,
126
+ key_context_size,
127
+ )
128
+ )
129
+ return term_bd_shifted
130
+
131
+ def __call__(self, queries: mx.array, keys: mx.array) -> mx.array:
132
+ # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
133
+ # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
134
+ # C = W + L + R (key_context_size)
135
+ # F_span = L + R + 1 (max_span + 1)
136
+
137
+ batch_size, num_query_blocks, query_block_size, num_heads, head_dim = (
138
+ queries.shape
139
+ )
140
+ _, _, key_context_size, _, _ = keys.shape
141
+
142
+ # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
143
+ # Length is L+R+1 = self.max_span + 1
144
+ pos_indices = mx.expand_dims(
145
+ mx.arange(self.max_backward, -self.max_forward - 1, -1), axis=0
146
+ ) # Shape [1, F_span]
147
+
148
+ max_span_plus_1 = pos_indices.shape[1] # F_span
149
+
150
+ sin_emb_timing_signal = self._get_timing_signal_1d_pos(
151
+ pos_indices, dtype=queries.dtype
152
+ ) # Shape [1, F_span, self.channels]
153
+
154
+ # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
155
+ projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
156
+ # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
157
+ sin_emb = projected_sin_emb.reshape(
158
+ 1, max_span_plus_1, self.num_heads, self.head_dim
159
+ ).squeeze(
160
+ 0
161
+ ) # Shape [F, N, H]
162
+
163
+ # term_ac: Query-Key content interaction
164
+ # queries: [B, U, W, N, H] -> transpose to [B, N, U, W, H] for matmul
165
+ # keys: [B, U, C, N, H] -> transpose to [B, N, U, H, C] for matmul
166
+ queries_p = queries.transpose(0, 3, 1, 2, 4) # [B, N, U, W, H]
167
+ keys_p_t = keys.transpose(0, 3, 1, 4, 2) # [B, N, U, H, C]
168
+ term_ac = mx.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
169
+
170
+ # term_bd: Query-Position interaction
171
+ # Original einsum: term_bd_unshifed = mx.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
172
+ # queries shape: [B, U, W, N, H]
173
+ # sin_emb shape: [F, N, H]
174
+ # Target output shape: [B, N, U, W, F]
175
+
176
+ # Transpose queries to [B, N, U, W, H] for easier broadcasting with sin_emb
177
+ q_transposed = queries.transpose(0, 3, 1, 2, 4)
178
+
179
+ # Permute sin_emb to [N, H, F] to prepare for matmul
180
+ # sin_emb original is [F, N, H]
181
+ s_transposed = sin_emb.transpose(1, 2, 0) # Shape: [N, H, F]
182
+
183
+ # Reshape queries for matmul: [B, N, U*W, H]
184
+ q_reshaped = q_transposed.reshape(
185
+ batch_size, num_heads, num_query_blocks * query_block_size, head_dim
186
+ )
187
+
188
+ # Perform matmul: [B, N, U*W, H] @ [N, H, F]
189
+ # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
190
+ # Result: [B, N, U*W, F]
191
+ term_bd_unshifed_matmul = mx.matmul(q_reshaped, s_transposed)
192
+
193
+ # Reshape to target [B, N, U, W, F]
194
+ term_bd_unshifed = term_bd_unshifed_matmul.reshape(
195
+ batch_size,
196
+ num_heads,
197
+ num_query_blocks,
198
+ query_block_size,
199
+ max_span_plus_1,
200
+ )
201
+
202
+ # Apply relative shift to term_bd_unshifed
203
+ term_bd_shifted = self._relative_shift(
204
+ term_bd_unshifed,
205
+ batch_size,
206
+ num_heads,
207
+ num_query_blocks,
208
+ query_block_size,
209
+ key_context_size,
210
+ max_span_plus_1,
211
+ ) # Shape [B, N, U, W, C]
212
+
213
+ return term_ac + term_bd_shifted
214
+
215
+
216
+ class Gemma3nAudioAttention(nn.Module):
217
+ def __init__(self, config: AudioConfig, *args, **kwargs):
218
+ super().__init__()
219
+ self.config = config
220
+
221
+ self.num_heads = self.config.conf_num_attention_heads
222
+ self.hidden_size = self.config.hidden_size
223
+ self.head_dim = self.hidden_size // self.num_heads
224
+
225
+ self.chunk_size = self.config.conf_attention_chunk_size
226
+ self.max_future_horizon = self.config.conf_attention_context_right
227
+ self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
228
+ self.attention_invalid_logits_value = (
229
+ self.config.conf_attention_invalid_logits_value
230
+ )
231
+ self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
232
+ self.context_size = (
233
+ self.chunk_size + self.max_past_horizon + self.max_future_horizon
234
+ )
235
+
236
+ self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
237
+ self.per_dim_scale = mx.zeros((self.head_dim,))
238
+
239
+ self.q_proj = nn.Linear(
240
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
241
+ )
242
+ self.k_proj = nn.Linear(
243
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
244
+ )
245
+ self.v_proj = nn.Linear(
246
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
247
+ )
248
+
249
+ q_scale = self.head_dim**-0.5
250
+ # Fix: Implement softplus manually since nn.softplus doesn't exist in MLX
251
+ # softplus(x) = log(1 + exp(x))
252
+ r_softplus_0 = 1.0 / mx.log(2.0)
253
+ self._q_scale = q_scale * r_softplus_0
254
+
255
+ lower_causal_mask = mx.tril(
256
+ mx.ones((self.context_size, self.chunk_size), dtype=mx.bool_),
257
+ k=0,
258
+ ).T
259
+ upper_causal_mask = mx.tril(
260
+ mx.ones((self.chunk_size, self.context_size), dtype=mx.bool_),
261
+ k=self.max_past_horizon + self.max_future_horizon,
262
+ )
263
+ local_causal_valid_mask = mx.ones(
264
+ (self.chunk_size, self.context_size), dtype=mx.bool_
265
+ )
266
+ local_causal_valid_mask = (
267
+ local_causal_valid_mask * lower_causal_mask * upper_causal_mask
268
+ )
269
+ self._local_causal_valid_mask = local_causal_valid_mask
270
+
271
+ self._softcap = mx.array(self.attention_logits_soft_cap, dtype=mx.float32)
272
+
273
+ def _pad_dim1(
274
+ self,
275
+ x: mx.array,
276
+ dim10_val: int,
277
+ dim11_val: int,
278
+ ) -> mx.array:
279
+ padding_tuple = [0] * x.ndim * 2
280
+ dim_idx_from_end = x.ndim - 2
281
+ start_idx_for_dim = 2 * dim_idx_from_end
282
+ padding_tuple[start_idx_for_dim] = dim10_val
283
+ padding_tuple[start_idx_for_dim + 1] = dim11_val
284
+
285
+ return mx.pad(x, convert_torch_to_mlx_pad_width(tuple(padding_tuple), x.shape))
286
+
287
+ def _convert_to_block(
288
+ self, x: mx.array, padding_val: Union[bool, float] = 0.0
289
+ ) -> mx.array:
290
+ shape = x.shape
291
+ b, t = shape[:2]
292
+ num_blocks = (t + self.chunk_size - 1) // self.chunk_size
293
+
294
+ if (padding_len := num_blocks * self.chunk_size - t) > 0:
295
+ x = self._pad_dim1(x, 0, padding_len)
296
+
297
+ permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
298
+ return x.reshape(permute_dims)
299
+
300
+ def unfold_mlx(self, x, dimension, size, step):
301
+ # Get the shape and determine the number of windows
302
+ shape = x.shape
303
+ dim_size = shape[dimension]
304
+ num_windows = (dim_size - size) // step + 1
305
+
306
+ # Create indices for each window
307
+ windows = []
308
+ for i in range(num_windows):
309
+ start_idx = i * step
310
+ end_idx = start_idx + size
311
+
312
+ # Create slice objects for all dimensions
313
+ slices = [slice(None)] * len(shape)
314
+ slices[dimension] = slice(start_idx, end_idx)
315
+
316
+ windows.append(x[tuple(slices)])
317
+
318
+ # Stack along a new dimension
319
+ return mx.stack(windows, axis=dimension + 1)
320
+
321
+ def _extract_block_context(self, x: mx.array) -> mx.array:
322
+ pad_left = self.max_past_horizon
323
+
324
+ pad_right = self.max_future_horizon + self.chunk_size - 1
325
+ x = self._pad_dim1(x, pad_left, pad_right)
326
+
327
+ frame_len = self.context_size
328
+ frame_step = self.chunk_size
329
+ # Create windows using sliding window approach for MLX
330
+ # x shape: (batch, time, ...)
331
+ batch_size = x.shape[0]
332
+ time_dim = x.shape[1]
333
+ other_dims = x.shape[2:]
334
+
335
+ x_unfolded = self.unfold_mlx(x, 1, frame_len, frame_step)
336
+
337
+ if x.ndim > 2 and x_unfolded.ndim > 3:
338
+ x_unfolded = x_unfolded.transpose(0, 2, 1, 3, 4)
339
+
340
+ return x_unfolded
341
+
342
+ def __call__(self, x: mx.array, mask: mx.array) -> mx.array:
343
+ query_states = self.q_proj(x).reshape(
344
+ *x.shape[:-1], self.num_heads, self.head_dim
345
+ )
346
+ key_states = self.k_proj(x).reshape(
347
+ *x.shape[:-1], self.num_heads, self.head_dim
348
+ )
349
+ value_states = self.v_proj(x).reshape(
350
+ *x.shape[:-1], self.num_heads, self.head_dim
351
+ )
352
+
353
+ per_dim_scale_sp = mx.logaddexp(self.per_dim_scale, 0.0)
354
+
355
+ broadcast_shape = (1, 1, 1, self.head_dim)
356
+ per_dim_scale_sp_broadcast = per_dim_scale_sp.reshape(broadcast_shape)
357
+ query_states = query_states * self._q_scale * per_dim_scale_sp_broadcast
358
+
359
+ batch_size, q_time = query_states.shape[:2]
360
+
361
+ query_blocks = self._convert_to_block(query_states)
362
+ key_blocks = self._extract_block_context(key_states)
363
+ value_blocks = self._extract_block_context(value_states)
364
+ num_query_blocks = query_blocks.shape[1]
365
+
366
+ # 1. Create a mask indicating originally valid positions.
367
+ original_valid_mask = ~mask # True for valid, False for padded
368
+
369
+ # 2. Extract blocks from this validity mask.
370
+ extracted_valid_mask_blocks = self._extract_block_context(
371
+ original_valid_mask
372
+ ).transpose(0, 2, 1)
373
+
374
+ # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
375
+ # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
376
+ # batch_size and num_query_blocks are known from query_blocks.
377
+ # self.context_size is C.
378
+ if (
379
+ extracted_valid_mask_blocks.ndim == 4
380
+ and extracted_valid_mask_blocks.shape[0] == batch_size
381
+ and extracted_valid_mask_blocks.shape[1] == num_query_blocks
382
+ and extracted_valid_mask_blocks.shape[2]
383
+ * extracted_valid_mask_blocks.shape[3]
384
+ == self.context_size
385
+ ):
386
+ extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
387
+ batch_size, num_query_blocks, self.context_size
388
+ )
389
+ # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
390
+ # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
391
+ # but for the mask case, this should hold.
392
+ if extracted_valid_mask_blocks.shape != (
393
+ batch_size,
394
+ num_query_blocks,
395
+ self.context_size,
396
+ ):
397
+ raise ValueError(
398
+ "Shape of extracted_valid_mask_blocks"
399
+ f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
400
+ f" {num_query_blocks}, {self.context_size}) after potential reshape."
401
+ )
402
+
403
+ # 3. Expand dimensions for broadcasting with logits and causal mask.
404
+ # Target shape for broadcasting with logits [B,N,U,W,C]
405
+ # extracted_valid_mask_blocks to [B, 1, U, 1, C]
406
+ condition_from_input_validity = mx.expand_dims(
407
+ extracted_valid_mask_blocks, axis=1
408
+ )
409
+ condition_from_input_validity = mx.expand_dims(
410
+ condition_from_input_validity, axis=-2
411
+ )
412
+
413
+ # self.local_causal_valid_mask is [W, C], True where allowed by local window.
414
+ # Expand to [1, 1, 1, W, C]
415
+ condition_from_causality = self._local_causal_valid_mask[None, None, None, ...]
416
+
417
+ # 4. Combine the two conditions.
418
+ # final_condition will be True where a key is *both* originally valid *and* causally accessible.
419
+ # Broadcasts to [B, 1, U, W, C]
420
+ final_condition_for_where = mx.logical_and(
421
+ condition_from_input_validity,
422
+ condition_from_causality, # Ensure same device
423
+ )
424
+
425
+ # Embed queries and keys
426
+ logits = self.relative_position_embedding(query_blocks, key_blocks)
427
+
428
+ # Apply attention logit softcap
429
+ # Ensure softcap is on the same device as logits
430
+ logits = logits / self._softcap
431
+ logits = nn.tanh(logits)
432
+ logits = logits * self._softcap
433
+
434
+ # Apply the combined mask.
435
+ # final_condition_for_where will broadcast with logits [B,N,U,W,C]
436
+ logits = mx.where(
437
+ final_condition_for_where, logits, self.attention_invalid_logits_value
438
+ )
439
+ probabilities = mx.softmax(logits.astype(mx.float32), axis=-1).astype(
440
+ value_blocks.dtype
441
+ )
442
+
443
+ # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
444
+ b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
445
+ h_dim = value_blocks.shape[-1]
446
+ prob_bun = probabilities.transpose(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
447
+ v_bun = value_blocks.transpose(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
448
+ result_bmm = mx.matmul(prob_bun, v_bun)
449
+ context_vectors = result_bmm.reshape(
450
+ b_dim, u_dim, n_dim, w_dim, h_dim
451
+ ).transpose(0, 1, 3, 2, 4)
452
+ context_vectors = context_vectors.reshape(
453
+ (
454
+ batch_size,
455
+ num_query_blocks * self.chunk_size,
456
+ self.num_heads,
457
+ self.head_dim,
458
+ )
459
+ )
460
+ context_vectors = context_vectors[:, :q_time]
461
+
462
+ return context_vectors
463
+
464
+
465
+ class Gemma3nCumulativeGroupNorm(nn.Module):
466
+ """Applies Group Normalization cumulatively over the time dimension.
467
+
468
+ This layer normalizes the input by calculating the mean and variance
469
+ cumulatively over the time dimension (dim 1). The statistics are computed
470
+ over all feature dimensions (specified by `feature_dims` and `num_channels`)
471
+ for elements marked as valid by the optional `mask`.
472
+
473
+ If a `mask` is provided (True for valid, False for invalid/padded),
474
+ invalid time steps do not contribute to the statistics calculation, and
475
+ their corresponding output values are zeroed out.
476
+
477
+ Scale and bias, if enabled, are applied per-channel (last dimension).
478
+ This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
479
+ and `cumulative=True`.
480
+ """
481
+
482
+ def __init__(
483
+ self,
484
+ num_channels: int, # Number of channels (size of the last dimension)
485
+ feature_dims: Tuple[
486
+ int
487
+ ], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
488
+ eps: float = 1e-3,
489
+ use_scale: bool = True,
490
+ use_bias: bool = False,
491
+ ):
492
+ super().__init__()
493
+ self.num_channels = num_channels
494
+ self.feature_dims = tuple(feature_dims)
495
+ self.eps = eps
496
+ self.use_scale = use_scale
497
+ self.use_bias = use_bias
498
+
499
+ if self.use_scale:
500
+ # Scale parameter depends only on the channel dimension
501
+ self.weight = mx.ones(num_channels)
502
+ else:
503
+ self.weight = None
504
+
505
+ if self.use_bias:
506
+ # Bias parameter depends only on the channel dimension
507
+ self.bias = mx.zeros(num_channels)
508
+ else:
509
+ self.bias = None
510
+
511
+ # Axes for normalization: all dimensions except Batch (0) and Time (1).
512
+ # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
513
+ self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
514
+
515
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
516
+ """Applies cumulative group norm, optionally using a mask.
517
+
518
+ Args:
519
+ x: Input tensor, shape [B, T, *feature_dims, C].
520
+ mask: Optional boolean mask, shape [B, T]. True indicates a valid
521
+ (non-padded) time step. If None, all time steps are considered valid.
522
+
523
+ Returns:
524
+ Normalized tensor with the same shape as x.
525
+ """
526
+ expected_input_suffix = self.feature_dims + (self.num_channels,)
527
+ if x.shape[2:] != expected_input_suffix:
528
+ raise ValueError(
529
+ f"Input tensor shape suffix {x.shape[2:]} does not match expected"
530
+ f" suffix (feature_dims + num_channels) {expected_input_suffix}"
531
+ )
532
+
533
+ if mask is not None:
534
+ if mask.shape != x.shape[:2]:
535
+ raise ValueError(
536
+ f"Mask shape {mask.shape} must match input Batch/Time dimensions {x.shape[:2]}"
537
+ )
538
+ if mask.dtype != mx.bool:
539
+ raise TypeError("Mask must be a boolean tensor.")
540
+
541
+ input_dtype = x.dtype
542
+ # Calculations are performed in float32 for numerical stability.
543
+ calc_dtype = mx.float32
544
+ x_calc = x.astype(calc_dtype)
545
+
546
+ # Prepare a broadcastable mask (`mask_calc`).
547
+ # If no mask is provided, treat all elements as valid
548
+ # (mask_calc is all ones).
549
+ # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
550
+ if mask is not None:
551
+ mask_suffix_shape = (1,) * len(expected_input_suffix)
552
+ mask_calc = mask.reshape(mask.shape + mask_suffix_shape).astype(calc_dtype)
553
+ else:
554
+ mask_calc = mx.ones_like(x_calc).astype(calc_dtype)
555
+
556
+ # Mask the input for sum calculation: only valid elements contribute.
557
+ x_masked_for_sum = x_calc * mask_calc
558
+
559
+ # Cumulative Statistics Calculation
560
+ # 1. Sum of values over reduction axes at each time step.
561
+ sum_values_at_t = mx.sum(
562
+ x_masked_for_sum, axis=self.reduction_axes, keepdims=True
563
+ )
564
+ # 2. Cumulative sum of values over time.
565
+ cum_sum_values = mx.cumsum(sum_values_at_t, axis=1)
566
+
567
+ # 3. Count of valid elements in the normalization group at each time step.
568
+ # (A "group" here consists of all features at a given Batch, Time).
569
+ elements_in_group_at_t = mx.sum(
570
+ mask_calc, axis=self.reduction_axes, keepdims=True
571
+ )
572
+ # 4. Cumulative count of valid elements over time.
573
+ cum_count_elements = mx.cumsum(elements_in_group_at_t, axis=1)
574
+ # Avoid division by zero if all preceding elements were masked.
575
+ safe_cum_count_elements = mx.clip(cum_count_elements, 1, None)
576
+
577
+ # 5. Cumulative mean.
578
+ cum_mean = cum_sum_values / safe_cum_count_elements
579
+
580
+ # 6. Sum of squared differences from the cumulative mean.
581
+ # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
582
+ # Using x_calc here for the difference, as cum_mean already accounts for masking.
583
+ squared_diff_from_mean = (x_calc - cum_mean) ** 2
584
+ sum_sq_diff_at_t = mx.sum(
585
+ squared_diff_from_mean * mask_calc,
586
+ axis=self.reduction_axes,
587
+ keepdims=True,
588
+ )
589
+ # 7. Cumulative sum of squared differences over time.
590
+ cum_sum_sq_diff = mx.cumsum(sum_sq_diff_at_t, axis=1)
591
+
592
+ # 8. Cumulative variance.
593
+ cum_variance = cum_sum_sq_diff / safe_cum_count_elements
594
+
595
+ # Normalize the input using the calculated cumulative statistics:
596
+ # (x - E[x]) / sqrt(Var[x] + eps)
597
+ normalized_x = (x_calc - cum_mean) * mx.rsqrt(cum_variance + self.eps)
598
+
599
+ # Apply affine transformation (scale and bias) if enabled.
600
+ # Scale and bias are applied per-channel (last dimension).
601
+ if self.use_scale and self.weight is not None:
602
+ scale = self.weight.astype(calc_dtype)
603
+ # Reshape for broadcasting: [C] -> [1, ..., 1, C]
604
+ scale_view_shape = [1] * (x.ndim - 1) + [self.num_channels]
605
+ normalized_x = normalized_x * scale.reshape(scale_view_shape)
606
+
607
+ if self.use_bias and self.bias is not None:
608
+ bias = self.bias.astype(calc_dtype)
609
+ bias_view_shape = [1] * (x.ndim - 1) + [self.num_channels]
610
+ normalized_x = normalized_x + bias.reshape(bias_view_shape)
611
+
612
+ # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
613
+ # This ensures padded/invalid positions in the input result in zero output.
614
+ final_output = normalized_x * mask_calc
615
+
616
+ return final_output.astype(input_dtype)
617
+
618
+
619
+ class Gemma3nAudioSSCPConvBlock(nn.Module):
620
+ def __init__(
621
+ self,
622
+ idx: int,
623
+ input_freq_dim: int,
624
+ config: AudioConfig,
625
+ manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
626
+ *args,
627
+ **kwargs,
628
+ ):
629
+ super().__init__()
630
+ self.config = config
631
+ self.manual_padding = manual_padding
632
+
633
+ # in_channels is 1 for the first block, or C_out from previous block's conv
634
+ in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
635
+ out_channels = self.config.sscp_conv_channel_size[idx]
636
+ kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
637
+ stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
638
+
639
+ self.conv = nn.Conv2d(
640
+ in_channels=in_channels,
641
+ out_channels=out_channels,
642
+ kernel_size=(
643
+ kernel_h,
644
+ kernel_w,
645
+ ), # Kernel (kH, kW) operates on (Time, Freq_dim)
646
+ stride=(stride_h, stride_w),
647
+ padding=(0, 0), # Manual padding is used
648
+ bias=False,
649
+ )
650
+
651
+ # Calculate output frequency dimension (f_out_conv) after this convolution.
652
+ # input_freq_dim is the unpadded width (feature dimension).
653
+ # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
654
+ f_in_padded = (
655
+ input_freq_dim
656
+ + self.manual_padding[0] # pad_F_left
657
+ + self.manual_padding[1] # pad_F_right
658
+ )
659
+ f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
660
+
661
+ self.norm = Gemma3nCumulativeGroupNorm(
662
+ num_channels=out_channels, # Channels of the conv output
663
+ feature_dims=(f_out_conv,), # The frequency dimension size after conv
664
+ eps=self.config.sscp_conv_eps,
665
+ use_scale=True,
666
+ use_bias=False,
667
+ )
668
+
669
+ def __call__(self, x: mx.array) -> mx.array:
670
+ # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
671
+ # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
672
+ # F.pad applies to last two dims: F_in then T_in
673
+
674
+ audio_encodings_padded = mx.pad(
675
+ x, convert_torch_to_mlx_pad_width(self.manual_padding, x.shape)
676
+ )
677
+
678
+ # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
679
+ # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
680
+ audio_encodings_conv = self.conv(audio_encodings_padded.transpose(0, 2, 3, 1))
681
+ # Expected conv output shape: [B, C_out, T_out, F_out]
682
+ # Input to norm is [B, T_out, F_out, C_out]
683
+ x_normed = self.norm(audio_encodings_conv)
684
+ # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
685
+ audio_encodings_normed = x_normed.transpose(0, 3, 1, 2)
686
+ return nn.relu(audio_encodings_normed)
687
+
688
+
689
+ class Gemma3nAudioSubSampleConvProjection(nn.Module):
690
+
691
+ def __init__(self, config: AudioConfig, *args, **kwargs):
692
+ super().__init__()
693
+ self.config = config
694
+
695
+ current_f_for_block_input = (
696
+ config.input_feat_size
697
+ ) # Start with original feature dim
698
+ calculated_block_padding = []
699
+ calculated_f_out_dims = [] # Tracking frequency dimension output sizes
700
+
701
+ for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
702
+ kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
703
+ stride_h, stride_w = config.sscp_conv_stride_size[i]
704
+ # Assuming dilation rate of 1 for frequency dimension as it's not in config
705
+ # effective_kernel_w = (kernel_w - 1) * dilation_w + 1 # Not needed if hardcoding freq padding
706
+
707
+ # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
708
+ # JAX 'reverse_causal' padding is (0, kernel_size - 1)
709
+ pad_t_top = 0
710
+ pad_t_bottom = kernel_h - 1
711
+
712
+ # Frequency Padding (Width for Conv2d)
713
+ # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
714
+ # and the successful test configuration.
715
+ # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
716
+ # to match generic JAX 'SAME' behavior if it differs.
717
+ pad_f_left = 1
718
+ pad_f_right = 1
719
+
720
+ manual_padding_tuple = (
721
+ pad_f_left,
722
+ pad_f_right,
723
+ pad_t_top,
724
+ pad_t_bottom,
725
+ )
726
+ calculated_block_padding.append(manual_padding_tuple)
727
+
728
+ # Calculate output frequency dimension after this convolution
729
+ # This uses the actual padding applied and kernel/stride.
730
+ f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
731
+ f_out_after_conv = (
732
+ f_in_padded - kernel_w
733
+ ) // stride_w + 1 # Assuming dilation_w = 1
734
+ calculated_f_out_dims.append(f_out_after_conv)
735
+ current_f_for_block_input = f_out_after_conv
736
+
737
+ self.conv_0 = Gemma3nAudioSSCPConvBlock(
738
+ idx=0,
739
+ input_freq_dim=config.input_feat_size, # Pass original feature dim
740
+ config=config,
741
+ manual_padding=calculated_block_padding[0],
742
+ )
743
+ self.conv_1 = Gemma3nAudioSSCPConvBlock(
744
+ idx=1,
745
+ input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
746
+ config=config,
747
+ manual_padding=calculated_block_padding[1],
748
+ )
749
+ final_c_out = config.sscp_conv_channel_size[-1]
750
+ final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
751
+ self.input_proj_in_features = final_c_out * final_f_out
752
+ self.input_proj_linear = nn.Linear(
753
+ self.input_proj_in_features, self.config.hidden_size, bias=False
754
+ )
755
+
756
+ def __call__(self, x: mx.array) -> mx.array:
757
+ # audio_encodings is [B, T, F_in]
758
+ # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
759
+ audio_encodings_reshaped = mx.expand_dims(x, 1)
760
+ x = self.conv_0(audio_encodings_reshaped)
761
+ x = self.conv_1(x)
762
+ # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
763
+ b, c_out, t_out, f_out = x.shape
764
+ # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
765
+ x_transposed = x.transpose(0, 2, 3, 1)
766
+ output_flattened = x_transposed.reshape(b, t_out, f_out * c_out)
767
+ output = self.input_proj_linear(output_flattened)
768
+ return output
769
+
770
+
771
+ class Gemma3nAudioConformerAttention(nn.Module):
772
+ def __init__(self, config: AudioConfig, *args, **kwargs):
773
+ super().__init__()
774
+ self.config = config
775
+
776
+ head_dim = self.config.hidden_size // self.config.conf_num_attention_heads
777
+ self.post_in_shape = (self.config.conf_num_attention_heads, head_dim)
778
+ self.post_in_features = self.config.hidden_size
779
+
780
+ self._gradient_clipping = mx.array(self.config.gradient_clipping)
781
+
782
+ self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
783
+ self.attn = Gemma3nAudioAttention(config)
784
+ self.post = nn.Linear(
785
+ self.post_in_features, self.config.hidden_size, bias=False
786
+ )
787
+ self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
788
+
789
+ def __call__(self, x: mx.array, mask: mx.array) -> mx.array:
790
+ audio_encodings_input_to_attn = x
791
+ x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
792
+ audio_encodings_norm = self.pre_attn_norm(x)
793
+ # Output of self.attn is [B, T, NumHeads, HeadDim]
794
+ audio_encodings_attn_out = self.attn(audio_encodings_norm, mask)
795
+
796
+ # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
797
+ # NumHeads * HeadDim = hidden_size
798
+ b, t, num_heads, head_dim = audio_encodings_attn_out.shape
799
+ audio_encodings_reshaped = audio_encodings_attn_out.reshape(
800
+ b, t, num_heads * head_dim
801
+ )
802
+
803
+ x = self.post(audio_encodings_reshaped)
804
+ x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
805
+ return audio_encodings_input_to_attn + self.post_norm(x)
806
+
807
+
808
+ class Gemma3nAudioConformerFeedForward(nn.Module):
809
+ def __init__(self, config: AudioConfig, *args, **kwargs):
810
+ super().__init__()
811
+ self.config = config
812
+
813
+ self._gradient_clipping = mx.array(self.config.gradient_clipping)
814
+
815
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
816
+ self.ffw_layer_1 = nn.Linear(
817
+ self.config.hidden_size, self.config.hidden_size * 4, bias=False
818
+ )
819
+ self.ffw_layer_2 = nn.Linear(
820
+ self.config.hidden_size * 4, self.config.hidden_size, bias=False
821
+ )
822
+ self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
823
+ self._post_layer_scale = mx.array(self.config.conf_residual_weight)
824
+
825
+ def __call__(self, x: mx.array) -> mx.array:
826
+ residual = x
827
+ x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
828
+ x = self.pre_layer_norm(x)
829
+ x: mx.array = self.ffw_layer_1(x) # jax.numpy.einsum("...a,ab->...b")
830
+ x = nn.silu(x) # Add SiLU (Swish) activation
831
+ x: mx.array = self.ffw_layer_2(x) # jax.numpy.einsum("...a,ab->...b")
832
+ x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
833
+ x = self.post_layer_norm(x)
834
+ return residual + (x * self._post_layer_scale)
835
+
836
+
837
+ class Gemma3nAudioConformerLightConv1d(nn.Module):
838
+ def __init__(self, config: AudioConfig, *args, **kwargs):
839
+ super().__init__()
840
+ self.config = config
841
+
842
+ self.pre_layer_norm = Gemma3nRMSNorm(
843
+ self.config.hidden_size, eps=self.config.rms_norm_eps
844
+ )
845
+ self.linear_start = nn.Linear(
846
+ self.config.hidden_size, self.config.hidden_size * 2, bias=False
847
+ )
848
+ self.depthwise_conv1d = nn.Conv1d(
849
+ in_channels=self.config.hidden_size,
850
+ out_channels=self.config.hidden_size,
851
+ kernel_size=self.config.conf_conv_kernel_size,
852
+ stride=1,
853
+ padding=0, # Manual causal padding
854
+ groups=self.config.hidden_size, # Depthwise
855
+ bias=False,
856
+ )
857
+ self._gradient_clipping = mx.array(self.config.gradient_clipping)
858
+ self.conv_norm = Gemma3nRMSNorm(
859
+ self.config.hidden_size, eps=self.config.rms_norm_eps
860
+ )
861
+ self.linear_end = nn.Linear(
862
+ self.config.hidden_size, self.config.hidden_size, bias=False
863
+ )
864
+
865
+ self.causal_padding = self.config.conf_conv_kernel_size - 1
866
+
867
+ def __call__(self, audio_encodings: mx.array) -> mx.array:
868
+ audio_encodings_residual = audio_encodings # Save for residual connection
869
+
870
+ audio_encodings = self.pre_layer_norm(audio_encodings)
871
+ audio_encodings = self.linear_start(audio_encodings)
872
+ audio_encodings = nn.glu(audio_encodings, axis=-1)
873
+ # Permute for Conv1d: [B, T, D] -> [B, D, T]
874
+ audio_encodings_transposed = audio_encodings.transpose(0, 2, 1)
875
+ # Apply manual causal padding
876
+ audio_encodings_transposed_padded = mx.pad(
877
+ audio_encodings_transposed,
878
+ convert_torch_to_mlx_pad_width(
879
+ (self.causal_padding, 0), audio_encodings_transposed.shape
880
+ ),
881
+ )
882
+ audio_encodings = self.depthwise_conv1d(
883
+ audio_encodings_transposed_padded.transpose(0, 2, 1)
884
+ )
885
+ audio_encodings = mx.clip(
886
+ audio_encodings, -self._gradient_clipping, self._gradient_clipping
887
+ )
888
+ audio_encodings = self.conv_norm(audio_encodings)
889
+ audio_encodings = nn.silu(audio_encodings)
890
+ audio_encodings = self.linear_end(audio_encodings)
891
+ output = audio_encodings + audio_encodings_residual
892
+ return output
893
+
894
+
895
+ class Gemma3nAudioConformerBlock(nn.Module):
896
+
897
+ def __init__(self, config: AudioConfig, *args, **kwargs):
898
+ super().__init__()
899
+ self.config = config
900
+
901
+ self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
902
+ self.attention = Gemma3nAudioConformerAttention(self.config)
903
+ self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
904
+ self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
905
+ self._gradient_clipping = mx.array(self.config.gradient_clipping)
906
+ self.norm = Gemma3nRMSNorm(self.config.hidden_size)
907
+
908
+ def __call__(self, audio_encodings: mx.array, audio_mel_mask: mx.array) -> mx.array:
909
+ audio_encodings = self.ffw_layer_start(audio_encodings)
910
+ audio_encodings = self.attention(audio_encodings, audio_mel_mask)
911
+ validity_mask_for_lconv = ~audio_mel_mask # True for valid
912
+ audio_encodings_for_lconv_input = audio_encodings * mx.expand_dims(
913
+ validity_mask_for_lconv, -1
914
+ ).astype(audio_encodings.dtype)
915
+ audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
916
+
917
+ audio_encodings = self.ffw_layer_end(audio_encodings)
918
+ audio_encodings = mx.clip(
919
+ audio_encodings, -self._gradient_clipping, self._gradient_clipping
920
+ )
921
+ output = self.norm(audio_encodings)
922
+ return output
923
+
924
+
925
+ class AudioModel(nn.Module):
926
+ def __init__(self, config: AudioConfig, *args, **kwargs):
927
+ super().__init__()
928
+ self.config = config
929
+
930
+ self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
931
+ self.conformer = [
932
+ Gemma3nAudioConformerBlock(config)
933
+ for _ in range(config.conf_num_hidden_layers)
934
+ ]
935
+
936
+ def __call__(
937
+ self, audio_mel: mx.array, audio_mel_mask: mx.array
938
+ ) -> Tuple[mx.array, mx.array]:
939
+ audio_encodings = self.subsample_conv_projection(
940
+ audio_mel
941
+ ) # audio_encodings: [B, T_sub, D]
942
+
943
+ # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
944
+ t_sub = audio_encodings.shape[1]
945
+
946
+ time_stride_product = 1
947
+ for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
948
+ time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
949
+
950
+ # Create indices for gathering from the original mask.
951
+ # These indices map to original time steps corresponding to the start of each
952
+ # receptive field in the subsampled output.
953
+ indices = mx.arange(t_sub) * time_stride_product
954
+ indices = mx.clip(
955
+ indices, None, a_max=audio_mel_mask.shape[1] - 1
956
+ ) # Ensure indices are valid
957
+
958
+ # Expand indices for batch compatibility if B > 1 and indices is 1D.
959
+ if audio_mel_mask.ndim > 1 and indices.ndim == 1:
960
+ indices = indices[None, :]
961
+ indices = mx.broadcast_to(
962
+ indices, (audio_mel_mask.shape[0], indices.shape[1])
963
+ ) # [B, T_sub]
964
+ elif (
965
+ audio_mel_mask.ndim == indices.ndim
966
+ and audio_mel_mask.shape[0] == 1
967
+ and indices.shape[0] != 1
968
+ and t_sub == indices.shape[0]
969
+ ):
970
+ # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
971
+ indices = indices[None, :]
972
+
973
+ current_mask = mx.take_along_axis(audio_mel_mask, indices, axis=1) # [B, T_sub]
974
+
975
+ # Fallback: Ensure mask length matches feature length after gather.
976
+ if current_mask.shape[1] != t_sub:
977
+ print(
978
+ "Warning: Subsampled mask length %s mismatch with feature length %s after gather. Adjusting.",
979
+ current_mask.shape[1],
980
+ t_sub,
981
+ )
982
+ if current_mask.shape[1] > t_sub:
983
+ current_mask = current_mask[:, :t_sub]
984
+ else: # current_mask.shape[1] < t_sub
985
+ padding_needed = t_sub - current_mask.shape[1]
986
+ current_mask = mx.pad(
987
+ current_mask,
988
+ convert_torch_to_mlx_pad_width(
989
+ (0, padding_needed), current_mask.shape
990
+ ),
991
+ )
992
+
993
+ for i, block in enumerate(self.conformer):
994
+ audio_encodings = block(
995
+ audio_encodings, current_mask
996
+ ) # Pass the processed mask
997
+
998
+ if self.config.conf_reduction_factor > 1:
999
+ audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
1000
+ # Reduce the mask as well
1001
+ current_mask = current_mask[:, :: self.config.conf_reduction_factor]
1002
+
1003
+ # Final masking of audio_encodings based on the final current_mask
1004
+ # Ensure current_mask length matches the finally reduced audio_encodings length
1005
+ if current_mask.shape[1] != audio_encodings.shape[1]:
1006
+ target_len = audio_encodings.shape[1]
1007
+ mask_current_len = current_mask.shape[1]
1008
+ if target_len > mask_current_len:
1009
+ padding_needed = target_len - mask_current_len
1010
+ current_mask = mx.pad(
1011
+ current_mask,
1012
+ convert_torch_to_mlx_pad_width(
1013
+ (0, padding_needed), current_mask.shape
1014
+ ),
1015
+ )
1016
+ elif mask_current_len > target_len: # mask is longer
1017
+ current_mask = current_mask[:, :target_len]
1018
+
1019
+ audio_encodings = mx.where(current_mask[..., None], 0.0, audio_encodings)
1020
+ return audio_encodings, current_mask
1021
+
1022
+ def sanitize(self, weights):
1023
+ sanitized_weights = {}
1024
+ for k, v in weights.items():
1025
+ if "conv.weight" in k:
1026
+ if check_array_shape(v):
1027
+ sanitized_weights[k] = v
1028
+ else:
1029
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
1030
+ elif "conv1d.weight" in k:
1031
+ if check_array_shape(v):
1032
+ sanitized_weights[k] = v
1033
+ else:
1034
+ sanitized_weights[k] = v.transpose(0, 2, 1)
1035
+ else:
1036
+ sanitized_weights[k] = v
1037
+
1038
+ return sanitized_weights