edsl 0.1.55__py3-none-any.whl → 0.1.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/base/data_transfer_models.py +15 -4
- edsl/coop/coop.py +3 -3
- edsl/dataset/dataset_operations_mixin.py +216 -180
- edsl/inference_services/services/google_service.py +5 -2
- edsl/interviews/request_token_estimator.py +8 -0
- edsl/invigilators/invigilators.py +26 -13
- edsl/jobs/jobs_pricing_estimation.py +176 -113
- edsl/language_models/language_model.py +24 -6
- edsl/language_models/price_manager.py +171 -36
- edsl/results/result.py +52 -30
- edsl/scenarios/file_store.py +60 -30
- {edsl-0.1.55.dist-info → edsl-0.1.57.dist-info}/METADATA +2 -2
- {edsl-0.1.55.dist-info → edsl-0.1.57.dist-info}/RECORD +17 -17
- {edsl-0.1.55.dist-info → edsl-0.1.57.dist-info}/LICENSE +0 -0
- {edsl-0.1.55.dist-info → edsl-0.1.57.dist-info}/WHEEL +0 -0
- {edsl-0.1.55.dist-info → edsl-0.1.57.dist-info}/entry_points.txt +0 -0
@@ -7,11 +7,13 @@ from google.api_core.exceptions import InvalidArgument
|
|
7
7
|
|
8
8
|
# from ...exceptions.general import MissingAPIKeyError
|
9
9
|
from ..inference_service_abc import InferenceServiceABC
|
10
|
+
|
10
11
|
# Use TYPE_CHECKING to avoid circular imports at runtime
|
11
12
|
if TYPE_CHECKING:
|
12
13
|
from ...language_models import LanguageModel
|
13
14
|
from ....scenarios.file_store import FileStore as Files
|
14
|
-
#from ...coop import Coop
|
15
|
+
# from ...coop import Coop
|
16
|
+
import asyncio
|
15
17
|
|
16
18
|
safety_settings = [
|
17
19
|
{
|
@@ -61,7 +63,7 @@ class GoogleService(InferenceServiceABC):
|
|
61
63
|
@classmethod
|
62
64
|
def create_model(
|
63
65
|
cls, model_name: str = "gemini-pro", model_class_name=None
|
64
|
-
) ->
|
66
|
+
) -> "LanguageModel":
|
65
67
|
if model_class_name is None:
|
66
68
|
model_class_name = cls.to_class_name(model_name)
|
67
69
|
|
@@ -138,6 +140,7 @@ class GoogleService(InferenceServiceABC):
|
|
138
140
|
gen_ai_file = google.generativeai.types.file_types.File(
|
139
141
|
file.external_locations["google"]
|
140
142
|
)
|
143
|
+
|
141
144
|
combined_prompt.append(gen_ai_file)
|
142
145
|
|
143
146
|
try:
|
@@ -124,6 +124,14 @@ class RequestTokenEstimator:
|
|
124
124
|
width, height = file.get_image_dimensions()
|
125
125
|
token_usage = estimate_tokens(model_name, width, height)
|
126
126
|
file_tokens += token_usage
|
127
|
+
if file.is_video():
|
128
|
+
model_name = self.interview.model.model
|
129
|
+
duration = file.get_video_metadata()["simplified"][
|
130
|
+
"duration_seconds"
|
131
|
+
]
|
132
|
+
file_tokens += (
|
133
|
+
duration * 295
|
134
|
+
) # (295 tokens per second for video + audio)
|
127
135
|
else:
|
128
136
|
file_tokens += file.size * 0.25
|
129
137
|
else:
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Module for creating Invigilators, which are objects to administer a question to an Agent."""
|
2
|
+
|
2
3
|
from abc import ABC, abstractmethod
|
3
4
|
import asyncio
|
4
5
|
from typing import Coroutine, Dict, Any, Optional, TYPE_CHECKING
|
@@ -395,17 +396,21 @@ class InvigilatorAI(InvigilatorBase):
|
|
395
396
|
|
396
397
|
if agent_response_dict.model_outputs.cache_used and False:
|
397
398
|
data = {
|
398
|
-
"answer":
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
399
|
+
"answer": (
|
400
|
+
agent_response_dict.edsl_dict.answer
|
401
|
+
if type(agent_response_dict.edsl_dict.answer) is str
|
402
|
+
or type(agent_response_dict.edsl_dict.answer) is dict
|
403
|
+
or type(agent_response_dict.edsl_dict.answer) is list
|
404
|
+
or type(agent_response_dict.edsl_dict.answer) is int
|
405
|
+
or type(agent_response_dict.edsl_dict.answer) is float
|
406
|
+
or type(agent_response_dict.edsl_dict.answer) is bool
|
407
|
+
else ""
|
408
|
+
),
|
409
|
+
"comment": (
|
410
|
+
agent_response_dict.edsl_dict.comment
|
411
|
+
if agent_response_dict.edsl_dict.comment
|
412
|
+
else ""
|
413
|
+
),
|
409
414
|
"generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
|
410
415
|
"question_name": self.question.question_name,
|
411
416
|
"prompts": self.get_prompts(),
|
@@ -415,7 +420,11 @@ class InvigilatorAI(InvigilatorBase):
|
|
415
420
|
"cache_key": agent_response_dict.model_outputs.cache_key,
|
416
421
|
"validated": True,
|
417
422
|
"exception_occurred": exception_occurred,
|
418
|
-
"
|
423
|
+
"input_tokens": agent_response_dict.model_outputs.input_tokens,
|
424
|
+
"output_tokens": agent_response_dict.model_outputs.output_tokens,
|
425
|
+
"input_price_per_million_tokens": agent_response_dict.model_outputs.input_price_per_million_tokens,
|
426
|
+
"output_price_per_million_tokens": agent_response_dict.model_outputs.output_price_per_million_tokens,
|
427
|
+
"total_cost": agent_response_dict.model_outputs.total_cost,
|
419
428
|
}
|
420
429
|
|
421
430
|
result = EDSLResultObjectInput(**data)
|
@@ -480,7 +489,11 @@ class InvigilatorAI(InvigilatorBase):
|
|
480
489
|
"cache_key": agent_response_dict.model_outputs.cache_key,
|
481
490
|
"validated": validated,
|
482
491
|
"exception_occurred": exception_occurred,
|
483
|
-
"
|
492
|
+
"input_tokens": agent_response_dict.model_outputs.input_tokens,
|
493
|
+
"output_tokens": agent_response_dict.model_outputs.output_tokens,
|
494
|
+
"input_price_per_million_tokens": agent_response_dict.model_outputs.input_price_per_million_tokens,
|
495
|
+
"output_price_per_million_tokens": agent_response_dict.model_outputs.output_price_per_million_tokens,
|
496
|
+
"total_cost": agent_response_dict.model_outputs.total_cost,
|
484
497
|
}
|
485
498
|
result = EDSLResultObjectInput(**data)
|
486
499
|
return result
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
import math
|
3
3
|
|
4
|
-
from typing import List, TYPE_CHECKING, Union, Literal
|
4
|
+
from typing import List, TYPE_CHECKING, Union, Literal, Dict
|
5
|
+
from collections import namedtuple
|
5
6
|
|
6
7
|
if TYPE_CHECKING:
|
7
8
|
from .jobs import Jobs
|
@@ -20,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|
20
21
|
|
21
22
|
class PromptCostEstimator:
|
22
23
|
|
23
|
-
|
24
|
-
|
24
|
+
DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
25
|
+
DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
25
26
|
CHARS_PER_TOKEN = 4
|
26
27
|
OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
|
27
28
|
PIPING_MULTIPLIER = 2
|
@@ -48,81 +49,90 @@ class PromptCostEstimator:
|
|
48
49
|
return PromptCostEstimator.PIPING_MULTIPLIER
|
49
50
|
return 1
|
50
51
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
def relevant_prices(self):
|
57
|
-
try:
|
58
|
-
return self.price_lookup[self.key]
|
59
|
-
except KeyError:
|
60
|
-
return {}
|
61
|
-
|
62
|
-
def _get_highest_price_for_service(self, price_type: str) -> Union[float, None]:
|
63
|
-
"""Returns the highest price per token for a given service and price type (input/output).
|
52
|
+
def _get_fallback_price(self, inference_service: str) -> Dict:
|
53
|
+
"""
|
54
|
+
Get fallback prices for a service.
|
55
|
+
- First fallback: The highest input and output prices for that service from the price lookup.
|
56
|
+
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
57
|
|
65
58
|
Args:
|
66
|
-
|
59
|
+
inference_service (str): The inference service name
|
67
60
|
|
68
61
|
Returns:
|
69
|
-
|
62
|
+
Dict: Price information
|
70
63
|
"""
|
71
|
-
|
72
|
-
|
73
|
-
|
64
|
+
PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
|
65
|
+
|
66
|
+
service_prices = [
|
67
|
+
prices
|
74
68
|
for (service, _), prices in self.price_lookup.items()
|
75
|
-
if service ==
|
69
|
+
if service == inference_service
|
76
70
|
]
|
77
|
-
return max(prices_for_service) if prices_for_service else None
|
78
71
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
if highest_price is not None:
|
90
|
-
import warnings
|
91
|
-
|
92
|
-
warnings.warn(
|
93
|
-
f"Price data not found for {self.key}. Using highest available input price for {self.inference_service}: ${highest_price:.6f} per token"
|
94
|
-
)
|
95
|
-
return highest_price, "highest_price_for_service"
|
96
|
-
import warnings
|
72
|
+
default_input_price_info = {
|
73
|
+
"one_usd_buys": 1_000_000,
|
74
|
+
"service_stated_token_qty": 1_000_000,
|
75
|
+
"service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
|
76
|
+
}
|
77
|
+
default_output_price_info = {
|
78
|
+
"one_usd_buys": 1_000_000,
|
79
|
+
"service_stated_token_qty": 1_000_000,
|
80
|
+
"service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
|
81
|
+
}
|
97
82
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
return self.DEFAULT_INPUT_PRICE_PER_TOKEN, "default"
|
83
|
+
# Find the most expensive price entries (lowest tokens per USD)
|
84
|
+
input_price_info = default_input_price_info
|
85
|
+
output_price_info = default_output_price_info
|
102
86
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
87
|
+
input_prices = [
|
88
|
+
PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
|
89
|
+
for p in service_prices
|
90
|
+
if "input" in p
|
91
|
+
]
|
92
|
+
if input_prices:
|
93
|
+
input_price_info = min(
|
94
|
+
input_prices, key=lambda price: price.tokens_per_usd
|
95
|
+
).price_info
|
96
|
+
|
97
|
+
output_prices = [
|
98
|
+
PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
|
99
|
+
for p in service_prices
|
100
|
+
if "output" in p
|
101
|
+
]
|
102
|
+
if output_prices:
|
103
|
+
output_price_info = min(
|
104
|
+
output_prices, key=lambda price: price.tokens_per_usd
|
105
|
+
).price_info
|
121
106
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
107
|
+
return {
|
108
|
+
"input": input_price_info,
|
109
|
+
"output": output_price_info,
|
110
|
+
}
|
111
|
+
|
112
|
+
def get_price(self, inference_service: str, model: str) -> Dict:
|
113
|
+
"""Get the price information for a specific service and model."""
|
114
|
+
key = (inference_service, model)
|
115
|
+
return self.price_lookup.get(key) or self._get_fallback_price(inference_service)
|
116
|
+
|
117
|
+
def get_price_per_million_tokens(
|
118
|
+
self,
|
119
|
+
relevant_prices: Dict,
|
120
|
+
token_type: Literal["input", "output"],
|
121
|
+
) -> Dict:
|
122
|
+
"""
|
123
|
+
Get the price per million tokens for a specific service, model, and token type.
|
124
|
+
"""
|
125
|
+
service_price = relevant_prices[token_type]["service_stated_token_price"]
|
126
|
+
service_qty = relevant_prices[token_type]["service_stated_token_qty"]
|
127
|
+
|
128
|
+
if service_qty == 1_000_000:
|
129
|
+
price_per_million_tokens = service_price
|
130
|
+
elif service_qty == 1_000:
|
131
|
+
price_per_million_tokens = service_price * 1_000
|
132
|
+
else:
|
133
|
+
price_per_token = service_price / service_qty
|
134
|
+
price_per_million_tokens = round(price_per_token * 1_000_000, 10)
|
135
|
+
return price_per_million_tokens
|
126
136
|
|
127
137
|
def __call__(self):
|
128
138
|
user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
|
@@ -135,20 +145,28 @@ class PromptCostEstimator:
|
|
135
145
|
input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
|
136
146
|
output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
|
137
147
|
|
138
|
-
|
139
|
-
output_price_per_token, output_price_source = self.output_price_per_token()
|
148
|
+
relevant_prices = self.get_price(self.inference_service, self.model)
|
140
149
|
|
141
|
-
|
142
|
-
|
143
|
-
|
150
|
+
input_price_per_million_tokens = self.get_price_per_million_tokens(
|
151
|
+
relevant_prices, "input"
|
152
|
+
)
|
153
|
+
output_price_per_million_tokens = self.get_price_per_million_tokens(
|
154
|
+
relevant_prices, "output"
|
144
155
|
)
|
156
|
+
|
157
|
+
input_price_per_token = input_price_per_million_tokens / 1_000_000
|
158
|
+
output_price_per_token = output_price_per_million_tokens / 1_000_000
|
159
|
+
|
160
|
+
input_cost = input_tokens * input_price_per_token
|
161
|
+
output_cost = output_tokens * output_price_per_token
|
162
|
+
cost = input_cost + output_cost
|
145
163
|
return {
|
146
|
-
"
|
147
|
-
"
|
164
|
+
"input_price_per_million_tokens": input_price_per_million_tokens,
|
165
|
+
"output_price_per_million_tokens": output_price_per_million_tokens,
|
148
166
|
"input_tokens": input_tokens,
|
149
|
-
"output_price_source": output_price_source,
|
150
167
|
"output_tokens": output_tokens,
|
151
|
-
"
|
168
|
+
"input_cost_usd": input_cost,
|
169
|
+
"output_cost_usd": output_cost,
|
152
170
|
"cost_usd": cost,
|
153
171
|
}
|
154
172
|
|
@@ -328,6 +346,40 @@ class JobsPrompts:
|
|
328
346
|
"model": model,
|
329
347
|
}
|
330
348
|
|
349
|
+
def process_token_type(self, item: dict, token_type: str) -> tuple:
|
350
|
+
"""
|
351
|
+
Helper function to process a single token type (input or output) for price estimation.
|
352
|
+
"""
|
353
|
+
price = item[f"estimated_{token_type}_price_per_million_tokens"]
|
354
|
+
tokens = item[f"estimated_{token_type}_tokens"]
|
355
|
+
cost = item[f"estimated_{token_type}_cost_usd"]
|
356
|
+
|
357
|
+
return (
|
358
|
+
(item["inference_service"], item["model"], token_type, price),
|
359
|
+
{
|
360
|
+
"inference_service": item["inference_service"],
|
361
|
+
"model": item["model"],
|
362
|
+
"token_type": token_type,
|
363
|
+
"price_per_million_tokens": price,
|
364
|
+
"tokens": tokens,
|
365
|
+
"cost_usd": cost,
|
366
|
+
},
|
367
|
+
)
|
368
|
+
|
369
|
+
@staticmethod
|
370
|
+
def usd_to_credits(usd: float) -> float:
|
371
|
+
"""Converts USD to credits."""
|
372
|
+
cents = usd * 100
|
373
|
+
credits_per_cent = 1
|
374
|
+
credits = cents * credits_per_cent
|
375
|
+
|
376
|
+
# Round up to the nearest hundredth of a credit
|
377
|
+
minicredits = math.ceil(credits * 100)
|
378
|
+
|
379
|
+
# Convert back to credits
|
380
|
+
credits = round(minicredits / 100, 2)
|
381
|
+
return credits
|
382
|
+
|
331
383
|
def estimate_job_cost_from_external_prices(
|
332
384
|
self, price_lookup: dict, iterations: int = 1
|
333
385
|
) -> dict:
|
@@ -341,9 +393,9 @@ class JobsPrompts:
|
|
341
393
|
- 1 token = 4 characters.
|
342
394
|
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
343
395
|
"""
|
344
|
-
|
396
|
+
# Collect all prompt data
|
345
397
|
data = []
|
346
|
-
for interview in interviews:
|
398
|
+
for interview in self.interviews:
|
347
399
|
invigilators = [
|
348
400
|
FetchInvigilator(interview)(question)
|
349
401
|
for question in self.survey.questions
|
@@ -354,59 +406,70 @@ class JobsPrompts:
|
|
354
406
|
**prompt_details, price_lookup=price_lookup
|
355
407
|
)
|
356
408
|
price_estimates = {
|
409
|
+
"estimated_input_price_per_million_tokens": prompt_cost[
|
410
|
+
"input_price_per_million_tokens"
|
411
|
+
],
|
412
|
+
"estimated_output_price_per_million_tokens": prompt_cost[
|
413
|
+
"output_price_per_million_tokens"
|
414
|
+
],
|
357
415
|
"estimated_input_tokens": prompt_cost["input_tokens"],
|
358
416
|
"estimated_output_tokens": prompt_cost["output_tokens"],
|
417
|
+
"estimated_input_cost_usd": prompt_cost["input_cost_usd"],
|
418
|
+
"estimated_output_cost_usd": prompt_cost["output_cost_usd"],
|
359
419
|
"estimated_cost_usd": prompt_cost["cost_usd"],
|
360
420
|
}
|
361
|
-
data.append(
|
421
|
+
data.append(
|
422
|
+
{
|
423
|
+
**prompt_details,
|
424
|
+
**price_estimates,
|
425
|
+
}
|
426
|
+
)
|
362
427
|
|
363
|
-
|
428
|
+
# Group by service, model, token type, and price
|
429
|
+
detailed_groups = {}
|
364
430
|
for item in data:
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
"
|
371
|
-
"
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
]
|
383
|
-
|
384
|
-
# Apply iterations and convert to list
|
385
|
-
estimated_costs_by_model = []
|
386
|
-
for group_data in model_groups.values():
|
387
|
-
group_data["estimated_cost_usd"] *= iterations
|
388
|
-
group_data["estimated_input_tokens"] *= iterations
|
389
|
-
group_data["estimated_output_tokens"] *= iterations
|
390
|
-
estimated_costs_by_model.append(group_data)
|
431
|
+
for token_type in ["input", "output"]:
|
432
|
+
key, group_data = self.process_token_type(item, token_type)
|
433
|
+
if key not in detailed_groups:
|
434
|
+
detailed_groups[key] = group_data
|
435
|
+
else:
|
436
|
+
detailed_groups[key]["tokens"] += group_data["tokens"]
|
437
|
+
detailed_groups[key]["cost_usd"] += group_data["cost_usd"]
|
438
|
+
|
439
|
+
# Apply iterations and prepare final output
|
440
|
+
detailed_costs = []
|
441
|
+
for group in detailed_groups.values():
|
442
|
+
group["tokens"] *= iterations
|
443
|
+
group["cost_usd"] *= iterations
|
444
|
+
detailed_costs.append(group)
|
445
|
+
|
446
|
+
# Convert to credits
|
447
|
+
for group in detailed_costs:
|
448
|
+
group["credits_hold"] = self.usd_to_credits(group["cost_usd"])
|
391
449
|
|
392
450
|
# Calculate totals
|
393
|
-
|
394
|
-
|
451
|
+
estimated_total_cost_usd = sum(group["cost_usd"] for group in detailed_costs)
|
452
|
+
total_credits_hold = sum(
|
453
|
+
group["credits_hold"] for group in detailed_costs
|
395
454
|
)
|
396
455
|
estimated_total_input_tokens = sum(
|
397
|
-
|
456
|
+
group["tokens"]
|
457
|
+
for group in detailed_costs
|
458
|
+
if group["token_type"] == "input"
|
398
459
|
)
|
399
460
|
estimated_total_output_tokens = sum(
|
400
|
-
|
461
|
+
group["tokens"]
|
462
|
+
for group in detailed_costs
|
463
|
+
if group["token_type"] == "output"
|
401
464
|
)
|
402
465
|
|
403
466
|
output = {
|
404
|
-
"estimated_total_cost_usd":
|
467
|
+
"estimated_total_cost_usd": estimated_total_cost_usd,
|
468
|
+
"total_credits_hold": total_credits_hold,
|
405
469
|
"estimated_total_input_tokens": estimated_total_input_tokens,
|
406
470
|
"estimated_total_output_tokens": estimated_total_output_tokens,
|
407
|
-
"
|
471
|
+
"detailed_costs": detailed_costs,
|
408
472
|
}
|
409
|
-
|
410
473
|
return output
|
411
474
|
|
412
475
|
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
@@ -49,6 +49,7 @@ from ..data_transfer_models import (
|
|
49
49
|
)
|
50
50
|
|
51
51
|
if TYPE_CHECKING:
|
52
|
+
from .price_manager import ResponseCost
|
52
53
|
from ..caching import Cache
|
53
54
|
from ..scenarios import FileStore
|
54
55
|
from ..questions import QuestionBase
|
@@ -782,13 +783,18 @@ class LanguageModel(
|
|
782
783
|
# Calculate cost for the response
|
783
784
|
cost = self.cost(response)
|
784
785
|
# Return a structured response with metadata
|
785
|
-
|
786
|
+
response = ModelResponse(
|
786
787
|
response=response,
|
787
788
|
cache_used=cache_used,
|
788
789
|
cache_key=cache_key,
|
789
790
|
cached_response=cached_response,
|
790
|
-
|
791
|
+
input_tokens=cost.input_tokens,
|
792
|
+
output_tokens=cost.output_tokens,
|
793
|
+
input_price_per_million_tokens=cost.input_price_per_million_tokens,
|
794
|
+
output_price_per_million_tokens=cost.output_price_per_million_tokens,
|
795
|
+
total_cost=cost.total_cost,
|
791
796
|
)
|
797
|
+
return response
|
792
798
|
|
793
799
|
_get_intended_model_call_outcome = sync_wrapper(
|
794
800
|
_async_get_intended_model_call_outcome
|
@@ -881,7 +887,7 @@ class LanguageModel(
|
|
881
887
|
|
882
888
|
get_response = sync_wrapper(async_get_response)
|
883
889
|
|
884
|
-
def cost(self, raw_response: dict[str, Any]) ->
|
890
|
+
def cost(self, raw_response: dict[str, Any]) -> ResponseCost:
|
885
891
|
"""Calculate the monetary cost of a model API call.
|
886
892
|
|
887
893
|
This method extracts token usage information from the response and
|
@@ -892,7 +898,7 @@ class LanguageModel(
|
|
892
898
|
raw_response: The complete response dictionary from the model API
|
893
899
|
|
894
900
|
Returns:
|
895
|
-
|
901
|
+
ResponseCost: Object containing token counts and total cost
|
896
902
|
"""
|
897
903
|
# Extract token usage data from the response
|
898
904
|
usage = self.get_usage_dict(raw_response)
|
@@ -1147,13 +1153,25 @@ class LanguageModel(
|
|
1147
1153
|
}
|
1148
1154
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
1149
1155
|
response = json.loads(cached_response)
|
1150
|
-
|
1156
|
+
|
1157
|
+
try:
|
1158
|
+
usage = self.get_usage_dict(response)
|
1159
|
+
input_tokens = int(usage[self.input_token_name])
|
1160
|
+
output_tokens = int(usage[self.output_token_name])
|
1161
|
+
except Exception as e:
|
1162
|
+
print(f"Could not fetch tokens from model response: {e}")
|
1163
|
+
input_tokens = None
|
1164
|
+
output_tokens = None
|
1151
1165
|
return ModelResponse(
|
1152
1166
|
response=response,
|
1153
1167
|
cache_used=True,
|
1154
1168
|
cache_key=cache_key,
|
1155
1169
|
cached_response=cached_response,
|
1156
|
-
|
1170
|
+
input_tokens=input_tokens,
|
1171
|
+
output_tokens=output_tokens,
|
1172
|
+
input_price_per_million_tokens=0,
|
1173
|
+
output_price_per_million_tokens=0,
|
1174
|
+
total_cost=0,
|
1157
1175
|
)
|
1158
1176
|
|
1159
1177
|
# Bind the new method to the copied instance
|