isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  4. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  5. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  6. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  7. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  8. isa_model/eval/__init__.py +56 -0
  9. isa_model/eval/benchmarks.py +469 -0
  10. isa_model/eval/factory.py +582 -0
  11. isa_model/eval/metrics.py +628 -0
  12. isa_model/inference/ai_factory.py +98 -93
  13. isa_model/inference/providers/openai_provider.py +21 -7
  14. isa_model/inference/providers/replicate_provider.py +18 -5
  15. isa_model/inference/providers/triton_provider.py +1 -1
  16. isa_model/inference/services/audio/base_stt_service.py +91 -0
  17. isa_model/inference/services/audio/base_tts_service.py +136 -0
  18. isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
  19. isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
  20. isa_model/inference/services/llm/__init__.py +0 -4
  21. isa_model/inference/services/llm/base_llm_service.py +134 -0
  22. isa_model/inference/services/llm/ollama_llm_service.py +1 -10
  23. isa_model/inference/services/llm/openai_llm_service.py +70 -61
  24. isa_model/inference/services/vision/__init__.py +1 -1
  25. isa_model/inference/services/vision/ollama_vision_service.py +4 -4
  26. isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
  27. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  28. isa_model/training/__init__.py +44 -0
  29. isa_model/training/factory.py +393 -0
  30. isa_model-0.1.1.dist-info/METADATA +327 -0
  31. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
  32. isa_model/deployment/mlflow_gateway/__init__.py +0 -8
  33. isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
  34. isa_model/deployment/unified_multimodal_client.py +0 -341
  35. isa_model/inference/adapter/triton_adapter.py +0 -453
  36. isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
  37. isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
  38. isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
  39. isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
  40. isa_model/inference/backends/__init__.py +0 -53
  41. isa_model/inference/backends/base_backend_client.py +0 -26
  42. isa_model/inference/backends/container_services.py +0 -104
  43. isa_model/inference/backends/local_services.py +0 -72
  44. isa_model/inference/backends/openai_client.py +0 -130
  45. isa_model/inference/backends/replicate_client.py +0 -197
  46. isa_model/inference/backends/third_party_services.py +0 -239
  47. isa_model/inference/backends/triton_client.py +0 -97
  48. isa_model/inference/client_sdk/client.py +0 -134
  49. isa_model/inference/client_sdk/client_data_std.py +0 -34
  50. isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
  51. isa_model/inference/client_sdk/exceptions.py +0 -0
  52. isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
  53. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
  54. isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
  55. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
  56. isa_model/inference/providers/vllm_provider.py +0 -0
  57. isa_model/inference/providers/yyds_provider.py +0 -83
  58. isa_model/inference/services/audio/fish_speech/handler.py +0 -215
  59. isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
  60. isa_model/inference/services/audio/triton_speech_service.py +0 -138
  61. isa_model/inference/services/audio/whisper_service.py +0 -186
  62. isa_model/inference/services/base_tts_service.py +0 -66
  63. isa_model/inference/services/embedding/bge_service.py +0 -183
  64. isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
  65. isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
  66. isa_model/inference/services/llm/gemma_service.py +0 -143
  67. isa_model/inference/services/llm/llama_service.py +0 -143
  68. isa_model/inference/services/llm/replicate_llm_service.py +0 -179
  69. isa_model/inference/services/llm/triton_llm_service.py +0 -230
  70. isa_model/inference/services/vision/replicate_vision_service.py +0 -241
  71. isa_model/inference/services/vision/triton_vision_service.py +0 -199
  72. isa_model-0.1.0.dist-info/METADATA +0 -116
  73. /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
  74. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
  75. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
  76. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,143 +0,0 @@
1
- import json
2
- import logging
3
- import asyncio
4
- from typing import Dict, List, Any, Optional, Union
5
-
6
- from isa_model.inference.services.base_service import BaseService
7
- from isa_model.inference.backends.triton_client import TritonClient
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class GemmaService(BaseService):
13
- """
14
- Service for Gemma LLM using Triton Inference Server.
15
- """
16
-
17
- def __init__(self, triton_url: str = "localhost:8001", model_name: str = "gemma"):
18
- """
19
- Initialize the Gemma service.
20
-
21
- Args:
22
- triton_url: URL of the Triton Inference Server
23
- model_name: Name of the model in Triton
24
- """
25
- super().__init__()
26
- self.triton_url = triton_url
27
- self.model_name = model_name
28
- self.client = None
29
-
30
- # Default generation config
31
- self.default_config = {
32
- "max_new_tokens": 512,
33
- "temperature": 0.7,
34
- "top_p": 0.9,
35
- "top_k": 50,
36
- "repetition_penalty": 1.1,
37
- "do_sample": True
38
- }
39
-
40
- self.logger = logger
41
-
42
- async def load(self) -> None:
43
- """
44
- Load the client connection to Triton.
45
- """
46
- if self.is_loaded():
47
- return
48
-
49
- try:
50
- # Create Triton client
51
- self.logger.info(f"Connecting to Triton server at {self.triton_url}")
52
- self.client = TritonClient(self.triton_url)
53
-
54
- # Check if model is ready
55
- if not await self.client.is_model_ready(self.model_name):
56
- self.logger.error(f"Model {self.model_name} is not ready on Triton server")
57
- raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
58
-
59
- self._loaded = True
60
- self.logger.info(f"Connected to Triton for model {self.model_name}")
61
-
62
- except Exception as e:
63
- self.logger.error(f"Failed to connect to Triton: {str(e)}")
64
- raise
65
-
66
- async def unload(self) -> None:
67
- """
68
- Unload the client connection.
69
- """
70
- if not self.is_loaded():
71
- return
72
-
73
- self.client = None
74
- self._loaded = False
75
- self.logger.info("Triton client connection closed")
76
-
77
- async def generate(self,
78
- prompt: str,
79
- system_prompt: Optional[str] = None,
80
- generation_config: Optional[Dict[str, Any]] = None) -> str:
81
- """
82
- Generate text from a prompt using Triton.
83
-
84
- Args:
85
- prompt: User prompt
86
- system_prompt: System prompt to control model behavior
87
- generation_config: Configuration for text generation
88
-
89
- Returns:
90
- Generated text
91
- """
92
- if not self.is_loaded():
93
- await self.load()
94
-
95
- # Get configuration
96
- merged_config = self.default_config.copy()
97
- if generation_config:
98
- merged_config.update(generation_config)
99
-
100
- try:
101
- # Prepare inputs
102
- inputs = {
103
- "prompt": [prompt],
104
- }
105
-
106
- # Add optional inputs
107
- if system_prompt:
108
- inputs["system_prompt"] = [system_prompt]
109
-
110
- if merged_config:
111
- inputs["generation_config"] = [json.dumps(merged_config)]
112
-
113
- # Run inference
114
- result = await self.client.infer(
115
- model_name=self.model_name,
116
- inputs=inputs,
117
- outputs=["text_output"]
118
- )
119
-
120
- # Extract generated text
121
- generated_text = result["text_output"][0].decode('utf-8')
122
-
123
- return generated_text
124
-
125
- except Exception as e:
126
- self.logger.error(f"Error during text generation: {str(e)}")
127
- raise
128
-
129
- def get_model_info(self) -> Dict[str, Any]:
130
- """
131
- Get information about the model.
132
-
133
- Returns:
134
- Dictionary containing model information
135
- """
136
- return {
137
- "name": self.model_name,
138
- "type": "llm",
139
- "backend": "triton",
140
- "url": self.triton_url,
141
- "loaded": self.is_loaded(),
142
- "config": self.default_config
143
- }
@@ -1,143 +0,0 @@
1
- import json
2
- import logging
3
- import asyncio
4
- from typing import Dict, List, Any, Optional, Union
5
-
6
- from isa_model.inference.services.base_service import BaseService
7
- from isa_model.inference.backends.triton_client import TritonClient
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class LlamaService(BaseService):
13
- """
14
- Service for Llama LLM using Triton Inference Server.
15
- """
16
-
17
- def __init__(self, triton_url: str = "localhost:8001", model_name: str = "llama"):
18
- """
19
- Initialize the Llama service.
20
-
21
- Args:
22
- triton_url: URL of the Triton Inference Server
23
- model_name: Name of the model in Triton
24
- """
25
- super().__init__()
26
- self.triton_url = triton_url
27
- self.model_name = model_name
28
- self.client = None
29
-
30
- # Default generation config
31
- self.default_config = {
32
- "max_new_tokens": 512,
33
- "temperature": 0.7,
34
- "top_p": 0.9,
35
- "top_k": 50,
36
- "repetition_penalty": 1.1,
37
- "do_sample": True
38
- }
39
-
40
- self.logger = logger
41
-
42
- async def load(self) -> None:
43
- """
44
- Load the client connection to Triton.
45
- """
46
- if self.is_loaded():
47
- return
48
-
49
- try:
50
- # Create Triton client
51
- self.logger.info(f"Connecting to Triton server at {self.triton_url}")
52
- self.client = TritonClient(self.triton_url)
53
-
54
- # Check if model is ready
55
- if not await self.client.is_model_ready(self.model_name):
56
- self.logger.error(f"Model {self.model_name} is not ready on Triton server")
57
- raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
58
-
59
- self._loaded = True
60
- self.logger.info(f"Connected to Triton for model {self.model_name}")
61
-
62
- except Exception as e:
63
- self.logger.error(f"Failed to connect to Triton: {str(e)}")
64
- raise
65
-
66
- async def unload(self) -> None:
67
- """
68
- Unload the client connection.
69
- """
70
- if not self.is_loaded():
71
- return
72
-
73
- self.client = None
74
- self._loaded = False
75
- self.logger.info("Triton client connection closed")
76
-
77
- async def generate(self,
78
- prompt: str,
79
- system_prompt: Optional[str] = None,
80
- generation_config: Optional[Dict[str, Any]] = None) -> str:
81
- """
82
- Generate text from a prompt using Triton.
83
-
84
- Args:
85
- prompt: User prompt
86
- system_prompt: System prompt to control model behavior
87
- generation_config: Configuration for text generation
88
-
89
- Returns:
90
- Generated text
91
- """
92
- if not self.is_loaded():
93
- await self.load()
94
-
95
- # Get configuration
96
- merged_config = self.default_config.copy()
97
- if generation_config:
98
- merged_config.update(generation_config)
99
-
100
- try:
101
- # Prepare inputs
102
- inputs = {
103
- "prompt": [prompt],
104
- }
105
-
106
- # Add optional inputs
107
- if system_prompt:
108
- inputs["system_prompt"] = [system_prompt]
109
-
110
- if merged_config:
111
- inputs["generation_config"] = [json.dumps(merged_config)]
112
-
113
- # Run inference
114
- result = await self.client.infer(
115
- model_name=self.model_name,
116
- inputs=inputs,
117
- outputs=["text_output"]
118
- )
119
-
120
- # Extract generated text
121
- generated_text = result["text_output"][0].decode('utf-8')
122
-
123
- return generated_text
124
-
125
- except Exception as e:
126
- self.logger.error(f"Error during text generation: {str(e)}")
127
- raise
128
-
129
- def get_model_info(self) -> Dict[str, Any]:
130
- """
131
- Get information about the model.
132
-
133
- Returns:
134
- Dictionary containing model information
135
- """
136
- return {
137
- "name": self.model_name,
138
- "type": "llm",
139
- "backend": "triton",
140
- "url": self.triton_url,
141
- "loaded": self.is_loaded(),
142
- "config": self.default_config
143
- }
@@ -1,179 +0,0 @@
1
- import logging
2
- from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
- from isa_model.inference.services.base_service import BaseLLMService
4
- from isa_model.inference.providers.base_provider import BaseProvider
5
- from isa_model.inference.backends.replicate_client import ReplicateBackendClient
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class ReplicateLLMService(BaseLLMService):
10
- """Replicate LLM service implementation"""
11
-
12
- def __init__(self, provider: 'BaseProvider', model_name: str = "meta/llama-3-8b-instruct", backend: Optional[ReplicateBackendClient] = None):
13
- super().__init__(provider, model_name)
14
-
15
- # Use provided backend or create new one
16
- if backend:
17
- self.backend = backend
18
- else:
19
- api_token = self.config.get("api_token", "")
20
- self.backend = ReplicateBackendClient(api_token)
21
-
22
- # Parse model name for Replicate format (owner/model)
23
- if "/" not in model_name:
24
- logger.warning(f"Model name {model_name} is not in Replicate format (owner/model). Using as-is.")
25
-
26
- # Store version separately if provided
27
- self.model_version = None
28
- if ":" in model_name:
29
- self.model_name, self.model_version = model_name.split(":", 1)
30
-
31
- self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
32
- logger.info(f"Initialized ReplicateLLMService with model {model_name}")
33
-
34
- async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
35
- """Universal invocation method"""
36
- if isinstance(prompt, str):
37
- return await self.acompletion(prompt)
38
- elif isinstance(prompt, list):
39
- return await self.achat(prompt)
40
- else:
41
- raise ValueError("Prompt must be string or list of messages")
42
-
43
- async def achat(self, messages: List[Dict[str, str]]):
44
- """Chat completion method"""
45
- try:
46
- temperature = self.config.get("temperature", 0.7)
47
- max_tokens = self.config.get("max_tokens", 1024)
48
-
49
- # Convert to Replicate format
50
- prompt = self._convert_messages_to_prompt(messages)
51
-
52
- # Prepare input data
53
- input_data = {
54
- "prompt": prompt,
55
- "temperature": temperature,
56
- "max_new_tokens": max_tokens,
57
- "system_prompt": self._extract_system_prompt(messages)
58
- }
59
-
60
- # Call Replicate API
61
- prediction = await self.backend.create_prediction(
62
- self.model_name,
63
- self.model_version,
64
- input_data
65
- )
66
-
67
- # Get output - could be a list of strings or a single string
68
- output = prediction.get("output", "")
69
- if isinstance(output, list):
70
- output = "".join(output)
71
-
72
- # Approximate token usage - Replicate doesn't provide token counts
73
- approx_prompt_tokens = len(prompt) // 4 # Very rough approximation
74
- approx_completion_tokens = len(output) // 4
75
- self.last_token_usage = {
76
- "prompt_tokens": approx_prompt_tokens,
77
- "completion_tokens": approx_completion_tokens,
78
- "total_tokens": approx_prompt_tokens + approx_completion_tokens
79
- }
80
-
81
- return output
82
-
83
- except Exception as e:
84
- logger.error(f"Error in chat completion: {e}")
85
- raise
86
-
87
- def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
88
- """Convert chat messages to Replicate prompt format"""
89
- prompt = ""
90
-
91
- for msg in messages:
92
- role = msg.get("role", "user")
93
- content = msg.get("content", "")
94
-
95
- # Skip system prompts - handled separately
96
- if role == "system":
97
- continue
98
-
99
- if role == "user":
100
- prompt += f"Human: {content}\n\n"
101
- elif role == "assistant":
102
- prompt += f"Assistant: {content}\n\n"
103
- else:
104
- # Default to user for unknown roles
105
- prompt += f"Human: {content}\n\n"
106
-
107
- # Add final assistant prefix for the model to continue
108
- prompt += "Assistant: "
109
-
110
- return prompt
111
-
112
- def _extract_system_prompt(self, messages: List[Dict[str, str]]) -> str:
113
- """Extract system prompt from messages"""
114
- for msg in messages:
115
- if msg.get("role") == "system":
116
- return msg.get("content", "")
117
- return ""
118
-
119
- async def acompletion(self, prompt: str):
120
- """Text completion method"""
121
- try:
122
- # For simple completion, use chat format with a single user message
123
- messages = [{"role": "user", "content": prompt}]
124
- return await self.achat(messages)
125
- except Exception as e:
126
- logger.error(f"Error in text completion: {e}")
127
- raise
128
-
129
- async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
130
- """Generate multiple completions"""
131
- # Replicate doesn't support multiple outputs in one call,
132
- # so we make multiple calls
133
- results = []
134
- for _ in range(n):
135
- result = await self.achat(messages)
136
- results.append(result)
137
- return results
138
-
139
- async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
140
- """Stream chat responses"""
141
- try:
142
- temperature = self.config.get("temperature", 0.7)
143
- max_tokens = self.config.get("max_tokens", 1024)
144
-
145
- # Convert to Replicate format
146
- prompt = self._convert_messages_to_prompt(messages)
147
-
148
- # Prepare input data
149
- input_data = {
150
- "prompt": prompt,
151
- "temperature": temperature,
152
- "max_new_tokens": max_tokens,
153
- "system_prompt": self._extract_system_prompt(messages),
154
- "stream": True
155
- }
156
-
157
- # Call Replicate API with streaming
158
- async for chunk in self.backend.stream_prediction(
159
- self.model_name,
160
- self.model_version,
161
- input_data
162
- ):
163
- yield chunk
164
-
165
- except Exception as e:
166
- logger.error(f"Error in stream chat: {e}")
167
- raise
168
-
169
- def get_token_usage(self):
170
- """Get total token usage statistics"""
171
- return self.last_token_usage
172
-
173
- def get_last_token_usage(self) -> Dict[str, int]:
174
- """Get token usage from last request"""
175
- return self.last_token_usage
176
-
177
- async def close(self):
178
- """Close the backend client"""
179
- await self.backend.close()
@@ -1,230 +0,0 @@
1
- import json
2
- import logging
3
- import asyncio
4
- from typing import Dict, List, Any, AsyncGenerator, Optional, Union
5
-
6
- import numpy as np
7
-
8
- from isa_model.inference.services.base_service import BaseLLMService
9
- from isa_model.inference.providers.triton_provider import TritonProvider
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class TritonLLMService(BaseLLMService):
15
- """
16
- LLM service that uses Triton Inference Server to run inference.
17
- """
18
-
19
- def __init__(self, provider: TritonProvider, model_name: str):
20
- """
21
- Initialize the Triton LLM service.
22
-
23
- Args:
24
- provider: The Triton provider
25
- model_name: Name of the model in Triton (e.g., "Llama3-8B")
26
- """
27
- super().__init__(provider, model_name)
28
- self.client = None
29
- self.tokenizer = None
30
- self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
31
- self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
32
-
33
- async def _initialize_client(self):
34
- """Initialize the Triton client"""
35
- if self.client is None:
36
- self.client = self.provider.create_client()
37
-
38
- # Check if model is ready
39
- if not self.provider.is_model_ready(self.model_name):
40
- logger.error(f"Model {self.model_name} is not ready on Triton server")
41
- raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
42
-
43
- logger.info(f"Initialized Triton client for model: {self.model_name}")
44
-
45
- async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
46
- """
47
- Universal invocation method.
48
-
49
- Args:
50
- prompt: Text prompt or chat messages
51
-
52
- Returns:
53
- Generated text
54
- """
55
- if isinstance(prompt, str):
56
- return await self.acompletion(prompt)
57
- elif isinstance(prompt, list) and all(isinstance(m, dict) for m in prompt):
58
- return await self.achat(prompt)
59
- else:
60
- raise ValueError("Prompt must be either a string or a list of message dictionaries")
61
-
62
- async def achat(self, messages: List[Dict[str, str]]) -> str:
63
- """
64
- Chat completion method.
65
-
66
- Args:
67
- messages: List of message dictionaries with 'role' and 'content'
68
-
69
- Returns:
70
- Generated chat response
71
- """
72
- # Format chat messages into a single prompt
73
- formatted_prompt = self._format_chat_messages(messages)
74
- return await self.acompletion(formatted_prompt)
75
-
76
- async def acompletion(self, prompt: str) -> str:
77
- """
78
- Text completion method.
79
-
80
- Args:
81
- prompt: Text prompt
82
-
83
- Returns:
84
- Generated text completion
85
- """
86
- await self._initialize_client()
87
-
88
- try:
89
- import tritonclient.http as httpclient
90
-
91
- # Create input tensors
92
- input_text = np.array([prompt], dtype=np.object_)
93
- inputs = [httpclient.InferInput("TEXT", input_text.shape, "BYTES")]
94
- inputs[0].set_data_from_numpy(input_text)
95
-
96
- # Default parameters
97
- generation_params = {
98
- "max_new_tokens": self.config.get("max_new_tokens", 512),
99
- "temperature": self.config.get("temperature", 0.7),
100
- "top_p": self.config.get("top_p", 0.9),
101
- "do_sample": self.config.get("temperature", 0.7) > 0
102
- }
103
-
104
- # Add parameters as input tensor
105
- param_json = json.dumps(generation_params)
106
- param_data = np.array([param_json], dtype=np.object_)
107
- param_input = httpclient.InferInput("PARAMETERS", param_data.shape, "BYTES")
108
- param_input.set_data_from_numpy(param_data)
109
- inputs.append(param_input)
110
-
111
- # Create output tensor
112
- outputs = [httpclient.InferRequestedOutput("TEXT")]
113
-
114
- # Send the request
115
- response = await asyncio.to_thread(
116
- self.client.infer,
117
- self.model_name,
118
- inputs,
119
- outputs=outputs
120
- )
121
-
122
- # Process the response
123
- output = response.as_numpy("TEXT")
124
- response_text = output[0].decode('utf-8')
125
-
126
- # Update token usage (estimated since we don't have actual token counts)
127
- prompt_tokens = len(prompt) // 4 # Rough estimate
128
- completion_tokens = len(response_text) // 4 # Rough estimate
129
- total_tokens = prompt_tokens + completion_tokens
130
-
131
- self.last_token_usage = {
132
- "prompt_tokens": prompt_tokens,
133
- "completion_tokens": completion_tokens,
134
- "total_tokens": total_tokens
135
- }
136
-
137
- # Update total token usage
138
- self.token_usage["prompt_tokens"] += prompt_tokens
139
- self.token_usage["completion_tokens"] += completion_tokens
140
- self.token_usage["total_tokens"] += total_tokens
141
-
142
- return response_text
143
-
144
- except Exception as e:
145
- logger.error(f"Error during Triton inference: {str(e)}")
146
- raise
147
-
148
- async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
149
- """
150
- Generate multiple completions.
151
-
152
- Args:
153
- messages: List of message dictionaries
154
- n: Number of completions to generate
155
-
156
- Returns:
157
- List of generated completions
158
- """
159
- results = []
160
- for _ in range(n):
161
- result = await self.achat(messages)
162
- results.append(result)
163
- return results
164
-
165
- async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
166
- """
167
- Stream chat responses.
168
-
169
- Args:
170
- messages: List of message dictionaries
171
-
172
- Yields:
173
- Generated text chunks
174
- """
175
- # For Triton, we don't have true streaming, so we generate the full response
176
- # and then simulate streaming
177
- full_response = await self.achat(messages)
178
-
179
- # Simulate streaming by yielding words
180
- words = full_response.split()
181
- for i in range(len(words)):
182
- chunk = ' '.join(words[:i+1])
183
- yield chunk
184
- await asyncio.sleep(0.05) # Small delay to simulate streaming
185
-
186
- def get_token_usage(self) -> Dict[str, int]:
187
- """
188
- Get total token usage statistics.
189
-
190
- Returns:
191
- Dictionary with token usage statistics
192
- """
193
- return self.token_usage
194
-
195
- def get_last_token_usage(self) -> Dict[str, int]:
196
- """
197
- Get token usage from last request.
198
-
199
- Returns:
200
- Dictionary with token usage statistics from last request
201
- """
202
- return self.last_token_usage
203
-
204
- def _format_chat_messages(self, messages: List[Dict[str, str]]) -> str:
205
- """
206
- Format chat messages into a single prompt for models that don't support chat natively.
207
-
208
- Args:
209
- messages: List of message dictionaries
210
-
211
- Returns:
212
- Formatted prompt
213
- """
214
- formatted_prompt = ""
215
-
216
- for message in messages:
217
- role = message.get("role", "user").lower()
218
- content = message.get("content", "")
219
-
220
- if role == "system":
221
- formatted_prompt += f"System: {content}\n\n"
222
- elif role == "user":
223
- formatted_prompt += f"User: {content}\n\n"
224
- elif role == "assistant":
225
- formatted_prompt += f"Assistant: {content}\n\n"
226
- else:
227
- formatted_prompt += f"{role.capitalize()}: {content}\n\n"
228
-
229
- formatted_prompt += "Assistant: "
230
- return formatted_prompt