nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1736 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import time
5
+ import os
6
+ import shutil
7
+ import math
8
+ from pathlib import Path
9
+ from typing import List, Tuple
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
+ from shapely.geometry import Polygon
14
+ import pyclipper
15
+
16
+ import mlx.core as mx
17
+ import mlx.nn as nn
18
+
19
+ ## =============================== PREPROCESSING CLASSES =============================== #
20
+
21
+
22
+ class DetResizeForTest(object):
23
+ def __init__(self, **kwargs):
24
+ super(DetResizeForTest, self).__init__()
25
+ self.resize_type = 0
26
+ if "image_shape" in kwargs:
27
+ self.image_shape = kwargs["image_shape"]
28
+ self.resize_type = 1
29
+ elif "limit_side_len" in kwargs:
30
+ self.limit_side_len = kwargs["limit_side_len"]
31
+ self.limit_type = kwargs.get("limit_type", "min")
32
+ elif "resize_long" in kwargs:
33
+ self.resize_type = 2
34
+ self.resize_long = kwargs.get("resize_long", 960)
35
+ else:
36
+ self.limit_side_len = 736
37
+ self.limit_type = "min"
38
+
39
+ def __call__(self, data):
40
+ img = data["image"]
41
+ src_h, src_w, _ = img.shape
42
+
43
+ if self.resize_type == 0:
44
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
45
+ elif self.resize_type == 2:
46
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
47
+ else:
48
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
49
+ data["image"] = img
50
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
51
+ return data
52
+
53
+ def resize_image_type1(self, img):
54
+ resize_h, resize_w = self.image_shape
55
+ ori_h, ori_w = img.shape[:2]
56
+ ratio_h = float(resize_h) / ori_h
57
+ ratio_w = float(resize_w) / ori_w
58
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
59
+ return img, [ratio_h, ratio_w]
60
+
61
+ def resize_image_type0(self, img):
62
+ limit_side_len = self.limit_side_len
63
+ h, w, c = img.shape
64
+
65
+ if self.limit_type == "max":
66
+ if max(h, w) > limit_side_len:
67
+ if h > w:
68
+ ratio = float(limit_side_len) / h
69
+ else:
70
+ ratio = float(limit_side_len) / w
71
+ else:
72
+ ratio = 1.0
73
+ elif self.limit_type == "min":
74
+ if min(h, w) < limit_side_len:
75
+ if h < w:
76
+ ratio = float(limit_side_len) / h
77
+ else:
78
+ ratio = float(limit_side_len) / w
79
+ else:
80
+ ratio = 1.0
81
+ elif self.limit_type == "resize_long":
82
+ ratio = float(limit_side_len) / max(h, w)
83
+ else:
84
+ raise Exception("not support limit type, image ")
85
+ resize_h = int(h * ratio)
86
+ resize_w = int(w * ratio)
87
+
88
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
89
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
90
+
91
+ try:
92
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
93
+ return None, (None, None)
94
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
95
+ except:
96
+ print(img.shape, resize_w, resize_h)
97
+ sys.exit(0)
98
+ ratio_h = resize_h / float(h)
99
+ ratio_w = resize_w / float(w)
100
+ return img, [ratio_h, ratio_w]
101
+
102
+ def resize_image_type2(self, img):
103
+ h, w, _ = img.shape
104
+ resize_w = w
105
+ resize_h = h
106
+
107
+ if resize_h > resize_w:
108
+ ratio = float(self.resize_long) / resize_h
109
+ else:
110
+ ratio = float(self.resize_long) / resize_w
111
+
112
+ resize_h = int(resize_h * ratio)
113
+ resize_w = int(resize_w * ratio)
114
+
115
+ max_stride = 128
116
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
117
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
118
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
119
+ ratio_h = resize_h / float(h)
120
+ ratio_w = resize_w / float(w)
121
+
122
+ return img, [ratio_h, ratio_w]
123
+
124
+
125
+ class NormalizeImage(object):
126
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
127
+ if isinstance(scale, str):
128
+ scale = eval(scale)
129
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
130
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
131
+ std = std if std is not None else [0.229, 0.224, 0.225]
132
+
133
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
134
+ self.mean = np.array(mean).reshape(shape).astype("float32")
135
+ self.std = np.array(std).reshape(shape).astype("float32")
136
+
137
+ def __call__(self, data):
138
+ img = data["image"]
139
+ from PIL import Image
140
+
141
+ if isinstance(img, Image.Image):
142
+ img = np.array(img)
143
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
144
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
145
+ return data
146
+
147
+
148
+ class ToCHWImage(object):
149
+ def __init__(self, **kwargs):
150
+ pass
151
+
152
+ def __call__(self, data):
153
+ img = data["image"]
154
+ from PIL import Image
155
+
156
+ if isinstance(img, Image.Image):
157
+ img = np.array(img)
158
+ data["image"] = img.transpose((2, 0, 1))
159
+ return data
160
+
161
+
162
+ class KeepKeys(object):
163
+ def __init__(self, keep_keys, **kwargs):
164
+ self.keep_keys = keep_keys
165
+
166
+ def __call__(self, data):
167
+ data_list = []
168
+ for key in self.keep_keys:
169
+ data_list.append(data[key])
170
+ return data_list
171
+
172
+
173
+ ## =============================== POSTPROCESSING CLASSES =============================== #
174
+
175
+
176
+ class DBPostProcess(object):
177
+ def __init__(
178
+ self,
179
+ thresh=0.3,
180
+ box_thresh=0.7,
181
+ max_candidates=1000,
182
+ unclip_ratio=2.0,
183
+ use_dilation=False,
184
+ score_mode="fast",
185
+ **kwargs,
186
+ ):
187
+ self.thresh = thresh
188
+ self.box_thresh = box_thresh
189
+ self.max_candidates = max_candidates
190
+ self.unclip_ratio = unclip_ratio
191
+ self.min_size = 3
192
+ self.score_mode = score_mode
193
+ assert score_mode in [
194
+ "slow",
195
+ "fast",
196
+ ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
197
+ self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
198
+
199
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
200
+ bitmap = _bitmap
201
+ height, width = bitmap.shape
202
+
203
+ outs = cv2.findContours(
204
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
205
+ )
206
+ if len(outs) == 3:
207
+ img, contours, _ = outs[0], outs[1], outs[2]
208
+ elif len(outs) == 2:
209
+ contours, _ = outs[0], outs[1]
210
+
211
+ num_contours = min(len(contours), self.max_candidates)
212
+
213
+ boxes = []
214
+ scores = []
215
+ for index in range(num_contours):
216
+ contour = contours[index]
217
+ points, sside = self.get_mini_boxes(contour)
218
+ if sside < self.min_size:
219
+ continue
220
+ points = np.array(points)
221
+ if self.score_mode == "fast":
222
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
223
+ else:
224
+ score = self.box_score_slow(pred, contour)
225
+ if self.box_thresh > score:
226
+ continue
227
+
228
+ box = self.unclip(points).reshape(-1, 1, 2)
229
+ box, sside = self.get_mini_boxes(box)
230
+ if sside < self.min_size + 2:
231
+ continue
232
+ box = np.array(box)
233
+
234
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
235
+ box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
236
+ boxes.append(box.astype(np.int16))
237
+ scores.append(score)
238
+ return np.array(boxes, dtype=np.int16), scores
239
+
240
+ def unclip(self, box):
241
+ unclip_ratio = self.unclip_ratio
242
+ poly = Polygon(box)
243
+ distance = poly.area * unclip_ratio / poly.length
244
+ offset = pyclipper.PyclipperOffset()
245
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
246
+ expanded = np.array(offset.Execute(distance))
247
+ return expanded
248
+
249
+ def get_mini_boxes(self, contour):
250
+ bounding_box = cv2.minAreaRect(contour)
251
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
252
+
253
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
254
+ if points[1][1] > points[0][1]:
255
+ index_1 = 0
256
+ index_4 = 1
257
+ else:
258
+ index_1 = 1
259
+ index_4 = 0
260
+ if points[3][1] > points[2][1]:
261
+ index_2 = 2
262
+ index_3 = 3
263
+ else:
264
+ index_2 = 3
265
+ index_3 = 2
266
+
267
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
268
+ return box, min(bounding_box[1])
269
+
270
+ def box_score_fast(self, bitmap, _box):
271
+ h, w = bitmap.shape[:2]
272
+ box = _box.copy()
273
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
274
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
275
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
276
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
277
+
278
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
279
+ box[:, 0] = box[:, 0] - xmin
280
+ box[:, 1] = box[:, 1] - ymin
281
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
282
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
283
+
284
+ def box_score_slow(self, bitmap, contour):
285
+ h, w = bitmap.shape[:2]
286
+ contour = contour.copy()
287
+ contour = np.reshape(contour, (-1, 2))
288
+
289
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
290
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
291
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
292
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
293
+
294
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
295
+ contour[:, 0] = contour[:, 0] - xmin
296
+ contour[:, 1] = contour[:, 1] - ymin
297
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
298
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
299
+
300
+ def __call__(self, outs_dict, shape_list):
301
+ pred = outs_dict["maps"]
302
+ if hasattr(pred, "numpy"): # Check if it has numpy method (for torch tensors)
303
+ pred = pred.numpy()
304
+ elif isinstance(pred, mx.array): # For MLX arrays
305
+ pred = np.array(pred)
306
+ pred = pred[:, 0, :, :]
307
+ segmentation = pred > self.thresh
308
+
309
+ boxes_batch = []
310
+ for batch_index in range(pred.shape[0]):
311
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
312
+ if self.dilation_kernel is not None:
313
+ mask = cv2.dilate(
314
+ np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel
315
+ )
316
+ else:
317
+ mask = segmentation[batch_index]
318
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
319
+ boxes_batch.append({"points": boxes})
320
+ return boxes_batch
321
+
322
+
323
+ class BaseRecLabelDecode(object):
324
+ def __init__(self, character_dict_path=None, use_space_char=False):
325
+ self.beg_str = "sos"
326
+ self.end_str = "eos"
327
+ self.character_str = []
328
+ if character_dict_path is None:
329
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
330
+ dict_character = list(self.character_str)
331
+ else:
332
+ with open(character_dict_path, "rb") as fin:
333
+ lines = fin.readlines()
334
+ for line in lines:
335
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
336
+ self.character_str.append(line)
337
+ if use_space_char:
338
+ self.character_str.append(" ")
339
+ dict_character = list(self.character_str)
340
+
341
+ dict_character = self.add_special_char(dict_character)
342
+ self.dict = {}
343
+ for i, char in enumerate(dict_character):
344
+ self.dict[char] = i
345
+ self.character = dict_character
346
+
347
+ def add_special_char(self, dict_character):
348
+ return dict_character
349
+
350
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
351
+ result_list = []
352
+ ignored_tokens = self.get_ignored_tokens()
353
+ batch_size = len(text_index)
354
+ for batch_idx in range(batch_size):
355
+ char_list = []
356
+ conf_list = []
357
+ for idx in range(len(text_index[batch_idx])):
358
+ if text_index[batch_idx][idx] in ignored_tokens:
359
+ continue
360
+ if is_remove_duplicate:
361
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]:
362
+ continue
363
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
364
+ if text_prob is not None:
365
+ conf_list.append(text_prob[batch_idx][idx])
366
+ else:
367
+ conf_list.append(1)
368
+ text = "".join(char_list)
369
+ # Check if conf_list is empty before calculating mean
370
+ confidence = np.mean(conf_list) if len(conf_list) > 0 else 0.0
371
+ result_list.append((text, confidence))
372
+ return result_list
373
+
374
+ def get_ignored_tokens(self):
375
+ return [0]
376
+
377
+
378
+ class CTCLabelDecode(BaseRecLabelDecode):
379
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
380
+ super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
381
+
382
+ def __call__(self, preds, label=None, *args, **kwargs):
383
+ if hasattr(preds, "numpy"): # Check if it has numpy method (for torch tensors)
384
+ preds = preds.numpy()
385
+ elif isinstance(preds, mx.array): # For MLX arrays
386
+ preds = np.array(preds)
387
+ preds_idx = preds.argmax(axis=2)
388
+ preds_prob = preds.max(axis=2)
389
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
390
+
391
+ if label is None:
392
+ return text
393
+ label = self.decode(label)
394
+ return text, label
395
+
396
+ def add_special_char(self, dict_character):
397
+ dict_character = ["blank"] + dict_character
398
+ return dict_character
399
+
400
+
401
+ ## =============================== CONFIG CLASS =============================== #
402
+
403
+
404
+ class Config:
405
+ def __init__(self, model_path):
406
+ # Base paths
407
+ self.base_dir = os.path.abspath(os.path.dirname(__file__))
408
+
409
+ self.model_cache_dir = model_path
410
+
411
+ # Detection settings
412
+ self.det_algorithm = "DB"
413
+ # Use downloaded model files instead of local paths
414
+ self.det_model_path = os.path.join(
415
+ self.model_cache_dir, "ch_ptocr_v4_det_infer.safetensors"
416
+ )
417
+ self.det_limit_side_len = 960
418
+ self.det_limit_type = "max"
419
+ self.det_db_thresh = 0.3
420
+ self.det_db_box_thresh = 0.6
421
+ self.det_db_unclip_ratio = 1.5
422
+ self.use_dilation = False
423
+ self.det_db_score_mode = "fast"
424
+
425
+ # Recognition settings
426
+ self.rec_algorithm = "CRNN"
427
+ # Use downloaded model files instead of local paths
428
+ self.rec_model_path = os.path.join(
429
+ self.model_cache_dir, "ch_ptocr_v4_rec_infer_f16.safetensors"
430
+ )
431
+ self.rec_char_type = "ch"
432
+ self.rec_batch_num = 6
433
+ self.max_text_length = 25
434
+ # Use downloaded character dictionary
435
+ self.rec_char_dict_path = os.path.join(self.model_cache_dir, "ppocr_keys_v1.txt")
436
+
437
+ # Other settings
438
+ self.use_space_char = True
439
+ self.drop_score = 0.5
440
+ self.limited_max_width = 1280
441
+ self.limited_min_width = 16
442
+ # Use downloaded font file
443
+ self.vis_font_path = os.path.join(self.model_cache_dir, "simfang.ttf")
444
+
445
+
446
+ ## =============================== MODEL COMPONENTS =============================== #
447
+
448
+
449
+ class LearnableAffineBlock(nn.Module):
450
+ def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
451
+ super().__init__()
452
+ # Match PyTorch parameter names exactly (lr_mult and lab_lr are ignored in MLX)
453
+ self.scale = mx.array([scale_value])
454
+ self.bias = mx.array([bias_value])
455
+
456
+ def __call__(self, x):
457
+ return self.scale * x + self.bias
458
+
459
+
460
+ class ConvBNLayer(nn.Module):
461
+ def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0):
462
+ super().__init__()
463
+ # lr_mult is ignored in MLX - it's a PyTorch/PaddlePaddle concept
464
+ padding = (kernel_size - 1) // 2
465
+ self.conv = nn.Conv2d(
466
+ in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False
467
+ )
468
+ self.bn = nn.BatchNorm(out_channels)
469
+
470
+ def __call__(self, x):
471
+ x = self.conv(x)
472
+ x = self.bn(x)
473
+ return x
474
+
475
+
476
+ class Act(nn.Module):
477
+ def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
478
+ super().__init__()
479
+ # lr_mult and lab_lr are ignored in MLX
480
+ self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
481
+
482
+ def __call__(self, x):
483
+ return self.lab(nn.hardswish(x))
484
+
485
+
486
+ class LearnableRepLayer(nn.Module):
487
+ def __init__(
488
+ self,
489
+ in_channels,
490
+ out_channels,
491
+ kernel_size,
492
+ stride=1,
493
+ groups=1,
494
+ num_conv_branches=4,
495
+ lr_mult=1.0,
496
+ lab_lr=0.1,
497
+ ):
498
+ super().__init__()
499
+ self.in_channels = in_channels
500
+ self.out_channels = out_channels
501
+ self.kernel_size = kernel_size
502
+ self.stride = stride
503
+ self.groups = groups
504
+ self.num_conv_branches = num_conv_branches
505
+
506
+ # Identity connection - only if channels match and stride is 1
507
+ self.identity = None
508
+ if out_channels == in_channels and stride == 1:
509
+ self.identity = nn.BatchNorm(in_channels)
510
+
511
+ # Create main conv branches using a list to match PyTorch structure
512
+ self.conv_kxk = []
513
+ for _ in range(num_conv_branches):
514
+ conv = ConvBNLayer(
515
+ in_channels, out_channels, kernel_size, stride, groups=groups, lr_mult=lr_mult
516
+ )
517
+ self.conv_kxk.append(conv)
518
+
519
+ # 1x1 conv branch - only if kernel > 1
520
+ self.conv_1x1 = None
521
+ if kernel_size > 1:
522
+ self.conv_1x1 = ConvBNLayer(
523
+ in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
524
+ )
525
+
526
+ self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
527
+ self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
528
+
529
+ def __call__(self, x):
530
+ out = 0
531
+
532
+ # Add identity if available
533
+ if self.identity is not None:
534
+ out = out + self.identity(x)
535
+
536
+ # Add 1x1 conv if available
537
+ if self.conv_1x1 is not None:
538
+ out = out + self.conv_1x1(x)
539
+
540
+ # Add all conv_kxk branches
541
+ for conv in self.conv_kxk:
542
+ out = out + conv(x)
543
+
544
+ # Apply learnable affine and activation
545
+ out = self.lab(out)
546
+ if self.stride != 2:
547
+ out = self.act(out)
548
+
549
+ return out
550
+
551
+
552
+ class SELayer(nn.Module):
553
+ def __init__(self, channel, reduction=4, lr_mult=1.0):
554
+ super().__init__()
555
+ # lr_mult is ignored in MLX
556
+ reduced_channels = max(1, channel // reduction)
557
+ self.conv1 = nn.Conv2d(channel, reduced_channels, 1)
558
+ self.conv2 = nn.Conv2d(reduced_channels, channel, 1)
559
+
560
+ def __call__(self, x):
561
+ identity = x
562
+ se_input = mx.mean(x, axis=(1, 2), keepdims=True) # Changed from (2, 3) to (1, 2)
563
+ se_out = nn.relu(self.conv1(se_input))
564
+ se_out = self.conv2(se_out)
565
+ se_out = mx.clip(se_out + 3.0, 0.0, 6.0) / 6.0
566
+ se_out = identity * se_out
567
+ return se_out
568
+
569
+
570
+ class LCNetV3Block(nn.Module):
571
+ def __init__(
572
+ self,
573
+ in_channels,
574
+ out_channels,
575
+ stride,
576
+ dw_size,
577
+ use_se=False,
578
+ conv_kxk_num=4,
579
+ lr_mult=1.0,
580
+ lab_lr=0.1,
581
+ ):
582
+ super().__init__()
583
+ self.use_se = use_se
584
+
585
+ # Depthwise convolution: in_channels -> in_channels with groups=in_channels
586
+ self.dw_conv = LearnableRepLayer(
587
+ in_channels=in_channels, # INPUT: 192
588
+ out_channels=in_channels, # OUTPUT: 192 (same as input for depthwise)
589
+ kernel_size=dw_size,
590
+ stride=stride,
591
+ groups=in_channels, # GROUPS: 192 (depthwise)
592
+ num_conv_branches=conv_kxk_num,
593
+ lr_mult=lr_mult,
594
+ lab_lr=lab_lr,
595
+ )
596
+
597
+ if use_se:
598
+ self.se = SELayer(in_channels, lr_mult=lr_mult)
599
+
600
+ # Pointwise convolution: in_channels -> out_channels with groups=1
601
+ self.pw_conv = LearnableRepLayer(
602
+ in_channels=in_channels, # INPUT: 192
603
+ out_channels=out_channels, # OUTPUT: 384
604
+ kernel_size=1,
605
+ stride=1,
606
+ groups=1, # GROUPS: 1 (pointwise)
607
+ num_conv_branches=conv_kxk_num,
608
+ lr_mult=lr_mult,
609
+ lab_lr=lab_lr,
610
+ )
611
+
612
+ def __call__(self, x):
613
+ x = self.dw_conv(x)
614
+ if self.use_se:
615
+ x = self.se(x)
616
+ x = self.pw_conv(x)
617
+ return x
618
+
619
+
620
+ def make_divisible(v, divisor=16):
621
+ return max(divisor, int(v + divisor / 2) // divisor * divisor)
622
+
623
+
624
+ # Add the NET_CONFIG_det at the top
625
+ NET_CONFIG_det = {
626
+ "blocks2": [[3, 16, 32, 1, False]],
627
+ "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
628
+ "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
629
+ "blocks5": [
630
+ [3, 128, 256, 2, False],
631
+ [5, 256, 256, 1, False],
632
+ [5, 256, 256, 1, False],
633
+ [5, 256, 256, 1, False],
634
+ [5, 256, 256, 1, False],
635
+ ],
636
+ "blocks6": [
637
+ [5, 256, 512, 2, True],
638
+ [5, 512, 512, 1, True],
639
+ [5, 512, 512, 1, False],
640
+ [5, 512, 512, 1, False],
641
+ ],
642
+ }
643
+
644
+ NET_CONFIG_rec = {
645
+ "blocks2": [[3, 16, 32, 1, False]],
646
+ "blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
647
+ "blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
648
+ "blocks5": [
649
+ [3, 128, 256, (1, 2), False],
650
+ [5, 256, 256, 1, False],
651
+ [5, 256, 256, 1, False],
652
+ [5, 256, 256, 1, False],
653
+ [5, 256, 256, 1, False],
654
+ ],
655
+ "blocks6": [
656
+ [5, 256, 512, (2, 1), True],
657
+ [5, 512, 512, 1, True],
658
+ [5, 512, 512, (2, 1), False],
659
+ [5, 512, 512, 1, False],
660
+ ],
661
+ }
662
+
663
+
664
+ ## =================================== for the backbone of text recognition ===================================
665
+ class PPLCNetV3(nn.Module):
666
+ def __init__(
667
+ self,
668
+ scale=1.0,
669
+ conv_kxk_num=4,
670
+ lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
671
+ lab_lr=0.1,
672
+ det=False,
673
+ **kwargs,
674
+ ):
675
+ super().__init__()
676
+ self.scale = scale
677
+ self.lr_mult_list = lr_mult_list
678
+ self.det = det
679
+ self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
680
+
681
+ assert isinstance(self.lr_mult_list, (list, tuple))
682
+ assert len(self.lr_mult_list) == 6
683
+
684
+ self.conv1 = ConvBNLayer(
685
+ in_channels=3,
686
+ out_channels=make_divisible(16 * scale),
687
+ kernel_size=3,
688
+ stride=2,
689
+ lr_mult=self.lr_mult_list[0],
690
+ )
691
+
692
+ # Build blocks2 - match PyTorch Sequential structure
693
+ blocks2_list = []
694
+ in_channels = make_divisible(16 * scale)
695
+ for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks2"]):
696
+ out_channels = make_divisible(out_c * scale)
697
+ block = LCNetV3Block(
698
+ in_channels=in_channels,
699
+ out_channels=out_channels,
700
+ dw_size=k,
701
+ stride=s,
702
+ use_se=se,
703
+ conv_kxk_num=conv_kxk_num,
704
+ lr_mult=self.lr_mult_list[1],
705
+ lab_lr=lab_lr,
706
+ )
707
+ blocks2_list.append(block)
708
+ in_channels = out_channels
709
+ self.blocks2 = blocks2_list
710
+
711
+ # Build blocks3
712
+ blocks3_list = []
713
+ for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks3"]):
714
+ out_channels = make_divisible(out_c * scale)
715
+ block = LCNetV3Block(
716
+ in_channels=in_channels,
717
+ out_channels=out_channels,
718
+ dw_size=k,
719
+ stride=s,
720
+ use_se=se,
721
+ conv_kxk_num=conv_kxk_num,
722
+ lr_mult=self.lr_mult_list[2],
723
+ lab_lr=lab_lr,
724
+ )
725
+ blocks3_list.append(block)
726
+ in_channels = out_channels
727
+ self.blocks3 = blocks3_list
728
+
729
+ # Build blocks4
730
+ blocks4_list = []
731
+ for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks4"]):
732
+ out_channels = make_divisible(out_c * scale)
733
+ block = LCNetV3Block(
734
+ in_channels=in_channels,
735
+ out_channels=out_channels,
736
+ dw_size=k,
737
+ stride=s,
738
+ use_se=se,
739
+ conv_kxk_num=conv_kxk_num,
740
+ lr_mult=self.lr_mult_list[3],
741
+ lab_lr=lab_lr,
742
+ )
743
+ blocks4_list.append(block)
744
+ in_channels = out_channels
745
+ self.blocks4 = blocks4_list
746
+
747
+ # Build blocks5
748
+ blocks5_list = []
749
+ for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks5"]):
750
+ out_channels = make_divisible(out_c * scale)
751
+ block = LCNetV3Block(
752
+ in_channels=in_channels,
753
+ out_channels=out_channels,
754
+ dw_size=k,
755
+ stride=s,
756
+ use_se=se,
757
+ conv_kxk_num=conv_kxk_num,
758
+ lr_mult=self.lr_mult_list[4],
759
+ lab_lr=lab_lr,
760
+ )
761
+ blocks5_list.append(block)
762
+ in_channels = out_channels
763
+ self.blocks5 = blocks5_list
764
+
765
+ # Build blocks6
766
+ blocks6_list = []
767
+ for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks6"]):
768
+ out_channels = make_divisible(out_c * scale)
769
+ block = LCNetV3Block(
770
+ in_channels=in_channels,
771
+ out_channels=out_channels,
772
+ dw_size=k,
773
+ stride=s,
774
+ use_se=se,
775
+ conv_kxk_num=conv_kxk_num,
776
+ lr_mult=self.lr_mult_list[5],
777
+ lab_lr=lab_lr,
778
+ )
779
+ blocks6_list.append(block)
780
+ in_channels = out_channels
781
+ self.blocks6 = blocks6_list
782
+
783
+ self.out_channels = make_divisible(512 * scale)
784
+
785
+ if self.det:
786
+ mv_c = [16, 24, 56, 480]
787
+ self.out_channels = [
788
+ make_divisible(self.net_config["blocks3"][-1][2] * scale),
789
+ make_divisible(self.net_config["blocks4"][-1][2] * scale),
790
+ make_divisible(self.net_config["blocks5"][-1][2] * scale),
791
+ make_divisible(self.net_config["blocks6"][-1][2] * scale),
792
+ ]
793
+
794
+ self.layer_list = []
795
+ for i in range(4):
796
+ layer = nn.Conv2d(self.out_channels[i], int(mv_c[i] * scale), 1, bias=True)
797
+ self.layer_list.append(layer)
798
+
799
+ self.out_channels = [
800
+ int(mv_c[0] * scale),
801
+ int(mv_c[1] * scale),
802
+ int(mv_c[2] * scale),
803
+ int(mv_c[3] * scale),
804
+ ]
805
+
806
+ def __call__(self, x):
807
+ out_list = []
808
+
809
+ ## Transpose to match the format required by MLX
810
+ x = mx.transpose(x, (0, 2, 3, 1))
811
+ x = self.conv1(x)
812
+
813
+ for block in self.blocks2:
814
+ x = block(x)
815
+
816
+ for block in self.blocks3:
817
+ x = block(x)
818
+ out_list.append(x)
819
+
820
+ for block in self.blocks4:
821
+ x = block(x)
822
+ out_list.append(x)
823
+
824
+ for block in self.blocks5:
825
+ x = block(x)
826
+ out_list.append(x)
827
+
828
+ for block in self.blocks6:
829
+ x = block(x)
830
+ out_list.append(x)
831
+
832
+ if self.det:
833
+ out_list[0] = self.layer_list[0](out_list[0])
834
+ out_list[1] = self.layer_list[1](out_list[1])
835
+ out_list[2] = self.layer_list[2](out_list[2])
836
+ out_list[3] = self.layer_list[3](out_list[3])
837
+ return out_list
838
+
839
+ B, H, W, C = x.shape
840
+
841
+ # Ensure dimensions are divisible by kernel size for clean pooling
842
+ H_out = H // 3
843
+ W_out = W // 2
844
+
845
+ # Trim to make dimensions divisible
846
+ x = x[:, : H_out * 3, : W_out * 2, :]
847
+
848
+ # Reshape for 3x2 average pooling
849
+ x = mx.reshape(x, (B, H_out, 3, W_out, 2, C))
850
+ x = mx.mean(x, axis=(2, 4)) # Average over the 3x2 kernel
851
+ return x
852
+
853
+
854
+ ## =================================== for the neck of text detection ===================================
855
+ class IndexedContainer(nn.Module):
856
+ """Container that creates numbered attributes for MLX"""
857
+
858
+ def __init__(self):
859
+ super().__init__()
860
+ self._modules = []
861
+
862
+ def add_module(self, module):
863
+ idx = len(self._modules)
864
+ setattr(self, str(idx), module)
865
+ self._modules.append(module)
866
+ return idx
867
+
868
+ def __getitem__(self, idx):
869
+ return getattr(self, str(idx))
870
+
871
+
872
+ class SEModule(nn.Module):
873
+ def __init__(self, in_channels, reduction=4):
874
+ super().__init__()
875
+ reduced_channels = in_channels // reduction
876
+ self.conv1 = nn.Conv2d(in_channels, reduced_channels, 1, bias=True)
877
+ self.conv2 = nn.Conv2d(reduced_channels, in_channels, 1, bias=True)
878
+
879
+ def __call__(self, inputs):
880
+ outputs = mx.mean(inputs, axis=(1, 2), keepdims=True)
881
+ outputs = self.conv1(outputs)
882
+ outputs = nn.relu(outputs)
883
+ outputs = self.conv2(outputs)
884
+ # PaddlePaddle hard_sigmoid: F.relu6(1.2 * x + 3.) / 6.
885
+ outputs = mx.clip(1.2 * outputs + 3.0, 0.0, 6.0) / 6.0 # PaddlePaddle hard_sigmoid
886
+ outputs = inputs * outputs
887
+ return outputs
888
+
889
+
890
+ class RSELayer(nn.Module):
891
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
892
+ super().__init__()
893
+ padding = kernel_size // 2
894
+ self.in_conv = nn.Conv2d(
895
+ in_channels, out_channels, kernel_size, padding=padding, bias=False
896
+ )
897
+ self.se_block = SEModule(out_channels)
898
+ self.shortcut = shortcut
899
+
900
+ def __call__(self, x):
901
+ conv_out = self.in_conv(x)
902
+ if self.shortcut:
903
+ return conv_out + self.se_block(conv_out)
904
+ else:
905
+ return self.se_block(conv_out)
906
+
907
+
908
+ class RSEFPN(nn.Module):
909
+ def __init__(self, in_channels, out_channels=96, shortcut=True):
910
+ super().__init__()
911
+ self.out_channels = out_channels
912
+
913
+ # Create container modules that inherit from nn.Module
914
+ self.ins_conv = IndexedContainer()
915
+ self.inp_conv = IndexedContainer()
916
+
917
+ # Add modules - this should create the correct parameter names
918
+ for i, in_ch in enumerate(in_channels):
919
+ self.ins_conv.add_module(
920
+ RSELayer(in_ch, out_channels, kernel_size=1, shortcut=shortcut)
921
+ )
922
+ self.inp_conv.add_module(
923
+ RSELayer(out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut)
924
+ )
925
+
926
+ def __call__(self, x):
927
+ c2, c3, c4, c5 = x
928
+
929
+ in5 = self.ins_conv[3](c5)
930
+ in4 = self.ins_conv[2](c4)
931
+ in3 = self.ins_conv[1](c3)
932
+ in2 = self.ins_conv[0](c2)
933
+
934
+ # Upsample both H and W dimensions
935
+ up_in5 = mx.repeat(in5, 2, axis=1)
936
+ up_in5 = mx.repeat(up_in5, 2, axis=2)
937
+ out4 = in4 + up_in5
938
+
939
+ up_out4 = mx.repeat(out4, 2, axis=1)
940
+ up_out4 = mx.repeat(up_out4, 2, axis=2)
941
+ out3 = in3 + up_out4
942
+
943
+ up_out3 = mx.repeat(out3, 2, axis=1)
944
+ up_out3 = mx.repeat(up_out3, 2, axis=2)
945
+ out2 = in2 + up_out3
946
+
947
+ p5 = self.inp_conv[3](in5)
948
+ p4 = self.inp_conv[2](out4)
949
+ p3 = self.inp_conv[1](out3)
950
+ p2 = self.inp_conv[0](out2)
951
+
952
+ # Use target size from p2 for consistent upsampling
953
+ target_h, target_w = p2.shape[1], p2.shape[2]
954
+
955
+ # MLX doesn't have F.upsample, but we can calculate target sizes and use repeat more carefully
956
+ # P5: upsample by 8x to match p2 size
957
+ p5_h, p5_w = p5.shape[1], p5.shape[2]
958
+ p5_target_h, p5_target_w = min(target_h, p5_h * 8), min(target_w, p5_w * 8)
959
+
960
+ # Calculate exact repeat factors
961
+ h_repeat_p5 = p5_target_h // p5_h
962
+ w_repeat_p5 = p5_target_w // p5_w
963
+ p5 = mx.repeat(p5, h_repeat_p5, axis=1)
964
+ p5 = mx.repeat(p5, w_repeat_p5, axis=2)
965
+ p5 = p5[:, :target_h, :target_w]
966
+
967
+ # P4: upsample by 4x to match p2 size
968
+ p4_h, p4_w = p4.shape[1], p4.shape[2]
969
+ p4_target_h, p4_target_w = min(target_h, p4_h * 4), min(target_w, p4_w * 4)
970
+
971
+ h_repeat_p4 = p4_target_h // p4_h
972
+ w_repeat_p4 = p4_target_w // p4_w
973
+ p4 = mx.repeat(p4, h_repeat_p4, axis=1)
974
+ p4 = mx.repeat(p4, w_repeat_p4, axis=2)
975
+ p4 = p4[:, :target_h, :target_w]
976
+
977
+ # P3: upsample by 2x to match p2 size
978
+ p3_h, p3_w = p3.shape[1], p3.shape[2]
979
+ p3_target_h, p3_target_w = min(target_h, p3_h * 2), min(target_w, p3_w * 2)
980
+
981
+ h_repeat_p3 = p3_target_h // p3_h
982
+ w_repeat_p3 = p3_target_w // p3_w
983
+ p3 = mx.repeat(p3, h_repeat_p3, axis=1)
984
+ p3 = mx.repeat(p3, w_repeat_p3, axis=2)
985
+ p3 = p3[:, :target_h, :target_w]
986
+
987
+ fuse = mx.concatenate([p5, p4, p3, p2], axis=-1)
988
+ return fuse
989
+
990
+
991
+ ## =================================== for the head of text detection ===================================
992
+ class DetectionHead(nn.Module):
993
+ def __init__(self, in_channels):
994
+ super().__init__()
995
+ self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 3, padding=1, bias=False)
996
+ self.conv_bn1 = nn.BatchNorm(in_channels // 4)
997
+
998
+ self.conv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, stride=2)
999
+ self.conv_bn2 = nn.BatchNorm(in_channels // 4)
1000
+
1001
+ self.conv3 = nn.ConvTranspose2d(in_channels // 4, 1, 2, stride=2)
1002
+
1003
+ def __call__(self, x):
1004
+ x = nn.relu(self.conv_bn1(self.conv1(x)))
1005
+ x = nn.relu(self.conv_bn2(self.conv2(x)))
1006
+ x = self.conv3(x)
1007
+ x = nn.sigmoid(x)
1008
+ return x
1009
+
1010
+
1011
+ class DBHead(nn.Module):
1012
+ def __init__(self, in_channels, k=50):
1013
+ super().__init__()
1014
+ self.k = k
1015
+ self.binarize = DetectionHead(in_channels) # First branch
1016
+ self.thresh = DetectionHead(in_channels) # Second branch (was missing!)
1017
+
1018
+ def step_function(self, x, y):
1019
+ return 1.0 / (1.0 + mx.exp(-self.k * (x - y)))
1020
+
1021
+ def __call__(self, x):
1022
+ shrink_maps = self.binarize(x)
1023
+ shrink_maps = mx.transpose(shrink_maps, (0, 3, 1, 2))
1024
+ return {"maps": shrink_maps}
1025
+
1026
+
1027
+ class TextDetector(nn.Module):
1028
+ def __init__(self, args):
1029
+ super().__init__()
1030
+
1031
+ self.preprocess_op = [
1032
+ DetResizeForTest(
1033
+ limit_side_len=args.det_limit_side_len, limit_type=args.det_limit_type
1034
+ ),
1035
+ NormalizeImage(
1036
+ mean=[0.485, 0.456, 0.406],
1037
+ std=[0.229, 0.224, 0.225],
1038
+ scale=1.0 / 255.0,
1039
+ order="hwc",
1040
+ ),
1041
+ ToCHWImage(),
1042
+ KeepKeys(keep_keys=["image", "shape"]),
1043
+ ]
1044
+
1045
+ postprocess_params = {
1046
+ "thresh": args.det_db_thresh,
1047
+ "box_thresh": args.det_db_box_thresh,
1048
+ "max_candidates": 1000,
1049
+ "unclip_ratio": args.det_db_unclip_ratio,
1050
+ "use_dilation": args.use_dilation,
1051
+ "score_mode": args.det_db_score_mode,
1052
+ }
1053
+ self.postprocess_op = DBPostProcess(**postprocess_params)
1054
+
1055
+ # Match exact PyTorch model structure
1056
+ backbone_config = {"scale": 0.75, "det": True, "in_channels": 3}
1057
+ self.backbone = PPLCNetV3(**backbone_config)
1058
+
1059
+ # Use correct neck config - the backbone outputs these channels
1060
+ neck_config = {
1061
+ "out_channels": 96,
1062
+ "shortcut": True,
1063
+ "in_channels": self.backbone.out_channels, # Should be [12, 18, 42, 360]
1064
+ }
1065
+ self.neck = RSEFPN(**neck_config)
1066
+
1067
+ head_config = {"k": 50, "in_channels": 96}
1068
+ self.head = DBHead(**head_config)
1069
+
1070
+ def order_points_clockwise(self, pts):
1071
+ """
1072
+ reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
1073
+ # sort the points based on their x-coordinates
1074
+ """
1075
+ xSorted = pts[np.argsort(pts[:, 0]), :]
1076
+
1077
+ # grab the left-most and right-most points from the sorted
1078
+ # x-roodinate points
1079
+ leftMost = xSorted[:2, :]
1080
+ rightMost = xSorted[2:, :]
1081
+
1082
+ # now, sort the left-most coordinates according to their
1083
+ # y-coordinates so we can grab the top-left and bottom-left
1084
+ # points, respectively
1085
+ leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
1086
+ (tl, bl) = leftMost
1087
+
1088
+ rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
1089
+ (tr, br) = rightMost
1090
+
1091
+ rect = np.array([tl, tr, br, bl], dtype="float32")
1092
+ return rect
1093
+
1094
+ def clip_det_res(self, points, img_height, img_width):
1095
+ for pno in range(points.shape[0]):
1096
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
1097
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
1098
+ return points
1099
+
1100
+ def filter_tag_det_res(self, dt_boxes, image_shape):
1101
+ img_height, img_width = image_shape[0:2]
1102
+ dt_boxes_new = []
1103
+ for box in dt_boxes:
1104
+ box = self.order_points_clockwise(box)
1105
+ box = self.clip_det_res(box, img_height, img_width)
1106
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
1107
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
1108
+ if rect_width <= 3 or rect_height <= 3:
1109
+ continue
1110
+ dt_boxes_new.append(box)
1111
+ return np.array(dt_boxes_new) if dt_boxes_new else np.array([])
1112
+
1113
+ def forward(self, x):
1114
+ features = self.backbone(x)
1115
+ neck_out = self.neck(features)
1116
+ head_out = self.head(neck_out)
1117
+ return head_out
1118
+
1119
+ def __call__(self, img):
1120
+ ori_im = img.copy()
1121
+ data = {"image": img}
1122
+
1123
+ for op in self.preprocess_op:
1124
+ data = op(data)
1125
+
1126
+ img, shape_list = data
1127
+ if img is None:
1128
+ return None, 0
1129
+
1130
+ img = np.expand_dims(img, axis=0)
1131
+ shape_list = np.expand_dims(shape_list, axis=0)
1132
+
1133
+ inp = mx.array(img.copy())
1134
+ outputs = self.forward(inp)
1135
+ preds = {"maps": np.array(outputs["maps"])}
1136
+
1137
+ post_result = self.postprocess_op(preds, shape_list)
1138
+ dt_boxes = post_result[0]["points"] if post_result else []
1139
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
1140
+ return dt_boxes
1141
+
1142
+
1143
+ def test_detector(args):
1144
+ img = np.load(
1145
+ "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/det_inp.npy"
1146
+ )
1147
+ detector = TextDetector(args)
1148
+ detector.eval()
1149
+ detector.load_weights(
1150
+ "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_det_infer.safetensors"
1151
+ )
1152
+ boxes = detector(img)
1153
+ print(f"Detected {len(boxes)} boxes")
1154
+
1155
+
1156
+ ## ==================================== Now the text det works ==================================== #
1157
+
1158
+
1159
+ ## ==================================== Text Recognition Components ==================================== #
1160
+
1161
+
1162
+ class Im2Seq(nn.Module):
1163
+ def __init__(self, in_channels, **kwargs):
1164
+ super().__init__()
1165
+ self.out_channels = in_channels
1166
+
1167
+ def __call__(self, x):
1168
+ B, H, W, C = x.shape # MLX format: (B, H, W, C)
1169
+ assert H == 1
1170
+ x = mx.reshape(x, (B, H * W, C)) # (B, W, C) for sequence
1171
+ return x
1172
+
1173
+
1174
+ class SVTRConvBNLayer(nn.Module):
1175
+ def __init__(
1176
+ self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, groups=1, act="swish"
1177
+ ):
1178
+ super().__init__()
1179
+ self.conv = nn.Conv2d(
1180
+ in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False
1181
+ )
1182
+ self.norm = nn.BatchNorm(out_channels)
1183
+ self.act = act
1184
+
1185
+ def __call__(self, x):
1186
+ x = self.conv(x)
1187
+ x = self.norm(x)
1188
+ if self.act == "swish":
1189
+ x = x * mx.sigmoid(x)
1190
+ return x
1191
+
1192
+
1193
+ class EncoderWithSVTR(nn.Module):
1194
+ def __init__(
1195
+ self,
1196
+ in_channels,
1197
+ dims=64,
1198
+ depth=2,
1199
+ hidden_dims=120,
1200
+ kernel_size=[3, 3],
1201
+ use_guide=False,
1202
+ **kwargs,
1203
+ ):
1204
+ super().__init__()
1205
+ self.depth = depth
1206
+ self.use_guide = use_guide
1207
+
1208
+ # Match original PyTorch structure exactly
1209
+ self.conv1 = SVTRConvBNLayer(
1210
+ in_channels,
1211
+ in_channels // 8,
1212
+ kernel_size=(1, 3), # Match actual model: (1, 3) not 3
1213
+ padding=(0, 1), # Match actual model: (0, 1) not 1
1214
+ act="swish",
1215
+ )
1216
+ self.conv2 = SVTRConvBNLayer(
1217
+ in_channels // 8, hidden_dims, kernel_size=1, padding=0, act="swish"
1218
+ )
1219
+
1220
+ # SVTR blocks - ADD THIS BACK!
1221
+ self.svtr_block = []
1222
+ for i in range(depth):
1223
+ block = Block(
1224
+ dim=hidden_dims,
1225
+ num_heads=8,
1226
+ mixer="Global",
1227
+ mlp_ratio=2.0,
1228
+ qkv_bias=True, # Change from False to True
1229
+ act_layer="swish", # Add this
1230
+ **kwargs,
1231
+ )
1232
+ setattr(self, f"svtr_block_{i}", block)
1233
+ self.svtr_block.append(block)
1234
+
1235
+ self.norm = nn.LayerNorm(hidden_dims)
1236
+
1237
+ self.conv3 = SVTRConvBNLayer(
1238
+ hidden_dims, in_channels, kernel_size=1, padding=0, act="swish"
1239
+ )
1240
+ self.conv4 = SVTRConvBNLayer(
1241
+ 2 * in_channels, in_channels // 8, kernel_size=3, padding=1, act="swish"
1242
+ )
1243
+ self.conv1x1 = SVTRConvBNLayer(
1244
+ in_channels // 8, dims, kernel_size=1, padding=0, act="swish"
1245
+ )
1246
+
1247
+ self.out_channels = dims
1248
+
1249
+ def __call__(self, x):
1250
+ # Short cut
1251
+ h = x
1252
+
1253
+ # Reduce dim
1254
+ z = self.conv1(x)
1255
+ z = self.conv2(z)
1256
+
1257
+ # SVTR global blocks
1258
+ B, H, W, C = z.shape
1259
+ z = mx.reshape(z, (B, H * W, C)) # Flatten spatial dims
1260
+
1261
+ for block in self.svtr_block:
1262
+ z = block(z)
1263
+
1264
+ z = self.norm(z)
1265
+
1266
+ # Reshape back - CRITICAL: use original H, W
1267
+ z = mx.reshape(z, (B, H, W, C)) # Use the H, W from before SVTR blocks
1268
+ z = self.conv3(z)
1269
+
1270
+ # Concatenate with shortcut - dimensions should match now
1271
+ z = mx.concatenate([h, z], axis=-1)
1272
+ z = self.conv4(z)
1273
+ z = self.conv1x1(z)
1274
+
1275
+ return z
1276
+
1277
+
1278
+ class Mlp(nn.Module):
1279
+ def __init__(
1280
+ self, in_features, hidden_features=None, out_features=None, act_layer="swish", drop=0.0
1281
+ ):
1282
+ super().__init__()
1283
+ out_features = out_features or in_features
1284
+ hidden_features = hidden_features or in_features
1285
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True) # Add bias=True
1286
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True) # Add bias=True
1287
+ self.act_layer = act_layer
1288
+
1289
+ def __call__(self, x):
1290
+ x = self.fc1(x)
1291
+ # Use swish activation to match PyTorch
1292
+ if self.act_layer == "swish":
1293
+ x = x * mx.sigmoid(x) # Swish activation
1294
+ elif self.act_layer == "gelu":
1295
+ x = nn.gelu(x)
1296
+ x = self.fc2(x)
1297
+ return x
1298
+
1299
+
1300
+ class Attention(nn.Module):
1301
+ def __init__(
1302
+ self,
1303
+ dim,
1304
+ num_heads=8,
1305
+ mixer="Global",
1306
+ HW=None,
1307
+ local_k=[7, 11],
1308
+ qkv_bias=False,
1309
+ qk_scale=None,
1310
+ attn_drop=0.0,
1311
+ proj_drop=0.0,
1312
+ ):
1313
+ super().__init__()
1314
+ self.num_heads = num_heads
1315
+ head_dim = dim // num_heads
1316
+ self.scale = qk_scale or head_dim**-0.5
1317
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1318
+ self.proj = nn.Linear(dim, dim, bias=True)
1319
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
1320
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
1321
+ self.HW = HW
1322
+ self.mixer = mixer
1323
+
1324
+ # Set N and C if HW is provided (like in PyTorch)
1325
+ if HW is not None:
1326
+ H = HW[0]
1327
+ W = HW[1]
1328
+ self.N = H * W
1329
+ self.C = dim
1330
+
1331
+ def __call__(self, x):
1332
+ if self.HW is not None:
1333
+ N = self.N
1334
+ C = self.C
1335
+ else:
1336
+ _, N, C = x.shape
1337
+
1338
+ qkv = self.qkv(x)
1339
+ qkv = qkv.reshape((-1, N, 3, self.num_heads, C // self.num_heads))
1340
+ qkv = mx.transpose(qkv, (2, 0, 3, 1, 4)) # permute(2, 0, 3, 1, 4)
1341
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
1342
+
1343
+ attn = q @ mx.transpose(k, (0, 1, 3, 2)) # q.matmul(k.permute(0, 1, 3, 2))
1344
+ if self.mixer == "Local":
1345
+ # attn += self.mask # Would need to implement mask for Local
1346
+ pass
1347
+ attn = mx.softmax(attn, axis=-1) # nn.functional.softmax(attn, dim=-1)
1348
+ attn = self.attn_drop(attn)
1349
+
1350
+ x = (attn @ v).transpose(0, 2, 1, 3).reshape((-1, N, C)) # Match exact reshape
1351
+ x = self.proj(x)
1352
+ x = self.proj_drop(x)
1353
+ return x
1354
+
1355
+
1356
+ class Block(nn.Module):
1357
+ def __init__(
1358
+ self,
1359
+ dim,
1360
+ num_heads,
1361
+ mixer="Global",
1362
+ local_mixer=[7, 11],
1363
+ HW=None,
1364
+ mlp_ratio=4.0,
1365
+ qkv_bias=False,
1366
+ qk_scale=None,
1367
+ drop=0.0,
1368
+ attn_drop=0.0,
1369
+ drop_path=0.0,
1370
+ act_layer="gelu",
1371
+ norm_layer="nn.LayerNorm",
1372
+ epsilon=1e-6,
1373
+ prenorm=False, # Set to False to match PyTorch
1374
+ ):
1375
+ super().__init__()
1376
+ self.norm1 = nn.LayerNorm(dim, eps=epsilon)
1377
+ self.mixer = Attention(
1378
+ dim,
1379
+ num_heads=num_heads,
1380
+ mixer=mixer,
1381
+ HW=HW,
1382
+ local_k=local_mixer,
1383
+ qkv_bias=qkv_bias,
1384
+ qk_scale=qk_scale,
1385
+ attn_drop=attn_drop,
1386
+ proj_drop=drop,
1387
+ )
1388
+
1389
+ self.norm2 = nn.LayerNorm(dim, eps=epsilon)
1390
+ mlp_hidden_dim = int(dim * mlp_ratio)
1391
+ self.mlp = Mlp(
1392
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
1393
+ )
1394
+ self.prenorm = prenorm
1395
+ self.drop_path = drop_path
1396
+
1397
+ def __call__(self, x):
1398
+ if self.prenorm:
1399
+ x = self.norm1(x + self._drop_path(self.mixer(x)))
1400
+ x = self.norm2(x + self._drop_path(self.mlp(x)))
1401
+ else:
1402
+ # This is the path that will be taken (prenorm=False)
1403
+ x = x + self._drop_path(self.mixer(self.norm1(x)))
1404
+ x = x + self._drop_path(self.mlp(self.norm2(x)))
1405
+ return x
1406
+
1407
+ def _drop_path(self, x):
1408
+ # For inference, drop_path is disabled, so just return x
1409
+ return x
1410
+
1411
+
1412
+ class SequenceEncoder(nn.Module):
1413
+ def __init__(self, in_channels, encoder_type="svtr", **kwargs):
1414
+ super().__init__()
1415
+ self.encoder_type = encoder_type.lower()
1416
+ self.encoder_reshape = Im2Seq(in_channels)
1417
+
1418
+ if self.encoder_type == "svtr":
1419
+ self.encoder = EncoderWithSVTR(in_channels, **kwargs)
1420
+ self.out_channels = self.encoder.out_channels
1421
+ self.only_reshape = False
1422
+ else:
1423
+ self.out_channels = in_channels
1424
+ self.only_reshape = True
1425
+
1426
+ def __call__(self, x):
1427
+ if self.encoder_type == "svtr":
1428
+ # For SVTR: encoder works on 2D data first, then reshape
1429
+ x = self.encoder(x) # x is still (B, H, W, C)
1430
+ x = self.encoder_reshape(x) # Now reshape to (B, W, C)
1431
+ return x
1432
+ else:
1433
+ # For others: reshape first, then encoder
1434
+ x = self.encoder_reshape(x)
1435
+ if not self.only_reshape:
1436
+ x = self.encoder(x)
1437
+ return x
1438
+
1439
+
1440
+ class CTCHead(nn.Module):
1441
+ def __init__(
1442
+ self,
1443
+ in_channels,
1444
+ out_channels,
1445
+ fc_decay=0.0004,
1446
+ mid_channels=None,
1447
+ return_feats=False,
1448
+ **kwargs,
1449
+ ):
1450
+ super().__init__()
1451
+ self.return_feats = return_feats
1452
+ self.mid_channels = mid_channels
1453
+
1454
+ if mid_channels is None:
1455
+ self.fc = nn.Linear(in_channels, out_channels, bias=True)
1456
+ else:
1457
+ self.fc1 = nn.Linear(in_channels, mid_channels, bias=True)
1458
+ self.fc2 = nn.Linear(mid_channels, out_channels, bias=True)
1459
+
1460
+ self.out_channels = out_channels
1461
+
1462
+ def __call__(self, x):
1463
+ if self.mid_channels is None:
1464
+ predicts = self.fc(x)
1465
+ else:
1466
+ x = self.fc1(x)
1467
+ predicts = self.fc2(x)
1468
+
1469
+ if self.return_feats:
1470
+ result = (x, predicts)
1471
+ else:
1472
+ result = predicts
1473
+
1474
+ # Apply softmax for inference using MLX
1475
+ if not self.training:
1476
+ predicts = mx.softmax(predicts, axis=2)
1477
+ result = predicts
1478
+
1479
+ return result
1480
+
1481
+
1482
+ class MultiHead(nn.Module):
1483
+ def __init__(self, in_channels, out_channels_list, head_list, **kwargs):
1484
+ super().__init__()
1485
+ self.head_list = head_list
1486
+
1487
+ for idx, head_name in enumerate(self.head_list):
1488
+ name = list(head_name)[0]
1489
+ if name == "CTCHead":
1490
+ # No separate encoder_reshape - it's handled inside SequenceEncoder
1491
+ neck_args = self.head_list[idx][name]["Neck"].copy()
1492
+ encoder_type = neck_args.pop("name")
1493
+ self.ctc_encoder = SequenceEncoder(
1494
+ in_channels=in_channels, encoder_type=encoder_type, **neck_args
1495
+ )
1496
+ # CTC head
1497
+ head_args = self.head_list[idx][name].get("Head", {})
1498
+ if head_args is None:
1499
+ head_args = {}
1500
+ self.ctc_head = CTCHead(
1501
+ in_channels=self.ctc_encoder.out_channels,
1502
+ out_channels=out_channels_list["CTCLabelDecode"],
1503
+ **head_args,
1504
+ )
1505
+
1506
+ def __call__(self, x, data=None):
1507
+ # Direct call to ctc_encoder - let it handle reshaping internally
1508
+ ctc_encoder = self.ctc_encoder(x)
1509
+ ctc_out = self.ctc_head(ctc_encoder)
1510
+
1511
+ # Eval mode
1512
+ if not self.training:
1513
+ return ctc_out
1514
+
1515
+ head_out = dict()
1516
+ head_out["ctc"] = ctc_out
1517
+ head_out["res"] = ctc_out
1518
+ head_out["ctc_neck"] = ctc_encoder
1519
+ return head_out
1520
+
1521
+
1522
+ class TextRecognizer(nn.Module):
1523
+ def __init__(self, args, **kwargs):
1524
+ super().__init__()
1525
+
1526
+ self.rec_image_shape = [3, 48, 320]
1527
+ self.rec_batch_num = args.rec_batch_num
1528
+ self.limited_max_width = args.limited_max_width
1529
+ self.limited_min_width = args.limited_min_width
1530
+
1531
+ # Character dictionary path
1532
+ postprocess_params = {
1533
+ "character_type": args.rec_char_type,
1534
+ "character_dict_path": args.rec_char_dict_path,
1535
+ "use_space_char": args.use_space_char,
1536
+ }
1537
+ self.postprocess_op = CTCLabelDecode(**postprocess_params)
1538
+
1539
+ # Get character number
1540
+ char_num = len(getattr(self.postprocess_op, "character"))
1541
+
1542
+ # Recognition backbone - reuse existing PPLCNetV3 (already handles transpose)
1543
+ self.backbone = PPLCNetV3(scale=0.95, det=False)
1544
+
1545
+ # Recognition head
1546
+ head_config = {
1547
+ "head_list": [
1548
+ {
1549
+ "CTCHead": {
1550
+ "Neck": {
1551
+ "name": "svtr",
1552
+ "dims": 120,
1553
+ "depth": 2,
1554
+ "hidden_dims": 120,
1555
+ "kernel_size": [1, 3],
1556
+ "use_guide": True,
1557
+ },
1558
+ "Head": {"fc_decay": 1e-05},
1559
+ }
1560
+ },
1561
+ ],
1562
+ "out_channels_list": {
1563
+ "CTCLabelDecode": char_num,
1564
+ },
1565
+ "in_channels": 480, # PPLCNetV3 output channels
1566
+ }
1567
+ self.head = MultiHead(**head_config)
1568
+
1569
+ def resize_norm_img(self, img, max_wh_ratio):
1570
+ imgC, imgH, imgW = self.rec_image_shape
1571
+
1572
+ assert imgC == img.shape[2]
1573
+ max_wh_ratio = max(max_wh_ratio, imgW / imgH)
1574
+ imgW = int((imgH * max_wh_ratio))
1575
+ imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
1576
+ h, w = img.shape[:2]
1577
+ ratio = w / float(h)
1578
+ ratio_imgH = int(np.ceil(imgH * ratio))
1579
+ ratio_imgH = max(ratio_imgH, self.limited_min_width)
1580
+ if ratio_imgH > imgW:
1581
+ resized_w = imgW
1582
+ else:
1583
+ resized_w = int(ratio_imgH)
1584
+
1585
+ resized_image = cv2.resize(img, (resized_w, imgH))
1586
+ resized_image = resized_image.astype("float32")
1587
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
1588
+ resized_image -= 0.5
1589
+ resized_image /= 0.5
1590
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
1591
+ padding_im[:, :, 0:resized_w] = resized_image
1592
+ return padding_im
1593
+
1594
+ def __call__(self, img_list):
1595
+ img_num = len(img_list)
1596
+ # Calculate aspect ratio and sort for batching efficiency
1597
+ width_list = []
1598
+ for img in img_list:
1599
+ width_list.append(img.shape[1] / float(img.shape[0]))
1600
+ indices = np.argsort(np.array(width_list))
1601
+
1602
+ rec_res = [["", 0.0]] * img_num
1603
+ batch_num = self.rec_batch_num
1604
+ elapse = 0
1605
+
1606
+ for beg_img_no in range(0, img_num, batch_num):
1607
+ end_img_no = min(img_num, beg_img_no + batch_num)
1608
+ norm_img_batch = []
1609
+ max_wh_ratio = 0
1610
+
1611
+ # Calculate max width/height ratio for this batch
1612
+ for ino in range(beg_img_no, end_img_no):
1613
+ h, w = img_list[indices[ino]].shape[0:2]
1614
+ wh_ratio = w * 1.0 / h
1615
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
1616
+
1617
+ # Normalize images in batch
1618
+ for ino in range(beg_img_no, end_img_no):
1619
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
1620
+ norm_img = norm_img[np.newaxis, :]
1621
+ norm_img_batch.append(norm_img)
1622
+
1623
+ norm_img_batch = np.concatenate(norm_img_batch)
1624
+
1625
+ starttime = time.time()
1626
+
1627
+ # Forward pass
1628
+ inp = mx.array(norm_img_batch)
1629
+ # PPLCNetV3 backbone already handles the transpose from (B, C, H, W) to (B, H, W, C)
1630
+ backbone_out = self.backbone(inp)
1631
+ head_out = self.head(backbone_out)
1632
+
1633
+ preds = np.array(head_out)
1634
+ rec_result = self.postprocess_op(preds)
1635
+ for rno in range(len(rec_result)):
1636
+ rec_res[indices[beg_img_no + rno]] = rec_result[rno]
1637
+ elapse += time.time() - starttime
1638
+
1639
+ return rec_res, elapse
1640
+
1641
+
1642
+ def test_recognizer(args):
1643
+ loaded = np.load(
1644
+ "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/rec_input.npz"
1645
+ )
1646
+ img_list = [loaded[f"arr_{i}"] for i in range(len(loaded.files))]
1647
+ recognizer = TextRecognizer(args)
1648
+ # recognizer.load_weights(
1649
+ # "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer.safetensors"
1650
+ # )
1651
+ # recognizer.save_weights(
1652
+ # "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer.safetensors"
1653
+ # )
1654
+ # recognizer.set_dtype(mx.float16)
1655
+ # recognizer.save_weights(
1656
+ # "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer_f16.safetensors"
1657
+ # )
1658
+ recognizer.load_weights(
1659
+ "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer_f16.safetensors"
1660
+ )
1661
+ recognizer.eval() # Important for BatchNorm behavior in MLX
1662
+
1663
+ rec_res, elapse = recognizer(img_list)
1664
+ print(f"Recognition results: {rec_res}")
1665
+ print(f"Recognition time: {elapse:.3f}s")
1666
+
1667
+
1668
+ class TextSystem:
1669
+ """OCR text detection and recognition system"""
1670
+ def __init__(self, args):
1671
+ self.det = TextDetector(args)
1672
+ self.rec = TextRecognizer(args)
1673
+ self.drop_score = args.drop_score
1674
+
1675
+ # Load weights from safetensors
1676
+ self.det.load_weights(args.det_model_path)
1677
+ self.rec.load_weights(args.rec_model_path)
1678
+
1679
+ self.det.eval()
1680
+ self.rec.eval()
1681
+
1682
+ @staticmethod
1683
+ def _order_boxes(boxes: np.ndarray) -> List[np.ndarray]:
1684
+ """Order detected boxes by position (top to bottom, left to right)"""
1685
+ return sorted(boxes, key=lambda b: (b[0][1], b[0][0]))
1686
+
1687
+ @staticmethod
1688
+ def _crop_rotated(img: np.ndarray, pts: np.ndarray) -> np.ndarray:
1689
+ """Crop rotated text region from image"""
1690
+ pts = pts.astype("float32")
1691
+ w = int(max(np.linalg.norm(pts[0] - pts[1]), np.linalg.norm(pts[2] - pts[3])))
1692
+ h = int(max(np.linalg.norm(pts[0] - pts[3]), np.linalg.norm(pts[1] - pts[2])))
1693
+ M = cv2.getPerspectiveTransform(
1694
+ pts, np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype="float32")
1695
+ )
1696
+ dst = cv2.warpPerspective(img, M, (w, h), borderMode=cv2.BORDER_REPLICATE)
1697
+ if h / max(w, 1) > 1.5:
1698
+ dst = np.rot90(dst)
1699
+ return dst
1700
+
1701
+ def __call__(self, img: np.ndarray) -> Tuple[List[np.ndarray], List[Tuple[str, float]]]:
1702
+ """Perform OCR on input image"""
1703
+ boxes = self.det(img)
1704
+ if boxes is None or len(boxes) == 0:
1705
+ return [], []
1706
+
1707
+ boxes = self._order_boxes(boxes)
1708
+ crops = [self._crop_rotated(img, b.copy()) for b in boxes]
1709
+
1710
+ rec_res, _ = self.rec(crops)
1711
+
1712
+ keep_boxes, keep_txt = [], []
1713
+ for box, (txt, score) in zip(boxes, rec_res):
1714
+ if score >= self.drop_score:
1715
+ keep_boxes.append(box)
1716
+ keep_txt.append((txt, float(score)))
1717
+ return keep_boxes, keep_txt
1718
+
1719
+
1720
+ if __name__ == "__main__":
1721
+ config = Config()
1722
+ text_system = TextSystem(config)
1723
+ # Test with a sample image from model directory if available
1724
+ img_path = os.path.join(config.model_cache_dir, "1.jpg")
1725
+ if not os.path.exists(img_path):
1726
+ print("No test image found. Please provide an image path for testing.")
1727
+ sys.exit(1)
1728
+
1729
+ img = cv2.imread(img_path)
1730
+ if img is None:
1731
+ print(f"Error: Could not read image at {img_path}")
1732
+ sys.exit(1)
1733
+
1734
+ boxes, txts = text_system(img)
1735
+ print(f"Detected {len(boxes)} boxes")
1736
+ print(f"Recognized text: {txts}")