indoxrouter 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indoxRouter/__init__.py +83 -0
- indoxRouter/client.py +564 -218
- indoxRouter/client_resourses/__init__.py +20 -0
- indoxRouter/client_resourses/base.py +67 -0
- indoxRouter/client_resourses/chat.py +144 -0
- indoxRouter/client_resourses/completion.py +138 -0
- indoxRouter/client_resourses/embedding.py +83 -0
- indoxRouter/client_resourses/image.py +116 -0
- indoxRouter/client_resourses/models.py +114 -0
- indoxRouter/config.py +151 -0
- indoxRouter/constants/__init__.py +81 -0
- indoxRouter/exceptions/__init__.py +70 -0
- indoxRouter/models/__init__.py +111 -0
- indoxRouter/providers/__init__.py +50 -50
- indoxRouter/providers/ai21labs.json +128 -0
- indoxRouter/providers/base_provider.py +62 -30
- indoxRouter/providers/claude.json +164 -0
- indoxRouter/providers/cohere.json +116 -0
- indoxRouter/providers/databricks.json +110 -0
- indoxRouter/providers/deepseek.json +110 -0
- indoxRouter/providers/google.json +128 -0
- indoxRouter/providers/meta.json +128 -0
- indoxRouter/providers/mistral.json +146 -0
- indoxRouter/providers/nvidia.json +110 -0
- indoxRouter/providers/openai.json +308 -0
- indoxRouter/providers/openai.py +471 -72
- indoxRouter/providers/qwen.json +110 -0
- indoxRouter/utils/__init__.py +240 -0
- indoxrouter-0.1.2.dist-info/LICENSE +21 -0
- indoxrouter-0.1.2.dist-info/METADATA +259 -0
- indoxrouter-0.1.2.dist-info/RECORD +33 -0
- indoxRouter/api_endpoints.py +0 -336
- indoxRouter/client_package.py +0 -138
- indoxRouter/init_db.py +0 -71
- indoxRouter/main.py +0 -711
- indoxRouter/migrations/__init__.py +0 -1
- indoxRouter/migrations/env.py +0 -98
- indoxRouter/migrations/versions/__init__.py +0 -1
- indoxRouter/migrations/versions/initial_schema.py +0 -84
- indoxRouter/providers/ai21.py +0 -268
- indoxRouter/providers/claude.py +0 -177
- indoxRouter/providers/cohere.py +0 -171
- indoxRouter/providers/databricks.py +0 -166
- indoxRouter/providers/deepseek.py +0 -166
- indoxRouter/providers/google.py +0 -216
- indoxRouter/providers/llama.py +0 -164
- indoxRouter/providers/meta.py +0 -227
- indoxRouter/providers/mistral.py +0 -182
- indoxRouter/providers/nvidia.py +0 -164
- indoxrouter-0.1.0.dist-info/METADATA +0 -179
- indoxrouter-0.1.0.dist-info/RECORD +0 -27
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/WHEEL +0 -0
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/top_level.txt +0 -0
indoxRouter/providers/openai.py
CHANGED
@@ -1,122 +1,521 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
"""
|
2
|
+
OpenAI provider for indoxRouter.
|
3
|
+
"""
|
4
|
+
|
3
5
|
import os
|
4
|
-
from
|
5
|
-
|
6
|
+
from typing import Dict, List, Any, Optional, Union
|
7
|
+
|
8
|
+
import openai
|
9
|
+
from openai import OpenAI
|
10
|
+
from datetime import datetime
|
6
11
|
from .base_provider import BaseProvider
|
12
|
+
from ..exceptions import AuthenticationError, RequestError, RateLimitError
|
13
|
+
from ..utils import calculate_cost, get_model_info
|
14
|
+
from ..models import ChatMessage
|
7
15
|
|
8
16
|
|
9
17
|
class Provider(BaseProvider):
|
10
|
-
"""
|
11
|
-
OpenAI provider implementation
|
12
|
-
"""
|
18
|
+
"""OpenAI provider implementation."""
|
13
19
|
|
14
20
|
def __init__(self, api_key: str, model_name: str):
|
15
21
|
"""
|
16
|
-
Initialize the OpenAI provider
|
22
|
+
Initialize the OpenAI provider.
|
17
23
|
|
18
24
|
Args:
|
19
|
-
api_key:
|
20
|
-
model_name:
|
25
|
+
api_key: The API key for OpenAI.
|
26
|
+
model_name: The name of the model to use.
|
21
27
|
"""
|
22
28
|
super().__init__(api_key, model_name)
|
23
|
-
self.client =
|
24
|
-
self.
|
29
|
+
self.client = OpenAI(api_key=api_key)
|
30
|
+
self.model_info = get_model_info("openai", model_name)
|
25
31
|
|
26
|
-
def
|
32
|
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
27
33
|
"""
|
28
|
-
|
34
|
+
Send a chat request to OpenAI.
|
29
35
|
|
30
36
|
Args:
|
31
|
-
|
37
|
+
messages: A list of message dictionaries with 'role' and 'content' keys.
|
38
|
+
**kwargs: Additional parameters to pass to the OpenAI API.
|
32
39
|
|
33
40
|
Returns:
|
34
|
-
|
41
|
+
A dictionary containing the response from OpenAI.
|
42
|
+
If stream=True and return_generator=True, returns a generator that yields chunks of the response.
|
43
|
+
|
44
|
+
Raises:
|
45
|
+
AuthenticationError: If the API key is invalid.
|
46
|
+
RequestError: If the request fails.
|
47
|
+
RateLimitError: If the rate limit is exceeded.
|
35
48
|
"""
|
36
|
-
|
37
|
-
|
49
|
+
try:
|
50
|
+
# Check if streaming is requested
|
51
|
+
stream = kwargs.pop("stream", False)
|
52
|
+
# Check if we should return a generator
|
53
|
+
return_generator = kwargs.pop("return_generator", False)
|
54
|
+
|
55
|
+
# If streaming is requested, we need to handle it differently
|
56
|
+
if stream:
|
57
|
+
# Remove stream from kwargs to avoid passing it twice
|
58
|
+
openai_messages = []
|
59
|
+
for msg in messages:
|
60
|
+
if isinstance(msg, ChatMessage):
|
61
|
+
openai_messages.append(
|
62
|
+
{"role": msg.role, "content": msg.content}
|
63
|
+
)
|
64
|
+
else:
|
65
|
+
openai_messages.append(msg)
|
66
|
+
|
67
|
+
# Create the streaming response
|
68
|
+
stream_response = self.client.chat.completions.create(
|
69
|
+
model=self.model_name,
|
70
|
+
messages=openai_messages,
|
71
|
+
stream=True,
|
72
|
+
**kwargs,
|
73
|
+
)
|
74
|
+
|
75
|
+
# If return_generator is True, return a generator that yields chunks
|
76
|
+
if return_generator:
|
77
|
+
# Create a streaming generator with usage tracking
|
78
|
+
return StreamingGenerator(
|
79
|
+
stream_response=stream_response,
|
80
|
+
model_name=self.model_name,
|
81
|
+
messages=messages,
|
82
|
+
)
|
83
|
+
|
84
|
+
# Otherwise, collect the full response content from the stream
|
85
|
+
content = ""
|
86
|
+
for chunk in stream_response:
|
87
|
+
if hasattr(chunk, "choices") and len(chunk.choices) > 0:
|
88
|
+
delta = chunk.choices[0].delta
|
89
|
+
if hasattr(delta, "content") and delta.content is not None:
|
90
|
+
content += delta.content
|
91
|
+
|
92
|
+
# For streaming responses, we don't have usage information directly
|
93
|
+
# We'll provide a minimal response with the content
|
94
|
+
return {
|
95
|
+
"data": content,
|
96
|
+
"model": self.model_name,
|
97
|
+
"provider": "openai",
|
98
|
+
"success": True,
|
99
|
+
"message": "Successfully completed streaming chat request",
|
100
|
+
"cost": 0.0, # We don't have cost information for streaming responses
|
101
|
+
"timestamp": datetime.now().isoformat(),
|
102
|
+
"usage": {
|
103
|
+
"tokens_prompt": 0, # We don't have token information for streaming responses
|
104
|
+
"tokens_completion": 0,
|
105
|
+
"tokens_total": 0,
|
106
|
+
},
|
107
|
+
"finish_reason": "stop", # Default finish reason
|
108
|
+
"raw_response": None, # We don't have the raw response for streaming
|
109
|
+
}
|
38
110
|
|
39
|
-
|
40
|
-
|
41
|
-
|
111
|
+
# Handle non-streaming responses as before
|
112
|
+
openai_messages = []
|
113
|
+
for msg in messages:
|
114
|
+
if isinstance(msg, ChatMessage):
|
115
|
+
openai_messages.append({"role": msg.role, "content": msg.content})
|
116
|
+
else:
|
117
|
+
openai_messages.append(msg)
|
118
|
+
response = self.client.chat.completions.create(
|
119
|
+
model=self.model_name,
|
120
|
+
messages=openai_messages,
|
121
|
+
**kwargs,
|
122
|
+
)
|
123
|
+
# Extract the response content
|
124
|
+
content = response.choices[0].message.content
|
42
125
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
126
|
+
# Extract usage information from the response
|
127
|
+
prompt_tokens = (
|
128
|
+
response.usage.prompt_tokens
|
129
|
+
if hasattr(response.usage, "prompt_tokens")
|
130
|
+
else 0
|
131
|
+
)
|
132
|
+
completion_tokens = (
|
133
|
+
response.usage.completion_tokens
|
134
|
+
if hasattr(response.usage, "completion_tokens")
|
135
|
+
else 0
|
136
|
+
)
|
137
|
+
total_tokens = (
|
138
|
+
response.usage.total_tokens
|
139
|
+
if hasattr(response.usage, "total_tokens")
|
140
|
+
else 0
|
141
|
+
)
|
47
142
|
|
48
|
-
|
49
|
-
|
143
|
+
cost = calculate_cost(
|
144
|
+
f"openai/{self.model_name}",
|
145
|
+
input_tokens=prompt_tokens,
|
146
|
+
output_tokens=completion_tokens,
|
147
|
+
)
|
50
148
|
|
51
|
-
|
149
|
+
# Create a response dictionary with the extracted information
|
150
|
+
return {
|
151
|
+
"data": content,
|
152
|
+
"model": self.model_name,
|
153
|
+
"provider": "openai",
|
154
|
+
"success": True,
|
155
|
+
"message": "Successfully completed chat request",
|
156
|
+
"cost": cost,
|
157
|
+
"timestamp": datetime.now().isoformat(),
|
158
|
+
# Add usage as dict with consistent field names
|
159
|
+
"usage": {
|
160
|
+
"tokens_prompt": prompt_tokens,
|
161
|
+
"tokens_completion": completion_tokens,
|
162
|
+
"tokens_total": total_tokens,
|
163
|
+
},
|
164
|
+
# Optional fields
|
165
|
+
"finish_reason": response.choices[0].finish_reason,
|
166
|
+
"raw_response": response.model_dump(),
|
167
|
+
}
|
168
|
+
|
169
|
+
except openai.AuthenticationError:
|
170
|
+
raise AuthenticationError("Invalid OpenAI API key.")
|
171
|
+
except openai.RateLimitError:
|
172
|
+
raise RateLimitError("OpenAI rate limit exceeded.")
|
173
|
+
except Exception as e:
|
174
|
+
raise RequestError(f"OpenAI request failed: {str(e)}")
|
175
|
+
|
176
|
+
def complete(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
52
177
|
"""
|
53
|
-
|
178
|
+
Send a completion request to OpenAI.
|
54
179
|
|
55
180
|
Args:
|
56
|
-
prompt: The prompt to
|
57
|
-
|
181
|
+
prompt: The prompt to complete.
|
182
|
+
**kwargs: Additional parameters to pass to the OpenAI API.
|
58
183
|
|
59
184
|
Returns:
|
60
|
-
|
185
|
+
A dictionary containing the response from OpenAI.
|
186
|
+
If stream=True and return_generator=True, returns a generator that yields chunks of the response.
|
187
|
+
|
188
|
+
Raises:
|
189
|
+
AuthenticationError: If the API key is invalid.
|
190
|
+
RequestError: If the request fails.
|
191
|
+
RateLimitError: If the rate limit is exceeded.
|
61
192
|
"""
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
)
|
66
|
-
return (
|
67
|
-
input_tokens * self.model_config["inputPricePer1KTokens"] / 1000
|
68
|
-
+ max_output_tokens * self.model_config["outputPricePer1KTokens"] / 1000
|
69
|
-
)
|
193
|
+
# Check if streaming is requested
|
194
|
+
stream = kwargs.pop("stream", False)
|
195
|
+
return_generator = kwargs.pop("return_generator", False)
|
70
196
|
|
71
|
-
|
197
|
+
# For OpenAI, we'll use the chat API for completions as well
|
198
|
+
messages = [{"role": "user", "content": prompt}]
|
199
|
+
|
200
|
+
# If streaming is requested, handle it through the chat method
|
201
|
+
if stream:
|
202
|
+
return self.chat(
|
203
|
+
messages, stream=True, return_generator=return_generator, **kwargs
|
204
|
+
)
|
205
|
+
|
206
|
+
# Otherwise, use the regular chat method
|
207
|
+
return self.chat(messages, **kwargs)
|
208
|
+
|
209
|
+
def embed(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]:
|
72
210
|
"""
|
73
|
-
|
211
|
+
Send an embedding request to OpenAI.
|
74
212
|
|
75
213
|
Args:
|
76
|
-
text:
|
214
|
+
text: The text to embed. Can be a single string or a list of strings.
|
215
|
+
**kwargs: Additional parameters to pass to the OpenAI API.
|
77
216
|
|
78
217
|
Returns:
|
79
|
-
|
218
|
+
A dictionary containing the embeddings from OpenAI.
|
219
|
+
|
220
|
+
Raises:
|
221
|
+
AuthenticationError: If the API key is invalid.
|
222
|
+
RequestError: If the request fails.
|
223
|
+
RateLimitError: If the rate limit is exceeded.
|
80
224
|
"""
|
81
|
-
|
82
|
-
|
225
|
+
try:
|
226
|
+
# Ensure text is a list
|
227
|
+
if isinstance(text, str):
|
228
|
+
text = [text]
|
229
|
+
|
230
|
+
# Use the embedding model
|
231
|
+
response = self.client.embeddings.create(
|
232
|
+
model=self.model_name, input=text, **kwargs
|
233
|
+
)
|
83
234
|
|
84
|
-
|
235
|
+
# Extract embeddings
|
236
|
+
embeddings = [item.embedding for item in response.data]
|
237
|
+
|
238
|
+
# Create a list of embedding objects with the expected structure
|
239
|
+
embedding_objects = []
|
240
|
+
for i, embedding in enumerate(embeddings):
|
241
|
+
embedding_objects.append(
|
242
|
+
{
|
243
|
+
"embedding": embedding,
|
244
|
+
"index": i,
|
245
|
+
"text": text[i] if i < len(text) else "",
|
246
|
+
}
|
247
|
+
)
|
248
|
+
|
249
|
+
# Extract usage information from the response
|
250
|
+
prompt_tokens = (
|
251
|
+
response.usage.prompt_tokens
|
252
|
+
if hasattr(response.usage, "prompt_tokens")
|
253
|
+
else 0
|
254
|
+
)
|
255
|
+
total_tokens = (
|
256
|
+
response.usage.total_tokens
|
257
|
+
if hasattr(response.usage, "total_tokens")
|
258
|
+
else 0
|
259
|
+
)
|
260
|
+
|
261
|
+
embedding_price_per_1k = get_model_info("openai", self.model_name).get(
|
262
|
+
"inputPricePer1KTokens"
|
263
|
+
)
|
264
|
+
|
265
|
+
# Calculate the cost
|
266
|
+
cost = (prompt_tokens / 1000) * embedding_price_per_1k
|
267
|
+
|
268
|
+
# Create usage information
|
269
|
+
usage = {
|
270
|
+
"tokens_prompt": prompt_tokens,
|
271
|
+
"tokens_completion": 0,
|
272
|
+
"tokens_total": total_tokens,
|
273
|
+
"cost": cost,
|
274
|
+
"latency": 0.0, # We don't have latency information from the API
|
275
|
+
"timestamp": datetime.now().isoformat(),
|
276
|
+
}
|
277
|
+
|
278
|
+
return {
|
279
|
+
"data": embedding_objects,
|
280
|
+
"model": self.model_name,
|
281
|
+
"provider": "openai",
|
282
|
+
"success": True,
|
283
|
+
"message": "Successfully generated embeddings",
|
284
|
+
"usage": usage,
|
285
|
+
"raw_response": response.model_dump(),
|
286
|
+
}
|
287
|
+
except openai.AuthenticationError:
|
288
|
+
raise AuthenticationError("Invalid OpenAI API key.")
|
289
|
+
except openai.RateLimitError:
|
290
|
+
raise RateLimitError("OpenAI rate limit exceeded.")
|
291
|
+
except Exception as e:
|
292
|
+
raise RequestError(f"OpenAI embedding request failed: {str(e)}")
|
293
|
+
|
294
|
+
def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
85
295
|
"""
|
86
|
-
Generate a
|
296
|
+
Generate an image from a prompt using OpenAI.
|
87
297
|
|
88
298
|
Args:
|
89
|
-
prompt: The prompt to generate
|
90
|
-
**kwargs: Additional parameters
|
299
|
+
prompt: The prompt to generate an image from.
|
300
|
+
**kwargs: Additional parameters to pass to the OpenAI API.
|
91
301
|
|
92
302
|
Returns:
|
93
|
-
|
303
|
+
A dictionary containing the image URL or data.
|
304
|
+
|
305
|
+
Raises:
|
306
|
+
AuthenticationError: If the API key is invalid.
|
307
|
+
RequestError: If the request fails.
|
308
|
+
RateLimitError: If the rate limit is exceeded.
|
94
309
|
"""
|
95
|
-
|
96
|
-
|
97
|
-
model=
|
98
|
-
|
99
|
-
|
100
|
-
|
310
|
+
try:
|
311
|
+
# Use DALL-E model
|
312
|
+
model = kwargs.get("model", "dall-e-3")
|
313
|
+
size = kwargs.get("size", "1024x1024")
|
314
|
+
quality = kwargs.get("quality", "standard")
|
315
|
+
n = kwargs.get("n", 1)
|
316
|
+
|
317
|
+
response = self.client.images.generate(
|
318
|
+
model=model, prompt=prompt, size=size, quality=quality, n=n
|
319
|
+
)
|
320
|
+
|
321
|
+
# Extract image URLs
|
322
|
+
images = [item.url for item in response.data]
|
323
|
+
|
324
|
+
# For image generation, we don't have token usage, so we'll estimate cost
|
325
|
+
# based on the model and parameters
|
326
|
+
cost = calculate_cost(
|
327
|
+
f"openai/{model}", # e.g., "openai/dall-e-3"
|
328
|
+
input_tokens=n, # Number of images
|
329
|
+
output_tokens=0,
|
330
|
+
)
|
331
|
+
|
332
|
+
# Create usage information
|
333
|
+
usage = {
|
334
|
+
"tokens_prompt": 0, # We don't have token information for images
|
335
|
+
"tokens_completion": 0,
|
336
|
+
"tokens_total": 0,
|
337
|
+
"cost": cost,
|
338
|
+
"latency": 0.0,
|
339
|
+
"timestamp": datetime.now().isoformat(),
|
340
|
+
}
|
341
|
+
|
342
|
+
return {
|
343
|
+
"data": images,
|
344
|
+
"model": model,
|
345
|
+
"provider": "openai",
|
346
|
+
"success": True,
|
347
|
+
"message": "Successfully generated images",
|
348
|
+
"usage": usage,
|
349
|
+
"sizes": [size] * n,
|
350
|
+
"formats": ["url"] * n,
|
351
|
+
"raw_response": response.model_dump(),
|
352
|
+
}
|
353
|
+
|
354
|
+
except openai.AuthenticationError:
|
355
|
+
raise AuthenticationError("Invalid OpenAI API key.")
|
356
|
+
except openai.RateLimitError:
|
357
|
+
raise RateLimitError("OpenAI rate limit exceeded.")
|
358
|
+
except Exception as e:
|
359
|
+
raise RequestError(f"OpenAI image generation request failed: {str(e)}")
|
360
|
+
|
361
|
+
def get_token_count(self, text: str) -> int:
|
362
|
+
"""
|
363
|
+
Get the number of tokens in a text using OpenAI's tokenizer.
|
364
|
+
|
365
|
+
Args:
|
366
|
+
text: The text to count tokens for.
|
367
|
+
|
368
|
+
Returns:
|
369
|
+
The number of tokens in the text.
|
370
|
+
"""
|
371
|
+
try:
|
372
|
+
# Use tiktoken for token counting
|
373
|
+
import tiktoken
|
374
|
+
|
375
|
+
encoding = tiktoken.encoding_for_model(self.model_name)
|
376
|
+
return len(encoding.encode(text))
|
377
|
+
except ImportError:
|
378
|
+
# Fallback to a simple approximation if tiktoken is not available
|
379
|
+
return len(text.split()) * 1.3 # Rough approximation
|
380
|
+
|
381
|
+
def get_model_info(self) -> Dict[str, Any]:
|
382
|
+
"""
|
383
|
+
Get information about the model.
|
384
|
+
|
385
|
+
Returns:
|
386
|
+
A dictionary containing information about the model.
|
387
|
+
"""
|
388
|
+
return self.model_info
|
389
|
+
|
390
|
+
|
391
|
+
class StreamingGenerator:
|
392
|
+
"""
|
393
|
+
A generator class that yields chunks of text from a streaming response
|
394
|
+
and provides methods to get usage information at any point.
|
395
|
+
"""
|
396
|
+
|
397
|
+
def __init__(self, stream_response, model_name, messages):
|
398
|
+
"""
|
399
|
+
Initialize the streaming generator.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
stream_response: The streaming response from the provider.
|
403
|
+
model_name: The name of the model being used.
|
404
|
+
messages: The messages sent to the provider.
|
405
|
+
"""
|
406
|
+
self.stream_response = stream_response
|
407
|
+
self.model_name = model_name
|
408
|
+
self.messages = messages
|
409
|
+
self.full_content = ""
|
410
|
+
self.finish_reason = None
|
411
|
+
self.is_finished = False
|
412
|
+
|
413
|
+
# Try to initialize tiktoken for token counting
|
414
|
+
try:
|
415
|
+
import tiktoken
|
416
|
+
|
417
|
+
self.encoding = tiktoken.encoding_for_model(model_name)
|
418
|
+
self.has_tiktoken = True
|
419
|
+
except (ImportError, Exception):
|
420
|
+
self.has_tiktoken = False
|
421
|
+
|
422
|
+
# Estimate prompt tokens
|
423
|
+
self.prompt_tokens = self._count_prompt_tokens()
|
424
|
+
|
425
|
+
def _count_prompt_tokens(self):
|
426
|
+
"""Count tokens in the prompt messages."""
|
427
|
+
if self.has_tiktoken:
|
428
|
+
# Use tiktoken for accurate token counting
|
429
|
+
prompt_text = " ".join(
|
430
|
+
[
|
431
|
+
msg.get("content", "") if isinstance(msg, dict) else msg.content
|
432
|
+
for msg in self.messages
|
433
|
+
]
|
434
|
+
)
|
435
|
+
return len(self.encoding.encode(prompt_text))
|
436
|
+
else:
|
437
|
+
# Fallback to character-based estimation
|
438
|
+
prompt_text = " ".join(
|
439
|
+
[
|
440
|
+
msg.get("content", "") if isinstance(msg, dict) else msg.content
|
441
|
+
for msg in self.messages
|
442
|
+
]
|
443
|
+
)
|
444
|
+
return len(prompt_text) // 4 # Rough estimate: 4 chars per token
|
445
|
+
|
446
|
+
def _count_completion_tokens(self):
|
447
|
+
"""Count tokens in the completion text."""
|
448
|
+
if self.has_tiktoken:
|
449
|
+
# Use tiktoken for accurate token counting
|
450
|
+
return len(self.encoding.encode(self.full_content))
|
451
|
+
else:
|
452
|
+
# Fallback to character-based estimation
|
453
|
+
return len(self.full_content) // 4 # Rough estimate: 4 chars per token
|
454
|
+
|
455
|
+
def get_usage_info(self):
|
456
|
+
"""
|
457
|
+
Get usage information based on the current state.
|
458
|
+
|
459
|
+
Returns:
|
460
|
+
A dictionary with usage information.
|
461
|
+
"""
|
462
|
+
completion_tokens = self._count_completion_tokens()
|
463
|
+
total_tokens = self.prompt_tokens + completion_tokens
|
101
464
|
|
102
465
|
# Calculate cost
|
103
|
-
|
104
|
-
"
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
total_cost = input_cost + output_cost
|
110
|
-
|
111
|
-
# Return standardized response
|
466
|
+
cost = calculate_cost(
|
467
|
+
f"openai/{self.model_name}",
|
468
|
+
input_tokens=self.prompt_tokens,
|
469
|
+
output_tokens=completion_tokens,
|
470
|
+
)
|
471
|
+
|
112
472
|
return {
|
113
|
-
"text": response.choices[0].message.content,
|
114
|
-
"cost": total_cost,
|
115
473
|
"usage": {
|
116
|
-
"
|
117
|
-
"
|
118
|
-
"
|
119
|
-
+ response.usage.completion_tokens,
|
474
|
+
"tokens_prompt": self.prompt_tokens,
|
475
|
+
"tokens_completion": completion_tokens,
|
476
|
+
"tokens_total": total_tokens,
|
120
477
|
},
|
478
|
+
"cost": cost,
|
121
479
|
"model": self.model_name,
|
480
|
+
"provider": "openai",
|
481
|
+
"finish_reason": self.finish_reason,
|
482
|
+
"is_finished": self.is_finished,
|
122
483
|
}
|
484
|
+
|
485
|
+
def __iter__(self):
|
486
|
+
return self
|
487
|
+
|
488
|
+
def __next__(self):
|
489
|
+
"""Get the next chunk from the stream."""
|
490
|
+
if self.is_finished:
|
491
|
+
raise StopIteration
|
492
|
+
|
493
|
+
try:
|
494
|
+
chunk = next(self.stream_response)
|
495
|
+
|
496
|
+
if hasattr(chunk, "choices") and len(chunk.choices) > 0:
|
497
|
+
# Check for finish reason
|
498
|
+
if (
|
499
|
+
hasattr(chunk.choices[0], "finish_reason")
|
500
|
+
and chunk.choices[0].finish_reason
|
501
|
+
):
|
502
|
+
self.finish_reason = chunk.choices[0].finish_reason
|
503
|
+
|
504
|
+
# Get content delta
|
505
|
+
delta = chunk.choices[0].delta
|
506
|
+
if hasattr(delta, "content") and delta.content is not None:
|
507
|
+
content_chunk = delta.content
|
508
|
+
self.full_content += content_chunk
|
509
|
+
return content_chunk
|
510
|
+
|
511
|
+
# If we got a chunk with no content but with finish_reason, we're done
|
512
|
+
if self.finish_reason:
|
513
|
+
self.is_finished = True
|
514
|
+
raise StopIteration
|
515
|
+
|
516
|
+
# If we got here, try the next chunk
|
517
|
+
return next(self)
|
518
|
+
|
519
|
+
except StopIteration:
|
520
|
+
self.is_finished = True
|
521
|
+
raise
|