edsl 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +3 -2
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +105 -7
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/invigilators/invigilators.py +10 -1
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +127 -46
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +102 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +45 -75
- edsl/language_models/registry.py +5 -0
- edsl/language_models/utilities.py +2 -1
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_check_box.py +171 -149
- edsl/questions/question_dict.py +243 -51
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +63 -16
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +117 -6
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/METADATA +52 -76
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/RECORD +102 -78
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import math
|
3
3
|
|
4
|
-
from typing import List, TYPE_CHECKING
|
4
|
+
from typing import List, TYPE_CHECKING, Union, Literal
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from .jobs import Jobs
|
@@ -26,53 +26,104 @@ class PromptCostEstimator:
|
|
26
26
|
OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
|
27
27
|
PIPING_MULTIPLIER = 2
|
28
28
|
|
29
|
-
def __init__(
|
29
|
+
def __init__(
|
30
|
+
self,
|
30
31
|
system_prompt: str,
|
31
32
|
user_prompt: str,
|
32
33
|
price_lookup: dict,
|
33
34
|
inference_service: str,
|
34
|
-
model: str
|
35
|
+
model: str,
|
36
|
+
):
|
35
37
|
self.system_prompt = system_prompt
|
36
38
|
self.user_prompt = user_prompt
|
37
39
|
self.price_lookup = price_lookup
|
38
40
|
self.inference_service = inference_service
|
39
41
|
self.model = model
|
40
42
|
|
41
|
-
@staticmethod
|
43
|
+
@staticmethod
|
42
44
|
def get_piping_multiplier(prompt: str):
|
43
45
|
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
44
46
|
|
45
47
|
if "{{" in prompt and "}}" in prompt:
|
46
48
|
return PromptCostEstimator.PIPING_MULTIPLIER
|
47
49
|
return 1
|
48
|
-
|
50
|
+
|
49
51
|
@property
|
50
52
|
def key(self):
|
51
53
|
return (self.inference_service, self.model)
|
52
|
-
|
54
|
+
|
53
55
|
@property
|
54
56
|
def relevant_prices(self):
|
55
57
|
try:
|
56
58
|
return self.price_lookup[self.key]
|
57
59
|
except KeyError:
|
58
60
|
return {}
|
59
|
-
|
60
|
-
def
|
61
|
+
|
62
|
+
def _get_highest_price_for_service(self, price_type: str) -> Union[float, None]:
|
63
|
+
"""Returns the highest price per token for a given service and price type (input/output).
|
64
|
+
|
65
|
+
Args:
|
66
|
+
price_type: Either "input" or "output"
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
float | None: The highest price per token for the service, or None if not found
|
70
|
+
"""
|
71
|
+
prices_for_service = [
|
72
|
+
prices[price_type]["service_stated_token_price"]
|
73
|
+
/ prices[price_type]["service_stated_token_qty"]
|
74
|
+
for (service, _), prices in self.price_lookup.items()
|
75
|
+
if service == self.inference_service and price_type in prices
|
76
|
+
]
|
77
|
+
return max(prices_for_service) if prices_for_service else None
|
78
|
+
|
79
|
+
def input_price_per_token(
|
80
|
+
self,
|
81
|
+
) -> tuple[float, Literal["price_lookup", "highest_price_for_service", "default"]]:
|
61
82
|
try:
|
62
|
-
return
|
83
|
+
return (
|
84
|
+
self.relevant_prices["input"]["service_stated_token_price"]
|
85
|
+
/ self.relevant_prices["input"]["service_stated_token_qty"]
|
86
|
+
), "price_lookup"
|
63
87
|
except KeyError:
|
88
|
+
highest_price = self._get_highest_price_for_service("input")
|
89
|
+
if highest_price is not None:
|
90
|
+
import warnings
|
91
|
+
|
92
|
+
warnings.warn(
|
93
|
+
f"Price data not found for {self.key}. Using highest available input price for {self.inference_service}: ${highest_price:.6f} per token"
|
94
|
+
)
|
95
|
+
return highest_price, "highest_price_for_service"
|
64
96
|
import warnings
|
97
|
+
|
65
98
|
warnings.warn(
|
66
|
-
"Price data
|
99
|
+
f"Price data not found for {self.inference_service}. Using default estimate for input token price: $1.00 / 1M tokens"
|
67
100
|
)
|
68
|
-
return self.DEFAULT_INPUT_PRICE_PER_TOKEN
|
101
|
+
return self.DEFAULT_INPUT_PRICE_PER_TOKEN, "default"
|
69
102
|
|
70
|
-
def output_price_per_token(
|
103
|
+
def output_price_per_token(
|
104
|
+
self,
|
105
|
+
) -> tuple[float, Literal["price_lookup", "highest_price_for_service", "default"]]:
|
71
106
|
try:
|
72
|
-
return
|
107
|
+
return (
|
108
|
+
self.relevant_prices["output"]["service_stated_token_price"]
|
109
|
+
/ self.relevant_prices["output"]["service_stated_token_qty"]
|
110
|
+
), "price_lookup"
|
73
111
|
except KeyError:
|
74
|
-
|
75
|
-
|
112
|
+
highest_price = self._get_highest_price_for_service("output")
|
113
|
+
if highest_price is not None:
|
114
|
+
import warnings
|
115
|
+
|
116
|
+
warnings.warn(
|
117
|
+
f"Price data not found for {self.key}. Using highest available output price for {self.inference_service}: ${highest_price:.6f} per token"
|
118
|
+
)
|
119
|
+
return highest_price, "highest_price_for_service"
|
120
|
+
import warnings
|
121
|
+
|
122
|
+
warnings.warn(
|
123
|
+
f"Price data not found for {self.inference_service}. Using default estimate for output token price: $1.00 / 1M tokens"
|
124
|
+
)
|
125
|
+
return self.DEFAULT_OUTPUT_PRICE_PER_TOKEN, "default"
|
126
|
+
|
76
127
|
def __call__(self):
|
77
128
|
user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
|
78
129
|
str(self.user_prompt)
|
@@ -84,20 +135,37 @@ class PromptCostEstimator:
|
|
84
135
|
input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
|
85
136
|
output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
|
86
137
|
|
138
|
+
input_price_per_token, input_price_source = self.input_price_per_token()
|
139
|
+
output_price_per_token, output_price_source = self.output_price_per_token()
|
140
|
+
|
87
141
|
cost = (
|
88
|
-
input_tokens *
|
89
|
-
+ output_tokens *
|
142
|
+
input_tokens * input_price_per_token
|
143
|
+
+ output_tokens * output_price_per_token
|
90
144
|
)
|
91
145
|
return {
|
146
|
+
"input_price_source": input_price_source,
|
147
|
+
"input_price_per_token": input_price_per_token,
|
92
148
|
"input_tokens": input_tokens,
|
149
|
+
"output_price_source": output_price_source,
|
93
150
|
"output_tokens": output_tokens,
|
151
|
+
"output_price_per_token": output_price_per_token,
|
94
152
|
"cost_usd": cost,
|
95
153
|
}
|
96
154
|
|
97
155
|
|
98
156
|
class JobsPrompts:
|
99
157
|
|
100
|
-
relevant_keys = [
|
158
|
+
relevant_keys = [
|
159
|
+
"user_prompt",
|
160
|
+
"system_prompt",
|
161
|
+
"interview_index",
|
162
|
+
"question_name",
|
163
|
+
"scenario_index",
|
164
|
+
"agent_index",
|
165
|
+
"model",
|
166
|
+
"estimated_cost",
|
167
|
+
"cache_keys",
|
168
|
+
]
|
101
169
|
|
102
170
|
"""This generates the prompts for a job for price estimation purposes.
|
103
171
|
|
@@ -105,7 +173,6 @@ class JobsPrompts:
|
|
105
173
|
So assumptions are made about expansion of Jinja braces, etc.
|
106
174
|
"""
|
107
175
|
|
108
|
-
|
109
176
|
@classmethod
|
110
177
|
def from_jobs(cls, jobs: "Jobs"):
|
111
178
|
"""Construct a JobsPrompts object from a Jobs object."""
|
@@ -114,13 +181,16 @@ class JobsPrompts:
|
|
114
181
|
scenarios = jobs.scenarios
|
115
182
|
survey = jobs.survey
|
116
183
|
return cls(
|
117
|
-
interviews=interviews,
|
118
|
-
agents=agents,
|
119
|
-
scenarios=scenarios,
|
120
|
-
survey=survey
|
184
|
+
interviews=interviews, agents=agents, scenarios=scenarios, survey=survey
|
121
185
|
)
|
122
|
-
|
123
|
-
def __init__(
|
186
|
+
|
187
|
+
def __init__(
|
188
|
+
self,
|
189
|
+
interviews: List["Interview"],
|
190
|
+
agents: "AgentList",
|
191
|
+
scenarios: "ScenarioList",
|
192
|
+
survey: "Survey",
|
193
|
+
):
|
124
194
|
"""Initialize with extracted components rather than a Jobs object."""
|
125
195
|
self.interviews = interviews
|
126
196
|
self.agents = agents
|
@@ -143,17 +213,19 @@ class JobsPrompts:
|
|
143
213
|
self._price_lookup = c.fetch_prices()
|
144
214
|
return self._price_lookup
|
145
215
|
|
146
|
-
def _process_one_invigilator(
|
216
|
+
def _process_one_invigilator(
|
217
|
+
self, invigilator: "Invigilator", interview_index: int, iterations: int = 1
|
218
|
+
) -> dict:
|
147
219
|
"""Process a single invigilator and return a dictionary with all needed data fields."""
|
148
220
|
prompts = invigilator.get_prompts()
|
149
221
|
user_prompt = prompts["user_prompt"]
|
150
222
|
system_prompt = prompts["system_prompt"]
|
151
|
-
|
223
|
+
|
152
224
|
agent_index = self._agent_lookup[invigilator.agent]
|
153
225
|
scenario_index = self._scenario_lookup[invigilator.scenario]
|
154
226
|
model = invigilator.model.model
|
155
227
|
question_name = invigilator.question.question_name
|
156
|
-
|
228
|
+
|
157
229
|
# Calculate prompt cost
|
158
230
|
prompt_cost = self.estimate_prompt_cost(
|
159
231
|
system_prompt=system_prompt,
|
@@ -163,7 +235,7 @@ class JobsPrompts:
|
|
163
235
|
model=model,
|
164
236
|
)
|
165
237
|
cost = prompt_cost["cost_usd"]
|
166
|
-
|
238
|
+
|
167
239
|
# Generate cache keys for each iteration
|
168
240
|
cache_keys = []
|
169
241
|
for iteration in range(iterations):
|
@@ -175,7 +247,7 @@ class JobsPrompts:
|
|
175
247
|
iteration=iteration,
|
176
248
|
)
|
177
249
|
cache_keys.append(cache_key)
|
178
|
-
|
250
|
+
|
179
251
|
d = {
|
180
252
|
"user_prompt": user_prompt,
|
181
253
|
"system_prompt": system_prompt,
|
@@ -200,7 +272,7 @@ class JobsPrompts:
|
|
200
272
|
dataset_of_prompts = {k: [] for k in self.relevant_keys}
|
201
273
|
|
202
274
|
interviews = self.interviews
|
203
|
-
|
275
|
+
|
204
276
|
# Process each interview and invigilator
|
205
277
|
for interview_index, interview in enumerate(interviews):
|
206
278
|
invigilators = [
|
@@ -210,11 +282,13 @@ class JobsPrompts:
|
|
210
282
|
|
211
283
|
for invigilator in invigilators:
|
212
284
|
# Process the invigilator and get all data as a dictionary
|
213
|
-
data = self._process_one_invigilator(
|
285
|
+
data = self._process_one_invigilator(
|
286
|
+
invigilator, interview_index, iterations
|
287
|
+
)
|
214
288
|
for k in self.relevant_keys:
|
215
289
|
dataset_of_prompts[k].append(data[k])
|
216
|
-
|
217
|
-
return Dataset([{k:dataset_of_prompts[k]} for k in self.relevant_keys])
|
290
|
+
|
291
|
+
return Dataset([{k: dataset_of_prompts[k]} for k in self.relevant_keys])
|
218
292
|
|
219
293
|
@staticmethod
|
220
294
|
def estimate_prompt_cost(
|
@@ -230,13 +304,13 @@ class JobsPrompts:
|
|
230
304
|
user_prompt=user_prompt,
|
231
305
|
price_lookup=price_lookup,
|
232
306
|
inference_service=inference_service,
|
233
|
-
model=model
|
307
|
+
model=model,
|
234
308
|
)()
|
235
|
-
|
309
|
+
|
236
310
|
@staticmethod
|
237
311
|
def _extract_prompt_details(invigilator: FetchInvigilator) -> dict:
|
238
312
|
"""Extracts the prompt details from the invigilator.
|
239
|
-
|
313
|
+
|
240
314
|
>>> from edsl.invigilators import InvigilatorAI
|
241
315
|
>>> invigilator = InvigilatorAI.example()
|
242
316
|
>>> JobsPrompts._extract_prompt_details(invigilator)
|
@@ -276,11 +350,13 @@ class JobsPrompts:
|
|
276
350
|
]
|
277
351
|
for invigilator in invigilators:
|
278
352
|
prompt_details = self._extract_prompt_details(invigilator)
|
279
|
-
prompt_cost = self.estimate_prompt_cost(
|
353
|
+
prompt_cost = self.estimate_prompt_cost(
|
354
|
+
**prompt_details, price_lookup=price_lookup
|
355
|
+
)
|
280
356
|
price_estimates = {
|
281
|
-
|
282
|
-
|
283
|
-
|
357
|
+
"estimated_input_tokens": prompt_cost["input_tokens"],
|
358
|
+
"estimated_output_tokens": prompt_cost["output_tokens"],
|
359
|
+
"estimated_cost_usd": prompt_cost["cost_usd"],
|
284
360
|
}
|
285
361
|
data.append({**price_estimates, **prompt_details})
|
286
362
|
|
@@ -293,14 +369,18 @@ class JobsPrompts:
|
|
293
369
|
"model": item["model"],
|
294
370
|
"estimated_cost_usd": 0,
|
295
371
|
"estimated_input_tokens": 0,
|
296
|
-
"estimated_output_tokens": 0
|
372
|
+
"estimated_output_tokens": 0,
|
297
373
|
}
|
298
|
-
|
374
|
+
|
299
375
|
# Accumulate values
|
300
376
|
model_groups[key]["estimated_cost_usd"] += item["estimated_cost_usd"]
|
301
|
-
model_groups[key]["estimated_input_tokens"] += item[
|
302
|
-
|
303
|
-
|
377
|
+
model_groups[key]["estimated_input_tokens"] += item[
|
378
|
+
"estimated_input_tokens"
|
379
|
+
]
|
380
|
+
model_groups[key]["estimated_output_tokens"] += item[
|
381
|
+
"estimated_output_tokens"
|
382
|
+
]
|
383
|
+
|
304
384
|
# Apply iterations and convert to list
|
305
385
|
estimated_costs_by_model = []
|
306
386
|
for group_data in model_groups.values():
|
@@ -345,4 +425,5 @@ class JobsPrompts:
|
|
345
425
|
|
346
426
|
if __name__ == "__main__":
|
347
427
|
import doctest
|
428
|
+
|
348
429
|
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -30,6 +30,8 @@ class JobsInfo:
|
|
30
30
|
error_report_url: str = None
|
31
31
|
results_uuid: str = None
|
32
32
|
results_url: str = None
|
33
|
+
completed_interviews: int = None
|
34
|
+
failed_interviews: int = None
|
33
35
|
|
34
36
|
pretty_names = {
|
35
37
|
"job_uuid": "Job UUID",
|
@@ -53,6 +55,8 @@ class JobLogger(ABC):
|
|
53
55
|
"error_report_url",
|
54
56
|
"results_uuid",
|
55
57
|
"results_url",
|
58
|
+
"completed_interviews",
|
59
|
+
"failed_interviews",
|
56
60
|
],
|
57
61
|
value: str,
|
58
62
|
):
|
edsl/jobs/jobs_runner_status.py
CHANGED
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
|
|
10
10
|
from uuid import UUID
|
11
11
|
|
12
12
|
if TYPE_CHECKING:
|
13
|
-
from .
|
13
|
+
from .jobs import Jobs
|
14
14
|
|
15
15
|
|
16
16
|
@dataclass
|
@@ -65,14 +65,14 @@ class StatisticsTracker:
|
|
65
65
|
class JobsRunnerStatusBase(ABC):
|
66
66
|
def __init__(
|
67
67
|
self,
|
68
|
-
|
68
|
+
jobs: "Jobs",
|
69
69
|
n: int,
|
70
70
|
refresh_rate: float = 1,
|
71
71
|
endpoint_url: Optional[str] = "http://localhost:8000",
|
72
72
|
job_uuid: Optional[UUID] = None,
|
73
73
|
api_key: str = None,
|
74
74
|
):
|
75
|
-
self.
|
75
|
+
self.jobs = jobs
|
76
76
|
self.job_uuid = job_uuid
|
77
77
|
self.base_url = f"{endpoint_url}"
|
78
78
|
self.refresh_rate = refresh_rate
|
@@ -86,10 +86,10 @@ class JobsRunnerStatusBase(ABC):
|
|
86
86
|
"unfixed_exceptions",
|
87
87
|
"throughput",
|
88
88
|
]
|
89
|
-
self.num_total_interviews = n * len(self.
|
89
|
+
self.num_total_interviews = n * len(self.jobs)
|
90
90
|
|
91
91
|
self.distinct_models = list(
|
92
|
-
set(model.model for model in self.
|
92
|
+
set(model.model for model in self.jobs.models)
|
93
93
|
)
|
94
94
|
|
95
95
|
self.stats_tracker = StatisticsTracker(
|
@@ -151,26 +151,31 @@ class JobsRunnerStatusBase(ABC):
|
|
151
151
|
}
|
152
152
|
|
153
153
|
model_queues = {}
|
154
|
-
#
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
"
|
164
|
-
"
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
"
|
172
|
-
|
173
|
-
|
154
|
+
# Check if bucket collection exists and is not empty
|
155
|
+
if (hasattr(self.jobs, 'run_config') and
|
156
|
+
hasattr(self.jobs.run_config, 'environment') and
|
157
|
+
hasattr(self.jobs.run_config.environment, 'bucket_collection') and
|
158
|
+
self.jobs.run_config.environment.bucket_collection):
|
159
|
+
|
160
|
+
for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
|
161
|
+
model_name = model.model
|
162
|
+
model_queues[model_name] = {
|
163
|
+
"language_model_name": model_name,
|
164
|
+
"requests_bucket": {
|
165
|
+
"completed": bucket.requests_bucket.num_released,
|
166
|
+
"requested": bucket.requests_bucket.num_requests,
|
167
|
+
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
168
|
+
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
169
|
+
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
170
|
+
},
|
171
|
+
"tokens_bucket": {
|
172
|
+
"completed": bucket.tokens_bucket.num_released,
|
173
|
+
"requested": bucket.tokens_bucket.num_requests,
|
174
|
+
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
175
|
+
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
176
|
+
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
177
|
+
},
|
178
|
+
}
|
174
179
|
status_dict["language_model_queues"] = model_queues
|
175
180
|
return status_dict
|
176
181
|
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Progress bar management for asynchronous job execution.
|
3
|
+
|
4
|
+
This module provides a context manager for handling progress bar setup and thread
|
5
|
+
management during job execution. It coordinates the display and updating of progress
|
6
|
+
bars, particularly for remote tracking via the Expected Parrot API.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import threading
|
10
|
+
import warnings
|
11
|
+
|
12
|
+
from ..coop import Coop
|
13
|
+
from .jobs_runner_status import JobsRunnerStatus
|
14
|
+
|
15
|
+
|
16
|
+
class ProgressBarManager:
|
17
|
+
"""Context manager for handling progress bar setup and thread management.
|
18
|
+
|
19
|
+
This class manages the progress bar display and updating during job execution,
|
20
|
+
particularly for remote tracking via the Expected Parrot API.
|
21
|
+
|
22
|
+
It handles:
|
23
|
+
1. Setting up a status tracking object
|
24
|
+
2. Creating and managing a background thread for progress updates
|
25
|
+
3. Properly cleaning up resources when execution completes
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, jobs, run_config, parameters):
|
29
|
+
self.parameters = parameters
|
30
|
+
self.jobs = jobs
|
31
|
+
|
32
|
+
# Set up progress tracking
|
33
|
+
coop = Coop()
|
34
|
+
endpoint_url = coop.get_progress_bar_url()
|
35
|
+
|
36
|
+
# Set up jobs status object
|
37
|
+
params = {
|
38
|
+
"jobs": jobs,
|
39
|
+
"n": parameters.n,
|
40
|
+
"endpoint_url": endpoint_url,
|
41
|
+
"job_uuid": parameters.job_uuid,
|
42
|
+
}
|
43
|
+
|
44
|
+
# If the jobs_runner_status is already set, use it directly
|
45
|
+
if run_config.environment.jobs_runner_status is not None:
|
46
|
+
self.jobs_runner_status = run_config.environment.jobs_runner_status
|
47
|
+
else:
|
48
|
+
# Otherwise create a new one
|
49
|
+
self.jobs_runner_status = JobsRunnerStatus(**params)
|
50
|
+
|
51
|
+
# Store on run_config for use by other components
|
52
|
+
run_config.environment.jobs_runner_status = self.jobs_runner_status
|
53
|
+
|
54
|
+
self.progress_thread = None
|
55
|
+
self.stop_event = threading.Event()
|
56
|
+
|
57
|
+
def __enter__(self):
|
58
|
+
if self.parameters.progress_bar and self.jobs_runner_status.has_ep_api_key():
|
59
|
+
self.jobs_runner_status.setup()
|
60
|
+
self.progress_thread = threading.Thread(
|
61
|
+
target=self._run_progress_bar,
|
62
|
+
args=(self.stop_event, self.jobs_runner_status)
|
63
|
+
)
|
64
|
+
self.progress_thread.start()
|
65
|
+
elif self.parameters.progress_bar:
|
66
|
+
warnings.warn(
|
67
|
+
"You need an Expected Parrot API key to view job progress bars."
|
68
|
+
)
|
69
|
+
return self.stop_event
|
70
|
+
|
71
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
72
|
+
self.stop_event.set()
|
73
|
+
if self.progress_thread is not None:
|
74
|
+
self.progress_thread.join()
|
75
|
+
|
76
|
+
@staticmethod
|
77
|
+
def _run_progress_bar(stop_event, jobs_runner_status):
|
78
|
+
"""Runs the progress bar in a separate thread."""
|
79
|
+
jobs_runner_status.update_progress(stop_event)
|
edsl/jobs/remote_inference.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
|
2
3
|
from dataclasses import dataclass
|
3
4
|
from ..coop import CoopServerResponseError
|
@@ -112,13 +113,18 @@ class JobsRemoteInferenceHandler:
|
|
112
113
|
)
|
113
114
|
logger.add_info("job_uuid", job_uuid)
|
114
115
|
|
116
|
+
remote_inference_url = self.remote_inference_url
|
117
|
+
if "localhost" in remote_inference_url:
|
118
|
+
remote_inference_url = remote_inference_url.replace("8000", "1234")
|
115
119
|
logger.update(
|
116
|
-
f"Job details are available at your Coop account. [Go to Remote Inference page]({
|
120
|
+
f"Job details are available at your Coop account. [Go to Remote Inference page]({remote_inference_url})",
|
117
121
|
status=JobsStatus.RUNNING,
|
118
122
|
)
|
119
123
|
progress_bar_url = (
|
120
124
|
f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
121
125
|
)
|
126
|
+
if "localhost" in progress_bar_url:
|
127
|
+
progress_bar_url = progress_bar_url.replace("8000", "1234")
|
122
128
|
logger.add_info("progress_bar_url", progress_bar_url)
|
123
129
|
logger.update(
|
124
130
|
f"View job progress [here]({progress_bar_url})", status=JobsStatus.RUNNING
|
@@ -200,10 +206,35 @@ class JobsRemoteInferenceHandler:
|
|
200
206
|
status=JobsStatus.FAILED,
|
201
207
|
)
|
202
208
|
|
209
|
+
def _handle_partially_failed_job_interview_details(
|
210
|
+
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
211
|
+
) -> None:
|
212
|
+
"Extracts the interview details from the remote job data."
|
213
|
+
try:
|
214
|
+
# Job details is a string of the form "64 out of 1,758 interviews failed"
|
215
|
+
job_details = remote_job_data.get("latest_failure_description")
|
216
|
+
|
217
|
+
text_without_commas = job_details.replace(",", "")
|
218
|
+
|
219
|
+
# Find all numbers in the text
|
220
|
+
numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
|
221
|
+
|
222
|
+
failed = numbers[0]
|
223
|
+
total = numbers[1]
|
224
|
+
completed = total - failed
|
225
|
+
|
226
|
+
job_info.logger.add_info("completed_interviews", completed)
|
227
|
+
job_info.logger.add_info("failed_interviews", failed)
|
228
|
+
# This is mainly helpful metadata, and any errors here should not stop the code
|
229
|
+
except:
|
230
|
+
pass
|
231
|
+
|
203
232
|
def _handle_partially_failed_job(
|
204
233
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
205
234
|
) -> None:
|
206
235
|
"Handles a partially failed job by logging the error and updating the job status."
|
236
|
+
self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
|
237
|
+
|
207
238
|
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
208
239
|
|
209
240
|
if latest_error_report_url:
|
@@ -244,6 +275,8 @@ class JobsRemoteInferenceHandler:
|
|
244
275
|
job_info.logger.add_info("results_uuid", results_uuid)
|
245
276
|
results = object_fetcher(results_uuid, expected_object_type="results")
|
246
277
|
results_url = remote_job_data.get("results_url")
|
278
|
+
if "localhost" in results_url:
|
279
|
+
results_url = results_url.replace("8000", "1234")
|
247
280
|
job_info.logger.add_info("results_url", results_url)
|
248
281
|
|
249
282
|
if job_status == "completed":
|
@@ -256,6 +289,7 @@ class JobsRemoteInferenceHandler:
|
|
256
289
|
f"View partial results [here]({results_url})",
|
257
290
|
status=JobsStatus.PARTIALLY_FAILED,
|
258
291
|
)
|
292
|
+
|
259
293
|
results.job_uuid = job_info.job_uuid
|
260
294
|
results.results_uuid = results_uuid
|
261
295
|
return results
|
@@ -2,6 +2,7 @@ from typing import Optional, TYPE_CHECKING
|
|
2
2
|
import os
|
3
3
|
from functools import lru_cache
|
4
4
|
import textwrap
|
5
|
+
import requests
|
5
6
|
|
6
7
|
if TYPE_CHECKING:
|
7
8
|
from ..coop import Coop
|
@@ -255,7 +256,11 @@ class KeyLookupBuilder:
|
|
255
256
|
return dict(list(os.environ.items()))
|
256
257
|
|
257
258
|
def _coop_key_value_pairs(self):
|
258
|
-
|
259
|
+
try:
|
260
|
+
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
|
261
|
+
except requests.ConnectionError:
|
262
|
+
# If connection fails, return empty dict instead of raising error
|
263
|
+
return {}
|
259
264
|
|
260
265
|
def _config_key_value_pairs(self):
|
261
266
|
from ..config import CONFIG
|