edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/base/data_transfer_models.py +15 -4
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/dataset/dataset_operations_mixin.py +216 -180
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +7 -3
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +94 -5
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/interviews/request_token_estimator.py +8 -0
- edsl/invigilators/invigilators.py +35 -13
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +154 -113
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +110 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +176 -71
- edsl/language_models/registry.py +5 -0
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_dict.py +201 -16
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +115 -46
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +150 -9
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
import math
|
3
3
|
|
4
|
-
from typing import List, TYPE_CHECKING, Union, Literal
|
4
|
+
from typing import List, TYPE_CHECKING, Union, Literal, Dict
|
5
|
+
from collections import namedtuple
|
5
6
|
|
6
7
|
if TYPE_CHECKING:
|
7
8
|
from .jobs import Jobs
|
@@ -20,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|
20
21
|
|
21
22
|
class PromptCostEstimator:
|
22
23
|
|
23
|
-
|
24
|
-
|
24
|
+
DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
25
|
+
DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
25
26
|
CHARS_PER_TOKEN = 4
|
26
27
|
OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
|
27
28
|
PIPING_MULTIPLIER = 2
|
@@ -48,81 +49,90 @@ class PromptCostEstimator:
|
|
48
49
|
return PromptCostEstimator.PIPING_MULTIPLIER
|
49
50
|
return 1
|
50
51
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
def relevant_prices(self):
|
57
|
-
try:
|
58
|
-
return self.price_lookup[self.key]
|
59
|
-
except KeyError:
|
60
|
-
return {}
|
61
|
-
|
62
|
-
def _get_highest_price_for_service(self, price_type: str) -> Union[float, None]:
|
63
|
-
"""Returns the highest price per token for a given service and price type (input/output).
|
52
|
+
def _get_fallback_price(self, inference_service: str) -> Dict:
|
53
|
+
"""
|
54
|
+
Get fallback prices for a service.
|
55
|
+
- First fallback: The highest input and output prices for that service from the price lookup.
|
56
|
+
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
57
|
|
65
58
|
Args:
|
66
|
-
|
59
|
+
inference_service (str): The inference service name
|
67
60
|
|
68
61
|
Returns:
|
69
|
-
|
62
|
+
Dict: Price information
|
70
63
|
"""
|
71
|
-
|
72
|
-
|
73
|
-
|
64
|
+
PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
|
65
|
+
|
66
|
+
service_prices = [
|
67
|
+
prices
|
74
68
|
for (service, _), prices in self.price_lookup.items()
|
75
|
-
if service ==
|
69
|
+
if service == inference_service
|
76
70
|
]
|
77
|
-
return max(prices_for_service) if prices_for_service else None
|
78
71
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
if highest_price is not None:
|
90
|
-
import warnings
|
91
|
-
|
92
|
-
warnings.warn(
|
93
|
-
f"Price data not found for {self.key}. Using highest available input price for {self.inference_service}: ${highest_price:.6f} per token"
|
94
|
-
)
|
95
|
-
return highest_price, "highest_price_for_service"
|
96
|
-
import warnings
|
72
|
+
default_input_price_info = {
|
73
|
+
"one_usd_buys": 1_000_000,
|
74
|
+
"service_stated_token_qty": 1_000_000,
|
75
|
+
"service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
|
76
|
+
}
|
77
|
+
default_output_price_info = {
|
78
|
+
"one_usd_buys": 1_000_000,
|
79
|
+
"service_stated_token_qty": 1_000_000,
|
80
|
+
"service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
|
81
|
+
}
|
97
82
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
return self.DEFAULT_INPUT_PRICE_PER_TOKEN, "default"
|
83
|
+
# Find the most expensive price entries (lowest tokens per USD)
|
84
|
+
input_price_info = default_input_price_info
|
85
|
+
output_price_info = default_output_price_info
|
102
86
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
87
|
+
input_prices = [
|
88
|
+
PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
|
89
|
+
for p in service_prices
|
90
|
+
if "input" in p
|
91
|
+
]
|
92
|
+
if input_prices:
|
93
|
+
input_price_info = min(
|
94
|
+
input_prices, key=lambda price: price.tokens_per_usd
|
95
|
+
).price_info
|
96
|
+
|
97
|
+
output_prices = [
|
98
|
+
PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
|
99
|
+
for p in service_prices
|
100
|
+
if "output" in p
|
101
|
+
]
|
102
|
+
if output_prices:
|
103
|
+
output_price_info = min(
|
104
|
+
output_prices, key=lambda price: price.tokens_per_usd
|
105
|
+
).price_info
|
121
106
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
107
|
+
return {
|
108
|
+
"input": input_price_info,
|
109
|
+
"output": output_price_info,
|
110
|
+
}
|
111
|
+
|
112
|
+
def get_price(self, inference_service: str, model: str) -> Dict:
|
113
|
+
"""Get the price information for a specific service and model."""
|
114
|
+
key = (inference_service, model)
|
115
|
+
return self.price_lookup.get(key) or self._get_fallback_price(inference_service)
|
116
|
+
|
117
|
+
def get_price_per_million_tokens(
|
118
|
+
self,
|
119
|
+
relevant_prices: Dict,
|
120
|
+
token_type: Literal["input", "output"],
|
121
|
+
) -> Dict:
|
122
|
+
"""
|
123
|
+
Get the price per million tokens for a specific service, model, and token type.
|
124
|
+
"""
|
125
|
+
service_price = relevant_prices[token_type]["service_stated_token_price"]
|
126
|
+
service_qty = relevant_prices[token_type]["service_stated_token_qty"]
|
127
|
+
|
128
|
+
if service_qty == 1_000_000:
|
129
|
+
price_per_million_tokens = service_price
|
130
|
+
elif service_qty == 1_000:
|
131
|
+
price_per_million_tokens = service_price * 1_000
|
132
|
+
else:
|
133
|
+
price_per_token = service_price / service_qty
|
134
|
+
price_per_million_tokens = round(price_per_token * 1_000_000, 10)
|
135
|
+
return price_per_million_tokens
|
126
136
|
|
127
137
|
def __call__(self):
|
128
138
|
user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
|
@@ -135,20 +145,28 @@ class PromptCostEstimator:
|
|
135
145
|
input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
|
136
146
|
output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
|
137
147
|
|
138
|
-
|
139
|
-
output_price_per_token, output_price_source = self.output_price_per_token()
|
148
|
+
relevant_prices = self.get_price(self.inference_service, self.model)
|
140
149
|
|
141
|
-
|
142
|
-
|
143
|
-
+ output_tokens * output_price_per_token
|
150
|
+
input_price_per_million_tokens = self.get_price_per_million_tokens(
|
151
|
+
relevant_prices, "input"
|
144
152
|
)
|
153
|
+
output_price_per_million_tokens = self.get_price_per_million_tokens(
|
154
|
+
relevant_prices, "output"
|
155
|
+
)
|
156
|
+
|
157
|
+
input_price_per_token = input_price_per_million_tokens / 1_000_000
|
158
|
+
output_price_per_token = output_price_per_million_tokens / 1_000_000
|
159
|
+
|
160
|
+
input_cost = input_tokens * input_price_per_token
|
161
|
+
output_cost = output_tokens * output_price_per_token
|
162
|
+
cost = input_cost + output_cost
|
145
163
|
return {
|
146
|
-
"
|
147
|
-
"
|
164
|
+
"input_price_per_million_tokens": input_price_per_million_tokens,
|
165
|
+
"output_price_per_million_tokens": output_price_per_million_tokens,
|
148
166
|
"input_tokens": input_tokens,
|
149
|
-
"output_price_source": output_price_source,
|
150
167
|
"output_tokens": output_tokens,
|
151
|
-
"
|
168
|
+
"input_cost_usd": input_cost,
|
169
|
+
"output_cost_usd": output_cost,
|
152
170
|
"cost_usd": cost,
|
153
171
|
}
|
154
172
|
|
@@ -328,6 +346,26 @@ class JobsPrompts:
|
|
328
346
|
"model": model,
|
329
347
|
}
|
330
348
|
|
349
|
+
def process_token_type(self, item: dict, token_type: str) -> tuple:
|
350
|
+
"""
|
351
|
+
Helper function to process a single token type (input or output) for price estimation.
|
352
|
+
"""
|
353
|
+
price = item[f"estimated_{token_type}_price_per_million_tokens"]
|
354
|
+
tokens = item[f"estimated_{token_type}_tokens"]
|
355
|
+
cost = item[f"estimated_{token_type}_cost_usd"]
|
356
|
+
|
357
|
+
return (
|
358
|
+
(item["inference_service"], item["model"], token_type, price),
|
359
|
+
{
|
360
|
+
"inference_service": item["inference_service"],
|
361
|
+
"model": item["model"],
|
362
|
+
"token_type": token_type,
|
363
|
+
"price_per_million_tokens": price,
|
364
|
+
"tokens": tokens,
|
365
|
+
"cost_usd": cost,
|
366
|
+
},
|
367
|
+
)
|
368
|
+
|
331
369
|
def estimate_job_cost_from_external_prices(
|
332
370
|
self, price_lookup: dict, iterations: int = 1
|
333
371
|
) -> dict:
|
@@ -341,9 +379,9 @@ class JobsPrompts:
|
|
341
379
|
- 1 token = 4 characters.
|
342
380
|
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
343
381
|
"""
|
344
|
-
|
382
|
+
# Collect all prompt data
|
345
383
|
data = []
|
346
|
-
for interview in interviews:
|
384
|
+
for interview in self.interviews:
|
347
385
|
invigilators = [
|
348
386
|
FetchInvigilator(interview)(question)
|
349
387
|
for question in self.survey.questions
|
@@ -354,59 +392,62 @@ class JobsPrompts:
|
|
354
392
|
**prompt_details, price_lookup=price_lookup
|
355
393
|
)
|
356
394
|
price_estimates = {
|
395
|
+
"estimated_input_price_per_million_tokens": prompt_cost[
|
396
|
+
"input_price_per_million_tokens"
|
397
|
+
],
|
398
|
+
"estimated_output_price_per_million_tokens": prompt_cost[
|
399
|
+
"output_price_per_million_tokens"
|
400
|
+
],
|
357
401
|
"estimated_input_tokens": prompt_cost["input_tokens"],
|
358
402
|
"estimated_output_tokens": prompt_cost["output_tokens"],
|
403
|
+
"estimated_input_cost_usd": prompt_cost["input_cost_usd"],
|
404
|
+
"estimated_output_cost_usd": prompt_cost["output_cost_usd"],
|
359
405
|
"estimated_cost_usd": prompt_cost["cost_usd"],
|
360
406
|
}
|
361
|
-
data.append(
|
407
|
+
data.append(
|
408
|
+
{
|
409
|
+
**prompt_details,
|
410
|
+
**price_estimates,
|
411
|
+
}
|
412
|
+
)
|
362
413
|
|
363
|
-
|
414
|
+
# Group by service, model, token type, and price
|
415
|
+
detailed_groups = {}
|
364
416
|
for item in data:
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
"
|
371
|
-
"
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
]
|
380
|
-
model_groups[key]["estimated_output_tokens"] += item[
|
381
|
-
"estimated_output_tokens"
|
382
|
-
]
|
383
|
-
|
384
|
-
# Apply iterations and convert to list
|
385
|
-
estimated_costs_by_model = []
|
386
|
-
for group_data in model_groups.values():
|
387
|
-
group_data["estimated_cost_usd"] *= iterations
|
388
|
-
group_data["estimated_input_tokens"] *= iterations
|
389
|
-
group_data["estimated_output_tokens"] *= iterations
|
390
|
-
estimated_costs_by_model.append(group_data)
|
417
|
+
for token_type in ["input", "output"]:
|
418
|
+
key, group_data = self.process_token_type(item, token_type)
|
419
|
+
if key not in detailed_groups:
|
420
|
+
detailed_groups[key] = group_data
|
421
|
+
else:
|
422
|
+
detailed_groups[key]["tokens"] += group_data["tokens"]
|
423
|
+
detailed_groups[key]["cost_usd"] += group_data["cost_usd"]
|
424
|
+
|
425
|
+
# Apply iterations and prepare final output
|
426
|
+
detailed_costs = []
|
427
|
+
for group in detailed_groups.values():
|
428
|
+
group["tokens"] *= iterations
|
429
|
+
group["cost_usd"] *= iterations
|
430
|
+
detailed_costs.append(group)
|
391
431
|
|
392
432
|
# Calculate totals
|
393
|
-
estimated_total_cost = sum(
|
394
|
-
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
395
|
-
)
|
433
|
+
estimated_total_cost = sum(group["cost_usd"] for group in detailed_costs)
|
396
434
|
estimated_total_input_tokens = sum(
|
397
|
-
|
435
|
+
group["tokens"]
|
436
|
+
for group in detailed_costs
|
437
|
+
if group["token_type"] == "input"
|
398
438
|
)
|
399
439
|
estimated_total_output_tokens = sum(
|
400
|
-
|
440
|
+
group["tokens"]
|
441
|
+
for group in detailed_costs
|
442
|
+
if group["token_type"] == "output"
|
401
443
|
)
|
402
444
|
|
403
445
|
output = {
|
404
446
|
"estimated_total_cost_usd": estimated_total_cost,
|
405
447
|
"estimated_total_input_tokens": estimated_total_input_tokens,
|
406
448
|
"estimated_total_output_tokens": estimated_total_output_tokens,
|
407
|
-
"
|
449
|
+
"detailed_costs": detailed_costs,
|
408
450
|
}
|
409
|
-
|
410
451
|
return output
|
411
452
|
|
412
453
|
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
@@ -30,6 +30,8 @@ class JobsInfo:
|
|
30
30
|
error_report_url: str = None
|
31
31
|
results_uuid: str = None
|
32
32
|
results_url: str = None
|
33
|
+
completed_interviews: int = None
|
34
|
+
failed_interviews: int = None
|
33
35
|
|
34
36
|
pretty_names = {
|
35
37
|
"job_uuid": "Job UUID",
|
@@ -53,6 +55,8 @@ class JobLogger(ABC):
|
|
53
55
|
"error_report_url",
|
54
56
|
"results_uuid",
|
55
57
|
"results_url",
|
58
|
+
"completed_interviews",
|
59
|
+
"failed_interviews",
|
56
60
|
],
|
57
61
|
value: str,
|
58
62
|
):
|
edsl/jobs/jobs_runner_status.py
CHANGED
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
|
|
10
10
|
from uuid import UUID
|
11
11
|
|
12
12
|
if TYPE_CHECKING:
|
13
|
-
from .
|
13
|
+
from .jobs import Jobs
|
14
14
|
|
15
15
|
|
16
16
|
@dataclass
|
@@ -65,14 +65,14 @@ class StatisticsTracker:
|
|
65
65
|
class JobsRunnerStatusBase(ABC):
|
66
66
|
def __init__(
|
67
67
|
self,
|
68
|
-
|
68
|
+
jobs: "Jobs",
|
69
69
|
n: int,
|
70
70
|
refresh_rate: float = 1,
|
71
71
|
endpoint_url: Optional[str] = "http://localhost:8000",
|
72
72
|
job_uuid: Optional[UUID] = None,
|
73
73
|
api_key: str = None,
|
74
74
|
):
|
75
|
-
self.
|
75
|
+
self.jobs = jobs
|
76
76
|
self.job_uuid = job_uuid
|
77
77
|
self.base_url = f"{endpoint_url}"
|
78
78
|
self.refresh_rate = refresh_rate
|
@@ -86,10 +86,10 @@ class JobsRunnerStatusBase(ABC):
|
|
86
86
|
"unfixed_exceptions",
|
87
87
|
"throughput",
|
88
88
|
]
|
89
|
-
self.num_total_interviews = n * len(self.
|
89
|
+
self.num_total_interviews = n * len(self.jobs)
|
90
90
|
|
91
91
|
self.distinct_models = list(
|
92
|
-
set(model.model for model in self.
|
92
|
+
set(model.model for model in self.jobs.models)
|
93
93
|
)
|
94
94
|
|
95
95
|
self.stats_tracker = StatisticsTracker(
|
@@ -151,26 +151,31 @@ class JobsRunnerStatusBase(ABC):
|
|
151
151
|
}
|
152
152
|
|
153
153
|
model_queues = {}
|
154
|
-
#
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
"
|
164
|
-
"
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
"
|
172
|
-
|
173
|
-
|
154
|
+
# Check if bucket collection exists and is not empty
|
155
|
+
if (hasattr(self.jobs, 'run_config') and
|
156
|
+
hasattr(self.jobs.run_config, 'environment') and
|
157
|
+
hasattr(self.jobs.run_config.environment, 'bucket_collection') and
|
158
|
+
self.jobs.run_config.environment.bucket_collection):
|
159
|
+
|
160
|
+
for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
|
161
|
+
model_name = model.model
|
162
|
+
model_queues[model_name] = {
|
163
|
+
"language_model_name": model_name,
|
164
|
+
"requests_bucket": {
|
165
|
+
"completed": bucket.requests_bucket.num_released,
|
166
|
+
"requested": bucket.requests_bucket.num_requests,
|
167
|
+
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
168
|
+
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
169
|
+
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
170
|
+
},
|
171
|
+
"tokens_bucket": {
|
172
|
+
"completed": bucket.tokens_bucket.num_released,
|
173
|
+
"requested": bucket.tokens_bucket.num_requests,
|
174
|
+
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
175
|
+
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
176
|
+
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
177
|
+
},
|
178
|
+
}
|
174
179
|
status_dict["language_model_queues"] = model_queues
|
175
180
|
return status_dict
|
176
181
|
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Progress bar management for asynchronous job execution.
|
3
|
+
|
4
|
+
This module provides a context manager for handling progress bar setup and thread
|
5
|
+
management during job execution. It coordinates the display and updating of progress
|
6
|
+
bars, particularly for remote tracking via the Expected Parrot API.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import threading
|
10
|
+
import warnings
|
11
|
+
|
12
|
+
from ..coop import Coop
|
13
|
+
from .jobs_runner_status import JobsRunnerStatus
|
14
|
+
|
15
|
+
|
16
|
+
class ProgressBarManager:
|
17
|
+
"""Context manager for handling progress bar setup and thread management.
|
18
|
+
|
19
|
+
This class manages the progress bar display and updating during job execution,
|
20
|
+
particularly for remote tracking via the Expected Parrot API.
|
21
|
+
|
22
|
+
It handles:
|
23
|
+
1. Setting up a status tracking object
|
24
|
+
2. Creating and managing a background thread for progress updates
|
25
|
+
3. Properly cleaning up resources when execution completes
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, jobs, run_config, parameters):
|
29
|
+
self.parameters = parameters
|
30
|
+
self.jobs = jobs
|
31
|
+
|
32
|
+
# Set up progress tracking
|
33
|
+
coop = Coop()
|
34
|
+
endpoint_url = coop.get_progress_bar_url()
|
35
|
+
|
36
|
+
# Set up jobs status object
|
37
|
+
params = {
|
38
|
+
"jobs": jobs,
|
39
|
+
"n": parameters.n,
|
40
|
+
"endpoint_url": endpoint_url,
|
41
|
+
"job_uuid": parameters.job_uuid,
|
42
|
+
}
|
43
|
+
|
44
|
+
# If the jobs_runner_status is already set, use it directly
|
45
|
+
if run_config.environment.jobs_runner_status is not None:
|
46
|
+
self.jobs_runner_status = run_config.environment.jobs_runner_status
|
47
|
+
else:
|
48
|
+
# Otherwise create a new one
|
49
|
+
self.jobs_runner_status = JobsRunnerStatus(**params)
|
50
|
+
|
51
|
+
# Store on run_config for use by other components
|
52
|
+
run_config.environment.jobs_runner_status = self.jobs_runner_status
|
53
|
+
|
54
|
+
self.progress_thread = None
|
55
|
+
self.stop_event = threading.Event()
|
56
|
+
|
57
|
+
def __enter__(self):
|
58
|
+
if self.parameters.progress_bar and self.jobs_runner_status.has_ep_api_key():
|
59
|
+
self.jobs_runner_status.setup()
|
60
|
+
self.progress_thread = threading.Thread(
|
61
|
+
target=self._run_progress_bar,
|
62
|
+
args=(self.stop_event, self.jobs_runner_status)
|
63
|
+
)
|
64
|
+
self.progress_thread.start()
|
65
|
+
elif self.parameters.progress_bar:
|
66
|
+
warnings.warn(
|
67
|
+
"You need an Expected Parrot API key to view job progress bars."
|
68
|
+
)
|
69
|
+
return self.stop_event
|
70
|
+
|
71
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
72
|
+
self.stop_event.set()
|
73
|
+
if self.progress_thread is not None:
|
74
|
+
self.progress_thread.join()
|
75
|
+
|
76
|
+
@staticmethod
|
77
|
+
def _run_progress_bar(stop_event, jobs_runner_status):
|
78
|
+
"""Runs the progress bar in a separate thread."""
|
79
|
+
jobs_runner_status.update_progress(stop_event)
|
edsl/jobs/remote_inference.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
|
2
3
|
from dataclasses import dataclass
|
3
4
|
from ..coop import CoopServerResponseError
|
@@ -112,13 +113,18 @@ class JobsRemoteInferenceHandler:
|
|
112
113
|
)
|
113
114
|
logger.add_info("job_uuid", job_uuid)
|
114
115
|
|
116
|
+
remote_inference_url = self.remote_inference_url
|
117
|
+
if "localhost" in remote_inference_url:
|
118
|
+
remote_inference_url = remote_inference_url.replace("8000", "1234")
|
115
119
|
logger.update(
|
116
|
-
f"Job details are available at your Coop account. [Go to Remote Inference page]({
|
120
|
+
f"Job details are available at your Coop account. [Go to Remote Inference page]({remote_inference_url})",
|
117
121
|
status=JobsStatus.RUNNING,
|
118
122
|
)
|
119
123
|
progress_bar_url = (
|
120
124
|
f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
121
125
|
)
|
126
|
+
if "localhost" in progress_bar_url:
|
127
|
+
progress_bar_url = progress_bar_url.replace("8000", "1234")
|
122
128
|
logger.add_info("progress_bar_url", progress_bar_url)
|
123
129
|
logger.update(
|
124
130
|
f"View job progress [here]({progress_bar_url})", status=JobsStatus.RUNNING
|
@@ -200,10 +206,35 @@ class JobsRemoteInferenceHandler:
|
|
200
206
|
status=JobsStatus.FAILED,
|
201
207
|
)
|
202
208
|
|
209
|
+
def _handle_partially_failed_job_interview_details(
|
210
|
+
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
211
|
+
) -> None:
|
212
|
+
"Extracts the interview details from the remote job data."
|
213
|
+
try:
|
214
|
+
# Job details is a string of the form "64 out of 1,758 interviews failed"
|
215
|
+
job_details = remote_job_data.get("latest_failure_description")
|
216
|
+
|
217
|
+
text_without_commas = job_details.replace(",", "")
|
218
|
+
|
219
|
+
# Find all numbers in the text
|
220
|
+
numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
|
221
|
+
|
222
|
+
failed = numbers[0]
|
223
|
+
total = numbers[1]
|
224
|
+
completed = total - failed
|
225
|
+
|
226
|
+
job_info.logger.add_info("completed_interviews", completed)
|
227
|
+
job_info.logger.add_info("failed_interviews", failed)
|
228
|
+
# This is mainly helpful metadata, and any errors here should not stop the code
|
229
|
+
except:
|
230
|
+
pass
|
231
|
+
|
203
232
|
def _handle_partially_failed_job(
|
204
233
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
205
234
|
) -> None:
|
206
235
|
"Handles a partially failed job by logging the error and updating the job status."
|
236
|
+
self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
|
237
|
+
|
207
238
|
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
208
239
|
|
209
240
|
if latest_error_report_url:
|
@@ -244,6 +275,8 @@ class JobsRemoteInferenceHandler:
|
|
244
275
|
job_info.logger.add_info("results_uuid", results_uuid)
|
245
276
|
results = object_fetcher(results_uuid, expected_object_type="results")
|
246
277
|
results_url = remote_job_data.get("results_url")
|
278
|
+
if "localhost" in results_url:
|
279
|
+
results_url = results_url.replace("8000", "1234")
|
247
280
|
job_info.logger.add_info("results_url", results_url)
|
248
281
|
|
249
282
|
if job_status == "completed":
|
@@ -256,6 +289,7 @@ class JobsRemoteInferenceHandler:
|
|
256
289
|
f"View partial results [here]({results_url})",
|
257
290
|
status=JobsStatus.PARTIALLY_FAILED,
|
258
291
|
)
|
292
|
+
|
259
293
|
results.job_uuid = job_info.job_uuid
|
260
294
|
results.results_uuid = results_uuid
|
261
295
|
return results
|
@@ -2,6 +2,7 @@ from typing import Optional, TYPE_CHECKING
|
|
2
2
|
import os
|
3
3
|
from functools import lru_cache
|
4
4
|
import textwrap
|
5
|
+
import requests
|
5
6
|
|
6
7
|
if TYPE_CHECKING:
|
7
8
|
from ..coop import Coop
|
@@ -255,7 +256,11 @@ class KeyLookupBuilder:
|
|
255
256
|
return dict(list(os.environ.items()))
|
256
257
|
|
257
258
|
def _coop_key_value_pairs(self):
|
258
|
-
|
259
|
+
try:
|
260
|
+
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
|
261
|
+
except requests.ConnectionError:
|
262
|
+
# If connection fails, return empty dict instead of raising error
|
263
|
+
return {}
|
259
264
|
|
260
265
|
def _config_key_value_pairs(self):
|
261
266
|
from ..config import CONFIG
|