edsl 0.1.56__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,16 +13,15 @@ if TYPE_CHECKING:
13
13
  from ..invigilators.invigilator_base import Invigilator
14
14
 
15
15
  from .fetch_invigilator import FetchInvigilator
16
+ from ..coop.utils import CostConverter
16
17
  from ..caching import CacheEntry
17
18
  from ..dataset import Dataset
19
+ from ..language_models.price_manager import PriceRetriever
18
20
 
19
21
  logger = logging.getLogger(__name__)
20
22
 
21
23
 
22
24
  class PromptCostEstimator:
23
-
24
- DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
25
- DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
26
25
  CHARS_PER_TOKEN = 4
27
26
  OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
28
27
  PIPING_MULTIPLIER = 2
@@ -37,7 +36,7 @@ class PromptCostEstimator:
37
36
  ):
38
37
  self.system_prompt = system_prompt
39
38
  self.user_prompt = user_prompt
40
- self.price_lookup = price_lookup
39
+ self.price_retriever = PriceRetriever(price_lookup)
41
40
  self.inference_service = inference_service
42
41
  self.model = model
43
42
 
@@ -49,91 +48,6 @@ class PromptCostEstimator:
49
48
  return PromptCostEstimator.PIPING_MULTIPLIER
50
49
  return 1
51
50
 
52
- def _get_fallback_price(self, inference_service: str) -> Dict:
53
- """
54
- Get fallback prices for a service.
55
- - First fallback: The highest input and output prices for that service from the price lookup.
56
- - Second fallback: $1.00 per million tokens (for both input and output).
57
-
58
- Args:
59
- inference_service (str): The inference service name
60
-
61
- Returns:
62
- Dict: Price information
63
- """
64
- PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
65
-
66
- service_prices = [
67
- prices
68
- for (service, _), prices in self.price_lookup.items()
69
- if service == inference_service
70
- ]
71
-
72
- default_input_price_info = {
73
- "one_usd_buys": 1_000_000,
74
- "service_stated_token_qty": 1_000_000,
75
- "service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
76
- }
77
- default_output_price_info = {
78
- "one_usd_buys": 1_000_000,
79
- "service_stated_token_qty": 1_000_000,
80
- "service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
81
- }
82
-
83
- # Find the most expensive price entries (lowest tokens per USD)
84
- input_price_info = default_input_price_info
85
- output_price_info = default_output_price_info
86
-
87
- input_prices = [
88
- PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
89
- for p in service_prices
90
- if "input" in p
91
- ]
92
- if input_prices:
93
- input_price_info = min(
94
- input_prices, key=lambda price: price.tokens_per_usd
95
- ).price_info
96
-
97
- output_prices = [
98
- PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
99
- for p in service_prices
100
- if "output" in p
101
- ]
102
- if output_prices:
103
- output_price_info = min(
104
- output_prices, key=lambda price: price.tokens_per_usd
105
- ).price_info
106
-
107
- return {
108
- "input": input_price_info,
109
- "output": output_price_info,
110
- }
111
-
112
- def get_price(self, inference_service: str, model: str) -> Dict:
113
- """Get the price information for a specific service and model."""
114
- key = (inference_service, model)
115
- return self.price_lookup.get(key) or self._get_fallback_price(inference_service)
116
-
117
- def get_price_per_million_tokens(
118
- self,
119
- relevant_prices: Dict,
120
- token_type: Literal["input", "output"],
121
- ) -> Dict:
122
- """
123
- Get the price per million tokens for a specific service, model, and token type.
124
- """
125
- service_price = relevant_prices[token_type]["service_stated_token_price"]
126
- service_qty = relevant_prices[token_type]["service_stated_token_qty"]
127
-
128
- if service_qty == 1_000_000:
129
- price_per_million_tokens = service_price
130
- elif service_qty == 1_000:
131
- price_per_million_tokens = service_price * 1_000
132
- else:
133
- price_per_token = service_price / service_qty
134
- price_per_million_tokens = round(price_per_token * 1_000_000, 10)
135
- return price_per_million_tokens
136
-
137
51
  def __call__(self):
138
52
  user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
139
53
  str(self.user_prompt)
@@ -145,13 +59,15 @@ class PromptCostEstimator:
145
59
  input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
146
60
  output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
147
61
 
148
- relevant_prices = self.get_price(self.inference_service, self.model)
62
+ relevant_prices = self.price_retriever.get_price(
63
+ self.inference_service, self.model
64
+ )
149
65
 
150
- input_price_per_million_tokens = self.get_price_per_million_tokens(
151
- relevant_prices, "input"
66
+ input_price_per_million_tokens = (
67
+ self.price_retriever.get_price_per_million_tokens(relevant_prices, "input")
152
68
  )
153
- output_price_per_million_tokens = self.get_price_per_million_tokens(
154
- relevant_prices, "output"
69
+ output_price_per_million_tokens = (
70
+ self.price_retriever.get_price_per_million_tokens(relevant_prices, "output")
155
71
  )
156
72
 
157
73
  input_price_per_token = input_price_per_million_tokens / 1_000_000
@@ -429,8 +345,14 @@ class JobsPrompts:
429
345
  group["cost_usd"] *= iterations
430
346
  detailed_costs.append(group)
431
347
 
348
+ # Convert to credits
349
+ converter = CostConverter()
350
+ for group in detailed_costs:
351
+ group["credits_hold"] = converter.usd_to_credits(group["cost_usd"])
352
+
432
353
  # Calculate totals
433
- estimated_total_cost = sum(group["cost_usd"] for group in detailed_costs)
354
+ estimated_total_cost_usd = sum(group["cost_usd"] for group in detailed_costs)
355
+ total_credits_hold = sum(group["credits_hold"] for group in detailed_costs)
434
356
  estimated_total_input_tokens = sum(
435
357
  group["tokens"]
436
358
  for group in detailed_costs
@@ -443,7 +365,8 @@ class JobsPrompts:
443
365
  )
444
366
 
445
367
  output = {
446
- "estimated_total_cost_usd": estimated_total_cost,
368
+ "estimated_total_cost_usd": estimated_total_cost_usd,
369
+ "total_credits_hold": total_credits_hold,
447
370
  "estimated_total_input_tokens": estimated_total_input_tokens,
448
371
  "estimated_total_output_tokens": estimated_total_output_tokens,
449
372
  "detailed_costs": detailed_costs,
@@ -23,15 +23,38 @@ class LogMessage:
23
23
  status: JobsStatus
24
24
 
25
25
 
26
+ @dataclass
27
+ class JobRunExceptionCounter:
28
+ exception_type: str = None
29
+ inference_service: str = None
30
+ model: str = None
31
+ question_name: str = None
32
+ exception_count: int = None
33
+
34
+
35
+ @dataclass
36
+ class ModelCost:
37
+ service: str = None
38
+ model: str = None
39
+ input_tokens: int = None
40
+ input_cost_usd: float = None
41
+ output_tokens: int = None
42
+ output_cost_usd: float = None
43
+
44
+
26
45
  @dataclass
27
46
  class JobsInfo:
28
47
  job_uuid: str = None
29
48
  progress_bar_url: str = None
30
49
  error_report_url: str = None
50
+ remote_inference_url: str = None
51
+ remote_cache_url: str = None
31
52
  results_uuid: str = None
32
53
  results_url: str = None
33
54
  completed_interviews: int = None
34
55
  failed_interviews: int = None
56
+ exception_summary: list[JobRunExceptionCounter] = None
57
+ model_costs: list[ModelCost] = None
35
58
 
36
59
  pretty_names = {
37
60
  "job_uuid": "Job UUID",
@@ -39,6 +62,8 @@ class JobsInfo:
39
62
  "error_report_url": "Exceptions Report URL",
40
63
  "results_uuid": "Results UUID",
41
64
  "results_url": "Results URL",
65
+ "remote_inference_url": "Remote Jobs",
66
+ "remote_cache_url": "Remote Cache",
42
67
  }
43
68
 
44
69
 
@@ -57,6 +82,8 @@ class JobLogger(ABC):
57
82
  "results_url",
58
83
  "completed_interviews",
59
84
  "failed_interviews",
85
+ "model_costs",
86
+ "exception_summary",
60
87
  ],
61
88
  value: str,
62
89
  ):
@@ -11,15 +11,7 @@ from uuid import UUID
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from .jobs import Jobs
14
-
15
-
16
- @dataclass
17
- class ModelInfo:
18
- model_name: str
19
- TPM_limit_k: float
20
- RPM_limit_k: float
21
- num_tasks_waiting: int
22
- token_usage_info: dict
14
+ from ..interviews import Interview
23
15
 
24
16
 
25
17
  class StatisticsTracker:
@@ -29,16 +21,33 @@ class StatisticsTracker:
29
21
  self.completed_count = 0
30
22
  self.completed_by_model = defaultdict(int)
31
23
  self.distinct_models = distinct_models
24
+ self.interviews_with_exceptions = 0
32
25
  self.total_exceptions = 0
33
26
  self.unfixed_exceptions = 0
27
+ self.exceptions_counter = defaultdict(int)
34
28
 
35
29
  def add_completed_interview(
36
- self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
30
+ self,
31
+ model: str,
32
+ exceptions: list[dict],
33
+ num_exceptions: int = 0,
34
+ num_unfixed: int = 0,
37
35
  ):
38
36
  self.completed_count += 1
39
37
  self.completed_by_model[model] += 1
40
38
  self.total_exceptions += num_exceptions
41
39
  self.unfixed_exceptions += num_unfixed
40
+ if num_exceptions > 0:
41
+ self.interviews_with_exceptions += 1
42
+
43
+ for exception in exceptions:
44
+ key = (
45
+ exception["exception_type"],
46
+ exception["inference_service"],
47
+ exception["model"],
48
+ exception["question_name"],
49
+ )
50
+ self.exceptions_counter[key] += 1
42
51
 
43
52
  def get_elapsed_time(self) -> float:
44
53
  return time.time() - self.start_time
@@ -88,9 +97,7 @@ class JobsRunnerStatusBase(ABC):
88
97
  ]
89
98
  self.num_total_interviews = n * len(self.jobs)
90
99
 
91
- self.distinct_models = list(
92
- set(model.model for model in self.jobs.models)
93
- )
100
+ self.distinct_models = list(set(model.model for model in self.jobs.models))
94
101
 
95
102
  self.stats_tracker = StatisticsTracker(
96
103
  total_interviews=self.num_total_interviews,
@@ -130,6 +137,7 @@ class JobsRunnerStatusBase(ABC):
130
137
  status_dict = {
131
138
  "overall_progress": {
132
139
  "completed": self.stats_tracker.completed_count,
140
+ "has_exceptions": self.stats_tracker.interviews_with_exceptions,
133
141
  "total": self.num_total_interviews,
134
142
  "percent": (
135
143
  (
@@ -148,16 +156,36 @@ class JobsRunnerStatusBase(ABC):
148
156
  if self.stats_tracker.completed_count >= self.num_total_interviews
149
157
  else "running"
150
158
  ),
159
+ "exceptions_counter": [
160
+ {
161
+ "exception_type": exception_type,
162
+ "inference_service": inference_service,
163
+ "model": model,
164
+ "question_name": question_name,
165
+ "count": count,
166
+ }
167
+ for (
168
+ exception_type,
169
+ inference_service,
170
+ model,
171
+ question_name,
172
+ ), count in self.stats_tracker.exceptions_counter.items()
173
+ ],
151
174
  }
152
175
 
153
176
  model_queues = {}
154
177
  # Check if bucket collection exists and is not empty
155
- if (hasattr(self.jobs, 'run_config') and
156
- hasattr(self.jobs.run_config, 'environment') and
157
- hasattr(self.jobs.run_config.environment, 'bucket_collection') and
158
- self.jobs.run_config.environment.bucket_collection):
159
-
160
- for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
178
+ if (
179
+ hasattr(self.jobs, "run_config")
180
+ and hasattr(self.jobs.run_config, "environment")
181
+ and hasattr(self.jobs.run_config.environment, "bucket_collection")
182
+ and self.jobs.run_config.environment.bucket_collection
183
+ ):
184
+
185
+ for (
186
+ model,
187
+ bucket,
188
+ ) in self.jobs.run_config.environment.bucket_collection.items():
161
189
  model_name = model.model
162
190
  model_queues[model_name] = {
163
191
  "language_model_name": model_name,
@@ -166,7 +194,9 @@ class JobsRunnerStatusBase(ABC):
166
194
  "requested": bucket.requests_bucket.num_requests,
167
195
  "tokens_returned": bucket.requests_bucket.tokens_returned,
168
196
  "target_rate": round(bucket.requests_bucket.target_rate, 1),
169
- "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
197
+ "current_rate": round(
198
+ bucket.requests_bucket.get_throughput(), 1
199
+ ),
170
200
  },
171
201
  "tokens_bucket": {
172
202
  "completed": bucket.tokens_bucket.num_released,
@@ -179,10 +209,11 @@ class JobsRunnerStatusBase(ABC):
179
209
  status_dict["language_model_queues"] = model_queues
180
210
  return status_dict
181
211
 
182
- def add_completed_interview(self, interview):
212
+ def add_completed_interview(self, interview: "Interview"):
183
213
  """Records a completed interview without storing the full interview data."""
184
214
  self.stats_tracker.add_completed_interview(
185
215
  model=interview.model.model,
216
+ exceptions=interview.exceptions.list(),
186
217
  num_exceptions=interview.exceptions.num_exceptions(),
187
218
  num_unfixed=interview.exceptions.num_unfixed_exceptions(),
188
219
  )
@@ -1,11 +1,12 @@
1
1
  import re
2
+ import math
2
3
  from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
3
4
  from dataclasses import dataclass
4
5
  from ..coop import CoopServerResponseError
5
- from ..coop.utils import VisibilityType
6
+ from ..coop.utils import VisibilityType, CostConverter
6
7
  from ..coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
7
8
  from .jobs_status_enums import JobsStatus
8
- from .jobs_remote_inference_logger import JobLogger
9
+ from .jobs_remote_inference_logger import JobLogger, JobRunExceptionCounter, ModelCost
9
10
  from .exceptions import RemoteInferenceError
10
11
 
11
12
 
@@ -94,6 +95,12 @@ class JobsRemoteInferenceHandler:
94
95
  "Remote inference activated. Sending job to server...",
95
96
  status=JobsStatus.QUEUED,
96
97
  )
98
+ logger.add_info(
99
+ "remote_inference_url", f"{self.expected_parrot_url}/home/remote-inference"
100
+ )
101
+ logger.add_info(
102
+ "remote_cache_url", f"{self.expected_parrot_url}/home/remote-cache"
103
+ )
97
104
  remote_job_creation_data = coop.remote_inference_create(
98
105
  self.jobs,
99
106
  description=remote_inference_description,
@@ -183,7 +190,9 @@ class JobsRemoteInferenceHandler:
183
190
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
184
191
  ) -> None:
185
192
  "Handles a failed job by logging the error and updating the job status."
186
- latest_error_report_url = remote_job_data.get("latest_error_report_url")
193
+ error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
194
+ "error_report_url"
195
+ )
187
196
 
188
197
  reason = remote_job_data.get("reason")
189
198
 
@@ -193,8 +202,8 @@ class JobsRemoteInferenceHandler:
193
202
  status=JobsStatus.FAILED,
194
203
  )
195
204
 
196
- if latest_error_report_url:
197
- job_info.logger.add_info("error_report_url", latest_error_report_url)
205
+ if error_report_url:
206
+ job_info.logger.add_info("error_report_url", error_report_url)
198
207
 
199
208
  job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
200
209
  job_info.logger.update(
@@ -206,39 +215,46 @@ class JobsRemoteInferenceHandler:
206
215
  status=JobsStatus.FAILED,
207
216
  )
208
217
 
209
- def _handle_partially_failed_job_interview_details(
218
+ def _update_interview_details(
210
219
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
211
220
  ) -> None:
212
- "Extracts the interview details from the remote job data."
213
- try:
214
- # Job details is a string of the form "64 out of 1,758 interviews failed"
215
- job_details = remote_job_data.get("latest_failure_description")
216
-
217
- text_without_commas = job_details.replace(",", "")
218
-
219
- # Find all numbers in the text
220
- numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
221
-
222
- failed = numbers[0]
223
- total = numbers[1]
224
- completed = total - failed
225
-
226
- job_info.logger.add_info("completed_interviews", completed)
227
- job_info.logger.add_info("failed_interviews", failed)
228
- # This is mainly helpful metadata, and any errors here should not stop the code
229
- except:
230
- pass
221
+ "Updates the interview details in the job info."
222
+ latest_job_run_details = remote_job_data.get("latest_job_run_details", {})
223
+ interview_details = latest_job_run_details.get("interview_details", {}) or {}
224
+ completed_interviews = interview_details.get("completed_interviews", 0)
225
+ interviews_with_exceptions = interview_details.get(
226
+ "interviews_with_exceptions", 0
227
+ )
228
+ interviews_without_exceptions = (
229
+ completed_interviews - interviews_with_exceptions
230
+ )
231
+ job_info.logger.add_info("completed_interviews", interviews_without_exceptions)
232
+ job_info.logger.add_info("failed_interviews", interviews_with_exceptions)
233
+
234
+ exception_summary = interview_details.get("exception_summary", []) or []
235
+ if exception_summary:
236
+ job_run_exception_counters = []
237
+ for exception in exception_summary:
238
+ exception_counter = JobRunExceptionCounter(
239
+ exception_type=exception.get("exception_type"),
240
+ inference_service=exception.get("inference_service"),
241
+ model=exception.get("model"),
242
+ question_name=exception.get("question_name"),
243
+ exception_count=exception.get("exception_count"),
244
+ )
245
+ job_run_exception_counters.append(exception_counter)
246
+ job_info.logger.add_info("exception_summary", job_run_exception_counters)
231
247
 
232
248
  def _handle_partially_failed_job(
233
249
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
234
250
  ) -> None:
235
251
  "Handles a partially failed job by logging the error and updating the job status."
236
- self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
237
-
238
- latest_error_report_url = remote_job_data.get("latest_error_report_url")
252
+ error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
253
+ "error_report_url"
254
+ )
239
255
 
240
- if latest_error_report_url:
241
- job_info.logger.add_info("error_report_url", latest_error_report_url)
256
+ if error_report_url:
257
+ job_info.logger.add_info("error_report_url", error_report_url)
242
258
 
243
259
  job_info.logger.update(
244
260
  "Job completed with partial results.", status=JobsStatus.PARTIALLY_FAILED
@@ -263,6 +279,128 @@ class JobsRemoteInferenceHandler:
263
279
  )
264
280
  time.sleep(self.poll_interval)
265
281
 
282
+ def _get_expenses_from_results(
283
+ self, results: "Results", include_cached_responses_in_cost: bool = False
284
+ ) -> dict:
285
+ """
286
+ Calculates expenses from Results object.
287
+
288
+ Args:
289
+ results: Results object containing model responses
290
+ include_cached_responses_in_cost: Whether to include cached responses in cost calculation
291
+
292
+ Returns:
293
+ Dictionary mapping ExpenseKey to TokenExpense information
294
+ """
295
+ expenses = {}
296
+
297
+ for result in results:
298
+ raw_response = result["raw_model_response"]
299
+
300
+ # Process each cost field in the response
301
+ for key in raw_response:
302
+ if not key.endswith("_cost"):
303
+ continue
304
+
305
+ result_cost = raw_response[key]
306
+ if not isinstance(result_cost, (int, float)):
307
+ continue
308
+
309
+ question_name = key.removesuffix("_cost")
310
+ cache_used = result["cache_used_dict"][question_name]
311
+
312
+ # Skip if we're excluding cached responses and this was cached
313
+ if not include_cached_responses_in_cost and cache_used:
314
+ continue
315
+
316
+ # Get expense keys for input and output tokens
317
+ input_key = (
318
+ result["model"]._inference_service_,
319
+ result["model"].model,
320
+ "input",
321
+ raw_response[f"{question_name}_input_price_per_million_tokens"],
322
+ )
323
+ output_key = (
324
+ result["model"]._inference_service_,
325
+ result["model"].model,
326
+ "output",
327
+ raw_response[f"{question_name}_output_price_per_million_tokens"],
328
+ )
329
+
330
+ # Update input token expenses
331
+ if input_key not in expenses:
332
+ expenses[input_key] = {
333
+ "tokens": 0,
334
+ "cost_usd": 0,
335
+ }
336
+
337
+ input_price_per_million_tokens = input_key[3]
338
+ input_tokens = raw_response[f"{question_name}_input_tokens"]
339
+ input_cost = (input_price_per_million_tokens / 1_000_000) * input_tokens
340
+
341
+ expenses[input_key]["tokens"] += input_tokens
342
+ expenses[input_key]["cost_usd"] += input_cost
343
+
344
+ # Update output token expenses
345
+ if output_key not in expenses:
346
+ expenses[output_key] = {
347
+ "tokens": 0,
348
+ "cost_usd": 0,
349
+ }
350
+
351
+ output_price_per_million_tokens = output_key[3]
352
+ output_tokens = raw_response[f"{question_name}_output_tokens"]
353
+ output_cost = (
354
+ output_price_per_million_tokens / 1_000_000
355
+ ) * output_tokens
356
+
357
+ expenses[output_key]["tokens"] += output_tokens
358
+ expenses[output_key]["cost_usd"] += output_cost
359
+
360
+ expenses_by_model = {}
361
+ for expense_key, expense_usage in expenses.items():
362
+ service, model, token_type, _ = expense_key
363
+ model_key = (service, model)
364
+
365
+ if model_key not in expenses_by_model:
366
+ expenses_by_model[model_key] = {
367
+ "service": service,
368
+ "model": model,
369
+ "input_tokens": 0,
370
+ "input_cost_usd": 0,
371
+ "output_tokens": 0,
372
+ "output_cost_usd": 0,
373
+ }
374
+
375
+ if token_type == "input":
376
+ expenses_by_model[model_key]["input_tokens"] += expense_usage["tokens"]
377
+ expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
378
+ "cost_usd"
379
+ ]
380
+ elif token_type == "output":
381
+ expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
382
+ expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
383
+ "cost_usd"
384
+ ]
385
+
386
+ converter = CostConverter()
387
+ for model_key, model_cost_dict in expenses_by_model.items():
388
+ input_cost = model_cost_dict["input_cost_usd"]
389
+ output_cost = model_cost_dict["output_cost_usd"]
390
+ model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
391
+ model_cost_dict["output_cost_credits"] = converter.usd_to_credits(
392
+ output_cost
393
+ )
394
+ # Convert back to USD (to get the rounded value)
395
+ model_cost_dict["input_cost_usd"] = converter.credits_to_usd(
396
+ model_cost_dict["input_cost_credits"]
397
+ )
398
+ model_cost_dict["output_cost_usd"] = converter.credits_to_usd(
399
+ model_cost_dict["output_cost_credits"]
400
+ )
401
+
402
+ return list(expenses_by_model.values())
403
+
266
404
  def _fetch_results_and_log(
267
405
  self,
268
406
  job_info: RemoteJobInfo,
@@ -274,12 +412,30 @@ class JobsRemoteInferenceHandler:
274
412
  "Fetches the results object and logs the results URL."
275
413
  job_info.logger.add_info("results_uuid", results_uuid)
276
414
  results = object_fetcher(results_uuid, expected_object_type="results")
415
+
416
+ model_cost_dicts = self._get_expenses_from_results(results)
417
+
418
+ model_costs = [
419
+ ModelCost(
420
+ service=model_cost_dict.get("service"),
421
+ model=model_cost_dict.get("model"),
422
+ input_tokens=model_cost_dict.get("input_tokens"),
423
+ input_cost_usd=model_cost_dict.get("input_cost_usd"),
424
+ output_tokens=model_cost_dict.get("output_tokens"),
425
+ output_cost_usd=model_cost_dict.get("output_cost_usd"),
426
+ )
427
+ for model_cost_dict in model_cost_dicts
428
+ ]
429
+ job_info.logger.add_info("model_costs", model_costs)
430
+
277
431
  results_url = remote_job_data.get("results_url")
278
432
  if "localhost" in results_url:
279
433
  results_url = results_url.replace("8000", "1234")
280
434
  job_info.logger.add_info("results_url", results_url)
281
435
 
282
436
  if job_status == "completed":
437
+ job_info.logger.add_info("completed_interviews", len(results))
438
+ job_info.logger.add_info("failed_interviews", 0)
283
439
  job_info.logger.update(
284
440
  f"Job completed and Results stored on Coop. [View Results]({results_url})",
285
441
  status=JobsStatus.COMPLETED,
@@ -302,6 +458,7 @@ class JobsRemoteInferenceHandler:
302
458
  ) -> Union[None, "Results", Literal["continue"]]:
303
459
  """Makes one attempt to fetch and process a remote job's status and results."""
304
460
  remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
461
+ self._update_interview_details(job_info, remote_job_data)
305
462
  status = remote_job_data.get("status")
306
463
  reason = remote_job_data.get("reason")
307
464
  if status == "cancelled":
@@ -1019,7 +1019,7 @@ class LanguageModel(
1019
1019
 
1020
1020
  # Combine model name and parameters
1021
1021
  return (
1022
- f"Model(model_name = '{self.model}'"
1022
+ f"Model(model_name = '{self.model}', service_name = '{self._inference_service_}'"
1023
1023
  + (f", {param_string}" if param_string else "")
1024
1024
  + ")"
1025
1025
  )