isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/core/model_manager.py +69 -4
- isa_model/inference/ai_factory.py +335 -46
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +48 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/base_llm_service.py +88 -192
- isa_model/inference/services/llm/llm_adapter.py +459 -0
- isa_model/inference/services/llm/ollama_llm_service.py +111 -185
- isa_model/inference/services/llm/openai_llm_service.py +115 -360
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +11 -3
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
|
|
1
|
+
import logging
|
2
|
+
import asyncio
|
3
|
+
from typing import List, Dict, Any, Optional
|
4
|
+
from openai import AsyncOpenAI
|
5
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
+
|
7
|
+
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
from isa_model.inference.billing_tracker import ServiceType
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
class OpenAIEmbedService(BaseEmbedService):
|
14
|
+
"""
|
15
|
+
OpenAI embedding service using text-embedding-3-small as default.
|
16
|
+
Provides high-quality embeddings for production use.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "text-embedding-3-small"):
|
20
|
+
super().__init__(provider, model_name)
|
21
|
+
|
22
|
+
# Get full configuration from provider (including sensitive data)
|
23
|
+
provider_config = provider.get_full_config()
|
24
|
+
|
25
|
+
# Initialize AsyncOpenAI client with provider configuration
|
26
|
+
try:
|
27
|
+
if not provider_config.get("api_key"):
|
28
|
+
raise ValueError("OpenAI API key not found in provider configuration")
|
29
|
+
|
30
|
+
self.client = AsyncOpenAI(
|
31
|
+
api_key=provider_config["api_key"],
|
32
|
+
base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
|
33
|
+
organization=provider_config.get("organization")
|
34
|
+
)
|
35
|
+
|
36
|
+
logger.info(f"Initialized OpenAIEmbedService with model '{self.model_name}'")
|
37
|
+
|
38
|
+
except Exception as e:
|
39
|
+
logger.error(f"Failed to initialize OpenAI client: {e}")
|
40
|
+
raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
|
41
|
+
|
42
|
+
# Model-specific configurations
|
43
|
+
self.dimensions = provider_config.get('dimensions', None) # Optional dimension reduction
|
44
|
+
self.encoding_format = provider_config.get('encoding_format', 'float')
|
45
|
+
|
46
|
+
@retry(
|
47
|
+
stop=stop_after_attempt(3),
|
48
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
49
|
+
reraise=True
|
50
|
+
)
|
51
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
52
|
+
"""Create embedding for single text"""
|
53
|
+
try:
|
54
|
+
kwargs = {
|
55
|
+
"model": self.model_name,
|
56
|
+
"input": text,
|
57
|
+
"encoding_format": self.encoding_format
|
58
|
+
}
|
59
|
+
|
60
|
+
# Add dimensions parameter if specified (for text-embedding-3-small/large)
|
61
|
+
if self.dimensions and "text-embedding-3" in self.model_name:
|
62
|
+
kwargs["dimensions"] = self.dimensions
|
63
|
+
|
64
|
+
response = await self.client.embeddings.create(**kwargs)
|
65
|
+
|
66
|
+
# Track usage for billing
|
67
|
+
usage = getattr(response, 'usage', None)
|
68
|
+
if usage:
|
69
|
+
total_tokens = getattr(usage, 'total_tokens', 0)
|
70
|
+
self._track_usage(
|
71
|
+
service_type=ServiceType.EMBEDDING,
|
72
|
+
operation="create_text_embedding",
|
73
|
+
input_tokens=total_tokens,
|
74
|
+
output_tokens=0,
|
75
|
+
metadata={
|
76
|
+
"model": self.model_name,
|
77
|
+
"dimensions": self.dimensions,
|
78
|
+
"text_length": len(text)
|
79
|
+
}
|
80
|
+
)
|
81
|
+
|
82
|
+
return response.data[0].embedding
|
83
|
+
|
84
|
+
except Exception as e:
|
85
|
+
logger.error(f"Error creating text embedding: {e}")
|
86
|
+
raise
|
87
|
+
|
88
|
+
@retry(
|
89
|
+
stop=stop_after_attempt(3),
|
90
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
91
|
+
reraise=True
|
92
|
+
)
|
93
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
94
|
+
"""Create embeddings for multiple texts"""
|
95
|
+
if not texts:
|
96
|
+
return []
|
97
|
+
|
98
|
+
try:
|
99
|
+
kwargs = {
|
100
|
+
"model": self.model_name,
|
101
|
+
"input": texts,
|
102
|
+
"encoding_format": self.encoding_format
|
103
|
+
}
|
104
|
+
|
105
|
+
# Add dimensions parameter if specified
|
106
|
+
if self.dimensions and "text-embedding-3" in self.model_name:
|
107
|
+
kwargs["dimensions"] = self.dimensions
|
108
|
+
|
109
|
+
response = await self.client.embeddings.create(**kwargs)
|
110
|
+
|
111
|
+
# Track usage for billing
|
112
|
+
usage = getattr(response, 'usage', None)
|
113
|
+
if usage:
|
114
|
+
total_tokens = getattr(usage, 'total_tokens', 0)
|
115
|
+
self._track_usage(
|
116
|
+
service_type=ServiceType.EMBEDDING,
|
117
|
+
operation="create_text_embeddings",
|
118
|
+
input_tokens=total_tokens,
|
119
|
+
output_tokens=0,
|
120
|
+
metadata={
|
121
|
+
"model": self.model_name,
|
122
|
+
"dimensions": self.dimensions,
|
123
|
+
"batch_size": len(texts),
|
124
|
+
"total_text_length": sum(len(t) for t in texts)
|
125
|
+
}
|
126
|
+
)
|
127
|
+
|
128
|
+
return [data.embedding for data in response.data]
|
129
|
+
|
130
|
+
except Exception as e:
|
131
|
+
logger.error(f"Error creating text embeddings: {e}")
|
132
|
+
raise
|
133
|
+
|
134
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
135
|
+
"""Create text chunks with embeddings"""
|
136
|
+
# Chunk size optimized for OpenAI models (roughly 512 tokens)
|
137
|
+
chunk_size = 400 # words
|
138
|
+
overlap = 50 # word overlap between chunks
|
139
|
+
|
140
|
+
words = text.split()
|
141
|
+
if not words:
|
142
|
+
return []
|
143
|
+
|
144
|
+
chunks = []
|
145
|
+
chunk_texts = []
|
146
|
+
|
147
|
+
for i in range(0, len(words), chunk_size - overlap):
|
148
|
+
chunk_words = words[i:i + chunk_size]
|
149
|
+
chunk_text = " ".join(chunk_words)
|
150
|
+
chunk_texts.append(chunk_text)
|
151
|
+
|
152
|
+
chunks.append({
|
153
|
+
"text": chunk_text,
|
154
|
+
"start_index": i,
|
155
|
+
"end_index": min(i + chunk_size, len(words)),
|
156
|
+
"metadata": metadata or {}
|
157
|
+
})
|
158
|
+
|
159
|
+
# Get embeddings for all chunks
|
160
|
+
embeddings = await self.create_text_embeddings(chunk_texts)
|
161
|
+
|
162
|
+
# Add embeddings to chunks
|
163
|
+
for chunk, embedding in zip(chunks, embeddings):
|
164
|
+
chunk["embedding"] = embedding
|
165
|
+
|
166
|
+
return chunks
|
167
|
+
|
168
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
169
|
+
"""Compute cosine similarity between two embeddings"""
|
170
|
+
import math
|
171
|
+
|
172
|
+
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
173
|
+
norm1 = math.sqrt(sum(a * a for a in embedding1))
|
174
|
+
norm2 = math.sqrt(sum(b * b for b in embedding2))
|
175
|
+
|
176
|
+
if norm1 * norm2 == 0:
|
177
|
+
return 0.0
|
178
|
+
|
179
|
+
return dot_product / (norm1 * norm2)
|
180
|
+
|
181
|
+
async def find_similar_texts(
|
182
|
+
self,
|
183
|
+
query_embedding: List[float],
|
184
|
+
candidate_embeddings: List[List[float]],
|
185
|
+
top_k: int = 5
|
186
|
+
) -> List[Dict[str, Any]]:
|
187
|
+
"""Find most similar texts based on embeddings"""
|
188
|
+
similarities = []
|
189
|
+
|
190
|
+
for i, candidate in enumerate(candidate_embeddings):
|
191
|
+
similarity = await self.compute_similarity(query_embedding, candidate)
|
192
|
+
similarities.append({
|
193
|
+
"index": i,
|
194
|
+
"similarity": similarity
|
195
|
+
})
|
196
|
+
|
197
|
+
# Sort by similarity in descending order and return top_k
|
198
|
+
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
199
|
+
return similarities[:top_k]
|
200
|
+
|
201
|
+
def get_embedding_dimension(self) -> int:
|
202
|
+
"""Get the dimension of embeddings produced by this service"""
|
203
|
+
if self.dimensions:
|
204
|
+
return self.dimensions
|
205
|
+
|
206
|
+
# Default dimensions for OpenAI models
|
207
|
+
model_dimensions = {
|
208
|
+
"text-embedding-3-small": 1536,
|
209
|
+
"text-embedding-3-large": 3072,
|
210
|
+
"text-embedding-ada-002": 1536
|
211
|
+
}
|
212
|
+
|
213
|
+
return model_dimensions.get(self.model_name, 1536)
|
214
|
+
|
215
|
+
def get_max_input_length(self) -> int:
|
216
|
+
"""Get maximum input text length supported"""
|
217
|
+
# OpenAI embedding models support up to 8192 tokens
|
218
|
+
return 8192
|
219
|
+
|
220
|
+
async def close(self):
|
221
|
+
"""Cleanup resources"""
|
222
|
+
await self.client.close()
|
223
|
+
logger.info("OpenAIEmbedService client has been closed.")
|
@@ -1,244 +1,140 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Dict, Any, List, Union, Optional, AsyncGenerator,
|
2
|
+
from typing import Dict, Any, List, Union, Optional, AsyncGenerator, Callable
|
3
3
|
from isa_model.inference.services.base_service import BaseService
|
4
|
-
|
5
|
-
T = TypeVar('T') # Generic type for responses
|
4
|
+
from isa_model.inference.services.llm.llm_adapter import AdapterManager
|
6
5
|
|
7
6
|
class BaseLLMService(BaseService):
|
8
|
-
"""Base class for Large Language Model services"""
|
7
|
+
"""Base class for Large Language Model services with unified invoke interface"""
|
9
8
|
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
9
|
+
def __init__(self, provider, model_name: str):
|
10
|
+
super().__init__(provider, model_name)
|
11
|
+
self._bound_tools: List[Any] = [] # 改为存储原始工具对象
|
12
|
+
self._tool_mappings: Dict[str, tuple] = {} # 工具名到(工具, 适配器)的映射
|
14
13
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
Returns:
|
19
|
-
Model response in the appropriate format
|
20
|
-
"""
|
21
|
-
pass
|
22
|
-
|
23
|
-
@abstractmethod
|
24
|
-
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
25
|
-
"""
|
26
|
-
Chat completion method using message format
|
14
|
+
# 初始化适配器管理器
|
15
|
+
self.adapter_manager = AdapterManager()
|
27
16
|
|
28
|
-
|
29
|
-
|
30
|
-
Example: [{"role": "user", "content": "Hello"}]
|
31
|
-
|
32
|
-
Returns:
|
33
|
-
Chat completion response
|
34
|
-
"""
|
35
|
-
pass
|
17
|
+
# Get streaming config from provider config
|
18
|
+
self.streaming = self.config.get("streaming", False)
|
36
19
|
|
37
|
-
|
38
|
-
async def acompletion(self, prompt: str) -> T:
|
20
|
+
def bind_tools(self, tools: List[Any], **kwargs) -> 'BaseLLMService':
|
39
21
|
"""
|
40
|
-
|
22
|
+
Bind tools to this LLM service for function calling
|
41
23
|
|
42
24
|
Args:
|
43
|
-
|
25
|
+
tools: List of tools to bind (functions, LangChain tools, etc.)
|
26
|
+
**kwargs: Additional tool binding parameters
|
44
27
|
|
45
28
|
Returns:
|
46
|
-
|
29
|
+
Self for method chaining
|
47
30
|
"""
|
48
|
-
|
31
|
+
self._bound_tools = tools
|
32
|
+
return self
|
49
33
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
34
|
+
async def _prepare_tools_for_request(self) -> List[Dict[str, Any]]:
|
35
|
+
"""准备工具用于请求"""
|
36
|
+
if not self._bound_tools:
|
37
|
+
return []
|
54
38
|
|
55
|
-
|
56
|
-
|
57
|
-
n: Number of completions to generate
|
58
|
-
|
59
|
-
Returns:
|
60
|
-
List of completion responses
|
61
|
-
"""
|
62
|
-
pass
|
39
|
+
schemas, self._tool_mappings = await self.adapter_manager.convert_tools_to_schemas(self._bound_tools)
|
40
|
+
return schemas
|
63
41
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
pass
|
42
|
+
def _prepare_messages(self, input_data: Union[str, List[Dict[str, str]], Any]) -> List[Dict[str, str]]:
|
43
|
+
"""使用适配器管理器转换消息格式"""
|
44
|
+
return self.adapter_manager.convert_messages(input_data)
|
45
|
+
|
46
|
+
def _format_response(self, response: str, original_input: Any) -> Union[str, Any]:
|
47
|
+
"""使用适配器管理器格式化响应"""
|
48
|
+
return self.adapter_manager.format_response(response, original_input)
|
49
|
+
|
50
|
+
async def _execute_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
51
|
+
"""使用适配器管理器执行工具调用"""
|
52
|
+
return await self.adapter_manager.execute_tool(tool_name, arguments, self._tool_mappings)
|
76
53
|
|
77
54
|
@abstractmethod
|
78
|
-
async def
|
55
|
+
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
79
56
|
"""
|
80
|
-
|
57
|
+
Universal async invocation method that handles different input types
|
81
58
|
|
82
59
|
Args:
|
83
|
-
|
60
|
+
input_data: Can be:
|
61
|
+
- str: Simple text prompt
|
62
|
+
- list: Message history like [{"role": "user", "content": "hello"}]
|
63
|
+
- Any: LangChain message objects or other formats
|
84
64
|
|
85
|
-
|
86
|
-
|
65
|
+
Returns:
|
66
|
+
Model response (string for simple cases, object for complex cases)
|
87
67
|
"""
|
88
68
|
pass
|
89
69
|
|
90
|
-
def
|
70
|
+
def invoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
91
71
|
"""
|
92
|
-
|
72
|
+
Synchronous wrapper for ainvoke
|
93
73
|
|
94
74
|
Args:
|
95
|
-
|
96
|
-
- Dictionary with tool schema
|
97
|
-
- Callable functions (will be converted to schema)
|
98
|
-
**kwargs: Additional tool binding parameters
|
75
|
+
input_data: Same as ainvoke
|
99
76
|
|
100
77
|
Returns:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
return bound_service
|
116
|
-
|
117
|
-
def _create_bound_copy(self) -> 'BaseLLMService':
|
118
|
-
"""Create a copy of this service for tool binding"""
|
119
|
-
# Default implementation - subclasses should override if needed
|
120
|
-
bound_service = self.__class__(self.provider, self.model_name)
|
121
|
-
bound_service.config = self.config.copy()
|
122
|
-
return bound_service
|
123
|
-
|
124
|
-
def _convert_tools_to_schema(self, tools: List[Union[Dict[str, Any], Callable]]) -> List[Dict[str, Any]]:
|
125
|
-
"""Convert tools to OpenAI function calling schema"""
|
126
|
-
schemas = []
|
127
|
-
for tool in tools:
|
128
|
-
if callable(tool):
|
129
|
-
schema = self._function_to_schema(tool)
|
130
|
-
elif isinstance(tool, dict):
|
131
|
-
schema = tool
|
132
|
-
else:
|
133
|
-
raise ValueError(f"Tool must be callable or dict, got {type(tool)}")
|
134
|
-
schemas.append(schema)
|
135
|
-
return schemas
|
136
|
-
|
137
|
-
def _function_to_schema(self, func: Callable) -> Dict[str, Any]:
|
138
|
-
"""Convert a Python function to OpenAI function schema"""
|
139
|
-
import inspect
|
140
|
-
import json
|
141
|
-
from typing import get_type_hints
|
142
|
-
|
143
|
-
sig = inspect.signature(func)
|
144
|
-
type_hints = get_type_hints(func)
|
145
|
-
|
146
|
-
properties = {}
|
147
|
-
required = []
|
148
|
-
|
149
|
-
for param_name, param in sig.parameters.items():
|
150
|
-
param_type = type_hints.get(param_name, str)
|
151
|
-
|
152
|
-
# Convert Python types to JSON schema types
|
153
|
-
if param_type == str:
|
154
|
-
prop_type = "string"
|
155
|
-
elif param_type == int:
|
156
|
-
prop_type = "integer"
|
157
|
-
elif param_type == float:
|
158
|
-
prop_type = "number"
|
159
|
-
elif param_type == bool:
|
160
|
-
prop_type = "boolean"
|
161
|
-
elif param_type == list:
|
162
|
-
prop_type = "array"
|
163
|
-
elif param_type == dict:
|
164
|
-
prop_type = "object"
|
165
|
-
else:
|
166
|
-
prop_type = "string" # Default fallback
|
167
|
-
|
168
|
-
properties[param_name] = {"type": prop_type}
|
169
|
-
|
170
|
-
# Add parameter to required if it has no default value
|
171
|
-
if param.default == inspect.Parameter.empty:
|
172
|
-
required.append(param_name)
|
173
|
-
|
174
|
-
return {
|
175
|
-
"type": "function",
|
176
|
-
"function": {
|
177
|
-
"name": func.__name__,
|
178
|
-
"description": func.__doc__ or f"Function {func.__name__}",
|
179
|
-
"parameters": {
|
180
|
-
"type": "object",
|
181
|
-
"properties": properties,
|
182
|
-
"required": required
|
183
|
-
}
|
184
|
-
}
|
185
|
-
}
|
78
|
+
Model response
|
79
|
+
"""
|
80
|
+
import asyncio
|
81
|
+
try:
|
82
|
+
# Try to get current event loop
|
83
|
+
loop = asyncio.get_running_loop()
|
84
|
+
# If we're in an event loop, create a new thread
|
85
|
+
import concurrent.futures
|
86
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
87
|
+
future = executor.submit(asyncio.run, self.ainvoke(input_data))
|
88
|
+
return future.result()
|
89
|
+
except RuntimeError:
|
90
|
+
# No event loop running, create a new one
|
91
|
+
return asyncio.run(self.ainvoke(input_data))
|
186
92
|
|
187
93
|
def _has_bound_tools(self) -> bool:
|
188
94
|
"""Check if this service has bound tools"""
|
189
|
-
return
|
95
|
+
return bool(self._bound_tools)
|
190
96
|
|
191
|
-
def _get_bound_tools(self) -> List[
|
192
|
-
"""Get the bound tools
|
193
|
-
return
|
194
|
-
|
195
|
-
def _execute_tool_call(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
196
|
-
"""Execute a tool call by name with arguments"""
|
197
|
-
# This is a placeholder - subclasses should implement actual tool execution
|
198
|
-
raise NotImplementedError("Tool execution not implemented in base class")
|
97
|
+
def _get_bound_tools(self) -> List[Any]:
|
98
|
+
"""Get the bound tools"""
|
99
|
+
return self._bound_tools
|
199
100
|
|
200
101
|
@abstractmethod
|
201
102
|
def get_token_usage(self) -> Dict[str, Any]:
|
202
|
-
"""
|
203
|
-
Get cumulative token usage statistics for this service instance
|
204
|
-
|
205
|
-
Returns:
|
206
|
-
Dict containing token usage information:
|
207
|
-
- total_tokens: Total tokens used
|
208
|
-
- prompt_tokens: Tokens used for prompts
|
209
|
-
- completion_tokens: Tokens used for completions
|
210
|
-
- requests_count: Number of requests made
|
211
|
-
"""
|
103
|
+
"""Get cumulative token usage statistics"""
|
212
104
|
pass
|
213
105
|
|
214
106
|
@abstractmethod
|
215
107
|
def get_last_token_usage(self) -> Dict[str, int]:
|
216
|
-
"""
|
217
|
-
Get token usage from the last request
|
218
|
-
|
219
|
-
Returns:
|
220
|
-
Dict containing last request token usage:
|
221
|
-
- prompt_tokens: Tokens in last prompt
|
222
|
-
- completion_tokens: Tokens in last completion
|
223
|
-
- total_tokens: Total tokens in last request
|
224
|
-
"""
|
108
|
+
"""Get token usage from the last request"""
|
225
109
|
pass
|
226
110
|
|
227
111
|
@abstractmethod
|
228
112
|
def get_model_info(self) -> Dict[str, Any]:
|
229
|
-
"""
|
230
|
-
Get information about the current model
|
231
|
-
|
232
|
-
Returns:
|
233
|
-
Dict containing model information:
|
234
|
-
- name: Model name
|
235
|
-
- max_tokens: Maximum context length
|
236
|
-
- supports_streaming: Whether streaming is supported
|
237
|
-
- supports_functions: Whether function calling is supported
|
238
|
-
"""
|
113
|
+
"""Get information about the current model"""
|
239
114
|
pass
|
240
115
|
|
241
116
|
@abstractmethod
|
242
117
|
async def close(self):
|
243
118
|
"""Cleanup resources and close connections"""
|
244
119
|
pass
|
120
|
+
|
121
|
+
def get_last_usage_with_cost(self) -> Dict[str, Any]:
|
122
|
+
"""Get last request usage with cost information"""
|
123
|
+
usage = self.get_last_token_usage()
|
124
|
+
|
125
|
+
# Calculate cost using provider
|
126
|
+
if hasattr(self.provider, 'calculate_cost'):
|
127
|
+
cost = getattr(self.provider, 'calculate_cost')(
|
128
|
+
self.model_name,
|
129
|
+
usage["prompt_tokens"],
|
130
|
+
usage["completion_tokens"]
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
cost = 0.0
|
134
|
+
|
135
|
+
return {
|
136
|
+
**usage,
|
137
|
+
"cost_usd": cost,
|
138
|
+
"model": self.model_name,
|
139
|
+
"provider": getattr(self.provider, 'name', 'unknown')
|
140
|
+
}
|