nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
# Copyright © 2024 Apple Inc.
|
|
2
|
-
|
|
3
|
-
from mlx_lm import generate, load
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def test_llm_generate_stream(model_path):
|
|
7
|
-
# Load the corresponding model and tokenizer
|
|
8
|
-
model, tokenizer = load(path_or_hf_repo=model_path)
|
|
9
|
-
|
|
10
|
-
# Conversation history to maintain context
|
|
11
|
-
conversation = []
|
|
12
|
-
|
|
13
|
-
# Specify the maximum number of tokens
|
|
14
|
-
max_tokens = 1_000
|
|
15
|
-
|
|
16
|
-
# Specify if tokens and timing information will be printed
|
|
17
|
-
verbose = True
|
|
18
|
-
|
|
19
|
-
print("Multi-round conversation started. Type 'quit' or 'exit' to end.")
|
|
20
|
-
print("=" * 50)
|
|
21
|
-
|
|
22
|
-
while True:
|
|
23
|
-
# Get user input
|
|
24
|
-
user_input = input("\nUser: ").strip()
|
|
25
|
-
|
|
26
|
-
# Check for exit commands
|
|
27
|
-
if user_input.lower() in ["quit", "exit", "q"]:
|
|
28
|
-
print("Goodbye!")
|
|
29
|
-
break
|
|
30
|
-
|
|
31
|
-
if not user_input:
|
|
32
|
-
continue
|
|
33
|
-
|
|
34
|
-
# Add user input to conversation history
|
|
35
|
-
conversation.append({"role": "user", "content": user_input})
|
|
36
|
-
|
|
37
|
-
# Transform the conversation into the chat template
|
|
38
|
-
prompt = tokenizer.apply_chat_template(
|
|
39
|
-
conversation=conversation, add_generation_prompt=True
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
# Generate response
|
|
43
|
-
print("Assistant: ", end="", flush=True)
|
|
44
|
-
|
|
45
|
-
# Generate text, already handled KV cache
|
|
46
|
-
response = generate(
|
|
47
|
-
model=model,
|
|
48
|
-
tokenizer=tokenizer,
|
|
49
|
-
prompt=prompt,
|
|
50
|
-
max_tokens=max_tokens,
|
|
51
|
-
verbose=verbose,
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
# Extract the generated text (response includes the prompt)
|
|
55
|
-
generated_text = response.strip()
|
|
56
|
-
|
|
57
|
-
# Add assistant response to conversation history
|
|
58
|
-
conversation.append({"role": "assistant", "content": generated_text})
|
|
59
|
-
|
|
60
|
-
print() # New line after response
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
if __name__ == "__main__":
|
|
64
|
-
import argparse
|
|
65
|
-
parser = argparse.ArgumentParser()
|
|
66
|
-
parser.add_argument("--model_path", type=str, default="mlx-community/Qwen3-1.7B-4bit-DWQ")
|
|
67
|
-
args = parser.parse_args()
|
|
68
|
-
test_llm_generate_stream(args.model_path)
|
|
File without changes
|
|
@@ -1,174 +0,0 @@
|
|
|
1
|
-
# Copyright © Nexa AI
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import sys
|
|
16
|
-
import os
|
|
17
|
-
import mlx.core as mx
|
|
18
|
-
import mlx.nn as nn
|
|
19
|
-
import numpy as np
|
|
20
|
-
import time
|
|
21
|
-
|
|
22
|
-
from transformers import AutoTokenizer
|
|
23
|
-
from huggingface_hub import snapshot_download
|
|
24
|
-
from .modeling.nexa_jina_rerank import Model, ModelArgs
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
|
28
|
-
"""Create position ids from input ids, accounting for padding tokens"""
|
|
29
|
-
mask = (input_ids != padding_idx).astype(mx.int32)
|
|
30
|
-
incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
|
|
31
|
-
return incremental_indices.astype(mx.int32) + padding_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def prepare_inputs(query, documents, tokenizer, max_length=1024):
|
|
35
|
-
"""Prepare inputs for the model - match torch exactly"""
|
|
36
|
-
sentence_pairs = [[query, doc] for doc in documents]
|
|
37
|
-
inputs = tokenizer(
|
|
38
|
-
sentence_pairs,
|
|
39
|
-
padding="max_length",
|
|
40
|
-
truncation=True,
|
|
41
|
-
return_tensors="np",
|
|
42
|
-
max_length=max_length,
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
|
|
46
|
-
seqlen = input_ids.shape[1]
|
|
47
|
-
attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
|
|
48
|
-
|
|
49
|
-
# Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
|
|
50
|
-
token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
|
|
51
|
-
batch_size = input_ids.shape[0]
|
|
52
|
-
token_type_ids = mx.broadcast_to(
|
|
53
|
-
mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
# Create position ids for each sequence in the batch
|
|
57
|
-
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=1)
|
|
58
|
-
|
|
59
|
-
return input_ids, attention_mask, token_type_ids, position_ids
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def load_model(model_id):
|
|
63
|
-
"""Initialize and load the Jina V2 rerank model."""
|
|
64
|
-
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
65
|
-
model_dir = f"{curr_dir}/modelfiles/nexaml_jina_v2_rerank_mlx"
|
|
66
|
-
|
|
67
|
-
# Download model if not exists
|
|
68
|
-
if not os.path.exists(model_dir):
|
|
69
|
-
print(f"Downloading model {model_id}...")
|
|
70
|
-
|
|
71
|
-
os.makedirs(model_dir, exist_ok=True)
|
|
72
|
-
|
|
73
|
-
try:
|
|
74
|
-
snapshot_download(
|
|
75
|
-
repo_id=model_id,
|
|
76
|
-
allow_patterns=["*.safetensors", "config.json", "tokenizer*"],
|
|
77
|
-
local_dir=model_dir,
|
|
78
|
-
local_dir_use_symlinks=False
|
|
79
|
-
)
|
|
80
|
-
print("Model download completed!")
|
|
81
|
-
except Exception as e:
|
|
82
|
-
print(f"Failed to download model: {e}")
|
|
83
|
-
print("Try: huggingface-cli login (if authentication required)")
|
|
84
|
-
raise
|
|
85
|
-
|
|
86
|
-
# Create model config
|
|
87
|
-
config = ModelArgs()
|
|
88
|
-
model = Model(config)
|
|
89
|
-
|
|
90
|
-
# Load weights
|
|
91
|
-
weight_file = os.path.join(model_dir, "model.safetensors")
|
|
92
|
-
if not os.path.exists(weight_file):
|
|
93
|
-
# Try alternative naming patterns
|
|
94
|
-
safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
|
95
|
-
if safetensors_files:
|
|
96
|
-
weight_file = os.path.join(model_dir, safetensors_files[0])
|
|
97
|
-
else:
|
|
98
|
-
raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
|
|
99
|
-
|
|
100
|
-
print(f"Loading weights from: {weight_file}")
|
|
101
|
-
model.load_weights(weight_file, strict=True)
|
|
102
|
-
model.eval()
|
|
103
|
-
|
|
104
|
-
return model, model_dir
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def load_tokenizer(model_path):
|
|
108
|
-
"""Load and configure the tokenizer."""
|
|
109
|
-
return AutoTokenizer.from_pretrained(model_path)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def rerank_documents(model, tokenizer, query, documents, max_length=1024):
|
|
113
|
-
"""Rerank documents based on query relevance."""
|
|
114
|
-
# Prepare inputs
|
|
115
|
-
input_ids, attention_mask, token_type_ids, position_ids = prepare_inputs(
|
|
116
|
-
query, documents, tokenizer, max_length
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
# Run inference
|
|
120
|
-
start_time = time.time()
|
|
121
|
-
scores = model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
|
|
122
|
-
scores = mx.squeeze(scores, axis=-1)
|
|
123
|
-
end_time = time.time()
|
|
124
|
-
|
|
125
|
-
# Apply sigmoid to get probabilities
|
|
126
|
-
scores_sigmoid = mx.sigmoid(scores)
|
|
127
|
-
|
|
128
|
-
inference_time = (end_time - start_time) * 1000 # Convert to ms
|
|
129
|
-
|
|
130
|
-
return scores, scores_sigmoid, inference_time
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def main(model_id):
|
|
134
|
-
"""Main function to handle reranking demonstration."""
|
|
135
|
-
|
|
136
|
-
# Load model and tokenizer
|
|
137
|
-
model, model_path = load_model(model_id)
|
|
138
|
-
tokenizer = load_tokenizer(model_path)
|
|
139
|
-
|
|
140
|
-
# Example query and documents
|
|
141
|
-
query = "What are the health benefits of green tea?"
|
|
142
|
-
documents = [
|
|
143
|
-
"Green tea is rich in antioxidants and may improve brain function.",
|
|
144
|
-
"Coffee contains caffeine and can boost energy levels.",
|
|
145
|
-
"Das Trinken von grünem Tee kann das Risiko für Herzkrankheiten senken.",
|
|
146
|
-
"Black tea is another popular beverage with its own health benefits.",
|
|
147
|
-
]
|
|
148
|
-
|
|
149
|
-
# Perform reranking
|
|
150
|
-
scores, scores_sigmoid, inference_time = rerank_documents(
|
|
151
|
-
model, tokenizer, query, documents
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
# Display results
|
|
155
|
-
print("=" * 70)
|
|
156
|
-
print("Reranking Results:")
|
|
157
|
-
print("=" * 70)
|
|
158
|
-
print(f"Query: {query}")
|
|
159
|
-
print()
|
|
160
|
-
|
|
161
|
-
for i, (doc, score, prob) in enumerate(zip(documents, scores.tolist(), scores_sigmoid.tolist())):
|
|
162
|
-
print(f"Document {i+1}:")
|
|
163
|
-
print(f" Text: {doc}")
|
|
164
|
-
print(f" Score: {score:.4f}")
|
|
165
|
-
print(f" Probability: {prob:.4f}")
|
|
166
|
-
print()
|
|
167
|
-
|
|
168
|
-
print(f"Inference time: {inference_time:.1f}ms")
|
|
169
|
-
print(f"Throughput: {len(documents)/inference_time*1000:.1f} docs/s")
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
if __name__ == "__main__":
|
|
173
|
-
model_id = "nexaml/jina-v2-rerank-mlx"
|
|
174
|
-
main(model_id)
|
|
@@ -1,287 +0,0 @@
|
|
|
1
|
-
# Copyright © Nexa AI
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.s
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import os
|
|
18
|
-
import json
|
|
19
|
-
import mlx.core as mx
|
|
20
|
-
import mlx.nn as nn
|
|
21
|
-
import numpy as np
|
|
22
|
-
import time
|
|
23
|
-
from pathlib import Path
|
|
24
|
-
from typing import Any, List, Optional, Sequence
|
|
25
|
-
from dataclasses import dataclass
|
|
26
|
-
from abc import ABC, abstractmethod
|
|
27
|
-
|
|
28
|
-
# Import necessary modules
|
|
29
|
-
from transformers import AutoTokenizer
|
|
30
|
-
|
|
31
|
-
# Import from ml.py for API alignment (assuming similar structure)
|
|
32
|
-
try:
|
|
33
|
-
from ml import (
|
|
34
|
-
Reranker as BaseReranker,
|
|
35
|
-
Path as PathType,
|
|
36
|
-
)
|
|
37
|
-
except ImportError:
|
|
38
|
-
# Fallback to local definitions if ml.py not available
|
|
39
|
-
PathType = Path
|
|
40
|
-
BaseReranker = ABC
|
|
41
|
-
|
|
42
|
-
# Import profiling module
|
|
43
|
-
from profiling import ProfilingMixin, ProfilingData, StopReason
|
|
44
|
-
|
|
45
|
-
# Import the model implementation
|
|
46
|
-
from .modeling.nexa_jina_rerank import Model, ModelArgs
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@dataclass
|
|
50
|
-
class RerankConfig:
|
|
51
|
-
"""Configuration for reranking."""
|
|
52
|
-
batch_size: int = 1
|
|
53
|
-
normalize: bool = True
|
|
54
|
-
normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
|
|
55
|
-
|
|
56
|
-
def __init__(
|
|
57
|
-
self,
|
|
58
|
-
batch_size: int = 1,
|
|
59
|
-
normalize: bool = True,
|
|
60
|
-
normalize_method: str = "softmax",
|
|
61
|
-
) -> None:
|
|
62
|
-
self.batch_size = batch_size
|
|
63
|
-
self.normalize = normalize
|
|
64
|
-
self.normalize_method = normalize_method
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class Reranker(BaseReranker, ProfilingMixin):
|
|
68
|
-
"""
|
|
69
|
-
Reranker interface for MLX reranking models.
|
|
70
|
-
API aligned with ml.py Reranker abstract base class.
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
model_path: PathType,
|
|
76
|
-
tokenizer_path: PathType,
|
|
77
|
-
device: Optional[str] = None,
|
|
78
|
-
) -> None:
|
|
79
|
-
"""Initialize the Reranker model."""
|
|
80
|
-
# Initialize profiling mixin
|
|
81
|
-
ProfilingMixin.__init__(self)
|
|
82
|
-
|
|
83
|
-
# Store paths
|
|
84
|
-
if (os.path.isfile(model_path)):
|
|
85
|
-
model_path = os.path.dirname(model_path)
|
|
86
|
-
|
|
87
|
-
# Call parent constructor if inheriting from ml.py
|
|
88
|
-
if hasattr(super(), '__init__'):
|
|
89
|
-
super().__init__(model_path, tokenizer_path, device)
|
|
90
|
-
|
|
91
|
-
# Store paths and device
|
|
92
|
-
self.model_path = model_path
|
|
93
|
-
self.tokenizer_path = tokenizer_path
|
|
94
|
-
self.device = device if device is not None else "cpu"
|
|
95
|
-
|
|
96
|
-
# Initialize model and tokenizer as None
|
|
97
|
-
self.model = None
|
|
98
|
-
self.tokenizer = None
|
|
99
|
-
self.config = None
|
|
100
|
-
|
|
101
|
-
def destroy(self) -> None:
|
|
102
|
-
"""Destroy the model and free resources."""
|
|
103
|
-
self.model = None
|
|
104
|
-
self.tokenizer = None
|
|
105
|
-
self.config = None
|
|
106
|
-
|
|
107
|
-
def load_model(self, model_path: PathType, extra_data: Any = None) -> bool:
|
|
108
|
-
"""Load model from path."""
|
|
109
|
-
try:
|
|
110
|
-
# Use the provided model_path or fall back to instance path
|
|
111
|
-
if model_path:
|
|
112
|
-
# Apply same file-to-directory conversion as in __init__
|
|
113
|
-
if os.path.isfile(model_path):
|
|
114
|
-
model_path = os.path.dirname(model_path)
|
|
115
|
-
self.model_path = model_path
|
|
116
|
-
|
|
117
|
-
# Load the model using internal implementation
|
|
118
|
-
self.model = self._load_jina_model(self.model_path)
|
|
119
|
-
self.tokenizer = self._load_tokenizer()
|
|
120
|
-
|
|
121
|
-
return True
|
|
122
|
-
except Exception as e:
|
|
123
|
-
print(f"Failed to load model: {e}")
|
|
124
|
-
return False
|
|
125
|
-
|
|
126
|
-
def close(self) -> None:
|
|
127
|
-
"""Close the model."""
|
|
128
|
-
self.destroy()
|
|
129
|
-
|
|
130
|
-
def rerank(
|
|
131
|
-
self,
|
|
132
|
-
query: str,
|
|
133
|
-
documents: Sequence[str],
|
|
134
|
-
config: Optional[RerankConfig] = None,
|
|
135
|
-
clear_cache: bool = True,
|
|
136
|
-
) -> mx.array:
|
|
137
|
-
"""Rerank documents given a query."""
|
|
138
|
-
if self.model is None or self.tokenizer is None:
|
|
139
|
-
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
140
|
-
|
|
141
|
-
if config is None:
|
|
142
|
-
config = RerankConfig()
|
|
143
|
-
|
|
144
|
-
# Start profiling
|
|
145
|
-
self._start_profiling()
|
|
146
|
-
self._prompt_start()
|
|
147
|
-
|
|
148
|
-
all_scores = []
|
|
149
|
-
|
|
150
|
-
# Process documents in batches
|
|
151
|
-
batch_size = config.batch_size
|
|
152
|
-
for i in range(0, len(documents), batch_size):
|
|
153
|
-
batch_docs = documents[i:i + batch_size]
|
|
154
|
-
batch_scores = self._rerank_batch(query, batch_docs, config)
|
|
155
|
-
all_scores.append(batch_scores)
|
|
156
|
-
|
|
157
|
-
if clear_cache:
|
|
158
|
-
mx.clear_cache()
|
|
159
|
-
|
|
160
|
-
# End prompt processing, start decode
|
|
161
|
-
self._prompt_end()
|
|
162
|
-
self._decode_start()
|
|
163
|
-
|
|
164
|
-
# Concatenate all batch scores into a single array
|
|
165
|
-
res = mx.concatenate(all_scores, axis=0) if len(all_scores) > 1 else all_scores[0]
|
|
166
|
-
|
|
167
|
-
# End decode and profiling
|
|
168
|
-
self._decode_end()
|
|
169
|
-
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
170
|
-
self._end_profiling()
|
|
171
|
-
|
|
172
|
-
return res
|
|
173
|
-
|
|
174
|
-
def _load_jina_model(self, model_dir: str) -> Model:
|
|
175
|
-
"""Initialize and load the Jina V2 rerank model."""
|
|
176
|
-
|
|
177
|
-
# Validate that model path exists
|
|
178
|
-
if not os.path.exists(model_dir):
|
|
179
|
-
raise ValueError(f"Model path does not exist: {model_dir}")
|
|
180
|
-
|
|
181
|
-
# Store model directory for tokenizer loading
|
|
182
|
-
self._model_dir = model_dir
|
|
183
|
-
|
|
184
|
-
# Create model config
|
|
185
|
-
config = ModelArgs()
|
|
186
|
-
model = Model(config)
|
|
187
|
-
|
|
188
|
-
# Load weights
|
|
189
|
-
weight_file = os.path.join(model_dir, "model.safetensors")
|
|
190
|
-
if not os.path.exists(weight_file):
|
|
191
|
-
# Try alternative naming patterns
|
|
192
|
-
safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
|
193
|
-
if safetensors_files:
|
|
194
|
-
weight_file = os.path.join(model_dir, safetensors_files[0])
|
|
195
|
-
else:
|
|
196
|
-
raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
|
|
197
|
-
|
|
198
|
-
model.load_weights(weight_file, strict=True)
|
|
199
|
-
model.eval()
|
|
200
|
-
|
|
201
|
-
return model
|
|
202
|
-
|
|
203
|
-
def _load_tokenizer(self) -> AutoTokenizer:
|
|
204
|
-
"""Load and configure the tokenizer."""
|
|
205
|
-
return AutoTokenizer.from_pretrained(self._model_dir)
|
|
206
|
-
|
|
207
|
-
def _rerank_batch(self, query: str, documents: List[str], config: RerankConfig) -> mx.array:
|
|
208
|
-
"""Rerank a batch of documents and return their scores."""
|
|
209
|
-
# Prepare inputs
|
|
210
|
-
input_ids, attention_mask, token_type_ids, position_ids = self._prepare_inputs(
|
|
211
|
-
query, documents, self.tokenizer, max_length=1024
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
# Run inference
|
|
215
|
-
scores = self.model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
|
|
216
|
-
scores = mx.squeeze(scores, axis=-1)
|
|
217
|
-
|
|
218
|
-
# Apply normalization if requested
|
|
219
|
-
if config.normalize:
|
|
220
|
-
scores = self._normalize_scores(scores, config.normalize_method)
|
|
221
|
-
|
|
222
|
-
return scores
|
|
223
|
-
|
|
224
|
-
def _create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
|
|
225
|
-
"""Create position ids from input ids, accounting for padding tokens"""
|
|
226
|
-
mask = (input_ids != padding_idx).astype(mx.int32)
|
|
227
|
-
incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
|
|
228
|
-
return incremental_indices.astype(mx.int32) + padding_idx
|
|
229
|
-
|
|
230
|
-
def _prepare_inputs(self, query, documents, tokenizer, max_length=1024):
|
|
231
|
-
"""Prepare inputs for the model - match torch exactly"""
|
|
232
|
-
sentence_pairs = [[query, doc] for doc in documents]
|
|
233
|
-
inputs = tokenizer(
|
|
234
|
-
sentence_pairs,
|
|
235
|
-
padding="max_length",
|
|
236
|
-
truncation=True,
|
|
237
|
-
return_tensors="np",
|
|
238
|
-
max_length=max_length,
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
|
|
242
|
-
seqlen = input_ids.shape[1]
|
|
243
|
-
attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
|
|
244
|
-
|
|
245
|
-
# Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
|
|
246
|
-
token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
|
|
247
|
-
batch_size = input_ids.shape[0]
|
|
248
|
-
token_type_ids = mx.broadcast_to(
|
|
249
|
-
mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
# Create position ids for each sequence in the batch
|
|
253
|
-
position_ids = self._create_position_ids_from_input_ids(input_ids, padding_idx=1)
|
|
254
|
-
|
|
255
|
-
return input_ids, attention_mask, token_type_ids, position_ids
|
|
256
|
-
|
|
257
|
-
def _normalize_scores(self, scores: mx.array, method: str) -> mx.array:
|
|
258
|
-
"""Normalize scores using specified method."""
|
|
259
|
-
if method == "none":
|
|
260
|
-
return scores
|
|
261
|
-
elif method == "softmax":
|
|
262
|
-
# For 1D arrays, use axis=0; for higher dims, use axis=-1
|
|
263
|
-
if len(scores.shape) == 1:
|
|
264
|
-
return mx.softmax(scores, axis=0)
|
|
265
|
-
else:
|
|
266
|
-
return mx.softmax(scores, axis=-1)
|
|
267
|
-
elif method == "min-max":
|
|
268
|
-
min_val = mx.min(scores)
|
|
269
|
-
max_val = mx.max(scores)
|
|
270
|
-
if max_val > min_val:
|
|
271
|
-
return (scores - min_val) / (max_val - min_val)
|
|
272
|
-
return scores
|
|
273
|
-
else:
|
|
274
|
-
return scores
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
# Factory function for creating reranker instances
|
|
278
|
-
def create_reranker(
|
|
279
|
-
model_path: PathType,
|
|
280
|
-
tokenizer_path: Optional[PathType] = None,
|
|
281
|
-
device: Optional[str] = None,
|
|
282
|
-
) -> Reranker:
|
|
283
|
-
"""Create and return a Reranker instance."""
|
|
284
|
-
if tokenizer_path is None:
|
|
285
|
-
tokenizer_path = model_path
|
|
286
|
-
|
|
287
|
-
return Reranker(model_path, tokenizer_path, device)
|
|
@@ -1,127 +0,0 @@
|
|
|
1
|
-
# Copyright © Nexa AI
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import time
|
|
16
|
-
import mlx.core as mx
|
|
17
|
-
from .interface import create_reranker, RerankConfig
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def test_reranking():
|
|
21
|
-
"""Test reranking model functionality."""
|
|
22
|
-
# Create reranker instance
|
|
23
|
-
model_path = "nexaml/jina-v2-rerank-mlx"
|
|
24
|
-
reranker = create_reranker(model_path=model_path)
|
|
25
|
-
|
|
26
|
-
# Load the model
|
|
27
|
-
print("Loading reranking model...")
|
|
28
|
-
success = reranker.load_model(model_path, extra_data="nexaml/jina-v2-rerank-mlx")
|
|
29
|
-
|
|
30
|
-
if not success:
|
|
31
|
-
print("Failed to load model!")
|
|
32
|
-
return
|
|
33
|
-
|
|
34
|
-
print("✅ Model loaded successfully!")
|
|
35
|
-
|
|
36
|
-
# Test query and documents (same as generate.py)
|
|
37
|
-
query = "What are the health benefits of green tea?"
|
|
38
|
-
documents = [
|
|
39
|
-
"Green tea is rich in antioxidants and may improve brain function.",
|
|
40
|
-
"Coffee contains caffeine and can boost energy levels.",
|
|
41
|
-
"Das Trinken von grünem Tee kann das Risiko für Herzkrankheiten senken.",
|
|
42
|
-
"Black tea is another popular beverage with its own health benefits.",
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
# Configure reranking with no normalization to get raw scores
|
|
46
|
-
config = RerankConfig(
|
|
47
|
-
batch_size=len(documents),
|
|
48
|
-
normalize=False,
|
|
49
|
-
normalize_method="none"
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
# Generate reranking scores
|
|
53
|
-
start_time = time.time()
|
|
54
|
-
scores = reranker.rerank(query, documents, config)
|
|
55
|
-
end_time = time.time()
|
|
56
|
-
|
|
57
|
-
# Calculate sigmoid probabilities manually
|
|
58
|
-
scores_sigmoid = mx.sigmoid(scores).tolist()
|
|
59
|
-
|
|
60
|
-
inference_time = (end_time - start_time) * 1000 # Convert to ms
|
|
61
|
-
|
|
62
|
-
print("=" * 70)
|
|
63
|
-
print("Reranking Results:")
|
|
64
|
-
print("=" * 70)
|
|
65
|
-
print(f"Query: {query}")
|
|
66
|
-
print()
|
|
67
|
-
|
|
68
|
-
for i, (doc, score, prob) in enumerate(zip(documents, scores.tolist(), scores_sigmoid)):
|
|
69
|
-
print(f"Document {i+1}:")
|
|
70
|
-
print(f" Text: {doc}")
|
|
71
|
-
print(f" Score: {score:.4f}")
|
|
72
|
-
print(f" Probability: {prob:.4f}")
|
|
73
|
-
print()
|
|
74
|
-
|
|
75
|
-
print(f"Inference time: {inference_time:.1f}ms")
|
|
76
|
-
print(f"Throughput: {len(documents)/inference_time*1000:.1f} docs/s")
|
|
77
|
-
|
|
78
|
-
# Cleanup
|
|
79
|
-
reranker.close()
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def main(model_id):
|
|
83
|
-
"""Main function to handle reranking demonstration - aligned with embedding generate.py format."""
|
|
84
|
-
# Create reranker instance
|
|
85
|
-
reranker = create_reranker(model_path=model_id)
|
|
86
|
-
|
|
87
|
-
# Load the model
|
|
88
|
-
success = reranker.load_model(model_id, extra_data=model_id)
|
|
89
|
-
|
|
90
|
-
if not success:
|
|
91
|
-
print("Failed to load model!")
|
|
92
|
-
return
|
|
93
|
-
|
|
94
|
-
# Simple test like embedding generate.py
|
|
95
|
-
query = "What are the health benefits of green tea?"
|
|
96
|
-
documents = [
|
|
97
|
-
"Green tea is rich in antioxidants and may improve brain function.",
|
|
98
|
-
"Coffee contains caffeine and can boost energy levels.",
|
|
99
|
-
]
|
|
100
|
-
|
|
101
|
-
# Get raw scores
|
|
102
|
-
config = RerankConfig(normalize=False)
|
|
103
|
-
scores = reranker.rerank(query, documents, config)
|
|
104
|
-
|
|
105
|
-
# Calculate statistics on raw MLX array
|
|
106
|
-
scores_sigmoid = mx.sigmoid(scores)
|
|
107
|
-
|
|
108
|
-
print(f"Scores shape: {scores.shape}")
|
|
109
|
-
print(f"Score sample values: {scores.tolist()}")
|
|
110
|
-
print(f"Scores min: {scores.min():.4f}, Max: {scores.max():.4f}, Mean: {scores.mean():.4f}, Std: {scores.std():.4f}")
|
|
111
|
-
print(f"Sigmoid probabilities: {scores_sigmoid.tolist()}")
|
|
112
|
-
|
|
113
|
-
# Cleanup
|
|
114
|
-
reranker.close()
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
if __name__ == "__main__":
|
|
118
|
-
import argparse
|
|
119
|
-
parser = argparse.ArgumentParser()
|
|
120
|
-
parser.add_argument("--model_path", type=str, default="nexaml/jina-v2-rerank-mlx")
|
|
121
|
-
args = parser.parse_args()
|
|
122
|
-
|
|
123
|
-
# Use test_reranking for comprehensive test, main for simple format like generate.py
|
|
124
|
-
if hasattr(args, 'simple') and args.simple:
|
|
125
|
-
main(args.model_path)
|
|
126
|
-
else:
|
|
127
|
-
test_reranking()
|
|
File without changes
|