nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +83 -7
- nexaai/utils/model_types.py +2 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +224 -24
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1736 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
import os
|
|
6
|
+
import shutil
|
|
7
|
+
import math
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import List, Tuple
|
|
10
|
+
import cv2
|
|
11
|
+
import numpy as np
|
|
12
|
+
from PIL import Image
|
|
13
|
+
from shapely.geometry import Polygon
|
|
14
|
+
import pyclipper
|
|
15
|
+
|
|
16
|
+
import mlx.core as mx
|
|
17
|
+
import mlx.nn as nn
|
|
18
|
+
|
|
19
|
+
## =============================== PREPROCESSING CLASSES =============================== #
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DetResizeForTest(object):
|
|
23
|
+
def __init__(self, **kwargs):
|
|
24
|
+
super(DetResizeForTest, self).__init__()
|
|
25
|
+
self.resize_type = 0
|
|
26
|
+
if "image_shape" in kwargs:
|
|
27
|
+
self.image_shape = kwargs["image_shape"]
|
|
28
|
+
self.resize_type = 1
|
|
29
|
+
elif "limit_side_len" in kwargs:
|
|
30
|
+
self.limit_side_len = kwargs["limit_side_len"]
|
|
31
|
+
self.limit_type = kwargs.get("limit_type", "min")
|
|
32
|
+
elif "resize_long" in kwargs:
|
|
33
|
+
self.resize_type = 2
|
|
34
|
+
self.resize_long = kwargs.get("resize_long", 960)
|
|
35
|
+
else:
|
|
36
|
+
self.limit_side_len = 736
|
|
37
|
+
self.limit_type = "min"
|
|
38
|
+
|
|
39
|
+
def __call__(self, data):
|
|
40
|
+
img = data["image"]
|
|
41
|
+
src_h, src_w, _ = img.shape
|
|
42
|
+
|
|
43
|
+
if self.resize_type == 0:
|
|
44
|
+
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
|
45
|
+
elif self.resize_type == 2:
|
|
46
|
+
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
|
47
|
+
else:
|
|
48
|
+
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
|
49
|
+
data["image"] = img
|
|
50
|
+
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
|
|
51
|
+
return data
|
|
52
|
+
|
|
53
|
+
def resize_image_type1(self, img):
|
|
54
|
+
resize_h, resize_w = self.image_shape
|
|
55
|
+
ori_h, ori_w = img.shape[:2]
|
|
56
|
+
ratio_h = float(resize_h) / ori_h
|
|
57
|
+
ratio_w = float(resize_w) / ori_w
|
|
58
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
|
59
|
+
return img, [ratio_h, ratio_w]
|
|
60
|
+
|
|
61
|
+
def resize_image_type0(self, img):
|
|
62
|
+
limit_side_len = self.limit_side_len
|
|
63
|
+
h, w, c = img.shape
|
|
64
|
+
|
|
65
|
+
if self.limit_type == "max":
|
|
66
|
+
if max(h, w) > limit_side_len:
|
|
67
|
+
if h > w:
|
|
68
|
+
ratio = float(limit_side_len) / h
|
|
69
|
+
else:
|
|
70
|
+
ratio = float(limit_side_len) / w
|
|
71
|
+
else:
|
|
72
|
+
ratio = 1.0
|
|
73
|
+
elif self.limit_type == "min":
|
|
74
|
+
if min(h, w) < limit_side_len:
|
|
75
|
+
if h < w:
|
|
76
|
+
ratio = float(limit_side_len) / h
|
|
77
|
+
else:
|
|
78
|
+
ratio = float(limit_side_len) / w
|
|
79
|
+
else:
|
|
80
|
+
ratio = 1.0
|
|
81
|
+
elif self.limit_type == "resize_long":
|
|
82
|
+
ratio = float(limit_side_len) / max(h, w)
|
|
83
|
+
else:
|
|
84
|
+
raise Exception("not support limit type, image ")
|
|
85
|
+
resize_h = int(h * ratio)
|
|
86
|
+
resize_w = int(w * ratio)
|
|
87
|
+
|
|
88
|
+
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
|
89
|
+
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
|
93
|
+
return None, (None, None)
|
|
94
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
|
95
|
+
except:
|
|
96
|
+
print(img.shape, resize_w, resize_h)
|
|
97
|
+
sys.exit(0)
|
|
98
|
+
ratio_h = resize_h / float(h)
|
|
99
|
+
ratio_w = resize_w / float(w)
|
|
100
|
+
return img, [ratio_h, ratio_w]
|
|
101
|
+
|
|
102
|
+
def resize_image_type2(self, img):
|
|
103
|
+
h, w, _ = img.shape
|
|
104
|
+
resize_w = w
|
|
105
|
+
resize_h = h
|
|
106
|
+
|
|
107
|
+
if resize_h > resize_w:
|
|
108
|
+
ratio = float(self.resize_long) / resize_h
|
|
109
|
+
else:
|
|
110
|
+
ratio = float(self.resize_long) / resize_w
|
|
111
|
+
|
|
112
|
+
resize_h = int(resize_h * ratio)
|
|
113
|
+
resize_w = int(resize_w * ratio)
|
|
114
|
+
|
|
115
|
+
max_stride = 128
|
|
116
|
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
|
117
|
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
|
118
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
|
119
|
+
ratio_h = resize_h / float(h)
|
|
120
|
+
ratio_w = resize_w / float(w)
|
|
121
|
+
|
|
122
|
+
return img, [ratio_h, ratio_w]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class NormalizeImage(object):
|
|
126
|
+
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
|
|
127
|
+
if isinstance(scale, str):
|
|
128
|
+
scale = eval(scale)
|
|
129
|
+
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
|
130
|
+
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
|
131
|
+
std = std if std is not None else [0.229, 0.224, 0.225]
|
|
132
|
+
|
|
133
|
+
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
|
|
134
|
+
self.mean = np.array(mean).reshape(shape).astype("float32")
|
|
135
|
+
self.std = np.array(std).reshape(shape).astype("float32")
|
|
136
|
+
|
|
137
|
+
def __call__(self, data):
|
|
138
|
+
img = data["image"]
|
|
139
|
+
from PIL import Image
|
|
140
|
+
|
|
141
|
+
if isinstance(img, Image.Image):
|
|
142
|
+
img = np.array(img)
|
|
143
|
+
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
|
|
144
|
+
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
|
|
145
|
+
return data
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class ToCHWImage(object):
|
|
149
|
+
def __init__(self, **kwargs):
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
def __call__(self, data):
|
|
153
|
+
img = data["image"]
|
|
154
|
+
from PIL import Image
|
|
155
|
+
|
|
156
|
+
if isinstance(img, Image.Image):
|
|
157
|
+
img = np.array(img)
|
|
158
|
+
data["image"] = img.transpose((2, 0, 1))
|
|
159
|
+
return data
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class KeepKeys(object):
|
|
163
|
+
def __init__(self, keep_keys, **kwargs):
|
|
164
|
+
self.keep_keys = keep_keys
|
|
165
|
+
|
|
166
|
+
def __call__(self, data):
|
|
167
|
+
data_list = []
|
|
168
|
+
for key in self.keep_keys:
|
|
169
|
+
data_list.append(data[key])
|
|
170
|
+
return data_list
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
## =============================== POSTPROCESSING CLASSES =============================== #
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class DBPostProcess(object):
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
thresh=0.3,
|
|
180
|
+
box_thresh=0.7,
|
|
181
|
+
max_candidates=1000,
|
|
182
|
+
unclip_ratio=2.0,
|
|
183
|
+
use_dilation=False,
|
|
184
|
+
score_mode="fast",
|
|
185
|
+
**kwargs,
|
|
186
|
+
):
|
|
187
|
+
self.thresh = thresh
|
|
188
|
+
self.box_thresh = box_thresh
|
|
189
|
+
self.max_candidates = max_candidates
|
|
190
|
+
self.unclip_ratio = unclip_ratio
|
|
191
|
+
self.min_size = 3
|
|
192
|
+
self.score_mode = score_mode
|
|
193
|
+
assert score_mode in [
|
|
194
|
+
"slow",
|
|
195
|
+
"fast",
|
|
196
|
+
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
|
197
|
+
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
|
|
198
|
+
|
|
199
|
+
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
|
200
|
+
bitmap = _bitmap
|
|
201
|
+
height, width = bitmap.shape
|
|
202
|
+
|
|
203
|
+
outs = cv2.findContours(
|
|
204
|
+
(bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
|
|
205
|
+
)
|
|
206
|
+
if len(outs) == 3:
|
|
207
|
+
img, contours, _ = outs[0], outs[1], outs[2]
|
|
208
|
+
elif len(outs) == 2:
|
|
209
|
+
contours, _ = outs[0], outs[1]
|
|
210
|
+
|
|
211
|
+
num_contours = min(len(contours), self.max_candidates)
|
|
212
|
+
|
|
213
|
+
boxes = []
|
|
214
|
+
scores = []
|
|
215
|
+
for index in range(num_contours):
|
|
216
|
+
contour = contours[index]
|
|
217
|
+
points, sside = self.get_mini_boxes(contour)
|
|
218
|
+
if sside < self.min_size:
|
|
219
|
+
continue
|
|
220
|
+
points = np.array(points)
|
|
221
|
+
if self.score_mode == "fast":
|
|
222
|
+
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
|
223
|
+
else:
|
|
224
|
+
score = self.box_score_slow(pred, contour)
|
|
225
|
+
if self.box_thresh > score:
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
box = self.unclip(points).reshape(-1, 1, 2)
|
|
229
|
+
box, sside = self.get_mini_boxes(box)
|
|
230
|
+
if sside < self.min_size + 2:
|
|
231
|
+
continue
|
|
232
|
+
box = np.array(box)
|
|
233
|
+
|
|
234
|
+
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
|
235
|
+
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
|
236
|
+
boxes.append(box.astype(np.int16))
|
|
237
|
+
scores.append(score)
|
|
238
|
+
return np.array(boxes, dtype=np.int16), scores
|
|
239
|
+
|
|
240
|
+
def unclip(self, box):
|
|
241
|
+
unclip_ratio = self.unclip_ratio
|
|
242
|
+
poly = Polygon(box)
|
|
243
|
+
distance = poly.area * unclip_ratio / poly.length
|
|
244
|
+
offset = pyclipper.PyclipperOffset()
|
|
245
|
+
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
246
|
+
expanded = np.array(offset.Execute(distance))
|
|
247
|
+
return expanded
|
|
248
|
+
|
|
249
|
+
def get_mini_boxes(self, contour):
|
|
250
|
+
bounding_box = cv2.minAreaRect(contour)
|
|
251
|
+
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
|
252
|
+
|
|
253
|
+
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
|
254
|
+
if points[1][1] > points[0][1]:
|
|
255
|
+
index_1 = 0
|
|
256
|
+
index_4 = 1
|
|
257
|
+
else:
|
|
258
|
+
index_1 = 1
|
|
259
|
+
index_4 = 0
|
|
260
|
+
if points[3][1] > points[2][1]:
|
|
261
|
+
index_2 = 2
|
|
262
|
+
index_3 = 3
|
|
263
|
+
else:
|
|
264
|
+
index_2 = 3
|
|
265
|
+
index_3 = 2
|
|
266
|
+
|
|
267
|
+
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
|
268
|
+
return box, min(bounding_box[1])
|
|
269
|
+
|
|
270
|
+
def box_score_fast(self, bitmap, _box):
|
|
271
|
+
h, w = bitmap.shape[:2]
|
|
272
|
+
box = _box.copy()
|
|
273
|
+
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
|
|
274
|
+
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
|
|
275
|
+
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
|
|
276
|
+
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
|
|
277
|
+
|
|
278
|
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
|
279
|
+
box[:, 0] = box[:, 0] - xmin
|
|
280
|
+
box[:, 1] = box[:, 1] - ymin
|
|
281
|
+
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
|
282
|
+
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
|
283
|
+
|
|
284
|
+
def box_score_slow(self, bitmap, contour):
|
|
285
|
+
h, w = bitmap.shape[:2]
|
|
286
|
+
contour = contour.copy()
|
|
287
|
+
contour = np.reshape(contour, (-1, 2))
|
|
288
|
+
|
|
289
|
+
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
|
290
|
+
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
|
291
|
+
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
|
292
|
+
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
|
293
|
+
|
|
294
|
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
|
295
|
+
contour[:, 0] = contour[:, 0] - xmin
|
|
296
|
+
contour[:, 1] = contour[:, 1] - ymin
|
|
297
|
+
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
|
|
298
|
+
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
|
299
|
+
|
|
300
|
+
def __call__(self, outs_dict, shape_list):
|
|
301
|
+
pred = outs_dict["maps"]
|
|
302
|
+
if hasattr(pred, "numpy"): # Check if it has numpy method (for torch tensors)
|
|
303
|
+
pred = pred.numpy()
|
|
304
|
+
elif isinstance(pred, mx.array): # For MLX arrays
|
|
305
|
+
pred = np.array(pred)
|
|
306
|
+
pred = pred[:, 0, :, :]
|
|
307
|
+
segmentation = pred > self.thresh
|
|
308
|
+
|
|
309
|
+
boxes_batch = []
|
|
310
|
+
for batch_index in range(pred.shape[0]):
|
|
311
|
+
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
|
312
|
+
if self.dilation_kernel is not None:
|
|
313
|
+
mask = cv2.dilate(
|
|
314
|
+
np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel
|
|
315
|
+
)
|
|
316
|
+
else:
|
|
317
|
+
mask = segmentation[batch_index]
|
|
318
|
+
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
|
|
319
|
+
boxes_batch.append({"points": boxes})
|
|
320
|
+
return boxes_batch
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class BaseRecLabelDecode(object):
|
|
324
|
+
def __init__(self, character_dict_path=None, use_space_char=False):
|
|
325
|
+
self.beg_str = "sos"
|
|
326
|
+
self.end_str = "eos"
|
|
327
|
+
self.character_str = []
|
|
328
|
+
if character_dict_path is None:
|
|
329
|
+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
330
|
+
dict_character = list(self.character_str)
|
|
331
|
+
else:
|
|
332
|
+
with open(character_dict_path, "rb") as fin:
|
|
333
|
+
lines = fin.readlines()
|
|
334
|
+
for line in lines:
|
|
335
|
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
|
336
|
+
self.character_str.append(line)
|
|
337
|
+
if use_space_char:
|
|
338
|
+
self.character_str.append(" ")
|
|
339
|
+
dict_character = list(self.character_str)
|
|
340
|
+
|
|
341
|
+
dict_character = self.add_special_char(dict_character)
|
|
342
|
+
self.dict = {}
|
|
343
|
+
for i, char in enumerate(dict_character):
|
|
344
|
+
self.dict[char] = i
|
|
345
|
+
self.character = dict_character
|
|
346
|
+
|
|
347
|
+
def add_special_char(self, dict_character):
|
|
348
|
+
return dict_character
|
|
349
|
+
|
|
350
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
351
|
+
result_list = []
|
|
352
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
353
|
+
batch_size = len(text_index)
|
|
354
|
+
for batch_idx in range(batch_size):
|
|
355
|
+
char_list = []
|
|
356
|
+
conf_list = []
|
|
357
|
+
for idx in range(len(text_index[batch_idx])):
|
|
358
|
+
if text_index[batch_idx][idx] in ignored_tokens:
|
|
359
|
+
continue
|
|
360
|
+
if is_remove_duplicate:
|
|
361
|
+
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]:
|
|
362
|
+
continue
|
|
363
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
364
|
+
if text_prob is not None:
|
|
365
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
366
|
+
else:
|
|
367
|
+
conf_list.append(1)
|
|
368
|
+
text = "".join(char_list)
|
|
369
|
+
# Check if conf_list is empty before calculating mean
|
|
370
|
+
confidence = np.mean(conf_list) if len(conf_list) > 0 else 0.0
|
|
371
|
+
result_list.append((text, confidence))
|
|
372
|
+
return result_list
|
|
373
|
+
|
|
374
|
+
def get_ignored_tokens(self):
|
|
375
|
+
return [0]
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class CTCLabelDecode(BaseRecLabelDecode):
|
|
379
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
380
|
+
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
381
|
+
|
|
382
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
383
|
+
if hasattr(preds, "numpy"): # Check if it has numpy method (for torch tensors)
|
|
384
|
+
preds = preds.numpy()
|
|
385
|
+
elif isinstance(preds, mx.array): # For MLX arrays
|
|
386
|
+
preds = np.array(preds)
|
|
387
|
+
preds_idx = preds.argmax(axis=2)
|
|
388
|
+
preds_prob = preds.max(axis=2)
|
|
389
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
|
390
|
+
|
|
391
|
+
if label is None:
|
|
392
|
+
return text
|
|
393
|
+
label = self.decode(label)
|
|
394
|
+
return text, label
|
|
395
|
+
|
|
396
|
+
def add_special_char(self, dict_character):
|
|
397
|
+
dict_character = ["blank"] + dict_character
|
|
398
|
+
return dict_character
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
## =============================== CONFIG CLASS =============================== #
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class Config:
|
|
405
|
+
def __init__(self, model_path):
|
|
406
|
+
# Base paths
|
|
407
|
+
self.base_dir = os.path.abspath(os.path.dirname(__file__))
|
|
408
|
+
|
|
409
|
+
self.model_cache_dir = model_path
|
|
410
|
+
|
|
411
|
+
# Detection settings
|
|
412
|
+
self.det_algorithm = "DB"
|
|
413
|
+
# Use downloaded model files instead of local paths
|
|
414
|
+
self.det_model_path = os.path.join(
|
|
415
|
+
self.model_cache_dir, "ch_ptocr_v4_det_infer.safetensors"
|
|
416
|
+
)
|
|
417
|
+
self.det_limit_side_len = 960
|
|
418
|
+
self.det_limit_type = "max"
|
|
419
|
+
self.det_db_thresh = 0.3
|
|
420
|
+
self.det_db_box_thresh = 0.6
|
|
421
|
+
self.det_db_unclip_ratio = 1.5
|
|
422
|
+
self.use_dilation = False
|
|
423
|
+
self.det_db_score_mode = "fast"
|
|
424
|
+
|
|
425
|
+
# Recognition settings
|
|
426
|
+
self.rec_algorithm = "CRNN"
|
|
427
|
+
# Use downloaded model files instead of local paths
|
|
428
|
+
self.rec_model_path = os.path.join(
|
|
429
|
+
self.model_cache_dir, "ch_ptocr_v4_rec_infer_f16.safetensors"
|
|
430
|
+
)
|
|
431
|
+
self.rec_char_type = "ch"
|
|
432
|
+
self.rec_batch_num = 6
|
|
433
|
+
self.max_text_length = 25
|
|
434
|
+
# Use downloaded character dictionary
|
|
435
|
+
self.rec_char_dict_path = os.path.join(self.model_cache_dir, "ppocr_keys_v1.txt")
|
|
436
|
+
|
|
437
|
+
# Other settings
|
|
438
|
+
self.use_space_char = True
|
|
439
|
+
self.drop_score = 0.5
|
|
440
|
+
self.limited_max_width = 1280
|
|
441
|
+
self.limited_min_width = 16
|
|
442
|
+
# Use downloaded font file
|
|
443
|
+
self.vis_font_path = os.path.join(self.model_cache_dir, "simfang.ttf")
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
## =============================== MODEL COMPONENTS =============================== #
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class LearnableAffineBlock(nn.Module):
|
|
450
|
+
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
|
|
451
|
+
super().__init__()
|
|
452
|
+
# Match PyTorch parameter names exactly (lr_mult and lab_lr are ignored in MLX)
|
|
453
|
+
self.scale = mx.array([scale_value])
|
|
454
|
+
self.bias = mx.array([bias_value])
|
|
455
|
+
|
|
456
|
+
def __call__(self, x):
|
|
457
|
+
return self.scale * x + self.bias
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class ConvBNLayer(nn.Module):
|
|
461
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0):
|
|
462
|
+
super().__init__()
|
|
463
|
+
# lr_mult is ignored in MLX - it's a PyTorch/PaddlePaddle concept
|
|
464
|
+
padding = (kernel_size - 1) // 2
|
|
465
|
+
self.conv = nn.Conv2d(
|
|
466
|
+
in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False
|
|
467
|
+
)
|
|
468
|
+
self.bn = nn.BatchNorm(out_channels)
|
|
469
|
+
|
|
470
|
+
def __call__(self, x):
|
|
471
|
+
x = self.conv(x)
|
|
472
|
+
x = self.bn(x)
|
|
473
|
+
return x
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class Act(nn.Module):
|
|
477
|
+
def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
|
|
478
|
+
super().__init__()
|
|
479
|
+
# lr_mult and lab_lr are ignored in MLX
|
|
480
|
+
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
|
|
481
|
+
|
|
482
|
+
def __call__(self, x):
|
|
483
|
+
return self.lab(nn.hardswish(x))
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
class LearnableRepLayer(nn.Module):
|
|
487
|
+
def __init__(
|
|
488
|
+
self,
|
|
489
|
+
in_channels,
|
|
490
|
+
out_channels,
|
|
491
|
+
kernel_size,
|
|
492
|
+
stride=1,
|
|
493
|
+
groups=1,
|
|
494
|
+
num_conv_branches=4,
|
|
495
|
+
lr_mult=1.0,
|
|
496
|
+
lab_lr=0.1,
|
|
497
|
+
):
|
|
498
|
+
super().__init__()
|
|
499
|
+
self.in_channels = in_channels
|
|
500
|
+
self.out_channels = out_channels
|
|
501
|
+
self.kernel_size = kernel_size
|
|
502
|
+
self.stride = stride
|
|
503
|
+
self.groups = groups
|
|
504
|
+
self.num_conv_branches = num_conv_branches
|
|
505
|
+
|
|
506
|
+
# Identity connection - only if channels match and stride is 1
|
|
507
|
+
self.identity = None
|
|
508
|
+
if out_channels == in_channels and stride == 1:
|
|
509
|
+
self.identity = nn.BatchNorm(in_channels)
|
|
510
|
+
|
|
511
|
+
# Create main conv branches using a list to match PyTorch structure
|
|
512
|
+
self.conv_kxk = []
|
|
513
|
+
for _ in range(num_conv_branches):
|
|
514
|
+
conv = ConvBNLayer(
|
|
515
|
+
in_channels, out_channels, kernel_size, stride, groups=groups, lr_mult=lr_mult
|
|
516
|
+
)
|
|
517
|
+
self.conv_kxk.append(conv)
|
|
518
|
+
|
|
519
|
+
# 1x1 conv branch - only if kernel > 1
|
|
520
|
+
self.conv_1x1 = None
|
|
521
|
+
if kernel_size > 1:
|
|
522
|
+
self.conv_1x1 = ConvBNLayer(
|
|
523
|
+
in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
|
|
527
|
+
self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
|
|
528
|
+
|
|
529
|
+
def __call__(self, x):
|
|
530
|
+
out = 0
|
|
531
|
+
|
|
532
|
+
# Add identity if available
|
|
533
|
+
if self.identity is not None:
|
|
534
|
+
out = out + self.identity(x)
|
|
535
|
+
|
|
536
|
+
# Add 1x1 conv if available
|
|
537
|
+
if self.conv_1x1 is not None:
|
|
538
|
+
out = out + self.conv_1x1(x)
|
|
539
|
+
|
|
540
|
+
# Add all conv_kxk branches
|
|
541
|
+
for conv in self.conv_kxk:
|
|
542
|
+
out = out + conv(x)
|
|
543
|
+
|
|
544
|
+
# Apply learnable affine and activation
|
|
545
|
+
out = self.lab(out)
|
|
546
|
+
if self.stride != 2:
|
|
547
|
+
out = self.act(out)
|
|
548
|
+
|
|
549
|
+
return out
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class SELayer(nn.Module):
|
|
553
|
+
def __init__(self, channel, reduction=4, lr_mult=1.0):
|
|
554
|
+
super().__init__()
|
|
555
|
+
# lr_mult is ignored in MLX
|
|
556
|
+
reduced_channels = max(1, channel // reduction)
|
|
557
|
+
self.conv1 = nn.Conv2d(channel, reduced_channels, 1)
|
|
558
|
+
self.conv2 = nn.Conv2d(reduced_channels, channel, 1)
|
|
559
|
+
|
|
560
|
+
def __call__(self, x):
|
|
561
|
+
identity = x
|
|
562
|
+
se_input = mx.mean(x, axis=(1, 2), keepdims=True) # Changed from (2, 3) to (1, 2)
|
|
563
|
+
se_out = nn.relu(self.conv1(se_input))
|
|
564
|
+
se_out = self.conv2(se_out)
|
|
565
|
+
se_out = mx.clip(se_out + 3.0, 0.0, 6.0) / 6.0
|
|
566
|
+
se_out = identity * se_out
|
|
567
|
+
return se_out
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class LCNetV3Block(nn.Module):
|
|
571
|
+
def __init__(
|
|
572
|
+
self,
|
|
573
|
+
in_channels,
|
|
574
|
+
out_channels,
|
|
575
|
+
stride,
|
|
576
|
+
dw_size,
|
|
577
|
+
use_se=False,
|
|
578
|
+
conv_kxk_num=4,
|
|
579
|
+
lr_mult=1.0,
|
|
580
|
+
lab_lr=0.1,
|
|
581
|
+
):
|
|
582
|
+
super().__init__()
|
|
583
|
+
self.use_se = use_se
|
|
584
|
+
|
|
585
|
+
# Depthwise convolution: in_channels -> in_channels with groups=in_channels
|
|
586
|
+
self.dw_conv = LearnableRepLayer(
|
|
587
|
+
in_channels=in_channels, # INPUT: 192
|
|
588
|
+
out_channels=in_channels, # OUTPUT: 192 (same as input for depthwise)
|
|
589
|
+
kernel_size=dw_size,
|
|
590
|
+
stride=stride,
|
|
591
|
+
groups=in_channels, # GROUPS: 192 (depthwise)
|
|
592
|
+
num_conv_branches=conv_kxk_num,
|
|
593
|
+
lr_mult=lr_mult,
|
|
594
|
+
lab_lr=lab_lr,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
if use_se:
|
|
598
|
+
self.se = SELayer(in_channels, lr_mult=lr_mult)
|
|
599
|
+
|
|
600
|
+
# Pointwise convolution: in_channels -> out_channels with groups=1
|
|
601
|
+
self.pw_conv = LearnableRepLayer(
|
|
602
|
+
in_channels=in_channels, # INPUT: 192
|
|
603
|
+
out_channels=out_channels, # OUTPUT: 384
|
|
604
|
+
kernel_size=1,
|
|
605
|
+
stride=1,
|
|
606
|
+
groups=1, # GROUPS: 1 (pointwise)
|
|
607
|
+
num_conv_branches=conv_kxk_num,
|
|
608
|
+
lr_mult=lr_mult,
|
|
609
|
+
lab_lr=lab_lr,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def __call__(self, x):
|
|
613
|
+
x = self.dw_conv(x)
|
|
614
|
+
if self.use_se:
|
|
615
|
+
x = self.se(x)
|
|
616
|
+
x = self.pw_conv(x)
|
|
617
|
+
return x
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def make_divisible(v, divisor=16):
|
|
621
|
+
return max(divisor, int(v + divisor / 2) // divisor * divisor)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
# Add the NET_CONFIG_det at the top
|
|
625
|
+
NET_CONFIG_det = {
|
|
626
|
+
"blocks2": [[3, 16, 32, 1, False]],
|
|
627
|
+
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
|
628
|
+
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
|
629
|
+
"blocks5": [
|
|
630
|
+
[3, 128, 256, 2, False],
|
|
631
|
+
[5, 256, 256, 1, False],
|
|
632
|
+
[5, 256, 256, 1, False],
|
|
633
|
+
[5, 256, 256, 1, False],
|
|
634
|
+
[5, 256, 256, 1, False],
|
|
635
|
+
],
|
|
636
|
+
"blocks6": [
|
|
637
|
+
[5, 256, 512, 2, True],
|
|
638
|
+
[5, 512, 512, 1, True],
|
|
639
|
+
[5, 512, 512, 1, False],
|
|
640
|
+
[5, 512, 512, 1, False],
|
|
641
|
+
],
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
NET_CONFIG_rec = {
|
|
645
|
+
"blocks2": [[3, 16, 32, 1, False]],
|
|
646
|
+
"blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
|
|
647
|
+
"blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
|
|
648
|
+
"blocks5": [
|
|
649
|
+
[3, 128, 256, (1, 2), False],
|
|
650
|
+
[5, 256, 256, 1, False],
|
|
651
|
+
[5, 256, 256, 1, False],
|
|
652
|
+
[5, 256, 256, 1, False],
|
|
653
|
+
[5, 256, 256, 1, False],
|
|
654
|
+
],
|
|
655
|
+
"blocks6": [
|
|
656
|
+
[5, 256, 512, (2, 1), True],
|
|
657
|
+
[5, 512, 512, 1, True],
|
|
658
|
+
[5, 512, 512, (2, 1), False],
|
|
659
|
+
[5, 512, 512, 1, False],
|
|
660
|
+
],
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
## =================================== for the backbone of text recognition ===================================
|
|
665
|
+
class PPLCNetV3(nn.Module):
|
|
666
|
+
def __init__(
|
|
667
|
+
self,
|
|
668
|
+
scale=1.0,
|
|
669
|
+
conv_kxk_num=4,
|
|
670
|
+
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
|
671
|
+
lab_lr=0.1,
|
|
672
|
+
det=False,
|
|
673
|
+
**kwargs,
|
|
674
|
+
):
|
|
675
|
+
super().__init__()
|
|
676
|
+
self.scale = scale
|
|
677
|
+
self.lr_mult_list = lr_mult_list
|
|
678
|
+
self.det = det
|
|
679
|
+
self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
|
|
680
|
+
|
|
681
|
+
assert isinstance(self.lr_mult_list, (list, tuple))
|
|
682
|
+
assert len(self.lr_mult_list) == 6
|
|
683
|
+
|
|
684
|
+
self.conv1 = ConvBNLayer(
|
|
685
|
+
in_channels=3,
|
|
686
|
+
out_channels=make_divisible(16 * scale),
|
|
687
|
+
kernel_size=3,
|
|
688
|
+
stride=2,
|
|
689
|
+
lr_mult=self.lr_mult_list[0],
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# Build blocks2 - match PyTorch Sequential structure
|
|
693
|
+
blocks2_list = []
|
|
694
|
+
in_channels = make_divisible(16 * scale)
|
|
695
|
+
for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks2"]):
|
|
696
|
+
out_channels = make_divisible(out_c * scale)
|
|
697
|
+
block = LCNetV3Block(
|
|
698
|
+
in_channels=in_channels,
|
|
699
|
+
out_channels=out_channels,
|
|
700
|
+
dw_size=k,
|
|
701
|
+
stride=s,
|
|
702
|
+
use_se=se,
|
|
703
|
+
conv_kxk_num=conv_kxk_num,
|
|
704
|
+
lr_mult=self.lr_mult_list[1],
|
|
705
|
+
lab_lr=lab_lr,
|
|
706
|
+
)
|
|
707
|
+
blocks2_list.append(block)
|
|
708
|
+
in_channels = out_channels
|
|
709
|
+
self.blocks2 = blocks2_list
|
|
710
|
+
|
|
711
|
+
# Build blocks3
|
|
712
|
+
blocks3_list = []
|
|
713
|
+
for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks3"]):
|
|
714
|
+
out_channels = make_divisible(out_c * scale)
|
|
715
|
+
block = LCNetV3Block(
|
|
716
|
+
in_channels=in_channels,
|
|
717
|
+
out_channels=out_channels,
|
|
718
|
+
dw_size=k,
|
|
719
|
+
stride=s,
|
|
720
|
+
use_se=se,
|
|
721
|
+
conv_kxk_num=conv_kxk_num,
|
|
722
|
+
lr_mult=self.lr_mult_list[2],
|
|
723
|
+
lab_lr=lab_lr,
|
|
724
|
+
)
|
|
725
|
+
blocks3_list.append(block)
|
|
726
|
+
in_channels = out_channels
|
|
727
|
+
self.blocks3 = blocks3_list
|
|
728
|
+
|
|
729
|
+
# Build blocks4
|
|
730
|
+
blocks4_list = []
|
|
731
|
+
for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks4"]):
|
|
732
|
+
out_channels = make_divisible(out_c * scale)
|
|
733
|
+
block = LCNetV3Block(
|
|
734
|
+
in_channels=in_channels,
|
|
735
|
+
out_channels=out_channels,
|
|
736
|
+
dw_size=k,
|
|
737
|
+
stride=s,
|
|
738
|
+
use_se=se,
|
|
739
|
+
conv_kxk_num=conv_kxk_num,
|
|
740
|
+
lr_mult=self.lr_mult_list[3],
|
|
741
|
+
lab_lr=lab_lr,
|
|
742
|
+
)
|
|
743
|
+
blocks4_list.append(block)
|
|
744
|
+
in_channels = out_channels
|
|
745
|
+
self.blocks4 = blocks4_list
|
|
746
|
+
|
|
747
|
+
# Build blocks5
|
|
748
|
+
blocks5_list = []
|
|
749
|
+
for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks5"]):
|
|
750
|
+
out_channels = make_divisible(out_c * scale)
|
|
751
|
+
block = LCNetV3Block(
|
|
752
|
+
in_channels=in_channels,
|
|
753
|
+
out_channels=out_channels,
|
|
754
|
+
dw_size=k,
|
|
755
|
+
stride=s,
|
|
756
|
+
use_se=se,
|
|
757
|
+
conv_kxk_num=conv_kxk_num,
|
|
758
|
+
lr_mult=self.lr_mult_list[4],
|
|
759
|
+
lab_lr=lab_lr,
|
|
760
|
+
)
|
|
761
|
+
blocks5_list.append(block)
|
|
762
|
+
in_channels = out_channels
|
|
763
|
+
self.blocks5 = blocks5_list
|
|
764
|
+
|
|
765
|
+
# Build blocks6
|
|
766
|
+
blocks6_list = []
|
|
767
|
+
for i, (k, _, out_c, s, se) in enumerate(self.net_config["blocks6"]):
|
|
768
|
+
out_channels = make_divisible(out_c * scale)
|
|
769
|
+
block = LCNetV3Block(
|
|
770
|
+
in_channels=in_channels,
|
|
771
|
+
out_channels=out_channels,
|
|
772
|
+
dw_size=k,
|
|
773
|
+
stride=s,
|
|
774
|
+
use_se=se,
|
|
775
|
+
conv_kxk_num=conv_kxk_num,
|
|
776
|
+
lr_mult=self.lr_mult_list[5],
|
|
777
|
+
lab_lr=lab_lr,
|
|
778
|
+
)
|
|
779
|
+
blocks6_list.append(block)
|
|
780
|
+
in_channels = out_channels
|
|
781
|
+
self.blocks6 = blocks6_list
|
|
782
|
+
|
|
783
|
+
self.out_channels = make_divisible(512 * scale)
|
|
784
|
+
|
|
785
|
+
if self.det:
|
|
786
|
+
mv_c = [16, 24, 56, 480]
|
|
787
|
+
self.out_channels = [
|
|
788
|
+
make_divisible(self.net_config["blocks3"][-1][2] * scale),
|
|
789
|
+
make_divisible(self.net_config["blocks4"][-1][2] * scale),
|
|
790
|
+
make_divisible(self.net_config["blocks5"][-1][2] * scale),
|
|
791
|
+
make_divisible(self.net_config["blocks6"][-1][2] * scale),
|
|
792
|
+
]
|
|
793
|
+
|
|
794
|
+
self.layer_list = []
|
|
795
|
+
for i in range(4):
|
|
796
|
+
layer = nn.Conv2d(self.out_channels[i], int(mv_c[i] * scale), 1, bias=True)
|
|
797
|
+
self.layer_list.append(layer)
|
|
798
|
+
|
|
799
|
+
self.out_channels = [
|
|
800
|
+
int(mv_c[0] * scale),
|
|
801
|
+
int(mv_c[1] * scale),
|
|
802
|
+
int(mv_c[2] * scale),
|
|
803
|
+
int(mv_c[3] * scale),
|
|
804
|
+
]
|
|
805
|
+
|
|
806
|
+
def __call__(self, x):
|
|
807
|
+
out_list = []
|
|
808
|
+
|
|
809
|
+
## Transpose to match the format required by MLX
|
|
810
|
+
x = mx.transpose(x, (0, 2, 3, 1))
|
|
811
|
+
x = self.conv1(x)
|
|
812
|
+
|
|
813
|
+
for block in self.blocks2:
|
|
814
|
+
x = block(x)
|
|
815
|
+
|
|
816
|
+
for block in self.blocks3:
|
|
817
|
+
x = block(x)
|
|
818
|
+
out_list.append(x)
|
|
819
|
+
|
|
820
|
+
for block in self.blocks4:
|
|
821
|
+
x = block(x)
|
|
822
|
+
out_list.append(x)
|
|
823
|
+
|
|
824
|
+
for block in self.blocks5:
|
|
825
|
+
x = block(x)
|
|
826
|
+
out_list.append(x)
|
|
827
|
+
|
|
828
|
+
for block in self.blocks6:
|
|
829
|
+
x = block(x)
|
|
830
|
+
out_list.append(x)
|
|
831
|
+
|
|
832
|
+
if self.det:
|
|
833
|
+
out_list[0] = self.layer_list[0](out_list[0])
|
|
834
|
+
out_list[1] = self.layer_list[1](out_list[1])
|
|
835
|
+
out_list[2] = self.layer_list[2](out_list[2])
|
|
836
|
+
out_list[3] = self.layer_list[3](out_list[3])
|
|
837
|
+
return out_list
|
|
838
|
+
|
|
839
|
+
B, H, W, C = x.shape
|
|
840
|
+
|
|
841
|
+
# Ensure dimensions are divisible by kernel size for clean pooling
|
|
842
|
+
H_out = H // 3
|
|
843
|
+
W_out = W // 2
|
|
844
|
+
|
|
845
|
+
# Trim to make dimensions divisible
|
|
846
|
+
x = x[:, : H_out * 3, : W_out * 2, :]
|
|
847
|
+
|
|
848
|
+
# Reshape for 3x2 average pooling
|
|
849
|
+
x = mx.reshape(x, (B, H_out, 3, W_out, 2, C))
|
|
850
|
+
x = mx.mean(x, axis=(2, 4)) # Average over the 3x2 kernel
|
|
851
|
+
return x
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
## =================================== for the neck of text detection ===================================
|
|
855
|
+
class IndexedContainer(nn.Module):
|
|
856
|
+
"""Container that creates numbered attributes for MLX"""
|
|
857
|
+
|
|
858
|
+
def __init__(self):
|
|
859
|
+
super().__init__()
|
|
860
|
+
self._modules = []
|
|
861
|
+
|
|
862
|
+
def add_module(self, module):
|
|
863
|
+
idx = len(self._modules)
|
|
864
|
+
setattr(self, str(idx), module)
|
|
865
|
+
self._modules.append(module)
|
|
866
|
+
return idx
|
|
867
|
+
|
|
868
|
+
def __getitem__(self, idx):
|
|
869
|
+
return getattr(self, str(idx))
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
class SEModule(nn.Module):
|
|
873
|
+
def __init__(self, in_channels, reduction=4):
|
|
874
|
+
super().__init__()
|
|
875
|
+
reduced_channels = in_channels // reduction
|
|
876
|
+
self.conv1 = nn.Conv2d(in_channels, reduced_channels, 1, bias=True)
|
|
877
|
+
self.conv2 = nn.Conv2d(reduced_channels, in_channels, 1, bias=True)
|
|
878
|
+
|
|
879
|
+
def __call__(self, inputs):
|
|
880
|
+
outputs = mx.mean(inputs, axis=(1, 2), keepdims=True)
|
|
881
|
+
outputs = self.conv1(outputs)
|
|
882
|
+
outputs = nn.relu(outputs)
|
|
883
|
+
outputs = self.conv2(outputs)
|
|
884
|
+
# PaddlePaddle hard_sigmoid: F.relu6(1.2 * x + 3.) / 6.
|
|
885
|
+
outputs = mx.clip(1.2 * outputs + 3.0, 0.0, 6.0) / 6.0 # PaddlePaddle hard_sigmoid
|
|
886
|
+
outputs = inputs * outputs
|
|
887
|
+
return outputs
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
class RSELayer(nn.Module):
|
|
891
|
+
def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
|
|
892
|
+
super().__init__()
|
|
893
|
+
padding = kernel_size // 2
|
|
894
|
+
self.in_conv = nn.Conv2d(
|
|
895
|
+
in_channels, out_channels, kernel_size, padding=padding, bias=False
|
|
896
|
+
)
|
|
897
|
+
self.se_block = SEModule(out_channels)
|
|
898
|
+
self.shortcut = shortcut
|
|
899
|
+
|
|
900
|
+
def __call__(self, x):
|
|
901
|
+
conv_out = self.in_conv(x)
|
|
902
|
+
if self.shortcut:
|
|
903
|
+
return conv_out + self.se_block(conv_out)
|
|
904
|
+
else:
|
|
905
|
+
return self.se_block(conv_out)
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
class RSEFPN(nn.Module):
|
|
909
|
+
def __init__(self, in_channels, out_channels=96, shortcut=True):
|
|
910
|
+
super().__init__()
|
|
911
|
+
self.out_channels = out_channels
|
|
912
|
+
|
|
913
|
+
# Create container modules that inherit from nn.Module
|
|
914
|
+
self.ins_conv = IndexedContainer()
|
|
915
|
+
self.inp_conv = IndexedContainer()
|
|
916
|
+
|
|
917
|
+
# Add modules - this should create the correct parameter names
|
|
918
|
+
for i, in_ch in enumerate(in_channels):
|
|
919
|
+
self.ins_conv.add_module(
|
|
920
|
+
RSELayer(in_ch, out_channels, kernel_size=1, shortcut=shortcut)
|
|
921
|
+
)
|
|
922
|
+
self.inp_conv.add_module(
|
|
923
|
+
RSELayer(out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut)
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
def __call__(self, x):
|
|
927
|
+
c2, c3, c4, c5 = x
|
|
928
|
+
|
|
929
|
+
in5 = self.ins_conv[3](c5)
|
|
930
|
+
in4 = self.ins_conv[2](c4)
|
|
931
|
+
in3 = self.ins_conv[1](c3)
|
|
932
|
+
in2 = self.ins_conv[0](c2)
|
|
933
|
+
|
|
934
|
+
# Upsample both H and W dimensions
|
|
935
|
+
up_in5 = mx.repeat(in5, 2, axis=1)
|
|
936
|
+
up_in5 = mx.repeat(up_in5, 2, axis=2)
|
|
937
|
+
out4 = in4 + up_in5
|
|
938
|
+
|
|
939
|
+
up_out4 = mx.repeat(out4, 2, axis=1)
|
|
940
|
+
up_out4 = mx.repeat(up_out4, 2, axis=2)
|
|
941
|
+
out3 = in3 + up_out4
|
|
942
|
+
|
|
943
|
+
up_out3 = mx.repeat(out3, 2, axis=1)
|
|
944
|
+
up_out3 = mx.repeat(up_out3, 2, axis=2)
|
|
945
|
+
out2 = in2 + up_out3
|
|
946
|
+
|
|
947
|
+
p5 = self.inp_conv[3](in5)
|
|
948
|
+
p4 = self.inp_conv[2](out4)
|
|
949
|
+
p3 = self.inp_conv[1](out3)
|
|
950
|
+
p2 = self.inp_conv[0](out2)
|
|
951
|
+
|
|
952
|
+
# Use target size from p2 for consistent upsampling
|
|
953
|
+
target_h, target_w = p2.shape[1], p2.shape[2]
|
|
954
|
+
|
|
955
|
+
# MLX doesn't have F.upsample, but we can calculate target sizes and use repeat more carefully
|
|
956
|
+
# P5: upsample by 8x to match p2 size
|
|
957
|
+
p5_h, p5_w = p5.shape[1], p5.shape[2]
|
|
958
|
+
p5_target_h, p5_target_w = min(target_h, p5_h * 8), min(target_w, p5_w * 8)
|
|
959
|
+
|
|
960
|
+
# Calculate exact repeat factors
|
|
961
|
+
h_repeat_p5 = p5_target_h // p5_h
|
|
962
|
+
w_repeat_p5 = p5_target_w // p5_w
|
|
963
|
+
p5 = mx.repeat(p5, h_repeat_p5, axis=1)
|
|
964
|
+
p5 = mx.repeat(p5, w_repeat_p5, axis=2)
|
|
965
|
+
p5 = p5[:, :target_h, :target_w]
|
|
966
|
+
|
|
967
|
+
# P4: upsample by 4x to match p2 size
|
|
968
|
+
p4_h, p4_w = p4.shape[1], p4.shape[2]
|
|
969
|
+
p4_target_h, p4_target_w = min(target_h, p4_h * 4), min(target_w, p4_w * 4)
|
|
970
|
+
|
|
971
|
+
h_repeat_p4 = p4_target_h // p4_h
|
|
972
|
+
w_repeat_p4 = p4_target_w // p4_w
|
|
973
|
+
p4 = mx.repeat(p4, h_repeat_p4, axis=1)
|
|
974
|
+
p4 = mx.repeat(p4, w_repeat_p4, axis=2)
|
|
975
|
+
p4 = p4[:, :target_h, :target_w]
|
|
976
|
+
|
|
977
|
+
# P3: upsample by 2x to match p2 size
|
|
978
|
+
p3_h, p3_w = p3.shape[1], p3.shape[2]
|
|
979
|
+
p3_target_h, p3_target_w = min(target_h, p3_h * 2), min(target_w, p3_w * 2)
|
|
980
|
+
|
|
981
|
+
h_repeat_p3 = p3_target_h // p3_h
|
|
982
|
+
w_repeat_p3 = p3_target_w // p3_w
|
|
983
|
+
p3 = mx.repeat(p3, h_repeat_p3, axis=1)
|
|
984
|
+
p3 = mx.repeat(p3, w_repeat_p3, axis=2)
|
|
985
|
+
p3 = p3[:, :target_h, :target_w]
|
|
986
|
+
|
|
987
|
+
fuse = mx.concatenate([p5, p4, p3, p2], axis=-1)
|
|
988
|
+
return fuse
|
|
989
|
+
|
|
990
|
+
|
|
991
|
+
## =================================== for the head of text detection ===================================
|
|
992
|
+
class DetectionHead(nn.Module):
|
|
993
|
+
def __init__(self, in_channels):
|
|
994
|
+
super().__init__()
|
|
995
|
+
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 3, padding=1, bias=False)
|
|
996
|
+
self.conv_bn1 = nn.BatchNorm(in_channels // 4)
|
|
997
|
+
|
|
998
|
+
self.conv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, stride=2)
|
|
999
|
+
self.conv_bn2 = nn.BatchNorm(in_channels // 4)
|
|
1000
|
+
|
|
1001
|
+
self.conv3 = nn.ConvTranspose2d(in_channels // 4, 1, 2, stride=2)
|
|
1002
|
+
|
|
1003
|
+
def __call__(self, x):
|
|
1004
|
+
x = nn.relu(self.conv_bn1(self.conv1(x)))
|
|
1005
|
+
x = nn.relu(self.conv_bn2(self.conv2(x)))
|
|
1006
|
+
x = self.conv3(x)
|
|
1007
|
+
x = nn.sigmoid(x)
|
|
1008
|
+
return x
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
class DBHead(nn.Module):
|
|
1012
|
+
def __init__(self, in_channels, k=50):
|
|
1013
|
+
super().__init__()
|
|
1014
|
+
self.k = k
|
|
1015
|
+
self.binarize = DetectionHead(in_channels) # First branch
|
|
1016
|
+
self.thresh = DetectionHead(in_channels) # Second branch (was missing!)
|
|
1017
|
+
|
|
1018
|
+
def step_function(self, x, y):
|
|
1019
|
+
return 1.0 / (1.0 + mx.exp(-self.k * (x - y)))
|
|
1020
|
+
|
|
1021
|
+
def __call__(self, x):
|
|
1022
|
+
shrink_maps = self.binarize(x)
|
|
1023
|
+
shrink_maps = mx.transpose(shrink_maps, (0, 3, 1, 2))
|
|
1024
|
+
return {"maps": shrink_maps}
|
|
1025
|
+
|
|
1026
|
+
|
|
1027
|
+
class TextDetector(nn.Module):
|
|
1028
|
+
def __init__(self, args):
|
|
1029
|
+
super().__init__()
|
|
1030
|
+
|
|
1031
|
+
self.preprocess_op = [
|
|
1032
|
+
DetResizeForTest(
|
|
1033
|
+
limit_side_len=args.det_limit_side_len, limit_type=args.det_limit_type
|
|
1034
|
+
),
|
|
1035
|
+
NormalizeImage(
|
|
1036
|
+
mean=[0.485, 0.456, 0.406],
|
|
1037
|
+
std=[0.229, 0.224, 0.225],
|
|
1038
|
+
scale=1.0 / 255.0,
|
|
1039
|
+
order="hwc",
|
|
1040
|
+
),
|
|
1041
|
+
ToCHWImage(),
|
|
1042
|
+
KeepKeys(keep_keys=["image", "shape"]),
|
|
1043
|
+
]
|
|
1044
|
+
|
|
1045
|
+
postprocess_params = {
|
|
1046
|
+
"thresh": args.det_db_thresh,
|
|
1047
|
+
"box_thresh": args.det_db_box_thresh,
|
|
1048
|
+
"max_candidates": 1000,
|
|
1049
|
+
"unclip_ratio": args.det_db_unclip_ratio,
|
|
1050
|
+
"use_dilation": args.use_dilation,
|
|
1051
|
+
"score_mode": args.det_db_score_mode,
|
|
1052
|
+
}
|
|
1053
|
+
self.postprocess_op = DBPostProcess(**postprocess_params)
|
|
1054
|
+
|
|
1055
|
+
# Match exact PyTorch model structure
|
|
1056
|
+
backbone_config = {"scale": 0.75, "det": True, "in_channels": 3}
|
|
1057
|
+
self.backbone = PPLCNetV3(**backbone_config)
|
|
1058
|
+
|
|
1059
|
+
# Use correct neck config - the backbone outputs these channels
|
|
1060
|
+
neck_config = {
|
|
1061
|
+
"out_channels": 96,
|
|
1062
|
+
"shortcut": True,
|
|
1063
|
+
"in_channels": self.backbone.out_channels, # Should be [12, 18, 42, 360]
|
|
1064
|
+
}
|
|
1065
|
+
self.neck = RSEFPN(**neck_config)
|
|
1066
|
+
|
|
1067
|
+
head_config = {"k": 50, "in_channels": 96}
|
|
1068
|
+
self.head = DBHead(**head_config)
|
|
1069
|
+
|
|
1070
|
+
def order_points_clockwise(self, pts):
|
|
1071
|
+
"""
|
|
1072
|
+
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
|
|
1073
|
+
# sort the points based on their x-coordinates
|
|
1074
|
+
"""
|
|
1075
|
+
xSorted = pts[np.argsort(pts[:, 0]), :]
|
|
1076
|
+
|
|
1077
|
+
# grab the left-most and right-most points from the sorted
|
|
1078
|
+
# x-roodinate points
|
|
1079
|
+
leftMost = xSorted[:2, :]
|
|
1080
|
+
rightMost = xSorted[2:, :]
|
|
1081
|
+
|
|
1082
|
+
# now, sort the left-most coordinates according to their
|
|
1083
|
+
# y-coordinates so we can grab the top-left and bottom-left
|
|
1084
|
+
# points, respectively
|
|
1085
|
+
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
|
|
1086
|
+
(tl, bl) = leftMost
|
|
1087
|
+
|
|
1088
|
+
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
|
|
1089
|
+
(tr, br) = rightMost
|
|
1090
|
+
|
|
1091
|
+
rect = np.array([tl, tr, br, bl], dtype="float32")
|
|
1092
|
+
return rect
|
|
1093
|
+
|
|
1094
|
+
def clip_det_res(self, points, img_height, img_width):
|
|
1095
|
+
for pno in range(points.shape[0]):
|
|
1096
|
+
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
|
1097
|
+
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
|
1098
|
+
return points
|
|
1099
|
+
|
|
1100
|
+
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
1101
|
+
img_height, img_width = image_shape[0:2]
|
|
1102
|
+
dt_boxes_new = []
|
|
1103
|
+
for box in dt_boxes:
|
|
1104
|
+
box = self.order_points_clockwise(box)
|
|
1105
|
+
box = self.clip_det_res(box, img_height, img_width)
|
|
1106
|
+
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
1107
|
+
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
1108
|
+
if rect_width <= 3 or rect_height <= 3:
|
|
1109
|
+
continue
|
|
1110
|
+
dt_boxes_new.append(box)
|
|
1111
|
+
return np.array(dt_boxes_new) if dt_boxes_new else np.array([])
|
|
1112
|
+
|
|
1113
|
+
def forward(self, x):
|
|
1114
|
+
features = self.backbone(x)
|
|
1115
|
+
neck_out = self.neck(features)
|
|
1116
|
+
head_out = self.head(neck_out)
|
|
1117
|
+
return head_out
|
|
1118
|
+
|
|
1119
|
+
def __call__(self, img):
|
|
1120
|
+
ori_im = img.copy()
|
|
1121
|
+
data = {"image": img}
|
|
1122
|
+
|
|
1123
|
+
for op in self.preprocess_op:
|
|
1124
|
+
data = op(data)
|
|
1125
|
+
|
|
1126
|
+
img, shape_list = data
|
|
1127
|
+
if img is None:
|
|
1128
|
+
return None, 0
|
|
1129
|
+
|
|
1130
|
+
img = np.expand_dims(img, axis=0)
|
|
1131
|
+
shape_list = np.expand_dims(shape_list, axis=0)
|
|
1132
|
+
|
|
1133
|
+
inp = mx.array(img.copy())
|
|
1134
|
+
outputs = self.forward(inp)
|
|
1135
|
+
preds = {"maps": np.array(outputs["maps"])}
|
|
1136
|
+
|
|
1137
|
+
post_result = self.postprocess_op(preds, shape_list)
|
|
1138
|
+
dt_boxes = post_result[0]["points"] if post_result else []
|
|
1139
|
+
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
|
1140
|
+
return dt_boxes
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
def test_detector(args):
|
|
1144
|
+
img = np.load(
|
|
1145
|
+
"/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/det_inp.npy"
|
|
1146
|
+
)
|
|
1147
|
+
detector = TextDetector(args)
|
|
1148
|
+
detector.eval()
|
|
1149
|
+
detector.load_weights(
|
|
1150
|
+
"/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_det_infer.safetensors"
|
|
1151
|
+
)
|
|
1152
|
+
boxes = detector(img)
|
|
1153
|
+
print(f"Detected {len(boxes)} boxes")
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
## ==================================== Now the text det works ==================================== #
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
## ==================================== Text Recognition Components ==================================== #
|
|
1160
|
+
|
|
1161
|
+
|
|
1162
|
+
class Im2Seq(nn.Module):
|
|
1163
|
+
def __init__(self, in_channels, **kwargs):
|
|
1164
|
+
super().__init__()
|
|
1165
|
+
self.out_channels = in_channels
|
|
1166
|
+
|
|
1167
|
+
def __call__(self, x):
|
|
1168
|
+
B, H, W, C = x.shape # MLX format: (B, H, W, C)
|
|
1169
|
+
assert H == 1
|
|
1170
|
+
x = mx.reshape(x, (B, H * W, C)) # (B, W, C) for sequence
|
|
1171
|
+
return x
|
|
1172
|
+
|
|
1173
|
+
|
|
1174
|
+
class SVTRConvBNLayer(nn.Module):
|
|
1175
|
+
def __init__(
|
|
1176
|
+
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, groups=1, act="swish"
|
|
1177
|
+
):
|
|
1178
|
+
super().__init__()
|
|
1179
|
+
self.conv = nn.Conv2d(
|
|
1180
|
+
in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False
|
|
1181
|
+
)
|
|
1182
|
+
self.norm = nn.BatchNorm(out_channels)
|
|
1183
|
+
self.act = act
|
|
1184
|
+
|
|
1185
|
+
def __call__(self, x):
|
|
1186
|
+
x = self.conv(x)
|
|
1187
|
+
x = self.norm(x)
|
|
1188
|
+
if self.act == "swish":
|
|
1189
|
+
x = x * mx.sigmoid(x)
|
|
1190
|
+
return x
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
class EncoderWithSVTR(nn.Module):
|
|
1194
|
+
def __init__(
|
|
1195
|
+
self,
|
|
1196
|
+
in_channels,
|
|
1197
|
+
dims=64,
|
|
1198
|
+
depth=2,
|
|
1199
|
+
hidden_dims=120,
|
|
1200
|
+
kernel_size=[3, 3],
|
|
1201
|
+
use_guide=False,
|
|
1202
|
+
**kwargs,
|
|
1203
|
+
):
|
|
1204
|
+
super().__init__()
|
|
1205
|
+
self.depth = depth
|
|
1206
|
+
self.use_guide = use_guide
|
|
1207
|
+
|
|
1208
|
+
# Match original PyTorch structure exactly
|
|
1209
|
+
self.conv1 = SVTRConvBNLayer(
|
|
1210
|
+
in_channels,
|
|
1211
|
+
in_channels // 8,
|
|
1212
|
+
kernel_size=(1, 3), # Match actual model: (1, 3) not 3
|
|
1213
|
+
padding=(0, 1), # Match actual model: (0, 1) not 1
|
|
1214
|
+
act="swish",
|
|
1215
|
+
)
|
|
1216
|
+
self.conv2 = SVTRConvBNLayer(
|
|
1217
|
+
in_channels // 8, hidden_dims, kernel_size=1, padding=0, act="swish"
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
# SVTR blocks - ADD THIS BACK!
|
|
1221
|
+
self.svtr_block = []
|
|
1222
|
+
for i in range(depth):
|
|
1223
|
+
block = Block(
|
|
1224
|
+
dim=hidden_dims,
|
|
1225
|
+
num_heads=8,
|
|
1226
|
+
mixer="Global",
|
|
1227
|
+
mlp_ratio=2.0,
|
|
1228
|
+
qkv_bias=True, # Change from False to True
|
|
1229
|
+
act_layer="swish", # Add this
|
|
1230
|
+
**kwargs,
|
|
1231
|
+
)
|
|
1232
|
+
setattr(self, f"svtr_block_{i}", block)
|
|
1233
|
+
self.svtr_block.append(block)
|
|
1234
|
+
|
|
1235
|
+
self.norm = nn.LayerNorm(hidden_dims)
|
|
1236
|
+
|
|
1237
|
+
self.conv3 = SVTRConvBNLayer(
|
|
1238
|
+
hidden_dims, in_channels, kernel_size=1, padding=0, act="swish"
|
|
1239
|
+
)
|
|
1240
|
+
self.conv4 = SVTRConvBNLayer(
|
|
1241
|
+
2 * in_channels, in_channels // 8, kernel_size=3, padding=1, act="swish"
|
|
1242
|
+
)
|
|
1243
|
+
self.conv1x1 = SVTRConvBNLayer(
|
|
1244
|
+
in_channels // 8, dims, kernel_size=1, padding=0, act="swish"
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
self.out_channels = dims
|
|
1248
|
+
|
|
1249
|
+
def __call__(self, x):
|
|
1250
|
+
# Short cut
|
|
1251
|
+
h = x
|
|
1252
|
+
|
|
1253
|
+
# Reduce dim
|
|
1254
|
+
z = self.conv1(x)
|
|
1255
|
+
z = self.conv2(z)
|
|
1256
|
+
|
|
1257
|
+
# SVTR global blocks
|
|
1258
|
+
B, H, W, C = z.shape
|
|
1259
|
+
z = mx.reshape(z, (B, H * W, C)) # Flatten spatial dims
|
|
1260
|
+
|
|
1261
|
+
for block in self.svtr_block:
|
|
1262
|
+
z = block(z)
|
|
1263
|
+
|
|
1264
|
+
z = self.norm(z)
|
|
1265
|
+
|
|
1266
|
+
# Reshape back - CRITICAL: use original H, W
|
|
1267
|
+
z = mx.reshape(z, (B, H, W, C)) # Use the H, W from before SVTR blocks
|
|
1268
|
+
z = self.conv3(z)
|
|
1269
|
+
|
|
1270
|
+
# Concatenate with shortcut - dimensions should match now
|
|
1271
|
+
z = mx.concatenate([h, z], axis=-1)
|
|
1272
|
+
z = self.conv4(z)
|
|
1273
|
+
z = self.conv1x1(z)
|
|
1274
|
+
|
|
1275
|
+
return z
|
|
1276
|
+
|
|
1277
|
+
|
|
1278
|
+
class Mlp(nn.Module):
|
|
1279
|
+
def __init__(
|
|
1280
|
+
self, in_features, hidden_features=None, out_features=None, act_layer="swish", drop=0.0
|
|
1281
|
+
):
|
|
1282
|
+
super().__init__()
|
|
1283
|
+
out_features = out_features or in_features
|
|
1284
|
+
hidden_features = hidden_features or in_features
|
|
1285
|
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True) # Add bias=True
|
|
1286
|
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True) # Add bias=True
|
|
1287
|
+
self.act_layer = act_layer
|
|
1288
|
+
|
|
1289
|
+
def __call__(self, x):
|
|
1290
|
+
x = self.fc1(x)
|
|
1291
|
+
# Use swish activation to match PyTorch
|
|
1292
|
+
if self.act_layer == "swish":
|
|
1293
|
+
x = x * mx.sigmoid(x) # Swish activation
|
|
1294
|
+
elif self.act_layer == "gelu":
|
|
1295
|
+
x = nn.gelu(x)
|
|
1296
|
+
x = self.fc2(x)
|
|
1297
|
+
return x
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
class Attention(nn.Module):
|
|
1301
|
+
def __init__(
|
|
1302
|
+
self,
|
|
1303
|
+
dim,
|
|
1304
|
+
num_heads=8,
|
|
1305
|
+
mixer="Global",
|
|
1306
|
+
HW=None,
|
|
1307
|
+
local_k=[7, 11],
|
|
1308
|
+
qkv_bias=False,
|
|
1309
|
+
qk_scale=None,
|
|
1310
|
+
attn_drop=0.0,
|
|
1311
|
+
proj_drop=0.0,
|
|
1312
|
+
):
|
|
1313
|
+
super().__init__()
|
|
1314
|
+
self.num_heads = num_heads
|
|
1315
|
+
head_dim = dim // num_heads
|
|
1316
|
+
self.scale = qk_scale or head_dim**-0.5
|
|
1317
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
1318
|
+
self.proj = nn.Linear(dim, dim, bias=True)
|
|
1319
|
+
self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
|
|
1320
|
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
|
|
1321
|
+
self.HW = HW
|
|
1322
|
+
self.mixer = mixer
|
|
1323
|
+
|
|
1324
|
+
# Set N and C if HW is provided (like in PyTorch)
|
|
1325
|
+
if HW is not None:
|
|
1326
|
+
H = HW[0]
|
|
1327
|
+
W = HW[1]
|
|
1328
|
+
self.N = H * W
|
|
1329
|
+
self.C = dim
|
|
1330
|
+
|
|
1331
|
+
def __call__(self, x):
|
|
1332
|
+
if self.HW is not None:
|
|
1333
|
+
N = self.N
|
|
1334
|
+
C = self.C
|
|
1335
|
+
else:
|
|
1336
|
+
_, N, C = x.shape
|
|
1337
|
+
|
|
1338
|
+
qkv = self.qkv(x)
|
|
1339
|
+
qkv = qkv.reshape((-1, N, 3, self.num_heads, C // self.num_heads))
|
|
1340
|
+
qkv = mx.transpose(qkv, (2, 0, 3, 1, 4)) # permute(2, 0, 3, 1, 4)
|
|
1341
|
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
|
1342
|
+
|
|
1343
|
+
attn = q @ mx.transpose(k, (0, 1, 3, 2)) # q.matmul(k.permute(0, 1, 3, 2))
|
|
1344
|
+
if self.mixer == "Local":
|
|
1345
|
+
# attn += self.mask # Would need to implement mask for Local
|
|
1346
|
+
pass
|
|
1347
|
+
attn = mx.softmax(attn, axis=-1) # nn.functional.softmax(attn, dim=-1)
|
|
1348
|
+
attn = self.attn_drop(attn)
|
|
1349
|
+
|
|
1350
|
+
x = (attn @ v).transpose(0, 2, 1, 3).reshape((-1, N, C)) # Match exact reshape
|
|
1351
|
+
x = self.proj(x)
|
|
1352
|
+
x = self.proj_drop(x)
|
|
1353
|
+
return x
|
|
1354
|
+
|
|
1355
|
+
|
|
1356
|
+
class Block(nn.Module):
|
|
1357
|
+
def __init__(
|
|
1358
|
+
self,
|
|
1359
|
+
dim,
|
|
1360
|
+
num_heads,
|
|
1361
|
+
mixer="Global",
|
|
1362
|
+
local_mixer=[7, 11],
|
|
1363
|
+
HW=None,
|
|
1364
|
+
mlp_ratio=4.0,
|
|
1365
|
+
qkv_bias=False,
|
|
1366
|
+
qk_scale=None,
|
|
1367
|
+
drop=0.0,
|
|
1368
|
+
attn_drop=0.0,
|
|
1369
|
+
drop_path=0.0,
|
|
1370
|
+
act_layer="gelu",
|
|
1371
|
+
norm_layer="nn.LayerNorm",
|
|
1372
|
+
epsilon=1e-6,
|
|
1373
|
+
prenorm=False, # Set to False to match PyTorch
|
|
1374
|
+
):
|
|
1375
|
+
super().__init__()
|
|
1376
|
+
self.norm1 = nn.LayerNorm(dim, eps=epsilon)
|
|
1377
|
+
self.mixer = Attention(
|
|
1378
|
+
dim,
|
|
1379
|
+
num_heads=num_heads,
|
|
1380
|
+
mixer=mixer,
|
|
1381
|
+
HW=HW,
|
|
1382
|
+
local_k=local_mixer,
|
|
1383
|
+
qkv_bias=qkv_bias,
|
|
1384
|
+
qk_scale=qk_scale,
|
|
1385
|
+
attn_drop=attn_drop,
|
|
1386
|
+
proj_drop=drop,
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
self.norm2 = nn.LayerNorm(dim, eps=epsilon)
|
|
1390
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
1391
|
+
self.mlp = Mlp(
|
|
1392
|
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
|
1393
|
+
)
|
|
1394
|
+
self.prenorm = prenorm
|
|
1395
|
+
self.drop_path = drop_path
|
|
1396
|
+
|
|
1397
|
+
def __call__(self, x):
|
|
1398
|
+
if self.prenorm:
|
|
1399
|
+
x = self.norm1(x + self._drop_path(self.mixer(x)))
|
|
1400
|
+
x = self.norm2(x + self._drop_path(self.mlp(x)))
|
|
1401
|
+
else:
|
|
1402
|
+
# This is the path that will be taken (prenorm=False)
|
|
1403
|
+
x = x + self._drop_path(self.mixer(self.norm1(x)))
|
|
1404
|
+
x = x + self._drop_path(self.mlp(self.norm2(x)))
|
|
1405
|
+
return x
|
|
1406
|
+
|
|
1407
|
+
def _drop_path(self, x):
|
|
1408
|
+
# For inference, drop_path is disabled, so just return x
|
|
1409
|
+
return x
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
class SequenceEncoder(nn.Module):
|
|
1413
|
+
def __init__(self, in_channels, encoder_type="svtr", **kwargs):
|
|
1414
|
+
super().__init__()
|
|
1415
|
+
self.encoder_type = encoder_type.lower()
|
|
1416
|
+
self.encoder_reshape = Im2Seq(in_channels)
|
|
1417
|
+
|
|
1418
|
+
if self.encoder_type == "svtr":
|
|
1419
|
+
self.encoder = EncoderWithSVTR(in_channels, **kwargs)
|
|
1420
|
+
self.out_channels = self.encoder.out_channels
|
|
1421
|
+
self.only_reshape = False
|
|
1422
|
+
else:
|
|
1423
|
+
self.out_channels = in_channels
|
|
1424
|
+
self.only_reshape = True
|
|
1425
|
+
|
|
1426
|
+
def __call__(self, x):
|
|
1427
|
+
if self.encoder_type == "svtr":
|
|
1428
|
+
# For SVTR: encoder works on 2D data first, then reshape
|
|
1429
|
+
x = self.encoder(x) # x is still (B, H, W, C)
|
|
1430
|
+
x = self.encoder_reshape(x) # Now reshape to (B, W, C)
|
|
1431
|
+
return x
|
|
1432
|
+
else:
|
|
1433
|
+
# For others: reshape first, then encoder
|
|
1434
|
+
x = self.encoder_reshape(x)
|
|
1435
|
+
if not self.only_reshape:
|
|
1436
|
+
x = self.encoder(x)
|
|
1437
|
+
return x
|
|
1438
|
+
|
|
1439
|
+
|
|
1440
|
+
class CTCHead(nn.Module):
|
|
1441
|
+
def __init__(
|
|
1442
|
+
self,
|
|
1443
|
+
in_channels,
|
|
1444
|
+
out_channels,
|
|
1445
|
+
fc_decay=0.0004,
|
|
1446
|
+
mid_channels=None,
|
|
1447
|
+
return_feats=False,
|
|
1448
|
+
**kwargs,
|
|
1449
|
+
):
|
|
1450
|
+
super().__init__()
|
|
1451
|
+
self.return_feats = return_feats
|
|
1452
|
+
self.mid_channels = mid_channels
|
|
1453
|
+
|
|
1454
|
+
if mid_channels is None:
|
|
1455
|
+
self.fc = nn.Linear(in_channels, out_channels, bias=True)
|
|
1456
|
+
else:
|
|
1457
|
+
self.fc1 = nn.Linear(in_channels, mid_channels, bias=True)
|
|
1458
|
+
self.fc2 = nn.Linear(mid_channels, out_channels, bias=True)
|
|
1459
|
+
|
|
1460
|
+
self.out_channels = out_channels
|
|
1461
|
+
|
|
1462
|
+
def __call__(self, x):
|
|
1463
|
+
if self.mid_channels is None:
|
|
1464
|
+
predicts = self.fc(x)
|
|
1465
|
+
else:
|
|
1466
|
+
x = self.fc1(x)
|
|
1467
|
+
predicts = self.fc2(x)
|
|
1468
|
+
|
|
1469
|
+
if self.return_feats:
|
|
1470
|
+
result = (x, predicts)
|
|
1471
|
+
else:
|
|
1472
|
+
result = predicts
|
|
1473
|
+
|
|
1474
|
+
# Apply softmax for inference using MLX
|
|
1475
|
+
if not self.training:
|
|
1476
|
+
predicts = mx.softmax(predicts, axis=2)
|
|
1477
|
+
result = predicts
|
|
1478
|
+
|
|
1479
|
+
return result
|
|
1480
|
+
|
|
1481
|
+
|
|
1482
|
+
class MultiHead(nn.Module):
|
|
1483
|
+
def __init__(self, in_channels, out_channels_list, head_list, **kwargs):
|
|
1484
|
+
super().__init__()
|
|
1485
|
+
self.head_list = head_list
|
|
1486
|
+
|
|
1487
|
+
for idx, head_name in enumerate(self.head_list):
|
|
1488
|
+
name = list(head_name)[0]
|
|
1489
|
+
if name == "CTCHead":
|
|
1490
|
+
# No separate encoder_reshape - it's handled inside SequenceEncoder
|
|
1491
|
+
neck_args = self.head_list[idx][name]["Neck"].copy()
|
|
1492
|
+
encoder_type = neck_args.pop("name")
|
|
1493
|
+
self.ctc_encoder = SequenceEncoder(
|
|
1494
|
+
in_channels=in_channels, encoder_type=encoder_type, **neck_args
|
|
1495
|
+
)
|
|
1496
|
+
# CTC head
|
|
1497
|
+
head_args = self.head_list[idx][name].get("Head", {})
|
|
1498
|
+
if head_args is None:
|
|
1499
|
+
head_args = {}
|
|
1500
|
+
self.ctc_head = CTCHead(
|
|
1501
|
+
in_channels=self.ctc_encoder.out_channels,
|
|
1502
|
+
out_channels=out_channels_list["CTCLabelDecode"],
|
|
1503
|
+
**head_args,
|
|
1504
|
+
)
|
|
1505
|
+
|
|
1506
|
+
def __call__(self, x, data=None):
|
|
1507
|
+
# Direct call to ctc_encoder - let it handle reshaping internally
|
|
1508
|
+
ctc_encoder = self.ctc_encoder(x)
|
|
1509
|
+
ctc_out = self.ctc_head(ctc_encoder)
|
|
1510
|
+
|
|
1511
|
+
# Eval mode
|
|
1512
|
+
if not self.training:
|
|
1513
|
+
return ctc_out
|
|
1514
|
+
|
|
1515
|
+
head_out = dict()
|
|
1516
|
+
head_out["ctc"] = ctc_out
|
|
1517
|
+
head_out["res"] = ctc_out
|
|
1518
|
+
head_out["ctc_neck"] = ctc_encoder
|
|
1519
|
+
return head_out
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
class TextRecognizer(nn.Module):
|
|
1523
|
+
def __init__(self, args, **kwargs):
|
|
1524
|
+
super().__init__()
|
|
1525
|
+
|
|
1526
|
+
self.rec_image_shape = [3, 48, 320]
|
|
1527
|
+
self.rec_batch_num = args.rec_batch_num
|
|
1528
|
+
self.limited_max_width = args.limited_max_width
|
|
1529
|
+
self.limited_min_width = args.limited_min_width
|
|
1530
|
+
|
|
1531
|
+
# Character dictionary path
|
|
1532
|
+
postprocess_params = {
|
|
1533
|
+
"character_type": args.rec_char_type,
|
|
1534
|
+
"character_dict_path": args.rec_char_dict_path,
|
|
1535
|
+
"use_space_char": args.use_space_char,
|
|
1536
|
+
}
|
|
1537
|
+
self.postprocess_op = CTCLabelDecode(**postprocess_params)
|
|
1538
|
+
|
|
1539
|
+
# Get character number
|
|
1540
|
+
char_num = len(getattr(self.postprocess_op, "character"))
|
|
1541
|
+
|
|
1542
|
+
# Recognition backbone - reuse existing PPLCNetV3 (already handles transpose)
|
|
1543
|
+
self.backbone = PPLCNetV3(scale=0.95, det=False)
|
|
1544
|
+
|
|
1545
|
+
# Recognition head
|
|
1546
|
+
head_config = {
|
|
1547
|
+
"head_list": [
|
|
1548
|
+
{
|
|
1549
|
+
"CTCHead": {
|
|
1550
|
+
"Neck": {
|
|
1551
|
+
"name": "svtr",
|
|
1552
|
+
"dims": 120,
|
|
1553
|
+
"depth": 2,
|
|
1554
|
+
"hidden_dims": 120,
|
|
1555
|
+
"kernel_size": [1, 3],
|
|
1556
|
+
"use_guide": True,
|
|
1557
|
+
},
|
|
1558
|
+
"Head": {"fc_decay": 1e-05},
|
|
1559
|
+
}
|
|
1560
|
+
},
|
|
1561
|
+
],
|
|
1562
|
+
"out_channels_list": {
|
|
1563
|
+
"CTCLabelDecode": char_num,
|
|
1564
|
+
},
|
|
1565
|
+
"in_channels": 480, # PPLCNetV3 output channels
|
|
1566
|
+
}
|
|
1567
|
+
self.head = MultiHead(**head_config)
|
|
1568
|
+
|
|
1569
|
+
def resize_norm_img(self, img, max_wh_ratio):
|
|
1570
|
+
imgC, imgH, imgW = self.rec_image_shape
|
|
1571
|
+
|
|
1572
|
+
assert imgC == img.shape[2]
|
|
1573
|
+
max_wh_ratio = max(max_wh_ratio, imgW / imgH)
|
|
1574
|
+
imgW = int((imgH * max_wh_ratio))
|
|
1575
|
+
imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
|
|
1576
|
+
h, w = img.shape[:2]
|
|
1577
|
+
ratio = w / float(h)
|
|
1578
|
+
ratio_imgH = int(np.ceil(imgH * ratio))
|
|
1579
|
+
ratio_imgH = max(ratio_imgH, self.limited_min_width)
|
|
1580
|
+
if ratio_imgH > imgW:
|
|
1581
|
+
resized_w = imgW
|
|
1582
|
+
else:
|
|
1583
|
+
resized_w = int(ratio_imgH)
|
|
1584
|
+
|
|
1585
|
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
|
1586
|
+
resized_image = resized_image.astype("float32")
|
|
1587
|
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
1588
|
+
resized_image -= 0.5
|
|
1589
|
+
resized_image /= 0.5
|
|
1590
|
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
|
1591
|
+
padding_im[:, :, 0:resized_w] = resized_image
|
|
1592
|
+
return padding_im
|
|
1593
|
+
|
|
1594
|
+
def __call__(self, img_list):
|
|
1595
|
+
img_num = len(img_list)
|
|
1596
|
+
# Calculate aspect ratio and sort for batching efficiency
|
|
1597
|
+
width_list = []
|
|
1598
|
+
for img in img_list:
|
|
1599
|
+
width_list.append(img.shape[1] / float(img.shape[0]))
|
|
1600
|
+
indices = np.argsort(np.array(width_list))
|
|
1601
|
+
|
|
1602
|
+
rec_res = [["", 0.0]] * img_num
|
|
1603
|
+
batch_num = self.rec_batch_num
|
|
1604
|
+
elapse = 0
|
|
1605
|
+
|
|
1606
|
+
for beg_img_no in range(0, img_num, batch_num):
|
|
1607
|
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
|
1608
|
+
norm_img_batch = []
|
|
1609
|
+
max_wh_ratio = 0
|
|
1610
|
+
|
|
1611
|
+
# Calculate max width/height ratio for this batch
|
|
1612
|
+
for ino in range(beg_img_no, end_img_no):
|
|
1613
|
+
h, w = img_list[indices[ino]].shape[0:2]
|
|
1614
|
+
wh_ratio = w * 1.0 / h
|
|
1615
|
+
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
|
1616
|
+
|
|
1617
|
+
# Normalize images in batch
|
|
1618
|
+
for ino in range(beg_img_no, end_img_no):
|
|
1619
|
+
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
|
1620
|
+
norm_img = norm_img[np.newaxis, :]
|
|
1621
|
+
norm_img_batch.append(norm_img)
|
|
1622
|
+
|
|
1623
|
+
norm_img_batch = np.concatenate(norm_img_batch)
|
|
1624
|
+
|
|
1625
|
+
starttime = time.time()
|
|
1626
|
+
|
|
1627
|
+
# Forward pass
|
|
1628
|
+
inp = mx.array(norm_img_batch)
|
|
1629
|
+
# PPLCNetV3 backbone already handles the transpose from (B, C, H, W) to (B, H, W, C)
|
|
1630
|
+
backbone_out = self.backbone(inp)
|
|
1631
|
+
head_out = self.head(backbone_out)
|
|
1632
|
+
|
|
1633
|
+
preds = np.array(head_out)
|
|
1634
|
+
rec_result = self.postprocess_op(preds)
|
|
1635
|
+
for rno in range(len(rec_result)):
|
|
1636
|
+
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
|
1637
|
+
elapse += time.time() - starttime
|
|
1638
|
+
|
|
1639
|
+
return rec_res, elapse
|
|
1640
|
+
|
|
1641
|
+
|
|
1642
|
+
def test_recognizer(args):
|
|
1643
|
+
loaded = np.load(
|
|
1644
|
+
"/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/rec_input.npz"
|
|
1645
|
+
)
|
|
1646
|
+
img_list = [loaded[f"arr_{i}"] for i in range(len(loaded.files))]
|
|
1647
|
+
recognizer = TextRecognizer(args)
|
|
1648
|
+
# recognizer.load_weights(
|
|
1649
|
+
# "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer.safetensors"
|
|
1650
|
+
# )
|
|
1651
|
+
# recognizer.save_weights(
|
|
1652
|
+
# "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer.safetensors"
|
|
1653
|
+
# )
|
|
1654
|
+
# recognizer.set_dtype(mx.float16)
|
|
1655
|
+
# recognizer.save_weights(
|
|
1656
|
+
# "/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer_f16.safetensors"
|
|
1657
|
+
# )
|
|
1658
|
+
recognizer.load_weights(
|
|
1659
|
+
"/Users/alexchen/Desktop/LocalDev/nexaml-mlx/examples/paddle_ocr/modelfiles/ch_ptocr_v4_rec_infer_f16.safetensors"
|
|
1660
|
+
)
|
|
1661
|
+
recognizer.eval() # Important for BatchNorm behavior in MLX
|
|
1662
|
+
|
|
1663
|
+
rec_res, elapse = recognizer(img_list)
|
|
1664
|
+
print(f"Recognition results: {rec_res}")
|
|
1665
|
+
print(f"Recognition time: {elapse:.3f}s")
|
|
1666
|
+
|
|
1667
|
+
|
|
1668
|
+
class TextSystem:
|
|
1669
|
+
"""OCR text detection and recognition system"""
|
|
1670
|
+
def __init__(self, args):
|
|
1671
|
+
self.det = TextDetector(args)
|
|
1672
|
+
self.rec = TextRecognizer(args)
|
|
1673
|
+
self.drop_score = args.drop_score
|
|
1674
|
+
|
|
1675
|
+
# Load weights from safetensors
|
|
1676
|
+
self.det.load_weights(args.det_model_path)
|
|
1677
|
+
self.rec.load_weights(args.rec_model_path)
|
|
1678
|
+
|
|
1679
|
+
self.det.eval()
|
|
1680
|
+
self.rec.eval()
|
|
1681
|
+
|
|
1682
|
+
@staticmethod
|
|
1683
|
+
def _order_boxes(boxes: np.ndarray) -> List[np.ndarray]:
|
|
1684
|
+
"""Order detected boxes by position (top to bottom, left to right)"""
|
|
1685
|
+
return sorted(boxes, key=lambda b: (b[0][1], b[0][0]))
|
|
1686
|
+
|
|
1687
|
+
@staticmethod
|
|
1688
|
+
def _crop_rotated(img: np.ndarray, pts: np.ndarray) -> np.ndarray:
|
|
1689
|
+
"""Crop rotated text region from image"""
|
|
1690
|
+
pts = pts.astype("float32")
|
|
1691
|
+
w = int(max(np.linalg.norm(pts[0] - pts[1]), np.linalg.norm(pts[2] - pts[3])))
|
|
1692
|
+
h = int(max(np.linalg.norm(pts[0] - pts[3]), np.linalg.norm(pts[1] - pts[2])))
|
|
1693
|
+
M = cv2.getPerspectiveTransform(
|
|
1694
|
+
pts, np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype="float32")
|
|
1695
|
+
)
|
|
1696
|
+
dst = cv2.warpPerspective(img, M, (w, h), borderMode=cv2.BORDER_REPLICATE)
|
|
1697
|
+
if h / max(w, 1) > 1.5:
|
|
1698
|
+
dst = np.rot90(dst)
|
|
1699
|
+
return dst
|
|
1700
|
+
|
|
1701
|
+
def __call__(self, img: np.ndarray) -> Tuple[List[np.ndarray], List[Tuple[str, float]]]:
|
|
1702
|
+
"""Perform OCR on input image"""
|
|
1703
|
+
boxes = self.det(img)
|
|
1704
|
+
if boxes is None or len(boxes) == 0:
|
|
1705
|
+
return [], []
|
|
1706
|
+
|
|
1707
|
+
boxes = self._order_boxes(boxes)
|
|
1708
|
+
crops = [self._crop_rotated(img, b.copy()) for b in boxes]
|
|
1709
|
+
|
|
1710
|
+
rec_res, _ = self.rec(crops)
|
|
1711
|
+
|
|
1712
|
+
keep_boxes, keep_txt = [], []
|
|
1713
|
+
for box, (txt, score) in zip(boxes, rec_res):
|
|
1714
|
+
if score >= self.drop_score:
|
|
1715
|
+
keep_boxes.append(box)
|
|
1716
|
+
keep_txt.append((txt, float(score)))
|
|
1717
|
+
return keep_boxes, keep_txt
|
|
1718
|
+
|
|
1719
|
+
|
|
1720
|
+
if __name__ == "__main__":
|
|
1721
|
+
config = Config()
|
|
1722
|
+
text_system = TextSystem(config)
|
|
1723
|
+
# Test with a sample image from model directory if available
|
|
1724
|
+
img_path = os.path.join(config.model_cache_dir, "1.jpg")
|
|
1725
|
+
if not os.path.exists(img_path):
|
|
1726
|
+
print("No test image found. Please provide an image path for testing.")
|
|
1727
|
+
sys.exit(1)
|
|
1728
|
+
|
|
1729
|
+
img = cv2.imread(img_path)
|
|
1730
|
+
if img is None:
|
|
1731
|
+
print(f"Error: Could not read image at {img_path}")
|
|
1732
|
+
sys.exit(1)
|
|
1733
|
+
|
|
1734
|
+
boxes, txts = text_system(img)
|
|
1735
|
+
print(f"Detected {len(boxes)} boxes")
|
|
1736
|
+
print(f"Recognized text: {txts}")
|