isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,143 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import asyncio
|
4
|
-
from typing import Dict, List, Any, Optional, Union
|
5
|
-
|
6
|
-
from isa_model.inference.services.base_service import BaseService
|
7
|
-
from isa_model.inference.backends.triton_client import TritonClient
|
8
|
-
|
9
|
-
logger = logging.getLogger(__name__)
|
10
|
-
|
11
|
-
|
12
|
-
class GemmaService(BaseService):
|
13
|
-
"""
|
14
|
-
Service for Gemma LLM using Triton Inference Server.
|
15
|
-
"""
|
16
|
-
|
17
|
-
def __init__(self, triton_url: str = "localhost:8001", model_name: str = "gemma"):
|
18
|
-
"""
|
19
|
-
Initialize the Gemma service.
|
20
|
-
|
21
|
-
Args:
|
22
|
-
triton_url: URL of the Triton Inference Server
|
23
|
-
model_name: Name of the model in Triton
|
24
|
-
"""
|
25
|
-
super().__init__()
|
26
|
-
self.triton_url = triton_url
|
27
|
-
self.model_name = model_name
|
28
|
-
self.client = None
|
29
|
-
|
30
|
-
# Default generation config
|
31
|
-
self.default_config = {
|
32
|
-
"max_new_tokens": 512,
|
33
|
-
"temperature": 0.7,
|
34
|
-
"top_p": 0.9,
|
35
|
-
"top_k": 50,
|
36
|
-
"repetition_penalty": 1.1,
|
37
|
-
"do_sample": True
|
38
|
-
}
|
39
|
-
|
40
|
-
self.logger = logger
|
41
|
-
|
42
|
-
async def load(self) -> None:
|
43
|
-
"""
|
44
|
-
Load the client connection to Triton.
|
45
|
-
"""
|
46
|
-
if self.is_loaded():
|
47
|
-
return
|
48
|
-
|
49
|
-
try:
|
50
|
-
# Create Triton client
|
51
|
-
self.logger.info(f"Connecting to Triton server at {self.triton_url}")
|
52
|
-
self.client = TritonClient(self.triton_url)
|
53
|
-
|
54
|
-
# Check if model is ready
|
55
|
-
if not await self.client.is_model_ready(self.model_name):
|
56
|
-
self.logger.error(f"Model {self.model_name} is not ready on Triton server")
|
57
|
-
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
58
|
-
|
59
|
-
self._loaded = True
|
60
|
-
self.logger.info(f"Connected to Triton for model {self.model_name}")
|
61
|
-
|
62
|
-
except Exception as e:
|
63
|
-
self.logger.error(f"Failed to connect to Triton: {str(e)}")
|
64
|
-
raise
|
65
|
-
|
66
|
-
async def unload(self) -> None:
|
67
|
-
"""
|
68
|
-
Unload the client connection.
|
69
|
-
"""
|
70
|
-
if not self.is_loaded():
|
71
|
-
return
|
72
|
-
|
73
|
-
self.client = None
|
74
|
-
self._loaded = False
|
75
|
-
self.logger.info("Triton client connection closed")
|
76
|
-
|
77
|
-
async def generate(self,
|
78
|
-
prompt: str,
|
79
|
-
system_prompt: Optional[str] = None,
|
80
|
-
generation_config: Optional[Dict[str, Any]] = None) -> str:
|
81
|
-
"""
|
82
|
-
Generate text from a prompt using Triton.
|
83
|
-
|
84
|
-
Args:
|
85
|
-
prompt: User prompt
|
86
|
-
system_prompt: System prompt to control model behavior
|
87
|
-
generation_config: Configuration for text generation
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
Generated text
|
91
|
-
"""
|
92
|
-
if not self.is_loaded():
|
93
|
-
await self.load()
|
94
|
-
|
95
|
-
# Get configuration
|
96
|
-
merged_config = self.default_config.copy()
|
97
|
-
if generation_config:
|
98
|
-
merged_config.update(generation_config)
|
99
|
-
|
100
|
-
try:
|
101
|
-
# Prepare inputs
|
102
|
-
inputs = {
|
103
|
-
"prompt": [prompt],
|
104
|
-
}
|
105
|
-
|
106
|
-
# Add optional inputs
|
107
|
-
if system_prompt:
|
108
|
-
inputs["system_prompt"] = [system_prompt]
|
109
|
-
|
110
|
-
if merged_config:
|
111
|
-
inputs["generation_config"] = [json.dumps(merged_config)]
|
112
|
-
|
113
|
-
# Run inference
|
114
|
-
result = await self.client.infer(
|
115
|
-
model_name=self.model_name,
|
116
|
-
inputs=inputs,
|
117
|
-
outputs=["text_output"]
|
118
|
-
)
|
119
|
-
|
120
|
-
# Extract generated text
|
121
|
-
generated_text = result["text_output"][0].decode('utf-8')
|
122
|
-
|
123
|
-
return generated_text
|
124
|
-
|
125
|
-
except Exception as e:
|
126
|
-
self.logger.error(f"Error during text generation: {str(e)}")
|
127
|
-
raise
|
128
|
-
|
129
|
-
def get_model_info(self) -> Dict[str, Any]:
|
130
|
-
"""
|
131
|
-
Get information about the model.
|
132
|
-
|
133
|
-
Returns:
|
134
|
-
Dictionary containing model information
|
135
|
-
"""
|
136
|
-
return {
|
137
|
-
"name": self.model_name,
|
138
|
-
"type": "llm",
|
139
|
-
"backend": "triton",
|
140
|
-
"url": self.triton_url,
|
141
|
-
"loaded": self.is_loaded(),
|
142
|
-
"config": self.default_config
|
143
|
-
}
|
@@ -1,143 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import asyncio
|
4
|
-
from typing import Dict, List, Any, Optional, Union
|
5
|
-
|
6
|
-
from isa_model.inference.services.base_service import BaseService
|
7
|
-
from isa_model.inference.backends.triton_client import TritonClient
|
8
|
-
|
9
|
-
logger = logging.getLogger(__name__)
|
10
|
-
|
11
|
-
|
12
|
-
class LlamaService(BaseService):
|
13
|
-
"""
|
14
|
-
Service for Llama LLM using Triton Inference Server.
|
15
|
-
"""
|
16
|
-
|
17
|
-
def __init__(self, triton_url: str = "localhost:8001", model_name: str = "llama"):
|
18
|
-
"""
|
19
|
-
Initialize the Llama service.
|
20
|
-
|
21
|
-
Args:
|
22
|
-
triton_url: URL of the Triton Inference Server
|
23
|
-
model_name: Name of the model in Triton
|
24
|
-
"""
|
25
|
-
super().__init__()
|
26
|
-
self.triton_url = triton_url
|
27
|
-
self.model_name = model_name
|
28
|
-
self.client = None
|
29
|
-
|
30
|
-
# Default generation config
|
31
|
-
self.default_config = {
|
32
|
-
"max_new_tokens": 512,
|
33
|
-
"temperature": 0.7,
|
34
|
-
"top_p": 0.9,
|
35
|
-
"top_k": 50,
|
36
|
-
"repetition_penalty": 1.1,
|
37
|
-
"do_sample": True
|
38
|
-
}
|
39
|
-
|
40
|
-
self.logger = logger
|
41
|
-
|
42
|
-
async def load(self) -> None:
|
43
|
-
"""
|
44
|
-
Load the client connection to Triton.
|
45
|
-
"""
|
46
|
-
if self.is_loaded():
|
47
|
-
return
|
48
|
-
|
49
|
-
try:
|
50
|
-
# Create Triton client
|
51
|
-
self.logger.info(f"Connecting to Triton server at {self.triton_url}")
|
52
|
-
self.client = TritonClient(self.triton_url)
|
53
|
-
|
54
|
-
# Check if model is ready
|
55
|
-
if not await self.client.is_model_ready(self.model_name):
|
56
|
-
self.logger.error(f"Model {self.model_name} is not ready on Triton server")
|
57
|
-
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
58
|
-
|
59
|
-
self._loaded = True
|
60
|
-
self.logger.info(f"Connected to Triton for model {self.model_name}")
|
61
|
-
|
62
|
-
except Exception as e:
|
63
|
-
self.logger.error(f"Failed to connect to Triton: {str(e)}")
|
64
|
-
raise
|
65
|
-
|
66
|
-
async def unload(self) -> None:
|
67
|
-
"""
|
68
|
-
Unload the client connection.
|
69
|
-
"""
|
70
|
-
if not self.is_loaded():
|
71
|
-
return
|
72
|
-
|
73
|
-
self.client = None
|
74
|
-
self._loaded = False
|
75
|
-
self.logger.info("Triton client connection closed")
|
76
|
-
|
77
|
-
async def generate(self,
|
78
|
-
prompt: str,
|
79
|
-
system_prompt: Optional[str] = None,
|
80
|
-
generation_config: Optional[Dict[str, Any]] = None) -> str:
|
81
|
-
"""
|
82
|
-
Generate text from a prompt using Triton.
|
83
|
-
|
84
|
-
Args:
|
85
|
-
prompt: User prompt
|
86
|
-
system_prompt: System prompt to control model behavior
|
87
|
-
generation_config: Configuration for text generation
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
Generated text
|
91
|
-
"""
|
92
|
-
if not self.is_loaded():
|
93
|
-
await self.load()
|
94
|
-
|
95
|
-
# Get configuration
|
96
|
-
merged_config = self.default_config.copy()
|
97
|
-
if generation_config:
|
98
|
-
merged_config.update(generation_config)
|
99
|
-
|
100
|
-
try:
|
101
|
-
# Prepare inputs
|
102
|
-
inputs = {
|
103
|
-
"prompt": [prompt],
|
104
|
-
}
|
105
|
-
|
106
|
-
# Add optional inputs
|
107
|
-
if system_prompt:
|
108
|
-
inputs["system_prompt"] = [system_prompt]
|
109
|
-
|
110
|
-
if merged_config:
|
111
|
-
inputs["generation_config"] = [json.dumps(merged_config)]
|
112
|
-
|
113
|
-
# Run inference
|
114
|
-
result = await self.client.infer(
|
115
|
-
model_name=self.model_name,
|
116
|
-
inputs=inputs,
|
117
|
-
outputs=["text_output"]
|
118
|
-
)
|
119
|
-
|
120
|
-
# Extract generated text
|
121
|
-
generated_text = result["text_output"][0].decode('utf-8')
|
122
|
-
|
123
|
-
return generated_text
|
124
|
-
|
125
|
-
except Exception as e:
|
126
|
-
self.logger.error(f"Error during text generation: {str(e)}")
|
127
|
-
raise
|
128
|
-
|
129
|
-
def get_model_info(self) -> Dict[str, Any]:
|
130
|
-
"""
|
131
|
-
Get information about the model.
|
132
|
-
|
133
|
-
Returns:
|
134
|
-
Dictionary containing model information
|
135
|
-
"""
|
136
|
-
return {
|
137
|
-
"name": self.model_name,
|
138
|
-
"type": "llm",
|
139
|
-
"backend": "triton",
|
140
|
-
"url": self.triton_url,
|
141
|
-
"loaded": self.is_loaded(),
|
142
|
-
"config": self.default_config
|
143
|
-
}
|
@@ -1,179 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
3
|
-
from isa_model.inference.services.base_service import BaseLLMService
|
4
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
-
from isa_model.inference.backends.replicate_client import ReplicateBackendClient
|
6
|
-
|
7
|
-
logger = logging.getLogger(__name__)
|
8
|
-
|
9
|
-
class ReplicateLLMService(BaseLLMService):
|
10
|
-
"""Replicate LLM service implementation"""
|
11
|
-
|
12
|
-
def __init__(self, provider: 'BaseProvider', model_name: str = "meta/llama-3-8b-instruct", backend: Optional[ReplicateBackendClient] = None):
|
13
|
-
super().__init__(provider, model_name)
|
14
|
-
|
15
|
-
# Use provided backend or create new one
|
16
|
-
if backend:
|
17
|
-
self.backend = backend
|
18
|
-
else:
|
19
|
-
api_token = self.config.get("api_token", "")
|
20
|
-
self.backend = ReplicateBackendClient(api_token)
|
21
|
-
|
22
|
-
# Parse model name for Replicate format (owner/model)
|
23
|
-
if "/" not in model_name:
|
24
|
-
logger.warning(f"Model name {model_name} is not in Replicate format (owner/model). Using as-is.")
|
25
|
-
|
26
|
-
# Store version separately if provided
|
27
|
-
self.model_version = None
|
28
|
-
if ":" in model_name:
|
29
|
-
self.model_name, self.model_version = model_name.split(":", 1)
|
30
|
-
|
31
|
-
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
32
|
-
logger.info(f"Initialized ReplicateLLMService with model {model_name}")
|
33
|
-
|
34
|
-
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
|
35
|
-
"""Universal invocation method"""
|
36
|
-
if isinstance(prompt, str):
|
37
|
-
return await self.acompletion(prompt)
|
38
|
-
elif isinstance(prompt, list):
|
39
|
-
return await self.achat(prompt)
|
40
|
-
else:
|
41
|
-
raise ValueError("Prompt must be string or list of messages")
|
42
|
-
|
43
|
-
async def achat(self, messages: List[Dict[str, str]]):
|
44
|
-
"""Chat completion method"""
|
45
|
-
try:
|
46
|
-
temperature = self.config.get("temperature", 0.7)
|
47
|
-
max_tokens = self.config.get("max_tokens", 1024)
|
48
|
-
|
49
|
-
# Convert to Replicate format
|
50
|
-
prompt = self._convert_messages_to_prompt(messages)
|
51
|
-
|
52
|
-
# Prepare input data
|
53
|
-
input_data = {
|
54
|
-
"prompt": prompt,
|
55
|
-
"temperature": temperature,
|
56
|
-
"max_new_tokens": max_tokens,
|
57
|
-
"system_prompt": self._extract_system_prompt(messages)
|
58
|
-
}
|
59
|
-
|
60
|
-
# Call Replicate API
|
61
|
-
prediction = await self.backend.create_prediction(
|
62
|
-
self.model_name,
|
63
|
-
self.model_version,
|
64
|
-
input_data
|
65
|
-
)
|
66
|
-
|
67
|
-
# Get output - could be a list of strings or a single string
|
68
|
-
output = prediction.get("output", "")
|
69
|
-
if isinstance(output, list):
|
70
|
-
output = "".join(output)
|
71
|
-
|
72
|
-
# Approximate token usage - Replicate doesn't provide token counts
|
73
|
-
approx_prompt_tokens = len(prompt) // 4 # Very rough approximation
|
74
|
-
approx_completion_tokens = len(output) // 4
|
75
|
-
self.last_token_usage = {
|
76
|
-
"prompt_tokens": approx_prompt_tokens,
|
77
|
-
"completion_tokens": approx_completion_tokens,
|
78
|
-
"total_tokens": approx_prompt_tokens + approx_completion_tokens
|
79
|
-
}
|
80
|
-
|
81
|
-
return output
|
82
|
-
|
83
|
-
except Exception as e:
|
84
|
-
logger.error(f"Error in chat completion: {e}")
|
85
|
-
raise
|
86
|
-
|
87
|
-
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
88
|
-
"""Convert chat messages to Replicate prompt format"""
|
89
|
-
prompt = ""
|
90
|
-
|
91
|
-
for msg in messages:
|
92
|
-
role = msg.get("role", "user")
|
93
|
-
content = msg.get("content", "")
|
94
|
-
|
95
|
-
# Skip system prompts - handled separately
|
96
|
-
if role == "system":
|
97
|
-
continue
|
98
|
-
|
99
|
-
if role == "user":
|
100
|
-
prompt += f"Human: {content}\n\n"
|
101
|
-
elif role == "assistant":
|
102
|
-
prompt += f"Assistant: {content}\n\n"
|
103
|
-
else:
|
104
|
-
# Default to user for unknown roles
|
105
|
-
prompt += f"Human: {content}\n\n"
|
106
|
-
|
107
|
-
# Add final assistant prefix for the model to continue
|
108
|
-
prompt += "Assistant: "
|
109
|
-
|
110
|
-
return prompt
|
111
|
-
|
112
|
-
def _extract_system_prompt(self, messages: List[Dict[str, str]]) -> str:
|
113
|
-
"""Extract system prompt from messages"""
|
114
|
-
for msg in messages:
|
115
|
-
if msg.get("role") == "system":
|
116
|
-
return msg.get("content", "")
|
117
|
-
return ""
|
118
|
-
|
119
|
-
async def acompletion(self, prompt: str):
|
120
|
-
"""Text completion method"""
|
121
|
-
try:
|
122
|
-
# For simple completion, use chat format with a single user message
|
123
|
-
messages = [{"role": "user", "content": prompt}]
|
124
|
-
return await self.achat(messages)
|
125
|
-
except Exception as e:
|
126
|
-
logger.error(f"Error in text completion: {e}")
|
127
|
-
raise
|
128
|
-
|
129
|
-
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
130
|
-
"""Generate multiple completions"""
|
131
|
-
# Replicate doesn't support multiple outputs in one call,
|
132
|
-
# so we make multiple calls
|
133
|
-
results = []
|
134
|
-
for _ in range(n):
|
135
|
-
result = await self.achat(messages)
|
136
|
-
results.append(result)
|
137
|
-
return results
|
138
|
-
|
139
|
-
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
140
|
-
"""Stream chat responses"""
|
141
|
-
try:
|
142
|
-
temperature = self.config.get("temperature", 0.7)
|
143
|
-
max_tokens = self.config.get("max_tokens", 1024)
|
144
|
-
|
145
|
-
# Convert to Replicate format
|
146
|
-
prompt = self._convert_messages_to_prompt(messages)
|
147
|
-
|
148
|
-
# Prepare input data
|
149
|
-
input_data = {
|
150
|
-
"prompt": prompt,
|
151
|
-
"temperature": temperature,
|
152
|
-
"max_new_tokens": max_tokens,
|
153
|
-
"system_prompt": self._extract_system_prompt(messages),
|
154
|
-
"stream": True
|
155
|
-
}
|
156
|
-
|
157
|
-
# Call Replicate API with streaming
|
158
|
-
async for chunk in self.backend.stream_prediction(
|
159
|
-
self.model_name,
|
160
|
-
self.model_version,
|
161
|
-
input_data
|
162
|
-
):
|
163
|
-
yield chunk
|
164
|
-
|
165
|
-
except Exception as e:
|
166
|
-
logger.error(f"Error in stream chat: {e}")
|
167
|
-
raise
|
168
|
-
|
169
|
-
def get_token_usage(self):
|
170
|
-
"""Get total token usage statistics"""
|
171
|
-
return self.last_token_usage
|
172
|
-
|
173
|
-
def get_last_token_usage(self) -> Dict[str, int]:
|
174
|
-
"""Get token usage from last request"""
|
175
|
-
return self.last_token_usage
|
176
|
-
|
177
|
-
async def close(self):
|
178
|
-
"""Close the backend client"""
|
179
|
-
await self.backend.close()
|
@@ -1,230 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import asyncio
|
4
|
-
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
5
|
-
|
6
|
-
import numpy as np
|
7
|
-
|
8
|
-
from isa_model.inference.services.base_service import BaseLLMService
|
9
|
-
from isa_model.inference.providers.triton_provider import TritonProvider
|
10
|
-
|
11
|
-
logger = logging.getLogger(__name__)
|
12
|
-
|
13
|
-
|
14
|
-
class TritonLLMService(BaseLLMService):
|
15
|
-
"""
|
16
|
-
LLM service that uses Triton Inference Server to run inference.
|
17
|
-
"""
|
18
|
-
|
19
|
-
def __init__(self, provider: TritonProvider, model_name: str):
|
20
|
-
"""
|
21
|
-
Initialize the Triton LLM service.
|
22
|
-
|
23
|
-
Args:
|
24
|
-
provider: The Triton provider
|
25
|
-
model_name: Name of the model in Triton (e.g., "Llama3-8B")
|
26
|
-
"""
|
27
|
-
super().__init__(provider, model_name)
|
28
|
-
self.client = None
|
29
|
-
self.tokenizer = None
|
30
|
-
self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
31
|
-
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
32
|
-
|
33
|
-
async def _initialize_client(self):
|
34
|
-
"""Initialize the Triton client"""
|
35
|
-
if self.client is None:
|
36
|
-
self.client = self.provider.create_client()
|
37
|
-
|
38
|
-
# Check if model is ready
|
39
|
-
if not self.provider.is_model_ready(self.model_name):
|
40
|
-
logger.error(f"Model {self.model_name} is not ready on Triton server")
|
41
|
-
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
42
|
-
|
43
|
-
logger.info(f"Initialized Triton client for model: {self.model_name}")
|
44
|
-
|
45
|
-
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
|
46
|
-
"""
|
47
|
-
Universal invocation method.
|
48
|
-
|
49
|
-
Args:
|
50
|
-
prompt: Text prompt or chat messages
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
Generated text
|
54
|
-
"""
|
55
|
-
if isinstance(prompt, str):
|
56
|
-
return await self.acompletion(prompt)
|
57
|
-
elif isinstance(prompt, list) and all(isinstance(m, dict) for m in prompt):
|
58
|
-
return await self.achat(prompt)
|
59
|
-
else:
|
60
|
-
raise ValueError("Prompt must be either a string or a list of message dictionaries")
|
61
|
-
|
62
|
-
async def achat(self, messages: List[Dict[str, str]]) -> str:
|
63
|
-
"""
|
64
|
-
Chat completion method.
|
65
|
-
|
66
|
-
Args:
|
67
|
-
messages: List of message dictionaries with 'role' and 'content'
|
68
|
-
|
69
|
-
Returns:
|
70
|
-
Generated chat response
|
71
|
-
"""
|
72
|
-
# Format chat messages into a single prompt
|
73
|
-
formatted_prompt = self._format_chat_messages(messages)
|
74
|
-
return await self.acompletion(formatted_prompt)
|
75
|
-
|
76
|
-
async def acompletion(self, prompt: str) -> str:
|
77
|
-
"""
|
78
|
-
Text completion method.
|
79
|
-
|
80
|
-
Args:
|
81
|
-
prompt: Text prompt
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
Generated text completion
|
85
|
-
"""
|
86
|
-
await self._initialize_client()
|
87
|
-
|
88
|
-
try:
|
89
|
-
import tritonclient.http as httpclient
|
90
|
-
|
91
|
-
# Create input tensors
|
92
|
-
input_text = np.array([prompt], dtype=np.object_)
|
93
|
-
inputs = [httpclient.InferInput("TEXT", input_text.shape, "BYTES")]
|
94
|
-
inputs[0].set_data_from_numpy(input_text)
|
95
|
-
|
96
|
-
# Default parameters
|
97
|
-
generation_params = {
|
98
|
-
"max_new_tokens": self.config.get("max_new_tokens", 512),
|
99
|
-
"temperature": self.config.get("temperature", 0.7),
|
100
|
-
"top_p": self.config.get("top_p", 0.9),
|
101
|
-
"do_sample": self.config.get("temperature", 0.7) > 0
|
102
|
-
}
|
103
|
-
|
104
|
-
# Add parameters as input tensor
|
105
|
-
param_json = json.dumps(generation_params)
|
106
|
-
param_data = np.array([param_json], dtype=np.object_)
|
107
|
-
param_input = httpclient.InferInput("PARAMETERS", param_data.shape, "BYTES")
|
108
|
-
param_input.set_data_from_numpy(param_data)
|
109
|
-
inputs.append(param_input)
|
110
|
-
|
111
|
-
# Create output tensor
|
112
|
-
outputs = [httpclient.InferRequestedOutput("TEXT")]
|
113
|
-
|
114
|
-
# Send the request
|
115
|
-
response = await asyncio.to_thread(
|
116
|
-
self.client.infer,
|
117
|
-
self.model_name,
|
118
|
-
inputs,
|
119
|
-
outputs=outputs
|
120
|
-
)
|
121
|
-
|
122
|
-
# Process the response
|
123
|
-
output = response.as_numpy("TEXT")
|
124
|
-
response_text = output[0].decode('utf-8')
|
125
|
-
|
126
|
-
# Update token usage (estimated since we don't have actual token counts)
|
127
|
-
prompt_tokens = len(prompt) // 4 # Rough estimate
|
128
|
-
completion_tokens = len(response_text) // 4 # Rough estimate
|
129
|
-
total_tokens = prompt_tokens + completion_tokens
|
130
|
-
|
131
|
-
self.last_token_usage = {
|
132
|
-
"prompt_tokens": prompt_tokens,
|
133
|
-
"completion_tokens": completion_tokens,
|
134
|
-
"total_tokens": total_tokens
|
135
|
-
}
|
136
|
-
|
137
|
-
# Update total token usage
|
138
|
-
self.token_usage["prompt_tokens"] += prompt_tokens
|
139
|
-
self.token_usage["completion_tokens"] += completion_tokens
|
140
|
-
self.token_usage["total_tokens"] += total_tokens
|
141
|
-
|
142
|
-
return response_text
|
143
|
-
|
144
|
-
except Exception as e:
|
145
|
-
logger.error(f"Error during Triton inference: {str(e)}")
|
146
|
-
raise
|
147
|
-
|
148
|
-
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
149
|
-
"""
|
150
|
-
Generate multiple completions.
|
151
|
-
|
152
|
-
Args:
|
153
|
-
messages: List of message dictionaries
|
154
|
-
n: Number of completions to generate
|
155
|
-
|
156
|
-
Returns:
|
157
|
-
List of generated completions
|
158
|
-
"""
|
159
|
-
results = []
|
160
|
-
for _ in range(n):
|
161
|
-
result = await self.achat(messages)
|
162
|
-
results.append(result)
|
163
|
-
return results
|
164
|
-
|
165
|
-
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
166
|
-
"""
|
167
|
-
Stream chat responses.
|
168
|
-
|
169
|
-
Args:
|
170
|
-
messages: List of message dictionaries
|
171
|
-
|
172
|
-
Yields:
|
173
|
-
Generated text chunks
|
174
|
-
"""
|
175
|
-
# For Triton, we don't have true streaming, so we generate the full response
|
176
|
-
# and then simulate streaming
|
177
|
-
full_response = await self.achat(messages)
|
178
|
-
|
179
|
-
# Simulate streaming by yielding words
|
180
|
-
words = full_response.split()
|
181
|
-
for i in range(len(words)):
|
182
|
-
chunk = ' '.join(words[:i+1])
|
183
|
-
yield chunk
|
184
|
-
await asyncio.sleep(0.05) # Small delay to simulate streaming
|
185
|
-
|
186
|
-
def get_token_usage(self) -> Dict[str, int]:
|
187
|
-
"""
|
188
|
-
Get total token usage statistics.
|
189
|
-
|
190
|
-
Returns:
|
191
|
-
Dictionary with token usage statistics
|
192
|
-
"""
|
193
|
-
return self.token_usage
|
194
|
-
|
195
|
-
def get_last_token_usage(self) -> Dict[str, int]:
|
196
|
-
"""
|
197
|
-
Get token usage from last request.
|
198
|
-
|
199
|
-
Returns:
|
200
|
-
Dictionary with token usage statistics from last request
|
201
|
-
"""
|
202
|
-
return self.last_token_usage
|
203
|
-
|
204
|
-
def _format_chat_messages(self, messages: List[Dict[str, str]]) -> str:
|
205
|
-
"""
|
206
|
-
Format chat messages into a single prompt for models that don't support chat natively.
|
207
|
-
|
208
|
-
Args:
|
209
|
-
messages: List of message dictionaries
|
210
|
-
|
211
|
-
Returns:
|
212
|
-
Formatted prompt
|
213
|
-
"""
|
214
|
-
formatted_prompt = ""
|
215
|
-
|
216
|
-
for message in messages:
|
217
|
-
role = message.get("role", "user").lower()
|
218
|
-
content = message.get("content", "")
|
219
|
-
|
220
|
-
if role == "system":
|
221
|
-
formatted_prompt += f"System: {content}\n\n"
|
222
|
-
elif role == "user":
|
223
|
-
formatted_prompt += f"User: {content}\n\n"
|
224
|
-
elif role == "assistant":
|
225
|
-
formatted_prompt += f"Assistant: {content}\n\n"
|
226
|
-
else:
|
227
|
-
formatted_prompt += f"{role.capitalize()}: {content}\n\n"
|
228
|
-
|
229
|
-
formatted_prompt += "Assistant: "
|
230
|
-
return formatted_prompt
|