edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/base/data_transfer_models.py +15 -4
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/dataset/dataset_operations_mixin.py +216 -180
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +7 -3
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +94 -5
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/interviews/request_token_estimator.py +8 -0
- edsl/invigilators/invigilators.py +35 -13
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +154 -113
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +110 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +176 -71
- edsl/language_models/registry.py +5 -0
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_dict.py +201 -16
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +115 -46
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +150 -9
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
@@ -49,6 +49,7 @@ from ..data_transfer_models import (
|
|
49
49
|
)
|
50
50
|
|
51
51
|
if TYPE_CHECKING:
|
52
|
+
from .price_manager import ResponseCost
|
52
53
|
from ..caching import Cache
|
53
54
|
from ..scenarios import FileStore
|
54
55
|
from ..questions import QuestionBase
|
@@ -365,6 +366,59 @@ class LanguageModel(
|
|
365
366
|
self._api_token = info.api_token
|
366
367
|
return self._api_token
|
367
368
|
|
369
|
+
def copy(self) -> "LanguageModel":
|
370
|
+
"""Create a deep copy of this language model instance.
|
371
|
+
|
372
|
+
This method creates a completely independent copy of the language model
|
373
|
+
by creating a new instance with the same parameters and copying relevant attributes.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
LanguageModel: A new language model instance that is functionally identical to this one
|
377
|
+
|
378
|
+
Examples:
|
379
|
+
>>> m1 = LanguageModel.example()
|
380
|
+
>>> m2 = m1.copy()
|
381
|
+
>>> m1 == m2 # Functionally equivalent
|
382
|
+
True
|
383
|
+
>>> id(m1) == id(m2) # But different objects
|
384
|
+
False
|
385
|
+
"""
|
386
|
+
# Create a new instance of the same class with the same parameters
|
387
|
+
try:
|
388
|
+
# For most models, we can instantiate with the saved parameters
|
389
|
+
new_model = self.__class__(**self.parameters)
|
390
|
+
|
391
|
+
# Copy all important instance attributes
|
392
|
+
for key, value in self.__dict__.items():
|
393
|
+
if key not in ("_api_token",) and not key.startswith("__"):
|
394
|
+
setattr(new_model, key, value)
|
395
|
+
|
396
|
+
return new_model
|
397
|
+
except Exception:
|
398
|
+
# Fallback for dynamically created classes like TestServiceLanguageModel
|
399
|
+
from ..inference_services import default
|
400
|
+
|
401
|
+
# If this is a test model, create a new test model instance
|
402
|
+
if getattr(self, "_inference_service_", "") == "test":
|
403
|
+
service = default.get_service("test")
|
404
|
+
model_class = service.create_model("test")
|
405
|
+
new_model = model_class(**self.parameters)
|
406
|
+
|
407
|
+
# Copy attributes
|
408
|
+
for key, value in self.__dict__.items():
|
409
|
+
if key not in ("_api_token",) and not key.startswith("__"):
|
410
|
+
setattr(new_model, key, value)
|
411
|
+
|
412
|
+
return new_model
|
413
|
+
|
414
|
+
# If we can't create the model directly, just return a simple test model
|
415
|
+
# This is a last resort fallback
|
416
|
+
from ..inference_services import get_service
|
417
|
+
|
418
|
+
service = get_service("test")
|
419
|
+
model_class = service.create_model("test")
|
420
|
+
return model_class()
|
421
|
+
|
368
422
|
def __getitem__(self, key):
|
369
423
|
"""Allow dictionary-style access to model attributes.
|
370
424
|
|
@@ -679,9 +733,12 @@ class LanguageModel(
|
|
679
733
|
user_prompt_with_hashes = user_prompt
|
680
734
|
|
681
735
|
# Prepare parameters for cache lookup
|
736
|
+
cache_parameters = self.parameters.copy()
|
737
|
+
if self.model == "test":
|
738
|
+
cache_parameters.pop("canned_response", None)
|
682
739
|
cache_call_params = {
|
683
740
|
"model": str(self.model),
|
684
|
-
"parameters":
|
741
|
+
"parameters": cache_parameters,
|
685
742
|
"system_prompt": system_prompt,
|
686
743
|
"user_prompt": user_prompt_with_hashes,
|
687
744
|
"iteration": iteration,
|
@@ -726,13 +783,18 @@ class LanguageModel(
|
|
726
783
|
# Calculate cost for the response
|
727
784
|
cost = self.cost(response)
|
728
785
|
# Return a structured response with metadata
|
729
|
-
|
786
|
+
response = ModelResponse(
|
730
787
|
response=response,
|
731
788
|
cache_used=cache_used,
|
732
789
|
cache_key=cache_key,
|
733
790
|
cached_response=cached_response,
|
734
|
-
|
791
|
+
input_tokens=cost.input_tokens,
|
792
|
+
output_tokens=cost.output_tokens,
|
793
|
+
input_price_per_million_tokens=cost.input_price_per_million_tokens,
|
794
|
+
output_price_per_million_tokens=cost.output_price_per_million_tokens,
|
795
|
+
total_cost=cost.total_cost,
|
735
796
|
)
|
797
|
+
return response
|
736
798
|
|
737
799
|
_get_intended_model_call_outcome = sync_wrapper(
|
738
800
|
_async_get_intended_model_call_outcome
|
@@ -825,7 +887,7 @@ class LanguageModel(
|
|
825
887
|
|
826
888
|
get_response = sync_wrapper(async_get_response)
|
827
889
|
|
828
|
-
def cost(self, raw_response: dict[str, Any]) ->
|
890
|
+
def cost(self, raw_response: dict[str, Any]) -> ResponseCost:
|
829
891
|
"""Calculate the monetary cost of a model API call.
|
830
892
|
|
831
893
|
This method extracts token usage information from the response and
|
@@ -836,7 +898,7 @@ class LanguageModel(
|
|
836
898
|
raw_response: The complete response dictionary from the model API
|
837
899
|
|
838
900
|
Returns:
|
839
|
-
|
901
|
+
ResponseCost: Object containing token counts and total cost
|
840
902
|
"""
|
841
903
|
# Extract token usage data from the response
|
842
904
|
usage = self.get_usage_dict(raw_response)
|
@@ -844,7 +906,7 @@ class LanguageModel(
|
|
844
906
|
# Use the price manager to calculate the actual cost
|
845
907
|
from .price_manager import PriceManager
|
846
908
|
|
847
|
-
price_manager = PriceManager()
|
909
|
+
price_manager = PriceManager.get_instance()
|
848
910
|
|
849
911
|
return price_manager.calculate_cost(
|
850
912
|
inference_service=self._inference_service_,
|
@@ -873,9 +935,15 @@ class LanguageModel(
|
|
873
935
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
874
936
|
"""
|
875
937
|
# Build the base dictionary with essential model information
|
938
|
+
parameters = self.parameters.copy()
|
939
|
+
|
940
|
+
# For test models, ensure canned_response is included in serialization
|
941
|
+
if self.model == "test" and hasattr(self, "canned_response"):
|
942
|
+
parameters["canned_response"] = self.canned_response
|
943
|
+
|
876
944
|
d = {
|
877
945
|
"model": self.model,
|
878
|
-
"parameters":
|
946
|
+
"parameters": parameters,
|
879
947
|
"inference_service": self._inference_service_,
|
880
948
|
}
|
881
949
|
|
@@ -913,7 +981,25 @@ class LanguageModel(
|
|
913
981
|
data["model"], service_name=data.get("inference_service", None)
|
914
982
|
)
|
915
983
|
|
916
|
-
#
|
984
|
+
# Handle canned_response in parameters for test models
|
985
|
+
if (
|
986
|
+
data["model"] == "test"
|
987
|
+
and "parameters" in data
|
988
|
+
and "canned_response" in data["parameters"]
|
989
|
+
):
|
990
|
+
# Extract canned_response from parameters to set as a direct attribute
|
991
|
+
canned_response = data["parameters"]["canned_response"]
|
992
|
+
params_copy = data.copy()
|
993
|
+
|
994
|
+
# Direct attribute will be set during initialization
|
995
|
+
# Add it as a top-level parameter for model initialization
|
996
|
+
if isinstance(params_copy, dict) and "parameters" in params_copy:
|
997
|
+
params_copy["canned_response"] = canned_response
|
998
|
+
|
999
|
+
# Create the instance with canned_response as a direct parameter
|
1000
|
+
return model_class(**params_copy)
|
1001
|
+
|
1002
|
+
# For non-test models or test models without canned_response
|
917
1003
|
return model_class(**data)
|
918
1004
|
|
919
1005
|
def __repr__(self) -> str:
|
@@ -999,8 +1085,8 @@ class LanguageModel(
|
|
999
1085
|
|
1000
1086
|
Create a test model that throws exceptions:
|
1001
1087
|
|
1002
|
-
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
|
1003
|
-
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
|
1088
|
+
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True) # doctest: +SKIP
|
1089
|
+
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True) # doctest: +SKIP
|
1004
1090
|
Exception report saved to ...
|
1005
1091
|
"""
|
1006
1092
|
from ..language_models import Model
|
@@ -1067,13 +1153,25 @@ class LanguageModel(
|
|
1067
1153
|
}
|
1068
1154
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
1069
1155
|
response = json.loads(cached_response)
|
1070
|
-
|
1156
|
+
|
1157
|
+
try:
|
1158
|
+
usage = self.get_usage_dict(response)
|
1159
|
+
input_tokens = int(usage[self.input_token_name])
|
1160
|
+
output_tokens = int(usage[self.output_token_name])
|
1161
|
+
except Exception as e:
|
1162
|
+
print(f"Could not fetch tokens from model response: {e}")
|
1163
|
+
input_tokens = None
|
1164
|
+
output_tokens = None
|
1071
1165
|
return ModelResponse(
|
1072
1166
|
response=response,
|
1073
1167
|
cache_used=True,
|
1074
1168
|
cache_key=cache_key,
|
1075
1169
|
cached_response=cached_response,
|
1076
|
-
|
1170
|
+
input_tokens=input_tokens,
|
1171
|
+
output_tokens=output_tokens,
|
1172
|
+
input_price_per_million_tokens=0,
|
1173
|
+
output_price_per_million_tokens=0,
|
1174
|
+
total_cost=0,
|
1077
1175
|
)
|
1078
1176
|
|
1079
1177
|
# Bind the new method to the copied instance
|
edsl/language_models/model.py
CHANGED
@@ -6,8 +6,10 @@ from ..utilities import PrettyList
|
|
6
6
|
from ..config import CONFIG
|
7
7
|
from .exceptions import LanguageModelValueError
|
8
8
|
|
9
|
+
# Import only what's needed initially to avoid circular imports
|
9
10
|
from ..inference_services import (InferenceServicesCollection,
|
10
|
-
AvailableModels, InferenceServiceABC, InferenceServiceError
|
11
|
+
AvailableModels, InferenceServiceABC, InferenceServiceError)
|
12
|
+
# The 'default' import will be imported lazily when needed
|
11
13
|
|
12
14
|
from ..enums import InferenceServiceLiteral
|
13
15
|
|
@@ -20,7 +22,10 @@ def get_model_class(
|
|
20
22
|
registry: Optional[InferenceServicesCollection] = None,
|
21
23
|
service_name: Optional[InferenceServiceLiteral] = None,
|
22
24
|
):
|
23
|
-
|
25
|
+
if registry is None:
|
26
|
+
# Import default lazily only when needed
|
27
|
+
from ..inference_services import default as inference_default
|
28
|
+
registry = inference_default
|
24
29
|
try:
|
25
30
|
factory = registry.create_model_factory(model_name, service_name=service_name)
|
26
31
|
return factory
|
@@ -54,7 +59,9 @@ class Model(metaclass=Meta):
|
|
54
59
|
def get_registry(cls) -> InferenceServicesCollection:
|
55
60
|
"""Get the current registry or initialize with default if None"""
|
56
61
|
if cls._registry is None:
|
57
|
-
|
62
|
+
# Import default lazily only when needed
|
63
|
+
from ..inference_services import default as inference_default
|
64
|
+
cls._registry = inference_default
|
58
65
|
return cls._registry
|
59
66
|
|
60
67
|
@classmethod
|
@@ -1,4 +1,22 @@
|
|
1
|
-
from
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Dict, Literal, Tuple, Union
|
3
|
+
from collections import namedtuple
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class ResponseCost:
|
8
|
+
"""
|
9
|
+
Class for storing the cost and token usage of a language model response.
|
10
|
+
|
11
|
+
If an error occurs when computing the cost, the total_cost will contain a string with the error message.
|
12
|
+
All other fields will be None.
|
13
|
+
"""
|
14
|
+
|
15
|
+
input_tokens: Union[int, None] = None
|
16
|
+
output_tokens: Union[int, None] = None
|
17
|
+
input_price_per_million_tokens: Union[float, None] = None
|
18
|
+
output_price_per_million_tokens: Union[float, None] = None
|
19
|
+
total_cost: Union[float, str, None] = None
|
2
20
|
|
3
21
|
|
4
22
|
class PriceManager:
|
@@ -8,20 +26,42 @@ class PriceManager:
|
|
8
26
|
|
9
27
|
def __new__(cls):
|
10
28
|
if cls._instance is None:
|
11
|
-
|
29
|
+
instance = super(PriceManager, cls).__new__(cls)
|
30
|
+
instance._price_lookup = {} # Instance-specific attribute
|
31
|
+
instance._is_initialized = False
|
32
|
+
cls._instance = instance # Store the instance directly
|
33
|
+
return instance
|
12
34
|
return cls._instance
|
13
35
|
|
14
36
|
def __init__(self):
|
15
|
-
|
37
|
+
"""Initialize the singleton instance only once."""
|
16
38
|
if not self._is_initialized:
|
17
39
|
self._is_initialized = True
|
18
40
|
self.refresh_prices()
|
19
41
|
|
20
|
-
|
21
|
-
|
22
|
-
|
42
|
+
@classmethod
|
43
|
+
def get_instance(cls):
|
44
|
+
"""Get the singleton instance, creating it if necessary."""
|
45
|
+
if cls._instance is None:
|
46
|
+
cls() # Create the instance if it doesn't exist
|
47
|
+
return cls._instance
|
23
48
|
|
24
|
-
|
49
|
+
@classmethod
|
50
|
+
def reset(cls):
|
51
|
+
"""Reset the singleton instance to clean up resources."""
|
52
|
+
cls._instance = None
|
53
|
+
cls._is_initialized = False
|
54
|
+
cls._price_lookup = {}
|
55
|
+
|
56
|
+
def __del__(self):
|
57
|
+
"""Ensure proper cleanup when the instance is garbage collected."""
|
58
|
+
try:
|
59
|
+
self._price_lookup = {} # Clean up resources
|
60
|
+
except:
|
61
|
+
pass # Ignore any cleanup errors
|
62
|
+
|
63
|
+
def refresh_prices(self) -> None:
|
64
|
+
"""Fetch fresh prices and update the internal price lookup."""
|
25
65
|
from edsl.coop import Coop
|
26
66
|
|
27
67
|
c = Coop()
|
@@ -31,29 +71,14 @@ class PriceManager:
|
|
31
71
|
print(f"Error fetching prices: {str(e)}")
|
32
72
|
|
33
73
|
def get_price(self, inference_service: str, model: str) -> Dict:
|
34
|
-
"""
|
35
|
-
Get the price information for a specific service and model combination.
|
36
|
-
If no specific price is found, returns a fallback price.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
inference_service (str): The name of the inference service
|
40
|
-
model (str): The model identifier
|
41
|
-
|
42
|
-
Returns:
|
43
|
-
Dict: Price information (either actual or fallback prices)
|
44
|
-
"""
|
74
|
+
"""Get the price information for a specific service and model."""
|
45
75
|
key = (inference_service, model)
|
46
76
|
return self._price_lookup.get(key) or self._get_fallback_price(
|
47
77
|
inference_service
|
48
78
|
)
|
49
79
|
|
50
80
|
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
51
|
-
"""
|
52
|
-
Get the complete price lookup dictionary.
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
Dict[Tuple[str, str], Dict]: The complete price lookup dictionary
|
56
|
-
"""
|
81
|
+
"""Get the complete price lookup dictionary."""
|
57
82
|
return self._price_lookup.copy()
|
58
83
|
|
59
84
|
def _get_fallback_price(self, inference_service: str) -> Dict:
|
@@ -68,73 +93,95 @@ class PriceManager:
|
|
68
93
|
Returns:
|
69
94
|
Dict: Price information
|
70
95
|
"""
|
96
|
+
PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
|
97
|
+
|
71
98
|
service_prices = [
|
72
99
|
prices
|
73
100
|
for (service, _), prices in self._price_lookup.items()
|
74
101
|
if service == inference_service
|
75
102
|
]
|
76
103
|
|
77
|
-
|
78
|
-
|
104
|
+
default_price_info = {
|
105
|
+
"one_usd_buys": 1_000_000,
|
106
|
+
"service_stated_token_qty": 1_000_000,
|
107
|
+
"service_stated_token_price": 1.0,
|
108
|
+
}
|
109
|
+
|
110
|
+
# Find the most expensive price entries (lowest tokens per USD)
|
111
|
+
input_price_info = default_price_info
|
112
|
+
output_price_info = default_price_info
|
113
|
+
|
114
|
+
input_prices = [
|
115
|
+
PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
|
116
|
+
for p in service_prices
|
117
|
+
if "input" in p
|
79
118
|
]
|
80
|
-
if
|
81
|
-
|
82
|
-
|
83
|
-
|
119
|
+
if input_prices:
|
120
|
+
input_price_info = min(
|
121
|
+
input_prices, key=lambda price: price.tokens_per_usd
|
122
|
+
).price_info
|
84
123
|
|
85
|
-
|
86
|
-
float(p["output"]["one_usd_buys"])
|
124
|
+
output_prices = [
|
125
|
+
PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
|
126
|
+
for p in service_prices
|
127
|
+
if "output" in p
|
87
128
|
]
|
88
|
-
if
|
89
|
-
|
90
|
-
|
91
|
-
|
129
|
+
if output_prices:
|
130
|
+
output_price_info = min(
|
131
|
+
output_prices, key=lambda price: price.tokens_per_usd
|
132
|
+
).price_info
|
92
133
|
|
93
134
|
return {
|
94
|
-
"input":
|
95
|
-
"output":
|
135
|
+
"input": input_price_info,
|
136
|
+
"output": output_price_info,
|
96
137
|
}
|
97
138
|
|
98
|
-
def
|
139
|
+
def get_price_per_million_tokens(
|
99
140
|
self,
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
input_token_name: str,
|
104
|
-
output_token_name: str,
|
105
|
-
) -> Union[float, str]:
|
141
|
+
relevant_prices: Dict,
|
142
|
+
token_type: Literal["input", "output"],
|
143
|
+
) -> Dict:
|
106
144
|
"""
|
107
|
-
|
145
|
+
Get the price per million tokens for a specific service, model, and token type.
|
146
|
+
"""
|
147
|
+
service_price = relevant_prices[token_type]["service_stated_token_price"]
|
148
|
+
service_qty = relevant_prices[token_type]["service_stated_token_qty"]
|
108
149
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
150
|
+
if service_qty == 1_000_000:
|
151
|
+
price_per_million_tokens = service_price
|
152
|
+
elif service_qty == 1_000:
|
153
|
+
price_per_million_tokens = service_price * 1_000
|
154
|
+
else:
|
155
|
+
price_per_token = service_price / service_qty
|
156
|
+
price_per_million_tokens = round(price_per_token * 1_000_000, 10)
|
157
|
+
return price_per_million_tokens
|
115
158
|
|
116
|
-
|
117
|
-
|
159
|
+
def _calculate_total_cost(
|
160
|
+
self,
|
161
|
+
relevant_prices: Dict,
|
162
|
+
input_tokens: int,
|
163
|
+
output_tokens: int,
|
164
|
+
) -> float:
|
118
165
|
"""
|
119
|
-
|
120
|
-
|
121
|
-
# Extract token counts
|
122
|
-
try:
|
123
|
-
input_tokens = int(usage[input_token_name])
|
124
|
-
output_tokens = int(usage[output_token_name])
|
125
|
-
except Exception as e:
|
126
|
-
return f"Could not fetch tokens from model response: {e}"
|
166
|
+
Calculate the total cost for a model usage based on input and output tokens.
|
127
167
|
|
168
|
+
Returns:
|
169
|
+
float: Total cost
|
170
|
+
"""
|
128
171
|
# Extract price information
|
129
172
|
try:
|
130
173
|
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
131
174
|
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
132
175
|
except Exception as e:
|
133
176
|
if "output" not in relevant_prices:
|
134
|
-
|
177
|
+
raise KeyError(
|
178
|
+
f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
179
|
+
)
|
135
180
|
if "input" not in relevant_prices:
|
136
|
-
|
137
|
-
|
181
|
+
raise KeyError(
|
182
|
+
f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
183
|
+
)
|
184
|
+
raise Exception(f"Could not fetch prices from {relevant_prices} - {e}")
|
138
185
|
|
139
186
|
# Calculate input cost
|
140
187
|
if inverse_input_price == "infinity":
|
@@ -143,7 +190,7 @@ class PriceManager:
|
|
143
190
|
try:
|
144
191
|
input_cost = input_tokens / float(inverse_input_price)
|
145
192
|
except Exception as e:
|
146
|
-
|
193
|
+
raise Exception(f"Could not compute input price - {e}")
|
147
194
|
|
148
195
|
# Calculate output cost
|
149
196
|
if inverse_output_price == "infinity":
|
@@ -152,16 +199,74 @@ class PriceManager:
|
|
152
199
|
try:
|
153
200
|
output_cost = output_tokens / float(inverse_output_price)
|
154
201
|
except Exception as e:
|
155
|
-
|
202
|
+
raise Exception(f"Could not compute output price - {e}")
|
156
203
|
|
157
204
|
return input_cost + output_cost
|
158
205
|
|
159
|
-
|
160
|
-
|
206
|
+
def calculate_cost(
|
207
|
+
self,
|
208
|
+
inference_service: str,
|
209
|
+
model: str,
|
210
|
+
usage: Dict[str, Union[str, int]],
|
211
|
+
input_token_name: str,
|
212
|
+
output_token_name: str,
|
213
|
+
) -> ResponseCost:
|
161
214
|
"""
|
162
|
-
|
215
|
+
Calculate the cost and token usage for a model response.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
inference_service (str): The inference service identifier
|
219
|
+
model (str): The model identifier
|
220
|
+
usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
|
221
|
+
input_token_name (str): Key name for input tokens in the usage dict
|
222
|
+
output_token_name (str): Key name for output tokens in the usage dict
|
163
223
|
|
164
224
|
Returns:
|
165
|
-
|
225
|
+
ResponseCost: Object containing token counts and total cost
|
166
226
|
"""
|
227
|
+
try:
|
228
|
+
input_tokens = int(usage[input_token_name])
|
229
|
+
output_tokens = int(usage[output_token_name])
|
230
|
+
except Exception as e:
|
231
|
+
return ResponseCost(
|
232
|
+
total_cost=f"Could not fetch tokens from model response: {e}",
|
233
|
+
)
|
234
|
+
|
235
|
+
try:
|
236
|
+
relevant_prices = self.get_price(inference_service, model)
|
237
|
+
except Exception as e:
|
238
|
+
return ResponseCost(
|
239
|
+
total_cost=f"Could not fetch prices from {inference_service} - {model}: {e}",
|
240
|
+
)
|
241
|
+
|
242
|
+
try:
|
243
|
+
input_price_per_million_tokens = self.get_price_per_million_tokens(
|
244
|
+
relevant_prices, "input"
|
245
|
+
)
|
246
|
+
output_price_per_million_tokens = self.get_price_per_million_tokens(
|
247
|
+
relevant_prices, "output"
|
248
|
+
)
|
249
|
+
except Exception as e:
|
250
|
+
return ResponseCost(
|
251
|
+
total_cost=f"Could not compute price per million tokens: {e}",
|
252
|
+
)
|
253
|
+
|
254
|
+
try:
|
255
|
+
total_cost = self._calculate_total_cost(
|
256
|
+
relevant_prices, input_tokens, output_tokens
|
257
|
+
)
|
258
|
+
except Exception as e:
|
259
|
+
return ResponseCost(total_cost=f"{e}")
|
260
|
+
|
261
|
+
return ResponseCost(
|
262
|
+
input_tokens=input_tokens,
|
263
|
+
output_tokens=output_tokens,
|
264
|
+
input_price_per_million_tokens=input_price_per_million_tokens,
|
265
|
+
output_price_per_million_tokens=output_price_per_million_tokens,
|
266
|
+
total_cost=total_cost,
|
267
|
+
)
|
268
|
+
|
269
|
+
@property
|
270
|
+
def is_initialized(self) -> bool:
|
271
|
+
"""Check if the PriceManager has been initialized."""
|
167
272
|
return self._is_initialized
|
edsl/language_models/registry.py
CHANGED
@@ -18,6 +18,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
18
18
|
|
19
19
|
_registry = {} # Initialize the registry as a dictionary
|
20
20
|
REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def clear_registry(cls):
|
24
|
+
"""Clear the registry to prevent memory leaks."""
|
25
|
+
cls._registry = {}
|
21
26
|
|
22
27
|
def __init__(cls, name, bases, dct):
|
23
28
|
"""Register the class in the registry if it has a _model_ attribute."""
|