edsl 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +23 -4
- edsl/agents/agent_list.py +36 -6
- edsl/coop/coop.py +274 -35
- edsl/coop/utils.py +63 -0
- edsl/dataset/dataset.py +74 -0
- edsl/dataset/dataset_operations_mixin.py +67 -62
- edsl/inference_services/services/test_service.py +1 -1
- edsl/interviews/exception_tracking.py +92 -20
- edsl/invigilators/invigilators.py +5 -1
- edsl/invigilators/prompt_constructor.py +299 -136
- edsl/jobs/html_table_job_logger.py +394 -48
- edsl/jobs/jobs_pricing_estimation.py +19 -114
- edsl/jobs/jobs_remote_inference_logger.py +29 -0
- edsl/jobs/jobs_runner_status.py +52 -21
- edsl/jobs/remote_inference.py +214 -30
- edsl/language_models/language_model.py +40 -3
- edsl/language_models/price_manager.py +91 -57
- edsl/prompts/prompt.py +1 -0
- edsl/questions/question_list.py +76 -20
- edsl/results/results.py +8 -1
- edsl/scenarios/file_store.py +8 -12
- edsl/scenarios/scenario.py +50 -2
- edsl/scenarios/scenario_list.py +34 -12
- edsl/surveys/survey.py +4 -0
- edsl/tasks/task_history.py +180 -6
- edsl/utilities/wikipedia.py +194 -0
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/METADATA +4 -3
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/RECORD +32 -32
- edsl/language_models/compute_cost.py +0 -78
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/LICENSE +0 -0
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/WHEEL +0 -0
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/entry_points.txt +0 -0
edsl/jobs/remote_inference.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
import re
|
2
|
+
import math
|
2
3
|
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from ..coop import CoopServerResponseError
|
5
|
-
from ..coop.utils import VisibilityType
|
6
|
+
from ..coop.utils import VisibilityType, CostConverter
|
6
7
|
from ..coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
|
7
8
|
from .jobs_status_enums import JobsStatus
|
8
|
-
from .jobs_remote_inference_logger import JobLogger
|
9
|
+
from .jobs_remote_inference_logger import JobLogger, JobRunExceptionCounter, ModelCost
|
9
10
|
from .exceptions import RemoteInferenceError
|
10
11
|
|
11
12
|
|
@@ -94,6 +95,12 @@ class JobsRemoteInferenceHandler:
|
|
94
95
|
"Remote inference activated. Sending job to server...",
|
95
96
|
status=JobsStatus.QUEUED,
|
96
97
|
)
|
98
|
+
logger.add_info(
|
99
|
+
"remote_inference_url", f"{self.expected_parrot_url}/home/remote-inference"
|
100
|
+
)
|
101
|
+
logger.add_info(
|
102
|
+
"remote_cache_url", f"{self.expected_parrot_url}/home/remote-cache"
|
103
|
+
)
|
97
104
|
remote_job_creation_data = coop.remote_inference_create(
|
98
105
|
self.jobs,
|
99
106
|
description=remote_inference_description,
|
@@ -183,7 +190,9 @@ class JobsRemoteInferenceHandler:
|
|
183
190
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
184
191
|
) -> None:
|
185
192
|
"Handles a failed job by logging the error and updating the job status."
|
186
|
-
|
193
|
+
error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
|
194
|
+
"error_report_url"
|
195
|
+
)
|
187
196
|
|
188
197
|
reason = remote_job_data.get("reason")
|
189
198
|
|
@@ -193,8 +202,8 @@ class JobsRemoteInferenceHandler:
|
|
193
202
|
status=JobsStatus.FAILED,
|
194
203
|
)
|
195
204
|
|
196
|
-
if
|
197
|
-
job_info.logger.add_info("error_report_url",
|
205
|
+
if error_report_url:
|
206
|
+
job_info.logger.add_info("error_report_url", error_report_url)
|
198
207
|
|
199
208
|
job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
|
200
209
|
job_info.logger.update(
|
@@ -206,39 +215,46 @@ class JobsRemoteInferenceHandler:
|
|
206
215
|
status=JobsStatus.FAILED,
|
207
216
|
)
|
208
217
|
|
209
|
-
def
|
218
|
+
def _update_interview_details(
|
210
219
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
211
220
|
) -> None:
|
212
|
-
"
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
221
|
+
"Updates the interview details in the job info."
|
222
|
+
latest_job_run_details = remote_job_data.get("latest_job_run_details", {})
|
223
|
+
interview_details = latest_job_run_details.get("interview_details", {}) or {}
|
224
|
+
completed_interviews = interview_details.get("completed_interviews", 0)
|
225
|
+
interviews_with_exceptions = interview_details.get(
|
226
|
+
"interviews_with_exceptions", 0
|
227
|
+
)
|
228
|
+
interviews_without_exceptions = (
|
229
|
+
completed_interviews - interviews_with_exceptions
|
230
|
+
)
|
231
|
+
job_info.logger.add_info("completed_interviews", interviews_without_exceptions)
|
232
|
+
job_info.logger.add_info("failed_interviews", interviews_with_exceptions)
|
233
|
+
|
234
|
+
exception_summary = interview_details.get("exception_summary", []) or []
|
235
|
+
if exception_summary:
|
236
|
+
job_run_exception_counters = []
|
237
|
+
for exception in exception_summary:
|
238
|
+
exception_counter = JobRunExceptionCounter(
|
239
|
+
exception_type=exception.get("exception_type"),
|
240
|
+
inference_service=exception.get("inference_service"),
|
241
|
+
model=exception.get("model"),
|
242
|
+
question_name=exception.get("question_name"),
|
243
|
+
exception_count=exception.get("exception_count"),
|
244
|
+
)
|
245
|
+
job_run_exception_counters.append(exception_counter)
|
246
|
+
job_info.logger.add_info("exception_summary", job_run_exception_counters)
|
231
247
|
|
232
248
|
def _handle_partially_failed_job(
|
233
249
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
234
250
|
) -> None:
|
235
251
|
"Handles a partially failed job by logging the error and updating the job status."
|
236
|
-
|
237
|
-
|
238
|
-
|
252
|
+
error_report_url = remote_job_data.get("latest_job_run_details", {}).get(
|
253
|
+
"error_report_url"
|
254
|
+
)
|
239
255
|
|
240
|
-
if
|
241
|
-
job_info.logger.add_info("error_report_url",
|
256
|
+
if error_report_url:
|
257
|
+
job_info.logger.add_info("error_report_url", error_report_url)
|
242
258
|
|
243
259
|
job_info.logger.update(
|
244
260
|
"Job completed with partial results.", status=JobsStatus.PARTIALLY_FAILED
|
@@ -263,6 +279,149 @@ class JobsRemoteInferenceHandler:
|
|
263
279
|
)
|
264
280
|
time.sleep(self.poll_interval)
|
265
281
|
|
282
|
+
def _get_expenses_from_results(self, results: "Results") -> dict:
|
283
|
+
"""
|
284
|
+
Calculates expenses from Results object.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
results: Results object containing model responses
|
288
|
+
include_cached_responses_in_cost: Whether to include cached responses in cost calculation
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
Dictionary mapping ExpenseKey to TokenExpense information
|
292
|
+
"""
|
293
|
+
expenses = {}
|
294
|
+
|
295
|
+
for result in results:
|
296
|
+
raw_response = result["raw_model_response"]
|
297
|
+
|
298
|
+
# Process each cost field in the response
|
299
|
+
for key in raw_response:
|
300
|
+
if not key.endswith("_cost"):
|
301
|
+
continue
|
302
|
+
|
303
|
+
result_cost = raw_response[key]
|
304
|
+
if not isinstance(result_cost, (int, float)):
|
305
|
+
continue
|
306
|
+
|
307
|
+
question_name = key.removesuffix("_cost")
|
308
|
+
cache_used = result["cache_used_dict"][question_name]
|
309
|
+
|
310
|
+
# Get expense keys for input and output tokens
|
311
|
+
input_key = (
|
312
|
+
result["model"]._inference_service_,
|
313
|
+
result["model"].model,
|
314
|
+
"input",
|
315
|
+
raw_response[f"{question_name}_input_price_per_million_tokens"],
|
316
|
+
)
|
317
|
+
output_key = (
|
318
|
+
result["model"]._inference_service_,
|
319
|
+
result["model"].model,
|
320
|
+
"output",
|
321
|
+
raw_response[f"{question_name}_output_price_per_million_tokens"],
|
322
|
+
)
|
323
|
+
|
324
|
+
# Update input token expenses
|
325
|
+
if input_key not in expenses:
|
326
|
+
expenses[input_key] = {
|
327
|
+
"tokens": 0,
|
328
|
+
"cost_usd": 0,
|
329
|
+
"cost_usd_with_cache": 0,
|
330
|
+
}
|
331
|
+
|
332
|
+
input_price_per_million_tokens = input_key[3]
|
333
|
+
input_tokens = raw_response[f"{question_name}_input_tokens"]
|
334
|
+
input_cost = (input_price_per_million_tokens / 1_000_000) * input_tokens
|
335
|
+
|
336
|
+
expenses[input_key]["tokens"] += input_tokens
|
337
|
+
expenses[input_key]["cost_usd"] += input_cost
|
338
|
+
|
339
|
+
if not cache_used:
|
340
|
+
expenses[input_key]["cost_usd_with_cache"] += input_cost
|
341
|
+
|
342
|
+
# Update output token expenses
|
343
|
+
if output_key not in expenses:
|
344
|
+
expenses[output_key] = {
|
345
|
+
"tokens": 0,
|
346
|
+
"cost_usd": 0,
|
347
|
+
"cost_usd_with_cache": 0,
|
348
|
+
}
|
349
|
+
|
350
|
+
output_price_per_million_tokens = output_key[3]
|
351
|
+
output_tokens = raw_response[f"{question_name}_output_tokens"]
|
352
|
+
output_cost = (
|
353
|
+
output_price_per_million_tokens / 1_000_000
|
354
|
+
) * output_tokens
|
355
|
+
|
356
|
+
expenses[output_key]["tokens"] += output_tokens
|
357
|
+
expenses[output_key]["cost_usd"] += output_cost
|
358
|
+
|
359
|
+
if not cache_used:
|
360
|
+
expenses[output_key]["cost_usd_with_cache"] += output_cost
|
361
|
+
|
362
|
+
expenses_by_model = {}
|
363
|
+
for expense_key, expense_usage in expenses.items():
|
364
|
+
service, model, token_type, _ = expense_key
|
365
|
+
model_key = (service, model)
|
366
|
+
|
367
|
+
if model_key not in expenses_by_model:
|
368
|
+
expenses_by_model[model_key] = {
|
369
|
+
"service": service,
|
370
|
+
"model": model,
|
371
|
+
"input_tokens": 0,
|
372
|
+
"input_cost_usd": 0,
|
373
|
+
"input_cost_usd_with_cache": 0,
|
374
|
+
"output_tokens": 0,
|
375
|
+
"output_cost_usd": 0,
|
376
|
+
"output_cost_usd_with_cache": 0,
|
377
|
+
}
|
378
|
+
|
379
|
+
if token_type == "input":
|
380
|
+
expenses_by_model[model_key]["input_tokens"] += expense_usage["tokens"]
|
381
|
+
expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
|
382
|
+
"cost_usd"
|
383
|
+
]
|
384
|
+
expenses_by_model[model_key][
|
385
|
+
"input_cost_usd_with_cache"
|
386
|
+
] += expense_usage["cost_usd_with_cache"]
|
387
|
+
elif token_type == "output":
|
388
|
+
expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
|
389
|
+
expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
|
390
|
+
"cost_usd"
|
391
|
+
]
|
392
|
+
expenses_by_model[model_key][
|
393
|
+
"output_cost_usd_with_cache"
|
394
|
+
] += expense_usage["cost_usd_with_cache"]
|
395
|
+
|
396
|
+
converter = CostConverter()
|
397
|
+
for model_key, model_cost_dict in expenses_by_model.items():
|
398
|
+
|
399
|
+
# Handle full cost (without cache)
|
400
|
+
input_cost = model_cost_dict["input_cost_usd"]
|
401
|
+
output_cost = model_cost_dict["output_cost_usd"]
|
402
|
+
model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
|
403
|
+
model_cost_dict["output_cost_credits"] = converter.usd_to_credits(
|
404
|
+
output_cost
|
405
|
+
)
|
406
|
+
# Convert back to USD (to get the rounded value)
|
407
|
+
model_cost_dict["input_cost_usd"] = converter.credits_to_usd(
|
408
|
+
model_cost_dict["input_cost_credits"]
|
409
|
+
)
|
410
|
+
model_cost_dict["output_cost_usd"] = converter.credits_to_usd(
|
411
|
+
model_cost_dict["output_cost_credits"]
|
412
|
+
)
|
413
|
+
|
414
|
+
# Handle cost with cache
|
415
|
+
input_cost_with_cache = model_cost_dict["input_cost_usd_with_cache"]
|
416
|
+
output_cost_with_cache = model_cost_dict["output_cost_usd_with_cache"]
|
417
|
+
model_cost_dict["input_cost_credits_with_cache"] = converter.usd_to_credits(
|
418
|
+
input_cost_with_cache
|
419
|
+
)
|
420
|
+
model_cost_dict["output_cost_credits_with_cache"] = (
|
421
|
+
converter.usd_to_credits(output_cost_with_cache)
|
422
|
+
)
|
423
|
+
return list(expenses_by_model.values())
|
424
|
+
|
266
425
|
def _fetch_results_and_log(
|
267
426
|
self,
|
268
427
|
job_info: RemoteJobInfo,
|
@@ -274,12 +433,36 @@ class JobsRemoteInferenceHandler:
|
|
274
433
|
"Fetches the results object and logs the results URL."
|
275
434
|
job_info.logger.add_info("results_uuid", results_uuid)
|
276
435
|
results = object_fetcher(results_uuid, expected_object_type="results")
|
436
|
+
|
437
|
+
model_cost_dicts = self._get_expenses_from_results(results)
|
438
|
+
|
439
|
+
model_costs = [
|
440
|
+
ModelCost(
|
441
|
+
service=model_cost_dict.get("service"),
|
442
|
+
model=model_cost_dict.get("model"),
|
443
|
+
input_tokens=model_cost_dict.get("input_tokens"),
|
444
|
+
input_cost_usd=model_cost_dict.get("input_cost_usd"),
|
445
|
+
output_tokens=model_cost_dict.get("output_tokens"),
|
446
|
+
output_cost_usd=model_cost_dict.get("output_cost_usd"),
|
447
|
+
input_cost_credits_with_cache=model_cost_dict.get(
|
448
|
+
"input_cost_credits_with_cache"
|
449
|
+
),
|
450
|
+
output_cost_credits_with_cache=model_cost_dict.get(
|
451
|
+
"output_cost_credits_with_cache"
|
452
|
+
),
|
453
|
+
)
|
454
|
+
for model_cost_dict in model_cost_dicts
|
455
|
+
]
|
456
|
+
job_info.logger.add_info("model_costs", model_costs)
|
457
|
+
|
277
458
|
results_url = remote_job_data.get("results_url")
|
278
459
|
if "localhost" in results_url:
|
279
460
|
results_url = results_url.replace("8000", "1234")
|
280
461
|
job_info.logger.add_info("results_url", results_url)
|
281
462
|
|
282
463
|
if job_status == "completed":
|
464
|
+
job_info.logger.add_info("completed_interviews", len(results))
|
465
|
+
job_info.logger.add_info("failed_interviews", 0)
|
283
466
|
job_info.logger.update(
|
284
467
|
f"Job completed and Results stored on Coop. [View Results]({results_url})",
|
285
468
|
status=JobsStatus.COMPLETED,
|
@@ -302,6 +485,7 @@ class JobsRemoteInferenceHandler:
|
|
302
485
|
) -> Union[None, "Results", Literal["continue"]]:
|
303
486
|
"""Makes one attempt to fetch and process a remote job's status and results."""
|
304
487
|
remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
|
488
|
+
self._update_interview_details(job_info, remote_job_data)
|
305
489
|
status = remote_job_data.get("status")
|
306
490
|
reason = remote_job_data.get("reason")
|
307
491
|
if status == "cancelled":
|
@@ -769,8 +769,45 @@ class LanguageModel(
|
|
769
769
|
params["question_name"] = invigilator.question.question_name
|
770
770
|
# Get timeout from configuration
|
771
771
|
from ..config import CONFIG
|
772
|
-
|
773
|
-
|
772
|
+
import logging
|
773
|
+
|
774
|
+
logger = logging.getLogger(__name__)
|
775
|
+
base_timeout = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
776
|
+
|
777
|
+
# Adjust timeout if files are present
|
778
|
+
import time
|
779
|
+
|
780
|
+
start = time.time()
|
781
|
+
if files_list:
|
782
|
+
# Calculate total size of attached files in MB
|
783
|
+
file_sizes = []
|
784
|
+
for file in files_list:
|
785
|
+
# Try different attributes that might contain the file content
|
786
|
+
if hasattr(file, "base64_string") and file.base64_string:
|
787
|
+
file_sizes.append(len(file.base64_string) / (1024 * 1024))
|
788
|
+
elif hasattr(file, "content") and file.content:
|
789
|
+
file_sizes.append(len(file.content) / (1024 * 1024))
|
790
|
+
elif hasattr(file, "data") and file.data:
|
791
|
+
file_sizes.append(len(file.data) / (1024 * 1024))
|
792
|
+
else:
|
793
|
+
# Default minimum size if we can't determine actual size
|
794
|
+
file_sizes.append(1) # Assume at least 1MB
|
795
|
+
total_size_mb = sum(file_sizes)
|
796
|
+
|
797
|
+
# Increase timeout proportionally to file size
|
798
|
+
# For each MB of file size, add 10 seconds to the timeout (adjust as needed)
|
799
|
+
size_adjustment = total_size_mb * 10
|
800
|
+
|
801
|
+
# Cap the maximum timeout adjustment at 5 minutes (300 seconds)
|
802
|
+
size_adjustment = min(size_adjustment, 300)
|
803
|
+
|
804
|
+
TIMEOUT = base_timeout + size_adjustment
|
805
|
+
|
806
|
+
logger.info(
|
807
|
+
f"Adjusted timeout for API call with {len(files_list)} files (total size: {total_size_mb:.2f}MB). Base timeout: {base_timeout}s, New timeout: {TIMEOUT}s"
|
808
|
+
)
|
809
|
+
else:
|
810
|
+
TIMEOUT = base_timeout
|
774
811
|
|
775
812
|
# Execute the model call with timeout
|
776
813
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
@@ -1019,7 +1056,7 @@ class LanguageModel(
|
|
1019
1056
|
|
1020
1057
|
# Combine model name and parameters
|
1021
1058
|
return (
|
1022
|
-
f"Model(model_name = '{self.model}'"
|
1059
|
+
f"Model(model_name = '{self.model}', service_name = '{self._inference_service_}'"
|
1023
1060
|
+ (f", {param_string}" if param_string else "")
|
1024
1061
|
+ ")"
|
1025
1062
|
)
|
@@ -19,56 +19,12 @@ class ResponseCost:
|
|
19
19
|
total_cost: Union[float, str, None] = None
|
20
20
|
|
21
21
|
|
22
|
-
class
|
23
|
-
|
24
|
-
|
25
|
-
_is_initialized = False
|
26
|
-
|
27
|
-
def __new__(cls):
|
28
|
-
if cls._instance is None:
|
29
|
-
instance = super(PriceManager, cls).__new__(cls)
|
30
|
-
instance._price_lookup = {} # Instance-specific attribute
|
31
|
-
instance._is_initialized = False
|
32
|
-
cls._instance = instance # Store the instance directly
|
33
|
-
return instance
|
34
|
-
return cls._instance
|
35
|
-
|
36
|
-
def __init__(self):
|
37
|
-
"""Initialize the singleton instance only once."""
|
38
|
-
if not self._is_initialized:
|
39
|
-
self._is_initialized = True
|
40
|
-
self.refresh_prices()
|
41
|
-
|
42
|
-
@classmethod
|
43
|
-
def get_instance(cls):
|
44
|
-
"""Get the singleton instance, creating it if necessary."""
|
45
|
-
if cls._instance is None:
|
46
|
-
cls() # Create the instance if it doesn't exist
|
47
|
-
return cls._instance
|
48
|
-
|
49
|
-
@classmethod
|
50
|
-
def reset(cls):
|
51
|
-
"""Reset the singleton instance to clean up resources."""
|
52
|
-
cls._instance = None
|
53
|
-
cls._is_initialized = False
|
54
|
-
cls._price_lookup = {}
|
55
|
-
|
56
|
-
def __del__(self):
|
57
|
-
"""Ensure proper cleanup when the instance is garbage collected."""
|
58
|
-
try:
|
59
|
-
self._price_lookup = {} # Clean up resources
|
60
|
-
except:
|
61
|
-
pass # Ignore any cleanup errors
|
22
|
+
class PriceRetriever:
|
23
|
+
DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
24
|
+
DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
|
62
25
|
|
63
|
-
def
|
64
|
-
|
65
|
-
from edsl.coop import Coop
|
66
|
-
|
67
|
-
c = Coop()
|
68
|
-
try:
|
69
|
-
self._price_lookup = c.fetch_prices()
|
70
|
-
except Exception as e:
|
71
|
-
print(f"Error fetching prices: {str(e)}")
|
26
|
+
def __init__(self, price_lookup: Dict[Tuple[str, str], Dict]):
|
27
|
+
self._price_lookup = price_lookup
|
72
28
|
|
73
29
|
def get_price(self, inference_service: str, model: str) -> Dict:
|
74
30
|
"""Get the price information for a specific service and model."""
|
@@ -77,10 +33,6 @@ class PriceManager:
|
|
77
33
|
inference_service
|
78
34
|
)
|
79
35
|
|
80
|
-
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
81
|
-
"""Get the complete price lookup dictionary."""
|
82
|
-
return self._price_lookup.copy()
|
83
|
-
|
84
36
|
def _get_fallback_price(self, inference_service: str) -> Dict:
|
85
37
|
"""
|
86
38
|
Get fallback prices for a service.
|
@@ -101,15 +53,21 @@ class PriceManager:
|
|
101
53
|
if service == inference_service
|
102
54
|
]
|
103
55
|
|
104
|
-
|
56
|
+
default_input_price_info = {
|
57
|
+
"one_usd_buys": 1_000_000,
|
58
|
+
"service_stated_token_qty": 1_000_000,
|
59
|
+
"service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
|
60
|
+
}
|
61
|
+
|
62
|
+
default_output_price_info = {
|
105
63
|
"one_usd_buys": 1_000_000,
|
106
64
|
"service_stated_token_qty": 1_000_000,
|
107
|
-
"service_stated_token_price":
|
65
|
+
"service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
|
108
66
|
}
|
109
67
|
|
110
68
|
# Find the most expensive price entries (lowest tokens per USD)
|
111
|
-
input_price_info =
|
112
|
-
output_price_info =
|
69
|
+
input_price_info = default_input_price_info
|
70
|
+
output_price_info = default_output_price_info
|
113
71
|
|
114
72
|
input_prices = [
|
115
73
|
PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
|
@@ -156,6 +114,82 @@ class PriceManager:
|
|
156
114
|
price_per_million_tokens = round(price_per_token * 1_000_000, 10)
|
157
115
|
return price_per_million_tokens
|
158
116
|
|
117
|
+
|
118
|
+
class PriceManager:
|
119
|
+
_instance = None
|
120
|
+
_price_lookup: Dict[Tuple[str, str], Dict] = {}
|
121
|
+
_is_initialized = False
|
122
|
+
|
123
|
+
def __new__(cls):
|
124
|
+
if cls._instance is None:
|
125
|
+
instance = super(PriceManager, cls).__new__(cls)
|
126
|
+
instance._price_lookup = {} # Instance-specific attribute
|
127
|
+
instance._is_initialized = False
|
128
|
+
cls._instance = instance # Store the instance directly
|
129
|
+
return instance
|
130
|
+
return cls._instance
|
131
|
+
|
132
|
+
def __init__(self):
|
133
|
+
"""Initialize the singleton instance only once."""
|
134
|
+
if not self._is_initialized:
|
135
|
+
self._is_initialized = True
|
136
|
+
self.refresh_prices()
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
def get_instance(cls):
|
140
|
+
"""Get the singleton instance, creating it if necessary."""
|
141
|
+
if cls._instance is None:
|
142
|
+
cls() # Create the instance if it doesn't exist
|
143
|
+
return cls._instance
|
144
|
+
|
145
|
+
@classmethod
|
146
|
+
def reset(cls):
|
147
|
+
"""Reset the singleton instance to clean up resources."""
|
148
|
+
cls._instance = None
|
149
|
+
cls._is_initialized = False
|
150
|
+
cls._price_lookup = {}
|
151
|
+
|
152
|
+
def __del__(self):
|
153
|
+
"""Ensure proper cleanup when the instance is garbage collected."""
|
154
|
+
try:
|
155
|
+
self._price_lookup = {} # Clean up resources
|
156
|
+
except:
|
157
|
+
pass # Ignore any cleanup errors
|
158
|
+
|
159
|
+
@property
|
160
|
+
def price_retriever(self):
|
161
|
+
return PriceRetriever(self._price_lookup)
|
162
|
+
|
163
|
+
def refresh_prices(self) -> None:
|
164
|
+
"""Fetch fresh prices and update the internal price lookup."""
|
165
|
+
from edsl.coop import Coop
|
166
|
+
|
167
|
+
c = Coop()
|
168
|
+
try:
|
169
|
+
self._price_lookup = c.fetch_prices()
|
170
|
+
except Exception as e:
|
171
|
+
print(f"Error fetching prices: {str(e)}")
|
172
|
+
|
173
|
+
def get_price(self, inference_service: str, model: str) -> Dict:
|
174
|
+
"""Get the price information for a specific service and model."""
|
175
|
+
return self.price_retriever.get_price(inference_service, model)
|
176
|
+
|
177
|
+
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
178
|
+
"""Get the complete price lookup dictionary."""
|
179
|
+
return self._price_lookup.copy()
|
180
|
+
|
181
|
+
def get_price_per_million_tokens(
|
182
|
+
self,
|
183
|
+
relevant_prices: Dict,
|
184
|
+
token_type: Literal["input", "output"],
|
185
|
+
) -> Dict:
|
186
|
+
"""
|
187
|
+
Get the price per million tokens for a specific service, model, and token type.
|
188
|
+
"""
|
189
|
+
return self.price_retriever.get_price_per_million_tokens(
|
190
|
+
relevant_prices, token_type
|
191
|
+
)
|
192
|
+
|
159
193
|
def _calculate_total_cost(
|
160
194
|
self,
|
161
195
|
relevant_prices: Dict,
|