isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
|
3
3
|
from isa_model.inference.providers.base_provider import BaseProvider
|
4
|
+
from isa_model.inference.billing_tracker import track_usage, ServiceType, Provider
|
4
5
|
|
5
6
|
T = TypeVar('T') # Generic type for responses
|
6
7
|
|
@@ -10,7 +11,41 @@ class BaseService(ABC):
|
|
10
11
|
def __init__(self, provider: 'BaseProvider', model_name: str):
|
11
12
|
self.provider = provider
|
12
13
|
self.model_name = model_name
|
13
|
-
self.config = provider.
|
14
|
+
self.config = provider.get_full_config()
|
15
|
+
|
16
|
+
def _track_usage(
|
17
|
+
self,
|
18
|
+
service_type: Union[str, ServiceType],
|
19
|
+
operation: str,
|
20
|
+
input_tokens: Optional[int] = None,
|
21
|
+
output_tokens: Optional[int] = None,
|
22
|
+
input_units: Optional[float] = None,
|
23
|
+
output_units: Optional[float] = None,
|
24
|
+
metadata: Optional[Dict[str, Any]] = None
|
25
|
+
):
|
26
|
+
"""Track usage for billing purposes"""
|
27
|
+
try:
|
28
|
+
# Determine provider name - try multiple attributes
|
29
|
+
provider_name = getattr(self.provider, 'name', None) or \
|
30
|
+
getattr(self.provider, 'provider_name', None) or \
|
31
|
+
getattr(self.provider, '__class__', type(None)).__name__.lower().replace('provider', '') or \
|
32
|
+
'unknown'
|
33
|
+
|
34
|
+
track_usage(
|
35
|
+
provider=provider_name,
|
36
|
+
service_type=service_type,
|
37
|
+
model_name=self.model_name,
|
38
|
+
operation=operation,
|
39
|
+
input_tokens=input_tokens,
|
40
|
+
output_tokens=output_tokens,
|
41
|
+
input_units=input_units,
|
42
|
+
output_units=output_units,
|
43
|
+
metadata=metadata
|
44
|
+
)
|
45
|
+
except Exception as e:
|
46
|
+
# Don't let billing tracking break the service
|
47
|
+
import logging
|
48
|
+
logging.getLogger(__name__).warning(f"Failed to track usage: {e}")
|
14
49
|
|
15
50
|
def __await__(self):
|
16
51
|
"""Make the service awaitable"""
|
@@ -0,0 +1,112 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
class BaseEmbedService(BaseService):
|
6
|
+
"""Base class for embedding services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
10
|
+
"""
|
11
|
+
Create embedding for single text
|
12
|
+
|
13
|
+
Args:
|
14
|
+
text: Input text to embed
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
List of float values representing the embedding vector
|
18
|
+
"""
|
19
|
+
pass
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
23
|
+
"""
|
24
|
+
Create embeddings for multiple texts
|
25
|
+
|
26
|
+
Args:
|
27
|
+
texts: List of input texts to embed
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
List of embedding vectors, one for each input text
|
31
|
+
"""
|
32
|
+
pass
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
36
|
+
"""
|
37
|
+
Create text chunks with embeddings
|
38
|
+
|
39
|
+
Args:
|
40
|
+
text: Input text to chunk and embed
|
41
|
+
metadata: Optional metadata to attach to chunks
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
List of dictionaries containing:
|
45
|
+
- text: The chunk text
|
46
|
+
- embedding: The embedding vector
|
47
|
+
- metadata: Associated metadata
|
48
|
+
- start_index: Start position in original text
|
49
|
+
- end_index: End position in original text
|
50
|
+
"""
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
55
|
+
"""
|
56
|
+
Compute similarity between two embeddings
|
57
|
+
|
58
|
+
Args:
|
59
|
+
embedding1: First embedding vector
|
60
|
+
embedding2: Second embedding vector
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Similarity score (typically cosine similarity, range -1 to 1)
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
async def find_similar_texts(
|
69
|
+
self,
|
70
|
+
query_embedding: List[float],
|
71
|
+
candidate_embeddings: List[List[float]],
|
72
|
+
top_k: int = 5
|
73
|
+
) -> List[Dict[str, Any]]:
|
74
|
+
"""
|
75
|
+
Find most similar texts based on embeddings
|
76
|
+
|
77
|
+
Args:
|
78
|
+
query_embedding: Query embedding vector
|
79
|
+
candidate_embeddings: List of candidate embedding vectors
|
80
|
+
top_k: Number of top similar results to return
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
List of dictionaries containing:
|
84
|
+
- index: Index in candidate_embeddings
|
85
|
+
- similarity: Similarity score
|
86
|
+
"""
|
87
|
+
pass
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def get_embedding_dimension(self) -> int:
|
91
|
+
"""
|
92
|
+
Get the dimension of embeddings produced by this service
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
Integer dimension of embedding vectors
|
96
|
+
"""
|
97
|
+
pass
|
98
|
+
|
99
|
+
@abstractmethod
|
100
|
+
def get_max_input_length(self) -> int:
|
101
|
+
"""
|
102
|
+
Get maximum input text length supported
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Maximum number of characters/tokens supported
|
106
|
+
"""
|
107
|
+
pass
|
108
|
+
|
109
|
+
@abstractmethod
|
110
|
+
async def close(self):
|
111
|
+
"""Cleanup resources"""
|
112
|
+
pass
|
@@ -4,12 +4,12 @@ import asyncio
|
|
4
4
|
from typing import List, Dict, Any, Optional
|
5
5
|
|
6
6
|
# 保留您指定的导入和框架结构
|
7
|
-
from isa_model.inference.services.
|
7
|
+
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
8
|
from isa_model.inference.providers.base_provider import BaseProvider
|
9
9
|
|
10
10
|
logger = logging.getLogger(__name__)
|
11
11
|
|
12
|
-
class OllamaEmbedService(
|
12
|
+
class OllamaEmbedService(BaseEmbedService):
|
13
13
|
"""
|
14
14
|
Ollama embedding service.
|
15
15
|
此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
|
@@ -91,6 +91,32 @@ class OllamaEmbedService(BaseEmbeddingService):
|
|
91
91
|
|
92
92
|
return dot_product / (norm1 * norm2)
|
93
93
|
|
94
|
+
async def find_similar_texts(
|
95
|
+
self,
|
96
|
+
query_embedding: List[float],
|
97
|
+
candidate_embeddings: List[List[float]],
|
98
|
+
top_k: int = 5
|
99
|
+
) -> List[Dict[str, Any]]:
|
100
|
+
"""Find most similar texts based on embeddings"""
|
101
|
+
similarities = []
|
102
|
+
for i, candidate in enumerate(candidate_embeddings):
|
103
|
+
similarity = await self.compute_similarity(query_embedding, candidate)
|
104
|
+
similarities.append({"index": i, "similarity": similarity})
|
105
|
+
|
106
|
+
# Sort by similarity in descending order and return top_k
|
107
|
+
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
108
|
+
return similarities[:top_k]
|
109
|
+
|
110
|
+
def get_embedding_dimension(self) -> int:
|
111
|
+
"""Get the dimension of embeddings produced by this service"""
|
112
|
+
# BGE-M3 produces 1024-dimensional embeddings
|
113
|
+
return 1024
|
114
|
+
|
115
|
+
def get_max_input_length(self) -> int:
|
116
|
+
"""Get maximum input text length supported"""
|
117
|
+
# BGE-M3 supports up to 8192 tokens
|
118
|
+
return 8192
|
119
|
+
|
94
120
|
async def close(self):
|
95
121
|
"""关闭内置的 HTTP 客户端"""
|
96
122
|
await self.client.aclose()
|
@@ -0,0 +1,223 @@
|
|
1
|
+
import logging
|
2
|
+
import asyncio
|
3
|
+
from typing import List, Dict, Any, Optional
|
4
|
+
from openai import AsyncOpenAI
|
5
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
+
|
7
|
+
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
from isa_model.inference.billing_tracker import ServiceType
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
class OpenAIEmbedService(BaseEmbedService):
|
14
|
+
"""
|
15
|
+
OpenAI embedding service using text-embedding-3-small as default.
|
16
|
+
Provides high-quality embeddings for production use.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "text-embedding-3-small"):
|
20
|
+
super().__init__(provider, model_name)
|
21
|
+
|
22
|
+
# Get full configuration from provider (including sensitive data)
|
23
|
+
provider_config = provider.get_full_config()
|
24
|
+
|
25
|
+
# Initialize AsyncOpenAI client with provider configuration
|
26
|
+
try:
|
27
|
+
if not provider_config.get("api_key"):
|
28
|
+
raise ValueError("OpenAI API key not found in provider configuration")
|
29
|
+
|
30
|
+
self.client = AsyncOpenAI(
|
31
|
+
api_key=provider_config["api_key"],
|
32
|
+
base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
|
33
|
+
organization=provider_config.get("organization")
|
34
|
+
)
|
35
|
+
|
36
|
+
logger.info(f"Initialized OpenAIEmbedService with model '{self.model_name}'")
|
37
|
+
|
38
|
+
except Exception as e:
|
39
|
+
logger.error(f"Failed to initialize OpenAI client: {e}")
|
40
|
+
raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
|
41
|
+
|
42
|
+
# Model-specific configurations
|
43
|
+
self.dimensions = provider_config.get('dimensions', None) # Optional dimension reduction
|
44
|
+
self.encoding_format = provider_config.get('encoding_format', 'float')
|
45
|
+
|
46
|
+
@retry(
|
47
|
+
stop=stop_after_attempt(3),
|
48
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
49
|
+
reraise=True
|
50
|
+
)
|
51
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
52
|
+
"""Create embedding for single text"""
|
53
|
+
try:
|
54
|
+
kwargs = {
|
55
|
+
"model": self.model_name,
|
56
|
+
"input": text,
|
57
|
+
"encoding_format": self.encoding_format
|
58
|
+
}
|
59
|
+
|
60
|
+
# Add dimensions parameter if specified (for text-embedding-3-small/large)
|
61
|
+
if self.dimensions and "text-embedding-3" in self.model_name:
|
62
|
+
kwargs["dimensions"] = self.dimensions
|
63
|
+
|
64
|
+
response = await self.client.embeddings.create(**kwargs)
|
65
|
+
|
66
|
+
# Track usage for billing
|
67
|
+
usage = getattr(response, 'usage', None)
|
68
|
+
if usage:
|
69
|
+
total_tokens = getattr(usage, 'total_tokens', 0)
|
70
|
+
self._track_usage(
|
71
|
+
service_type=ServiceType.EMBEDDING,
|
72
|
+
operation="create_text_embedding",
|
73
|
+
input_tokens=total_tokens,
|
74
|
+
output_tokens=0,
|
75
|
+
metadata={
|
76
|
+
"model": self.model_name,
|
77
|
+
"dimensions": self.dimensions,
|
78
|
+
"text_length": len(text)
|
79
|
+
}
|
80
|
+
)
|
81
|
+
|
82
|
+
return response.data[0].embedding
|
83
|
+
|
84
|
+
except Exception as e:
|
85
|
+
logger.error(f"Error creating text embedding: {e}")
|
86
|
+
raise
|
87
|
+
|
88
|
+
@retry(
|
89
|
+
stop=stop_after_attempt(3),
|
90
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
91
|
+
reraise=True
|
92
|
+
)
|
93
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
94
|
+
"""Create embeddings for multiple texts"""
|
95
|
+
if not texts:
|
96
|
+
return []
|
97
|
+
|
98
|
+
try:
|
99
|
+
kwargs = {
|
100
|
+
"model": self.model_name,
|
101
|
+
"input": texts,
|
102
|
+
"encoding_format": self.encoding_format
|
103
|
+
}
|
104
|
+
|
105
|
+
# Add dimensions parameter if specified
|
106
|
+
if self.dimensions and "text-embedding-3" in self.model_name:
|
107
|
+
kwargs["dimensions"] = self.dimensions
|
108
|
+
|
109
|
+
response = await self.client.embeddings.create(**kwargs)
|
110
|
+
|
111
|
+
# Track usage for billing
|
112
|
+
usage = getattr(response, 'usage', None)
|
113
|
+
if usage:
|
114
|
+
total_tokens = getattr(usage, 'total_tokens', 0)
|
115
|
+
self._track_usage(
|
116
|
+
service_type=ServiceType.EMBEDDING,
|
117
|
+
operation="create_text_embeddings",
|
118
|
+
input_tokens=total_tokens,
|
119
|
+
output_tokens=0,
|
120
|
+
metadata={
|
121
|
+
"model": self.model_name,
|
122
|
+
"dimensions": self.dimensions,
|
123
|
+
"batch_size": len(texts),
|
124
|
+
"total_text_length": sum(len(t) for t in texts)
|
125
|
+
}
|
126
|
+
)
|
127
|
+
|
128
|
+
return [data.embedding for data in response.data]
|
129
|
+
|
130
|
+
except Exception as e:
|
131
|
+
logger.error(f"Error creating text embeddings: {e}")
|
132
|
+
raise
|
133
|
+
|
134
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
135
|
+
"""Create text chunks with embeddings"""
|
136
|
+
# Chunk size optimized for OpenAI models (roughly 512 tokens)
|
137
|
+
chunk_size = 400 # words
|
138
|
+
overlap = 50 # word overlap between chunks
|
139
|
+
|
140
|
+
words = text.split()
|
141
|
+
if not words:
|
142
|
+
return []
|
143
|
+
|
144
|
+
chunks = []
|
145
|
+
chunk_texts = []
|
146
|
+
|
147
|
+
for i in range(0, len(words), chunk_size - overlap):
|
148
|
+
chunk_words = words[i:i + chunk_size]
|
149
|
+
chunk_text = " ".join(chunk_words)
|
150
|
+
chunk_texts.append(chunk_text)
|
151
|
+
|
152
|
+
chunks.append({
|
153
|
+
"text": chunk_text,
|
154
|
+
"start_index": i,
|
155
|
+
"end_index": min(i + chunk_size, len(words)),
|
156
|
+
"metadata": metadata or {}
|
157
|
+
})
|
158
|
+
|
159
|
+
# Get embeddings for all chunks
|
160
|
+
embeddings = await self.create_text_embeddings(chunk_texts)
|
161
|
+
|
162
|
+
# Add embeddings to chunks
|
163
|
+
for chunk, embedding in zip(chunks, embeddings):
|
164
|
+
chunk["embedding"] = embedding
|
165
|
+
|
166
|
+
return chunks
|
167
|
+
|
168
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
169
|
+
"""Compute cosine similarity between two embeddings"""
|
170
|
+
import math
|
171
|
+
|
172
|
+
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
173
|
+
norm1 = math.sqrt(sum(a * a for a in embedding1))
|
174
|
+
norm2 = math.sqrt(sum(b * b for b in embedding2))
|
175
|
+
|
176
|
+
if norm1 * norm2 == 0:
|
177
|
+
return 0.0
|
178
|
+
|
179
|
+
return dot_product / (norm1 * norm2)
|
180
|
+
|
181
|
+
async def find_similar_texts(
|
182
|
+
self,
|
183
|
+
query_embedding: List[float],
|
184
|
+
candidate_embeddings: List[List[float]],
|
185
|
+
top_k: int = 5
|
186
|
+
) -> List[Dict[str, Any]]:
|
187
|
+
"""Find most similar texts based on embeddings"""
|
188
|
+
similarities = []
|
189
|
+
|
190
|
+
for i, candidate in enumerate(candidate_embeddings):
|
191
|
+
similarity = await self.compute_similarity(query_embedding, candidate)
|
192
|
+
similarities.append({
|
193
|
+
"index": i,
|
194
|
+
"similarity": similarity
|
195
|
+
})
|
196
|
+
|
197
|
+
# Sort by similarity in descending order and return top_k
|
198
|
+
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
199
|
+
return similarities[:top_k]
|
200
|
+
|
201
|
+
def get_embedding_dimension(self) -> int:
|
202
|
+
"""Get the dimension of embeddings produced by this service"""
|
203
|
+
if self.dimensions:
|
204
|
+
return self.dimensions
|
205
|
+
|
206
|
+
# Default dimensions for OpenAI models
|
207
|
+
model_dimensions = {
|
208
|
+
"text-embedding-3-small": 1536,
|
209
|
+
"text-embedding-3-large": 3072,
|
210
|
+
"text-embedding-ada-002": 1536
|
211
|
+
}
|
212
|
+
|
213
|
+
return model_dimensions.get(self.model_name, 1536)
|
214
|
+
|
215
|
+
def get_max_input_length(self) -> int:
|
216
|
+
"""Get maximum input text length supported"""
|
217
|
+
# OpenAI embedding models support up to 8192 tokens
|
218
|
+
return 8192
|
219
|
+
|
220
|
+
async def close(self):
|
221
|
+
"""Cleanup resources"""
|
222
|
+
await self.client.close()
|
223
|
+
logger.info("OpenAIEmbedService client has been closed.")
|
@@ -5,8 +5,10 @@ LLM Services - Business logic services for Language Models
|
|
5
5
|
# Import LLM services here when created
|
6
6
|
from .ollama_llm_service import OllamaLLMService
|
7
7
|
from .openai_llm_service import OpenAILLMService
|
8
|
+
from .triton_llm_service import TritonLLMService
|
8
9
|
|
9
10
|
__all__ = [
|
10
11
|
"OllamaLLMService",
|
11
12
|
"OpenAILLMService",
|
13
|
+
"TritonLLMService"
|
12
14
|
]
|