edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/base/data_transfer_models.py +15 -4
  8. edsl/buckets/__init__.py +8 -3
  9. edsl/buckets/bucket_collection.py +9 -3
  10. edsl/buckets/model_buckets.py +4 -2
  11. edsl/buckets/token_bucket.py +2 -2
  12. edsl/buckets/token_bucket_client.py +5 -3
  13. edsl/caching/cache.py +131 -62
  14. edsl/caching/cache_entry.py +70 -58
  15. edsl/caching/sql_dict.py +17 -0
  16. edsl/cli.py +99 -0
  17. edsl/config/config_class.py +16 -0
  18. edsl/conversation/__init__.py +31 -0
  19. edsl/coop/coop.py +276 -242
  20. edsl/coop/coop_jobs_objects.py +59 -0
  21. edsl/coop/coop_objects.py +29 -0
  22. edsl/coop/coop_regular_objects.py +26 -0
  23. edsl/coop/utils.py +24 -19
  24. edsl/dataset/dataset.py +338 -101
  25. edsl/dataset/dataset_operations_mixin.py +216 -180
  26. edsl/db_list/sqlite_list.py +349 -0
  27. edsl/inference_services/__init__.py +40 -5
  28. edsl/inference_services/exceptions.py +11 -0
  29. edsl/inference_services/services/anthropic_service.py +5 -2
  30. edsl/inference_services/services/aws_bedrock.py +6 -2
  31. edsl/inference_services/services/azure_ai.py +6 -2
  32. edsl/inference_services/services/google_service.py +7 -3
  33. edsl/inference_services/services/mistral_ai_service.py +6 -2
  34. edsl/inference_services/services/open_ai_service.py +6 -2
  35. edsl/inference_services/services/perplexity_service.py +6 -2
  36. edsl/inference_services/services/test_service.py +94 -5
  37. edsl/interviews/answering_function.py +167 -59
  38. edsl/interviews/interview.py +124 -72
  39. edsl/interviews/interview_task_manager.py +10 -0
  40. edsl/interviews/request_token_estimator.py +8 -0
  41. edsl/invigilators/invigilators.py +35 -13
  42. edsl/jobs/async_interview_runner.py +146 -104
  43. edsl/jobs/data_structures.py +6 -4
  44. edsl/jobs/decorators.py +61 -0
  45. edsl/jobs/fetch_invigilator.py +61 -18
  46. edsl/jobs/html_table_job_logger.py +14 -2
  47. edsl/jobs/jobs.py +180 -104
  48. edsl/jobs/jobs_component_constructor.py +2 -2
  49. edsl/jobs/jobs_interview_constructor.py +2 -0
  50. edsl/jobs/jobs_pricing_estimation.py +154 -113
  51. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  52. edsl/jobs/jobs_runner_status.py +30 -25
  53. edsl/jobs/progress_bar_manager.py +79 -0
  54. edsl/jobs/remote_inference.py +35 -1
  55. edsl/key_management/key_lookup_builder.py +6 -1
  56. edsl/language_models/language_model.py +110 -12
  57. edsl/language_models/model.py +10 -3
  58. edsl/language_models/price_manager.py +176 -71
  59. edsl/language_models/registry.py +5 -0
  60. edsl/notebooks/notebook.py +77 -10
  61. edsl/questions/VALIDATION_README.md +134 -0
  62. edsl/questions/__init__.py +24 -1
  63. edsl/questions/exceptions.py +21 -0
  64. edsl/questions/question_dict.py +201 -16
  65. edsl/questions/question_multiple_choice_with_other.py +624 -0
  66. edsl/questions/question_registry.py +2 -1
  67. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  68. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  69. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  70. edsl/questions/validation_analysis.py +185 -0
  71. edsl/questions/validation_cli.py +131 -0
  72. edsl/questions/validation_html_report.py +404 -0
  73. edsl/questions/validation_logger.py +136 -0
  74. edsl/results/result.py +115 -46
  75. edsl/results/results.py +702 -171
  76. edsl/scenarios/construct_download_link.py +16 -3
  77. edsl/scenarios/directory_scanner.py +226 -226
  78. edsl/scenarios/file_methods.py +5 -0
  79. edsl/scenarios/file_store.py +150 -9
  80. edsl/scenarios/handlers/__init__.py +5 -1
  81. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  82. edsl/scenarios/handlers/webm_file_store.py +104 -0
  83. edsl/scenarios/scenario.py +120 -101
  84. edsl/scenarios/scenario_list.py +800 -727
  85. edsl/scenarios/scenario_list_gc_test.py +146 -0
  86. edsl/scenarios/scenario_list_memory_test.py +214 -0
  87. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  88. edsl/scenarios/scenario_selector.py +5 -4
  89. edsl/scenarios/scenario_source.py +1990 -0
  90. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  91. edsl/surveys/survey.py +22 -0
  92. edsl/tasks/__init__.py +4 -2
  93. edsl/tasks/task_history.py +198 -36
  94. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  95. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  96. edsl/utilities/__init__.py +2 -1
  97. edsl/utilities/decorators.py +121 -0
  98. edsl/utilities/memory_debugger.py +1010 -0
  99. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
  100. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
  101. edsl/jobs/jobs_runner_asyncio.py +0 -281
  102. edsl/language_models/unused/fake_openai_service.py +0 -60
  103. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
  104. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
  105. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
@@ -49,6 +49,7 @@ from ..data_transfer_models import (
49
49
  )
50
50
 
51
51
  if TYPE_CHECKING:
52
+ from .price_manager import ResponseCost
52
53
  from ..caching import Cache
53
54
  from ..scenarios import FileStore
54
55
  from ..questions import QuestionBase
@@ -365,6 +366,59 @@ class LanguageModel(
365
366
  self._api_token = info.api_token
366
367
  return self._api_token
367
368
 
369
+ def copy(self) -> "LanguageModel":
370
+ """Create a deep copy of this language model instance.
371
+
372
+ This method creates a completely independent copy of the language model
373
+ by creating a new instance with the same parameters and copying relevant attributes.
374
+
375
+ Returns:
376
+ LanguageModel: A new language model instance that is functionally identical to this one
377
+
378
+ Examples:
379
+ >>> m1 = LanguageModel.example()
380
+ >>> m2 = m1.copy()
381
+ >>> m1 == m2 # Functionally equivalent
382
+ True
383
+ >>> id(m1) == id(m2) # But different objects
384
+ False
385
+ """
386
+ # Create a new instance of the same class with the same parameters
387
+ try:
388
+ # For most models, we can instantiate with the saved parameters
389
+ new_model = self.__class__(**self.parameters)
390
+
391
+ # Copy all important instance attributes
392
+ for key, value in self.__dict__.items():
393
+ if key not in ("_api_token",) and not key.startswith("__"):
394
+ setattr(new_model, key, value)
395
+
396
+ return new_model
397
+ except Exception:
398
+ # Fallback for dynamically created classes like TestServiceLanguageModel
399
+ from ..inference_services import default
400
+
401
+ # If this is a test model, create a new test model instance
402
+ if getattr(self, "_inference_service_", "") == "test":
403
+ service = default.get_service("test")
404
+ model_class = service.create_model("test")
405
+ new_model = model_class(**self.parameters)
406
+
407
+ # Copy attributes
408
+ for key, value in self.__dict__.items():
409
+ if key not in ("_api_token",) and not key.startswith("__"):
410
+ setattr(new_model, key, value)
411
+
412
+ return new_model
413
+
414
+ # If we can't create the model directly, just return a simple test model
415
+ # This is a last resort fallback
416
+ from ..inference_services import get_service
417
+
418
+ service = get_service("test")
419
+ model_class = service.create_model("test")
420
+ return model_class()
421
+
368
422
  def __getitem__(self, key):
369
423
  """Allow dictionary-style access to model attributes.
370
424
 
@@ -679,9 +733,12 @@ class LanguageModel(
679
733
  user_prompt_with_hashes = user_prompt
680
734
 
681
735
  # Prepare parameters for cache lookup
736
+ cache_parameters = self.parameters.copy()
737
+ if self.model == "test":
738
+ cache_parameters.pop("canned_response", None)
682
739
  cache_call_params = {
683
740
  "model": str(self.model),
684
- "parameters": self.parameters,
741
+ "parameters": cache_parameters,
685
742
  "system_prompt": system_prompt,
686
743
  "user_prompt": user_prompt_with_hashes,
687
744
  "iteration": iteration,
@@ -726,13 +783,18 @@ class LanguageModel(
726
783
  # Calculate cost for the response
727
784
  cost = self.cost(response)
728
785
  # Return a structured response with metadata
729
- return ModelResponse(
786
+ response = ModelResponse(
730
787
  response=response,
731
788
  cache_used=cache_used,
732
789
  cache_key=cache_key,
733
790
  cached_response=cached_response,
734
- cost=cost,
791
+ input_tokens=cost.input_tokens,
792
+ output_tokens=cost.output_tokens,
793
+ input_price_per_million_tokens=cost.input_price_per_million_tokens,
794
+ output_price_per_million_tokens=cost.output_price_per_million_tokens,
795
+ total_cost=cost.total_cost,
735
796
  )
797
+ return response
736
798
 
737
799
  _get_intended_model_call_outcome = sync_wrapper(
738
800
  _async_get_intended_model_call_outcome
@@ -825,7 +887,7 @@ class LanguageModel(
825
887
 
826
888
  get_response = sync_wrapper(async_get_response)
827
889
 
828
- def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
890
+ def cost(self, raw_response: dict[str, Any]) -> ResponseCost:
829
891
  """Calculate the monetary cost of a model API call.
830
892
 
831
893
  This method extracts token usage information from the response and
@@ -836,7 +898,7 @@ class LanguageModel(
836
898
  raw_response: The complete response dictionary from the model API
837
899
 
838
900
  Returns:
839
- Union[float, str]: The calculated cost in dollars, or an error message
901
+ ResponseCost: Object containing token counts and total cost
840
902
  """
841
903
  # Extract token usage data from the response
842
904
  usage = self.get_usage_dict(raw_response)
@@ -844,7 +906,7 @@ class LanguageModel(
844
906
  # Use the price manager to calculate the actual cost
845
907
  from .price_manager import PriceManager
846
908
 
847
- price_manager = PriceManager()
909
+ price_manager = PriceManager.get_instance()
848
910
 
849
911
  return price_manager.calculate_cost(
850
912
  inference_service=self._inference_service_,
@@ -873,9 +935,15 @@ class LanguageModel(
873
935
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
874
936
  """
875
937
  # Build the base dictionary with essential model information
938
+ parameters = self.parameters.copy()
939
+
940
+ # For test models, ensure canned_response is included in serialization
941
+ if self.model == "test" and hasattr(self, "canned_response"):
942
+ parameters["canned_response"] = self.canned_response
943
+
876
944
  d = {
877
945
  "model": self.model,
878
- "parameters": self.parameters,
946
+ "parameters": parameters,
879
947
  "inference_service": self._inference_service_,
880
948
  }
881
949
 
@@ -913,7 +981,25 @@ class LanguageModel(
913
981
  data["model"], service_name=data.get("inference_service", None)
914
982
  )
915
983
 
916
- # Create and return a new instance
984
+ # Handle canned_response in parameters for test models
985
+ if (
986
+ data["model"] == "test"
987
+ and "parameters" in data
988
+ and "canned_response" in data["parameters"]
989
+ ):
990
+ # Extract canned_response from parameters to set as a direct attribute
991
+ canned_response = data["parameters"]["canned_response"]
992
+ params_copy = data.copy()
993
+
994
+ # Direct attribute will be set during initialization
995
+ # Add it as a top-level parameter for model initialization
996
+ if isinstance(params_copy, dict) and "parameters" in params_copy:
997
+ params_copy["canned_response"] = canned_response
998
+
999
+ # Create the instance with canned_response as a direct parameter
1000
+ return model_class(**params_copy)
1001
+
1002
+ # For non-test models or test models without canned_response
917
1003
  return model_class(**data)
918
1004
 
919
1005
  def __repr__(self) -> str:
@@ -999,8 +1085,8 @@ class LanguageModel(
999
1085
 
1000
1086
  Create a test model that throws exceptions:
1001
1087
 
1002
- >>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
1003
- >>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
1088
+ >>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True) # doctest: +SKIP
1089
+ >>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True) # doctest: +SKIP
1004
1090
  Exception report saved to ...
1005
1091
  """
1006
1092
  from ..language_models import Model
@@ -1067,13 +1153,25 @@ class LanguageModel(
1067
1153
  }
1068
1154
  cached_response, cache_key = cache.fetch(**cache_call_params)
1069
1155
  response = json.loads(cached_response)
1070
- cost = 0
1156
+
1157
+ try:
1158
+ usage = self.get_usage_dict(response)
1159
+ input_tokens = int(usage[self.input_token_name])
1160
+ output_tokens = int(usage[self.output_token_name])
1161
+ except Exception as e:
1162
+ print(f"Could not fetch tokens from model response: {e}")
1163
+ input_tokens = None
1164
+ output_tokens = None
1071
1165
  return ModelResponse(
1072
1166
  response=response,
1073
1167
  cache_used=True,
1074
1168
  cache_key=cache_key,
1075
1169
  cached_response=cached_response,
1076
- cost=cost,
1170
+ input_tokens=input_tokens,
1171
+ output_tokens=output_tokens,
1172
+ input_price_per_million_tokens=0,
1173
+ output_price_per_million_tokens=0,
1174
+ total_cost=0,
1077
1175
  )
1078
1176
 
1079
1177
  # Bind the new method to the copied instance
@@ -6,8 +6,10 @@ from ..utilities import PrettyList
6
6
  from ..config import CONFIG
7
7
  from .exceptions import LanguageModelValueError
8
8
 
9
+ # Import only what's needed initially to avoid circular imports
9
10
  from ..inference_services import (InferenceServicesCollection,
10
- AvailableModels, InferenceServiceABC, InferenceServiceError, default)
11
+ AvailableModels, InferenceServiceABC, InferenceServiceError)
12
+ # The 'default' import will be imported lazily when needed
11
13
 
12
14
  from ..enums import InferenceServiceLiteral
13
15
 
@@ -20,7 +22,10 @@ def get_model_class(
20
22
  registry: Optional[InferenceServicesCollection] = None,
21
23
  service_name: Optional[InferenceServiceLiteral] = None,
22
24
  ):
23
- registry = registry or default
25
+ if registry is None:
26
+ # Import default lazily only when needed
27
+ from ..inference_services import default as inference_default
28
+ registry = inference_default
24
29
  try:
25
30
  factory = registry.create_model_factory(model_name, service_name=service_name)
26
31
  return factory
@@ -54,7 +59,9 @@ class Model(metaclass=Meta):
54
59
  def get_registry(cls) -> InferenceServicesCollection:
55
60
  """Get the current registry or initialize with default if None"""
56
61
  if cls._registry is None:
57
- cls._registry = default
62
+ # Import default lazily only when needed
63
+ from ..inference_services import default as inference_default
64
+ cls._registry = inference_default
58
65
  return cls._registry
59
66
 
60
67
  @classmethod
@@ -1,4 +1,22 @@
1
- from typing import Dict, Tuple, Union
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Literal, Tuple, Union
3
+ from collections import namedtuple
4
+
5
+
6
+ @dataclass
7
+ class ResponseCost:
8
+ """
9
+ Class for storing the cost and token usage of a language model response.
10
+
11
+ If an error occurs when computing the cost, the total_cost will contain a string with the error message.
12
+ All other fields will be None.
13
+ """
14
+
15
+ input_tokens: Union[int, None] = None
16
+ output_tokens: Union[int, None] = None
17
+ input_price_per_million_tokens: Union[float, None] = None
18
+ output_price_per_million_tokens: Union[float, None] = None
19
+ total_cost: Union[float, str, None] = None
2
20
 
3
21
 
4
22
  class PriceManager:
@@ -8,20 +26,42 @@ class PriceManager:
8
26
 
9
27
  def __new__(cls):
10
28
  if cls._instance is None:
11
- cls._instance = super(PriceManager, cls).__new__(cls)
29
+ instance = super(PriceManager, cls).__new__(cls)
30
+ instance._price_lookup = {} # Instance-specific attribute
31
+ instance._is_initialized = False
32
+ cls._instance = instance # Store the instance directly
33
+ return instance
12
34
  return cls._instance
13
35
 
14
36
  def __init__(self):
15
- # Only initialize once, even if __init__ is called multiple times
37
+ """Initialize the singleton instance only once."""
16
38
  if not self._is_initialized:
17
39
  self._is_initialized = True
18
40
  self.refresh_prices()
19
41
 
20
- def refresh_prices(self) -> None:
21
- """
22
- Fetch fresh prices from the Coop service and update the internal price lookup.
42
+ @classmethod
43
+ def get_instance(cls):
44
+ """Get the singleton instance, creating it if necessary."""
45
+ if cls._instance is None:
46
+ cls() # Create the instance if it doesn't exist
47
+ return cls._instance
23
48
 
24
- """
49
+ @classmethod
50
+ def reset(cls):
51
+ """Reset the singleton instance to clean up resources."""
52
+ cls._instance = None
53
+ cls._is_initialized = False
54
+ cls._price_lookup = {}
55
+
56
+ def __del__(self):
57
+ """Ensure proper cleanup when the instance is garbage collected."""
58
+ try:
59
+ self._price_lookup = {} # Clean up resources
60
+ except:
61
+ pass # Ignore any cleanup errors
62
+
63
+ def refresh_prices(self) -> None:
64
+ """Fetch fresh prices and update the internal price lookup."""
25
65
  from edsl.coop import Coop
26
66
 
27
67
  c = Coop()
@@ -31,29 +71,14 @@ class PriceManager:
31
71
  print(f"Error fetching prices: {str(e)}")
32
72
 
33
73
  def get_price(self, inference_service: str, model: str) -> Dict:
34
- """
35
- Get the price information for a specific service and model combination.
36
- If no specific price is found, returns a fallback price.
37
-
38
- Args:
39
- inference_service (str): The name of the inference service
40
- model (str): The model identifier
41
-
42
- Returns:
43
- Dict: Price information (either actual or fallback prices)
44
- """
74
+ """Get the price information for a specific service and model."""
45
75
  key = (inference_service, model)
46
76
  return self._price_lookup.get(key) or self._get_fallback_price(
47
77
  inference_service
48
78
  )
49
79
 
50
80
  def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
51
- """
52
- Get the complete price lookup dictionary.
53
-
54
- Returns:
55
- Dict[Tuple[str, str], Dict]: The complete price lookup dictionary
56
- """
81
+ """Get the complete price lookup dictionary."""
57
82
  return self._price_lookup.copy()
58
83
 
59
84
  def _get_fallback_price(self, inference_service: str) -> Dict:
@@ -68,73 +93,95 @@ class PriceManager:
68
93
  Returns:
69
94
  Dict: Price information
70
95
  """
96
+ PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
97
+
71
98
  service_prices = [
72
99
  prices
73
100
  for (service, _), prices in self._price_lookup.items()
74
101
  if service == inference_service
75
102
  ]
76
103
 
77
- input_tokens_per_usd = [
78
- float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
104
+ default_price_info = {
105
+ "one_usd_buys": 1_000_000,
106
+ "service_stated_token_qty": 1_000_000,
107
+ "service_stated_token_price": 1.0,
108
+ }
109
+
110
+ # Find the most expensive price entries (lowest tokens per USD)
111
+ input_price_info = default_price_info
112
+ output_price_info = default_price_info
113
+
114
+ input_prices = [
115
+ PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
116
+ for p in service_prices
117
+ if "input" in p
79
118
  ]
80
- if input_tokens_per_usd:
81
- min_input_tokens = min(input_tokens_per_usd)
82
- else:
83
- min_input_tokens = 1_000_000
119
+ if input_prices:
120
+ input_price_info = min(
121
+ input_prices, key=lambda price: price.tokens_per_usd
122
+ ).price_info
84
123
 
85
- output_tokens_per_usd = [
86
- float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
124
+ output_prices = [
125
+ PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
126
+ for p in service_prices
127
+ if "output" in p
87
128
  ]
88
- if output_tokens_per_usd:
89
- min_output_tokens = min(output_tokens_per_usd)
90
- else:
91
- min_output_tokens = 1_000_000
129
+ if output_prices:
130
+ output_price_info = min(
131
+ output_prices, key=lambda price: price.tokens_per_usd
132
+ ).price_info
92
133
 
93
134
  return {
94
- "input": {"one_usd_buys": min_input_tokens},
95
- "output": {"one_usd_buys": min_output_tokens},
135
+ "input": input_price_info,
136
+ "output": output_price_info,
96
137
  }
97
138
 
98
- def calculate_cost(
139
+ def get_price_per_million_tokens(
99
140
  self,
100
- inference_service: str,
101
- model: str,
102
- usage: Dict[str, Union[str, int]],
103
- input_token_name: str,
104
- output_token_name: str,
105
- ) -> Union[float, str]:
141
+ relevant_prices: Dict,
142
+ token_type: Literal["input", "output"],
143
+ ) -> Dict:
106
144
  """
107
- Calculate the total cost for a model usage based on input and output tokens.
145
+ Get the price per million tokens for a specific service, model, and token type.
146
+ """
147
+ service_price = relevant_prices[token_type]["service_stated_token_price"]
148
+ service_qty = relevant_prices[token_type]["service_stated_token_qty"]
108
149
 
109
- Args:
110
- inference_service (str): The inference service identifier
111
- model (str): The model identifier
112
- usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
113
- input_token_name (str): Key name for input tokens in the usage dict
114
- output_token_name (str): Key name for output tokens in the usage dict
150
+ if service_qty == 1_000_000:
151
+ price_per_million_tokens = service_price
152
+ elif service_qty == 1_000:
153
+ price_per_million_tokens = service_price * 1_000
154
+ else:
155
+ price_per_token = service_price / service_qty
156
+ price_per_million_tokens = round(price_per_token * 1_000_000, 10)
157
+ return price_per_million_tokens
115
158
 
116
- Returns:
117
- Union[float, str]: Total cost if calculation successful, error message string if not
159
+ def _calculate_total_cost(
160
+ self,
161
+ relevant_prices: Dict,
162
+ input_tokens: int,
163
+ output_tokens: int,
164
+ ) -> float:
118
165
  """
119
- relevant_prices = self.get_price(inference_service, model)
120
-
121
- # Extract token counts
122
- try:
123
- input_tokens = int(usage[input_token_name])
124
- output_tokens = int(usage[output_token_name])
125
- except Exception as e:
126
- return f"Could not fetch tokens from model response: {e}"
166
+ Calculate the total cost for a model usage based on input and output tokens.
127
167
 
168
+ Returns:
169
+ float: Total cost
170
+ """
128
171
  # Extract price information
129
172
  try:
130
173
  inverse_output_price = relevant_prices["output"]["one_usd_buys"]
131
174
  inverse_input_price = relevant_prices["input"]["one_usd_buys"]
132
175
  except Exception as e:
133
176
  if "output" not in relevant_prices:
134
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
177
+ raise KeyError(
178
+ f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
179
+ )
135
180
  if "input" not in relevant_prices:
136
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
137
- return f"Could not fetch prices from {relevant_prices} - {e}"
181
+ raise KeyError(
182
+ f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
183
+ )
184
+ raise Exception(f"Could not fetch prices from {relevant_prices} - {e}")
138
185
 
139
186
  # Calculate input cost
140
187
  if inverse_input_price == "infinity":
@@ -143,7 +190,7 @@ class PriceManager:
143
190
  try:
144
191
  input_cost = input_tokens / float(inverse_input_price)
145
192
  except Exception as e:
146
- return f"Could not compute input price - {e}."
193
+ raise Exception(f"Could not compute input price - {e}")
147
194
 
148
195
  # Calculate output cost
149
196
  if inverse_output_price == "infinity":
@@ -152,16 +199,74 @@ class PriceManager:
152
199
  try:
153
200
  output_cost = output_tokens / float(inverse_output_price)
154
201
  except Exception as e:
155
- return f"Could not compute output price - {e}"
202
+ raise Exception(f"Could not compute output price - {e}")
156
203
 
157
204
  return input_cost + output_cost
158
205
 
159
- @property
160
- def is_initialized(self) -> bool:
206
+ def calculate_cost(
207
+ self,
208
+ inference_service: str,
209
+ model: str,
210
+ usage: Dict[str, Union[str, int]],
211
+ input_token_name: str,
212
+ output_token_name: str,
213
+ ) -> ResponseCost:
161
214
  """
162
- Check if the PriceManager has been initialized.
215
+ Calculate the cost and token usage for a model response.
216
+
217
+ Args:
218
+ inference_service (str): The inference service identifier
219
+ model (str): The model identifier
220
+ usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
221
+ input_token_name (str): Key name for input tokens in the usage dict
222
+ output_token_name (str): Key name for output tokens in the usage dict
163
223
 
164
224
  Returns:
165
- bool: True if initialized, False otherwise
225
+ ResponseCost: Object containing token counts and total cost
166
226
  """
227
+ try:
228
+ input_tokens = int(usage[input_token_name])
229
+ output_tokens = int(usage[output_token_name])
230
+ except Exception as e:
231
+ return ResponseCost(
232
+ total_cost=f"Could not fetch tokens from model response: {e}",
233
+ )
234
+
235
+ try:
236
+ relevant_prices = self.get_price(inference_service, model)
237
+ except Exception as e:
238
+ return ResponseCost(
239
+ total_cost=f"Could not fetch prices from {inference_service} - {model}: {e}",
240
+ )
241
+
242
+ try:
243
+ input_price_per_million_tokens = self.get_price_per_million_tokens(
244
+ relevant_prices, "input"
245
+ )
246
+ output_price_per_million_tokens = self.get_price_per_million_tokens(
247
+ relevant_prices, "output"
248
+ )
249
+ except Exception as e:
250
+ return ResponseCost(
251
+ total_cost=f"Could not compute price per million tokens: {e}",
252
+ )
253
+
254
+ try:
255
+ total_cost = self._calculate_total_cost(
256
+ relevant_prices, input_tokens, output_tokens
257
+ )
258
+ except Exception as e:
259
+ return ResponseCost(total_cost=f"{e}")
260
+
261
+ return ResponseCost(
262
+ input_tokens=input_tokens,
263
+ output_tokens=output_tokens,
264
+ input_price_per_million_tokens=input_price_per_million_tokens,
265
+ output_price_per_million_tokens=output_price_per_million_tokens,
266
+ total_cost=total_cost,
267
+ )
268
+
269
+ @property
270
+ def is_initialized(self) -> bool:
271
+ """Check if the PriceManager has been initialized."""
167
272
  return self._is_initialized
@@ -18,6 +18,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
18
18
 
19
19
  _registry = {} # Initialize the registry as a dictionary
20
20
  REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
21
+
22
+ @classmethod
23
+ def clear_registry(cls):
24
+ """Clear the registry to prevent memory leaks."""
25
+ cls._registry = {}
21
26
 
22
27
  def __init__(cls, name, bases, dct):
23
28
  """Register the class in the registry if it has a _model_ attribute."""