nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1022 @@
1
+ import inspect
2
+ from collections.abc import Sequence
3
+ from dataclasses import dataclass
4
+ from math import sqrt
5
+ from typing import Dict, List, Optional, Tuple, Type
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+
10
+ from .config import VisionConfig
11
+
12
+ from ..base import check_array_shape
13
+ from ..kernels import bicubic_interpolate, nearest_interpolate
14
+
15
+
16
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L24
17
+ class MobileNetV5MultiScaleFusionAdapter(nn.Module):
18
+ """Multi-layer fusion token adapter.
19
+ Attributes:
20
+ out_filters: The number of output filters.
21
+ output_resolution: The output resolution.
22
+ activation: The activation function.
23
+ expansion_ratio: The expansion ratio.
24
+ upsampling_interpolation: The upsampling interpolation.
25
+ use_layer_scale: Whether to use layer scale.
26
+ layer_scale_init_value: The initial value of the layer scale.
27
+ skip_projection: Whether to skip the projection.
28
+ name: The name of the module.
29
+ upsize: The upsampling fn.
30
+ downsize: The downsampling fn.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ in_chs: List[int],
36
+ out_chs: int,
37
+ output_resolution: int,
38
+ expansion_ratio: float = 2.0,
39
+ interpolation_mode: str = "nearest",
40
+ use_layer_scale: bool = False,
41
+ layer_scale_init_value: float = 1e-5,
42
+ noskip: bool = True,
43
+ ):
44
+ super().__init__()
45
+ self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs
46
+ self.out_channels = out_chs
47
+ self.output_resolution = to_2tuple(output_resolution)
48
+ self.expansion_ratio = expansion_ratio
49
+ self.interpolation_mode = interpolation_mode
50
+ self.use_layer_scale = use_layer_scale
51
+ self.layer_scale_init_value = layer_scale_init_value
52
+ self.noskip = noskip
53
+
54
+ norm_layer = RMSNormAct2d
55
+ self.ffn = UniversalInvertedResidual(
56
+ in_chs=self.in_channels,
57
+ out_chs=self.out_channels,
58
+ dw_kernel_size_mid=0,
59
+ exp_ratio=self.expansion_ratio,
60
+ norm_layer=norm_layer,
61
+ noskip=self.noskip,
62
+ layer_scale_init_value=(
63
+ self.layer_scale_init_value if self.use_layer_scale else None
64
+ ),
65
+ )
66
+
67
+ self.norm = norm_layer(self.out_channels, eps=1e-6, apply_act=False)
68
+
69
+ def __call__(self, inputs: list[mx.array]) -> mx.array:
70
+ inputs = [i.transpose(0, 3, 1, 2) for i in inputs]
71
+ high_resolution = inputs[0].shape[
72
+ -2:
73
+ ] # Assuming the first input is the highest resolution.
74
+ resized_inputs = []
75
+
76
+ for _, img in enumerate(inputs):
77
+ if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]):
78
+ img = nearest_interpolate(img, size=high_resolution)
79
+
80
+ resized_inputs.append(img)
81
+
82
+ channel_cat_imgs = mx.concatenate(
83
+ resized_inputs, axis=1
84
+ ) # Cat on channel dim, must equal self.in_channels
85
+ img = self.ffn(channel_cat_imgs.swapaxes(1, 3)).swapaxes(1, 3)
86
+
87
+ if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]):
88
+ if (
89
+ high_resolution[0] % self.output_resolution[0] != 0
90
+ or high_resolution[1] % self.output_resolution[1] != 0
91
+ ):
92
+ img = bicubic_interpolate(img, self.output_resolution)
93
+ else:
94
+ h_strides = high_resolution[0] // self.output_resolution[0]
95
+ w_strides = high_resolution[1] // self.output_resolution[1]
96
+
97
+ img = nn.AvgPool2d(
98
+ kernel_size=(h_strides, w_strides),
99
+ stride=(h_strides, w_strides),
100
+ )(img.swapaxes(1, 3))
101
+
102
+ img = self.norm(img) if self.noskip else img
103
+
104
+ return img
105
+
106
+
107
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/layer_scale.py#L22
108
+ class LayerScale2d(nn.Module):
109
+ def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
110
+ super().__init__()
111
+ self.inplace = inplace
112
+ self.gamma = init_values * mx.ones((dim,))
113
+
114
+ def __call__(self, x: mx.array) -> mx.array:
115
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
116
+
117
+
118
+ def rms_norm2d(
119
+ x: mx.array,
120
+ normalized_shape: List[int],
121
+ weight: Optional[mx.array] = None,
122
+ eps: float = 1e-5,
123
+ ):
124
+ assert len(normalized_shape) == 1
125
+ dtype = x.dtype
126
+ v = mx.power(x, 2)
127
+ v = mx.mean(v, axis=1, keepdims=True)
128
+ x = x * mx.rsqrt(v + eps)
129
+ if weight is not None:
130
+ x = x.astype(dtype) * weight.reshape(1, -1, 1, 1)
131
+ return x
132
+
133
+
134
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/norm_act.py#L504
135
+ class RMSNormAct2d(nn.RMSNorm):
136
+ def __init__(
137
+ self,
138
+ num_channels,
139
+ eps=1e-6,
140
+ apply_act: bool = True,
141
+ ):
142
+ super().__init__(dims=num_channels, eps=eps)
143
+ self.normalized_shape = [num_channels]
144
+ self.drop = nn.Identity()
145
+ self.act = nn.GELU() if apply_act else nn.Identity()
146
+
147
+ def __call__(self, x: mx.array) -> mx.array:
148
+
149
+ x = x.transpose(0, 3, 1, 2) # Convert from NHWC to NCHW
150
+ x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
151
+ x = self.drop(x)
152
+ x = self.act(x)
153
+ x = x.transpose(0, 2, 3, 1) # Convert back to NHWC
154
+ return x
155
+
156
+
157
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L310
158
+ class UniversalInvertedResidual(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_chs: int,
162
+ out_chs: int,
163
+ dw_kernel_size_start: int = 0,
164
+ dw_kernel_size_mid: int = 3,
165
+ dw_kernel_size_end: int = 0,
166
+ stride: int = 1,
167
+ dilation: int = 1,
168
+ group_size: int = 1,
169
+ pad_type: str = "",
170
+ noskip: bool = False,
171
+ exp_ratio: float = 1.0,
172
+ norm_layer=RMSNormAct2d,
173
+ conv_kwargs: Optional[Dict] = None,
174
+ drop_path_rate: float = 0.0,
175
+ layer_scale_init_value: Optional[float] = 1e-5,
176
+ ):
177
+ super().__init__()
178
+ self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
179
+ if stride > 1:
180
+ assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
181
+
182
+ if dw_kernel_size_start:
183
+ dw_start_stride = stride if not dw_kernel_size_mid else 1
184
+ dw_start_groups = num_groups(group_size, in_chs)
185
+ self.dw_start = ConvNormAct(
186
+ nn.Conv2d,
187
+ in_chs,
188
+ in_chs,
189
+ kernel_size=dw_kernel_size_start,
190
+ stride=dw_start_stride,
191
+ padding=(dw_kernel_size_start - 1) // 2,
192
+ dilation=dilation,
193
+ groups=dw_start_groups,
194
+ bias=False,
195
+ apply_act=False,
196
+ eps=1e-05,
197
+ )
198
+ else:
199
+ self.dw_start = nn.Identity()
200
+
201
+ mid_chs = make_divisible(in_chs * exp_ratio)
202
+ self.pw_exp = ConvNormAct(
203
+ nn.Conv2d,
204
+ in_chs,
205
+ mid_chs,
206
+ kernel_size=1,
207
+ stride=1,
208
+ padding=0,
209
+ groups=1,
210
+ bias=False,
211
+ eps=1e-05,
212
+ )
213
+
214
+ if dw_kernel_size_mid:
215
+ dw_mid_groups = num_groups(group_size, mid_chs)
216
+ self.dw_mid = ConvNormAct(
217
+ Conv2dSame,
218
+ mid_chs,
219
+ mid_chs,
220
+ kernel_size=dw_kernel_size_mid,
221
+ stride=stride,
222
+ padding=0,
223
+ dilation=dilation,
224
+ groups=dw_mid_groups,
225
+ bias=False,
226
+ eps=1e-05,
227
+ )
228
+ else:
229
+ self.dw_mid = nn.Identity()
230
+
231
+ self.pw_proj = ConvNormAct(
232
+ nn.Conv2d,
233
+ mid_chs,
234
+ out_chs,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ groups=1,
239
+ bias=False,
240
+ apply_act=False,
241
+ eps=1e-05,
242
+ )
243
+ if layer_scale_init_value is not None:
244
+ self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
245
+ else:
246
+ self.layer_scale = nn.Identity()
247
+
248
+ def __call__(self, x: mx.array) -> mx.array:
249
+ shortcut = x
250
+ x = self.dw_start(x)
251
+ x = self.pw_exp(x)
252
+ x = self.dw_mid(x)
253
+ x = self.pw_proj(x)
254
+ x = self.layer_scale(x)
255
+ if self.has_skip:
256
+ x = x + shortcut
257
+ return x
258
+
259
+
260
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/conv_bn_act.py#L15
261
+ class ConvNormAct(nn.Module):
262
+ def __init__(
263
+ self,
264
+ conv_cls,
265
+ in_chs: int,
266
+ out_chs: int,
267
+ kernel_size: int = 3,
268
+ stride: int = 1,
269
+ padding: int = 0,
270
+ dilation: int = 1,
271
+ groups: int = 1,
272
+ bias: bool = False,
273
+ apply_act: bool = True,
274
+ eps: float = 1e-6,
275
+ ):
276
+ super().__init__()
277
+ self.out_chs = out_chs
278
+ self.conv = conv_cls(
279
+ in_chs,
280
+ out_chs,
281
+ kernel_size,
282
+ stride,
283
+ padding,
284
+ (dilation, dilation),
285
+ groups,
286
+ bias,
287
+ )
288
+ self.bn = RMSNormAct2d(out_chs, eps=eps, apply_act=apply_act)
289
+
290
+ def __call__(self, x: mx.array) -> mx.array:
291
+ c = self.conv(x)
292
+ r = self.bn(c)
293
+ return r
294
+
295
+
296
+ def pad_same(
297
+ x,
298
+ kernel_size: List[int],
299
+ stride: List[int],
300
+ dilation: List[int] = (1, 1),
301
+ value: float = 0,
302
+ ):
303
+ """
304
+ Input should be in MLX format
305
+ """
306
+ ih, iw = x.shape[1:3]
307
+ pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
308
+ pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
309
+
310
+ # MLX pad format: [(low, high), (low, high), ...] for each axis
311
+ # Padding order is reversed compared to PyTorch F.pad
312
+ pad_widths = [
313
+ (0, 0), # No padding for batch dimension
314
+ (pad_h // 2, pad_h - pad_h // 2), # Height padding
315
+ (pad_w // 2, pad_w - pad_w // 2), # Width padding
316
+ (0, 0), # No padding for channel dimension
317
+ ]
318
+
319
+ x = mx.pad(x, pad_widths, constant_values=value)
320
+ return x
321
+
322
+
323
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
324
+ dynamic = False
325
+ if isinstance(padding, str):
326
+ # for any string padding, the padding will be calculated for you, one of three ways
327
+ padding = padding.lower()
328
+ if padding == "same":
329
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
330
+ if is_static_pad(kernel_size, **kwargs):
331
+ # static case, no extra overhead
332
+ padding = get_padding(kernel_size, **kwargs)
333
+ else:
334
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
335
+ padding = 0
336
+ dynamic = True
337
+ elif padding == "valid":
338
+ # 'VALID' padding, same as padding=0
339
+ padding = 0
340
+ else:
341
+ # Default to PyTorch style 'same'-ish symmetric padding
342
+ padding = get_padding(kernel_size, **kwargs)
343
+ return padding, dynamic
344
+
345
+
346
+ def get_same_padding(
347
+ input_size: int, kernel_size: int, stride: int, dilation: int = 1
348
+ ) -> int:
349
+ """Calculate padding needed for 'same' output size."""
350
+ effective_kernel_size = dilation * (kernel_size - 1) + 1
351
+ output_size = (input_size + stride - 1) // stride
352
+ total_padding = max(
353
+ 0, (output_size - 1) * stride + effective_kernel_size - input_size
354
+ )
355
+ return total_padding
356
+
357
+
358
+ def get_padding(kernel_size, stride=1, dilation=1, **_):
359
+ """Get symmetric padding for given kernel size."""
360
+ if isinstance(kernel_size, int):
361
+ kernel_size = [kernel_size, kernel_size]
362
+ if isinstance(stride, int):
363
+ stride = [stride, stride]
364
+ if isinstance(dilation, int):
365
+ dilation = [dilation, dilation]
366
+
367
+ padding = []
368
+ for k, d in zip(kernel_size, dilation):
369
+ effective_k = d * (k - 1) + 1
370
+ pad_total = effective_k - 1
371
+ padding.append(pad_total // 2)
372
+ return tuple(padding)
373
+
374
+
375
+ def is_static_pad(kernel_size, stride=1, dilation=1, **_):
376
+ """Check if padding can be calculated statically."""
377
+ if isinstance(kernel_size, int):
378
+ kernel_size = [kernel_size, kernel_size]
379
+ if isinstance(stride, int):
380
+ stride = [stride, stride]
381
+ if isinstance(dilation, int):
382
+ dilation = [dilation, dilation]
383
+
384
+ # Static padding is possible when stride is 1 for all dimensions
385
+ return all(s == 1 for s in stride)
386
+
387
+
388
+ class Conv2dSame(nn.Conv2d):
389
+ def __init__(self, *args, **kwargs):
390
+ super().__init__(*args, **kwargs)
391
+ self.kernel_size = self.weight.shape[1:3]
392
+
393
+ def __call__(self, x: mx.array) -> mx.array:
394
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
395
+ y = mx.conv2d(
396
+ x, self.weight, self.stride, self.padding, self.dilation, self.groups
397
+ )
398
+ if "bias" in self:
399
+ y = y + self.bias
400
+ return y
401
+
402
+
403
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L629
404
+ class EdgeResidual(nn.Module):
405
+ def __init__(
406
+ self,
407
+ in_chs: int,
408
+ out_chs: int,
409
+ exp_kernel_size: int = 3,
410
+ stride: int = 1,
411
+ dilation: int = 1,
412
+ group_size: int = 0,
413
+ pad_type: str = "",
414
+ force_in_chs: int = 0,
415
+ noskip: bool = False,
416
+ expand_ratio: float = 1.0,
417
+ pw_kernel_size: int = 1,
418
+ norm_layer=RMSNormAct2d,
419
+ ):
420
+ super().__init__()
421
+
422
+ if force_in_chs > 0:
423
+ mid_chs = make_divisible(force_in_chs * expand_ratio)
424
+ else:
425
+ mid_chs = make_divisible(in_chs * expand_ratio)
426
+
427
+ groups = num_groups(group_size, mid_chs)
428
+
429
+ self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
430
+
431
+ self.conv_exp = Conv2dSame(
432
+ in_chs,
433
+ mid_chs,
434
+ kernel_size=exp_kernel_size,
435
+ stride=stride,
436
+ padding=0,
437
+ dilation=(dilation, dilation),
438
+ groups=groups,
439
+ bias=False,
440
+ )
441
+
442
+ self.bn1 = norm_layer(mid_chs, eps=1e-05) if norm_layer else nn.Identity()
443
+
444
+ # Point-wise linear projection
445
+ padding_pwl = (pw_kernel_size - 1) // 2
446
+ self.conv_pwl = nn.Conv2d(
447
+ mid_chs,
448
+ out_chs,
449
+ kernel_size=pw_kernel_size,
450
+ padding=padding_pwl,
451
+ bias=False,
452
+ )
453
+
454
+ self.bn2 = (
455
+ norm_layer(out_chs, eps=1e-05, apply_act=False)
456
+ if norm_layer
457
+ else nn.Identity()
458
+ )
459
+
460
+ def __call__(self, x: mx.array) -> mx.array:
461
+ shortcut = x
462
+ x = self.conv_exp(x)
463
+ x = self.bn1(x)
464
+ x = self.conv_pwl(x)
465
+ x = self.bn2(x)
466
+ if self.has_skip:
467
+ x = x + shortcut
468
+ return x
469
+
470
+
471
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L449
472
+ class MobileAttention(nn.Module):
473
+ def __init__(
474
+ self,
475
+ in_chs: int,
476
+ out_chs: int,
477
+ stride: int = 1,
478
+ dw_kernel_size: int = 3,
479
+ dilation: int = 1,
480
+ group_size: int = 1,
481
+ pad_type: str = "",
482
+ num_heads: int = 8,
483
+ key_dim: int = 64,
484
+ value_dim: int = 64,
485
+ use_multi_query: bool = True,
486
+ query_strides: Tuple[int, int] = (1, 1),
487
+ kv_stride: int = 1,
488
+ cpe_dw_kernel_size: int = 3,
489
+ noskip: bool = False,
490
+ act_layer=nn.GELU,
491
+ aa_layer=None,
492
+ drop_path_rate: float = 0.0,
493
+ attn_drop: float = 0.0,
494
+ proj_drop: float = 0.0,
495
+ layer_scale_init_value: Optional[float] = 1e-5,
496
+ use_bias: bool = False,
497
+ ):
498
+ super().__init__()
499
+ self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
500
+ self.query_strides = to_2tuple(query_strides)
501
+ self.kv_stride = kv_stride
502
+ self.has_query_stride = any([s > 1 for s in self.query_strides])
503
+
504
+ # Normalization layer
505
+ self.norm = RMSNormAct2d(
506
+ in_chs,
507
+ eps=1e-05,
508
+ apply_act=False,
509
+ )
510
+ # Determine number of heads if not provided
511
+ if num_heads is None:
512
+ assert in_chs % key_dim == 0
513
+ num_heads = in_chs // key_dim
514
+
515
+ # Attention layer
516
+ if use_multi_query:
517
+ self.attn = MultiQueryAttention2d(
518
+ in_chs,
519
+ dim_out=out_chs,
520
+ num_heads=num_heads,
521
+ key_dim=key_dim,
522
+ value_dim=value_dim,
523
+ query_strides=query_strides,
524
+ kv_stride=kv_stride,
525
+ dilation=dilation,
526
+ padding=pad_type,
527
+ dw_kernel_size=dw_kernel_size,
528
+ attn_drop=attn_drop,
529
+ proj_drop=proj_drop,
530
+ )
531
+ else:
532
+ raise NotImplementedError("attention not implemented")
533
+
534
+ # Layer scaling
535
+ if layer_scale_init_value is not None:
536
+ self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
537
+ else:
538
+ self.layer_scale = nn.Identity()
539
+
540
+ # Drop path for residual connection
541
+ self.drop_path = nn.Identity()
542
+
543
+ def __call__(self, x: mx.array) -> mx.array:
544
+ shortcut = x
545
+ x = self.norm(x)
546
+ x = self.attn(x)
547
+ x = self.layer_scale(x)
548
+
549
+ # Apply skip connection if available
550
+ if self.has_skip:
551
+ x = self.drop_path(x) + shortcut
552
+ return x
553
+
554
+
555
+ def create_conv2d(
556
+ in_channels,
557
+ out_channels,
558
+ kernel_size,
559
+ stride=1,
560
+ dilation=1,
561
+ depthwise=False,
562
+ bias=False,
563
+ **kwargs,
564
+ ):
565
+ """Helper function to create a 2D convolution with common parameters"""
566
+ if depthwise:
567
+ # Depthwise convolution
568
+ return nn.Conv2d(
569
+ in_channels,
570
+ out_channels,
571
+ kernel_size=kernel_size,
572
+ stride=stride,
573
+ padding=(kernel_size - 1) // 2 * dilation,
574
+ dilation=dilation,
575
+ groups=in_channels,
576
+ bias=bias,
577
+ )
578
+ else:
579
+ # Regular convolution
580
+ return nn.Conv2d(
581
+ in_channels,
582
+ out_channels,
583
+ kernel_size=kernel_size,
584
+ stride=stride,
585
+ padding=(kernel_size - 1) // 2 * dilation,
586
+ dilation=dilation,
587
+ bias=bias,
588
+ )
589
+
590
+
591
+ def to_2tuple(x):
592
+ """Convert input to 2-tuple"""
593
+ if isinstance(x, tuple):
594
+ return x
595
+ return (x, x)
596
+
597
+
598
+ class NamedSequential(nn.Module):
599
+ def __init__(self):
600
+ super().__init__()
601
+ self._order = []
602
+
603
+ def add_module(self, name, module):
604
+ setattr(self, name, module)
605
+ self._order.append(name)
606
+
607
+ def __call__(self, x):
608
+ for name in self._order:
609
+ x = getattr(self, name)(x)
610
+ return x
611
+
612
+
613
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/attention2d.py#L82
614
+ class MultiQueryAttention2d(nn.Module):
615
+ def __init__(
616
+ self,
617
+ dim: int,
618
+ dim_out: Optional[int] = None,
619
+ num_heads: int = 8,
620
+ key_dim: int = 64,
621
+ value_dim: int = 64,
622
+ query_strides: Tuple[int, int] = (1, 1),
623
+ kv_stride: int = 1,
624
+ dilation: int = 1,
625
+ padding: str = "",
626
+ dw_kernel_size: int = 3,
627
+ attn_drop: float = 0.0,
628
+ proj_drop: float = 0.0,
629
+ ):
630
+ super().__init__()
631
+ dim_out = dim_out or dim
632
+ self.num_heads = num_heads
633
+ self.query_strides = to_2tuple(query_strides)
634
+ self.kv_stride = kv_stride
635
+ self.fused_attn = True
636
+ self.key_dim = key_dim
637
+ self.value_dim = value_dim
638
+ head_dim = key_dim
639
+ self.scale = head_dim**-0.5
640
+
641
+ self.query = NamedSequential()
642
+ self.query.add_module(
643
+ "proj",
644
+ create_conv2d(
645
+ dim,
646
+ self.num_heads * self.key_dim,
647
+ kernel_size=1,
648
+ ),
649
+ )
650
+ self.key = NamedSequential()
651
+ if kv_stride > 1:
652
+ self.key.add_module(
653
+ "down_conv",
654
+ create_conv2d(
655
+ dim,
656
+ dim,
657
+ kernel_size=dw_kernel_size,
658
+ stride=kv_stride,
659
+ dilation=dilation,
660
+ padding=padding,
661
+ depthwise=True,
662
+ ),
663
+ )
664
+ self.key.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
665
+ self.key.add_module(
666
+ "proj", create_conv2d(dim, key_dim, kernel_size=1, bias=False)
667
+ )
668
+
669
+ self.value = NamedSequential()
670
+ if kv_stride > 1:
671
+ self.value.add_module(
672
+ "down_conv",
673
+ create_conv2d(
674
+ dim,
675
+ dim,
676
+ kernel_size=dw_kernel_size,
677
+ stride=kv_stride,
678
+ dilation=dilation,
679
+ padding=padding,
680
+ depthwise=True,
681
+ ),
682
+ )
683
+ self.value.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
684
+ self.value.add_module(
685
+ "proj", create_conv2d(dim, value_dim, kernel_size=1, bias=False)
686
+ )
687
+
688
+ # Attention dropout
689
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
690
+
691
+ # Output projection
692
+ self.output = NamedSequential()
693
+ self.output.add_module(
694
+ "proj",
695
+ create_conv2d(
696
+ value_dim * num_heads,
697
+ dim_out,
698
+ kernel_size=1,
699
+ stride=1,
700
+ bias=False,
701
+ ),
702
+ )
703
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
704
+
705
+ def _reshape_input(self, t: mx.array):
706
+ """
707
+ Input shape MLX: [B, H, W, C]
708
+ Input shape PyTorch: [B, C, H, W]
709
+
710
+ PyTorch Reshape: [B, C, H, W] -> [B, C, -1] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
711
+ MLX Reshape: [B, H, W, C] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
712
+ """
713
+ s = t.shape
714
+ t = t.reshape(s[0], -1, s[3])[:, None, :, :]
715
+
716
+ return t
717
+
718
+ def _reshape_projected_query(self, t: mx.array, num_heads: int, key_dim: int):
719
+ """
720
+ Input shape MLX: [B, H, W, C] where C = num_heads * key_dim
721
+ """
722
+ B, H, W, C = t.shape
723
+ # t = t.reshape(B, H, W, num_heads, key_dim)
724
+ t = t.reshape(B, H * W, num_heads, key_dim)
725
+ return t.transpose(0, 2, 1, 3)
726
+
727
+ def _reshape_output(self, t: mx.array, num_heads: int, h_px: int, w_px: int):
728
+ """
729
+ Input shape: [B, NH, L, D] where L = h_px * w_px
730
+ Output shape MLX: [B, H, W, C] where C = NH * D
731
+ """
732
+ B, NH, L, D = t.shape
733
+ # First transpose to [B, L, NH, D]
734
+ t = t.transpose(0, 2, 1, 3)
735
+ # Then reshape to [B, H, W, NH*D]
736
+ t = t.reshape(B, h_px, w_px, NH * D)
737
+ return t
738
+
739
+ def __call__(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
740
+ B, H, W, C = x.shape
741
+ q = self.query(x)
742
+ q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
743
+
744
+ k = self.key(x)
745
+ k = self._reshape_input(k)
746
+
747
+ v = self.value(x)
748
+ v = self._reshape_input(v)
749
+
750
+ if self.fused_attn:
751
+ o = mx.fast.scaled_dot_product_attention(
752
+ q,
753
+ k,
754
+ v,
755
+ scale=1.0 / sqrt(q.shape[-1]),
756
+ )
757
+ else:
758
+ raise NotImplementedError("unfused attention not implemented")
759
+
760
+ o = self._reshape_output(
761
+ o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]
762
+ )
763
+ x = self.output(o)
764
+ return x
765
+
766
+
767
+ def num_groups(group_size: Optional[int], channels: int) -> int:
768
+ if not group_size: # 0 or None
769
+ return 1 # normal conv with 1 group
770
+ else:
771
+ # NOTE group_size == 1 -> depthwise conv
772
+ assert channels % group_size == 0
773
+ return channels // group_size
774
+
775
+
776
+ def make_divisible(v, divisor: int = 8, min_value=None, round_limit: float = 0.9):
777
+ min_value = min_value or divisor
778
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
779
+ # Make sure that round down does not go down by more than 10%.
780
+ if new_v < round_limit * v:
781
+ new_v += divisor
782
+ return new_v
783
+
784
+
785
+ @dataclass(frozen=True)
786
+ class EdgeResidualConfig:
787
+ kernel_size: int = 3
788
+ filters: int = 32
789
+ strides: int = 1
790
+ expand_ratio: float = 4.0
791
+ is_multiscale: bool = False
792
+
793
+
794
+ def _er(kernel_size, filters, strides=1, expand_ratio=4.0, is_multiscale=False):
795
+ return EdgeResidualConfig(
796
+ kernel_size=kernel_size,
797
+ filters=filters,
798
+ strides=strides,
799
+ expand_ratio=expand_ratio,
800
+ is_multiscale=is_multiscale,
801
+ )
802
+
803
+
804
+ @dataclass(frozen=True)
805
+ class UniversalInvertedResidualConfig:
806
+ start_dw_kernel_size: int = 0 # Zero size means no conv
807
+ mid_dw_kernel_size: int = 0 # Zero size means no conv
808
+ filters: int = 32
809
+ strides: int = 1
810
+ expand_ratio: float = 4.0
811
+ is_multiscale: bool = False
812
+
813
+
814
+ def _uir(
815
+ start_dw_kernel_size,
816
+ mid_dw_kernel_size,
817
+ filters,
818
+ strides=1,
819
+ expand_ratio=4.0,
820
+ is_multiscale=False,
821
+ ):
822
+ return UniversalInvertedResidualConfig(
823
+ start_dw_kernel_size=start_dw_kernel_size,
824
+ mid_dw_kernel_size=mid_dw_kernel_size,
825
+ filters=filters,
826
+ strides=strides,
827
+ expand_ratio=expand_ratio,
828
+ is_multiscale=is_multiscale,
829
+ )
830
+
831
+
832
+ @dataclass(frozen=True)
833
+ class MultiQueryAttentionBlockConfig:
834
+ num_heads: int = 8
835
+ kv_dim: int = 16
836
+ kv_strides: int = 1
837
+ mmqa_avg_pool_kv: bool = False
838
+ mmqa_dropout: float = 0.0
839
+ mmqa_dw_kernel_size: int = 3
840
+ is_multiscale: bool = False
841
+
842
+
843
+ def _mmqa(
844
+ num_heads,
845
+ kv_dim,
846
+ kv_strides,
847
+ mmqa_avg_pool_kv=False,
848
+ is_multiscale=False,
849
+ ):
850
+ conf = MultiQueryAttentionBlockConfig(
851
+ num_heads=num_heads,
852
+ kv_dim=kv_dim,
853
+ kv_strides=kv_strides,
854
+ mmqa_avg_pool_kv=mmqa_avg_pool_kv,
855
+ is_multiscale=is_multiscale,
856
+ )
857
+ return conf
858
+
859
+
860
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L596
861
+ def gemma3n_mobilenet_def():
862
+ return [
863
+ # Stage 1: Edge Residuals
864
+ [_er(3, 128, 2)] + [_er(3, 128, 1)] * 2,
865
+ # Stage 2: Universal Inverted Residuals
866
+ [_uir(3, 5, 256, 2, 6.0)] + [_uir(k, 0, 256) for k in [5, 3, 5, 3]],
867
+ # Stage 3: Universal Inverted Residuals with Multi-Query Attention
868
+ [_uir(5, 5, 640, 2, 6.0)]
869
+ + [_uir(5, 0, 640)] * 7
870
+ + [_uir(0, 0, 640, 1, 1.0)]
871
+ + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0)] * 13
872
+ + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0, is_multiscale=True)],
873
+ # Stage 4: Universal Inverted Residuals with Multi-Query Attention
874
+ [_uir(5, 5, 1280, 2, 6.0)]
875
+ + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0)] * 18
876
+ + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0, is_multiscale=True)],
877
+ ]
878
+
879
+
880
+ class VisionTower(nn.Module):
881
+ def __init__(self, config: VisionConfig):
882
+ super().__init__()
883
+ self.conv_stem = ConvNormAct(
884
+ Conv2dSame,
885
+ in_chs=3,
886
+ out_chs=64,
887
+ kernel_size=3,
888
+ stride=2,
889
+ padding=0,
890
+ eps=1e-05,
891
+ bias=True,
892
+ )
893
+ msfa_indices = (3, 4)
894
+ msfa_output_resolution = (16, 16)
895
+
896
+ (num_features, self.blocks) = self.build()
897
+ self.num_features = self.head_hidden_size = (
898
+ num_features # output of msfa is output of forward_features()
899
+ )
900
+ self.msfa_indices = msfa_indices
901
+ self.msfa_output_resolution = msfa_output_resolution
902
+
903
+ self.msfa = MobileNetV5MultiScaleFusionAdapter(
904
+ in_chs=[1920],
905
+ out_chs=2048,
906
+ output_resolution=self.msfa_output_resolution,
907
+ )
908
+
909
+ def build(self):
910
+ blocks = []
911
+ in_chs = self.conv_stem.out_chs
912
+ for stage, block_config in enumerate(gemma3n_mobilenet_def()):
913
+ block_group = []
914
+ for config in block_config:
915
+ match config:
916
+ case EdgeResidualConfig(
917
+ kernel_size, filters, strides, expand_ratio, is_multiscale
918
+ ):
919
+ x = EdgeResidual(
920
+ exp_kernel_size=kernel_size,
921
+ in_chs=in_chs,
922
+ out_chs=filters,
923
+ stride=strides,
924
+ expand_ratio=expand_ratio,
925
+ )
926
+ in_chs = filters # in_chs of next is out_chs of prev
927
+ block_group.append(x)
928
+ case UniversalInvertedResidualConfig(
929
+ start_dw_kernel_size,
930
+ mid_dw_kernel_size,
931
+ filters,
932
+ strides,
933
+ expand_ratio,
934
+ is_multiscale,
935
+ ):
936
+ x = UniversalInvertedResidual(
937
+ in_chs=in_chs,
938
+ out_chs=filters,
939
+ dw_kernel_size_start=start_dw_kernel_size,
940
+ dw_kernel_size_mid=mid_dw_kernel_size,
941
+ stride=strides,
942
+ exp_ratio=expand_ratio,
943
+ )
944
+ in_chs = filters
945
+ block_group.append(x)
946
+ case MultiQueryAttentionBlockConfig(
947
+ num_heads,
948
+ kv_dim,
949
+ kv_strides,
950
+ mmqa_avg_pool_kv,
951
+ is_multiscale,
952
+ ):
953
+ x = MobileAttention(
954
+ in_chs=in_chs,
955
+ out_chs=in_chs,
956
+ stride=1,
957
+ num_heads=num_heads,
958
+ key_dim=kv_dim,
959
+ value_dim=kv_dim,
960
+ kv_stride=kv_strides,
961
+ act_layer=None,
962
+ )
963
+ block_group.append(x)
964
+ case _:
965
+ continue
966
+ blocks.append(block_group)
967
+ return (in_chs, blocks)
968
+
969
+ def __call__(
970
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
971
+ ) -> mx.array:
972
+ feat_idx = 0
973
+ x = x.transpose(0, 2, 3, 1) # Convert from NCHW to NHWC
974
+ x = self.conv_stem(x)
975
+ intermediates = []
976
+
977
+ if feat_idx in self.msfa_indices:
978
+ intermediates.append(x)
979
+
980
+ # MBV5 is constructed of 4 stages, each stage is a group of blocks.
981
+ for block_group in self.blocks:
982
+ feat_idx += 1
983
+ for block in block_group:
984
+ x = block(x)
985
+
986
+ if feat_idx in self.msfa_indices:
987
+ intermediates.append(x)
988
+
989
+ x = self.msfa(intermediates)
990
+ return x
991
+
992
+
993
+ class VisionModel(nn.Module):
994
+ def __init__(self, config: VisionConfig):
995
+ super().__init__()
996
+ self.model_type = config.model_type
997
+ if self.model_type not in ["gemma3", "gemma3_vision", "gemma3n_vision"]:
998
+ raise ValueError(f"Unsupported model type: {self.model_type}")
999
+
1000
+ self.timm_model = VisionTower(config)
1001
+
1002
+ def __call__(
1003
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
1004
+ ) -> mx.array:
1005
+ return self.timm_model(x, output_hidden_states)
1006
+
1007
+ def sanitize(self, weights):
1008
+ sanitized_weights = {}
1009
+ skip_transpose = False
1010
+ _, H, _, C = weights["vision_tower.timm_model.blocks.0.0.conv_exp.weight"].shape
1011
+ if C > H:
1012
+ skip_transpose = True
1013
+
1014
+ for k, v in weights.items():
1015
+ # PyTorch conv2d weight: [out_channels, in_channels, kH, kW]
1016
+ # MLX conv2d weight: [out_channels, kH, KW, in_channels]
1017
+ if ("conv" in k and "weight" in k) or ("attn" and "proj.weight") in k:
1018
+ if len(v.shape) == 4 and not skip_transpose:
1019
+ v = v.transpose(0, 2, 3, 1)
1020
+ sanitized_weights[k] = v
1021
+
1022
+ return sanitized_weights