nexaai 1.0.20__cp310-cp310-win_amd64.whl → 1.0.21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (53) hide show
  1. nexaai/__init__.py +12 -0
  2. nexaai/_stub.cp310-win_amd64.pyd +0 -0
  3. nexaai/_version.py +1 -1
  4. nexaai/asr.py +10 -6
  5. nexaai/asr_impl/pybind_asr_impl.py +98 -15
  6. nexaai/binds/__init__.py +2 -0
  7. nexaai/binds/asr_bind.cp310-win_amd64.pyd +0 -0
  8. nexaai/binds/common_bind.cp310-win_amd64.pyd +0 -0
  9. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  10. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  11. nexaai/binds/cpu_gpu/ggml-cuda.dll +0 -0
  12. nexaai/binds/cpu_gpu/ggml-vulkan.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  14. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  15. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  16. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  17. nexaai/binds/embedder_bind.cp310-win_amd64.pyd +0 -0
  18. nexaai/binds/llm_bind.cp310-win_amd64.pyd +0 -0
  19. nexaai/binds/nexa_bridge.dll +0 -0
  20. nexaai/binds/nexaml/ggml-base.dll +0 -0
  21. nexaai/binds/nexaml/ggml-cpu.dll +0 -0
  22. nexaai/binds/nexaml/ggml-cuda.dll +0 -0
  23. nexaai/binds/nexaml/ggml-vulkan.dll +0 -0
  24. nexaai/binds/nexaml/ggml.dll +0 -0
  25. nexaai/binds/nexaml/nexa_plugin.dll +0 -0
  26. nexaai/binds/nexaml/nexaproc.dll +0 -0
  27. nexaai/binds/nexaml/qwen3-vl.dll +0 -0
  28. nexaai/binds/rerank_bind.cp310-win_amd64.pyd +0 -0
  29. nexaai/binds/vlm_bind.cp310-win_amd64.pyd +0 -0
  30. nexaai/common.py +1 -0
  31. nexaai/cv.py +2 -1
  32. nexaai/embedder.py +4 -3
  33. nexaai/embedder_impl/mlx_embedder_impl.py +3 -1
  34. nexaai/embedder_impl/pybind_embedder_impl.py +3 -2
  35. nexaai/image_gen.py +2 -1
  36. nexaai/llm.py +5 -3
  37. nexaai/llm_impl/mlx_llm_impl.py +2 -0
  38. nexaai/llm_impl/pybind_llm_impl.py +2 -0
  39. nexaai/rerank.py +5 -3
  40. nexaai/rerank_impl/mlx_rerank_impl.py +2 -0
  41. nexaai/rerank_impl/pybind_rerank_impl.py +109 -16
  42. nexaai/runtime_error.py +24 -0
  43. nexaai/tts.py +2 -1
  44. nexaai/utils/manifest_utils.py +10 -6
  45. nexaai/utils/model_manager.py +139 -8
  46. nexaai/vlm.py +4 -2
  47. nexaai/vlm_impl/mlx_vlm_impl.py +3 -2
  48. nexaai/vlm_impl/pybind_vlm_impl.py +33 -7
  49. {nexaai-1.0.20.dist-info → nexaai-1.0.21.dist-info}/METADATA +1 -2
  50. nexaai-1.0.21.dist-info/RECORD +79 -0
  51. nexaai-1.0.20.dist-info/RECORD +0 -76
  52. {nexaai-1.0.20.dist-info → nexaai-1.0.21.dist-info}/WHEEL +0 -0
  53. {nexaai-1.0.20.dist-info → nexaai-1.0.21.dist-info}/top_level.txt +0 -0
nexaai/__init__.py CHANGED
@@ -24,6 +24,13 @@ from .common import ModelConfig, GenerationConfig, ChatMessage, SamplerConfig, P
24
24
  # Import logging functionality
25
25
  from .log import set_logger, get_error_message
26
26
 
27
+ # Import runtime errors
28
+ from .runtime_error import (
29
+ NexaRuntimeError,
30
+ ContextLengthExceededError,
31
+ GenerationError
32
+ )
33
+
27
34
  # Create alias for PluginID to be accessible as plugin_id
28
35
  plugin_id = PluginID
29
36
 
@@ -52,6 +59,11 @@ __all__ = [
52
59
  # Logging functionality
53
60
  "set_logger",
54
61
  "get_error_message",
62
+
63
+ # Runtime errors
64
+ "NexaRuntimeError",
65
+ "ContextLengthExceededError",
66
+ "GenerationError",
55
67
 
56
68
  "LLM",
57
69
  "Embedder",
Binary file
nexaai/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # This file is generated by CMake from _version.py.in
2
2
  # Do not modify this file manually - it will be overwritten
3
3
 
4
- __version__ = "1.0.20"
4
+ __version__ = "1.0.21"
nexaai/asr.py CHANGED
@@ -3,7 +3,7 @@ from abc import abstractmethod
3
3
  from dataclasses import dataclass
4
4
 
5
5
  from nexaai.base import BaseModel
6
- from nexaai.common import PluginID
6
+ from nexaai.common import PluginID, ModelConfig
7
7
 
8
8
 
9
9
  @dataclass
@@ -25,17 +25,20 @@ class ASRResult:
25
25
  class ASR(BaseModel):
26
26
  """Abstract base class for Automatic Speech Recognition models."""
27
27
 
28
- def __init__(self):
28
+ def __init__(self, m_cfg: ModelConfig = ModelConfig()):
29
29
  """Initialize base ASR class."""
30
- pass
30
+ self._m_cfg = m_cfg
31
31
 
32
32
  @classmethod
33
33
  def _load_from(cls,
34
34
  model_path: str,
35
+ model_name: Optional[str] = None,
35
36
  tokenizer_path: Optional[str] = None,
36
37
  language: Optional[str] = None,
38
+ m_cfg: ModelConfig = ModelConfig(),
37
39
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
38
- device_id: Optional[str] = None
40
+ device_id: Optional[str] = None,
41
+ **kwargs
39
42
  ) -> 'ASR':
40
43
  """Load ASR model from local path, routing to appropriate implementation."""
41
44
  # Check plugin_id value for routing - handle both enum and string
@@ -43,10 +46,11 @@ class ASR(BaseModel):
43
46
 
44
47
  if plugin_value == "mlx":
45
48
  from nexaai.asr_impl.mlx_asr_impl import MLXASRImpl
46
- return MLXASRImpl._load_from(model_path, tokenizer_path, language, plugin_id, device_id)
49
+ return MLXASRImpl._load_from(model_path, model_name, tokenizer_path, language, m_cfg, plugin_id, device_id)
47
50
  else:
48
51
  from nexaai.asr_impl.pybind_asr_impl import PyBindASRImpl
49
- return PyBindASRImpl._load_from(model_path, tokenizer_path, language, plugin_id, device_id)
52
+ return PyBindASRImpl._load_from(model_path, model_name, tokenizer_path, language, m_cfg, plugin_id, device_id)
53
+
50
54
 
51
55
  @abstractmethod
52
56
  def transcribe(
@@ -1,32 +1,78 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
- from nexaai.common import PluginID
3
+ from nexaai.common import PluginID, ModelConfig
4
4
  from nexaai.asr import ASR, ASRConfig, ASRResult
5
+ from nexaai.binds import asr_bind, common_bind
6
+ from nexaai.runtime import _ensure_runtime
5
7
 
6
8
 
7
9
  class PyBindASRImpl(ASR):
8
- def __init__(self):
9
- """Initialize PyBind ASR implementation."""
10
- super().__init__()
11
- # TODO: Add PyBind-specific initialization
10
+ def __init__(self, handle: any, m_cfg: ModelConfig = ModelConfig()):
11
+ """Private constructor, should not be called directly."""
12
+ super().__init__(m_cfg)
13
+ self._handle = handle # This is a py::capsule
14
+ self._model_config = None
12
15
 
13
16
  @classmethod
14
17
  def _load_from(cls,
15
18
  model_path: str,
19
+ model_name: Optional[str] = None,
16
20
  tokenizer_path: Optional[str] = None,
17
21
  language: Optional[str] = None,
22
+ m_cfg: ModelConfig = ModelConfig(),
18
23
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
19
24
  device_id: Optional[str] = None
20
25
  ) -> 'PyBindASRImpl':
21
26
  """Load ASR model from local path using PyBind backend."""
22
- # TODO: Implement PyBind ASR loading
23
- instance = cls()
24
- return instance
27
+ _ensure_runtime()
28
+
29
+ # Create model config
30
+ config = common_bind.ModelConfig()
31
+
32
+ config.n_ctx = m_cfg.n_ctx
33
+ if m_cfg.n_threads is not None:
34
+ config.n_threads = m_cfg.n_threads
35
+ if m_cfg.n_threads_batch is not None:
36
+ config.n_threads_batch = m_cfg.n_threads_batch
37
+ if m_cfg.n_batch is not None:
38
+ config.n_batch = m_cfg.n_batch
39
+ if m_cfg.n_ubatch is not None:
40
+ config.n_ubatch = m_cfg.n_ubatch
41
+ if m_cfg.n_seq_max is not None:
42
+ config.n_seq_max = m_cfg.n_seq_max
43
+ config.n_gpu_layers = m_cfg.n_gpu_layers
44
+
45
+ # handle chat template strings
46
+ if m_cfg.chat_template_path:
47
+ config.chat_template_path = m_cfg.chat_template_path
48
+
49
+ if m_cfg.chat_template_content:
50
+ config.chat_template_content = m_cfg.chat_template_content
51
+
52
+ # Convert plugin_id to string
53
+ plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else str(plugin_id)
54
+
55
+ # Create ASR handle using the binding
56
+ handle = asr_bind.ml_asr_create(
57
+ model_path=model_path,
58
+ model_name=model_name,
59
+ tokenizer_path=tokenizer_path,
60
+ model_config=config,
61
+ language=language,
62
+ plugin_id=plugin_id_str,
63
+ device_id=device_id,
64
+ license_id=None, # Optional
65
+ license_key=None # Optional
66
+ )
67
+
68
+ return cls(handle, m_cfg)
25
69
 
26
70
  def eject(self):
27
- """Destroy the model and free resources."""
28
- # TODO: Implement PyBind ASR cleanup
29
- pass
71
+ """Release the model from memory."""
72
+ # py::capsule handles cleanup automatically
73
+ if hasattr(self, '_handle') and self._handle is not None:
74
+ del self._handle
75
+ self._handle = None
30
76
 
31
77
  def transcribe(
32
78
  self,
@@ -35,10 +81,47 @@ class PyBindASRImpl(ASR):
35
81
  config: Optional[ASRConfig] = None,
36
82
  ) -> ASRResult:
37
83
  """Transcribe audio file to text."""
38
- # TODO: Implement PyBind ASR transcription
39
- raise NotImplementedError("PyBind ASR transcription not yet implemented")
84
+ if self._handle is None:
85
+ raise RuntimeError("ASR model not loaded. Call _load_from first.")
86
+
87
+ # Convert ASRConfig to binding format if provided
88
+ asr_config = None
89
+ if config:
90
+ asr_config = asr_bind.ASRConfig()
91
+ asr_config.timestamps = config.timestamps
92
+ asr_config.beam_size = config.beam_size
93
+ asr_config.stream = config.stream
94
+
95
+ # Perform transcription using the binding
96
+ result_dict = asr_bind.ml_asr_transcribe(
97
+ handle=self._handle,
98
+ audio_path=audio_path,
99
+ language=language,
100
+ config=asr_config
101
+ )
102
+
103
+ # Convert result to ASRResult
104
+ transcript = result_dict.get("transcript", "")
105
+ confidence_scores = result_dict.get("confidence_scores")
106
+ timestamps = result_dict.get("timestamps")
107
+
108
+ # Convert timestamps to the expected format
109
+ timestamp_pairs = []
110
+ if timestamps:
111
+ for start, end in timestamps:
112
+ timestamp_pairs.append((float(start), float(end)))
113
+
114
+ return ASRResult(
115
+ transcript=transcript,
116
+ confidence_scores=confidence_scores or [],
117
+ timestamps=timestamp_pairs
118
+ )
40
119
 
41
120
  def list_supported_languages(self) -> List[str]:
42
121
  """List supported languages."""
43
- # TODO: Implement PyBind ASR language listing
44
- raise NotImplementedError("PyBind ASR language listing not yet implemented")
122
+ if self._handle is None:
123
+ raise RuntimeError("ASR model not loaded. Call _load_from first.")
124
+
125
+ # Get supported languages using the binding
126
+ languages = asr_bind.ml_asr_list_supported_languages(handle=self._handle)
127
+ return languages
nexaai/binds/__init__.py CHANGED
@@ -2,3 +2,5 @@ from .common_bind import *
2
2
  from .llm_bind import *
3
3
  from .embedder_bind import *
4
4
  from .vlm_bind import *
5
+ from .rerank_bind import *
6
+ from .asr_bind import *
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
nexaai/common.py CHANGED
@@ -8,6 +8,7 @@ class PluginID(str, Enum):
8
8
  MLX = "mlx"
9
9
  LLAMA_CPP = "llama_cpp"
10
10
  NEXAML = "nexaml"
11
+ NPU = "npu"
11
12
 
12
13
 
13
14
  class ChatMessage(TypedDict):
nexaai/cv.py CHANGED
@@ -73,7 +73,8 @@ class CVModel(BaseModel):
73
73
  _: str, # TODO: remove this argument, this is a hack to make api design happy
74
74
  config: CVModelConfig,
75
75
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
76
- device_id: Optional[str] = None
76
+ device_id: Optional[str] = None,
77
+ **kwargs
77
78
  ) -> 'CVModel':
78
79
  """Load CV model from configuration, routing to appropriate implementation."""
79
80
  # Check plugin_id value for routing - handle both enum and string
nexaai/embedder.py CHANGED
@@ -22,12 +22,13 @@ class Embedder(BaseModel):
22
22
  pass
23
23
 
24
24
  @classmethod
25
- def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP):
25
+ def _load_from(cls, model_path: str, model_name: str = None, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP, **kwargs):
26
26
  """
27
27
  Load an embedder from model files, routing to appropriate implementation.
28
28
 
29
29
  Args:
30
30
  model_path: Path to the model file
31
+ model_name: Name of the model
31
32
  tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
32
33
  plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
33
34
 
@@ -39,10 +40,10 @@ class Embedder(BaseModel):
39
40
 
40
41
  if plugin_value == "mlx":
41
42
  from nexaai.embedder_impl.mlx_embedder_impl import MLXEmbedderImpl
42
- return MLXEmbedderImpl._load_from(model_path, tokenizer_file, plugin_id)
43
+ return MLXEmbedderImpl._load_from(model_path, model_name, tokenizer_file, plugin_id)
43
44
  else:
44
45
  from nexaai.embedder_impl.pybind_embedder_impl import PyBindEmbedderImpl
45
- return PyBindEmbedderImpl._load_from(model_path, tokenizer_file, plugin_id)
46
+ return PyBindEmbedderImpl._load_from(model_path, model_name, tokenizer_file, plugin_id)
46
47
 
47
48
  @abstractmethod
48
49
  def generate(self, texts: Union[List[str], str] = None, config: EmbeddingConfig = EmbeddingConfig(), input_ids: Union[List[int], List[List[int]]] = None) -> np.ndarray:
@@ -14,12 +14,13 @@ class MLXEmbedderImpl(Embedder):
14
14
  self._mlx_embedder = None
15
15
 
16
16
  @classmethod
17
- def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.MLX):
17
+ def _load_from(cls, model_path: str, model_name: str = None, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.MLX):
18
18
  """
19
19
  Load an embedder from model files using MLX backend.
20
20
 
21
21
  Args:
22
22
  model_path: Path to the model file
23
+ model_name: Name of the model
23
24
  tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
24
25
  plugin_id: Plugin ID to use for the model (default: PluginID.MLX)
25
26
 
@@ -34,6 +35,7 @@ class MLXEmbedderImpl(Embedder):
34
35
  # This will automatically detect if it's JinaV2 or generic model and route correctly
35
36
  instance._mlx_embedder = create_embedder(
36
37
  model_path=model_path,
38
+ # model_name=model_name, # FIXME: For MLX Embedder, model_name is not used
37
39
  tokenizer_path=tokenizer_file
38
40
  )
39
41
 
@@ -16,12 +16,13 @@ class PyBindEmbedderImpl(Embedder):
16
16
  self._handle = _handle_ptr
17
17
 
18
18
  @classmethod
19
- def _load_from(cls, model_path: str, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP):
19
+ def _load_from(cls, model_path: str, model_name: str = None, tokenizer_file: str = "tokenizer.json", plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP):
20
20
  """
21
21
  Load an embedder from model files
22
22
 
23
23
  Args:
24
24
  model_path: Path to the model file
25
+ model_name: Name of the model
25
26
  tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
26
27
  plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
27
28
 
@@ -32,7 +33,7 @@ class PyBindEmbedderImpl(Embedder):
32
33
  # Convert enum to string for C++ binding
33
34
  plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
34
35
  # New parameter order: model_path, plugin_id, tokenizer_path (optional)
35
- handle = embedder_bind.ml_embedder_create(model_path, plugin_id_str, tokenizer_file)
36
+ handle = embedder_bind.ml_embedder_create(model_path, model_name, plugin_id_str, tokenizer_file)
36
37
  return cls(handle)
37
38
 
38
39
  def eject(self):
nexaai/image_gen.py CHANGED
@@ -71,7 +71,8 @@ class ImageGen(BaseModel):
71
71
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
72
72
  device_id: Optional[str] = None,
73
73
  float16: bool = True,
74
- quantize: bool = False
74
+ quantize: bool = False,
75
+ **kwargs
75
76
  ) -> 'ImageGen':
76
77
  """Load image generation model from local path, routing to appropriate implementation."""
77
78
  # Check plugin_id value for routing - handle both enum and string
nexaai/llm.py CHANGED
@@ -15,10 +15,12 @@ class LLM(BaseModel):
15
15
  @classmethod
16
16
  def _load_from(cls,
17
17
  local_path: str,
18
+ model_name: Optional[str] = None,
18
19
  tokenizer_path: Optional[str] = None,
19
20
  m_cfg: ModelConfig = ModelConfig(),
20
21
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
21
- device_id: Optional[str] = None
22
+ device_id: Optional[str] = None,
23
+ **kwargs
22
24
  ) -> 'LLM':
23
25
  """Load model from local path, routing to appropriate implementation."""
24
26
  # Check plugin_id value for routing - handle both enum and string
@@ -26,10 +28,10 @@ class LLM(BaseModel):
26
28
 
27
29
  if plugin_value == "mlx":
28
30
  from nexaai.llm_impl.mlx_llm_impl import MLXLLMImpl
29
- return MLXLLMImpl._load_from(local_path, tokenizer_path, m_cfg, plugin_id, device_id)
31
+ return MLXLLMImpl._load_from(local_path, model_name, tokenizer_path, m_cfg, plugin_id, device_id)
30
32
  else:
31
33
  from nexaai.llm_impl.pybind_llm_impl import PyBindLLMImpl
32
- return PyBindLLMImpl._load_from(local_path, tokenizer_path, m_cfg, plugin_id, device_id)
34
+ return PyBindLLMImpl._load_from(local_path, model_name, tokenizer_path, m_cfg, plugin_id, device_id)
33
35
 
34
36
  def cancel_generation(self):
35
37
  """Signal to cancel any ongoing stream generation."""
@@ -16,6 +16,7 @@ class MLXLLMImpl(LLM):
16
16
  @classmethod
17
17
  def _load_from(cls,
18
18
  local_path: str,
19
+ model_name: Optional[str] = None,
19
20
  tokenizer_path: Optional[str] = None,
20
21
  m_cfg: ModelConfig = ModelConfig(),
21
22
  plugin_id: Union[PluginID, str] = PluginID.MLX,
@@ -40,6 +41,7 @@ class MLXLLMImpl(LLM):
40
41
  instance = cls(m_cfg)
41
42
  instance._mlx_llm = MLXLLMInterface(
42
43
  model_path=local_path,
44
+ # model_name=model_name, # FIXME: For MLX LLM, model_name is not used
43
45
  tokenizer_path=tokenizer_path or local_path,
44
46
  config=mlx_config,
45
47
  device=device_id
@@ -19,6 +19,7 @@ class PyBindLLMImpl(LLM):
19
19
  @classmethod
20
20
  def _load_from(cls,
21
21
  local_path: str,
22
+ model_name: Optional[str] = None,
22
23
  tokenizer_path: Optional[str] = None,
23
24
  m_cfg: ModelConfig = ModelConfig(),
24
25
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
@@ -55,6 +56,7 @@ class PyBindLLMImpl(LLM):
55
56
  plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
56
57
  handle = llm_bind.ml_llm_create(
57
58
  model_path=local_path,
59
+ model_name=model_name,
58
60
  tokenizer_path=tokenizer_path,
59
61
  model_config=config,
60
62
  plugin_id=plugin_id_str,
nexaai/rerank.py CHANGED
@@ -24,9 +24,11 @@ class Reranker(BaseModel):
24
24
  @classmethod
25
25
  def _load_from(cls,
26
26
  model_path: str,
27
+ model_name: str = None,
27
28
  tokenizer_file: str = "tokenizer.json",
28
29
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
29
- device_id: Optional[str] = None
30
+ device_id: Optional[str] = None,
31
+ **kwargs
30
32
  ) -> 'Reranker':
31
33
  """Load reranker model from local path, routing to appropriate implementation."""
32
34
  # Check plugin_id value for routing - handle both enum and string
@@ -34,10 +36,10 @@ class Reranker(BaseModel):
34
36
 
35
37
  if plugin_value == "mlx":
36
38
  from nexaai.rerank_impl.mlx_rerank_impl import MLXRerankImpl
37
- return MLXRerankImpl._load_from(model_path, tokenizer_file, plugin_id, device_id)
39
+ return MLXRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
38
40
  else:
39
41
  from nexaai.rerank_impl.pybind_rerank_impl import PyBindRerankImpl
40
- return PyBindRerankImpl._load_from(model_path, tokenizer_file, plugin_id, device_id)
42
+ return PyBindRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
41
43
 
42
44
  @abstractmethod
43
45
  def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
@@ -17,6 +17,7 @@ class MLXRerankImpl(Reranker):
17
17
  @classmethod
18
18
  def _load_from(cls,
19
19
  model_path: str,
20
+ model_name: str = None,
20
21
  tokenizer_file: str = "tokenizer.json",
21
22
  plugin_id: Union[PluginID, str] = PluginID.MLX,
22
23
  device_id: Optional[str] = None
@@ -29,6 +30,7 @@ class MLXRerankImpl(Reranker):
29
30
  instance = cls()
30
31
  instance._mlx_reranker = create_reranker(
31
32
  model_path=model_path,
33
+ # model_name=model_name, # FIXME: For MLX Reranker, model_name is not used
32
34
  tokenizer_path=tokenizer_file,
33
35
  device=device_id
34
36
  )
@@ -1,36 +1,89 @@
1
1
  from typing import List, Optional, Sequence, Union
2
+ import numpy as np
2
3
 
3
4
  from nexaai.common import PluginID
4
5
  from nexaai.rerank import Reranker, RerankConfig
6
+ from nexaai.binds import rerank_bind, common_bind
7
+ from nexaai.runtime import _ensure_runtime
5
8
 
6
9
 
7
10
  class PyBindRerankImpl(Reranker):
8
- def __init__(self):
9
- """Initialize PyBind Rerank implementation."""
11
+ def __init__(self, _handle_ptr):
12
+ """
13
+ Internal initializer
14
+
15
+ Args:
16
+ _handle_ptr: Capsule handle to the C++ reranker object
17
+ """
10
18
  super().__init__()
11
- # TODO: Add PyBind-specific initialization
19
+ self._handle = _handle_ptr
12
20
 
13
21
  @classmethod
14
22
  def _load_from(cls,
15
23
  model_path: str,
24
+ model_name: str = None,
16
25
  tokenizer_file: str = "tokenizer.json",
17
26
  plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
18
27
  device_id: Optional[str] = None
19
28
  ) -> 'PyBindRerankImpl':
20
- """Load reranker model from local path using PyBind backend."""
21
- # TODO: Implement PyBind reranker loading
22
- instance = cls()
23
- return instance
29
+ """
30
+ Load reranker model from local path using PyBind backend.
31
+
32
+ Args:
33
+ model_path: Path to the model file
34
+ model_name: Name of the model (optional)
35
+ tokenizer_file: Path to the tokenizer file (default: "tokenizer.json")
36
+ plugin_id: Plugin ID to use for the model (default: PluginID.LLAMA_CPP)
37
+ device_id: Device ID to use for the model (optional)
38
+
39
+ Returns:
40
+ PyBindRerankImpl instance
41
+ """
42
+ _ensure_runtime()
43
+
44
+ # Convert enum to string for C++ binding
45
+ plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
46
+
47
+ # Create model config
48
+ model_config = common_bind.ModelConfig()
49
+
50
+ # Create reranker handle with new API signature
51
+ handle = rerank_bind.ml_reranker_create(
52
+ model_path,
53
+ model_name,
54
+ tokenizer_file,
55
+ model_config,
56
+ plugin_id_str,
57
+ device_id
58
+ )
59
+
60
+ return cls(handle)
24
61
 
25
62
  def eject(self):
26
- """Destroy the model and free resources."""
27
- # TODO: Implement PyBind reranker cleanup
28
- pass
63
+ """
64
+ Clean up resources and destroy the reranker
65
+ """
66
+ # Destructor of the handle will unload the model correctly
67
+ if hasattr(self, '_handle') and self._handle is not None:
68
+ del self._handle
69
+ self._handle = None
29
70
 
30
71
  def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
31
- """Load model from path."""
32
- # TODO: Implement PyBind reranker model loading
33
- raise NotImplementedError("PyBind reranker model loading not yet implemented")
72
+ """
73
+ Load model from path.
74
+
75
+ Note: This method is not typically used directly. Use _load_from instead.
76
+
77
+ Args:
78
+ model_path: Path to the model file
79
+ extra_data: Additional data (unused)
80
+
81
+ Returns:
82
+ True if successful
83
+ """
84
+ # This method is part of the BaseModel interface but typically not used
85
+ # directly for PyBind implementations since _load_from handles creation
86
+ raise NotImplementedError("Use _load_from class method to load models")
34
87
 
35
88
  def rerank(
36
89
  self,
@@ -38,6 +91,46 @@ class PyBindRerankImpl(Reranker):
38
91
  documents: Sequence[str],
39
92
  config: Optional[RerankConfig] = None,
40
93
  ) -> List[float]:
41
- """Rerank documents given a query."""
42
- # TODO: Implement PyBind reranking
43
- raise NotImplementedError("PyBind reranking not yet implemented")
94
+ """
95
+ Rerank documents given a query.
96
+
97
+ Args:
98
+ query: Query text as UTF-8 string
99
+ documents: List of document texts to rerank
100
+ config: Optional reranking configuration
101
+
102
+ Returns:
103
+ List of ranking scores (one per document)
104
+ """
105
+ if self._handle is None:
106
+ raise RuntimeError("Reranker handle is None. Model may have been ejected.")
107
+
108
+ # Use default config if not provided
109
+ if config is None:
110
+ config = RerankConfig()
111
+
112
+ # Create bind config
113
+ bind_config = rerank_bind.RerankConfig()
114
+ bind_config.batch_size = config.batch_size
115
+ bind_config.normalize = config.normalize
116
+ bind_config.normalize_method = config.normalize_method
117
+
118
+ # Convert documents to list if needed
119
+ documents_list = list(documents)
120
+
121
+ # Call the binding which returns a dict with scores and profile_data
122
+ result = rerank_bind.ml_reranker_rerank(
123
+ self._handle,
124
+ query,
125
+ documents_list,
126
+ bind_config
127
+ )
128
+
129
+ # Extract scores from result dict
130
+ scores_array = result.get("scores", np.array([]))
131
+
132
+ # Convert numpy array to list of floats
133
+ if isinstance(scores_array, np.ndarray):
134
+ return scores_array.tolist()
135
+ else:
136
+ return []