nexaai 1.0.4rc10__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (519) hide show
  1. nexaai/__init__.py +71 -0
  2. nexaai/_version.py +4 -0
  3. nexaai/asr.py +60 -0
  4. nexaai/asr_impl/__init__.py +0 -0
  5. nexaai/asr_impl/mlx_asr_impl.py +91 -0
  6. nexaai/asr_impl/pybind_asr_impl.py +43 -0
  7. nexaai/base.py +39 -0
  8. nexaai/binds/__init__.py +3 -0
  9. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  10. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/libnexa_bridge.dylib +0 -0
  12. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  13. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  17. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  20. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/py-lib/ml.py +842 -0
  22. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  172. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  173. nexaai/common.py +61 -0
  174. nexaai/cv.py +87 -0
  175. nexaai/cv_impl/__init__.py +0 -0
  176. nexaai/cv_impl/mlx_cv_impl.py +88 -0
  177. nexaai/cv_impl/pybind_cv_impl.py +31 -0
  178. nexaai/embedder.py +68 -0
  179. nexaai/embedder_impl/__init__.py +0 -0
  180. nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
  181. nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
  182. nexaai/image_gen.py +136 -0
  183. nexaai/image_gen_impl/__init__.py +0 -0
  184. nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
  185. nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
  186. nexaai/llm.py +89 -0
  187. nexaai/llm_impl/__init__.py +0 -0
  188. nexaai/llm_impl/mlx_llm_impl.py +249 -0
  189. nexaai/llm_impl/pybind_llm_impl.py +207 -0
  190. nexaai/mlx_backend/asr/__init__.py +12 -0
  191. nexaai/mlx_backend/asr/interface.py +122 -0
  192. nexaai/mlx_backend/common/__init__.py +0 -0
  193. nexaai/mlx_backend/common/utils.py +25 -0
  194. nexaai/mlx_backend/cv/__init__.py +0 -0
  195. nexaai/mlx_backend/cv/generate.py +195 -0
  196. nexaai/mlx_backend/cv/interface.py +151 -0
  197. nexaai/mlx_backend/cv/main.py +81 -0
  198. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  199. nexaai/mlx_backend/embedding/__init__.py +0 -0
  200. nexaai/mlx_backend/embedding/generate.py +130 -0
  201. nexaai/mlx_backend/embedding/interface.py +312 -0
  202. nexaai/mlx_backend/embedding/main.py +82 -0
  203. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  204. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  205. nexaai/mlx_backend/llm/__init__.py +0 -0
  206. nexaai/mlx_backend/llm/generate.py +149 -0
  207. nexaai/mlx_backend/llm/interface.py +764 -0
  208. nexaai/mlx_backend/llm/main.py +68 -0
  209. nexaai/mlx_backend/ml.py +842 -0
  210. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  211. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  212. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  213. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  214. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  215. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  216. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  217. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  218. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  219. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  220. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  221. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  222. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  223. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  224. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  225. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  226. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  227. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  228. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  229. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  230. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  231. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  232. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  233. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  234. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  235. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  236. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  247. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  249. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  250. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  251. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  252. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  253. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  254. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  255. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  256. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  257. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  258. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  259. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  260. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  261. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  262. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  264. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  265. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  266. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  267. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  268. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  269. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  270. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  271. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  272. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  273. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  274. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  275. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  276. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  277. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  278. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  279. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  280. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  281. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  282. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  283. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  284. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  285. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  286. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  288. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  289. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  290. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  291. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  292. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  293. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  294. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  295. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  296. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  297. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  298. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  299. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  300. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  301. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  302. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  303. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  304. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  305. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  306. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  307. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  308. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  309. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  310. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  353. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  354. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  355. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  356. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  357. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  358. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  359. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  360. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  361. nexaai/mlx_backend/profiling.py +239 -0
  362. nexaai/mlx_backend/rerank/__init__.py +0 -0
  363. nexaai/mlx_backend/rerank/generate.py +174 -0
  364. nexaai/mlx_backend/rerank/interface.py +287 -0
  365. nexaai/mlx_backend/rerank/main.py +127 -0
  366. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  367. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  368. nexaai/mlx_backend/sd/__init__.py +1 -0
  369. nexaai/mlx_backend/sd/interface.py +362 -0
  370. nexaai/mlx_backend/sd/main.py +286 -0
  371. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  372. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  373. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  374. nexaai/mlx_backend/sd/modeling/model_io.py +330 -0
  375. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  376. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  377. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  378. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  379. nexaai/mlx_backend/tts/__init__.py +12 -0
  380. nexaai/mlx_backend/tts/interface.py +276 -0
  381. nexaai/mlx_backend/vlm/__init__.py +3 -0
  382. nexaai/mlx_backend/vlm/generate.py +572 -0
  383. nexaai/mlx_backend/vlm/interface.py +406 -0
  384. nexaai/mlx_backend/vlm/main.py +157 -0
  385. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  386. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  387. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  388. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  389. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  390. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  391. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  392. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  393. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  394. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  395. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  396. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  397. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  398. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  399. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  400. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  401. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  402. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  403. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  404. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  405. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  406. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  407. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  408. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  409. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  410. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  411. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  412. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  413. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  414. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  415. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  416. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  417. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  418. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  419. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  420. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  421. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  422. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  423. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  424. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  425. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  426. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  427. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  428. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  429. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  430. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  431. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  432. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  433. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  434. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  435. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  436. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  437. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  438. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  439. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  440. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  441. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  443. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  444. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  445. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  447. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  448. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  449. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  450. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  451. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  452. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  453. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  454. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  455. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  457. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  458. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  459. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  460. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  461. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  462. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  463. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  464. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  465. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  466. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  467. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  468. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  470. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  471. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  472. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  473. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  474. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  475. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  476. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  477. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  478. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  479. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  480. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  481. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  482. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  483. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  484. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  485. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  486. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  487. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  488. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  489. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  490. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  491. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  492. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  493. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  494. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  495. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  496. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  497. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  498. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  499. nexaai/rerank.py +51 -0
  500. nexaai/rerank_impl/__init__.py +0 -0
  501. nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
  502. nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
  503. nexaai/runtime.py +64 -0
  504. nexaai/tts.py +70 -0
  505. nexaai/tts_impl/__init__.py +0 -0
  506. nexaai/tts_impl/mlx_tts_impl.py +93 -0
  507. nexaai/tts_impl/pybind_tts_impl.py +42 -0
  508. nexaai/utils/avatar_fetcher.py +104 -0
  509. nexaai/utils/decode.py +18 -0
  510. nexaai/utils/model_manager.py +1195 -0
  511. nexaai/utils/progress_tracker.py +372 -0
  512. nexaai/vlm.py +120 -0
  513. nexaai/vlm_impl/__init__.py +0 -0
  514. nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
  515. nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
  516. nexaai-1.0.4rc10.dist-info/METADATA +26 -0
  517. nexaai-1.0.4rc10.dist-info/RECORD +519 -0
  518. nexaai-1.0.4rc10.dist-info/WHEEL +5 -0
  519. nexaai-1.0.4rc10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Optional
6
+ from enum import IntEnum
7
+
8
+ # --------------------------------------------------------------------------------------
9
+ # Stop reason constants matching profile.h
10
+ # --------------------------------------------------------------------------------------
11
+
12
+ class StopReason(IntEnum):
13
+ """Stop reason constants matching profile.h"""
14
+ ML_STOP_REASON_UNKNOWN = 0
15
+ ML_STOP_REASON_EOS = 1
16
+ ML_STOP_REASON_LENGTH = 2
17
+ ML_STOP_REASON_USER = 3
18
+ ML_STOP_REASON_STOP_SEQUENCE = 4
19
+ ML_STOP_REASON_COMPLETED = 5
20
+
21
+ # --------------------------------------------------------------------------------------
22
+ # Profiling data structure
23
+ # --------------------------------------------------------------------------------------
24
+
25
+ @dataclass
26
+ class ProfilingData:
27
+ """Profiling data for performance metrics."""
28
+ ttft_us: int = 0 # Time to first token (us)
29
+ total_time_us: int = 0 # Total generation time (us)
30
+ prompt_time_us: int = 0 # Prompt processing time (us)
31
+ decode_time_us: int = 0 # Token generation time (us)
32
+ tokens_per_second: float = 0.0 # Decoding speed (tokens/sec)
33
+ total_tokens: int = 0 # Total tokens generated
34
+ prompt_tokens: int = 0 # Number of prompt tokens
35
+ generated_tokens: int = 0 # Number of generated tokens
36
+ stop_reason: int = StopReason.ML_STOP_REASON_UNKNOWN # Stop reason (numeric)
37
+
38
+ def reset(self):
39
+ """Reset all profiling data."""
40
+ self.ttft_us = 0
41
+ self.total_time_us = 0
42
+ self.prompt_time_us = 0
43
+ self.decode_time_us = 0
44
+ self.tokens_per_second = 0.0
45
+ self.total_tokens = 0
46
+ self.prompt_tokens = 0
47
+ self.generated_tokens = 0
48
+ self.stop_reason = StopReason.ML_STOP_REASON_UNKNOWN
49
+
50
+ # --------------------------------------------------------------------------------------
51
+ # Profiling context (similar to ml_ProfilingContext in profile.h)
52
+ # --------------------------------------------------------------------------------------
53
+
54
+ @dataclass
55
+ class ProfilingContext:
56
+ """Profiling context for tracking timing and state."""
57
+ start_time: Optional[float] = None
58
+ prompt_start_time: Optional[float] = None
59
+ prompt_end_time: Optional[float] = None
60
+ decode_start_time: Optional[float] = None
61
+ decode_end_time: Optional[float] = None
62
+ first_token_time: Optional[float] = None
63
+ end_time: Optional[float] = None
64
+
65
+ ttft_recorded: bool = False
66
+ stop_reason: int = StopReason.ML_STOP_REASON_UNKNOWN
67
+ prompt_tokens: int = 0
68
+ generated_tokens: int = 0
69
+
70
+ def reset(self):
71
+ """Reset profiling context."""
72
+ self.start_time = None
73
+ self.prompt_start_time = None
74
+ self.prompt_end_time = None
75
+ self.decode_start_time = None
76
+ self.decode_end_time = None
77
+ self.first_token_time = None
78
+ self.end_time = None
79
+ self.ttft_recorded = False
80
+ self.stop_reason = StopReason.ML_STOP_REASON_UNKNOWN
81
+ self.prompt_tokens = 0
82
+ self.generated_tokens = 0
83
+
84
+ # --------------------------------------------------------------------------------------
85
+ # Profiling functions (similar to profile.h functions)
86
+ # --------------------------------------------------------------------------------------
87
+
88
+ def profiling_reset(ctx: ProfilingContext) -> None:
89
+ """Reset profiling context (ml_profiling_reset)."""
90
+ ctx.reset()
91
+
92
+ def profiling_start(ctx: ProfilingContext) -> None:
93
+ """Start profiling (ml_profiling_start)."""
94
+ ctx.start_time = time.perf_counter()
95
+ ctx.prompt_start_time = ctx.start_time
96
+
97
+ def profiling_prompt_start(ctx: ProfilingContext) -> None:
98
+ """Start prompt processing timing (ml_profiling_prompt_start)."""
99
+ ctx.prompt_start_time = time.perf_counter()
100
+
101
+ def profiling_prompt_end(ctx: ProfilingContext) -> None:
102
+ """End prompt processing timing (ml_profiling_prompt_end)."""
103
+ ctx.prompt_end_time = time.perf_counter()
104
+
105
+ def profiling_decode_start(ctx: ProfilingContext) -> None:
106
+ """Start decode timing (ml_profiling_decode_start)."""
107
+ ctx.decode_start_time = time.perf_counter()
108
+
109
+ def profiling_decode_end(ctx: ProfilingContext) -> None:
110
+ """End decode timing (ml_profiling_decode_end)."""
111
+ ctx.decode_end_time = time.perf_counter()
112
+
113
+ def profiling_record_ttft(ctx: ProfilingContext) -> None:
114
+ """Record time to first token (ml_profiling_record_ttft)."""
115
+ if not ctx.ttft_recorded and ctx.start_time is not None:
116
+ ctx.first_token_time = time.perf_counter()
117
+ ctx.ttft_recorded = True
118
+
119
+ def profiling_update_prompt_tokens(ctx: ProfilingContext, prompt_tokens: int) -> None:
120
+ """Update prompt token count (ml_profiling_update_prompt_tokens)."""
121
+ ctx.prompt_tokens = prompt_tokens
122
+
123
+ def profiling_update_generated_tokens(ctx: ProfilingContext, generated_tokens: int) -> None:
124
+ """Update generated token count (ml_profiling_update_generated_tokens)."""
125
+ ctx.generated_tokens = generated_tokens
126
+
127
+ def profiling_stop_reason(ctx: ProfilingContext, stop_reason: int) -> None:
128
+ """Set stop reason (ml_profiling_stop_reason)."""
129
+ ctx.stop_reason = stop_reason
130
+
131
+ def profiling_end(ctx: ProfilingContext) -> None:
132
+ """End profiling (ml_profiling_end)."""
133
+ ctx.end_time = time.perf_counter()
134
+
135
+ def profiling_gen_data(ctx: ProfilingContext) -> ProfilingData:
136
+ """Generate profiling data from context (ml_profiling_gen_data)."""
137
+ data = ProfilingData()
138
+
139
+ if ctx.start_time is None or ctx.end_time is None:
140
+ return data
141
+
142
+ # Calculate total time
143
+ data.total_time_us = int((ctx.end_time - ctx.start_time) * 1_000_000)
144
+
145
+ # Calculate prompt time
146
+ if ctx.prompt_start_time is not None and ctx.prompt_end_time is not None:
147
+ data.prompt_time_us = int((ctx.prompt_end_time - ctx.prompt_start_time) * 1_000_000)
148
+
149
+ # Calculate decode time
150
+ if ctx.decode_start_time is not None and ctx.decode_end_time is not None:
151
+ data.decode_time_us = int((ctx.decode_end_time - ctx.decode_start_time) * 1_000_000)
152
+
153
+ # Calculate TTFT
154
+ if ctx.first_token_time is not None and ctx.start_time is not None:
155
+ data.ttft_us = int((ctx.first_token_time - ctx.start_time) * 1_000_000)
156
+
157
+ # Set token counts
158
+ data.prompt_tokens = ctx.prompt_tokens
159
+ data.generated_tokens = ctx.generated_tokens
160
+ data.total_tokens = ctx.prompt_tokens + ctx.generated_tokens
161
+
162
+ # Calculate tokens per second
163
+ if data.decode_time_us > 0:
164
+ data.tokens_per_second = (data.generated_tokens * 1_000_000.0) / data.decode_time_us
165
+
166
+ # Set stop reason
167
+ data.stop_reason = ctx.stop_reason
168
+
169
+ return data
170
+
171
+ def stop_reason_to_string(reason: int) -> str:
172
+ """Convert stop reason to string (stop_reason_to_string)."""
173
+ try:
174
+ return StopReason(reason).name
175
+ except ValueError:
176
+ return f"UNKNOWN({reason})"
177
+
178
+ # --------------------------------------------------------------------------------------
179
+ # Profiling mixin for model classes
180
+ # --------------------------------------------------------------------------------------
181
+
182
+ class ProfilingMixin:
183
+ """Mixin class to add profiling capabilities to model classes."""
184
+
185
+ def __init__(self):
186
+ """Initialize profiling mixin."""
187
+ self._profiling_context = ProfilingContext()
188
+ self._profiling_data = ProfilingData()
189
+
190
+ def _start_profiling(self) -> None:
191
+ """Start profiling for an operation."""
192
+ profiling_reset(self._profiling_context)
193
+ profiling_start(self._profiling_context)
194
+
195
+ def _prompt_start(self) -> None:
196
+ """Start prompt processing timing."""
197
+ profiling_prompt_start(self._profiling_context)
198
+
199
+ def _prompt_end(self) -> None:
200
+ """End prompt processing timing."""
201
+ profiling_prompt_end(self._profiling_context)
202
+
203
+ def _decode_start(self) -> None:
204
+ """Start decode timing."""
205
+ profiling_decode_start(self._profiling_context)
206
+
207
+ def _decode_end(self) -> None:
208
+ """End decode timing."""
209
+ profiling_decode_end(self._profiling_context)
210
+
211
+ def _record_ttft(self) -> None:
212
+ """Record time to first token."""
213
+ profiling_record_ttft(self._profiling_context)
214
+
215
+ def _update_prompt_tokens(self, prompt_tokens: int) -> None:
216
+ """Update prompt token count."""
217
+ profiling_update_prompt_tokens(self._profiling_context, prompt_tokens)
218
+
219
+ def _update_generated_tokens(self, generated_tokens: int) -> None:
220
+ """Update generated token count."""
221
+ profiling_update_generated_tokens(self._profiling_context, generated_tokens)
222
+
223
+ def _set_stop_reason(self, stop_reason: int) -> None:
224
+ """Set stop reason."""
225
+ profiling_stop_reason(self._profiling_context, stop_reason)
226
+
227
+ def _end_profiling(self) -> ProfilingData:
228
+ """End profiling and return data."""
229
+ profiling_end(self._profiling_context)
230
+ self._profiling_data = profiling_gen_data(self._profiling_context)
231
+ return self._profiling_data
232
+
233
+ def get_profiling_data(self) -> ProfilingData:
234
+ """Get profiling data for the last operation."""
235
+ return self._profiling_data
236
+
237
+ def reset_profiling(self) -> None:
238
+ """Reset profiling data."""
239
+ self._profiling_data.reset()
nexaai/common.py ADDED
@@ -0,0 +1,61 @@
1
+ from dataclasses import dataclass
2
+ from typing import TypedDict, Literal, Optional, List
3
+
4
+
5
+ class ChatMessage(TypedDict):
6
+ role: Literal["user", "assistant", "system"]
7
+ content: str
8
+
9
+ class MultiModalMessageContent(TypedDict):
10
+ type: Literal["text", "image", "audio", "video"]
11
+ text: Optional[str]
12
+ url: Optional[str]
13
+ path: Optional[str]
14
+
15
+ class MultiModalMessage(TypedDict):
16
+ role: Literal["user", "assistant", "system"]
17
+ content: List[MultiModalMessageContent]
18
+
19
+
20
+ @dataclass
21
+ class SamplerConfig:
22
+ temperature: float = 0.8
23
+ top_p: float = 0.95
24
+ top_k: int = 40
25
+ repetition_penalty: float = 1.0
26
+ presence_penalty: float = 0.0
27
+ frequency_penalty: float = 0.0
28
+ seed: int = -1
29
+ grammar_path: str = None
30
+ grammar_string: str = None
31
+
32
+ @dataclass
33
+ class GenerationConfig:
34
+ max_tokens: int = 1024
35
+ stop_words: list[str] = None
36
+ sampler_config: SamplerConfig = None
37
+ image_paths: list[str] = None
38
+ audio_paths: list[str] = None
39
+
40
+ @dataclass
41
+ class ModelConfig:
42
+ n_ctx: int = 4096
43
+ n_threads: int = None
44
+ n_threads_batch: int = None
45
+ n_batch: int = 512
46
+ n_ubatch: int = 512
47
+ n_seq_max: int = 1
48
+ n_gpu_layers: int = 999
49
+ chat_template_path: str = None
50
+ chat_template_content: str = None
51
+
52
+
53
+ @dataclass(frozen=True) # Read-only
54
+ class ProfilingData:
55
+ start_time: int
56
+ end_time: int
57
+ prompt_start_time: int = None
58
+ prompt_end_time: int = None
59
+ decode_start_time: int = None
60
+ decode_ent_time: int = None
61
+ first_token_time: int = None
nexaai/cv.py ADDED
@@ -0,0 +1,87 @@
1
+ from typing import List, Optional
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+
5
+ from nexaai.base import BaseModel
6
+
7
+
8
+ @dataclass
9
+ class BoundingBox:
10
+ """Generic bounding box structure."""
11
+ x: float # X coordinate (normalized or pixel, depends on model)
12
+ y: float # Y coordinate (normalized or pixel, depends on model)
13
+ width: float # Width
14
+ height: float # Height
15
+
16
+
17
+ @dataclass
18
+ class CVResult:
19
+ """Generic detection/classification result."""
20
+ image_paths: Optional[List[str]] = None # Output image paths
21
+ image_count: int = 0 # Number of output images
22
+ class_id: int = 0 # Class ID (example: ConvNext)
23
+ confidence: float = 0.0 # Confidence score [0.0-1.0]
24
+ bbox: Optional[BoundingBox] = None # Bounding box (example: YOLO)
25
+ text: Optional[str] = None # Text result (example: OCR)
26
+ embedding: Optional[List[float]] = None # Feature embedding (example: CLIP embedding)
27
+ embedding_dim: int = 0 # Embedding dimension
28
+
29
+
30
+ @dataclass
31
+ class CVResults:
32
+ """Generic CV inference result."""
33
+ results: List[CVResult] # Array of CV results
34
+ result_count: int # Number of CV results
35
+
36
+
37
+ class CVCapabilities:
38
+ """CV capabilities enum."""
39
+ OCR = 0 # OCR
40
+ CLASSIFICATION = 1 # Classification
41
+ SEGMENTATION = 2 # Segmentation
42
+ CUSTOM = 3 # Custom task
43
+
44
+
45
+ @dataclass
46
+ class CVModelConfig:
47
+ """CV model preprocessing configuration."""
48
+ capabilities: int # CVCapabilities
49
+
50
+ # MLX-OCR
51
+ det_model_path: Optional[str] = None # Detection model path
52
+ rec_model_path: Optional[str] = None # Recognition model path
53
+
54
+ # QNN
55
+ model_path: Optional[str] = None # Model path
56
+ system_library_path: Optional[str] = None # System library path
57
+ backend_library_path: Optional[str] = None # Backend library path
58
+ extension_library_path: Optional[str] = None # Extension library path
59
+ config_file_path: Optional[str] = None # Config file path
60
+ char_dict_path: Optional[str] = None # Character dictionary path
61
+
62
+
63
+ class CVModel(BaseModel):
64
+ """Abstract base class for generic computer vision models."""
65
+
66
+ def __init__(self):
67
+ """Initialize base CV model class."""
68
+ pass
69
+
70
+ @classmethod
71
+ def _load_from(cls,
72
+ config: CVModelConfig,
73
+ plugin_id: str = "llama_cpp",
74
+ device_id: Optional[str] = None
75
+ ) -> 'CVModel':
76
+ """Load CV model from configuration, routing to appropriate implementation."""
77
+ if plugin_id == "mlx":
78
+ from nexaai.cv_impl.mlx_cv_impl import MLXCVImpl
79
+ return MLXCVImpl._load_from(config, plugin_id, device_id)
80
+ else:
81
+ from nexaai.cv_impl.pybind_cv_impl import PyBindCVImpl
82
+ return PyBindCVImpl._load_from(config, plugin_id, device_id)
83
+
84
+ @abstractmethod
85
+ def infer(self, input_image_path: str) -> CVResults:
86
+ """Perform inference on image."""
87
+ pass
File without changes
@@ -0,0 +1,88 @@
1
+ # Note: This code is generated by Cursor, not tested yet.
2
+
3
+ from typing import Optional
4
+ import os
5
+
6
+ from nexaai.cv import CVModel, CVModelConfig, CVResults
7
+ from nexaai.mlx_backend.cv.interface import CVModel as MLXCVInterface, create_cv_model
8
+
9
+
10
+ class MLXCVImpl(CVModel):
11
+ def __init__(self):
12
+ """Initialize MLX CV implementation."""
13
+ super().__init__()
14
+ self._mlx_cv = None
15
+
16
+ @classmethod
17
+ def _load_from(cls,
18
+ config: CVModelConfig,
19
+ plugin_id: str = "mlx",
20
+ device_id: Optional[str] = None
21
+ ) -> 'MLXCVImpl':
22
+ """Load CV model from configuration using MLX backend."""
23
+ try:
24
+ # Get MLX config class
25
+ from nexaai.mlx_backend.ml import CVModelConfig as MLXCVModelConfig
26
+
27
+ # Convert our config to MLX format
28
+ mlx_config = MLXCVModelConfig(
29
+ capabilities=config.capabilities,
30
+ det_model_path=config.det_model_path,
31
+ rec_model_path=config.rec_model_path,
32
+ model_path=config.model_path,
33
+ system_library_path=config.system_library_path,
34
+ backend_library_path=config.backend_library_path,
35
+ extension_library_path=config.extension_library_path,
36
+ config_file_path=config.config_file_path,
37
+ char_dict_path=config.char_dict_path
38
+ )
39
+
40
+ # Create instance and load MLX CV model
41
+ instance = cls()
42
+ instance._mlx_cv = create_cv_model(mlx_config, device_id)
43
+
44
+ return instance
45
+ except Exception as e:
46
+ raise RuntimeError(f"Failed to load MLX CV: {str(e)}")
47
+
48
+ def eject(self):
49
+ """Destroy the model and free resources."""
50
+ if self._mlx_cv:
51
+ self._mlx_cv.destroy()
52
+ self._mlx_cv = None
53
+
54
+ def infer(self, input_image_path: str) -> CVResults:
55
+ """Perform inference on image."""
56
+ if not self._mlx_cv:
57
+ raise RuntimeError("MLX CV not loaded")
58
+
59
+ try:
60
+ # Use MLX CV inference
61
+ result = self._mlx_cv.infer(input_image_path)
62
+
63
+ # Convert MLX result to our format
64
+ from nexaai.cv import CVResult
65
+
66
+ our_results = []
67
+ for mlx_result in result.results:
68
+ our_result = CVResult(
69
+ image_paths=mlx_result.image_paths,
70
+ image_count=mlx_result.image_count,
71
+ class_id=mlx_result.class_id,
72
+ confidence=mlx_result.confidence,
73
+ bbox=mlx_result.bbox,
74
+ text=mlx_result.text,
75
+ embedding=mlx_result.embedding,
76
+ embedding_dim=mlx_result.embedding_dim
77
+ )
78
+ our_results.append(our_result)
79
+
80
+ return CVResults(
81
+ results=our_results,
82
+ result_count=result.result_count
83
+ )
84
+
85
+ except Exception as e:
86
+ raise RuntimeError(f"Failed to perform CV inference: {str(e)}")
87
+
88
+
@@ -0,0 +1,31 @@
1
+ from typing import Optional
2
+
3
+ from nexaai.cv import CVModel, CVModelConfig, CVResults
4
+
5
+
6
+ class PyBindCVImpl(CVModel):
7
+ def __init__(self):
8
+ """Initialize PyBind CV implementation."""
9
+ super().__init__()
10
+ # TODO: Add PyBind-specific initialization
11
+
12
+ @classmethod
13
+ def _load_from(cls,
14
+ config: CVModelConfig,
15
+ plugin_id: str = "llama_cpp",
16
+ device_id: Optional[str] = None
17
+ ) -> 'PyBindCVImpl':
18
+ """Load CV model from configuration using PyBind backend."""
19
+ # TODO: Implement PyBind CV loading
20
+ instance = cls()
21
+ return instance
22
+
23
+ def eject(self):
24
+ """Destroy the model and free resources."""
25
+ # TODO: Implement PyBind CV cleanup
26
+ pass
27
+
28
+ def infer(self, input_image_path: str) -> CVResults:
29
+ """Perform inference on image."""
30
+ # TODO: Implement PyBind CV inference
31
+ raise NotImplementedError("PyBind CV inference not yet implemented")
nexaai/embedder.py ADDED
@@ -0,0 +1,68 @@
1
+ from typing import List, Union
2
+ from dataclasses import dataclass
3
+ from abc import abstractmethod
4
+ import numpy as np
5
+
6
+ from nexaai.base import BaseModel
7
+
8
+
9
+ @dataclass
10
+ class EmbeddingConfig:
11
+ batch_size: int = 32
12
+ normalize: bool = True
13
+ normalize_method: str = "l2"
14
+
15
+
16
+ class Embedder(BaseModel):
17
+ def __init__(self):
18
+ """
19
+ Internal initializer
20
+ """
21
+ pass
22
+
23
+ @classmethod
24
+ def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: str = "llama_cpp"):
25
+ """
26
+ Load an embedder from model files, routing to appropriate implementation.
27
+
28
+ Args:
29
+ model_path: Path to the model file
30
+ tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
31
+ plugin_id: Plugin ID to use for the model (default: "llama_cpp")
32
+
33
+ Returns:
34
+ Embedder instance
35
+ """
36
+ if plugin_id == "mlx":
37
+ from nexaai.embedder_impl.mlx_embedder_impl import MLXEmbedderImpl
38
+ return MLXEmbedderImpl._load_from(model_path, tokenizer_file, plugin_id)
39
+ else:
40
+ from nexaai.embedder_impl.pybind_embedder_impl import PyBindEmbedderImpl
41
+ return PyBindEmbedderImpl._load_from(model_path, tokenizer_file, plugin_id)
42
+
43
+ @abstractmethod
44
+ def generate(self, texts: Union[List[str], str] = None, config: EmbeddingConfig = EmbeddingConfig(), input_ids: Union[List[int], List[List[int]]] = None) -> np.ndarray:
45
+ """
46
+ Generate embeddings for the given texts or input_ids.
47
+
48
+ Args:
49
+ texts: List of strings or single string to embed
50
+ input_ids: Pre-tokenized input as:
51
+ - Single sequence: list of integers [1, 2, 3, 4]
52
+ - Multiple sequences: list of lists [[1, 2, 3], [4, 5, 6]]
53
+ config: Configuration for embedding generation
54
+
55
+ Returns:
56
+ numpy array of embeddings with shape (num_sequences, embedding_dim)
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ def get_embedding_dim(self) -> int:
62
+ """
63
+ Get the embedding dimension of the model
64
+
65
+ Returns:
66
+ The embedding dimension in int
67
+ """
68
+ pass
File without changes
@@ -0,0 +1,114 @@
1
+ from typing import List, Union
2
+ import numpy as np
3
+
4
+ from nexaai.embedder import Embedder, EmbeddingConfig
5
+ from nexaai.mlx_backend.embedding.interface import Embedder as MLXEmbedderInterface
6
+ from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
7
+
8
+
9
+ class MLXEmbedderImpl(Embedder):
10
+ def __init__(self):
11
+ """Initialize MLX Embedder implementation."""
12
+ super().__init__()
13
+ self._mlx_embedder = None
14
+
15
+ @classmethod
16
+ def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: str = "mlx"):
17
+ """
18
+ Load an embedder from model files using MLX backend.
19
+
20
+ Args:
21
+ model_path: Path to the model file
22
+ tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
23
+ plugin_id: Plugin ID to use for the model (default: "mlx")
24
+
25
+ Returns:
26
+ MLXEmbedderImpl instance
27
+ """
28
+ try:
29
+ # MLX interface is already imported
30
+
31
+ # Create instance and load MLX embedder
32
+ instance = cls()
33
+ instance._mlx_embedder = MLXEmbedderInterface(
34
+ model_path=model_path,
35
+ tokenizer_path=tokenizer_file
36
+ )
37
+
38
+ # Load the model
39
+ success = instance._mlx_embedder.load_model(model_path)
40
+ if not success:
41
+ raise RuntimeError("Failed to load MLX embedder model")
42
+
43
+ return instance
44
+ except Exception as e:
45
+ raise RuntimeError(f"Failed to load MLX Embedder: {str(e)}")
46
+
47
+ def eject(self):
48
+ """
49
+ Clean up resources and destroy the embedder
50
+ """
51
+ if self._mlx_embedder:
52
+ self._mlx_embedder.destroy()
53
+ self._mlx_embedder = None
54
+
55
+ def generate(self, texts: Union[List[str], str] = None, config: EmbeddingConfig = EmbeddingConfig(), input_ids: Union[List[int], List[List[int]]] = None) -> np.ndarray:
56
+ """
57
+ Generate embeddings for the given texts or input_ids.
58
+
59
+ Args:
60
+ texts: List of strings or single string to embed
61
+ input_ids: Pre-tokenized input as:
62
+ - Single sequence: list of integers [1, 2, 3, 4]
63
+ - Multiple sequences: list of lists [[1, 2, 3], [4, 5, 6]]
64
+ config: Configuration for embedding generation
65
+
66
+ Returns:
67
+ numpy array of embeddings with shape (num_sequences, embedding_dim)
68
+ """
69
+ if not self._mlx_embedder:
70
+ raise RuntimeError("MLX Embedder not loaded")
71
+
72
+ if texts is None and input_ids is None:
73
+ raise ValueError("Either texts or input_ids must be provided")
74
+
75
+ # MLX embedder currently only supports text input, not pre-tokenized input_ids
76
+ if input_ids is not None:
77
+ raise NotImplementedError("MLX embedder does not support input_ids, only text input")
78
+
79
+ try:
80
+ # Convert single string to list if needed
81
+ if isinstance(texts, str):
82
+ texts = [texts]
83
+
84
+ # MLX config classes are already imported
85
+
86
+ # Convert our config to MLX config
87
+ mlx_config = EmbeddingConfig()
88
+ mlx_config.batch_size = config.batch_size
89
+ mlx_config.normalize = config.normalize
90
+ mlx_config.normalize_method = config.normalize_method
91
+
92
+ # Generate embeddings using MLX
93
+ embeddings = self._mlx_embedder.embed(texts, mlx_config)
94
+
95
+ # Convert to numpy array
96
+ return np.array(embeddings, dtype=np.float32)
97
+
98
+ except Exception as e:
99
+ raise RuntimeError(f"Failed to generate embeddings: {str(e)}")
100
+
101
+ def get_embedding_dim(self) -> int:
102
+ """
103
+ Get the embedding dimension of the model
104
+
105
+ Returns:
106
+ The embedding dimension in int
107
+ """
108
+ if not self._mlx_embedder:
109
+ raise RuntimeError("MLX Embedder not loaded")
110
+
111
+ try:
112
+ return self._mlx_embedder.embedding_dim()
113
+ except Exception as e:
114
+ raise RuntimeError(f"Failed to get embedding dimension: {str(e)}")