edsl 0.1.56__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/coop/coop.py +174 -37
- edsl/coop/utils.py +63 -0
- edsl/interviews/exception_tracking.py +26 -0
- edsl/jobs/html_table_job_logger.py +377 -48
- edsl/jobs/jobs_pricing_estimation.py +19 -96
- edsl/jobs/jobs_remote_inference_logger.py +27 -0
- edsl/jobs/jobs_runner_status.py +52 -21
- edsl/jobs/remote_inference.py +187 -30
- edsl/language_models/language_model.py +1 -1
- edsl/language_models/price_manager.py +91 -57
- {edsl-0.1.56.dist-info → edsl-0.1.58.dist-info}/METADATA +2 -2
- {edsl-0.1.56.dist-info → edsl-0.1.58.dist-info}/RECORD +16 -17
- edsl/language_models/compute_cost.py +0 -78
- {edsl-0.1.56.dist-info → edsl-0.1.58.dist-info}/LICENSE +0 -0
- {edsl-0.1.56.dist-info → edsl-0.1.58.dist-info}/WHEEL +0 -0
- {edsl-0.1.56.dist-info → edsl-0.1.58.dist-info}/entry_points.txt +0 -0
@@ -13,16 +13,15 @@ if TYPE_CHECKING:
|
|
13
13
|
from ..invigilators.invigilator_base import Invigilator
|
14
14
|
|
15
15
|
from .fetch_invigilator import FetchInvigilator
|
16
|
+
from ..coop.utils import CostConverter
|
16
17
|
from ..caching import CacheEntry
|
17
18
|
from ..dataset import Dataset
|
19
|
+
from ..language_models.price_manager import PriceRetriever
|
18
20
|
|
19
21
|
logger = logging.getLogger(__name__)
|
20
22
|
|
21
23
|
|
22
24
|
class PromptCostEstimator:
|
23
|
-
|
24
|
-
DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
25
|
-
DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
26
25
|
CHARS_PER_TOKEN = 4
|
27
26
|
OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
|
28
27
|
PIPING_MULTIPLIER = 2
|
@@ -37,7 +36,7 @@ class PromptCostEstimator:
|
|
37
36
|
):
|
38
37
|
self.system_prompt = system_prompt
|
39
38
|
self.user_prompt = user_prompt
|
40
|
-
self.
|
39
|
+
self.price_retriever = PriceRetriever(price_lookup)
|
41
40
|
self.inference_service = inference_service
|
42
41
|
self.model = model
|
43
42
|
|
@@ -49,91 +48,6 @@ class PromptCostEstimator:
|
|
49
48
|
return PromptCostEstimator.PIPING_MULTIPLIER
|
50
49
|
return 1
|
51
50
|
|
52
|
-
def _get_fallback_price(self, inference_service: str) -> Dict:
|
53
|
-
"""
|
54
|
-
Get fallback prices for a service.
|
55
|
-
- First fallback: The highest input and output prices for that service from the price lookup.
|
56
|
-
- Second fallback: $1.00 per million tokens (for both input and output).
|
57
|
-
|
58
|
-
Args:
|
59
|
-
inference_service (str): The inference service name
|
60
|
-
|
61
|
-
Returns:
|
62
|
-
Dict: Price information
|
63
|
-
"""
|
64
|
-
PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
|
65
|
-
|
66
|
-
service_prices = [
|
67
|
-
prices
|
68
|
-
for (service, _), prices in self.price_lookup.items()
|
69
|
-
if service == inference_service
|
70
|
-
]
|
71
|
-
|
72
|
-
default_input_price_info = {
|
73
|
-
"one_usd_buys": 1_000_000,
|
74
|
-
"service_stated_token_qty": 1_000_000,
|
75
|
-
"service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
|
76
|
-
}
|
77
|
-
default_output_price_info = {
|
78
|
-
"one_usd_buys": 1_000_000,
|
79
|
-
"service_stated_token_qty": 1_000_000,
|
80
|
-
"service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
|
81
|
-
}
|
82
|
-
|
83
|
-
# Find the most expensive price entries (lowest tokens per USD)
|
84
|
-
input_price_info = default_input_price_info
|
85
|
-
output_price_info = default_output_price_info
|
86
|
-
|
87
|
-
input_prices = [
|
88
|
-
PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
|
89
|
-
for p in service_prices
|
90
|
-
if "input" in p
|
91
|
-
]
|
92
|
-
if input_prices:
|
93
|
-
input_price_info = min(
|
94
|
-
input_prices, key=lambda price: price.tokens_per_usd
|
95
|
-
).price_info
|
96
|
-
|
97
|
-
output_prices = [
|
98
|
-
PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
|
99
|
-
for p in service_prices
|
100
|
-
if "output" in p
|
101
|
-
]
|
102
|
-
if output_prices:
|
103
|
-
output_price_info = min(
|
104
|
-
output_prices, key=lambda price: price.tokens_per_usd
|
105
|
-
).price_info
|
106
|
-
|
107
|
-
return {
|
108
|
-
"input": input_price_info,
|
109
|
-
"output": output_price_info,
|
110
|
-
}
|
111
|
-
|
112
|
-
def get_price(self, inference_service: str, model: str) -> Dict:
|
113
|
-
"""Get the price information for a specific service and model."""
|
114
|
-
key = (inference_service, model)
|
115
|
-
return self.price_lookup.get(key) or self._get_fallback_price(inference_service)
|
116
|
-
|
117
|
-
def get_price_per_million_tokens(
|
118
|
-
self,
|
119
|
-
relevant_prices: Dict,
|
120
|
-
token_type: Literal["input", "output"],
|
121
|
-
) -> Dict:
|
122
|
-
"""
|
123
|
-
Get the price per million tokens for a specific service, model, and token type.
|
124
|
-
"""
|
125
|
-
service_price = relevant_prices[token_type]["service_stated_token_price"]
|
126
|
-
service_qty = relevant_prices[token_type]["service_stated_token_qty"]
|
127
|
-
|
128
|
-
if service_qty == 1_000_000:
|
129
|
-
price_per_million_tokens = service_price
|
130
|
-
elif service_qty == 1_000:
|
131
|
-
price_per_million_tokens = service_price * 1_000
|
132
|
-
else:
|
133
|
-
price_per_token = service_price / service_qty
|
134
|
-
price_per_million_tokens = round(price_per_token * 1_000_000, 10)
|
135
|
-
return price_per_million_tokens
|
136
|
-
|
137
51
|
def __call__(self):
|
138
52
|
user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
|
139
53
|
str(self.user_prompt)
|
@@ -145,13 +59,15 @@ class PromptCostEstimator:
|
|
145
59
|
input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
|
146
60
|
output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
|
147
61
|
|
148
|
-
relevant_prices = self.get_price(
|
62
|
+
relevant_prices = self.price_retriever.get_price(
|
63
|
+
self.inference_service, self.model
|
64
|
+
)
|
149
65
|
|
150
|
-
input_price_per_million_tokens =
|
151
|
-
relevant_prices, "input"
|
66
|
+
input_price_per_million_tokens = (
|
67
|
+
self.price_retriever.get_price_per_million_tokens(relevant_prices, "input")
|
152
68
|
)
|
153
|
-
output_price_per_million_tokens =
|
154
|
-
relevant_prices, "output"
|
69
|
+
output_price_per_million_tokens = (
|
70
|
+
self.price_retriever.get_price_per_million_tokens(relevant_prices, "output")
|
155
71
|
)
|
156
72
|
|
157
73
|
input_price_per_token = input_price_per_million_tokens / 1_000_000
|
@@ -429,8 +345,14 @@ class JobsPrompts:
|
|
429
345
|
group["cost_usd"] *= iterations
|
430
346
|
detailed_costs.append(group)
|
431
347
|
|
348
|
+
# Convert to credits
|
349
|
+
converter = CostConverter()
|
350
|
+
for group in detailed_costs:
|
351
|
+
group["credits_hold"] = converter.usd_to_credits(group["cost_usd"])
|
352
|
+
|
432
353
|
# Calculate totals
|
433
|
-
|
354
|
+
estimated_total_cost_usd = sum(group["cost_usd"] for group in detailed_costs)
|
355
|
+
total_credits_hold = sum(group["credits_hold"] for group in detailed_costs)
|
434
356
|
estimated_total_input_tokens = sum(
|
435
357
|
group["tokens"]
|
436
358
|
for group in detailed_costs
|
@@ -443,7 +365,8 @@ class JobsPrompts:
|
|
443
365
|
)
|
444
366
|
|
445
367
|
output = {
|
446
|
-
"estimated_total_cost_usd":
|
368
|
+
"estimated_total_cost_usd": estimated_total_cost_usd,
|
369
|
+
"total_credits_hold": total_credits_hold,
|
447
370
|
"estimated_total_input_tokens": estimated_total_input_tokens,
|
448
371
|
"estimated_total_output_tokens": estimated_total_output_tokens,
|
449
372
|
"detailed_costs": detailed_costs,
|
@@ -23,15 +23,38 @@ class LogMessage:
|
|
23
23
|
status: JobsStatus
|
24
24
|
|
25
25
|
|
26
|
+
@dataclass
|
27
|
+
class JobRunExceptionCounter:
|
28
|
+
exception_type: str = None
|
29
|
+
inference_service: str = None
|
30
|
+
model: str = None
|
31
|
+
question_name: str = None
|
32
|
+
exception_count: int = None
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class ModelCost:
|
37
|
+
service: str = None
|
38
|
+
model: str = None
|
39
|
+
input_tokens: int = None
|
40
|
+
input_cost_usd: float = None
|
41
|
+
output_tokens: int = None
|
42
|
+
output_cost_usd: float = None
|
43
|
+
|
44
|
+
|
26
45
|
@dataclass
|
27
46
|
class JobsInfo:
|
28
47
|
job_uuid: str = None
|
29
48
|
progress_bar_url: str = None
|
30
49
|
error_report_url: str = None
|
50
|
+
remote_inference_url: str = None
|
51
|
+
remote_cache_url: str = None
|
31
52
|
results_uuid: str = None
|
32
53
|
results_url: str = None
|
33
54
|
completed_interviews: int = None
|
34
55
|
failed_interviews: int = None
|
56
|
+
exception_summary: list[JobRunExceptionCounter] = None
|
57
|
+
model_costs: list[ModelCost] = None
|
35
58
|
|
36
59
|
pretty_names = {
|
37
60
|
"job_uuid": "Job UUID",
|
@@ -39,6 +62,8 @@ class JobsInfo:
|
|
39
62
|
"error_report_url": "Exceptions Report URL",
|
40
63
|
"results_uuid": "Results UUID",
|
41
64
|
"results_url": "Results URL",
|
65
|
+
"remote_inference_url": "Remote Jobs",
|
66
|
+
"remote_cache_url": "Remote Cache",
|
42
67
|
}
|
43
68
|
|
44
69
|
|
@@ -57,6 +82,8 @@ class JobLogger(ABC):
|
|
57
82
|
"results_url",
|
58
83
|
"completed_interviews",
|
59
84
|
"failed_interviews",
|
85
|
+
"model_costs",
|
86
|
+
"exception_summary",
|
60
87
|
],
|
61
88
|
value: str,
|
62
89
|
):
|
edsl/jobs/jobs_runner_status.py
CHANGED
@@ -11,15 +11,7 @@ from uuid import UUID
|
|
11
11
|
|
12
12
|
if TYPE_CHECKING:
|
13
13
|
from .jobs import Jobs
|
14
|
-
|
15
|
-
|
16
|
-
@dataclass
|
17
|
-
class ModelInfo:
|
18
|
-
model_name: str
|
19
|
-
TPM_limit_k: float
|
20
|
-
RPM_limit_k: float
|
21
|
-
num_tasks_waiting: int
|
22
|
-
token_usage_info: dict
|
14
|
+
from ..interviews import Interview
|
23
15
|
|
24
16
|
|
25
17
|
class StatisticsTracker:
|
@@ -29,16 +21,33 @@ class StatisticsTracker:
|
|
29
21
|
self.completed_count = 0
|
30
22
|
self.completed_by_model = defaultdict(int)
|
31
23
|
self.distinct_models = distinct_models
|
24
|
+
self.interviews_with_exceptions = 0
|
32
25
|
self.total_exceptions = 0
|
33
26
|
self.unfixed_exceptions = 0
|
27
|
+
self.exceptions_counter = defaultdict(int)
|
34
28
|
|
35
29
|
def add_completed_interview(
|
36
|
-
self,
|
30
|
+
self,
|
31
|
+
model: str,
|
32
|
+
exceptions: list[dict],
|
33
|
+
num_exceptions: int = 0,
|
34
|
+
num_unfixed: int = 0,
|
37
35
|
):
|
38
36
|
self.completed_count += 1
|
39
37
|
self.completed_by_model[model] += 1
|
40
38
|
self.total_exceptions += num_exceptions
|
41
39
|
self.unfixed_exceptions += num_unfixed
|
40
|
+
if num_exceptions > 0:
|
41
|
+
self.interviews_with_exceptions += 1
|
42
|
+
|
43
|
+
for exception in exceptions:
|
44
|
+
key = (
|
45
|
+
exception["exception_type"],
|
46
|
+
exception["inference_service"],
|
47
|
+
exception["model"],
|
48
|
+
exception["question_name"],
|
49
|
+
)
|
50
|
+
self.exceptions_counter[key] += 1
|
42
51
|
|
43
52
|
def get_elapsed_time(self) -> float:
|
44
53
|
return time.time() - self.start_time
|
@@ -88,9 +97,7 @@ class JobsRunnerStatusBase(ABC):
|
|
88
97
|
]
|
89
98
|
self.num_total_interviews = n * len(self.jobs)
|
90
99
|
|
91
|
-
self.distinct_models = list(
|
92
|
-
set(model.model for model in self.jobs.models)
|
93
|
-
)
|
100
|
+
self.distinct_models = list(set(model.model for model in self.jobs.models))
|
94
101
|
|
95
102
|
self.stats_tracker = StatisticsTracker(
|
96
103
|
total_interviews=self.num_total_interviews,
|
@@ -130,6 +137,7 @@ class JobsRunnerStatusBase(ABC):
|
|
130
137
|
status_dict = {
|
131
138
|
"overall_progress": {
|
132
139
|
"completed": self.stats_tracker.completed_count,
|
140
|
+
"has_exceptions": self.stats_tracker.interviews_with_exceptions,
|
133
141
|
"total": self.num_total_interviews,
|
134
142
|
"percent": (
|
135
143
|
(
|
@@ -148,16 +156,36 @@ class JobsRunnerStatusBase(ABC):
|
|
148
156
|
if self.stats_tracker.completed_count >= self.num_total_interviews
|
149
157
|
else "running"
|
150
158
|
),
|
159
|
+
"exceptions_counter": [
|
160
|
+
{
|
161
|
+
"exception_type": exception_type,
|
162
|
+
"inference_service": inference_service,
|
163
|
+
"model": model,
|
164
|
+
"question_name": question_name,
|
165
|
+
"count": count,
|
166
|
+
}
|
167
|
+
for (
|
168
|
+
exception_type,
|
169
|
+
inference_service,
|
170
|
+
model,
|
171
|
+
question_name,
|
172
|
+
), count in self.stats_tracker.exceptions_counter.items()
|
173
|
+
],
|
151
174
|
}
|
152
175
|
|
153
176
|
model_queues = {}
|
154
177
|
# Check if bucket collection exists and is not empty
|
155
|
-
if (
|
156
|
-
hasattr(self.jobs
|
157
|
-
hasattr(self.jobs.run_config
|
158
|
-
self.jobs.run_config.environment
|
159
|
-
|
160
|
-
|
178
|
+
if (
|
179
|
+
hasattr(self.jobs, "run_config")
|
180
|
+
and hasattr(self.jobs.run_config, "environment")
|
181
|
+
and hasattr(self.jobs.run_config.environment, "bucket_collection")
|
182
|
+
and self.jobs.run_config.environment.bucket_collection
|
183
|
+
):
|
184
|
+
|
185
|
+
for (
|
186
|
+
model,
|
187
|
+
bucket,
|
188
|
+
) in self.jobs.run_config.environment.bucket_collection.items():
|
161
189
|
model_name = model.model
|
162
190
|
model_queues[model_name] = {
|
163
191
|
"language_model_name": model_name,
|
@@ -166,7 +194,9 @@ class JobsRunnerStatusBase(ABC):
|
|
166
194
|
"requested": bucket.requests_bucket.num_requests,
|
167
195
|
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
168
196
|
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
169
|
-
"current_rate": round(
|
197
|
+
"current_rate": round(
|
198
|
+
bucket.requests_bucket.get_throughput(), 1
|
199
|
+
),
|
170
200
|
},
|
171
201
|
"tokens_bucket": {
|
172
202
|
"completed": bucket.tokens_bucket.num_released,
|
@@ -179,10 +209,11 @@ class JobsRunnerStatusBase(ABC):
|
|
179
209
|
status_dict["language_model_queues"] = model_queues
|
180
210
|
return status_dict
|
181
211
|
|
182
|
-
def add_completed_interview(self, interview):
|
212
|
+
def add_completed_interview(self, interview: "Interview"):
|
183
213
|
"""Records a completed interview without storing the full interview data."""
|
184
214
|
self.stats_tracker.add_completed_interview(
|
185
215
|
model=interview.model.model,
|
216
|
+
exceptions=interview.exceptions.list(),
|
186
217
|
num_exceptions=interview.exceptions.num_exceptions(),
|
187
218
|
num_unfixed=interview.exceptions.num_unfixed_exceptions(),
|
188
219
|
)
|
edsl/jobs/remote_inference.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
import re
|
2
|
+
import math
|
2
3
|
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from ..coop import CoopServerResponseError
|
5
|
-
from ..coop.utils import VisibilityType
|
6
|
+
from ..coop.utils import VisibilityType, CostConverter
|
6
7
|
from ..coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
|
7
8
|
from .jobs_status_enums import JobsStatus
|
8
|
-
from .jobs_remote_inference_logger import JobLogger
|
9
|
+
from .jobs_remote_inference_logger import JobLogger, JobRunExceptionCounter, ModelCost
|
9
10
|
from .exceptions import RemoteInferenceError
|
10
11
|
|
11
12
|
|
@@ -94,6 +95,12 @@ class JobsRemoteInferenceHandler:
|
|
94
95
|
"Remote inference activated. Sending job to server...",
|
95
96
|
status=JobsStatus.QUEUED,
|
96
97
|
)
|
98
|
+
logger.add_info(
|
99
|
+
"remote_inference_url", f"{self.expected_parrot_url}/home/remote-inference"
|
100
|
+
)
|
101
|
+
logger.add_info(
|
102
|
+
"remote_cache_url", f"{self.expected_parrot_url}/home/remote-cache"
|
103
|
+
)
|
97
104
|
remote_job_creation_data = coop.remote_inference_create(
|
98
105
|
self.jobs,
|
99
106
|
description=remote_inference_description,
|
@@ -183,7 +190,9 @@ class JobsRemoteInferenceHandler:
|
|
183
190
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
184
191
|
) -> None:
|
185
192
|
"Handles a failed job by logging the error and updating the job status."
|
186
|
-
|
193
|
+
error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
|
194
|
+
"error_report_url"
|
195
|
+
)
|
187
196
|
|
188
197
|
reason = remote_job_data.get("reason")
|
189
198
|
|
@@ -193,8 +202,8 @@ class JobsRemoteInferenceHandler:
|
|
193
202
|
status=JobsStatus.FAILED,
|
194
203
|
)
|
195
204
|
|
196
|
-
if
|
197
|
-
job_info.logger.add_info("error_report_url",
|
205
|
+
if error_report_url:
|
206
|
+
job_info.logger.add_info("error_report_url", error_report_url)
|
198
207
|
|
199
208
|
job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
|
200
209
|
job_info.logger.update(
|
@@ -206,39 +215,46 @@ class JobsRemoteInferenceHandler:
|
|
206
215
|
status=JobsStatus.FAILED,
|
207
216
|
)
|
208
217
|
|
209
|
-
def
|
218
|
+
def _update_interview_details(
|
210
219
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
211
220
|
) -> None:
|
212
|
-
"
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
221
|
+
"Updates the interview details in the job info."
|
222
|
+
latest_job_run_details = remote_job_data.get("latest_job_run_details", {})
|
223
|
+
interview_details = latest_job_run_details.get("interview_details", {}) or {}
|
224
|
+
completed_interviews = interview_details.get("completed_interviews", 0)
|
225
|
+
interviews_with_exceptions = interview_details.get(
|
226
|
+
"interviews_with_exceptions", 0
|
227
|
+
)
|
228
|
+
interviews_without_exceptions = (
|
229
|
+
completed_interviews - interviews_with_exceptions
|
230
|
+
)
|
231
|
+
job_info.logger.add_info("completed_interviews", interviews_without_exceptions)
|
232
|
+
job_info.logger.add_info("failed_interviews", interviews_with_exceptions)
|
233
|
+
|
234
|
+
exception_summary = interview_details.get("exception_summary", []) or []
|
235
|
+
if exception_summary:
|
236
|
+
job_run_exception_counters = []
|
237
|
+
for exception in exception_summary:
|
238
|
+
exception_counter = JobRunExceptionCounter(
|
239
|
+
exception_type=exception.get("exception_type"),
|
240
|
+
inference_service=exception.get("inference_service"),
|
241
|
+
model=exception.get("model"),
|
242
|
+
question_name=exception.get("question_name"),
|
243
|
+
exception_count=exception.get("exception_count"),
|
244
|
+
)
|
245
|
+
job_run_exception_counters.append(exception_counter)
|
246
|
+
job_info.logger.add_info("exception_summary", job_run_exception_counters)
|
231
247
|
|
232
248
|
def _handle_partially_failed_job(
|
233
249
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
234
250
|
) -> None:
|
235
251
|
"Handles a partially failed job by logging the error and updating the job status."
|
236
|
-
|
237
|
-
|
238
|
-
|
252
|
+
error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
|
253
|
+
"error_report_url"
|
254
|
+
)
|
239
255
|
|
240
|
-
if
|
241
|
-
job_info.logger.add_info("error_report_url",
|
256
|
+
if error_report_url:
|
257
|
+
job_info.logger.add_info("error_report_url", error_report_url)
|
242
258
|
|
243
259
|
job_info.logger.update(
|
244
260
|
"Job completed with partial results.", status=JobsStatus.PARTIALLY_FAILED
|
@@ -263,6 +279,128 @@ class JobsRemoteInferenceHandler:
|
|
263
279
|
)
|
264
280
|
time.sleep(self.poll_interval)
|
265
281
|
|
282
|
+
def _get_expenses_from_results(
|
283
|
+
self, results: "Results", include_cached_responses_in_cost: bool = False
|
284
|
+
) -> dict:
|
285
|
+
"""
|
286
|
+
Calculates expenses from Results object.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
results: Results object containing model responses
|
290
|
+
include_cached_responses_in_cost: Whether to include cached responses in cost calculation
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
Dictionary mapping ExpenseKey to TokenExpense information
|
294
|
+
"""
|
295
|
+
expenses = {}
|
296
|
+
|
297
|
+
for result in results:
|
298
|
+
raw_response = result["raw_model_response"]
|
299
|
+
|
300
|
+
# Process each cost field in the response
|
301
|
+
for key in raw_response:
|
302
|
+
if not key.endswith("_cost"):
|
303
|
+
continue
|
304
|
+
|
305
|
+
result_cost = raw_response[key]
|
306
|
+
if not isinstance(result_cost, (int, float)):
|
307
|
+
continue
|
308
|
+
|
309
|
+
question_name = key.removesuffix("_cost")
|
310
|
+
cache_used = result["cache_used_dict"][question_name]
|
311
|
+
|
312
|
+
# Skip if we're excluding cached responses and this was cached
|
313
|
+
if not include_cached_responses_in_cost and cache_used:
|
314
|
+
continue
|
315
|
+
|
316
|
+
# Get expense keys for input and output tokens
|
317
|
+
input_key = (
|
318
|
+
result["model"]._inference_service_,
|
319
|
+
result["model"].model,
|
320
|
+
"input",
|
321
|
+
raw_response[f"{question_name}_input_price_per_million_tokens"],
|
322
|
+
)
|
323
|
+
output_key = (
|
324
|
+
result["model"]._inference_service_,
|
325
|
+
result["model"].model,
|
326
|
+
"output",
|
327
|
+
raw_response[f"{question_name}_output_price_per_million_tokens"],
|
328
|
+
)
|
329
|
+
|
330
|
+
# Update input token expenses
|
331
|
+
if input_key not in expenses:
|
332
|
+
expenses[input_key] = {
|
333
|
+
"tokens": 0,
|
334
|
+
"cost_usd": 0,
|
335
|
+
}
|
336
|
+
|
337
|
+
input_price_per_million_tokens = input_key[3]
|
338
|
+
input_tokens = raw_response[f"{question_name}_input_tokens"]
|
339
|
+
input_cost = (input_price_per_million_tokens / 1_000_000) * input_tokens
|
340
|
+
|
341
|
+
expenses[input_key]["tokens"] += input_tokens
|
342
|
+
expenses[input_key]["cost_usd"] += input_cost
|
343
|
+
|
344
|
+
# Update output token expenses
|
345
|
+
if output_key not in expenses:
|
346
|
+
expenses[output_key] = {
|
347
|
+
"tokens": 0,
|
348
|
+
"cost_usd": 0,
|
349
|
+
}
|
350
|
+
|
351
|
+
output_price_per_million_tokens = output_key[3]
|
352
|
+
output_tokens = raw_response[f"{question_name}_output_tokens"]
|
353
|
+
output_cost = (
|
354
|
+
output_price_per_million_tokens / 1_000_000
|
355
|
+
) * output_tokens
|
356
|
+
|
357
|
+
expenses[output_key]["tokens"] += output_tokens
|
358
|
+
expenses[output_key]["cost_usd"] += output_cost
|
359
|
+
|
360
|
+
expenses_by_model = {}
|
361
|
+
for expense_key, expense_usage in expenses.items():
|
362
|
+
service, model, token_type, _ = expense_key
|
363
|
+
model_key = (service, model)
|
364
|
+
|
365
|
+
if model_key not in expenses_by_model:
|
366
|
+
expenses_by_model[model_key] = {
|
367
|
+
"service": service,
|
368
|
+
"model": model,
|
369
|
+
"input_tokens": 0,
|
370
|
+
"input_cost_usd": 0,
|
371
|
+
"output_tokens": 0,
|
372
|
+
"output_cost_usd": 0,
|
373
|
+
}
|
374
|
+
|
375
|
+
if token_type == "input":
|
376
|
+
expenses_by_model[model_key]["input_tokens"] += expense_usage["tokens"]
|
377
|
+
expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
|
378
|
+
"cost_usd"
|
379
|
+
]
|
380
|
+
elif token_type == "output":
|
381
|
+
expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
|
382
|
+
expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
|
383
|
+
"cost_usd"
|
384
|
+
]
|
385
|
+
|
386
|
+
converter = CostConverter()
|
387
|
+
for model_key, model_cost_dict in expenses_by_model.items():
|
388
|
+
input_cost = model_cost_dict["input_cost_usd"]
|
389
|
+
output_cost = model_cost_dict["output_cost_usd"]
|
390
|
+
model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
|
391
|
+
model_cost_dict["output_cost_credits"] = converter.usd_to_credits(
|
392
|
+
output_cost
|
393
|
+
)
|
394
|
+
# Convert back to USD (to get the rounded value)
|
395
|
+
model_cost_dict["input_cost_usd"] = converter.credits_to_usd(
|
396
|
+
model_cost_dict["input_cost_credits"]
|
397
|
+
)
|
398
|
+
model_cost_dict["output_cost_usd"] = converter.credits_to_usd(
|
399
|
+
model_cost_dict["output_cost_credits"]
|
400
|
+
)
|
401
|
+
|
402
|
+
return list(expenses_by_model.values())
|
403
|
+
|
266
404
|
def _fetch_results_and_log(
|
267
405
|
self,
|
268
406
|
job_info: RemoteJobInfo,
|
@@ -274,12 +412,30 @@ class JobsRemoteInferenceHandler:
|
|
274
412
|
"Fetches the results object and logs the results URL."
|
275
413
|
job_info.logger.add_info("results_uuid", results_uuid)
|
276
414
|
results = object_fetcher(results_uuid, expected_object_type="results")
|
415
|
+
|
416
|
+
model_cost_dicts = self._get_expenses_from_results(results)
|
417
|
+
|
418
|
+
model_costs = [
|
419
|
+
ModelCost(
|
420
|
+
service=model_cost_dict.get("service"),
|
421
|
+
model=model_cost_dict.get("model"),
|
422
|
+
input_tokens=model_cost_dict.get("input_tokens"),
|
423
|
+
input_cost_usd=model_cost_dict.get("input_cost_usd"),
|
424
|
+
output_tokens=model_cost_dict.get("output_tokens"),
|
425
|
+
output_cost_usd=model_cost_dict.get("output_cost_usd"),
|
426
|
+
)
|
427
|
+
for model_cost_dict in model_cost_dicts
|
428
|
+
]
|
429
|
+
job_info.logger.add_info("model_costs", model_costs)
|
430
|
+
|
277
431
|
results_url = remote_job_data.get("results_url")
|
278
432
|
if "localhost" in results_url:
|
279
433
|
results_url = results_url.replace("8000", "1234")
|
280
434
|
job_info.logger.add_info("results_url", results_url)
|
281
435
|
|
282
436
|
if job_status == "completed":
|
437
|
+
job_info.logger.add_info("completed_interviews", len(results))
|
438
|
+
job_info.logger.add_info("failed_interviews", 0)
|
283
439
|
job_info.logger.update(
|
284
440
|
f"Job completed and Results stored on Coop. [View Results]({results_url})",
|
285
441
|
status=JobsStatus.COMPLETED,
|
@@ -302,6 +458,7 @@ class JobsRemoteInferenceHandler:
|
|
302
458
|
) -> Union[None, "Results", Literal["continue"]]:
|
303
459
|
"""Makes one attempt to fetch and process a remote job's status and results."""
|
304
460
|
remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
|
461
|
+
self._update_interview_details(job_info, remote_job_data)
|
305
462
|
status = remote_job_data.get("status")
|
306
463
|
reason = remote_job_data.get("reason")
|
307
464
|
if status == "cancelled":
|
@@ -1019,7 +1019,7 @@ class LanguageModel(
|
|
1019
1019
|
|
1020
1020
|
# Combine model name and parameters
|
1021
1021
|
return (
|
1022
|
-
f"Model(model_name = '{self.model}'"
|
1022
|
+
f"Model(model_name = '{self.model}', service_name = '{self._inference_service_}'"
|
1023
1023
|
+ (f", {param_string}" if param_string else "")
|
1024
1024
|
+ ")"
|
1025
1025
|
)
|