edsl 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/buckets/__init__.py +8 -3
  8. edsl/buckets/bucket_collection.py +9 -3
  9. edsl/buckets/model_buckets.py +4 -2
  10. edsl/buckets/token_bucket.py +2 -2
  11. edsl/buckets/token_bucket_client.py +5 -3
  12. edsl/caching/cache.py +131 -62
  13. edsl/caching/cache_entry.py +70 -58
  14. edsl/caching/sql_dict.py +17 -0
  15. edsl/cli.py +99 -0
  16. edsl/config/config_class.py +16 -0
  17. edsl/conversation/__init__.py +31 -0
  18. edsl/coop/coop.py +276 -242
  19. edsl/coop/coop_jobs_objects.py +59 -0
  20. edsl/coop/coop_objects.py +29 -0
  21. edsl/coop/coop_regular_objects.py +26 -0
  22. edsl/coop/utils.py +24 -19
  23. edsl/dataset/dataset.py +338 -101
  24. edsl/db_list/sqlite_list.py +349 -0
  25. edsl/inference_services/__init__.py +40 -5
  26. edsl/inference_services/exceptions.py +11 -0
  27. edsl/inference_services/services/anthropic_service.py +5 -2
  28. edsl/inference_services/services/aws_bedrock.py +6 -2
  29. edsl/inference_services/services/azure_ai.py +6 -2
  30. edsl/inference_services/services/google_service.py +3 -2
  31. edsl/inference_services/services/mistral_ai_service.py +6 -2
  32. edsl/inference_services/services/open_ai_service.py +6 -2
  33. edsl/inference_services/services/perplexity_service.py +6 -2
  34. edsl/inference_services/services/test_service.py +105 -7
  35. edsl/interviews/answering_function.py +167 -59
  36. edsl/interviews/interview.py +124 -72
  37. edsl/interviews/interview_task_manager.py +10 -0
  38. edsl/invigilators/invigilators.py +10 -1
  39. edsl/jobs/async_interview_runner.py +146 -104
  40. edsl/jobs/data_structures.py +6 -4
  41. edsl/jobs/decorators.py +61 -0
  42. edsl/jobs/fetch_invigilator.py +61 -18
  43. edsl/jobs/html_table_job_logger.py +14 -2
  44. edsl/jobs/jobs.py +180 -104
  45. edsl/jobs/jobs_component_constructor.py +2 -2
  46. edsl/jobs/jobs_interview_constructor.py +2 -0
  47. edsl/jobs/jobs_pricing_estimation.py +127 -46
  48. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  49. edsl/jobs/jobs_runner_status.py +30 -25
  50. edsl/jobs/progress_bar_manager.py +79 -0
  51. edsl/jobs/remote_inference.py +35 -1
  52. edsl/key_management/key_lookup_builder.py +6 -1
  53. edsl/language_models/language_model.py +102 -12
  54. edsl/language_models/model.py +10 -3
  55. edsl/language_models/price_manager.py +45 -75
  56. edsl/language_models/registry.py +5 -0
  57. edsl/language_models/utilities.py +2 -1
  58. edsl/notebooks/notebook.py +77 -10
  59. edsl/questions/VALIDATION_README.md +134 -0
  60. edsl/questions/__init__.py +24 -1
  61. edsl/questions/exceptions.py +21 -0
  62. edsl/questions/question_check_box.py +171 -149
  63. edsl/questions/question_dict.py +243 -51
  64. edsl/questions/question_multiple_choice_with_other.py +624 -0
  65. edsl/questions/question_registry.py +2 -1
  66. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  67. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  68. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  69. edsl/questions/validation_analysis.py +185 -0
  70. edsl/questions/validation_cli.py +131 -0
  71. edsl/questions/validation_html_report.py +404 -0
  72. edsl/questions/validation_logger.py +136 -0
  73. edsl/results/result.py +63 -16
  74. edsl/results/results.py +702 -171
  75. edsl/scenarios/construct_download_link.py +16 -3
  76. edsl/scenarios/directory_scanner.py +226 -226
  77. edsl/scenarios/file_methods.py +5 -0
  78. edsl/scenarios/file_store.py +117 -6
  79. edsl/scenarios/handlers/__init__.py +5 -1
  80. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  81. edsl/scenarios/handlers/webm_file_store.py +104 -0
  82. edsl/scenarios/scenario.py +120 -101
  83. edsl/scenarios/scenario_list.py +800 -727
  84. edsl/scenarios/scenario_list_gc_test.py +146 -0
  85. edsl/scenarios/scenario_list_memory_test.py +214 -0
  86. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  87. edsl/scenarios/scenario_selector.py +5 -4
  88. edsl/scenarios/scenario_source.py +1990 -0
  89. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  90. edsl/surveys/survey.py +22 -0
  91. edsl/tasks/__init__.py +4 -2
  92. edsl/tasks/task_history.py +198 -36
  93. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  94. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  95. edsl/utilities/__init__.py +2 -1
  96. edsl/utilities/decorators.py +121 -0
  97. edsl/utilities/memory_debugger.py +1010 -0
  98. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/METADATA +52 -76
  99. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/RECORD +102 -78
  100. edsl/jobs/jobs_runner_asyncio.py +0 -281
  101. edsl/language_models/unused/fake_openai_service.py +0 -60
  102. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
  103. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
  104. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import math
3
3
 
4
- from typing import List, TYPE_CHECKING
4
+ from typing import List, TYPE_CHECKING, Union, Literal
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from .jobs import Jobs
@@ -26,53 +26,104 @@ class PromptCostEstimator:
26
26
  OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
27
27
  PIPING_MULTIPLIER = 2
28
28
 
29
- def __init__(self,
29
+ def __init__(
30
+ self,
30
31
  system_prompt: str,
31
32
  user_prompt: str,
32
33
  price_lookup: dict,
33
34
  inference_service: str,
34
- model: str):
35
+ model: str,
36
+ ):
35
37
  self.system_prompt = system_prompt
36
38
  self.user_prompt = user_prompt
37
39
  self.price_lookup = price_lookup
38
40
  self.inference_service = inference_service
39
41
  self.model = model
40
42
 
41
- @staticmethod
43
+ @staticmethod
42
44
  def get_piping_multiplier(prompt: str):
43
45
  """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
44
46
 
45
47
  if "{{" in prompt and "}}" in prompt:
46
48
  return PromptCostEstimator.PIPING_MULTIPLIER
47
49
  return 1
48
-
50
+
49
51
  @property
50
52
  def key(self):
51
53
  return (self.inference_service, self.model)
52
-
54
+
53
55
  @property
54
56
  def relevant_prices(self):
55
57
  try:
56
58
  return self.price_lookup[self.key]
57
59
  except KeyError:
58
60
  return {}
59
-
60
- def input_price_per_token(self):
61
+
62
+ def _get_highest_price_for_service(self, price_type: str) -> Union[float, None]:
63
+ """Returns the highest price per token for a given service and price type (input/output).
64
+
65
+ Args:
66
+ price_type: Either "input" or "output"
67
+
68
+ Returns:
69
+ float | None: The highest price per token for the service, or None if not found
70
+ """
71
+ prices_for_service = [
72
+ prices[price_type]["service_stated_token_price"]
73
+ / prices[price_type]["service_stated_token_qty"]
74
+ for (service, _), prices in self.price_lookup.items()
75
+ if service == self.inference_service and price_type in prices
76
+ ]
77
+ return max(prices_for_service) if prices_for_service else None
78
+
79
+ def input_price_per_token(
80
+ self,
81
+ ) -> tuple[float, Literal["price_lookup", "highest_price_for_service", "default"]]:
61
82
  try:
62
- return self.relevant_prices["input"]["service_stated_token_price"] / self.relevant_prices["input"]["service_stated_token_qty"]
83
+ return (
84
+ self.relevant_prices["input"]["service_stated_token_price"]
85
+ / self.relevant_prices["input"]["service_stated_token_qty"]
86
+ ), "price_lookup"
63
87
  except KeyError:
88
+ highest_price = self._get_highest_price_for_service("input")
89
+ if highest_price is not None:
90
+ import warnings
91
+
92
+ warnings.warn(
93
+ f"Price data not found for {self.key}. Using highest available input price for {self.inference_service}: ${highest_price:.6f} per token"
94
+ )
95
+ return highest_price, "highest_price_for_service"
64
96
  import warnings
97
+
65
98
  warnings.warn(
66
- "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $1.00 / 1M tokens; Output: $1.00 / 1M tokens"
99
+ f"Price data not found for {self.inference_service}. Using default estimate for input token price: $1.00 / 1M tokens"
67
100
  )
68
- return self.DEFAULT_INPUT_PRICE_PER_TOKEN
101
+ return self.DEFAULT_INPUT_PRICE_PER_TOKEN, "default"
69
102
 
70
- def output_price_per_token(self):
103
+ def output_price_per_token(
104
+ self,
105
+ ) -> tuple[float, Literal["price_lookup", "highest_price_for_service", "default"]]:
71
106
  try:
72
- return self.relevant_prices["output"]["service_stated_token_price"] / self.relevant_prices["output"]["service_stated_token_qty"]
107
+ return (
108
+ self.relevant_prices["output"]["service_stated_token_price"]
109
+ / self.relevant_prices["output"]["service_stated_token_qty"]
110
+ ), "price_lookup"
73
111
  except KeyError:
74
- return self.DEFAULT_OUTPUT_PRICE_PER_TOKEN
75
-
112
+ highest_price = self._get_highest_price_for_service("output")
113
+ if highest_price is not None:
114
+ import warnings
115
+
116
+ warnings.warn(
117
+ f"Price data not found for {self.key}. Using highest available output price for {self.inference_service}: ${highest_price:.6f} per token"
118
+ )
119
+ return highest_price, "highest_price_for_service"
120
+ import warnings
121
+
122
+ warnings.warn(
123
+ f"Price data not found for {self.inference_service}. Using default estimate for output token price: $1.00 / 1M tokens"
124
+ )
125
+ return self.DEFAULT_OUTPUT_PRICE_PER_TOKEN, "default"
126
+
76
127
  def __call__(self):
77
128
  user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
78
129
  str(self.user_prompt)
@@ -84,20 +135,37 @@ class PromptCostEstimator:
84
135
  input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
85
136
  output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
86
137
 
138
+ input_price_per_token, input_price_source = self.input_price_per_token()
139
+ output_price_per_token, output_price_source = self.output_price_per_token()
140
+
87
141
  cost = (
88
- input_tokens * self.input_price_per_token()
89
- + output_tokens * self.output_price_per_token()
142
+ input_tokens * input_price_per_token
143
+ + output_tokens * output_price_per_token
90
144
  )
91
145
  return {
146
+ "input_price_source": input_price_source,
147
+ "input_price_per_token": input_price_per_token,
92
148
  "input_tokens": input_tokens,
149
+ "output_price_source": output_price_source,
93
150
  "output_tokens": output_tokens,
151
+ "output_price_per_token": output_price_per_token,
94
152
  "cost_usd": cost,
95
153
  }
96
154
 
97
155
 
98
156
  class JobsPrompts:
99
157
 
100
- relevant_keys = ["user_prompt", "system_prompt", "interview_index", "question_name", "scenario_index", "agent_index", "model", "estimated_cost", "cache_keys"]
158
+ relevant_keys = [
159
+ "user_prompt",
160
+ "system_prompt",
161
+ "interview_index",
162
+ "question_name",
163
+ "scenario_index",
164
+ "agent_index",
165
+ "model",
166
+ "estimated_cost",
167
+ "cache_keys",
168
+ ]
101
169
 
102
170
  """This generates the prompts for a job for price estimation purposes.
103
171
 
@@ -105,7 +173,6 @@ class JobsPrompts:
105
173
  So assumptions are made about expansion of Jinja braces, etc.
106
174
  """
107
175
 
108
-
109
176
  @classmethod
110
177
  def from_jobs(cls, jobs: "Jobs"):
111
178
  """Construct a JobsPrompts object from a Jobs object."""
@@ -114,13 +181,16 @@ class JobsPrompts:
114
181
  scenarios = jobs.scenarios
115
182
  survey = jobs.survey
116
183
  return cls(
117
- interviews=interviews,
118
- agents=agents,
119
- scenarios=scenarios,
120
- survey=survey
184
+ interviews=interviews, agents=agents, scenarios=scenarios, survey=survey
121
185
  )
122
-
123
- def __init__(self, interviews: List['Interview'], agents:'AgentList', scenarios: 'ScenarioList', survey: 'Survey'):
186
+
187
+ def __init__(
188
+ self,
189
+ interviews: List["Interview"],
190
+ agents: "AgentList",
191
+ scenarios: "ScenarioList",
192
+ survey: "Survey",
193
+ ):
124
194
  """Initialize with extracted components rather than a Jobs object."""
125
195
  self.interviews = interviews
126
196
  self.agents = agents
@@ -143,17 +213,19 @@ class JobsPrompts:
143
213
  self._price_lookup = c.fetch_prices()
144
214
  return self._price_lookup
145
215
 
146
- def _process_one_invigilator(self, invigilator: 'Invigilator', interview_index: int, iterations: int = 1) -> dict:
216
+ def _process_one_invigilator(
217
+ self, invigilator: "Invigilator", interview_index: int, iterations: int = 1
218
+ ) -> dict:
147
219
  """Process a single invigilator and return a dictionary with all needed data fields."""
148
220
  prompts = invigilator.get_prompts()
149
221
  user_prompt = prompts["user_prompt"]
150
222
  system_prompt = prompts["system_prompt"]
151
-
223
+
152
224
  agent_index = self._agent_lookup[invigilator.agent]
153
225
  scenario_index = self._scenario_lookup[invigilator.scenario]
154
226
  model = invigilator.model.model
155
227
  question_name = invigilator.question.question_name
156
-
228
+
157
229
  # Calculate prompt cost
158
230
  prompt_cost = self.estimate_prompt_cost(
159
231
  system_prompt=system_prompt,
@@ -163,7 +235,7 @@ class JobsPrompts:
163
235
  model=model,
164
236
  )
165
237
  cost = prompt_cost["cost_usd"]
166
-
238
+
167
239
  # Generate cache keys for each iteration
168
240
  cache_keys = []
169
241
  for iteration in range(iterations):
@@ -175,7 +247,7 @@ class JobsPrompts:
175
247
  iteration=iteration,
176
248
  )
177
249
  cache_keys.append(cache_key)
178
-
250
+
179
251
  d = {
180
252
  "user_prompt": user_prompt,
181
253
  "system_prompt": system_prompt,
@@ -200,7 +272,7 @@ class JobsPrompts:
200
272
  dataset_of_prompts = {k: [] for k in self.relevant_keys}
201
273
 
202
274
  interviews = self.interviews
203
-
275
+
204
276
  # Process each interview and invigilator
205
277
  for interview_index, interview in enumerate(interviews):
206
278
  invigilators = [
@@ -210,11 +282,13 @@ class JobsPrompts:
210
282
 
211
283
  for invigilator in invigilators:
212
284
  # Process the invigilator and get all data as a dictionary
213
- data = self._process_one_invigilator(invigilator, interview_index, iterations)
285
+ data = self._process_one_invigilator(
286
+ invigilator, interview_index, iterations
287
+ )
214
288
  for k in self.relevant_keys:
215
289
  dataset_of_prompts[k].append(data[k])
216
-
217
- return Dataset([{k:dataset_of_prompts[k]} for k in self.relevant_keys])
290
+
291
+ return Dataset([{k: dataset_of_prompts[k]} for k in self.relevant_keys])
218
292
 
219
293
  @staticmethod
220
294
  def estimate_prompt_cost(
@@ -230,13 +304,13 @@ class JobsPrompts:
230
304
  user_prompt=user_prompt,
231
305
  price_lookup=price_lookup,
232
306
  inference_service=inference_service,
233
- model=model
307
+ model=model,
234
308
  )()
235
-
309
+
236
310
  @staticmethod
237
311
  def _extract_prompt_details(invigilator: FetchInvigilator) -> dict:
238
312
  """Extracts the prompt details from the invigilator.
239
-
313
+
240
314
  >>> from edsl.invigilators import InvigilatorAI
241
315
  >>> invigilator = InvigilatorAI.example()
242
316
  >>> JobsPrompts._extract_prompt_details(invigilator)
@@ -276,11 +350,13 @@ class JobsPrompts:
276
350
  ]
277
351
  for invigilator in invigilators:
278
352
  prompt_details = self._extract_prompt_details(invigilator)
279
- prompt_cost = self.estimate_prompt_cost(**prompt_details, price_lookup=price_lookup)
353
+ prompt_cost = self.estimate_prompt_cost(
354
+ **prompt_details, price_lookup=price_lookup
355
+ )
280
356
  price_estimates = {
281
- 'estimated_input_tokens': prompt_cost['input_tokens'],
282
- 'estimated_output_tokens': prompt_cost['output_tokens'],
283
- 'estimated_cost_usd': prompt_cost['cost_usd']
357
+ "estimated_input_tokens": prompt_cost["input_tokens"],
358
+ "estimated_output_tokens": prompt_cost["output_tokens"],
359
+ "estimated_cost_usd": prompt_cost["cost_usd"],
284
360
  }
285
361
  data.append({**price_estimates, **prompt_details})
286
362
 
@@ -293,14 +369,18 @@ class JobsPrompts:
293
369
  "model": item["model"],
294
370
  "estimated_cost_usd": 0,
295
371
  "estimated_input_tokens": 0,
296
- "estimated_output_tokens": 0
372
+ "estimated_output_tokens": 0,
297
373
  }
298
-
374
+
299
375
  # Accumulate values
300
376
  model_groups[key]["estimated_cost_usd"] += item["estimated_cost_usd"]
301
- model_groups[key]["estimated_input_tokens"] += item["estimated_input_tokens"]
302
- model_groups[key]["estimated_output_tokens"] += item["estimated_output_tokens"]
303
-
377
+ model_groups[key]["estimated_input_tokens"] += item[
378
+ "estimated_input_tokens"
379
+ ]
380
+ model_groups[key]["estimated_output_tokens"] += item[
381
+ "estimated_output_tokens"
382
+ ]
383
+
304
384
  # Apply iterations and convert to list
305
385
  estimated_costs_by_model = []
306
386
  for group_data in model_groups.values():
@@ -345,4 +425,5 @@ class JobsPrompts:
345
425
 
346
426
  if __name__ == "__main__":
347
427
  import doctest
428
+
348
429
  doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -30,6 +30,8 @@ class JobsInfo:
30
30
  error_report_url: str = None
31
31
  results_uuid: str = None
32
32
  results_url: str = None
33
+ completed_interviews: int = None
34
+ failed_interviews: int = None
33
35
 
34
36
  pretty_names = {
35
37
  "job_uuid": "Job UUID",
@@ -53,6 +55,8 @@ class JobLogger(ABC):
53
55
  "error_report_url",
54
56
  "results_uuid",
55
57
  "results_url",
58
+ "completed_interviews",
59
+ "failed_interviews",
56
60
  ],
57
61
  value: str,
58
62
  ):
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
10
10
  from uuid import UUID
11
11
 
12
12
  if TYPE_CHECKING:
13
- from .jobs_runner_asyncio import JobsRunnerAsyncio
13
+ from .jobs import Jobs
14
14
 
15
15
 
16
16
  @dataclass
@@ -65,14 +65,14 @@ class StatisticsTracker:
65
65
  class JobsRunnerStatusBase(ABC):
66
66
  def __init__(
67
67
  self,
68
- jobs_runner: "JobsRunnerAsyncio",
68
+ jobs: "Jobs",
69
69
  n: int,
70
70
  refresh_rate: float = 1,
71
71
  endpoint_url: Optional[str] = "http://localhost:8000",
72
72
  job_uuid: Optional[UUID] = None,
73
73
  api_key: str = None,
74
74
  ):
75
- self.jobs_runner = jobs_runner
75
+ self.jobs = jobs
76
76
  self.job_uuid = job_uuid
77
77
  self.base_url = f"{endpoint_url}"
78
78
  self.refresh_rate = refresh_rate
@@ -86,10 +86,10 @@ class JobsRunnerStatusBase(ABC):
86
86
  "unfixed_exceptions",
87
87
  "throughput",
88
88
  ]
89
- self.num_total_interviews = n * len(self.jobs_runner)
89
+ self.num_total_interviews = n * len(self.jobs)
90
90
 
91
91
  self.distinct_models = list(
92
- set(model.model for model in self.jobs_runner.jobs.models)
92
+ set(model.model for model in self.jobs.models)
93
93
  )
94
94
 
95
95
  self.stats_tracker = StatisticsTracker(
@@ -151,26 +151,31 @@ class JobsRunnerStatusBase(ABC):
151
151
  }
152
152
 
153
153
  model_queues = {}
154
- # for model, bucket in self.jobs_runner.bucket_collection.items():
155
- for model, bucket in self.jobs_runner.environment.bucket_collection.items():
156
- model_name = model.model
157
- model_queues[model_name] = {
158
- "language_model_name": model_name,
159
- "requests_bucket": {
160
- "completed": bucket.requests_bucket.num_released,
161
- "requested": bucket.requests_bucket.num_requests,
162
- "tokens_returned": bucket.requests_bucket.tokens_returned,
163
- "target_rate": round(bucket.requests_bucket.target_rate, 1),
164
- "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
165
- },
166
- "tokens_bucket": {
167
- "completed": bucket.tokens_bucket.num_released,
168
- "requested": bucket.tokens_bucket.num_requests,
169
- "tokens_returned": bucket.tokens_bucket.tokens_returned,
170
- "target_rate": round(bucket.tokens_bucket.target_rate, 1),
171
- "current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
172
- },
173
- }
154
+ # Check if bucket collection exists and is not empty
155
+ if (hasattr(self.jobs, 'run_config') and
156
+ hasattr(self.jobs.run_config, 'environment') and
157
+ hasattr(self.jobs.run_config.environment, 'bucket_collection') and
158
+ self.jobs.run_config.environment.bucket_collection):
159
+
160
+ for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
161
+ model_name = model.model
162
+ model_queues[model_name] = {
163
+ "language_model_name": model_name,
164
+ "requests_bucket": {
165
+ "completed": bucket.requests_bucket.num_released,
166
+ "requested": bucket.requests_bucket.num_requests,
167
+ "tokens_returned": bucket.requests_bucket.tokens_returned,
168
+ "target_rate": round(bucket.requests_bucket.target_rate, 1),
169
+ "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
170
+ },
171
+ "tokens_bucket": {
172
+ "completed": bucket.tokens_bucket.num_released,
173
+ "requested": bucket.tokens_bucket.num_requests,
174
+ "tokens_returned": bucket.tokens_bucket.tokens_returned,
175
+ "target_rate": round(bucket.tokens_bucket.target_rate, 1),
176
+ "current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
177
+ },
178
+ }
174
179
  status_dict["language_model_queues"] = model_queues
175
180
  return status_dict
176
181
 
@@ -0,0 +1,79 @@
1
+ """
2
+ Progress bar management for asynchronous job execution.
3
+
4
+ This module provides a context manager for handling progress bar setup and thread
5
+ management during job execution. It coordinates the display and updating of progress
6
+ bars, particularly for remote tracking via the Expected Parrot API.
7
+ """
8
+
9
+ import threading
10
+ import warnings
11
+
12
+ from ..coop import Coop
13
+ from .jobs_runner_status import JobsRunnerStatus
14
+
15
+
16
+ class ProgressBarManager:
17
+ """Context manager for handling progress bar setup and thread management.
18
+
19
+ This class manages the progress bar display and updating during job execution,
20
+ particularly for remote tracking via the Expected Parrot API.
21
+
22
+ It handles:
23
+ 1. Setting up a status tracking object
24
+ 2. Creating and managing a background thread for progress updates
25
+ 3. Properly cleaning up resources when execution completes
26
+ """
27
+
28
+ def __init__(self, jobs, run_config, parameters):
29
+ self.parameters = parameters
30
+ self.jobs = jobs
31
+
32
+ # Set up progress tracking
33
+ coop = Coop()
34
+ endpoint_url = coop.get_progress_bar_url()
35
+
36
+ # Set up jobs status object
37
+ params = {
38
+ "jobs": jobs,
39
+ "n": parameters.n,
40
+ "endpoint_url": endpoint_url,
41
+ "job_uuid": parameters.job_uuid,
42
+ }
43
+
44
+ # If the jobs_runner_status is already set, use it directly
45
+ if run_config.environment.jobs_runner_status is not None:
46
+ self.jobs_runner_status = run_config.environment.jobs_runner_status
47
+ else:
48
+ # Otherwise create a new one
49
+ self.jobs_runner_status = JobsRunnerStatus(**params)
50
+
51
+ # Store on run_config for use by other components
52
+ run_config.environment.jobs_runner_status = self.jobs_runner_status
53
+
54
+ self.progress_thread = None
55
+ self.stop_event = threading.Event()
56
+
57
+ def __enter__(self):
58
+ if self.parameters.progress_bar and self.jobs_runner_status.has_ep_api_key():
59
+ self.jobs_runner_status.setup()
60
+ self.progress_thread = threading.Thread(
61
+ target=self._run_progress_bar,
62
+ args=(self.stop_event, self.jobs_runner_status)
63
+ )
64
+ self.progress_thread.start()
65
+ elif self.parameters.progress_bar:
66
+ warnings.warn(
67
+ "You need an Expected Parrot API key to view job progress bars."
68
+ )
69
+ return self.stop_event
70
+
71
+ def __exit__(self, exc_type, exc_val, exc_tb):
72
+ self.stop_event.set()
73
+ if self.progress_thread is not None:
74
+ self.progress_thread.join()
75
+
76
+ @staticmethod
77
+ def _run_progress_bar(stop_event, jobs_runner_status):
78
+ """Runs the progress bar in a separate thread."""
79
+ jobs_runner_status.update_progress(stop_event)
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
2
3
  from dataclasses import dataclass
3
4
  from ..coop import CoopServerResponseError
@@ -112,13 +113,18 @@ class JobsRemoteInferenceHandler:
112
113
  )
113
114
  logger.add_info("job_uuid", job_uuid)
114
115
 
116
+ remote_inference_url = self.remote_inference_url
117
+ if "localhost" in remote_inference_url:
118
+ remote_inference_url = remote_inference_url.replace("8000", "1234")
115
119
  logger.update(
116
- f"Job details are available at your Coop account. [Go to Remote Inference page]({self.remote_inference_url})",
120
+ f"Job details are available at your Coop account. [Go to Remote Inference page]({remote_inference_url})",
117
121
  status=JobsStatus.RUNNING,
118
122
  )
119
123
  progress_bar_url = (
120
124
  f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
121
125
  )
126
+ if "localhost" in progress_bar_url:
127
+ progress_bar_url = progress_bar_url.replace("8000", "1234")
122
128
  logger.add_info("progress_bar_url", progress_bar_url)
123
129
  logger.update(
124
130
  f"View job progress [here]({progress_bar_url})", status=JobsStatus.RUNNING
@@ -200,10 +206,35 @@ class JobsRemoteInferenceHandler:
200
206
  status=JobsStatus.FAILED,
201
207
  )
202
208
 
209
+ def _handle_partially_failed_job_interview_details(
210
+ self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
211
+ ) -> None:
212
+ "Extracts the interview details from the remote job data."
213
+ try:
214
+ # Job details is a string of the form "64 out of 1,758 interviews failed"
215
+ job_details = remote_job_data.get("latest_failure_description")
216
+
217
+ text_without_commas = job_details.replace(",", "")
218
+
219
+ # Find all numbers in the text
220
+ numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
221
+
222
+ failed = numbers[0]
223
+ total = numbers[1]
224
+ completed = total - failed
225
+
226
+ job_info.logger.add_info("completed_interviews", completed)
227
+ job_info.logger.add_info("failed_interviews", failed)
228
+ # This is mainly helpful metadata, and any errors here should not stop the code
229
+ except:
230
+ pass
231
+
203
232
  def _handle_partially_failed_job(
204
233
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
205
234
  ) -> None:
206
235
  "Handles a partially failed job by logging the error and updating the job status."
236
+ self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
237
+
207
238
  latest_error_report_url = remote_job_data.get("latest_error_report_url")
208
239
 
209
240
  if latest_error_report_url:
@@ -244,6 +275,8 @@ class JobsRemoteInferenceHandler:
244
275
  job_info.logger.add_info("results_uuid", results_uuid)
245
276
  results = object_fetcher(results_uuid, expected_object_type="results")
246
277
  results_url = remote_job_data.get("results_url")
278
+ if "localhost" in results_url:
279
+ results_url = results_url.replace("8000", "1234")
247
280
  job_info.logger.add_info("results_url", results_url)
248
281
 
249
282
  if job_status == "completed":
@@ -256,6 +289,7 @@ class JobsRemoteInferenceHandler:
256
289
  f"View partial results [here]({results_url})",
257
290
  status=JobsStatus.PARTIALLY_FAILED,
258
291
  )
292
+
259
293
  results.job_uuid = job_info.job_uuid
260
294
  results.results_uuid = results_uuid
261
295
  return results
@@ -2,6 +2,7 @@ from typing import Optional, TYPE_CHECKING
2
2
  import os
3
3
  from functools import lru_cache
4
4
  import textwrap
5
+ import requests
5
6
 
6
7
  if TYPE_CHECKING:
7
8
  from ..coop import Coop
@@ -255,7 +256,11 @@ class KeyLookupBuilder:
255
256
  return dict(list(os.environ.items()))
256
257
 
257
258
  def _coop_key_value_pairs(self):
258
- return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
259
+ try:
260
+ return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
261
+ except requests.ConnectionError:
262
+ # If connection fails, return empty dict instead of raising error
263
+ return {}
259
264
 
260
265
  def _config_key_value_pairs(self):
261
266
  from ..config import CONFIG