nexaai 1.0.19rc19__cp310-cp310-win_amd64.whl → 1.0.21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +12 -0
- nexaai/_stub.cp310-win_amd64.pyd +0 -0
- nexaai/_version.py +1 -1
- nexaai/asr.py +10 -6
- nexaai/asr_impl/pybind_asr_impl.py +98 -15
- nexaai/binds/__init__.py +2 -0
- nexaai/binds/asr_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/common_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cuda.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-vulkan.dll +0 -0
- nexaai/binds/cpu_gpu/ggml.dll +0 -0
- nexaai/binds/cpu_gpu/mtmd.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
- nexaai/binds/embedder_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/llm_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/nexaml/ggml-base.dll +0 -0
- nexaai/binds/nexaml/ggml-cpu.dll +0 -0
- nexaai/binds/nexaml/ggml-cuda.dll +0 -0
- nexaai/binds/nexaml/ggml-vulkan.dll +0 -0
- nexaai/binds/nexaml/ggml.dll +0 -0
- nexaai/binds/nexaml/nexa_plugin.dll +0 -0
- nexaai/binds/nexaml/nexaproc.dll +0 -0
- nexaai/binds/nexaml/qwen3-vl.dll +0 -0
- nexaai/binds/rerank_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/vlm_bind.cp310-win_amd64.pyd +0 -0
- nexaai/common.py +1 -0
- nexaai/cv.py +2 -1
- nexaai/embedder.py +4 -3
- nexaai/embedder_impl/mlx_embedder_impl.py +3 -1
- nexaai/embedder_impl/pybind_embedder_impl.py +3 -2
- nexaai/image_gen.py +2 -1
- nexaai/llm.py +5 -3
- nexaai/llm_impl/mlx_llm_impl.py +2 -0
- nexaai/llm_impl/pybind_llm_impl.py +2 -0
- nexaai/rerank.py +5 -3
- nexaai/rerank_impl/mlx_rerank_impl.py +2 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +109 -16
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +2 -1
- nexaai/utils/manifest_utils.py +10 -6
- nexaai/utils/model_manager.py +139 -8
- nexaai/vlm.py +4 -2
- nexaai/vlm_impl/mlx_vlm_impl.py +3 -2
- nexaai/vlm_impl/pybind_vlm_impl.py +33 -7
- {nexaai-1.0.19rc19.dist-info → nexaai-1.0.21.dist-info}/METADATA +2 -3
- nexaai-1.0.21.dist-info/RECORD +79 -0
- nexaai-1.0.19rc19.dist-info/RECORD +0 -76
- {nexaai-1.0.19rc19.dist-info → nexaai-1.0.21.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc19.dist-info → nexaai-1.0.21.dist-info}/top_level.txt +0 -0
nexaai/__init__.py
CHANGED
|
@@ -24,6 +24,13 @@ from .common import ModelConfig, GenerationConfig, ChatMessage, SamplerConfig, P
|
|
|
24
24
|
# Import logging functionality
|
|
25
25
|
from .log import set_logger, get_error_message
|
|
26
26
|
|
|
27
|
+
# Import runtime errors
|
|
28
|
+
from .runtime_error import (
|
|
29
|
+
NexaRuntimeError,
|
|
30
|
+
ContextLengthExceededError,
|
|
31
|
+
GenerationError
|
|
32
|
+
)
|
|
33
|
+
|
|
27
34
|
# Create alias for PluginID to be accessible as plugin_id
|
|
28
35
|
plugin_id = PluginID
|
|
29
36
|
|
|
@@ -52,6 +59,11 @@ __all__ = [
|
|
|
52
59
|
# Logging functionality
|
|
53
60
|
"set_logger",
|
|
54
61
|
"get_error_message",
|
|
62
|
+
|
|
63
|
+
# Runtime errors
|
|
64
|
+
"NexaRuntimeError",
|
|
65
|
+
"ContextLengthExceededError",
|
|
66
|
+
"GenerationError",
|
|
55
67
|
|
|
56
68
|
"LLM",
|
|
57
69
|
"Embedder",
|
nexaai/_stub.cp310-win_amd64.pyd
CHANGED
|
Binary file
|
nexaai/_version.py
CHANGED
nexaai/asr.py
CHANGED
|
@@ -3,7 +3,7 @@ from abc import abstractmethod
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
5
|
from nexaai.base import BaseModel
|
|
6
|
-
from nexaai.common import PluginID
|
|
6
|
+
from nexaai.common import PluginID, ModelConfig
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@dataclass
|
|
@@ -25,17 +25,20 @@ class ASRResult:
|
|
|
25
25
|
class ASR(BaseModel):
|
|
26
26
|
"""Abstract base class for Automatic Speech Recognition models."""
|
|
27
27
|
|
|
28
|
-
def __init__(self):
|
|
28
|
+
def __init__(self, m_cfg: ModelConfig = ModelConfig()):
|
|
29
29
|
"""Initialize base ASR class."""
|
|
30
|
-
|
|
30
|
+
self._m_cfg = m_cfg
|
|
31
31
|
|
|
32
32
|
@classmethod
|
|
33
33
|
def _load_from(cls,
|
|
34
34
|
model_path: str,
|
|
35
|
+
model_name: Optional[str] = None,
|
|
35
36
|
tokenizer_path: Optional[str] = None,
|
|
36
37
|
language: Optional[str] = None,
|
|
38
|
+
m_cfg: ModelConfig = ModelConfig(),
|
|
37
39
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
38
|
-
device_id: Optional[str] = None
|
|
40
|
+
device_id: Optional[str] = None,
|
|
41
|
+
**kwargs
|
|
39
42
|
) -> 'ASR':
|
|
40
43
|
"""Load ASR model from local path, routing to appropriate implementation."""
|
|
41
44
|
# Check plugin_id value for routing - handle both enum and string
|
|
@@ -43,10 +46,11 @@ class ASR(BaseModel):
|
|
|
43
46
|
|
|
44
47
|
if plugin_value == "mlx":
|
|
45
48
|
from nexaai.asr_impl.mlx_asr_impl import MLXASRImpl
|
|
46
|
-
return MLXASRImpl._load_from(model_path, tokenizer_path, language, plugin_id, device_id)
|
|
49
|
+
return MLXASRImpl._load_from(model_path, model_name, tokenizer_path, language, m_cfg, plugin_id, device_id)
|
|
47
50
|
else:
|
|
48
51
|
from nexaai.asr_impl.pybind_asr_impl import PyBindASRImpl
|
|
49
|
-
return PyBindASRImpl._load_from(model_path, tokenizer_path, language, plugin_id, device_id)
|
|
52
|
+
return PyBindASRImpl._load_from(model_path, model_name, tokenizer_path, language, m_cfg, plugin_id, device_id)
|
|
53
|
+
|
|
50
54
|
|
|
51
55
|
@abstractmethod
|
|
52
56
|
def transcribe(
|
|
@@ -1,32 +1,78 @@
|
|
|
1
1
|
from typing import List, Optional, Union
|
|
2
2
|
|
|
3
|
-
from nexaai.common import PluginID
|
|
3
|
+
from nexaai.common import PluginID, ModelConfig
|
|
4
4
|
from nexaai.asr import ASR, ASRConfig, ASRResult
|
|
5
|
+
from nexaai.binds import asr_bind, common_bind
|
|
6
|
+
from nexaai.runtime import _ensure_runtime
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class PyBindASRImpl(ASR):
|
|
8
|
-
def __init__(self):
|
|
9
|
-
"""
|
|
10
|
-
super().__init__()
|
|
11
|
-
#
|
|
10
|
+
def __init__(self, handle: any, m_cfg: ModelConfig = ModelConfig()):
|
|
11
|
+
"""Private constructor, should not be called directly."""
|
|
12
|
+
super().__init__(m_cfg)
|
|
13
|
+
self._handle = handle # This is a py::capsule
|
|
14
|
+
self._model_config = None
|
|
12
15
|
|
|
13
16
|
@classmethod
|
|
14
17
|
def _load_from(cls,
|
|
15
18
|
model_path: str,
|
|
19
|
+
model_name: Optional[str] = None,
|
|
16
20
|
tokenizer_path: Optional[str] = None,
|
|
17
21
|
language: Optional[str] = None,
|
|
22
|
+
m_cfg: ModelConfig = ModelConfig(),
|
|
18
23
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
19
24
|
device_id: Optional[str] = None
|
|
20
25
|
) -> 'PyBindASRImpl':
|
|
21
26
|
"""Load ASR model from local path using PyBind backend."""
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
27
|
+
_ensure_runtime()
|
|
28
|
+
|
|
29
|
+
# Create model config
|
|
30
|
+
config = common_bind.ModelConfig()
|
|
31
|
+
|
|
32
|
+
config.n_ctx = m_cfg.n_ctx
|
|
33
|
+
if m_cfg.n_threads is not None:
|
|
34
|
+
config.n_threads = m_cfg.n_threads
|
|
35
|
+
if m_cfg.n_threads_batch is not None:
|
|
36
|
+
config.n_threads_batch = m_cfg.n_threads_batch
|
|
37
|
+
if m_cfg.n_batch is not None:
|
|
38
|
+
config.n_batch = m_cfg.n_batch
|
|
39
|
+
if m_cfg.n_ubatch is not None:
|
|
40
|
+
config.n_ubatch = m_cfg.n_ubatch
|
|
41
|
+
if m_cfg.n_seq_max is not None:
|
|
42
|
+
config.n_seq_max = m_cfg.n_seq_max
|
|
43
|
+
config.n_gpu_layers = m_cfg.n_gpu_layers
|
|
44
|
+
|
|
45
|
+
# handle chat template strings
|
|
46
|
+
if m_cfg.chat_template_path:
|
|
47
|
+
config.chat_template_path = m_cfg.chat_template_path
|
|
48
|
+
|
|
49
|
+
if m_cfg.chat_template_content:
|
|
50
|
+
config.chat_template_content = m_cfg.chat_template_content
|
|
51
|
+
|
|
52
|
+
# Convert plugin_id to string
|
|
53
|
+
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else str(plugin_id)
|
|
54
|
+
|
|
55
|
+
# Create ASR handle using the binding
|
|
56
|
+
handle = asr_bind.ml_asr_create(
|
|
57
|
+
model_path=model_path,
|
|
58
|
+
model_name=model_name,
|
|
59
|
+
tokenizer_path=tokenizer_path,
|
|
60
|
+
model_config=config,
|
|
61
|
+
language=language,
|
|
62
|
+
plugin_id=plugin_id_str,
|
|
63
|
+
device_id=device_id,
|
|
64
|
+
license_id=None, # Optional
|
|
65
|
+
license_key=None # Optional
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return cls(handle, m_cfg)
|
|
25
69
|
|
|
26
70
|
def eject(self):
|
|
27
|
-
"""
|
|
28
|
-
#
|
|
29
|
-
|
|
71
|
+
"""Release the model from memory."""
|
|
72
|
+
# py::capsule handles cleanup automatically
|
|
73
|
+
if hasattr(self, '_handle') and self._handle is not None:
|
|
74
|
+
del self._handle
|
|
75
|
+
self._handle = None
|
|
30
76
|
|
|
31
77
|
def transcribe(
|
|
32
78
|
self,
|
|
@@ -35,10 +81,47 @@ class PyBindASRImpl(ASR):
|
|
|
35
81
|
config: Optional[ASRConfig] = None,
|
|
36
82
|
) -> ASRResult:
|
|
37
83
|
"""Transcribe audio file to text."""
|
|
38
|
-
|
|
39
|
-
|
|
84
|
+
if self._handle is None:
|
|
85
|
+
raise RuntimeError("ASR model not loaded. Call _load_from first.")
|
|
86
|
+
|
|
87
|
+
# Convert ASRConfig to binding format if provided
|
|
88
|
+
asr_config = None
|
|
89
|
+
if config:
|
|
90
|
+
asr_config = asr_bind.ASRConfig()
|
|
91
|
+
asr_config.timestamps = config.timestamps
|
|
92
|
+
asr_config.beam_size = config.beam_size
|
|
93
|
+
asr_config.stream = config.stream
|
|
94
|
+
|
|
95
|
+
# Perform transcription using the binding
|
|
96
|
+
result_dict = asr_bind.ml_asr_transcribe(
|
|
97
|
+
handle=self._handle,
|
|
98
|
+
audio_path=audio_path,
|
|
99
|
+
language=language,
|
|
100
|
+
config=asr_config
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Convert result to ASRResult
|
|
104
|
+
transcript = result_dict.get("transcript", "")
|
|
105
|
+
confidence_scores = result_dict.get("confidence_scores")
|
|
106
|
+
timestamps = result_dict.get("timestamps")
|
|
107
|
+
|
|
108
|
+
# Convert timestamps to the expected format
|
|
109
|
+
timestamp_pairs = []
|
|
110
|
+
if timestamps:
|
|
111
|
+
for start, end in timestamps:
|
|
112
|
+
timestamp_pairs.append((float(start), float(end)))
|
|
113
|
+
|
|
114
|
+
return ASRResult(
|
|
115
|
+
transcript=transcript,
|
|
116
|
+
confidence_scores=confidence_scores or [],
|
|
117
|
+
timestamps=timestamp_pairs
|
|
118
|
+
)
|
|
40
119
|
|
|
41
120
|
def list_supported_languages(self) -> List[str]:
|
|
42
121
|
"""List supported languages."""
|
|
43
|
-
|
|
44
|
-
|
|
122
|
+
if self._handle is None:
|
|
123
|
+
raise RuntimeError("ASR model not loaded. Call _load_from first.")
|
|
124
|
+
|
|
125
|
+
# Get supported languages using the binding
|
|
126
|
+
languages = asr_bind.ml_asr_list_supported_languages(handle=self._handle)
|
|
127
|
+
return languages
|
nexaai/binds/__init__.py
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
nexaai/binds/cpu_gpu/ggml.dll
CHANGED
|
Binary file
|
nexaai/binds/cpu_gpu/mtmd.dll
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
nexaai/binds/nexa_bridge.dll
CHANGED
|
Binary file
|
|
Binary file
|
nexaai/binds/nexaml/ggml-cpu.dll
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
nexaai/binds/nexaml/ggml.dll
CHANGED
|
Binary file
|
|
Binary file
|
nexaai/binds/nexaml/nexaproc.dll
CHANGED
|
Binary file
|
nexaai/binds/nexaml/qwen3-vl.dll
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
nexaai/common.py
CHANGED
nexaai/cv.py
CHANGED
|
@@ -73,7 +73,8 @@ class CVModel(BaseModel):
|
|
|
73
73
|
_: str, # TODO: remove this argument, this is a hack to make api design happy
|
|
74
74
|
config: CVModelConfig,
|
|
75
75
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
76
|
-
device_id: Optional[str] = None
|
|
76
|
+
device_id: Optional[str] = None,
|
|
77
|
+
**kwargs
|
|
77
78
|
) -> 'CVModel':
|
|
78
79
|
"""Load CV model from configuration, routing to appropriate implementation."""
|
|
79
80
|
# Check plugin_id value for routing - handle both enum and string
|
nexaai/embedder.py
CHANGED
|
@@ -22,12 +22,13 @@ class Embedder(BaseModel):
|
|
|
22
22
|
pass
|
|
23
23
|
|
|
24
24
|
@classmethod
|
|
25
|
-
def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP):
|
|
25
|
+
def _load_from(cls, model_path: str, model_name: str = None, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP, **kwargs):
|
|
26
26
|
"""
|
|
27
27
|
Load an embedder from model files, routing to appropriate implementation.
|
|
28
28
|
|
|
29
29
|
Args:
|
|
30
30
|
model_path: Path to the model file
|
|
31
|
+
model_name: Name of the model
|
|
31
32
|
tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
|
|
32
33
|
plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
|
|
33
34
|
|
|
@@ -39,10 +40,10 @@ class Embedder(BaseModel):
|
|
|
39
40
|
|
|
40
41
|
if plugin_value == "mlx":
|
|
41
42
|
from nexaai.embedder_impl.mlx_embedder_impl import MLXEmbedderImpl
|
|
42
|
-
return MLXEmbedderImpl._load_from(model_path, tokenizer_file, plugin_id)
|
|
43
|
+
return MLXEmbedderImpl._load_from(model_path, model_name, tokenizer_file, plugin_id)
|
|
43
44
|
else:
|
|
44
45
|
from nexaai.embedder_impl.pybind_embedder_impl import PyBindEmbedderImpl
|
|
45
|
-
return PyBindEmbedderImpl._load_from(model_path, tokenizer_file, plugin_id)
|
|
46
|
+
return PyBindEmbedderImpl._load_from(model_path, model_name, tokenizer_file, plugin_id)
|
|
46
47
|
|
|
47
48
|
@abstractmethod
|
|
48
49
|
def generate(self, texts: Union[List[str], str] = None, config: EmbeddingConfig = EmbeddingConfig(), input_ids: Union[List[int], List[List[int]]] = None) -> np.ndarray:
|
|
@@ -14,12 +14,13 @@ class MLXEmbedderImpl(Embedder):
|
|
|
14
14
|
self._mlx_embedder = None
|
|
15
15
|
|
|
16
16
|
@classmethod
|
|
17
|
-
def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.MLX):
|
|
17
|
+
def _load_from(cls, model_path: str, model_name: str = None, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.MLX):
|
|
18
18
|
"""
|
|
19
19
|
Load an embedder from model files using MLX backend.
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
22
|
model_path: Path to the model file
|
|
23
|
+
model_name: Name of the model
|
|
23
24
|
tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
|
|
24
25
|
plugin_id: Plugin ID to use for the model (default: PluginID.MLX)
|
|
25
26
|
|
|
@@ -34,6 +35,7 @@ class MLXEmbedderImpl(Embedder):
|
|
|
34
35
|
# This will automatically detect if it's JinaV2 or generic model and route correctly
|
|
35
36
|
instance._mlx_embedder = create_embedder(
|
|
36
37
|
model_path=model_path,
|
|
38
|
+
# model_name=model_name, # FIXME: For MLX Embedder, model_name is not used
|
|
37
39
|
tokenizer_path=tokenizer_file
|
|
38
40
|
)
|
|
39
41
|
|
|
@@ -16,12 +16,13 @@ class PyBindEmbedderImpl(Embedder):
|
|
|
16
16
|
self._handle = _handle_ptr
|
|
17
17
|
|
|
18
18
|
@classmethod
|
|
19
|
-
def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP):
|
|
19
|
+
def _load_from(cls, model_path: str, model_name: str = None, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP):
|
|
20
20
|
"""
|
|
21
21
|
Load an embedder from model files
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
24
|
model_path: Path to the model file
|
|
25
|
+
model_name: Name of the model
|
|
25
26
|
tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
|
|
26
27
|
plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
|
|
27
28
|
|
|
@@ -32,7 +33,7 @@ class PyBindEmbedderImpl(Embedder):
|
|
|
32
33
|
# Convert enum to string for C++ binding
|
|
33
34
|
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
34
35
|
# New parameter order: model_path, plugin_id, tokenizer_path (optional)
|
|
35
|
-
handle = embedder_bind.ml_embedder_create(model_path, plugin_id_str, tokenizer_file)
|
|
36
|
+
handle = embedder_bind.ml_embedder_create(model_path, model_name, plugin_id_str, tokenizer_file)
|
|
36
37
|
return cls(handle)
|
|
37
38
|
|
|
38
39
|
def eject(self):
|
nexaai/image_gen.py
CHANGED
|
@@ -71,7 +71,8 @@ class ImageGen(BaseModel):
|
|
|
71
71
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
72
72
|
device_id: Optional[str] = None,
|
|
73
73
|
float16: bool = True,
|
|
74
|
-
quantize: bool = False
|
|
74
|
+
quantize: bool = False,
|
|
75
|
+
**kwargs
|
|
75
76
|
) -> 'ImageGen':
|
|
76
77
|
"""Load image generation model from local path, routing to appropriate implementation."""
|
|
77
78
|
# Check plugin_id value for routing - handle both enum and string
|
nexaai/llm.py
CHANGED
|
@@ -15,10 +15,12 @@ class LLM(BaseModel):
|
|
|
15
15
|
@classmethod
|
|
16
16
|
def _load_from(cls,
|
|
17
17
|
local_path: str,
|
|
18
|
+
model_name: Optional[str] = None,
|
|
18
19
|
tokenizer_path: Optional[str] = None,
|
|
19
20
|
m_cfg: ModelConfig = ModelConfig(),
|
|
20
21
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
21
|
-
device_id: Optional[str] = None
|
|
22
|
+
device_id: Optional[str] = None,
|
|
23
|
+
**kwargs
|
|
22
24
|
) -> 'LLM':
|
|
23
25
|
"""Load model from local path, routing to appropriate implementation."""
|
|
24
26
|
# Check plugin_id value for routing - handle both enum and string
|
|
@@ -26,10 +28,10 @@ class LLM(BaseModel):
|
|
|
26
28
|
|
|
27
29
|
if plugin_value == "mlx":
|
|
28
30
|
from nexaai.llm_impl.mlx_llm_impl import MLXLLMImpl
|
|
29
|
-
return MLXLLMImpl._load_from(local_path, tokenizer_path, m_cfg, plugin_id, device_id)
|
|
31
|
+
return MLXLLMImpl._load_from(local_path, model_name, tokenizer_path, m_cfg, plugin_id, device_id)
|
|
30
32
|
else:
|
|
31
33
|
from nexaai.llm_impl.pybind_llm_impl import PyBindLLMImpl
|
|
32
|
-
return PyBindLLMImpl._load_from(local_path, tokenizer_path, m_cfg, plugin_id, device_id)
|
|
34
|
+
return PyBindLLMImpl._load_from(local_path, model_name, tokenizer_path, m_cfg, plugin_id, device_id)
|
|
33
35
|
|
|
34
36
|
def cancel_generation(self):
|
|
35
37
|
"""Signal to cancel any ongoing stream generation."""
|
nexaai/llm_impl/mlx_llm_impl.py
CHANGED
|
@@ -16,6 +16,7 @@ class MLXLLMImpl(LLM):
|
|
|
16
16
|
@classmethod
|
|
17
17
|
def _load_from(cls,
|
|
18
18
|
local_path: str,
|
|
19
|
+
model_name: Optional[str] = None,
|
|
19
20
|
tokenizer_path: Optional[str] = None,
|
|
20
21
|
m_cfg: ModelConfig = ModelConfig(),
|
|
21
22
|
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
@@ -40,6 +41,7 @@ class MLXLLMImpl(LLM):
|
|
|
40
41
|
instance = cls(m_cfg)
|
|
41
42
|
instance._mlx_llm = MLXLLMInterface(
|
|
42
43
|
model_path=local_path,
|
|
44
|
+
# model_name=model_name, # FIXME: For MLX LLM, model_name is not used
|
|
43
45
|
tokenizer_path=tokenizer_path or local_path,
|
|
44
46
|
config=mlx_config,
|
|
45
47
|
device=device_id
|
|
@@ -19,6 +19,7 @@ class PyBindLLMImpl(LLM):
|
|
|
19
19
|
@classmethod
|
|
20
20
|
def _load_from(cls,
|
|
21
21
|
local_path: str,
|
|
22
|
+
model_name: Optional[str] = None,
|
|
22
23
|
tokenizer_path: Optional[str] = None,
|
|
23
24
|
m_cfg: ModelConfig = ModelConfig(),
|
|
24
25
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
@@ -55,6 +56,7 @@ class PyBindLLMImpl(LLM):
|
|
|
55
56
|
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
56
57
|
handle = llm_bind.ml_llm_create(
|
|
57
58
|
model_path=local_path,
|
|
59
|
+
model_name=model_name,
|
|
58
60
|
tokenizer_path=tokenizer_path,
|
|
59
61
|
model_config=config,
|
|
60
62
|
plugin_id=plugin_id_str,
|
nexaai/rerank.py
CHANGED
|
@@ -24,9 +24,11 @@ class Reranker(BaseModel):
|
|
|
24
24
|
@classmethod
|
|
25
25
|
def _load_from(cls,
|
|
26
26
|
model_path: str,
|
|
27
|
+
model_name: str = None,
|
|
27
28
|
tokenizer_file: str = "tokenizer.json",
|
|
28
29
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
29
|
-
device_id: Optional[str] = None
|
|
30
|
+
device_id: Optional[str] = None,
|
|
31
|
+
**kwargs
|
|
30
32
|
) -> 'Reranker':
|
|
31
33
|
"""Load reranker model from local path, routing to appropriate implementation."""
|
|
32
34
|
# Check plugin_id value for routing - handle both enum and string
|
|
@@ -34,10 +36,10 @@ class Reranker(BaseModel):
|
|
|
34
36
|
|
|
35
37
|
if plugin_value == "mlx":
|
|
36
38
|
from nexaai.rerank_impl.mlx_rerank_impl import MLXRerankImpl
|
|
37
|
-
return MLXRerankImpl._load_from(model_path, tokenizer_file, plugin_id, device_id)
|
|
39
|
+
return MLXRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
|
|
38
40
|
else:
|
|
39
41
|
from nexaai.rerank_impl.pybind_rerank_impl import PyBindRerankImpl
|
|
40
|
-
return PyBindRerankImpl._load_from(model_path, tokenizer_file, plugin_id, device_id)
|
|
42
|
+
return PyBindRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
|
|
41
43
|
|
|
42
44
|
@abstractmethod
|
|
43
45
|
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
@@ -17,6 +17,7 @@ class MLXRerankImpl(Reranker):
|
|
|
17
17
|
@classmethod
|
|
18
18
|
def _load_from(cls,
|
|
19
19
|
model_path: str,
|
|
20
|
+
model_name: str = None,
|
|
20
21
|
tokenizer_file: str = "tokenizer.json",
|
|
21
22
|
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
22
23
|
device_id: Optional[str] = None
|
|
@@ -29,6 +30,7 @@ class MLXRerankImpl(Reranker):
|
|
|
29
30
|
instance = cls()
|
|
30
31
|
instance._mlx_reranker = create_reranker(
|
|
31
32
|
model_path=model_path,
|
|
33
|
+
# model_name=model_name, # FIXME: For MLX Reranker, model_name is not used
|
|
32
34
|
tokenizer_path=tokenizer_file,
|
|
33
35
|
device=device_id
|
|
34
36
|
)
|
|
@@ -1,36 +1,89 @@
|
|
|
1
1
|
from typing import List, Optional, Sequence, Union
|
|
2
|
+
import numpy as np
|
|
2
3
|
|
|
3
4
|
from nexaai.common import PluginID
|
|
4
5
|
from nexaai.rerank import Reranker, RerankConfig
|
|
6
|
+
from nexaai.binds import rerank_bind, common_bind
|
|
7
|
+
from nexaai.runtime import _ensure_runtime
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
class PyBindRerankImpl(Reranker):
|
|
8
|
-
def __init__(self):
|
|
9
|
-
"""
|
|
11
|
+
def __init__(self, _handle_ptr):
|
|
12
|
+
"""
|
|
13
|
+
Internal initializer
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
_handle_ptr: Capsule handle to the C++ reranker object
|
|
17
|
+
"""
|
|
10
18
|
super().__init__()
|
|
11
|
-
|
|
19
|
+
self._handle = _handle_ptr
|
|
12
20
|
|
|
13
21
|
@classmethod
|
|
14
22
|
def _load_from(cls,
|
|
15
23
|
model_path: str,
|
|
24
|
+
model_name: str = None,
|
|
16
25
|
tokenizer_file: str = "tokenizer.json",
|
|
17
26
|
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
18
27
|
device_id: Optional[str] = None
|
|
19
28
|
) -> 'PyBindRerankImpl':
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
29
|
+
"""
|
|
30
|
+
Load reranker model from local path using PyBind backend.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_path: Path to the model file
|
|
34
|
+
model_name: Name of the model (optional)
|
|
35
|
+
tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
|
|
36
|
+
plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
|
|
37
|
+
device_id: Device ID to use for the model (optional)
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
PyBindRerankImpl instance
|
|
41
|
+
"""
|
|
42
|
+
_ensure_runtime()
|
|
43
|
+
|
|
44
|
+
# Convert enum to string for C++ binding
|
|
45
|
+
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
46
|
+
|
|
47
|
+
# Create model config
|
|
48
|
+
model_config = common_bind.ModelConfig()
|
|
49
|
+
|
|
50
|
+
# Create reranker handle with new API signature
|
|
51
|
+
handle = rerank_bind.ml_reranker_create(
|
|
52
|
+
model_path,
|
|
53
|
+
model_name,
|
|
54
|
+
tokenizer_file,
|
|
55
|
+
model_config,
|
|
56
|
+
plugin_id_str,
|
|
57
|
+
device_id
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return cls(handle)
|
|
24
61
|
|
|
25
62
|
def eject(self):
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
|
|
63
|
+
"""
|
|
64
|
+
Clean up resources and destroy the reranker
|
|
65
|
+
"""
|
|
66
|
+
# Destructor of the handle will unload the model correctly
|
|
67
|
+
if hasattr(self, '_handle') and self._handle is not None:
|
|
68
|
+
del self._handle
|
|
69
|
+
self._handle = None
|
|
29
70
|
|
|
30
71
|
def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
|
|
72
|
+
"""
|
|
73
|
+
Load model from path.
|
|
74
|
+
|
|
75
|
+
Note: This method is not typically used directly. Use _load_from instead.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
model_path: Path to the model file
|
|
79
|
+
extra_data: Additional data (unused)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
True if successful
|
|
83
|
+
"""
|
|
84
|
+
# This method is part of the BaseModel interface but typically not used
|
|
85
|
+
# directly for PyBind implementations since _load_from handles creation
|
|
86
|
+
raise NotImplementedError("Use _load_from class method to load models")
|
|
34
87
|
|
|
35
88
|
def rerank(
|
|
36
89
|
self,
|
|
@@ -38,6 +91,46 @@ class PyBindRerankImpl(Reranker):
|
|
|
38
91
|
documents: Sequence[str],
|
|
39
92
|
config: Optional[RerankConfig] = None,
|
|
40
93
|
) -> List[float]:
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
|
|
94
|
+
"""
|
|
95
|
+
Rerank documents given a query.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
query: Query text as UTF-8 string
|
|
99
|
+
documents: List of document texts to rerank
|
|
100
|
+
config: Optional reranking configuration
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of ranking scores (one per document)
|
|
104
|
+
"""
|
|
105
|
+
if self._handle is None:
|
|
106
|
+
raise RuntimeError("Reranker handle is None. Model may have been ejected.")
|
|
107
|
+
|
|
108
|
+
# Use default config if not provided
|
|
109
|
+
if config is None:
|
|
110
|
+
config = RerankConfig()
|
|
111
|
+
|
|
112
|
+
# Create bind config
|
|
113
|
+
bind_config = rerank_bind.RerankConfig()
|
|
114
|
+
bind_config.batch_size = config.batch_size
|
|
115
|
+
bind_config.normalize = config.normalize
|
|
116
|
+
bind_config.normalize_method = config.normalize_method
|
|
117
|
+
|
|
118
|
+
# Convert documents to list if needed
|
|
119
|
+
documents_list = list(documents)
|
|
120
|
+
|
|
121
|
+
# Call the binding which returns a dict with scores and profile_data
|
|
122
|
+
result = rerank_bind.ml_reranker_rerank(
|
|
123
|
+
self._handle,
|
|
124
|
+
query,
|
|
125
|
+
documents_list,
|
|
126
|
+
bind_config
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Extract scores from result dict
|
|
130
|
+
scores_array = result.get("scores", np.array([]))
|
|
131
|
+
|
|
132
|
+
# Convert numpy array to list of floats
|
|
133
|
+
if isinstance(scores_array, np.ndarray):
|
|
134
|
+
return scores_array.tolist()
|
|
135
|
+
else:
|
|
136
|
+
return []
|