nexaai 1.0.21rc5__cp313-cp313-win_arm64.whl → 1.0.21rc14__cp313-cp313-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +95 -95
- nexaai/_stub.cp313-win_arm64.pyd +0 -0
- nexaai/_version.py +4 -1
- nexaai/asr.py +68 -65
- nexaai/asr_impl/mlx_asr_impl.py +92 -92
- nexaai/asr_impl/pybind_asr_impl.py +127 -44
- nexaai/base.py +39 -39
- nexaai/binds/__init__.py +6 -5
- nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
- nexaai/binds/cpu_gpu/ggml.dll +0 -0
- nexaai/binds/cpu_gpu/mtmd.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
- nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/libssl-3-arm64.dll +0 -0
- nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/npu/convnext-sdk.dll +0 -0
- nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
- nexaai/binds/npu/ggml-base.dll +0 -0
- nexaai/binds/npu/ggml-cpu.dll +0 -0
- nexaai/binds/npu/ggml-opencl.dll +0 -0
- nexaai/binds/npu/ggml.dll +0 -0
- nexaai/binds/npu/granite-nano-sdk.dll +0 -0
- nexaai/binds/npu/granite4-sdk.dll +0 -0
- nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
- nexaai/binds/npu/liquid-sdk.dll +0 -0
- nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
- nexaai/binds/npu/nexa-mm-process.dll +0 -0
- nexaai/binds/npu/nexa-sampling.dll +0 -0
- nexaai/binds/npu/nexa_plugin.dll +0 -0
- nexaai/binds/npu/omni-neural-sdk.dll +0 -0
- nexaai/binds/npu/openblas.dll +0 -0
- nexaai/binds/npu/paddleocr-sdk.dll +0 -0
- nexaai/binds/npu/parakeet-sdk.dll +0 -0
- nexaai/binds/npu/phi3-5-sdk.dll +0 -0
- nexaai/binds/npu/phi4-sdk.dll +0 -0
- nexaai/binds/npu/pyannote-sdk.dll +0 -0
- nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-vision.dll +0 -0
- nexaai/binds/npu/yolov12-sdk.dll +0 -0
- nexaai/binds/npu/zlib1.dll +0 -0
- nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/common.py +105 -105
- nexaai/cv.py +93 -93
- nexaai/cv_impl/mlx_cv_impl.py +89 -89
- nexaai/cv_impl/pybind_cv_impl.py +32 -32
- nexaai/embedder.py +73 -73
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -118
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -96
- nexaai/image_gen.py +141 -141
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -292
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -85
- nexaai/llm.py +98 -98
- nexaai/llm_impl/mlx_llm_impl.py +271 -271
- nexaai/llm_impl/pybind_llm_impl.py +220 -220
- nexaai/log.py +92 -92
- nexaai/rerank.py +57 -57
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -94
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -136
- nexaai/runtime.py +68 -68
- nexaai/runtime_error.py +24 -24
- nexaai/tts.py +75 -75
- nexaai/tts_impl/mlx_tts_impl.py +94 -94
- nexaai/tts_impl/pybind_tts_impl.py +43 -43
- nexaai/utils/decode.py +17 -17
- nexaai/utils/manifest_utils.py +531 -531
- nexaai/utils/model_manager.py +1562 -1562
- nexaai/utils/model_types.py +49 -49
- nexaai/utils/progress_tracker.py +384 -384
- nexaai/utils/quantization_utils.py +245 -245
- nexaai/vlm.py +129 -129
- nexaai/vlm_impl/mlx_vlm_impl.py +258 -258
- nexaai/vlm_impl/pybind_vlm_impl.py +256 -256
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/METADATA +1 -1
- nexaai-1.0.21rc14.dist-info/RECORD +154 -0
- nexaai/binds/nexaml/FLAC.dll +0 -0
- nexaai/binds/nexaml/fftw3.dll +0 -0
- nexaai/binds/nexaml/fftw3f.dll +0 -0
- nexaai/binds/nexaml/ggml-base.dll +0 -0
- nexaai/binds/nexaml/ggml-cpu.dll +0 -0
- nexaai/binds/nexaml/ggml-opencl.dll +0 -0
- nexaai/binds/nexaml/ggml.dll +0 -0
- nexaai/binds/nexaml/libmp3lame.DLL +0 -0
- nexaai/binds/nexaml/mpg123.dll +0 -0
- nexaai/binds/nexaml/nexa-mm-process.dll +0 -0
- nexaai/binds/nexaml/nexa-sampling.dll +0 -0
- nexaai/binds/nexaml/nexa_plugin.dll +0 -0
- nexaai/binds/nexaml/nexaproc.dll +0 -0
- nexaai/binds/nexaml/ogg.dll +0 -0
- nexaai/binds/nexaml/opus.dll +0 -0
- nexaai/binds/nexaml/qwen3-vl.dll +0 -0
- nexaai/binds/nexaml/qwen3vl-vision.dll +0 -0
- nexaai/binds/nexaml/vorbis.dll +0 -0
- nexaai/binds/nexaml/vorbisenc.dll +0 -0
- nexaai-1.0.21rc5.dist-info/RECORD +0 -162
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/WHEEL +0 -0
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/top_level.txt +0 -0
nexaai/rerank.py
CHANGED
|
@@ -1,57 +1,57 @@
|
|
|
1
|
-
from typing import List, Optional, Sequence, Union
|
|
2
|
-
from abc import abstractmethod
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
|
|
5
|
-
from nexaai.base import BaseModel
|
|
6
|
-
from nexaai.common import PluginID
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@dataclass
|
|
10
|
-
class RerankConfig:
|
|
11
|
-
"""Configuration for reranking."""
|
|
12
|
-
batch_size: int = 1
|
|
13
|
-
normalize: bool = True
|
|
14
|
-
normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class Reranker(BaseModel):
|
|
18
|
-
"""Abstract base class for reranker models."""
|
|
19
|
-
|
|
20
|
-
def __init__(self):
|
|
21
|
-
"""Initialize base Reranker class."""
|
|
22
|
-
pass
|
|
23
|
-
|
|
24
|
-
@classmethod
|
|
25
|
-
def _load_from(cls,
|
|
26
|
-
model_path: str,
|
|
27
|
-
model_name: str = None,
|
|
28
|
-
tokenizer_file: str = "tokenizer.json",
|
|
29
|
-
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
30
|
-
device_id: Optional[str] = None,
|
|
31
|
-
**kwargs
|
|
32
|
-
) -> 'Reranker':
|
|
33
|
-
"""Load reranker model from local path, routing to appropriate implementation."""
|
|
34
|
-
# Check plugin_id value for routing - handle both enum and string
|
|
35
|
-
plugin_value = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
36
|
-
|
|
37
|
-
if plugin_value == "mlx":
|
|
38
|
-
from nexaai.rerank_impl.mlx_rerank_impl import MLXRerankImpl
|
|
39
|
-
return MLXRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
|
|
40
|
-
else:
|
|
41
|
-
from nexaai.rerank_impl.pybind_rerank_impl import PyBindRerankImpl
|
|
42
|
-
return PyBindRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
|
|
43
|
-
|
|
44
|
-
@abstractmethod
|
|
45
|
-
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
46
|
-
"""Load model from path."""
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
@abstractmethod
|
|
50
|
-
def rerank(
|
|
51
|
-
self,
|
|
52
|
-
query: str,
|
|
53
|
-
documents: Sequence[str],
|
|
54
|
-
config: Optional[RerankConfig] = None,
|
|
55
|
-
) -> List[float]:
|
|
56
|
-
"""Rerank documents given a query."""
|
|
57
|
-
pass
|
|
1
|
+
from typing import List, Optional, Sequence, Union
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from nexaai.base import BaseModel
|
|
6
|
+
from nexaai.common import PluginID
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class RerankConfig:
|
|
11
|
+
"""Configuration for reranking."""
|
|
12
|
+
batch_size: int = 1
|
|
13
|
+
normalize: bool = True
|
|
14
|
+
normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Reranker(BaseModel):
|
|
18
|
+
"""Abstract base class for reranker models."""
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
"""Initialize base Reranker class."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def _load_from(cls,
|
|
26
|
+
model_path: str,
|
|
27
|
+
model_name: str = None,
|
|
28
|
+
tokenizer_file: str = "tokenizer.json",
|
|
29
|
+
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
30
|
+
device_id: Optional[str] = None,
|
|
31
|
+
**kwargs
|
|
32
|
+
) -> 'Reranker':
|
|
33
|
+
"""Load reranker model from local path, routing to appropriate implementation."""
|
|
34
|
+
# Check plugin_id value for routing - handle both enum and string
|
|
35
|
+
plugin_value = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
36
|
+
|
|
37
|
+
if plugin_value == "mlx":
|
|
38
|
+
from nexaai.rerank_impl.mlx_rerank_impl import MLXRerankImpl
|
|
39
|
+
return MLXRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
|
|
40
|
+
else:
|
|
41
|
+
from nexaai.rerank_impl.pybind_rerank_impl import PyBindRerankImpl
|
|
42
|
+
return PyBindRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
46
|
+
"""Load model from path."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def rerank(
|
|
51
|
+
self,
|
|
52
|
+
query: str,
|
|
53
|
+
documents: Sequence[str],
|
|
54
|
+
config: Optional[RerankConfig] = None,
|
|
55
|
+
) -> List[float]:
|
|
56
|
+
"""Rerank documents given a query."""
|
|
57
|
+
pass
|
|
@@ -1,94 +1,94 @@
|
|
|
1
|
-
# Note: This code is generated by Cursor, not tested yet.
|
|
2
|
-
|
|
3
|
-
from typing import List, Optional, Sequence, Union
|
|
4
|
-
import os
|
|
5
|
-
|
|
6
|
-
from nexaai.common import PluginID
|
|
7
|
-
from nexaai.rerank import Reranker, RerankConfig
|
|
8
|
-
from nexaai.mlx_backend.rerank.interface import Reranker as MLXRerankInterface, create_reranker
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class MLXRerankImpl(Reranker):
|
|
12
|
-
def __init__(self):
|
|
13
|
-
"""Initialize MLX Rerank implementation."""
|
|
14
|
-
super().__init__()
|
|
15
|
-
self._mlx_reranker = None
|
|
16
|
-
|
|
17
|
-
@classmethod
|
|
18
|
-
def _load_from(cls,
|
|
19
|
-
model_path: str,
|
|
20
|
-
model_name: str = None,
|
|
21
|
-
tokenizer_file: str = "tokenizer.json",
|
|
22
|
-
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
23
|
-
device_id: Optional[str] = None
|
|
24
|
-
) -> 'MLXRerankImpl':
|
|
25
|
-
"""Load reranker model from local path using MLX backend."""
|
|
26
|
-
try:
|
|
27
|
-
# MLX Rerank interfaces are already imported
|
|
28
|
-
|
|
29
|
-
# Create instance and load MLX reranker
|
|
30
|
-
instance = cls()
|
|
31
|
-
instance._mlx_reranker = create_reranker(
|
|
32
|
-
model_path=model_path,
|
|
33
|
-
# model_name=model_name, # FIXME: For MLX Reranker, model_name is not used
|
|
34
|
-
tokenizer_path=tokenizer_file,
|
|
35
|
-
device=device_id
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# Load the model
|
|
39
|
-
success = instance._mlx_reranker.load_model(model_path)
|
|
40
|
-
if not success:
|
|
41
|
-
raise RuntimeError("Failed to load MLX reranker model")
|
|
42
|
-
|
|
43
|
-
return instance
|
|
44
|
-
except Exception as e:
|
|
45
|
-
raise RuntimeError(f"Failed to load MLX Reranker: {str(e)}")
|
|
46
|
-
|
|
47
|
-
def eject(self):
|
|
48
|
-
"""Destroy the model and free resources."""
|
|
49
|
-
if self._mlx_reranker:
|
|
50
|
-
self._mlx_reranker.destroy()
|
|
51
|
-
self._mlx_reranker = None
|
|
52
|
-
|
|
53
|
-
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
54
|
-
"""Load model from path."""
|
|
55
|
-
if not self._mlx_reranker:
|
|
56
|
-
raise RuntimeError("MLX Reranker not initialized")
|
|
57
|
-
|
|
58
|
-
try:
|
|
59
|
-
return self._mlx_reranker.load_model(model_path, extra_data)
|
|
60
|
-
except Exception as e:
|
|
61
|
-
raise RuntimeError(f"Failed to load reranker model: {str(e)}")
|
|
62
|
-
|
|
63
|
-
def rerank(
|
|
64
|
-
self,
|
|
65
|
-
query: str,
|
|
66
|
-
documents: Sequence[str],
|
|
67
|
-
config: Optional[RerankConfig] = None,
|
|
68
|
-
) -> List[float]:
|
|
69
|
-
"""Rerank documents given a query."""
|
|
70
|
-
if not self._mlx_reranker:
|
|
71
|
-
raise RuntimeError("MLX Reranker not loaded")
|
|
72
|
-
|
|
73
|
-
try:
|
|
74
|
-
# Convert our config to MLX format if provided
|
|
75
|
-
mlx_config = None
|
|
76
|
-
if config:
|
|
77
|
-
from nexaai.mlx_backend.rerank.interface import RerankConfig as MLXRerankConfig
|
|
78
|
-
|
|
79
|
-
mlx_config = MLXRerankConfig(
|
|
80
|
-
batch_size=config.batch_size,
|
|
81
|
-
normalize=config.normalize,
|
|
82
|
-
normalize_method=config.normalize_method
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
# Use MLX reranking
|
|
86
|
-
scores = self._mlx_reranker.rerank(query, documents, mlx_config)
|
|
87
|
-
|
|
88
|
-
# Convert mx.array to Python list of floats
|
|
89
|
-
return scores.tolist()
|
|
90
|
-
|
|
91
|
-
except Exception as e:
|
|
92
|
-
raise RuntimeError(f"Failed to rerank documents: {str(e)}")
|
|
93
|
-
|
|
94
|
-
|
|
1
|
+
# Note: This code is generated by Cursor, not tested yet.
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Sequence, Union
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
from nexaai.common import PluginID
|
|
7
|
+
from nexaai.rerank import Reranker, RerankConfig
|
|
8
|
+
from nexaai.mlx_backend.rerank.interface import Reranker as MLXRerankInterface, create_reranker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MLXRerankImpl(Reranker):
|
|
12
|
+
def __init__(self):
|
|
13
|
+
"""Initialize MLX Rerank implementation."""
|
|
14
|
+
super().__init__()
|
|
15
|
+
self._mlx_reranker = None
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def _load_from(cls,
|
|
19
|
+
model_path: str,
|
|
20
|
+
model_name: str = None,
|
|
21
|
+
tokenizer_file: str = "tokenizer.json",
|
|
22
|
+
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
23
|
+
device_id: Optional[str] = None
|
|
24
|
+
) -> 'MLXRerankImpl':
|
|
25
|
+
"""Load reranker model from local path using MLX backend."""
|
|
26
|
+
try:
|
|
27
|
+
# MLX Rerank interfaces are already imported
|
|
28
|
+
|
|
29
|
+
# Create instance and load MLX reranker
|
|
30
|
+
instance = cls()
|
|
31
|
+
instance._mlx_reranker = create_reranker(
|
|
32
|
+
model_path=model_path,
|
|
33
|
+
# model_name=model_name, # FIXME: For MLX Reranker, model_name is not used
|
|
34
|
+
tokenizer_path=tokenizer_file,
|
|
35
|
+
device=device_id
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Load the model
|
|
39
|
+
success = instance._mlx_reranker.load_model(model_path)
|
|
40
|
+
if not success:
|
|
41
|
+
raise RuntimeError("Failed to load MLX reranker model")
|
|
42
|
+
|
|
43
|
+
return instance
|
|
44
|
+
except Exception as e:
|
|
45
|
+
raise RuntimeError(f"Failed to load MLX Reranker: {str(e)}")
|
|
46
|
+
|
|
47
|
+
def eject(self):
|
|
48
|
+
"""Destroy the model and free resources."""
|
|
49
|
+
if self._mlx_reranker:
|
|
50
|
+
self._mlx_reranker.destroy()
|
|
51
|
+
self._mlx_reranker = None
|
|
52
|
+
|
|
53
|
+
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
54
|
+
"""Load model from path."""
|
|
55
|
+
if not self._mlx_reranker:
|
|
56
|
+
raise RuntimeError("MLX Reranker not initialized")
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
return self._mlx_reranker.load_model(model_path, extra_data)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
raise RuntimeError(f"Failed to load reranker model: {str(e)}")
|
|
62
|
+
|
|
63
|
+
def rerank(
|
|
64
|
+
self,
|
|
65
|
+
query: str,
|
|
66
|
+
documents: Sequence[str],
|
|
67
|
+
config: Optional[RerankConfig] = None,
|
|
68
|
+
) -> List[float]:
|
|
69
|
+
"""Rerank documents given a query."""
|
|
70
|
+
if not self._mlx_reranker:
|
|
71
|
+
raise RuntimeError("MLX Reranker not loaded")
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
# Convert our config to MLX format if provided
|
|
75
|
+
mlx_config = None
|
|
76
|
+
if config:
|
|
77
|
+
from nexaai.mlx_backend.rerank.interface import RerankConfig as MLXRerankConfig
|
|
78
|
+
|
|
79
|
+
mlx_config = MLXRerankConfig(
|
|
80
|
+
batch_size=config.batch_size,
|
|
81
|
+
normalize=config.normalize,
|
|
82
|
+
normalize_method=config.normalize_method
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Use MLX reranking
|
|
86
|
+
scores = self._mlx_reranker.rerank(query, documents, mlx_config)
|
|
87
|
+
|
|
88
|
+
# Convert mx.array to Python list of floats
|
|
89
|
+
return scores.tolist()
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise RuntimeError(f"Failed to rerank documents: {str(e)}")
|
|
93
|
+
|
|
94
|
+
|
|
@@ -1,136 +1,136 @@
|
|
|
1
|
-
from typing import List, Optional, Sequence, Union
|
|
2
|
-
import numpy as np
|
|
3
|
-
|
|
4
|
-
from nexaai.common import PluginID
|
|
5
|
-
from nexaai.rerank import Reranker, RerankConfig
|
|
6
|
-
from nexaai.binds import rerank_bind, common_bind
|
|
7
|
-
from nexaai.runtime import _ensure_runtime
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class PyBindRerankImpl(Reranker):
|
|
11
|
-
def __init__(self, _handle_ptr):
|
|
12
|
-
"""
|
|
13
|
-
Internal initializer
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
_handle_ptr: Capsule handle to the C++ reranker object
|
|
17
|
-
"""
|
|
18
|
-
super().__init__()
|
|
19
|
-
self._handle = _handle_ptr
|
|
20
|
-
|
|
21
|
-
@classmethod
|
|
22
|
-
def _load_from(cls,
|
|
23
|
-
model_path: str,
|
|
24
|
-
model_name: str = None,
|
|
25
|
-
tokenizer_file: str = "tokenizer.json",
|
|
26
|
-
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
27
|
-
device_id: Optional[str] = None
|
|
28
|
-
) -> 'PyBindRerankImpl':
|
|
29
|
-
"""
|
|
30
|
-
Load reranker model from local path using PyBind backend.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
model_path: Path to the model file
|
|
34
|
-
model_name: Name of the model (optional)
|
|
35
|
-
tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
|
|
36
|
-
plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
|
|
37
|
-
device_id: Device ID to use for the model (optional)
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
PyBindRerankImpl instance
|
|
41
|
-
"""
|
|
42
|
-
_ensure_runtime()
|
|
43
|
-
|
|
44
|
-
# Convert enum to string for C++ binding
|
|
45
|
-
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
46
|
-
|
|
47
|
-
# Create model config
|
|
48
|
-
model_config = common_bind.ModelConfig()
|
|
49
|
-
|
|
50
|
-
# Create reranker handle with new API signature
|
|
51
|
-
handle = rerank_bind.ml_reranker_create(
|
|
52
|
-
model_path,
|
|
53
|
-
model_name,
|
|
54
|
-
tokenizer_file,
|
|
55
|
-
model_config,
|
|
56
|
-
plugin_id_str,
|
|
57
|
-
device_id
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
return cls(handle)
|
|
61
|
-
|
|
62
|
-
def eject(self):
|
|
63
|
-
"""
|
|
64
|
-
Clean up resources and destroy the reranker
|
|
65
|
-
"""
|
|
66
|
-
# Destructor of the handle will unload the model correctly
|
|
67
|
-
if hasattr(self, '_handle') and self._handle is not None:
|
|
68
|
-
del self._handle
|
|
69
|
-
self._handle = None
|
|
70
|
-
|
|
71
|
-
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
72
|
-
"""
|
|
73
|
-
Load model from path.
|
|
74
|
-
|
|
75
|
-
Note: This method is not typically used directly. Use _load_from instead.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
model_path: Path to the model file
|
|
79
|
-
extra_data: Additional data (unused)
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
True if successful
|
|
83
|
-
"""
|
|
84
|
-
# This method is part of the BaseModel interface but typically not used
|
|
85
|
-
# directly for PyBind implementations since _load_from handles creation
|
|
86
|
-
raise NotImplementedError("Use _load_from class method to load models")
|
|
87
|
-
|
|
88
|
-
def rerank(
|
|
89
|
-
self,
|
|
90
|
-
query: str,
|
|
91
|
-
documents: Sequence[str],
|
|
92
|
-
config: Optional[RerankConfig] = None,
|
|
93
|
-
) -> List[float]:
|
|
94
|
-
"""
|
|
95
|
-
Rerank documents given a query.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
query: Query text as UTF-8 string
|
|
99
|
-
documents: List of document texts to rerank
|
|
100
|
-
config: Optional reranking configuration
|
|
101
|
-
|
|
102
|
-
Returns:
|
|
103
|
-
List of ranking scores (one per document)
|
|
104
|
-
"""
|
|
105
|
-
if self._handle is None:
|
|
106
|
-
raise RuntimeError("Reranker handle is None. Model may have been ejected.")
|
|
107
|
-
|
|
108
|
-
# Use default config if not provided
|
|
109
|
-
if config is None:
|
|
110
|
-
config = RerankConfig()
|
|
111
|
-
|
|
112
|
-
# Create bind config
|
|
113
|
-
bind_config = rerank_bind.RerankConfig()
|
|
114
|
-
bind_config.batch_size = config.batch_size
|
|
115
|
-
bind_config.normalize = config.normalize
|
|
116
|
-
bind_config.normalize_method = config.normalize_method
|
|
117
|
-
|
|
118
|
-
# Convert documents to list if needed
|
|
119
|
-
documents_list = list(documents)
|
|
120
|
-
|
|
121
|
-
# Call the binding which returns a dict with scores and profile_data
|
|
122
|
-
result = rerank_bind.ml_reranker_rerank(
|
|
123
|
-
self._handle,
|
|
124
|
-
query,
|
|
125
|
-
documents_list,
|
|
126
|
-
bind_config
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
# Extract scores from result dict
|
|
130
|
-
scores_array = result.get("scores", np.array([]))
|
|
131
|
-
|
|
132
|
-
# Convert numpy array to list of floats
|
|
133
|
-
if isinstance(scores_array, np.ndarray):
|
|
134
|
-
return scores_array.tolist()
|
|
135
|
-
else:
|
|
136
|
-
return []
|
|
1
|
+
from typing import List, Optional, Sequence, Union
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from nexaai.common import PluginID
|
|
5
|
+
from nexaai.rerank import Reranker, RerankConfig
|
|
6
|
+
from nexaai.binds import rerank_bind, common_bind
|
|
7
|
+
from nexaai.runtime import _ensure_runtime
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PyBindRerankImpl(Reranker):
|
|
11
|
+
def __init__(self, _handle_ptr):
|
|
12
|
+
"""
|
|
13
|
+
Internal initializer
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
_handle_ptr: Capsule handle to the C++ reranker object
|
|
17
|
+
"""
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._handle = _handle_ptr
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def _load_from(cls,
|
|
23
|
+
model_path: str,
|
|
24
|
+
model_name: str = None,
|
|
25
|
+
tokenizer_file: str = "tokenizer.json",
|
|
26
|
+
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
27
|
+
device_id: Optional[str] = None
|
|
28
|
+
) -> 'PyBindRerankImpl':
|
|
29
|
+
"""
|
|
30
|
+
Load reranker model from local path using PyBind backend.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_path: Path to the model file
|
|
34
|
+
model_name: Name of the model (optional)
|
|
35
|
+
tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
|
|
36
|
+
plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
|
|
37
|
+
device_id: Device ID to use for the model (optional)
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
PyBindRerankImpl instance
|
|
41
|
+
"""
|
|
42
|
+
_ensure_runtime()
|
|
43
|
+
|
|
44
|
+
# Convert enum to string for C++ binding
|
|
45
|
+
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
46
|
+
|
|
47
|
+
# Create model config
|
|
48
|
+
model_config = common_bind.ModelConfig()
|
|
49
|
+
|
|
50
|
+
# Create reranker handle with new API signature
|
|
51
|
+
handle = rerank_bind.ml_reranker_create(
|
|
52
|
+
model_path,
|
|
53
|
+
model_name,
|
|
54
|
+
tokenizer_file,
|
|
55
|
+
model_config,
|
|
56
|
+
plugin_id_str,
|
|
57
|
+
device_id
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return cls(handle)
|
|
61
|
+
|
|
62
|
+
def eject(self):
|
|
63
|
+
"""
|
|
64
|
+
Clean up resources and destroy the reranker
|
|
65
|
+
"""
|
|
66
|
+
# Destructor of the handle will unload the model correctly
|
|
67
|
+
if hasattr(self, '_handle') and self._handle is not None:
|
|
68
|
+
del self._handle
|
|
69
|
+
self._handle = None
|
|
70
|
+
|
|
71
|
+
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
Load model from path.
|
|
74
|
+
|
|
75
|
+
Note: This method is not typically used directly. Use _load_from instead.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
model_path: Path to the model file
|
|
79
|
+
extra_data: Additional data (unused)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
True if successful
|
|
83
|
+
"""
|
|
84
|
+
# This method is part of the BaseModel interface but typically not used
|
|
85
|
+
# directly for PyBind implementations since _load_from handles creation
|
|
86
|
+
raise NotImplementedError("Use _load_from class method to load models")
|
|
87
|
+
|
|
88
|
+
def rerank(
|
|
89
|
+
self,
|
|
90
|
+
query: str,
|
|
91
|
+
documents: Sequence[str],
|
|
92
|
+
config: Optional[RerankConfig] = None,
|
|
93
|
+
) -> List[float]:
|
|
94
|
+
"""
|
|
95
|
+
Rerank documents given a query.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
query: Query text as UTF-8 string
|
|
99
|
+
documents: List of document texts to rerank
|
|
100
|
+
config: Optional reranking configuration
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of ranking scores (one per document)
|
|
104
|
+
"""
|
|
105
|
+
if self._handle is None:
|
|
106
|
+
raise RuntimeError("Reranker handle is None. Model may have been ejected.")
|
|
107
|
+
|
|
108
|
+
# Use default config if not provided
|
|
109
|
+
if config is None:
|
|
110
|
+
config = RerankConfig()
|
|
111
|
+
|
|
112
|
+
# Create bind config
|
|
113
|
+
bind_config = rerank_bind.RerankConfig()
|
|
114
|
+
bind_config.batch_size = config.batch_size
|
|
115
|
+
bind_config.normalize = config.normalize
|
|
116
|
+
bind_config.normalize_method = config.normalize_method
|
|
117
|
+
|
|
118
|
+
# Convert documents to list if needed
|
|
119
|
+
documents_list = list(documents)
|
|
120
|
+
|
|
121
|
+
# Call the binding which returns a dict with scores and profile_data
|
|
122
|
+
result = rerank_bind.ml_reranker_rerank(
|
|
123
|
+
self._handle,
|
|
124
|
+
query,
|
|
125
|
+
documents_list,
|
|
126
|
+
bind_config
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Extract scores from result dict
|
|
130
|
+
scores_array = result.get("scores", np.array([]))
|
|
131
|
+
|
|
132
|
+
# Convert numpy array to list of floats
|
|
133
|
+
if isinstance(scores_array, np.ndarray):
|
|
134
|
+
return scores_array.tolist()
|
|
135
|
+
else:
|
|
136
|
+
return []
|