isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,186 @@
1
+ import logging
2
+ import asyncio
3
+ import io
4
+ import numpy as np
5
+ from typing import Dict, Any, Optional, Union, BinaryIO
6
+
7
+ from isa_model.inference.services.base_service import BaseService
8
+ from isa_model.inference.backends.triton_client import TritonClient
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class WhisperService(BaseService):
14
+ """
15
+ Service for Whisper speech-to-text using Triton Inference Server.
16
+ """
17
+
18
+ def __init__(self, triton_url: str = "localhost:8001", model_name: str = "whisper"):
19
+ """
20
+ Initialize the Whisper service.
21
+
22
+ Args:
23
+ triton_url: URL of the Triton Inference Server
24
+ model_name: Name of the model in Triton
25
+ """
26
+ super().__init__()
27
+ self.triton_url = triton_url
28
+ self.model_name = model_name
29
+ self.client = None
30
+
31
+ # Default configuration
32
+ self.default_config = {
33
+ "language": "en",
34
+ "sampling_rate": 16000
35
+ }
36
+
37
+ self.logger = logger
38
+
39
+ async def load(self) -> None:
40
+ """
41
+ Load the client connection to Triton.
42
+ """
43
+ if self.is_loaded():
44
+ return
45
+
46
+ try:
47
+ from tritonclient.http import InferenceServerClient
48
+
49
+ # Create Triton client
50
+ self.logger.info(f"Connecting to Triton server at {self.triton_url}")
51
+ self.client = TritonClient(self.triton_url)
52
+
53
+ # Check if model is ready
54
+ if not await self.client.is_model_ready(self.model_name):
55
+ self.logger.error(f"Model {self.model_name} is not ready on Triton server")
56
+ raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
57
+
58
+ self._loaded = True
59
+ self.logger.info(f"Connected to Triton for model {self.model_name}")
60
+
61
+ except Exception as e:
62
+ self.logger.error(f"Failed to connect to Triton: {str(e)}")
63
+ raise
64
+
65
+ async def unload(self) -> None:
66
+ """
67
+ Unload the client connection.
68
+ """
69
+ if not self.is_loaded():
70
+ return
71
+
72
+ self.client = None
73
+ self._loaded = False
74
+ self.logger.info("Triton client connection closed")
75
+
76
+ async def transcribe(self,
77
+ audio: Union[str, BinaryIO, bytes, np.ndarray],
78
+ language: str = "en",
79
+ config: Optional[Dict[str, Any]] = None) -> str:
80
+ """
81
+ Transcribe audio to text using Triton.
82
+
83
+ Args:
84
+ audio: Audio input (file path, file-like object, bytes, or numpy array)
85
+ language: Language code (e.g., "en", "fr")
86
+ config: Additional configuration parameters
87
+
88
+ Returns:
89
+ Transcribed text
90
+ """
91
+ if not self.is_loaded():
92
+ await self.load()
93
+
94
+ # Process audio to get numpy array
95
+ audio_array = await self._process_audio_input(audio)
96
+
97
+ # Get configuration
98
+ merged_config = self.default_config.copy()
99
+ if config:
100
+ merged_config.update(config)
101
+
102
+ # Override language if provided
103
+ if language:
104
+ merged_config["language"] = language
105
+
106
+ try:
107
+ # Prepare inputs
108
+ inputs = {
109
+ "audio_input": audio_array,
110
+ "language": np.array([merged_config["language"]], dtype=np.object_)
111
+ }
112
+
113
+ # Run inference
114
+ result = await self.client.infer(
115
+ model_name=self.model_name,
116
+ inputs=inputs,
117
+ outputs=["text_output"]
118
+ )
119
+
120
+ # Extract transcription
121
+ transcription = result["text_output"][0].decode('utf-8')
122
+
123
+ return transcription
124
+
125
+ except Exception as e:
126
+ self.logger.error(f"Error during Whisper transcription: {str(e)}")
127
+ raise
128
+
129
+ async def _process_audio_input(self, audio: Union[str, BinaryIO, bytes, np.ndarray]) -> np.ndarray:
130
+ """
131
+ Process different types of audio inputs into a numpy array.
132
+
133
+ Args:
134
+ audio: Audio input (file path, file-like object, bytes, or numpy array)
135
+
136
+ Returns:
137
+ Numpy array of the audio
138
+ """
139
+ if isinstance(audio, np.ndarray):
140
+ return audio
141
+
142
+ try:
143
+ import librosa
144
+
145
+ if isinstance(audio, str):
146
+ # File path
147
+ y, sr = librosa.load(audio, sr=self.default_config["sampling_rate"])
148
+ return y.astype(np.float32)
149
+
150
+ elif isinstance(audio, (io.IOBase, BinaryIO)):
151
+ # File-like object
152
+ audio.seek(0)
153
+ y, sr = librosa.load(audio, sr=self.default_config["sampling_rate"])
154
+ return y.astype(np.float32)
155
+
156
+ elif isinstance(audio, bytes):
157
+ # Bytes
158
+ with io.BytesIO(audio) as audio_bytes:
159
+ y, sr = librosa.load(audio_bytes, sr=self.default_config["sampling_rate"])
160
+ return y.astype(np.float32)
161
+
162
+ else:
163
+ raise ValueError(f"Unsupported audio type: {type(audio)}")
164
+
165
+ except ImportError:
166
+ self.logger.error("librosa not installed. Please install with: pip install librosa")
167
+ raise
168
+ except Exception as e:
169
+ self.logger.error(f"Error processing audio: {str(e)}")
170
+ raise
171
+
172
+ def get_model_info(self) -> Dict[str, Any]:
173
+ """
174
+ Get information about the model.
175
+
176
+ Returns:
177
+ Dictionary containing model information
178
+ """
179
+ return {
180
+ "name": self.model_name,
181
+ "type": "speech",
182
+ "backend": "triton",
183
+ "url": self.triton_url,
184
+ "loaded": self.is_loaded(),
185
+ "config": self.default_config
186
+ }
@@ -0,0 +1,71 @@
1
+ from typing import Dict, Any
2
+ import tempfile
3
+ import os
4
+ from openai import AsyncOpenAI
5
+ from tenacity import retry, stop_after_attempt, wait_exponential
6
+ from ...base_service import BaseService
7
+ from ...base_provider import BaseProvider
8
+ from app.config.config_manager import config_manager
9
+
10
+ logger = config_manager.get_logger(__name__)
11
+
12
+ class YYDSAudioService(BaseService):
13
+ """Audio model service wrapper for YYDS"""
14
+
15
+ def __init__(self, provider: 'BaseProvider', model_name: str):
16
+ super().__init__(provider, model_name)
17
+ # 初始化 AsyncOpenAI 客户端
18
+ self._client = AsyncOpenAI(
19
+ api_key=self.config.get('api_key'),
20
+ base_url=self.config.get('base_url')
21
+ )
22
+ self.language = self.config.get('language', None)
23
+
24
+ @property
25
+ def client(self) -> AsyncOpenAI:
26
+ """获取底层的 OpenAI 客户端"""
27
+ return self._client
28
+
29
+ @retry(
30
+ stop=stop_after_attempt(3),
31
+ wait=wait_exponential(multiplier=1, min=4, max=10),
32
+ reraise=True
33
+ )
34
+ async def transcribe(self, audio_data: bytes) -> Dict[str, Any]:
35
+ """转写音频数据
36
+
37
+ Args:
38
+ audio_data: 音频二进制数据
39
+
40
+ Returns:
41
+ Dict[str, Any]: 包含转写文本的字典
42
+ """
43
+ try:
44
+ # 创建临时文件存储音频数据
45
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
46
+ temp_file.write(audio_data)
47
+ temp_file.flush()
48
+
49
+ # 以二进制模式打开文件用于 API 请求
50
+ with open(temp_file.name, 'rb') as audio_file:
51
+ # 只在有效的 ISO-639-1 语言代码时包含 language 参数
52
+ params = {
53
+ 'model': self.model_name,
54
+ 'file': audio_file,
55
+ }
56
+ if self.language and isinstance(self.language, str):
57
+ params['language'] = self.language
58
+
59
+ response = await self._client.audio.transcriptions.create(**params)
60
+
61
+ # 清理临时文件
62
+ os.unlink(temp_file.name)
63
+
64
+ # 返回包含转写文本的字典
65
+ return {
66
+ "text": response.text
67
+ }
68
+
69
+ except Exception as e:
70
+ logger.error(f"Error in audio transcription: {e}")
71
+ raise
@@ -0,0 +1,106 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
3
+ from isa_model.inference.providers.base_provider import BaseProvider
4
+
5
+ T = TypeVar('T') # Generic type for responses
6
+
7
+ class BaseService(ABC):
8
+ """Base class for all AI services"""
9
+
10
+ def __init__(self, provider: 'BaseProvider', model_name: str):
11
+ self.provider = provider
12
+ self.model_name = model_name
13
+ self.config = provider.get_config()
14
+
15
+ def __await__(self):
16
+ """Make the service awaitable"""
17
+ yield
18
+ return self
19
+
20
+ class BaseLLMService(BaseService):
21
+ """Base class for LLM services"""
22
+
23
+ @abstractmethod
24
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
25
+ """Universal invocation method"""
26
+ pass
27
+
28
+ @abstractmethod
29
+ async def achat(self, messages: List[Dict[str, str]]) -> T:
30
+ """Chat completion method"""
31
+ pass
32
+
33
+ @abstractmethod
34
+ async def acompletion(self, prompt: str) -> T:
35
+ """Text completion method"""
36
+ pass
37
+
38
+ @abstractmethod
39
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
40
+ """Generate multiple completions"""
41
+ pass
42
+
43
+ @abstractmethod
44
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
45
+ """Stream chat responses"""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def get_token_usage(self) -> Any:
50
+ """Get total token usage statistics"""
51
+ pass
52
+
53
+ @abstractmethod
54
+ def get_last_token_usage(self) -> Dict[str, int]:
55
+ """Get token usage from last request"""
56
+ pass
57
+
58
+ class BaseEmbeddingService(BaseService):
59
+ """Base class for embedding services"""
60
+
61
+ @abstractmethod
62
+ async def create_text_embedding(self, text: str) -> List[float]:
63
+ """Create embedding for single text"""
64
+ pass
65
+
66
+ @abstractmethod
67
+ async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
68
+ """Create embeddings for multiple texts"""
69
+ pass
70
+
71
+ @abstractmethod
72
+ async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
73
+ """Create text chunks with embeddings"""
74
+ pass
75
+
76
+ @abstractmethod
77
+ async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
78
+ """Compute similarity between two embeddings"""
79
+ pass
80
+
81
+ @abstractmethod
82
+ async def close(self):
83
+ """Cleanup resources"""
84
+ pass
85
+
86
+ class BaseRerankService(BaseService):
87
+ """Base class for reranking services"""
88
+
89
+ @abstractmethod
90
+ async def rerank(
91
+ self,
92
+ query: str,
93
+ documents: List[Dict],
94
+ top_k: int = 5
95
+ ) -> List[Dict]:
96
+ """Rerank documents based on query relevance"""
97
+ pass
98
+
99
+ @abstractmethod
100
+ async def rerank_texts(
101
+ self,
102
+ query: str,
103
+ texts: List[str]
104
+ ) -> List[Dict]:
105
+ """Rerank raw texts based on query relevance"""
106
+ pass
@@ -0,0 +1,66 @@
1
+ from abc import abstractmethod
2
+ from typing import Dict, Any, Optional, Union, BinaryIO
3
+ from .base_service import BaseService
4
+
5
+ class BaseTTSService(BaseService):
6
+ """Base class for Text-to-Speech services"""
7
+
8
+ @abstractmethod
9
+ async def generate_speech(
10
+ self,
11
+ text: str,
12
+ voice_id: Optional[str] = None,
13
+ language: Optional[str] = None,
14
+ speed: float = 1.0,
15
+ options: Optional[Dict[str, Any]] = None
16
+ ) -> bytes:
17
+ """
18
+ Generate speech from text
19
+
20
+ Args:
21
+ text: The text to convert to speech
22
+ voice_id: Optional voice identifier
23
+ language: Optional language code
24
+ speed: Speech speed factor (1.0 is normal speed)
25
+ options: Additional model-specific options
26
+
27
+ Returns:
28
+ Audio data as bytes
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ async def save_to_file(
34
+ self,
35
+ text: str,
36
+ output_file: Union[str, BinaryIO],
37
+ voice_id: Optional[str] = None,
38
+ language: Optional[str] = None,
39
+ speed: float = 1.0,
40
+ options: Optional[Dict[str, Any]] = None
41
+ ) -> str:
42
+ """
43
+ Generate speech and save to file
44
+
45
+ Args:
46
+ text: The text to convert to speech
47
+ output_file: Path to output file or file-like object
48
+ voice_id: Optional voice identifier
49
+ language: Optional language code
50
+ speed: Speech speed factor (1.0 is normal speed)
51
+ options: Additional model-specific options
52
+
53
+ Returns:
54
+ Path to the saved file
55
+ """
56
+ pass
57
+
58
+ @abstractmethod
59
+ async def get_available_voices(self) -> Dict[str, Any]:
60
+ """
61
+ Get available voices for the TTS service
62
+
63
+ Returns:
64
+ Dictionary of available voices with their details
65
+ """
66
+ pass
@@ -0,0 +1,183 @@
1
+ import logging
2
+ import asyncio
3
+ import numpy as np
4
+ from typing import Dict, List, Any, Optional, Union
5
+
6
+ from isa_model.inference.services.base_service import BaseService
7
+ from isa_model.inference.backends.triton_client import TritonClient
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class BgeEmbeddingService(BaseService):
13
+ """
14
+ Service for BGE embedding using Triton Inference Server.
15
+ """
16
+
17
+ def __init__(self, triton_url: str = "localhost:8001", model_name: str = "bge_embed"):
18
+ """
19
+ Initialize the BGE embedding service.
20
+
21
+ Args:
22
+ triton_url: URL of the Triton Inference Server
23
+ model_name: Name of the model in Triton
24
+ """
25
+ super().__init__()
26
+ self.triton_url = triton_url
27
+ self.model_name = model_name
28
+ self.client = None
29
+
30
+ # Default configuration
31
+ self.default_config = {
32
+ "normalize": True
33
+ }
34
+
35
+ self.logger = logger
36
+
37
+ async def load(self) -> None:
38
+ """
39
+ Load the client connection to Triton.
40
+ """
41
+ if self.is_loaded():
42
+ return
43
+
44
+ try:
45
+ # Create Triton client
46
+ self.logger.info(f"Connecting to Triton server at {self.triton_url}")
47
+ self.client = TritonClient(self.triton_url)
48
+
49
+ # Check if model is ready
50
+ if not await self.client.is_model_ready(self.model_name):
51
+ self.logger.error(f"Model {self.model_name} is not ready on Triton server")
52
+ raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
53
+
54
+ self._loaded = True
55
+ self.logger.info(f"Connected to Triton for model {self.model_name}")
56
+
57
+ except Exception as e:
58
+ self.logger.error(f"Failed to connect to Triton: {str(e)}")
59
+ raise
60
+
61
+ async def unload(self) -> None:
62
+ """
63
+ Unload the client connection.
64
+ """
65
+ if not self.is_loaded():
66
+ return
67
+
68
+ self.client = None
69
+ self._loaded = False
70
+ self.logger.info("Triton client connection closed")
71
+
72
+ async def embed(self,
73
+ texts: Union[str, List[str]],
74
+ normalize: Optional[bool] = None) -> np.ndarray:
75
+ """
76
+ Generate embeddings for texts using Triton.
77
+
78
+ Args:
79
+ texts: Single text or list of texts to embed
80
+ normalize: Whether to normalize embeddings (if None, use default)
81
+
82
+ Returns:
83
+ Numpy array of embeddings, shape [batch_size, embedding_dim]
84
+ """
85
+ if not self.is_loaded():
86
+ await self.load()
87
+
88
+ # Handle single text input
89
+ if isinstance(texts, str):
90
+ texts = [texts]
91
+
92
+ # Use default normalize setting if not specified
93
+ if normalize is None:
94
+ normalize = self.default_config["normalize"]
95
+
96
+ try:
97
+ # Prepare inputs
98
+ inputs = {
99
+ "text_input": texts,
100
+ "normalize": np.array([normalize], dtype=bool)
101
+ }
102
+
103
+ # Run inference
104
+ result = await self.client.infer(
105
+ model_name=self.model_name,
106
+ inputs=inputs,
107
+ outputs=["embedding"]
108
+ )
109
+
110
+ # Extract embeddings
111
+ embeddings = result["embedding"]
112
+
113
+ return embeddings
114
+
115
+ except Exception as e:
116
+ self.logger.error(f"Error during embedding generation: {str(e)}")
117
+ raise
118
+
119
+ async def similarity(self,
120
+ text1: str,
121
+ text2: str,
122
+ normalize: Optional[bool] = None) -> float:
123
+ """
124
+ Calculate the similarity between two texts.
125
+
126
+ Args:
127
+ text1: First text
128
+ text2: Second text
129
+ normalize: Whether to normalize embeddings (if None, use default)
130
+
131
+ Returns:
132
+ Cosine similarity score (float between -1 and 1)
133
+ """
134
+ # Generate embeddings for both texts
135
+ embeddings = await self.embed([text1, text2], normalize=normalize)
136
+
137
+ # Calculate cosine similarity
138
+ from sklearn.metrics.pairwise import cosine_similarity
139
+ similarity = cosine_similarity(embeddings[0:1], embeddings[1:2])[0][0]
140
+
141
+ return float(similarity)
142
+
143
+ async def batch_similarity(self,
144
+ queries: List[str],
145
+ documents: List[str],
146
+ normalize: Optional[bool] = None) -> np.ndarray:
147
+ """
148
+ Calculate similarities between queries and documents.
149
+
150
+ Args:
151
+ queries: List of query texts
152
+ documents: List of document texts
153
+ normalize: Whether to normalize embeddings (if None, use default)
154
+
155
+ Returns:
156
+ Numpy array of similarity scores, shape [len(queries), len(documents)]
157
+ """
158
+ # Generate embeddings for queries and documents
159
+ query_embeddings = await self.embed(queries, normalize=normalize)
160
+ doc_embeddings = await self.embed(documents, normalize=normalize)
161
+
162
+ # Calculate cosine similarities
163
+ from sklearn.metrics.pairwise import cosine_similarity
164
+ similarities = cosine_similarity(query_embeddings, doc_embeddings)
165
+
166
+ return similarities
167
+
168
+ def get_model_info(self) -> Dict[str, Any]:
169
+ """
170
+ Get information about the model.
171
+
172
+ Returns:
173
+ Dictionary containing model information
174
+ """
175
+ return {
176
+ "name": self.model_name,
177
+ "type": "embedding",
178
+ "backend": "triton",
179
+ "url": self.triton_url,
180
+ "loaded": self.is_loaded(),
181
+ "embedding_dim": 1024, # Typical for BGE models
182
+ "config": self.default_config
183
+ }
@@ -0,0 +1,85 @@
1
+ import logging
2
+ from typing import List, Dict, Any, Optional
3
+ from isa_model.inference.services.base_service import BaseEmbeddingService
4
+ from isa_model.inference.providers.base_provider import BaseProvider
5
+ from isa_model.inference.backends.local_services import OllamaBackendClient
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class OllamaEmbedService(BaseEmbeddingService):
10
+ """Ollama embedding service using backend client"""
11
+
12
+ def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3", backend: Optional[OllamaBackendClient] = None):
13
+ super().__init__(provider, model_name)
14
+
15
+ # Use provided backend or create new one
16
+ if backend:
17
+ self.backend = backend
18
+ else:
19
+ host = self.config.get("host", "localhost")
20
+ port = self.config.get("port", 11434)
21
+ self.backend = OllamaBackendClient(host, port)
22
+
23
+ logger.info(f"Initialized OllamaEmbedService with model {model_name}")
24
+
25
+ async def create_text_embedding(self, text: str) -> List[float]:
26
+ """Create embedding for text"""
27
+ try:
28
+ payload = {
29
+ "model": self.model_name,
30
+ "prompt": text
31
+ }
32
+ response = await self.backend.post("/api/embeddings", payload)
33
+ return response["embedding"]
34
+
35
+ except Exception as e:
36
+ logger.error(f"Error creating text embedding: {e}")
37
+ raise
38
+
39
+ async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
40
+ """Create embeddings for multiple texts"""
41
+ embeddings = []
42
+ for text in texts:
43
+ embedding = await self.create_text_embedding(text)
44
+ embeddings.append(embedding)
45
+ return embeddings
46
+
47
+ async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
48
+ """Create text chunks with embeddings"""
49
+ # 简单实现:将文本分成固定大小的块
50
+ chunk_size = 200 # 单词数量
51
+ chunks = []
52
+
53
+ # 按单词分割
54
+ words = text.split()
55
+
56
+ # 分块
57
+ for i in range(0, len(words), chunk_size):
58
+ chunk_text = " ".join(words[i:i+chunk_size])
59
+ embedding = await self.create_text_embedding(chunk_text)
60
+
61
+ chunk = {
62
+ "text": chunk_text,
63
+ "embedding": embedding,
64
+ "metadata": metadata or {}
65
+ }
66
+ chunks.append(chunk)
67
+
68
+ return chunks
69
+
70
+ async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
71
+ """计算两个嵌入向量之间的余弦相似度"""
72
+ # 余弦相似度简单实现
73
+ dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
74
+ norm1 = sum(a * a for a in embedding1) ** 0.5
75
+ norm2 = sum(b * b for b in embedding2) ** 0.5
76
+
77
+ if norm1 * norm2 == 0:
78
+ return 0.0
79
+
80
+ return dot_product / (norm1 * norm2)
81
+
82
+ async def close(self):
83
+ """Close the backend client"""
84
+ await self.backend.close()
85
+