edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/base/data_transfer_models.py +15 -4
  8. edsl/buckets/__init__.py +8 -3
  9. edsl/buckets/bucket_collection.py +9 -3
  10. edsl/buckets/model_buckets.py +4 -2
  11. edsl/buckets/token_bucket.py +2 -2
  12. edsl/buckets/token_bucket_client.py +5 -3
  13. edsl/caching/cache.py +131 -62
  14. edsl/caching/cache_entry.py +70 -58
  15. edsl/caching/sql_dict.py +17 -0
  16. edsl/cli.py +99 -0
  17. edsl/config/config_class.py +16 -0
  18. edsl/conversation/__init__.py +31 -0
  19. edsl/coop/coop.py +276 -242
  20. edsl/coop/coop_jobs_objects.py +59 -0
  21. edsl/coop/coop_objects.py +29 -0
  22. edsl/coop/coop_regular_objects.py +26 -0
  23. edsl/coop/utils.py +24 -19
  24. edsl/dataset/dataset.py +338 -101
  25. edsl/dataset/dataset_operations_mixin.py +216 -180
  26. edsl/db_list/sqlite_list.py +349 -0
  27. edsl/inference_services/__init__.py +40 -5
  28. edsl/inference_services/exceptions.py +11 -0
  29. edsl/inference_services/services/anthropic_service.py +5 -2
  30. edsl/inference_services/services/aws_bedrock.py +6 -2
  31. edsl/inference_services/services/azure_ai.py +6 -2
  32. edsl/inference_services/services/google_service.py +7 -3
  33. edsl/inference_services/services/mistral_ai_service.py +6 -2
  34. edsl/inference_services/services/open_ai_service.py +6 -2
  35. edsl/inference_services/services/perplexity_service.py +6 -2
  36. edsl/inference_services/services/test_service.py +94 -5
  37. edsl/interviews/answering_function.py +167 -59
  38. edsl/interviews/interview.py +124 -72
  39. edsl/interviews/interview_task_manager.py +10 -0
  40. edsl/interviews/request_token_estimator.py +8 -0
  41. edsl/invigilators/invigilators.py +35 -13
  42. edsl/jobs/async_interview_runner.py +146 -104
  43. edsl/jobs/data_structures.py +6 -4
  44. edsl/jobs/decorators.py +61 -0
  45. edsl/jobs/fetch_invigilator.py +61 -18
  46. edsl/jobs/html_table_job_logger.py +14 -2
  47. edsl/jobs/jobs.py +180 -104
  48. edsl/jobs/jobs_component_constructor.py +2 -2
  49. edsl/jobs/jobs_interview_constructor.py +2 -0
  50. edsl/jobs/jobs_pricing_estimation.py +154 -113
  51. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  52. edsl/jobs/jobs_runner_status.py +30 -25
  53. edsl/jobs/progress_bar_manager.py +79 -0
  54. edsl/jobs/remote_inference.py +35 -1
  55. edsl/key_management/key_lookup_builder.py +6 -1
  56. edsl/language_models/language_model.py +110 -12
  57. edsl/language_models/model.py +10 -3
  58. edsl/language_models/price_manager.py +176 -71
  59. edsl/language_models/registry.py +5 -0
  60. edsl/notebooks/notebook.py +77 -10
  61. edsl/questions/VALIDATION_README.md +134 -0
  62. edsl/questions/__init__.py +24 -1
  63. edsl/questions/exceptions.py +21 -0
  64. edsl/questions/question_dict.py +201 -16
  65. edsl/questions/question_multiple_choice_with_other.py +624 -0
  66. edsl/questions/question_registry.py +2 -1
  67. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  68. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  69. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  70. edsl/questions/validation_analysis.py +185 -0
  71. edsl/questions/validation_cli.py +131 -0
  72. edsl/questions/validation_html_report.py +404 -0
  73. edsl/questions/validation_logger.py +136 -0
  74. edsl/results/result.py +115 -46
  75. edsl/results/results.py +702 -171
  76. edsl/scenarios/construct_download_link.py +16 -3
  77. edsl/scenarios/directory_scanner.py +226 -226
  78. edsl/scenarios/file_methods.py +5 -0
  79. edsl/scenarios/file_store.py +150 -9
  80. edsl/scenarios/handlers/__init__.py +5 -1
  81. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  82. edsl/scenarios/handlers/webm_file_store.py +104 -0
  83. edsl/scenarios/scenario.py +120 -101
  84. edsl/scenarios/scenario_list.py +800 -727
  85. edsl/scenarios/scenario_list_gc_test.py +146 -0
  86. edsl/scenarios/scenario_list_memory_test.py +214 -0
  87. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  88. edsl/scenarios/scenario_selector.py +5 -4
  89. edsl/scenarios/scenario_source.py +1990 -0
  90. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  91. edsl/surveys/survey.py +22 -0
  92. edsl/tasks/__init__.py +4 -2
  93. edsl/tasks/task_history.py +198 -36
  94. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  95. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  96. edsl/utilities/__init__.py +2 -1
  97. edsl/utilities/decorators.py +121 -0
  98. edsl/utilities/memory_debugger.py +1010 -0
  99. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
  100. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
  101. edsl/jobs/jobs_runner_asyncio.py +0 -281
  102. edsl/language_models/unused/fake_openai_service.py +0 -60
  103. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
  104. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
  105. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  import math
3
3
 
4
- from typing import List, TYPE_CHECKING, Union, Literal
4
+ from typing import List, TYPE_CHECKING, Union, Literal, Dict
5
+ from collections import namedtuple
5
6
 
6
7
  if TYPE_CHECKING:
7
8
  from .jobs import Jobs
@@ -20,8 +21,8 @@ logger = logging.getLogger(__name__)
20
21
 
21
22
  class PromptCostEstimator:
22
23
 
23
- DEFAULT_INPUT_PRICE_PER_TOKEN = 0.000001
24
- DEFAULT_OUTPUT_PRICE_PER_TOKEN = 0.000001
24
+ DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
25
+ DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
25
26
  CHARS_PER_TOKEN = 4
26
27
  OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
27
28
  PIPING_MULTIPLIER = 2
@@ -48,81 +49,90 @@ class PromptCostEstimator:
48
49
  return PromptCostEstimator.PIPING_MULTIPLIER
49
50
  return 1
50
51
 
51
- @property
52
- def key(self):
53
- return (self.inference_service, self.model)
54
-
55
- @property
56
- def relevant_prices(self):
57
- try:
58
- return self.price_lookup[self.key]
59
- except KeyError:
60
- return {}
61
-
62
- def _get_highest_price_for_service(self, price_type: str) -> Union[float, None]:
63
- """Returns the highest price per token for a given service and price type (input/output).
52
+ def _get_fallback_price(self, inference_service: str) -> Dict:
53
+ """
54
+ Get fallback prices for a service.
55
+ - First fallback: The highest input and output prices for that service from the price lookup.
56
+ - Second fallback: $1.00 per million tokens (for both input and output).
64
57
 
65
58
  Args:
66
- price_type: Either "input" or "output"
59
+ inference_service (str): The inference service name
67
60
 
68
61
  Returns:
69
- float | None: The highest price per token for the service, or None if not found
62
+ Dict: Price information
70
63
  """
71
- prices_for_service = [
72
- prices[price_type]["service_stated_token_price"]
73
- / prices[price_type]["service_stated_token_qty"]
64
+ PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
65
+
66
+ service_prices = [
67
+ prices
74
68
  for (service, _), prices in self.price_lookup.items()
75
- if service == self.inference_service and price_type in prices
69
+ if service == inference_service
76
70
  ]
77
- return max(prices_for_service) if prices_for_service else None
78
71
 
79
- def input_price_per_token(
80
- self,
81
- ) -> tuple[float, Literal["price_lookup", "highest_price_for_service", "default"]]:
82
- try:
83
- return (
84
- self.relevant_prices["input"]["service_stated_token_price"]
85
- / self.relevant_prices["input"]["service_stated_token_qty"]
86
- ), "price_lookup"
87
- except KeyError:
88
- highest_price = self._get_highest_price_for_service("input")
89
- if highest_price is not None:
90
- import warnings
91
-
92
- warnings.warn(
93
- f"Price data not found for {self.key}. Using highest available input price for {self.inference_service}: ${highest_price:.6f} per token"
94
- )
95
- return highest_price, "highest_price_for_service"
96
- import warnings
72
+ default_input_price_info = {
73
+ "one_usd_buys": 1_000_000,
74
+ "service_stated_token_qty": 1_000_000,
75
+ "service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
76
+ }
77
+ default_output_price_info = {
78
+ "one_usd_buys": 1_000_000,
79
+ "service_stated_token_qty": 1_000_000,
80
+ "service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
81
+ }
97
82
 
98
- warnings.warn(
99
- f"Price data not found for {self.inference_service}. Using default estimate for input token price: $1.00 / 1M tokens"
100
- )
101
- return self.DEFAULT_INPUT_PRICE_PER_TOKEN, "default"
83
+ # Find the most expensive price entries (lowest tokens per USD)
84
+ input_price_info = default_input_price_info
85
+ output_price_info = default_output_price_info
102
86
 
103
- def output_price_per_token(
104
- self,
105
- ) -> tuple[float, Literal["price_lookup", "highest_price_for_service", "default"]]:
106
- try:
107
- return (
108
- self.relevant_prices["output"]["service_stated_token_price"]
109
- / self.relevant_prices["output"]["service_stated_token_qty"]
110
- ), "price_lookup"
111
- except KeyError:
112
- highest_price = self._get_highest_price_for_service("output")
113
- if highest_price is not None:
114
- import warnings
115
-
116
- warnings.warn(
117
- f"Price data not found for {self.key}. Using highest available output price for {self.inference_service}: ${highest_price:.6f} per token"
118
- )
119
- return highest_price, "highest_price_for_service"
120
- import warnings
87
+ input_prices = [
88
+ PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
89
+ for p in service_prices
90
+ if "input" in p
91
+ ]
92
+ if input_prices:
93
+ input_price_info = min(
94
+ input_prices, key=lambda price: price.tokens_per_usd
95
+ ).price_info
96
+
97
+ output_prices = [
98
+ PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
99
+ for p in service_prices
100
+ if "output" in p
101
+ ]
102
+ if output_prices:
103
+ output_price_info = min(
104
+ output_prices, key=lambda price: price.tokens_per_usd
105
+ ).price_info
121
106
 
122
- warnings.warn(
123
- f"Price data not found for {self.inference_service}. Using default estimate for output token price: $1.00 / 1M tokens"
124
- )
125
- return self.DEFAULT_OUTPUT_PRICE_PER_TOKEN, "default"
107
+ return {
108
+ "input": input_price_info,
109
+ "output": output_price_info,
110
+ }
111
+
112
+ def get_price(self, inference_service: str, model: str) -> Dict:
113
+ """Get the price information for a specific service and model."""
114
+ key = (inference_service, model)
115
+ return self.price_lookup.get(key) or self._get_fallback_price(inference_service)
116
+
117
+ def get_price_per_million_tokens(
118
+ self,
119
+ relevant_prices: Dict,
120
+ token_type: Literal["input", "output"],
121
+ ) -> Dict:
122
+ """
123
+ Get the price per million tokens for a specific service, model, and token type.
124
+ """
125
+ service_price = relevant_prices[token_type]["service_stated_token_price"]
126
+ service_qty = relevant_prices[token_type]["service_stated_token_qty"]
127
+
128
+ if service_qty == 1_000_000:
129
+ price_per_million_tokens = service_price
130
+ elif service_qty == 1_000:
131
+ price_per_million_tokens = service_price * 1_000
132
+ else:
133
+ price_per_token = service_price / service_qty
134
+ price_per_million_tokens = round(price_per_token * 1_000_000, 10)
135
+ return price_per_million_tokens
126
136
 
127
137
  def __call__(self):
128
138
  user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
@@ -135,20 +145,28 @@ class PromptCostEstimator:
135
145
  input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
136
146
  output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
137
147
 
138
- input_price_per_token, input_price_source = self.input_price_per_token()
139
- output_price_per_token, output_price_source = self.output_price_per_token()
148
+ relevant_prices = self.get_price(self.inference_service, self.model)
140
149
 
141
- cost = (
142
- input_tokens * input_price_per_token
143
- + output_tokens * output_price_per_token
150
+ input_price_per_million_tokens = self.get_price_per_million_tokens(
151
+ relevant_prices, "input"
144
152
  )
153
+ output_price_per_million_tokens = self.get_price_per_million_tokens(
154
+ relevant_prices, "output"
155
+ )
156
+
157
+ input_price_per_token = input_price_per_million_tokens / 1_000_000
158
+ output_price_per_token = output_price_per_million_tokens / 1_000_000
159
+
160
+ input_cost = input_tokens * input_price_per_token
161
+ output_cost = output_tokens * output_price_per_token
162
+ cost = input_cost + output_cost
145
163
  return {
146
- "input_price_source": input_price_source,
147
- "input_price_per_token": input_price_per_token,
164
+ "input_price_per_million_tokens": input_price_per_million_tokens,
165
+ "output_price_per_million_tokens": output_price_per_million_tokens,
148
166
  "input_tokens": input_tokens,
149
- "output_price_source": output_price_source,
150
167
  "output_tokens": output_tokens,
151
- "output_price_per_token": output_price_per_token,
168
+ "input_cost_usd": input_cost,
169
+ "output_cost_usd": output_cost,
152
170
  "cost_usd": cost,
153
171
  }
154
172
 
@@ -328,6 +346,26 @@ class JobsPrompts:
328
346
  "model": model,
329
347
  }
330
348
 
349
+ def process_token_type(self, item: dict, token_type: str) -> tuple:
350
+ """
351
+ Helper function to process a single token type (input or output) for price estimation.
352
+ """
353
+ price = item[f"estimated_{token_type}_price_per_million_tokens"]
354
+ tokens = item[f"estimated_{token_type}_tokens"]
355
+ cost = item[f"estimated_{token_type}_cost_usd"]
356
+
357
+ return (
358
+ (item["inference_service"], item["model"], token_type, price),
359
+ {
360
+ "inference_service": item["inference_service"],
361
+ "model": item["model"],
362
+ "token_type": token_type,
363
+ "price_per_million_tokens": price,
364
+ "tokens": tokens,
365
+ "cost_usd": cost,
366
+ },
367
+ )
368
+
331
369
  def estimate_job_cost_from_external_prices(
332
370
  self, price_lookup: dict, iterations: int = 1
333
371
  ) -> dict:
@@ -341,9 +379,9 @@ class JobsPrompts:
341
379
  - 1 token = 4 characters.
342
380
  - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
343
381
  """
344
- interviews = self.interviews
382
+ # Collect all prompt data
345
383
  data = []
346
- for interview in interviews:
384
+ for interview in self.interviews:
347
385
  invigilators = [
348
386
  FetchInvigilator(interview)(question)
349
387
  for question in self.survey.questions
@@ -354,59 +392,62 @@ class JobsPrompts:
354
392
  **prompt_details, price_lookup=price_lookup
355
393
  )
356
394
  price_estimates = {
395
+ "estimated_input_price_per_million_tokens": prompt_cost[
396
+ "input_price_per_million_tokens"
397
+ ],
398
+ "estimated_output_price_per_million_tokens": prompt_cost[
399
+ "output_price_per_million_tokens"
400
+ ],
357
401
  "estimated_input_tokens": prompt_cost["input_tokens"],
358
402
  "estimated_output_tokens": prompt_cost["output_tokens"],
403
+ "estimated_input_cost_usd": prompt_cost["input_cost_usd"],
404
+ "estimated_output_cost_usd": prompt_cost["output_cost_usd"],
359
405
  "estimated_cost_usd": prompt_cost["cost_usd"],
360
406
  }
361
- data.append({**price_estimates, **prompt_details})
407
+ data.append(
408
+ {
409
+ **prompt_details,
410
+ **price_estimates,
411
+ }
412
+ )
362
413
 
363
- model_groups = {}
414
+ # Group by service, model, token type, and price
415
+ detailed_groups = {}
364
416
  for item in data:
365
- key = (item["inference_service"], item["model"])
366
- if key not in model_groups:
367
- model_groups[key] = {
368
- "inference_service": item["inference_service"],
369
- "model": item["model"],
370
- "estimated_cost_usd": 0,
371
- "estimated_input_tokens": 0,
372
- "estimated_output_tokens": 0,
373
- }
374
-
375
- # Accumulate values
376
- model_groups[key]["estimated_cost_usd"] += item["estimated_cost_usd"]
377
- model_groups[key]["estimated_input_tokens"] += item[
378
- "estimated_input_tokens"
379
- ]
380
- model_groups[key]["estimated_output_tokens"] += item[
381
- "estimated_output_tokens"
382
- ]
383
-
384
- # Apply iterations and convert to list
385
- estimated_costs_by_model = []
386
- for group_data in model_groups.values():
387
- group_data["estimated_cost_usd"] *= iterations
388
- group_data["estimated_input_tokens"] *= iterations
389
- group_data["estimated_output_tokens"] *= iterations
390
- estimated_costs_by_model.append(group_data)
417
+ for token_type in ["input", "output"]:
418
+ key, group_data = self.process_token_type(item, token_type)
419
+ if key not in detailed_groups:
420
+ detailed_groups[key] = group_data
421
+ else:
422
+ detailed_groups[key]["tokens"] += group_data["tokens"]
423
+ detailed_groups[key]["cost_usd"] += group_data["cost_usd"]
424
+
425
+ # Apply iterations and prepare final output
426
+ detailed_costs = []
427
+ for group in detailed_groups.values():
428
+ group["tokens"] *= iterations
429
+ group["cost_usd"] *= iterations
430
+ detailed_costs.append(group)
391
431
 
392
432
  # Calculate totals
393
- estimated_total_cost = sum(
394
- model["estimated_cost_usd"] for model in estimated_costs_by_model
395
- )
433
+ estimated_total_cost = sum(group["cost_usd"] for group in detailed_costs)
396
434
  estimated_total_input_tokens = sum(
397
- model["estimated_input_tokens"] for model in estimated_costs_by_model
435
+ group["tokens"]
436
+ for group in detailed_costs
437
+ if group["token_type"] == "input"
398
438
  )
399
439
  estimated_total_output_tokens = sum(
400
- model["estimated_output_tokens"] for model in estimated_costs_by_model
440
+ group["tokens"]
441
+ for group in detailed_costs
442
+ if group["token_type"] == "output"
401
443
  )
402
444
 
403
445
  output = {
404
446
  "estimated_total_cost_usd": estimated_total_cost,
405
447
  "estimated_total_input_tokens": estimated_total_input_tokens,
406
448
  "estimated_total_output_tokens": estimated_total_output_tokens,
407
- "model_costs": estimated_costs_by_model,
449
+ "detailed_costs": detailed_costs,
408
450
  }
409
-
410
451
  return output
411
452
 
412
453
  def estimate_job_cost(self, iterations: int = 1) -> dict:
@@ -30,6 +30,8 @@ class JobsInfo:
30
30
  error_report_url: str = None
31
31
  results_uuid: str = None
32
32
  results_url: str = None
33
+ completed_interviews: int = None
34
+ failed_interviews: int = None
33
35
 
34
36
  pretty_names = {
35
37
  "job_uuid": "Job UUID",
@@ -53,6 +55,8 @@ class JobLogger(ABC):
53
55
  "error_report_url",
54
56
  "results_uuid",
55
57
  "results_url",
58
+ "completed_interviews",
59
+ "failed_interviews",
56
60
  ],
57
61
  value: str,
58
62
  ):
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
10
10
  from uuid import UUID
11
11
 
12
12
  if TYPE_CHECKING:
13
- from .jobs_runner_asyncio import JobsRunnerAsyncio
13
+ from .jobs import Jobs
14
14
 
15
15
 
16
16
  @dataclass
@@ -65,14 +65,14 @@ class StatisticsTracker:
65
65
  class JobsRunnerStatusBase(ABC):
66
66
  def __init__(
67
67
  self,
68
- jobs_runner: "JobsRunnerAsyncio",
68
+ jobs: "Jobs",
69
69
  n: int,
70
70
  refresh_rate: float = 1,
71
71
  endpoint_url: Optional[str] = "http://localhost:8000",
72
72
  job_uuid: Optional[UUID] = None,
73
73
  api_key: str = None,
74
74
  ):
75
- self.jobs_runner = jobs_runner
75
+ self.jobs = jobs
76
76
  self.job_uuid = job_uuid
77
77
  self.base_url = f"{endpoint_url}"
78
78
  self.refresh_rate = refresh_rate
@@ -86,10 +86,10 @@ class JobsRunnerStatusBase(ABC):
86
86
  "unfixed_exceptions",
87
87
  "throughput",
88
88
  ]
89
- self.num_total_interviews = n * len(self.jobs_runner)
89
+ self.num_total_interviews = n * len(self.jobs)
90
90
 
91
91
  self.distinct_models = list(
92
- set(model.model for model in self.jobs_runner.jobs.models)
92
+ set(model.model for model in self.jobs.models)
93
93
  )
94
94
 
95
95
  self.stats_tracker = StatisticsTracker(
@@ -151,26 +151,31 @@ class JobsRunnerStatusBase(ABC):
151
151
  }
152
152
 
153
153
  model_queues = {}
154
- # for model, bucket in self.jobs_runner.bucket_collection.items():
155
- for model, bucket in self.jobs_runner.environment.bucket_collection.items():
156
- model_name = model.model
157
- model_queues[model_name] = {
158
- "language_model_name": model_name,
159
- "requests_bucket": {
160
- "completed": bucket.requests_bucket.num_released,
161
- "requested": bucket.requests_bucket.num_requests,
162
- "tokens_returned": bucket.requests_bucket.tokens_returned,
163
- "target_rate": round(bucket.requests_bucket.target_rate, 1),
164
- "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
165
- },
166
- "tokens_bucket": {
167
- "completed": bucket.tokens_bucket.num_released,
168
- "requested": bucket.tokens_bucket.num_requests,
169
- "tokens_returned": bucket.tokens_bucket.tokens_returned,
170
- "target_rate": round(bucket.tokens_bucket.target_rate, 1),
171
- "current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
172
- },
173
- }
154
+ # Check if bucket collection exists and is not empty
155
+ if (hasattr(self.jobs, 'run_config') and
156
+ hasattr(self.jobs.run_config, 'environment') and
157
+ hasattr(self.jobs.run_config.environment, 'bucket_collection') and
158
+ self.jobs.run_config.environment.bucket_collection):
159
+
160
+ for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
161
+ model_name = model.model
162
+ model_queues[model_name] = {
163
+ "language_model_name": model_name,
164
+ "requests_bucket": {
165
+ "completed": bucket.requests_bucket.num_released,
166
+ "requested": bucket.requests_bucket.num_requests,
167
+ "tokens_returned": bucket.requests_bucket.tokens_returned,
168
+ "target_rate": round(bucket.requests_bucket.target_rate, 1),
169
+ "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
170
+ },
171
+ "tokens_bucket": {
172
+ "completed": bucket.tokens_bucket.num_released,
173
+ "requested": bucket.tokens_bucket.num_requests,
174
+ "tokens_returned": bucket.tokens_bucket.tokens_returned,
175
+ "target_rate": round(bucket.tokens_bucket.target_rate, 1),
176
+ "current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
177
+ },
178
+ }
174
179
  status_dict["language_model_queues"] = model_queues
175
180
  return status_dict
176
181
 
@@ -0,0 +1,79 @@
1
+ """
2
+ Progress bar management for asynchronous job execution.
3
+
4
+ This module provides a context manager for handling progress bar setup and thread
5
+ management during job execution. It coordinates the display and updating of progress
6
+ bars, particularly for remote tracking via the Expected Parrot API.
7
+ """
8
+
9
+ import threading
10
+ import warnings
11
+
12
+ from ..coop import Coop
13
+ from .jobs_runner_status import JobsRunnerStatus
14
+
15
+
16
+ class ProgressBarManager:
17
+ """Context manager for handling progress bar setup and thread management.
18
+
19
+ This class manages the progress bar display and updating during job execution,
20
+ particularly for remote tracking via the Expected Parrot API.
21
+
22
+ It handles:
23
+ 1. Setting up a status tracking object
24
+ 2. Creating and managing a background thread for progress updates
25
+ 3. Properly cleaning up resources when execution completes
26
+ """
27
+
28
+ def __init__(self, jobs, run_config, parameters):
29
+ self.parameters = parameters
30
+ self.jobs = jobs
31
+
32
+ # Set up progress tracking
33
+ coop = Coop()
34
+ endpoint_url = coop.get_progress_bar_url()
35
+
36
+ # Set up jobs status object
37
+ params = {
38
+ "jobs": jobs,
39
+ "n": parameters.n,
40
+ "endpoint_url": endpoint_url,
41
+ "job_uuid": parameters.job_uuid,
42
+ }
43
+
44
+ # If the jobs_runner_status is already set, use it directly
45
+ if run_config.environment.jobs_runner_status is not None:
46
+ self.jobs_runner_status = run_config.environment.jobs_runner_status
47
+ else:
48
+ # Otherwise create a new one
49
+ self.jobs_runner_status = JobsRunnerStatus(**params)
50
+
51
+ # Store on run_config for use by other components
52
+ run_config.environment.jobs_runner_status = self.jobs_runner_status
53
+
54
+ self.progress_thread = None
55
+ self.stop_event = threading.Event()
56
+
57
+ def __enter__(self):
58
+ if self.parameters.progress_bar and self.jobs_runner_status.has_ep_api_key():
59
+ self.jobs_runner_status.setup()
60
+ self.progress_thread = threading.Thread(
61
+ target=self._run_progress_bar,
62
+ args=(self.stop_event, self.jobs_runner_status)
63
+ )
64
+ self.progress_thread.start()
65
+ elif self.parameters.progress_bar:
66
+ warnings.warn(
67
+ "You need an Expected Parrot API key to view job progress bars."
68
+ )
69
+ return self.stop_event
70
+
71
+ def __exit__(self, exc_type, exc_val, exc_tb):
72
+ self.stop_event.set()
73
+ if self.progress_thread is not None:
74
+ self.progress_thread.join()
75
+
76
+ @staticmethod
77
+ def _run_progress_bar(stop_event, jobs_runner_status):
78
+ """Runs the progress bar in a separate thread."""
79
+ jobs_runner_status.update_progress(stop_event)
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
2
3
  from dataclasses import dataclass
3
4
  from ..coop import CoopServerResponseError
@@ -112,13 +113,18 @@ class JobsRemoteInferenceHandler:
112
113
  )
113
114
  logger.add_info("job_uuid", job_uuid)
114
115
 
116
+ remote_inference_url = self.remote_inference_url
117
+ if "localhost" in remote_inference_url:
118
+ remote_inference_url = remote_inference_url.replace("8000", "1234")
115
119
  logger.update(
116
- f"Job details are available at your Coop account. [Go to Remote Inference page]({self.remote_inference_url})",
120
+ f"Job details are available at your Coop account. [Go to Remote Inference page]({remote_inference_url})",
117
121
  status=JobsStatus.RUNNING,
118
122
  )
119
123
  progress_bar_url = (
120
124
  f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
121
125
  )
126
+ if "localhost" in progress_bar_url:
127
+ progress_bar_url = progress_bar_url.replace("8000", "1234")
122
128
  logger.add_info("progress_bar_url", progress_bar_url)
123
129
  logger.update(
124
130
  f"View job progress [here]({progress_bar_url})", status=JobsStatus.RUNNING
@@ -200,10 +206,35 @@ class JobsRemoteInferenceHandler:
200
206
  status=JobsStatus.FAILED,
201
207
  )
202
208
 
209
+ def _handle_partially_failed_job_interview_details(
210
+ self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
211
+ ) -> None:
212
+ "Extracts the interview details from the remote job data."
213
+ try:
214
+ # Job details is a string of the form "64 out of 1,758 interviews failed"
215
+ job_details = remote_job_data.get("latest_failure_description")
216
+
217
+ text_without_commas = job_details.replace(",", "")
218
+
219
+ # Find all numbers in the text
220
+ numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
221
+
222
+ failed = numbers[0]
223
+ total = numbers[1]
224
+ completed = total - failed
225
+
226
+ job_info.logger.add_info("completed_interviews", completed)
227
+ job_info.logger.add_info("failed_interviews", failed)
228
+ # This is mainly helpful metadata, and any errors here should not stop the code
229
+ except:
230
+ pass
231
+
203
232
  def _handle_partially_failed_job(
204
233
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
205
234
  ) -> None:
206
235
  "Handles a partially failed job by logging the error and updating the job status."
236
+ self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
237
+
207
238
  latest_error_report_url = remote_job_data.get("latest_error_report_url")
208
239
 
209
240
  if latest_error_report_url:
@@ -244,6 +275,8 @@ class JobsRemoteInferenceHandler:
244
275
  job_info.logger.add_info("results_uuid", results_uuid)
245
276
  results = object_fetcher(results_uuid, expected_object_type="results")
246
277
  results_url = remote_job_data.get("results_url")
278
+ if "localhost" in results_url:
279
+ results_url = results_url.replace("8000", "1234")
247
280
  job_info.logger.add_info("results_url", results_url)
248
281
 
249
282
  if job_status == "completed":
@@ -256,6 +289,7 @@ class JobsRemoteInferenceHandler:
256
289
  f"View partial results [here]({results_url})",
257
290
  status=JobsStatus.PARTIALLY_FAILED,
258
291
  )
292
+
259
293
  results.job_uuid = job_info.job_uuid
260
294
  results.results_uuid = results_uuid
261
295
  return results
@@ -2,6 +2,7 @@ from typing import Optional, TYPE_CHECKING
2
2
  import os
3
3
  from functools import lru_cache
4
4
  import textwrap
5
+ import requests
5
6
 
6
7
  if TYPE_CHECKING:
7
8
  from ..coop import Coop
@@ -255,7 +256,11 @@ class KeyLookupBuilder:
255
256
  return dict(list(os.environ.items()))
256
257
 
257
258
  def _coop_key_value_pairs(self):
258
- return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
259
+ try:
260
+ return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
261
+ except requests.ConnectionError:
262
+ # If connection fails, return empty dict instead of raising error
263
+ return {}
259
264
 
260
265
  def _config_key_value_pairs(self):
261
266
  from ..config import CONFIG