edsl 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,12 @@
1
1
  import re
2
+ import math
2
3
  from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
3
4
  from dataclasses import dataclass
4
5
  from ..coop import CoopServerResponseError
5
- from ..coop.utils import VisibilityType
6
+ from ..coop.utils import VisibilityType, CostConverter
6
7
  from ..coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
7
8
  from .jobs_status_enums import JobsStatus
8
- from .jobs_remote_inference_logger import JobLogger
9
+ from .jobs_remote_inference_logger import JobLogger, JobRunExceptionCounter, ModelCost
9
10
  from .exceptions import RemoteInferenceError
10
11
 
11
12
 
@@ -94,6 +95,12 @@ class JobsRemoteInferenceHandler:
94
95
  "Remote inference activated. Sending job to server...",
95
96
  status=JobsStatus.QUEUED,
96
97
  )
98
+ logger.add_info(
99
+ "remote_inference_url", f"{self.expected_parrot_url}/home/remote-inference"
100
+ )
101
+ logger.add_info(
102
+ "remote_cache_url", f"{self.expected_parrot_url}/home/remote-cache"
103
+ )
97
104
  remote_job_creation_data = coop.remote_inference_create(
98
105
  self.jobs,
99
106
  description=remote_inference_description,
@@ -183,7 +190,9 @@ class JobsRemoteInferenceHandler:
183
190
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
184
191
  ) -> None:
185
192
  "Handles a failed job by logging the error and updating the job status."
186
- latest_error_report_url = remote_job_data.get("latest_error_report_url")
193
+ error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
194
+ "error_report_url"
195
+ )
187
196
 
188
197
  reason = remote_job_data.get("reason")
189
198
 
@@ -193,8 +202,8 @@ class JobsRemoteInferenceHandler:
193
202
  status=JobsStatus.FAILED,
194
203
  )
195
204
 
196
- if latest_error_report_url:
197
- job_info.logger.add_info("error_report_url", latest_error_report_url)
205
+ if error_report_url:
206
+ job_info.logger.add_info("error_report_url", error_report_url)
198
207
 
199
208
  job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
200
209
  job_info.logger.update(
@@ -206,39 +215,46 @@ class JobsRemoteInferenceHandler:
206
215
  status=JobsStatus.FAILED,
207
216
  )
208
217
 
209
- def _handle_partially_failed_job_interview_details(
218
+ def _update_interview_details(
210
219
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
211
220
  ) -> None:
212
- "Extracts the interview details from the remote job data."
213
- try:
214
- # Job details is a string of the form "64 out of 1,758 interviews failed"
215
- job_details = remote_job_data.get("latest_failure_description")
216
-
217
- text_without_commas = job_details.replace(",", "")
218
-
219
- # Find all numbers in the text
220
- numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
221
-
222
- failed = numbers[0]
223
- total = numbers[1]
224
- completed = total - failed
225
-
226
- job_info.logger.add_info("completed_interviews", completed)
227
- job_info.logger.add_info("failed_interviews", failed)
228
- # This is mainly helpful metadata, and any errors here should not stop the code
229
- except:
230
- pass
221
+ "Updates the interview details in the job info."
222
+ latest_job_run_details = remote_job_data.get("latest_job_run_details", {})
223
+ interview_details = latest_job_run_details.get("interview_details", {}) or {}
224
+ completed_interviews = interview_details.get("completed_interviews", 0)
225
+ interviews_with_exceptions = interview_details.get(
226
+ "interviews_with_exceptions", 0
227
+ )
228
+ interviews_without_exceptions = (
229
+ completed_interviews - interviews_with_exceptions
230
+ )
231
+ job_info.logger.add_info("completed_interviews", interviews_without_exceptions)
232
+ job_info.logger.add_info("failed_interviews", interviews_with_exceptions)
233
+
234
+ exception_summary = interview_details.get("exception_summary", []) or []
235
+ if exception_summary:
236
+ job_run_exception_counters = []
237
+ for exception in exception_summary:
238
+ exception_counter = JobRunExceptionCounter(
239
+ exception_type=exception.get("exception_type"),
240
+ inference_service=exception.get("inference_service"),
241
+ model=exception.get("model"),
242
+ question_name=exception.get("question_name"),
243
+ exception_count=exception.get("exception_count"),
244
+ )
245
+ job_run_exception_counters.append(exception_counter)
246
+ job_info.logger.add_info("exception_summary", job_run_exception_counters)
231
247
 
232
248
  def _handle_partially_failed_job(
233
249
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
234
250
  ) -> None:
235
251
  "Handles a partially failed job by logging the error and updating the job status."
236
- self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
237
-
238
- latest_error_report_url = remote_job_data.get("latest_error_report_url")
252
+ error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
253
+ "error_report_url"
254
+ )
239
255
 
240
- if latest_error_report_url:
241
- job_info.logger.add_info("error_report_url", latest_error_report_url)
256
+ if error_report_url:
257
+ job_info.logger.add_info("error_report_url", error_report_url)
242
258
 
243
259
  job_info.logger.update(
244
260
  "Job completed with partial results.", status=JobsStatus.PARTIALLY_FAILED
@@ -263,6 +279,149 @@ class JobsRemoteInferenceHandler:
263
279
  )
264
280
  time.sleep(self.poll_interval)
265
281
 
282
+ def _get_expenses_from_results(self, results: "Results") -> dict:
283
+ """
284
+ Calculates expenses from Results object.
285
+
286
+ Args:
287
+ results: Results object containing model responses
288
+ include_cached_responses_in_cost: Whether to include cached responses in cost calculation
289
+
290
+ Returns:
291
+ Dictionary mapping ExpenseKey to TokenExpense information
292
+ """
293
+ expenses = {}
294
+
295
+ for result in results:
296
+ raw_response = result["raw_model_response"]
297
+
298
+ # Process each cost field in the response
299
+ for key in raw_response:
300
+ if not key.endswith("_cost"):
301
+ continue
302
+
303
+ result_cost = raw_response[key]
304
+ if not isinstance(result_cost, (int, float)):
305
+ continue
306
+
307
+ question_name = key.removesuffix("_cost")
308
+ cache_used = result["cache_used_dict"][question_name]
309
+
310
+ # Get expense keys for input and output tokens
311
+ input_key = (
312
+ result["model"]._inference_service_,
313
+ result["model"].model,
314
+ "input",
315
+ raw_response[f"{question_name}_input_price_per_million_tokens"],
316
+ )
317
+ output_key = (
318
+ result["model"]._inference_service_,
319
+ result["model"].model,
320
+ "output",
321
+ raw_response[f"{question_name}_output_price_per_million_tokens"],
322
+ )
323
+
324
+ # Update input token expenses
325
+ if input_key not in expenses:
326
+ expenses[input_key] = {
327
+ "tokens": 0,
328
+ "cost_usd": 0,
329
+ "cost_usd_with_cache": 0,
330
+ }
331
+
332
+ input_price_per_million_tokens = input_key[3]
333
+ input_tokens = raw_response[f"{question_name}_input_tokens"]
334
+ input_cost = (input_price_per_million_tokens / 1_000_000) * input_tokens
335
+
336
+ expenses[input_key]["tokens"] += input_tokens
337
+ expenses[input_key]["cost_usd"] += input_cost
338
+
339
+ if not cache_used:
340
+ expenses[input_key]["cost_usd_with_cache"] += input_cost
341
+
342
+ # Update output token expenses
343
+ if output_key not in expenses:
344
+ expenses[output_key] = {
345
+ "tokens": 0,
346
+ "cost_usd": 0,
347
+ "cost_usd_with_cache": 0,
348
+ }
349
+
350
+ output_price_per_million_tokens = output_key[3]
351
+ output_tokens = raw_response[f"{question_name}_output_tokens"]
352
+ output_cost = (
353
+ output_price_per_million_tokens / 1_000_000
354
+ ) * output_tokens
355
+
356
+ expenses[output_key]["tokens"] += output_tokens
357
+ expenses[output_key]["cost_usd"] += output_cost
358
+
359
+ if not cache_used:
360
+ expenses[output_key]["cost_usd_with_cache"] += output_cost
361
+
362
+ expenses_by_model = {}
363
+ for expense_key, expense_usage in expenses.items():
364
+ service, model, token_type, _ = expense_key
365
+ model_key = (service, model)
366
+
367
+ if model_key not in expenses_by_model:
368
+ expenses_by_model[model_key] = {
369
+ "service": service,
370
+ "model": model,
371
+ "input_tokens": 0,
372
+ "input_cost_usd": 0,
373
+ "input_cost_usd_with_cache": 0,
374
+ "output_tokens": 0,
375
+ "output_cost_usd": 0,
376
+ "output_cost_usd_with_cache": 0,
377
+ }
378
+
379
+ if token_type == "input":
380
+ expenses_by_model[model_key]["input_tokens"] += expense_usage["tokens"]
381
+ expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
382
+ "cost_usd"
383
+ ]
384
+ expenses_by_model[model_key][
385
+ "input_cost_usd_with_cache"
386
+ ] += expense_usage["cost_usd_with_cache"]
387
+ elif token_type == "output":
388
+ expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
389
+ expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
390
+ "cost_usd"
391
+ ]
392
+ expenses_by_model[model_key][
393
+ "output_cost_usd_with_cache"
394
+ ] += expense_usage["cost_usd_with_cache"]
395
+
396
+ converter = CostConverter()
397
+ for model_key, model_cost_dict in expenses_by_model.items():
398
+
399
+ # Handle full cost (without cache)
400
+ input_cost = model_cost_dict["input_cost_usd"]
401
+ output_cost = model_cost_dict["output_cost_usd"]
402
+ model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
403
+ model_cost_dict["output_cost_credits"] = converter.usd_to_credits(
404
+ output_cost
405
+ )
406
+ # Convert back to USD (to get the rounded value)
407
+ model_cost_dict["input_cost_usd"] = converter.credits_to_usd(
408
+ model_cost_dict["input_cost_credits"]
409
+ )
410
+ model_cost_dict["output_cost_usd"] = converter.credits_to_usd(
411
+ model_cost_dict["output_cost_credits"]
412
+ )
413
+
414
+ # Handle cost with cache
415
+ input_cost_with_cache = model_cost_dict["input_cost_usd_with_cache"]
416
+ output_cost_with_cache = model_cost_dict["output_cost_usd_with_cache"]
417
+ model_cost_dict["input_cost_credits_with_cache"] = converter.usd_to_credits(
418
+ input_cost_with_cache
419
+ )
420
+ model_cost_dict["output_cost_credits_with_cache"] = (
421
+ converter.usd_to_credits(output_cost_with_cache)
422
+ )
423
+ return list(expenses_by_model.values())
424
+
266
425
  def _fetch_results_and_log(
267
426
  self,
268
427
  job_info: RemoteJobInfo,
@@ -274,12 +433,36 @@ class JobsRemoteInferenceHandler:
274
433
  "Fetches the results object and logs the results URL."
275
434
  job_info.logger.add_info("results_uuid", results_uuid)
276
435
  results = object_fetcher(results_uuid, expected_object_type="results")
436
+
437
+ model_cost_dicts = self._get_expenses_from_results(results)
438
+
439
+ model_costs = [
440
+ ModelCost(
441
+ service=model_cost_dict.get("service"),
442
+ model=model_cost_dict.get("model"),
443
+ input_tokens=model_cost_dict.get("input_tokens"),
444
+ input_cost_usd=model_cost_dict.get("input_cost_usd"),
445
+ output_tokens=model_cost_dict.get("output_tokens"),
446
+ output_cost_usd=model_cost_dict.get("output_cost_usd"),
447
+ input_cost_credits_with_cache=model_cost_dict.get(
448
+ "input_cost_credits_with_cache"
449
+ ),
450
+ output_cost_credits_with_cache=model_cost_dict.get(
451
+ "output_cost_credits_with_cache"
452
+ ),
453
+ )
454
+ for model_cost_dict in model_cost_dicts
455
+ ]
456
+ job_info.logger.add_info("model_costs", model_costs)
457
+
277
458
  results_url = remote_job_data.get("results_url")
278
459
  if "localhost" in results_url:
279
460
  results_url = results_url.replace("8000", "1234")
280
461
  job_info.logger.add_info("results_url", results_url)
281
462
 
282
463
  if job_status == "completed":
464
+ job_info.logger.add_info("completed_interviews", len(results))
465
+ job_info.logger.add_info("failed_interviews", 0)
283
466
  job_info.logger.update(
284
467
  f"Job completed and Results stored on Coop. [View Results]({results_url})",
285
468
  status=JobsStatus.COMPLETED,
@@ -302,6 +485,7 @@ class JobsRemoteInferenceHandler:
302
485
  ) -> Union[None, "Results", Literal["continue"]]:
303
486
  """Makes one attempt to fetch and process a remote job's status and results."""
304
487
  remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
488
+ self._update_interview_details(job_info, remote_job_data)
305
489
  status = remote_job_data.get("status")
306
490
  reason = remote_job_data.get("reason")
307
491
  if status == "cancelled":
@@ -769,8 +769,45 @@ class LanguageModel(
769
769
  params["question_name"] = invigilator.question.question_name
770
770
  # Get timeout from configuration
771
771
  from ..config import CONFIG
772
-
773
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
772
+ import logging
773
+
774
+ logger = logging.getLogger(__name__)
775
+ base_timeout = float(CONFIG.get("EDSL_API_TIMEOUT"))
776
+
777
+ # Adjust timeout if files are present
778
+ import time
779
+
780
+ start = time.time()
781
+ if files_list:
782
+ # Calculate total size of attached files in MB
783
+ file_sizes = []
784
+ for file in files_list:
785
+ # Try different attributes that might contain the file content
786
+ if hasattr(file, "base64_string") and file.base64_string:
787
+ file_sizes.append(len(file.base64_string) / (1024 * 1024))
788
+ elif hasattr(file, "content") and file.content:
789
+ file_sizes.append(len(file.content) / (1024 * 1024))
790
+ elif hasattr(file, "data") and file.data:
791
+ file_sizes.append(len(file.data) / (1024 * 1024))
792
+ else:
793
+ # Default minimum size if we can't determine actual size
794
+ file_sizes.append(1) # Assume at least 1MB
795
+ total_size_mb = sum(file_sizes)
796
+
797
+ # Increase timeout proportionally to file size
798
+ # For each MB of file size, add 10 seconds to the timeout (adjust as needed)
799
+ size_adjustment = total_size_mb * 10
800
+
801
+ # Cap the maximum timeout adjustment at 5 minutes (300 seconds)
802
+ size_adjustment = min(size_adjustment, 300)
803
+
804
+ TIMEOUT = base_timeout + size_adjustment
805
+
806
+ logger.info(
807
+ f"Adjusted timeout for API call with {len(files_list)} files (total size: {total_size_mb:.2f}MB). Base timeout: {base_timeout}s, New timeout: {TIMEOUT}s"
808
+ )
809
+ else:
810
+ TIMEOUT = base_timeout
774
811
 
775
812
  # Execute the model call with timeout
776
813
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
@@ -1019,7 +1056,7 @@ class LanguageModel(
1019
1056
 
1020
1057
  # Combine model name and parameters
1021
1058
  return (
1022
- f"Model(model_name = '{self.model}'"
1059
+ f"Model(model_name = '{self.model}', service_name = '{self._inference_service_}'"
1023
1060
  + (f", {param_string}" if param_string else "")
1024
1061
  + ")"
1025
1062
  )
@@ -19,56 +19,12 @@ class ResponseCost:
19
19
  total_cost: Union[float, str, None] = None
20
20
 
21
21
 
22
- class PriceManager:
23
- _instance = None
24
- _price_lookup: Dict[Tuple[str, str], Dict] = {}
25
- _is_initialized = False
26
-
27
- def __new__(cls):
28
- if cls._instance is None:
29
- instance = super(PriceManager, cls).__new__(cls)
30
- instance._price_lookup = {} # Instance-specific attribute
31
- instance._is_initialized = False
32
- cls._instance = instance # Store the instance directly
33
- return instance
34
- return cls._instance
35
-
36
- def __init__(self):
37
- """Initialize the singleton instance only once."""
38
- if not self._is_initialized:
39
- self._is_initialized = True
40
- self.refresh_prices()
41
-
42
- @classmethod
43
- def get_instance(cls):
44
- """Get the singleton instance, creating it if necessary."""
45
- if cls._instance is None:
46
- cls() # Create the instance if it doesn't exist
47
- return cls._instance
48
-
49
- @classmethod
50
- def reset(cls):
51
- """Reset the singleton instance to clean up resources."""
52
- cls._instance = None
53
- cls._is_initialized = False
54
- cls._price_lookup = {}
55
-
56
- def __del__(self):
57
- """Ensure proper cleanup when the instance is garbage collected."""
58
- try:
59
- self._price_lookup = {} # Clean up resources
60
- except:
61
- pass # Ignore any cleanup errors
22
+ class PriceRetriever:
23
+ DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
24
+ DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
62
25
 
63
- def refresh_prices(self) -> None:
64
- """Fetch fresh prices and update the internal price lookup."""
65
- from edsl.coop import Coop
66
-
67
- c = Coop()
68
- try:
69
- self._price_lookup = c.fetch_prices()
70
- except Exception as e:
71
- print(f"Error fetching prices: {str(e)}")
26
+ def __init__(self, price_lookup: Dict[Tuple[str, str], Dict]):
27
+ self._price_lookup = price_lookup
72
28
 
73
29
  def get_price(self, inference_service: str, model: str) -> Dict:
74
30
  """Get the price information for a specific service and model."""
@@ -77,10 +33,6 @@ class PriceManager:
77
33
  inference_service
78
34
  )
79
35
 
80
- def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
81
- """Get the complete price lookup dictionary."""
82
- return self._price_lookup.copy()
83
-
84
36
  def _get_fallback_price(self, inference_service: str) -> Dict:
85
37
  """
86
38
  Get fallback prices for a service.
@@ -101,15 +53,21 @@ class PriceManager:
101
53
  if service == inference_service
102
54
  ]
103
55
 
104
- default_price_info = {
56
+ default_input_price_info = {
57
+ "one_usd_buys": 1_000_000,
58
+ "service_stated_token_qty": 1_000_000,
59
+ "service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
60
+ }
61
+
62
+ default_output_price_info = {
105
63
  "one_usd_buys": 1_000_000,
106
64
  "service_stated_token_qty": 1_000_000,
107
- "service_stated_token_price": 1.0,
65
+ "service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
108
66
  }
109
67
 
110
68
  # Find the most expensive price entries (lowest tokens per USD)
111
- input_price_info = default_price_info
112
- output_price_info = default_price_info
69
+ input_price_info = default_input_price_info
70
+ output_price_info = default_output_price_info
113
71
 
114
72
  input_prices = [
115
73
  PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
@@ -156,6 +114,82 @@ class PriceManager:
156
114
  price_per_million_tokens = round(price_per_token * 1_000_000, 10)
157
115
  return price_per_million_tokens
158
116
 
117
+
118
+ class PriceManager:
119
+ _instance = None
120
+ _price_lookup: Dict[Tuple[str, str], Dict] = {}
121
+ _is_initialized = False
122
+
123
+ def __new__(cls):
124
+ if cls._instance is None:
125
+ instance = super(PriceManager, cls).__new__(cls)
126
+ instance._price_lookup = {} # Instance-specific attribute
127
+ instance._is_initialized = False
128
+ cls._instance = instance # Store the instance directly
129
+ return instance
130
+ return cls._instance
131
+
132
+ def __init__(self):
133
+ """Initialize the singleton instance only once."""
134
+ if not self._is_initialized:
135
+ self._is_initialized = True
136
+ self.refresh_prices()
137
+
138
+ @classmethod
139
+ def get_instance(cls):
140
+ """Get the singleton instance, creating it if necessary."""
141
+ if cls._instance is None:
142
+ cls() # Create the instance if it doesn't exist
143
+ return cls._instance
144
+
145
+ @classmethod
146
+ def reset(cls):
147
+ """Reset the singleton instance to clean up resources."""
148
+ cls._instance = None
149
+ cls._is_initialized = False
150
+ cls._price_lookup = {}
151
+
152
+ def __del__(self):
153
+ """Ensure proper cleanup when the instance is garbage collected."""
154
+ try:
155
+ self._price_lookup = {} # Clean up resources
156
+ except:
157
+ pass # Ignore any cleanup errors
158
+
159
+ @property
160
+ def price_retriever(self):
161
+ return PriceRetriever(self._price_lookup)
162
+
163
+ def refresh_prices(self) -> None:
164
+ """Fetch fresh prices and update the internal price lookup."""
165
+ from edsl.coop import Coop
166
+
167
+ c = Coop()
168
+ try:
169
+ self._price_lookup = c.fetch_prices()
170
+ except Exception as e:
171
+ print(f"Error fetching prices: {str(e)}")
172
+
173
+ def get_price(self, inference_service: str, model: str) -> Dict:
174
+ """Get the price information for a specific service and model."""
175
+ return self.price_retriever.get_price(inference_service, model)
176
+
177
+ def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
178
+ """Get the complete price lookup dictionary."""
179
+ return self._price_lookup.copy()
180
+
181
+ def get_price_per_million_tokens(
182
+ self,
183
+ relevant_prices: Dict,
184
+ token_type: Literal["input", "output"],
185
+ ) -> Dict:
186
+ """
187
+ Get the price per million tokens for a specific service, model, and token type.
188
+ """
189
+ return self.price_retriever.get_price_per_million_tokens(
190
+ relevant_prices, token_type
191
+ )
192
+
159
193
  def _calculate_total_cost(
160
194
  self,
161
195
  relevant_prices: Dict,
edsl/prompts/prompt.py CHANGED
@@ -290,6 +290,7 @@ class Prompt(PersistenceMixin, RepresentationMixin):
290
290
  return result
291
291
  except Exception as e:
292
292
  print(f"Error rendering prompt: {e}")
293
+ raise e
293
294
  return self
294
295
 
295
296
  @staticmethod