isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,129 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
3
|
+
from isa_model.inference.services.base_service import BaseLLMService
|
4
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
+
from isa_model.inference.backends.openai_client import OpenAIBackendClient
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class OpenAILLMService(BaseLLMService):
|
10
|
+
"""OpenAI LLM service implementation"""
|
11
|
+
|
12
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo", backend: Optional[OpenAIBackendClient] = None):
|
13
|
+
super().__init__(provider, model_name)
|
14
|
+
|
15
|
+
# Use provided backend or create new one
|
16
|
+
if backend:
|
17
|
+
self.backend = backend
|
18
|
+
else:
|
19
|
+
api_key = self.config.get("api_key", "")
|
20
|
+
api_base = self.config.get("api_base", "https://api.openai.com/v1")
|
21
|
+
self.backend = OpenAIBackendClient(api_key, api_base)
|
22
|
+
|
23
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
24
|
+
logger.info(f"Initialized OpenAILLMService with model {model_name}")
|
25
|
+
|
26
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
|
27
|
+
"""Universal invocation method"""
|
28
|
+
if isinstance(prompt, str):
|
29
|
+
return await self.acompletion(prompt)
|
30
|
+
elif isinstance(prompt, list):
|
31
|
+
return await self.achat(prompt)
|
32
|
+
else:
|
33
|
+
raise ValueError("Prompt must be string or list of messages")
|
34
|
+
|
35
|
+
async def achat(self, messages: List[Dict[str, str]]):
|
36
|
+
"""Chat completion method"""
|
37
|
+
try:
|
38
|
+
temperature = self.config.get("temperature", 0.7)
|
39
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
40
|
+
|
41
|
+
payload = {
|
42
|
+
"model": self.model_name,
|
43
|
+
"messages": messages,
|
44
|
+
"temperature": temperature,
|
45
|
+
"max_tokens": max_tokens
|
46
|
+
}
|
47
|
+
response = await self.backend.post("/chat/completions", payload)
|
48
|
+
|
49
|
+
# Update token usage
|
50
|
+
self.last_token_usage = response.get("usage", {
|
51
|
+
"prompt_tokens": 0,
|
52
|
+
"completion_tokens": 0,
|
53
|
+
"total_tokens": 0
|
54
|
+
})
|
55
|
+
|
56
|
+
return response["choices"][0]["message"]["content"]
|
57
|
+
|
58
|
+
except Exception as e:
|
59
|
+
logger.error(f"Error in chat completion: {e}")
|
60
|
+
raise
|
61
|
+
|
62
|
+
async def acompletion(self, prompt: str):
|
63
|
+
"""Text completion method (using chat API since completions is deprecated)"""
|
64
|
+
try:
|
65
|
+
messages = [{"role": "user", "content": prompt}]
|
66
|
+
return await self.achat(messages)
|
67
|
+
except Exception as e:
|
68
|
+
logger.error(f"Error in text completion: {e}")
|
69
|
+
raise
|
70
|
+
|
71
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
72
|
+
"""Generate multiple completions"""
|
73
|
+
try:
|
74
|
+
temperature = self.config.get("temperature", 0.7)
|
75
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
76
|
+
|
77
|
+
payload = {
|
78
|
+
"model": self.model_name,
|
79
|
+
"messages": messages,
|
80
|
+
"temperature": temperature,
|
81
|
+
"max_tokens": max_tokens,
|
82
|
+
"n": n
|
83
|
+
}
|
84
|
+
response = await self.backend.post("/chat/completions", payload)
|
85
|
+
|
86
|
+
# Update token usage
|
87
|
+
self.last_token_usage = response.get("usage", {
|
88
|
+
"prompt_tokens": 0,
|
89
|
+
"completion_tokens": 0,
|
90
|
+
"total_tokens": 0
|
91
|
+
})
|
92
|
+
|
93
|
+
return [choice["message"]["content"] for choice in response["choices"]]
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Error in generate: {e}")
|
96
|
+
raise
|
97
|
+
|
98
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
99
|
+
"""Stream chat responses"""
|
100
|
+
try:
|
101
|
+
temperature = self.config.get("temperature", 0.7)
|
102
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
103
|
+
|
104
|
+
payload = {
|
105
|
+
"model": self.model_name,
|
106
|
+
"messages": messages,
|
107
|
+
"temperature": temperature,
|
108
|
+
"max_tokens": max_tokens,
|
109
|
+
"stream": True
|
110
|
+
}
|
111
|
+
|
112
|
+
async for chunk in self.backend.stream_chat(payload):
|
113
|
+
yield chunk
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
logger.error(f"Error in stream chat: {e}")
|
117
|
+
raise
|
118
|
+
|
119
|
+
def get_token_usage(self):
|
120
|
+
"""Get total token usage statistics"""
|
121
|
+
return self.last_token_usage
|
122
|
+
|
123
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
124
|
+
"""Get token usage from last request"""
|
125
|
+
return self.last_token_usage
|
126
|
+
|
127
|
+
async def close(self):
|
128
|
+
"""Close the backend client"""
|
129
|
+
await self.backend.close()
|
@@ -0,0 +1,179 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
3
|
+
from isa_model.inference.services.base_service import BaseLLMService
|
4
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
+
from isa_model.inference.backends.replicate_client import ReplicateBackendClient
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class ReplicateLLMService(BaseLLMService):
|
10
|
+
"""Replicate LLM service implementation"""
|
11
|
+
|
12
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "meta/llama-3-8b-instruct", backend: Optional[ReplicateBackendClient] = None):
|
13
|
+
super().__init__(provider, model_name)
|
14
|
+
|
15
|
+
# Use provided backend or create new one
|
16
|
+
if backend:
|
17
|
+
self.backend = backend
|
18
|
+
else:
|
19
|
+
api_token = self.config.get("api_token", "")
|
20
|
+
self.backend = ReplicateBackendClient(api_token)
|
21
|
+
|
22
|
+
# Parse model name for Replicate format (owner/model)
|
23
|
+
if "/" not in model_name:
|
24
|
+
logger.warning(f"Model name {model_name} is not in Replicate format (owner/model). Using as-is.")
|
25
|
+
|
26
|
+
# Store version separately if provided
|
27
|
+
self.model_version = None
|
28
|
+
if ":" in model_name:
|
29
|
+
self.model_name, self.model_version = model_name.split(":", 1)
|
30
|
+
|
31
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
32
|
+
logger.info(f"Initialized ReplicateLLMService with model {model_name}")
|
33
|
+
|
34
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
|
35
|
+
"""Universal invocation method"""
|
36
|
+
if isinstance(prompt, str):
|
37
|
+
return await self.acompletion(prompt)
|
38
|
+
elif isinstance(prompt, list):
|
39
|
+
return await self.achat(prompt)
|
40
|
+
else:
|
41
|
+
raise ValueError("Prompt must be string or list of messages")
|
42
|
+
|
43
|
+
async def achat(self, messages: List[Dict[str, str]]):
|
44
|
+
"""Chat completion method"""
|
45
|
+
try:
|
46
|
+
temperature = self.config.get("temperature", 0.7)
|
47
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
48
|
+
|
49
|
+
# Convert to Replicate format
|
50
|
+
prompt = self._convert_messages_to_prompt(messages)
|
51
|
+
|
52
|
+
# Prepare input data
|
53
|
+
input_data = {
|
54
|
+
"prompt": prompt,
|
55
|
+
"temperature": temperature,
|
56
|
+
"max_new_tokens": max_tokens,
|
57
|
+
"system_prompt": self._extract_system_prompt(messages)
|
58
|
+
}
|
59
|
+
|
60
|
+
# Call Replicate API
|
61
|
+
prediction = await self.backend.create_prediction(
|
62
|
+
self.model_name,
|
63
|
+
self.model_version,
|
64
|
+
input_data
|
65
|
+
)
|
66
|
+
|
67
|
+
# Get output - could be a list of strings or a single string
|
68
|
+
output = prediction.get("output", "")
|
69
|
+
if isinstance(output, list):
|
70
|
+
output = "".join(output)
|
71
|
+
|
72
|
+
# Approximate token usage - Replicate doesn't provide token counts
|
73
|
+
approx_prompt_tokens = len(prompt) // 4 # Very rough approximation
|
74
|
+
approx_completion_tokens = len(output) // 4
|
75
|
+
self.last_token_usage = {
|
76
|
+
"prompt_tokens": approx_prompt_tokens,
|
77
|
+
"completion_tokens": approx_completion_tokens,
|
78
|
+
"total_tokens": approx_prompt_tokens + approx_completion_tokens
|
79
|
+
}
|
80
|
+
|
81
|
+
return output
|
82
|
+
|
83
|
+
except Exception as e:
|
84
|
+
logger.error(f"Error in chat completion: {e}")
|
85
|
+
raise
|
86
|
+
|
87
|
+
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
88
|
+
"""Convert chat messages to Replicate prompt format"""
|
89
|
+
prompt = ""
|
90
|
+
|
91
|
+
for msg in messages:
|
92
|
+
role = msg.get("role", "user")
|
93
|
+
content = msg.get("content", "")
|
94
|
+
|
95
|
+
# Skip system prompts - handled separately
|
96
|
+
if role == "system":
|
97
|
+
continue
|
98
|
+
|
99
|
+
if role == "user":
|
100
|
+
prompt += f"Human: {content}\n\n"
|
101
|
+
elif role == "assistant":
|
102
|
+
prompt += f"Assistant: {content}\n\n"
|
103
|
+
else:
|
104
|
+
# Default to user for unknown roles
|
105
|
+
prompt += f"Human: {content}\n\n"
|
106
|
+
|
107
|
+
# Add final assistant prefix for the model to continue
|
108
|
+
prompt += "Assistant: "
|
109
|
+
|
110
|
+
return prompt
|
111
|
+
|
112
|
+
def _extract_system_prompt(self, messages: List[Dict[str, str]]) -> str:
|
113
|
+
"""Extract system prompt from messages"""
|
114
|
+
for msg in messages:
|
115
|
+
if msg.get("role") == "system":
|
116
|
+
return msg.get("content", "")
|
117
|
+
return ""
|
118
|
+
|
119
|
+
async def acompletion(self, prompt: str):
|
120
|
+
"""Text completion method"""
|
121
|
+
try:
|
122
|
+
# For simple completion, use chat format with a single user message
|
123
|
+
messages = [{"role": "user", "content": prompt}]
|
124
|
+
return await self.achat(messages)
|
125
|
+
except Exception as e:
|
126
|
+
logger.error(f"Error in text completion: {e}")
|
127
|
+
raise
|
128
|
+
|
129
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
130
|
+
"""Generate multiple completions"""
|
131
|
+
# Replicate doesn't support multiple outputs in one call,
|
132
|
+
# so we make multiple calls
|
133
|
+
results = []
|
134
|
+
for _ in range(n):
|
135
|
+
result = await self.achat(messages)
|
136
|
+
results.append(result)
|
137
|
+
return results
|
138
|
+
|
139
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
140
|
+
"""Stream chat responses"""
|
141
|
+
try:
|
142
|
+
temperature = self.config.get("temperature", 0.7)
|
143
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
144
|
+
|
145
|
+
# Convert to Replicate format
|
146
|
+
prompt = self._convert_messages_to_prompt(messages)
|
147
|
+
|
148
|
+
# Prepare input data
|
149
|
+
input_data = {
|
150
|
+
"prompt": prompt,
|
151
|
+
"temperature": temperature,
|
152
|
+
"max_new_tokens": max_tokens,
|
153
|
+
"system_prompt": self._extract_system_prompt(messages),
|
154
|
+
"stream": True
|
155
|
+
}
|
156
|
+
|
157
|
+
# Call Replicate API with streaming
|
158
|
+
async for chunk in self.backend.stream_prediction(
|
159
|
+
self.model_name,
|
160
|
+
self.model_version,
|
161
|
+
input_data
|
162
|
+
):
|
163
|
+
yield chunk
|
164
|
+
|
165
|
+
except Exception as e:
|
166
|
+
logger.error(f"Error in stream chat: {e}")
|
167
|
+
raise
|
168
|
+
|
169
|
+
def get_token_usage(self):
|
170
|
+
"""Get total token usage statistics"""
|
171
|
+
return self.last_token_usage
|
172
|
+
|
173
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
174
|
+
"""Get token usage from last request"""
|
175
|
+
return self.last_token_usage
|
176
|
+
|
177
|
+
async def close(self):
|
178
|
+
"""Close the backend client"""
|
179
|
+
await self.backend.close()
|
@@ -0,0 +1,230 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import asyncio
|
4
|
+
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from isa_model.inference.services.base_service import BaseLLMService
|
9
|
+
from isa_model.inference.providers.triton_provider import TritonProvider
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class TritonLLMService(BaseLLMService):
|
15
|
+
"""
|
16
|
+
LLM service that uses Triton Inference Server to run inference.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, provider: TritonProvider, model_name: str):
|
20
|
+
"""
|
21
|
+
Initialize the Triton LLM service.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
provider: The Triton provider
|
25
|
+
model_name: Name of the model in Triton (e.g., "Llama3-8B")
|
26
|
+
"""
|
27
|
+
super().__init__(provider, model_name)
|
28
|
+
self.client = None
|
29
|
+
self.tokenizer = None
|
30
|
+
self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
31
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
32
|
+
|
33
|
+
async def _initialize_client(self):
|
34
|
+
"""Initialize the Triton client"""
|
35
|
+
if self.client is None:
|
36
|
+
self.client = self.provider.create_client()
|
37
|
+
|
38
|
+
# Check if model is ready
|
39
|
+
if not self.provider.is_model_ready(self.model_name):
|
40
|
+
logger.error(f"Model {self.model_name} is not ready on Triton server")
|
41
|
+
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
42
|
+
|
43
|
+
logger.info(f"Initialized Triton client for model: {self.model_name}")
|
44
|
+
|
45
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
|
46
|
+
"""
|
47
|
+
Universal invocation method.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
prompt: Text prompt or chat messages
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Generated text
|
54
|
+
"""
|
55
|
+
if isinstance(prompt, str):
|
56
|
+
return await self.acompletion(prompt)
|
57
|
+
elif isinstance(prompt, list) and all(isinstance(m, dict) for m in prompt):
|
58
|
+
return await self.achat(prompt)
|
59
|
+
else:
|
60
|
+
raise ValueError("Prompt must be either a string or a list of message dictionaries")
|
61
|
+
|
62
|
+
async def achat(self, messages: List[Dict[str, str]]) -> str:
|
63
|
+
"""
|
64
|
+
Chat completion method.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
messages: List of message dictionaries with 'role' and 'content'
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
Generated chat response
|
71
|
+
"""
|
72
|
+
# Format chat messages into a single prompt
|
73
|
+
formatted_prompt = self._format_chat_messages(messages)
|
74
|
+
return await self.acompletion(formatted_prompt)
|
75
|
+
|
76
|
+
async def acompletion(self, prompt: str) -> str:
|
77
|
+
"""
|
78
|
+
Text completion method.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
prompt: Text prompt
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
Generated text completion
|
85
|
+
"""
|
86
|
+
await self._initialize_client()
|
87
|
+
|
88
|
+
try:
|
89
|
+
import tritonclient.http as httpclient
|
90
|
+
|
91
|
+
# Create input tensors
|
92
|
+
input_text = np.array([prompt], dtype=np.object_)
|
93
|
+
inputs = [httpclient.InferInput("TEXT", input_text.shape, "BYTES")]
|
94
|
+
inputs[0].set_data_from_numpy(input_text)
|
95
|
+
|
96
|
+
# Default parameters
|
97
|
+
generation_params = {
|
98
|
+
"max_new_tokens": self.config.get("max_new_tokens", 512),
|
99
|
+
"temperature": self.config.get("temperature", 0.7),
|
100
|
+
"top_p": self.config.get("top_p", 0.9),
|
101
|
+
"do_sample": self.config.get("temperature", 0.7) > 0
|
102
|
+
}
|
103
|
+
|
104
|
+
# Add parameters as input tensor
|
105
|
+
param_json = json.dumps(generation_params)
|
106
|
+
param_data = np.array([param_json], dtype=np.object_)
|
107
|
+
param_input = httpclient.InferInput("PARAMETERS", param_data.shape, "BYTES")
|
108
|
+
param_input.set_data_from_numpy(param_data)
|
109
|
+
inputs.append(param_input)
|
110
|
+
|
111
|
+
# Create output tensor
|
112
|
+
outputs = [httpclient.InferRequestedOutput("TEXT")]
|
113
|
+
|
114
|
+
# Send the request
|
115
|
+
response = await asyncio.to_thread(
|
116
|
+
self.client.infer,
|
117
|
+
self.model_name,
|
118
|
+
inputs,
|
119
|
+
outputs=outputs
|
120
|
+
)
|
121
|
+
|
122
|
+
# Process the response
|
123
|
+
output = response.as_numpy("TEXT")
|
124
|
+
response_text = output[0].decode('utf-8')
|
125
|
+
|
126
|
+
# Update token usage (estimated since we don't have actual token counts)
|
127
|
+
prompt_tokens = len(prompt) // 4 # Rough estimate
|
128
|
+
completion_tokens = len(response_text) // 4 # Rough estimate
|
129
|
+
total_tokens = prompt_tokens + completion_tokens
|
130
|
+
|
131
|
+
self.last_token_usage = {
|
132
|
+
"prompt_tokens": prompt_tokens,
|
133
|
+
"completion_tokens": completion_tokens,
|
134
|
+
"total_tokens": total_tokens
|
135
|
+
}
|
136
|
+
|
137
|
+
# Update total token usage
|
138
|
+
self.token_usage["prompt_tokens"] += prompt_tokens
|
139
|
+
self.token_usage["completion_tokens"] += completion_tokens
|
140
|
+
self.token_usage["total_tokens"] += total_tokens
|
141
|
+
|
142
|
+
return response_text
|
143
|
+
|
144
|
+
except Exception as e:
|
145
|
+
logger.error(f"Error during Triton inference: {str(e)}")
|
146
|
+
raise
|
147
|
+
|
148
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
149
|
+
"""
|
150
|
+
Generate multiple completions.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
messages: List of message dictionaries
|
154
|
+
n: Number of completions to generate
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
List of generated completions
|
158
|
+
"""
|
159
|
+
results = []
|
160
|
+
for _ in range(n):
|
161
|
+
result = await self.achat(messages)
|
162
|
+
results.append(result)
|
163
|
+
return results
|
164
|
+
|
165
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
166
|
+
"""
|
167
|
+
Stream chat responses.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
messages: List of message dictionaries
|
171
|
+
|
172
|
+
Yields:
|
173
|
+
Generated text chunks
|
174
|
+
"""
|
175
|
+
# For Triton, we don't have true streaming, so we generate the full response
|
176
|
+
# and then simulate streaming
|
177
|
+
full_response = await self.achat(messages)
|
178
|
+
|
179
|
+
# Simulate streaming by yielding words
|
180
|
+
words = full_response.split()
|
181
|
+
for i in range(len(words)):
|
182
|
+
chunk = ' '.join(words[:i+1])
|
183
|
+
yield chunk
|
184
|
+
await asyncio.sleep(0.05) # Small delay to simulate streaming
|
185
|
+
|
186
|
+
def get_token_usage(self) -> Dict[str, int]:
|
187
|
+
"""
|
188
|
+
Get total token usage statistics.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
Dictionary with token usage statistics
|
192
|
+
"""
|
193
|
+
return self.token_usage
|
194
|
+
|
195
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
196
|
+
"""
|
197
|
+
Get token usage from last request.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
Dictionary with token usage statistics from last request
|
201
|
+
"""
|
202
|
+
return self.last_token_usage
|
203
|
+
|
204
|
+
def _format_chat_messages(self, messages: List[Dict[str, str]]) -> str:
|
205
|
+
"""
|
206
|
+
Format chat messages into a single prompt for models that don't support chat natively.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
messages: List of message dictionaries
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Formatted prompt
|
213
|
+
"""
|
214
|
+
formatted_prompt = ""
|
215
|
+
|
216
|
+
for message in messages:
|
217
|
+
role = message.get("role", "user").lower()
|
218
|
+
content = message.get("content", "")
|
219
|
+
|
220
|
+
if role == "system":
|
221
|
+
formatted_prompt += f"System: {content}\n\n"
|
222
|
+
elif role == "user":
|
223
|
+
formatted_prompt += f"User: {content}\n\n"
|
224
|
+
elif role == "assistant":
|
225
|
+
formatted_prompt += f"Assistant: {content}\n\n"
|
226
|
+
else:
|
227
|
+
formatted_prompt += f"{role.capitalize()}: {content}\n\n"
|
228
|
+
|
229
|
+
formatted_prompt += "Assistant: "
|
230
|
+
return formatted_prompt
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Dict, Any, List, Union, Optional
|
2
|
+
from ...base_service import BaseService
|
3
|
+
from ...base_provider import BaseProvider
|
4
|
+
from transformers import TableTransformerForObjectDetection, DetrImageProcessor
|
5
|
+
import torch
|
6
|
+
from PIL import Image
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
class TableTransformerService(BaseService):
|
10
|
+
"""Table detection service using Microsoft's Table Transformer"""
|
11
|
+
|
12
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "microsoft/table-transformer-detection"):
|
13
|
+
super().__init__(provider, model_name)
|
14
|
+
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
15
|
+
self.model = TableTransformerForObjectDetection.from_pretrained(model_name)
|
16
|
+
if torch.cuda.is_available():
|
17
|
+
self.model = self.model.cuda()
|
18
|
+
|
19
|
+
async def detect_tables(self, image_path: str) -> Dict[str, Any]:
|
20
|
+
"""Detect tables in image"""
|
21
|
+
try:
|
22
|
+
# Load and process image
|
23
|
+
image = Image.open(image_path)
|
24
|
+
inputs = self.processor(images=image, return_tensors="pt")
|
25
|
+
|
26
|
+
if torch.cuda.is_available():
|
27
|
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
28
|
+
|
29
|
+
# Run inference
|
30
|
+
outputs = self.model(**inputs)
|
31
|
+
|
32
|
+
# Convert outputs to image size
|
33
|
+
target_sizes = torch.tensor([image.size[::-1]])
|
34
|
+
results = self.processor.post_process_object_detection(
|
35
|
+
outputs, target_sizes=target_sizes, threshold=0.7
|
36
|
+
)[0]
|
37
|
+
|
38
|
+
tables = []
|
39
|
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
40
|
+
if label == 1: # Table class
|
41
|
+
tables.append({
|
42
|
+
"confidence": score.item(),
|
43
|
+
"bbox": box.tolist(),
|
44
|
+
"type": "table"
|
45
|
+
})
|
46
|
+
|
47
|
+
return {
|
48
|
+
"tables": tables,
|
49
|
+
"image_size": image.size
|
50
|
+
}
|
51
|
+
|
52
|
+
except Exception as e:
|
53
|
+
raise RuntimeError(f"Table detection failed: {e}")
|
54
|
+
|
55
|
+
async def close(self):
|
56
|
+
"""Cleanup resources"""
|
57
|
+
if hasattr(self, 'model'):
|
58
|
+
del self.model
|
59
|
+
if hasattr(self, 'processor'):
|
60
|
+
del self.processor
|
61
|
+
torch.cuda.empty_cache()
|
@@ -0,0 +1,12 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Vision服务包
|
6
|
+
包含所有视觉相关服务模块
|
7
|
+
"""
|
8
|
+
|
9
|
+
# 导出ReplicateVisionService
|
10
|
+
from isa_model.inference.services.vision.replicate_vision_service import ReplicateVisionService
|
11
|
+
|
12
|
+
__all__ = ["ReplicateVisionService"]
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from io import BytesIO
|
2
|
+
from PIL import Image
|
3
|
+
from typing import Union
|
4
|
+
import base64
|
5
|
+
from app.config.config_manager import config_manager
|
6
|
+
|
7
|
+
logger = config_manager.get_logger(__name__)
|
8
|
+
|
9
|
+
def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
|
10
|
+
"""压缩图片以减小大小
|
11
|
+
|
12
|
+
Args:
|
13
|
+
image_data: 图片数据,可以是 bytes 或 BytesIO
|
14
|
+
max_size: 最大尺寸(像素)
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
bytes: 压缩后的图片数据
|
18
|
+
"""
|
19
|
+
try:
|
20
|
+
# 如果输入是 bytes,转换为 BytesIO
|
21
|
+
if isinstance(image_data, bytes):
|
22
|
+
image_data = BytesIO(image_data)
|
23
|
+
|
24
|
+
img = Image.open(image_data)
|
25
|
+
|
26
|
+
# 转换为 RGB 模式(如果需要)
|
27
|
+
if img.mode in ('RGBA', 'P'):
|
28
|
+
img = img.convert('RGB')
|
29
|
+
|
30
|
+
# 计算新尺寸,保持宽高比
|
31
|
+
ratio = max_size / max(img.size)
|
32
|
+
if ratio < 1:
|
33
|
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
34
|
+
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
35
|
+
|
36
|
+
# 保存压缩后的图片
|
37
|
+
output = BytesIO()
|
38
|
+
img.save(output, format='JPEG', quality=85, optimize=True)
|
39
|
+
return output.getvalue()
|
40
|
+
|
41
|
+
except Exception as e:
|
42
|
+
logger.error(f"Error compressing image: {e}")
|
43
|
+
raise
|
44
|
+
|
45
|
+
def encode_image_to_base64(image_data: bytes) -> str:
|
46
|
+
"""将图片数据编码为 base64 字符串
|
47
|
+
|
48
|
+
Args:
|
49
|
+
image_data: 图片二进制数据
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
str: base64 编码的字符串
|
53
|
+
"""
|
54
|
+
try:
|
55
|
+
return base64.b64encode(image_data).decode('utf-8')
|
56
|
+
except Exception as e:
|
57
|
+
logger.error(f"Error encoding image to base64: {e}")
|
58
|
+
raise
|