nexaai 1.0.21rc16__cp312-cp312-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +95 -0
- nexaai/_stub.cp312-win_arm64.pyd +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +92 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +6 -0
- nexaai/binds/asr_bind.cp312-win_arm64.pyd +0 -0
- nexaai/binds/common_bind.cp312-win_arm64.pyd +0 -0
- nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
- nexaai/binds/cpu_gpu/ggml.dll +0 -0
- nexaai/binds/cpu_gpu/libomp140.aarch64.dll +0 -0
- nexaai/binds/cpu_gpu/mtmd.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
- nexaai/binds/embedder_bind.cp312-win_arm64.pyd +0 -0
- nexaai/binds/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/libssl-3-arm64.dll +0 -0
- nexaai/binds/llm_bind.cp312-win_arm64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/npu/FLAC.dll +0 -0
- nexaai/binds/npu/convnext-sdk.dll +0 -0
- nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
- nexaai/binds/npu/fftw3.dll +0 -0
- nexaai/binds/npu/fftw3f.dll +0 -0
- nexaai/binds/npu/ggml-base.dll +0 -0
- nexaai/binds/npu/ggml-cpu.dll +0 -0
- nexaai/binds/npu/ggml-opencl.dll +0 -0
- nexaai/binds/npu/ggml.dll +0 -0
- nexaai/binds/npu/granite-nano-sdk.dll +0 -0
- nexaai/binds/npu/granite4-sdk.dll +0 -0
- nexaai/binds/npu/htp-files/Genie.dll +0 -0
- nexaai/binds/npu/htp-files/PlatformValidatorShared.dll +0 -0
- nexaai/binds/npu/htp-files/QnnChrometraceProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnCpu.dll +0 -0
- nexaai/binds/npu/htp-files/QnnCpuNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDsp.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDspNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDspV66CalculatorStub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDspV66Stub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGenAiTransformer.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGenAiTransformerCpuOpPkg.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGenAiTransformerModel.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGpu.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGpuNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGpuProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtp.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpOptraceProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpPrepare.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV68CalculatorStub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV68Stub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV73CalculatorStub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV73Stub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnIr.dll +0 -0
- nexaai/binds/npu/htp-files/QnnJsonProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnModelDlc.dll +0 -0
- nexaai/binds/npu/htp-files/QnnSaver.dll +0 -0
- nexaai/binds/npu/htp-files/QnnSystem.dll +0 -0
- nexaai/binds/npu/htp-files/SNPE.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeDspV66Stub.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeHtpPrepare.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeHtpV68Stub.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeHtpV73Stub.dll +0 -0
- nexaai/binds/npu/htp-files/calculator.dll +0 -0
- nexaai/binds/npu/htp-files/calculator_htp.dll +0 -0
- nexaai/binds/npu/htp-files/libCalculator_skel.so +0 -0
- nexaai/binds/npu/htp-files/libQnnHtpV73.so +0 -0
- nexaai/binds/npu/htp-files/libQnnHtpV73QemuDriver.so +0 -0
- nexaai/binds/npu/htp-files/libQnnHtpV73Skel.so +0 -0
- nexaai/binds/npu/htp-files/libQnnSaver.so +0 -0
- nexaai/binds/npu/htp-files/libQnnSystem.so +0 -0
- nexaai/binds/npu/htp-files/libSnpeHtpV73Skel.so +0 -0
- nexaai/binds/npu/htp-files/libqnnhtpv73.cat +0 -0
- nexaai/binds/npu/htp-files/libsnpehtpv73.cat +0 -0
- nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
- nexaai/binds/npu/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/npu/libmp3lame.DLL +0 -0
- nexaai/binds/npu/libomp140.aarch64.dll +0 -0
- nexaai/binds/npu/libssl-3-arm64.dll +0 -0
- nexaai/binds/npu/liquid-sdk.dll +0 -0
- nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
- nexaai/binds/npu/mpg123.dll +0 -0
- nexaai/binds/npu/nexa-mm-process.dll +0 -0
- nexaai/binds/npu/nexa-sampling.dll +0 -0
- nexaai/binds/npu/nexa_plugin.dll +0 -0
- nexaai/binds/npu/nexaproc.dll +0 -0
- nexaai/binds/npu/ogg.dll +0 -0
- nexaai/binds/npu/omni-neural-sdk.dll +0 -0
- nexaai/binds/npu/openblas.dll +0 -0
- nexaai/binds/npu/opus.dll +0 -0
- nexaai/binds/npu/paddle-ocr-proc-lib.dll +0 -0
- nexaai/binds/npu/paddleocr-sdk.dll +0 -0
- nexaai/binds/npu/parakeet-sdk.dll +0 -0
- nexaai/binds/npu/phi3-5-sdk.dll +0 -0
- nexaai/binds/npu/phi4-sdk.dll +0 -0
- nexaai/binds/npu/pyannote-sdk.dll +0 -0
- nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-vision.dll +0 -0
- nexaai/binds/npu/rtaudio.dll +0 -0
- nexaai/binds/npu/vorbis.dll +0 -0
- nexaai/binds/npu/vorbisenc.dll +0 -0
- nexaai/binds/npu/yolov12-sdk.dll +0 -0
- nexaai/binds/npu/zlib1.dll +0 -0
- nexaai/binds/rerank_bind.cp312-win_arm64.pyd +0 -0
- nexaai/binds/vlm_bind.cp312-win_arm64.pyd +0 -0
- nexaai/common.py +105 -0
- nexaai/cv.py +93 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +89 -0
- nexaai/cv_impl/pybind_cv_impl.py +32 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +220 -0
- nexaai/log.py +92 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1562 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +385 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +256 -0
- nexaai-1.0.21rc16.dist-info/METADATA +31 -0
- nexaai-1.0.21rc16.dist-info/RECORD +154 -0
- nexaai-1.0.21rc16.dist-info/WHEEL +5 -0
- nexaai-1.0.21rc16.dist-info/top_level.txt +1 -0
nexaai/__init__.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NexaAI Python bindings for NexaSDK C-lib backend.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
# Add mlx_backend to Python path as individual module (only if it exists)
|
|
9
|
+
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
10
|
+
_mlx_backend_path = os.path.join(_current_dir, "mlx_backend")
|
|
11
|
+
# Only add to path if the directory exists (it won't on Windows)
|
|
12
|
+
if os.path.exists(_mlx_backend_path) and _mlx_backend_path not in sys.path:
|
|
13
|
+
sys.path.insert(0, _mlx_backend_path)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from ._version import __version__
|
|
17
|
+
except ImportError:
|
|
18
|
+
# Fallback for development or when version file hasn't been generated yet
|
|
19
|
+
__version__ = "0.0.1"
|
|
20
|
+
|
|
21
|
+
# Import common configuration classes first (no external dependencies)
|
|
22
|
+
from .common import ModelConfig, GenerationConfig, ChatMessage, SamplerConfig, PluginID
|
|
23
|
+
|
|
24
|
+
# Import logging functionality
|
|
25
|
+
from .log import set_logger, get_error_message
|
|
26
|
+
|
|
27
|
+
# Import runtime errors
|
|
28
|
+
from .runtime_error import (
|
|
29
|
+
NexaRuntimeError,
|
|
30
|
+
ContextLengthExceededError,
|
|
31
|
+
GenerationError
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Create alias for PluginID to be accessible as plugin_id
|
|
35
|
+
plugin_id = PluginID
|
|
36
|
+
|
|
37
|
+
# Import new feature classes (no external dependencies in base classes)
|
|
38
|
+
from .llm import LLM
|
|
39
|
+
from .embedder import Embedder, EmbeddingConfig
|
|
40
|
+
from .vlm import VLM
|
|
41
|
+
from .asr import ASR, ASRConfig, ASRResult
|
|
42
|
+
from .cv import CVModel, CVModelConfig, CVResult, CVResults, CVCapabilities, BoundingBox
|
|
43
|
+
from .rerank import Reranker, RerankConfig
|
|
44
|
+
from .image_gen import ImageGen, ImageGenerationConfig, ImageSamplerConfig, SchedulerConfig, Image
|
|
45
|
+
from .tts import TTS, TTSConfig, TTSSamplerConfig, TTSResult
|
|
46
|
+
|
|
47
|
+
# Build __all__ list dynamically
|
|
48
|
+
__all__ = [
|
|
49
|
+
"__version__",
|
|
50
|
+
# Common configurations (always available)
|
|
51
|
+
"ModelConfig",
|
|
52
|
+
"GenerationConfig",
|
|
53
|
+
"ChatMessage",
|
|
54
|
+
"SamplerConfig",
|
|
55
|
+
"EmbeddingConfig",
|
|
56
|
+
"PluginID",
|
|
57
|
+
"plugin_id",
|
|
58
|
+
|
|
59
|
+
# Logging functionality
|
|
60
|
+
"set_logger",
|
|
61
|
+
"get_error_message",
|
|
62
|
+
|
|
63
|
+
# Runtime errors
|
|
64
|
+
"NexaRuntimeError",
|
|
65
|
+
"ContextLengthExceededError",
|
|
66
|
+
"GenerationError",
|
|
67
|
+
|
|
68
|
+
"LLM",
|
|
69
|
+
"Embedder",
|
|
70
|
+
"VLM",
|
|
71
|
+
"ASR",
|
|
72
|
+
"CVModel",
|
|
73
|
+
"Reranker",
|
|
74
|
+
"ImageGen",
|
|
75
|
+
"TTS",
|
|
76
|
+
|
|
77
|
+
"ASRConfig",
|
|
78
|
+
"ASRResult",
|
|
79
|
+
"CVModelConfig",
|
|
80
|
+
"CVResult",
|
|
81
|
+
"CVResults",
|
|
82
|
+
"CVCapabilities",
|
|
83
|
+
"BoundingBox",
|
|
84
|
+
"RerankConfig",
|
|
85
|
+
"ImageGenerationConfig",
|
|
86
|
+
"ImageSamplerConfig",
|
|
87
|
+
"SchedulerConfig",
|
|
88
|
+
"Image",
|
|
89
|
+
"TTSConfig",
|
|
90
|
+
"TTSSamplerConfig",
|
|
91
|
+
"TTSResult",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
|
|
Binary file
|
nexaai/_version.py
ADDED
nexaai/asr.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import List, Optional, Sequence, Tuple, Union
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from nexaai.base import BaseModel
|
|
6
|
+
from nexaai.common import PluginID, ModelConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ASRConfig:
|
|
11
|
+
"""Configuration for ASR."""
|
|
12
|
+
timestamps: str = "none" # "none" | "segment" | "word"
|
|
13
|
+
beam_size: int = 5
|
|
14
|
+
stream: bool = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ASRResult:
|
|
19
|
+
"""Result from ASR processing."""
|
|
20
|
+
transcript: str
|
|
21
|
+
confidence_scores: Sequence[float]
|
|
22
|
+
timestamps: Sequence[Tuple[float, float]]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ASR(BaseModel):
|
|
26
|
+
"""Abstract base class for Automatic Speech Recognition models."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, m_cfg: ModelConfig = ModelConfig()):
|
|
29
|
+
"""Initialize base ASR class."""
|
|
30
|
+
self._m_cfg = m_cfg
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def _load_from(cls,
|
|
34
|
+
model_path: str,
|
|
35
|
+
model_name: Optional[str] = None,
|
|
36
|
+
tokenizer_path: Optional[str] = None,
|
|
37
|
+
language: Optional[str] = None,
|
|
38
|
+
m_cfg: ModelConfig = ModelConfig(),
|
|
39
|
+
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
40
|
+
device_id: Optional[str] = None,
|
|
41
|
+
**kwargs
|
|
42
|
+
) -> 'ASR':
|
|
43
|
+
"""Load ASR model from local path, routing to appropriate implementation."""
|
|
44
|
+
# Check plugin_id value for routing - handle both enum and string
|
|
45
|
+
plugin_value = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
|
|
46
|
+
|
|
47
|
+
if plugin_value == "mlx":
|
|
48
|
+
from nexaai.asr_impl.mlx_asr_impl import MLXASRImpl
|
|
49
|
+
return MLXASRImpl._load_from(model_path, model_name, tokenizer_path, language, m_cfg, plugin_id, device_id)
|
|
50
|
+
else:
|
|
51
|
+
from nexaai.asr_impl.pybind_asr_impl import PyBindASRImpl
|
|
52
|
+
return PyBindASRImpl._load_from(model_path, model_name, tokenizer_path, language, m_cfg, plugin_id, device_id)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def transcribe(
|
|
57
|
+
self,
|
|
58
|
+
audio_path: str,
|
|
59
|
+
language: Optional[str] = None,
|
|
60
|
+
config: Optional[ASRConfig] = None,
|
|
61
|
+
) -> ASRResult:
|
|
62
|
+
"""Transcribe audio file to text."""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def list_supported_languages(self) -> List[str]:
|
|
67
|
+
"""List supported languages."""
|
|
68
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Note: This code is generated by Cursor, not tested yet.
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from nexaai.common import PluginID
|
|
6
|
+
from nexaai.asr import ASR, ASRConfig, ASRResult
|
|
7
|
+
from nexaai.mlx_backend.asr.interface import MlxAsr as MLXASRInterface
|
|
8
|
+
from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MLXASRImpl(ASR):
|
|
12
|
+
def __init__(self):
|
|
13
|
+
"""Initialize MLX ASR implementation."""
|
|
14
|
+
super().__init__()
|
|
15
|
+
self._mlx_asr = None
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def _load_from(cls,
|
|
19
|
+
model_path: str,
|
|
20
|
+
tokenizer_path: Optional[str] = None,
|
|
21
|
+
language: Optional[str] = None,
|
|
22
|
+
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
23
|
+
device_id: Optional[str] = None
|
|
24
|
+
) -> 'MLXASRImpl':
|
|
25
|
+
"""Load ASR model from local path using MLX backend."""
|
|
26
|
+
try:
|
|
27
|
+
# MLX ASR interface is already imported
|
|
28
|
+
|
|
29
|
+
# Create instance and load MLX ASR
|
|
30
|
+
instance = cls()
|
|
31
|
+
instance._mlx_asr = MLXASRInterface(
|
|
32
|
+
model_path=model_path,
|
|
33
|
+
tokenizer_path=tokenizer_path,
|
|
34
|
+
language=language,
|
|
35
|
+
device=device_id
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return instance
|
|
39
|
+
except Exception as e:
|
|
40
|
+
raise RuntimeError(f"Failed to load MLX ASR: {str(e)}")
|
|
41
|
+
|
|
42
|
+
def eject(self):
|
|
43
|
+
"""Destroy the model and free resources."""
|
|
44
|
+
if self._mlx_asr:
|
|
45
|
+
self._mlx_asr.destroy()
|
|
46
|
+
self._mlx_asr = None
|
|
47
|
+
|
|
48
|
+
def transcribe(
|
|
49
|
+
self,
|
|
50
|
+
audio_path: str,
|
|
51
|
+
language: Optional[str] = None,
|
|
52
|
+
config: Optional[ASRConfig] = None,
|
|
53
|
+
) -> ASRResult:
|
|
54
|
+
"""Transcribe audio file to text."""
|
|
55
|
+
if not self._mlx_asr:
|
|
56
|
+
raise RuntimeError("MLX ASR not loaded")
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
# Convert our config to MLX format if provided
|
|
60
|
+
mlx_config = None
|
|
61
|
+
if config:
|
|
62
|
+
from nexaai.mlx_backend.ml import ASRConfig as MLXASRConfig
|
|
63
|
+
|
|
64
|
+
mlx_config = MLXASRConfig()
|
|
65
|
+
mlx_config.timestamps = config.timestamps
|
|
66
|
+
mlx_config.beam_size = config.beam_size
|
|
67
|
+
mlx_config.stream = config.stream
|
|
68
|
+
|
|
69
|
+
# Use MLX ASR transcription
|
|
70
|
+
result = self._mlx_asr.transcribe(audio_path, language, mlx_config)
|
|
71
|
+
|
|
72
|
+
# Convert MLX result to our format
|
|
73
|
+
return ASRResult(
|
|
74
|
+
transcript=result.transcript,
|
|
75
|
+
confidence_scores=result.confidence_scores,
|
|
76
|
+
timestamps=result.timestamps
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
except Exception as e:
|
|
80
|
+
raise RuntimeError(f"Failed to transcribe audio: {str(e)}")
|
|
81
|
+
|
|
82
|
+
def list_supported_languages(self) -> List[str]:
|
|
83
|
+
"""List supported languages."""
|
|
84
|
+
if not self._mlx_asr:
|
|
85
|
+
raise RuntimeError("MLX ASR not loaded")
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
return self._mlx_asr.list_supported_languages()
|
|
89
|
+
except Exception as e:
|
|
90
|
+
raise RuntimeError(f"Failed to list supported languages: {str(e)}")
|
|
91
|
+
|
|
92
|
+
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from nexaai.common import PluginID, ModelConfig
|
|
4
|
+
from nexaai.asr import ASR, ASRConfig, ASRResult
|
|
5
|
+
from nexaai.binds import asr_bind, common_bind
|
|
6
|
+
from nexaai.runtime import _ensure_runtime
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PyBindASRImpl(ASR):
|
|
10
|
+
def __init__(self, handle: any, m_cfg: ModelConfig = ModelConfig()):
|
|
11
|
+
"""Private constructor, should not be called directly."""
|
|
12
|
+
super().__init__(m_cfg)
|
|
13
|
+
self._handle = handle # This is a py::capsule
|
|
14
|
+
self._model_config = None
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def _load_from(cls,
|
|
18
|
+
model_path: str,
|
|
19
|
+
model_name: Optional[str] = None,
|
|
20
|
+
tokenizer_path: Optional[str] = None,
|
|
21
|
+
language: Optional[str] = None,
|
|
22
|
+
m_cfg: ModelConfig = ModelConfig(),
|
|
23
|
+
plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
|
|
24
|
+
device_id: Optional[str] = None
|
|
25
|
+
) -> 'PyBindASRImpl':
|
|
26
|
+
"""Load ASR model from local path using PyBind backend."""
|
|
27
|
+
_ensure_runtime()
|
|
28
|
+
|
|
29
|
+
# Create model config
|
|
30
|
+
config = common_bind.ModelConfig()
|
|
31
|
+
|
|
32
|
+
config.n_ctx = m_cfg.n_ctx
|
|
33
|
+
if m_cfg.n_threads is not None:
|
|
34
|
+
config.n_threads = m_cfg.n_threads
|
|
35
|
+
if m_cfg.n_threads_batch is not None:
|
|
36
|
+
config.n_threads_batch = m_cfg.n_threads_batch
|
|
37
|
+
if m_cfg.n_batch is not None:
|
|
38
|
+
config.n_batch = m_cfg.n_batch
|
|
39
|
+
if m_cfg.n_ubatch is not None:
|
|
40
|
+
config.n_ubatch = m_cfg.n_ubatch
|
|
41
|
+
if m_cfg.n_seq_max is not None:
|
|
42
|
+
config.n_seq_max = m_cfg.n_seq_max
|
|
43
|
+
config.n_gpu_layers = m_cfg.n_gpu_layers
|
|
44
|
+
|
|
45
|
+
# handle chat template strings
|
|
46
|
+
if m_cfg.chat_template_path:
|
|
47
|
+
config.chat_template_path = m_cfg.chat_template_path
|
|
48
|
+
|
|
49
|
+
if m_cfg.chat_template_content:
|
|
50
|
+
config.chat_template_content = m_cfg.chat_template_content
|
|
51
|
+
|
|
52
|
+
# Convert plugin_id to string
|
|
53
|
+
plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else str(plugin_id)
|
|
54
|
+
|
|
55
|
+
# Create ASR handle using the binding
|
|
56
|
+
handle = asr_bind.ml_asr_create(
|
|
57
|
+
model_path=model_path,
|
|
58
|
+
model_name=model_name,
|
|
59
|
+
tokenizer_path=tokenizer_path,
|
|
60
|
+
model_config=config,
|
|
61
|
+
language=language,
|
|
62
|
+
plugin_id=plugin_id_str,
|
|
63
|
+
device_id=device_id,
|
|
64
|
+
license_id=None, # Optional
|
|
65
|
+
license_key=None # Optional
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return cls(handle, m_cfg)
|
|
69
|
+
|
|
70
|
+
def eject(self):
|
|
71
|
+
"""Release the model from memory."""
|
|
72
|
+
# py::capsule handles cleanup automatically
|
|
73
|
+
if hasattr(self, '_handle') and self._handle is not None:
|
|
74
|
+
del self._handle
|
|
75
|
+
self._handle = None
|
|
76
|
+
|
|
77
|
+
def transcribe(
|
|
78
|
+
self,
|
|
79
|
+
audio_path: str,
|
|
80
|
+
language: Optional[str] = None,
|
|
81
|
+
config: Optional[ASRConfig] = None,
|
|
82
|
+
) -> ASRResult:
|
|
83
|
+
"""Transcribe audio file to text."""
|
|
84
|
+
if self._handle is None:
|
|
85
|
+
raise RuntimeError("ASR model not loaded. Call _load_from first.")
|
|
86
|
+
|
|
87
|
+
# Convert ASRConfig to binding format if provided
|
|
88
|
+
asr_config = None
|
|
89
|
+
if config:
|
|
90
|
+
asr_config = asr_bind.ASRConfig()
|
|
91
|
+
asr_config.timestamps = config.timestamps
|
|
92
|
+
asr_config.beam_size = config.beam_size
|
|
93
|
+
asr_config.stream = config.stream
|
|
94
|
+
|
|
95
|
+
# Perform transcription using the binding
|
|
96
|
+
result_dict = asr_bind.ml_asr_transcribe(
|
|
97
|
+
handle=self._handle,
|
|
98
|
+
audio_path=audio_path,
|
|
99
|
+
language=language,
|
|
100
|
+
config=asr_config
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Convert result to ASRResult
|
|
104
|
+
transcript = result_dict.get("transcript", "")
|
|
105
|
+
confidence_scores = result_dict.get("confidence_scores")
|
|
106
|
+
timestamps = result_dict.get("timestamps")
|
|
107
|
+
|
|
108
|
+
# Convert timestamps to the expected format
|
|
109
|
+
timestamp_pairs = []
|
|
110
|
+
if timestamps:
|
|
111
|
+
for start, end in timestamps:
|
|
112
|
+
timestamp_pairs.append((float(start), float(end)))
|
|
113
|
+
|
|
114
|
+
return ASRResult(
|
|
115
|
+
transcript=transcript,
|
|
116
|
+
confidence_scores=confidence_scores or [],
|
|
117
|
+
timestamps=timestamp_pairs
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def list_supported_languages(self) -> List[str]:
|
|
121
|
+
"""List supported languages."""
|
|
122
|
+
if self._handle is None:
|
|
123
|
+
raise RuntimeError("ASR model not loaded. Call _load_from first.")
|
|
124
|
+
|
|
125
|
+
# Get supported languages using the binding
|
|
126
|
+
languages = asr_bind.ml_asr_list_supported_languages(handle=self._handle)
|
|
127
|
+
return languages
|
nexaai/base.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from nexaai.common import ProfilingData
|
|
3
|
+
from nexaai.utils.model_manager import auto_download_model
|
|
4
|
+
|
|
5
|
+
class BaseModel(ABC):
|
|
6
|
+
|
|
7
|
+
def __enter__(self):
|
|
8
|
+
return self
|
|
9
|
+
|
|
10
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
11
|
+
self.eject()
|
|
12
|
+
|
|
13
|
+
def __del__(self):
|
|
14
|
+
self.eject()
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
@auto_download_model
|
|
18
|
+
def from_(cls, name_or_path: str, **kwargs) -> "BaseModel":
|
|
19
|
+
"""
|
|
20
|
+
initialize model from (1) HF (2) if not found, then from local path
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
return cls._load_from(name_or_path, **kwargs)
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def _load_from(cls, name_or_path: str, **kwargs) -> "BaseModel":
|
|
28
|
+
"""
|
|
29
|
+
Model-specific loading logic. Must be implemented by each model type.
|
|
30
|
+
Called after model is available locally.
|
|
31
|
+
"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def eject(self):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def get_profiling_data(self) -> ProfilingData:
|
|
39
|
+
pass
|
nexaai/binds/__init__.py
ADDED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|