edsl 0.1.54__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/buckets/__init__.py +8 -3
  8. edsl/buckets/bucket_collection.py +9 -3
  9. edsl/buckets/model_buckets.py +4 -2
  10. edsl/buckets/token_bucket.py +2 -2
  11. edsl/buckets/token_bucket_client.py +5 -3
  12. edsl/caching/cache.py +131 -62
  13. edsl/caching/cache_entry.py +70 -58
  14. edsl/caching/sql_dict.py +17 -0
  15. edsl/cli.py +99 -0
  16. edsl/config/config_class.py +16 -0
  17. edsl/conversation/__init__.py +31 -0
  18. edsl/coop/coop.py +276 -242
  19. edsl/coop/coop_jobs_objects.py +59 -0
  20. edsl/coop/coop_objects.py +29 -0
  21. edsl/coop/coop_regular_objects.py +26 -0
  22. edsl/coop/utils.py +24 -19
  23. edsl/dataset/dataset.py +338 -101
  24. edsl/db_list/sqlite_list.py +349 -0
  25. edsl/inference_services/__init__.py +40 -5
  26. edsl/inference_services/exceptions.py +11 -0
  27. edsl/inference_services/services/anthropic_service.py +5 -2
  28. edsl/inference_services/services/aws_bedrock.py +6 -2
  29. edsl/inference_services/services/azure_ai.py +6 -2
  30. edsl/inference_services/services/google_service.py +3 -2
  31. edsl/inference_services/services/mistral_ai_service.py +6 -2
  32. edsl/inference_services/services/open_ai_service.py +6 -2
  33. edsl/inference_services/services/perplexity_service.py +6 -2
  34. edsl/inference_services/services/test_service.py +94 -5
  35. edsl/interviews/answering_function.py +167 -59
  36. edsl/interviews/interview.py +124 -72
  37. edsl/interviews/interview_task_manager.py +10 -0
  38. edsl/invigilators/invigilators.py +9 -0
  39. edsl/jobs/async_interview_runner.py +146 -104
  40. edsl/jobs/data_structures.py +6 -4
  41. edsl/jobs/decorators.py +61 -0
  42. edsl/jobs/fetch_invigilator.py +61 -18
  43. edsl/jobs/html_table_job_logger.py +14 -2
  44. edsl/jobs/jobs.py +180 -104
  45. edsl/jobs/jobs_component_constructor.py +2 -2
  46. edsl/jobs/jobs_interview_constructor.py +2 -0
  47. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  48. edsl/jobs/jobs_runner_status.py +30 -25
  49. edsl/jobs/progress_bar_manager.py +79 -0
  50. edsl/jobs/remote_inference.py +35 -1
  51. edsl/key_management/key_lookup_builder.py +6 -1
  52. edsl/language_models/language_model.py +86 -6
  53. edsl/language_models/model.py +10 -3
  54. edsl/language_models/price_manager.py +45 -75
  55. edsl/language_models/registry.py +5 -0
  56. edsl/notebooks/notebook.py +77 -10
  57. edsl/questions/VALIDATION_README.md +134 -0
  58. edsl/questions/__init__.py +24 -1
  59. edsl/questions/exceptions.py +21 -0
  60. edsl/questions/question_dict.py +201 -16
  61. edsl/questions/question_multiple_choice_with_other.py +624 -0
  62. edsl/questions/question_registry.py +2 -1
  63. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  64. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  65. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  66. edsl/questions/validation_analysis.py +185 -0
  67. edsl/questions/validation_cli.py +131 -0
  68. edsl/questions/validation_html_report.py +404 -0
  69. edsl/questions/validation_logger.py +136 -0
  70. edsl/results/result.py +63 -16
  71. edsl/results/results.py +702 -171
  72. edsl/scenarios/construct_download_link.py +16 -3
  73. edsl/scenarios/directory_scanner.py +226 -226
  74. edsl/scenarios/file_methods.py +5 -0
  75. edsl/scenarios/file_store.py +117 -6
  76. edsl/scenarios/handlers/__init__.py +5 -1
  77. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  78. edsl/scenarios/handlers/webm_file_store.py +104 -0
  79. edsl/scenarios/scenario.py +120 -101
  80. edsl/scenarios/scenario_list.py +800 -727
  81. edsl/scenarios/scenario_list_gc_test.py +146 -0
  82. edsl/scenarios/scenario_list_memory_test.py +214 -0
  83. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  84. edsl/scenarios/scenario_selector.py +5 -4
  85. edsl/scenarios/scenario_source.py +1990 -0
  86. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  87. edsl/surveys/survey.py +22 -0
  88. edsl/tasks/__init__.py +4 -2
  89. edsl/tasks/task_history.py +198 -36
  90. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  91. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  92. edsl/utilities/__init__.py +2 -1
  93. edsl/utilities/decorators.py +121 -0
  94. edsl/utilities/memory_debugger.py +1010 -0
  95. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/METADATA +51 -76
  96. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/RECORD +99 -75
  97. edsl/jobs/jobs_runner_asyncio.py +0 -281
  98. edsl/language_models/unused/fake_openai_service.py +0 -60
  99. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
  100. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
  101. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
2
3
  from dataclasses import dataclass
3
4
  from ..coop import CoopServerResponseError
@@ -112,13 +113,18 @@ class JobsRemoteInferenceHandler:
112
113
  )
113
114
  logger.add_info("job_uuid", job_uuid)
114
115
 
116
+ remote_inference_url = self.remote_inference_url
117
+ if "localhost" in remote_inference_url:
118
+ remote_inference_url = remote_inference_url.replace("8000", "1234")
115
119
  logger.update(
116
- f"Job details are available at your Coop account. [Go to Remote Inference page]({self.remote_inference_url})",
120
+ f"Job details are available at your Coop account. [Go to Remote Inference page]({remote_inference_url})",
117
121
  status=JobsStatus.RUNNING,
118
122
  )
119
123
  progress_bar_url = (
120
124
  f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
121
125
  )
126
+ if "localhost" in progress_bar_url:
127
+ progress_bar_url = progress_bar_url.replace("8000", "1234")
122
128
  logger.add_info("progress_bar_url", progress_bar_url)
123
129
  logger.update(
124
130
  f"View job progress [here]({progress_bar_url})", status=JobsStatus.RUNNING
@@ -200,10 +206,35 @@ class JobsRemoteInferenceHandler:
200
206
  status=JobsStatus.FAILED,
201
207
  )
202
208
 
209
+ def _handle_partially_failed_job_interview_details(
210
+ self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
211
+ ) -> None:
212
+ "Extracts the interview details from the remote job data."
213
+ try:
214
+ # Job details is a string of the form "64 out of 1,758 interviews failed"
215
+ job_details = remote_job_data.get("latest_failure_description")
216
+
217
+ text_without_commas = job_details.replace(",", "")
218
+
219
+ # Find all numbers in the text
220
+ numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
221
+
222
+ failed = numbers[0]
223
+ total = numbers[1]
224
+ completed = total - failed
225
+
226
+ job_info.logger.add_info("completed_interviews", completed)
227
+ job_info.logger.add_info("failed_interviews", failed)
228
+ # This is mainly helpful metadata, and any errors here should not stop the code
229
+ except:
230
+ pass
231
+
203
232
  def _handle_partially_failed_job(
204
233
  self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
205
234
  ) -> None:
206
235
  "Handles a partially failed job by logging the error and updating the job status."
236
+ self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
237
+
207
238
  latest_error_report_url = remote_job_data.get("latest_error_report_url")
208
239
 
209
240
  if latest_error_report_url:
@@ -244,6 +275,8 @@ class JobsRemoteInferenceHandler:
244
275
  job_info.logger.add_info("results_uuid", results_uuid)
245
276
  results = object_fetcher(results_uuid, expected_object_type="results")
246
277
  results_url = remote_job_data.get("results_url")
278
+ if "localhost" in results_url:
279
+ results_url = results_url.replace("8000", "1234")
247
280
  job_info.logger.add_info("results_url", results_url)
248
281
 
249
282
  if job_status == "completed":
@@ -256,6 +289,7 @@ class JobsRemoteInferenceHandler:
256
289
  f"View partial results [here]({results_url})",
257
290
  status=JobsStatus.PARTIALLY_FAILED,
258
291
  )
292
+
259
293
  results.job_uuid = job_info.job_uuid
260
294
  results.results_uuid = results_uuid
261
295
  return results
@@ -2,6 +2,7 @@ from typing import Optional, TYPE_CHECKING
2
2
  import os
3
3
  from functools import lru_cache
4
4
  import textwrap
5
+ import requests
5
6
 
6
7
  if TYPE_CHECKING:
7
8
  from ..coop import Coop
@@ -255,7 +256,11 @@ class KeyLookupBuilder:
255
256
  return dict(list(os.environ.items()))
256
257
 
257
258
  def _coop_key_value_pairs(self):
258
- return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
259
+ try:
260
+ return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
261
+ except requests.ConnectionError:
262
+ # If connection fails, return empty dict instead of raising error
263
+ return {}
259
264
 
260
265
  def _config_key_value_pairs(self):
261
266
  from ..config import CONFIG
@@ -365,6 +365,59 @@ class LanguageModel(
365
365
  self._api_token = info.api_token
366
366
  return self._api_token
367
367
 
368
+ def copy(self) -> "LanguageModel":
369
+ """Create a deep copy of this language model instance.
370
+
371
+ This method creates a completely independent copy of the language model
372
+ by creating a new instance with the same parameters and copying relevant attributes.
373
+
374
+ Returns:
375
+ LanguageModel: A new language model instance that is functionally identical to this one
376
+
377
+ Examples:
378
+ >>> m1 = LanguageModel.example()
379
+ >>> m2 = m1.copy()
380
+ >>> m1 == m2 # Functionally equivalent
381
+ True
382
+ >>> id(m1) == id(m2) # But different objects
383
+ False
384
+ """
385
+ # Create a new instance of the same class with the same parameters
386
+ try:
387
+ # For most models, we can instantiate with the saved parameters
388
+ new_model = self.__class__(**self.parameters)
389
+
390
+ # Copy all important instance attributes
391
+ for key, value in self.__dict__.items():
392
+ if key not in ("_api_token",) and not key.startswith("__"):
393
+ setattr(new_model, key, value)
394
+
395
+ return new_model
396
+ except Exception:
397
+ # Fallback for dynamically created classes like TestServiceLanguageModel
398
+ from ..inference_services import default
399
+
400
+ # If this is a test model, create a new test model instance
401
+ if getattr(self, "_inference_service_", "") == "test":
402
+ service = default.get_service("test")
403
+ model_class = service.create_model("test")
404
+ new_model = model_class(**self.parameters)
405
+
406
+ # Copy attributes
407
+ for key, value in self.__dict__.items():
408
+ if key not in ("_api_token",) and not key.startswith("__"):
409
+ setattr(new_model, key, value)
410
+
411
+ return new_model
412
+
413
+ # If we can't create the model directly, just return a simple test model
414
+ # This is a last resort fallback
415
+ from ..inference_services import get_service
416
+
417
+ service = get_service("test")
418
+ model_class = service.create_model("test")
419
+ return model_class()
420
+
368
421
  def __getitem__(self, key):
369
422
  """Allow dictionary-style access to model attributes.
370
423
 
@@ -679,9 +732,12 @@ class LanguageModel(
679
732
  user_prompt_with_hashes = user_prompt
680
733
 
681
734
  # Prepare parameters for cache lookup
735
+ cache_parameters = self.parameters.copy()
736
+ if self.model == "test":
737
+ cache_parameters.pop("canned_response", None)
682
738
  cache_call_params = {
683
739
  "model": str(self.model),
684
- "parameters": self.parameters,
740
+ "parameters": cache_parameters,
685
741
  "system_prompt": system_prompt,
686
742
  "user_prompt": user_prompt_with_hashes,
687
743
  "iteration": iteration,
@@ -844,7 +900,7 @@ class LanguageModel(
844
900
  # Use the price manager to calculate the actual cost
845
901
  from .price_manager import PriceManager
846
902
 
847
- price_manager = PriceManager()
903
+ price_manager = PriceManager.get_instance()
848
904
 
849
905
  return price_manager.calculate_cost(
850
906
  inference_service=self._inference_service_,
@@ -873,9 +929,15 @@ class LanguageModel(
873
929
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
874
930
  """
875
931
  # Build the base dictionary with essential model information
932
+ parameters = self.parameters.copy()
933
+
934
+ # For test models, ensure canned_response is included in serialization
935
+ if self.model == "test" and hasattr(self, "canned_response"):
936
+ parameters["canned_response"] = self.canned_response
937
+
876
938
  d = {
877
939
  "model": self.model,
878
- "parameters": self.parameters,
940
+ "parameters": parameters,
879
941
  "inference_service": self._inference_service_,
880
942
  }
881
943
 
@@ -913,7 +975,25 @@ class LanguageModel(
913
975
  data["model"], service_name=data.get("inference_service", None)
914
976
  )
915
977
 
916
- # Create and return a new instance
978
+ # Handle canned_response in parameters for test models
979
+ if (
980
+ data["model"] == "test"
981
+ and "parameters" in data
982
+ and "canned_response" in data["parameters"]
983
+ ):
984
+ # Extract canned_response from parameters to set as a direct attribute
985
+ canned_response = data["parameters"]["canned_response"]
986
+ params_copy = data.copy()
987
+
988
+ # Direct attribute will be set during initialization
989
+ # Add it as a top-level parameter for model initialization
990
+ if isinstance(params_copy, dict) and "parameters" in params_copy:
991
+ params_copy["canned_response"] = canned_response
992
+
993
+ # Create the instance with canned_response as a direct parameter
994
+ return model_class(**params_copy)
995
+
996
+ # For non-test models or test models without canned_response
917
997
  return model_class(**data)
918
998
 
919
999
  def __repr__(self) -> str:
@@ -999,8 +1079,8 @@ class LanguageModel(
999
1079
 
1000
1080
  Create a test model that throws exceptions:
1001
1081
 
1002
- >>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
1003
- >>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
1082
+ >>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True) # doctest: +SKIP
1083
+ >>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True) # doctest: +SKIP
1004
1084
  Exception report saved to ...
1005
1085
  """
1006
1086
  from ..language_models import Model
@@ -6,8 +6,10 @@ from ..utilities import PrettyList
6
6
  from ..config import CONFIG
7
7
  from .exceptions import LanguageModelValueError
8
8
 
9
+ # Import only what's needed initially to avoid circular imports
9
10
  from ..inference_services import (InferenceServicesCollection,
10
- AvailableModels, InferenceServiceABC, InferenceServiceError, default)
11
+ AvailableModels, InferenceServiceABC, InferenceServiceError)
12
+ # The 'default' import will be imported lazily when needed
11
13
 
12
14
  from ..enums import InferenceServiceLiteral
13
15
 
@@ -20,7 +22,10 @@ def get_model_class(
20
22
  registry: Optional[InferenceServicesCollection] = None,
21
23
  service_name: Optional[InferenceServiceLiteral] = None,
22
24
  ):
23
- registry = registry or default
25
+ if registry is None:
26
+ # Import default lazily only when needed
27
+ from ..inference_services import default as inference_default
28
+ registry = inference_default
24
29
  try:
25
30
  factory = registry.create_model_factory(model_name, service_name=service_name)
26
31
  return factory
@@ -54,7 +59,9 @@ class Model(metaclass=Meta):
54
59
  def get_registry(cls) -> InferenceServicesCollection:
55
60
  """Get the current registry or initialize with default if None"""
56
61
  if cls._registry is None:
57
- cls._registry = default
62
+ # Import default lazily only when needed
63
+ from ..inference_services import default as inference_default
64
+ cls._registry = inference_default
58
65
  return cls._registry
59
66
 
60
67
  @classmethod
@@ -8,20 +8,42 @@ class PriceManager:
8
8
 
9
9
  def __new__(cls):
10
10
  if cls._instance is None:
11
- cls._instance = super(PriceManager, cls).__new__(cls)
11
+ instance = super(PriceManager, cls).__new__(cls)
12
+ instance._price_lookup = {} # Instance-specific attribute
13
+ instance._is_initialized = False
14
+ cls._instance = instance # Store the instance directly
15
+ return instance
12
16
  return cls._instance
13
17
 
14
18
  def __init__(self):
15
- # Only initialize once, even if __init__ is called multiple times
19
+ """Initialize the singleton instance only once."""
16
20
  if not self._is_initialized:
17
21
  self._is_initialized = True
18
22
  self.refresh_prices()
19
23
 
20
- def refresh_prices(self) -> None:
21
- """
22
- Fetch fresh prices from the Coop service and update the internal price lookup.
24
+ @classmethod
25
+ def get_instance(cls):
26
+ """Get the singleton instance, creating it if necessary."""
27
+ if cls._instance is None:
28
+ cls() # Create the instance if it doesn't exist
29
+ return cls._instance
30
+
31
+ @classmethod
32
+ def reset(cls):
33
+ """Reset the singleton instance to clean up resources."""
34
+ cls._instance = None
35
+ cls._is_initialized = False
36
+ cls._price_lookup = {}
23
37
 
24
- """
38
+ def __del__(self):
39
+ """Ensure proper cleanup when the instance is garbage collected."""
40
+ try:
41
+ self._price_lookup = {} # Clean up resources
42
+ except:
43
+ pass # Ignore any cleanup errors
44
+
45
+ def refresh_prices(self) -> None:
46
+ """Fetch fresh prices and update the internal price lookup."""
25
47
  from edsl.coop import Coop
26
48
 
27
49
  c = Coop()
@@ -31,43 +53,18 @@ class PriceManager:
31
53
  print(f"Error fetching prices: {str(e)}")
32
54
 
33
55
  def get_price(self, inference_service: str, model: str) -> Dict:
34
- """
35
- Get the price information for a specific service and model combination.
36
- If no specific price is found, returns a fallback price.
37
-
38
- Args:
39
- inference_service (str): The name of the inference service
40
- model (str): The model identifier
41
-
42
- Returns:
43
- Dict: Price information (either actual or fallback prices)
44
- """
56
+ """Get the price information for a specific service and model."""
45
57
  key = (inference_service, model)
46
58
  return self._price_lookup.get(key) or self._get_fallback_price(
47
59
  inference_service
48
60
  )
49
61
 
50
62
  def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
51
- """
52
- Get the complete price lookup dictionary.
53
-
54
- Returns:
55
- Dict[Tuple[str, str], Dict]: The complete price lookup dictionary
56
- """
63
+ """Get the complete price lookup dictionary."""
57
64
  return self._price_lookup.copy()
58
65
 
59
66
  def _get_fallback_price(self, inference_service: str) -> Dict:
60
- """
61
- Get fallback prices for a service.
62
- - First fallback: The highest input and output prices for that service from the price lookup.
63
- - Second fallback: $1.00 per million tokens (for both input and output).
64
-
65
- Args:
66
- inference_service (str): The inference service name
67
-
68
- Returns:
69
- Dict: Price information
70
- """
67
+ """Get fallback prices for a service."""
71
68
  service_prices = [
72
69
  prices
73
70
  for (service, _), prices in self._price_lookup.items()
@@ -77,18 +74,12 @@ class PriceManager:
77
74
  input_tokens_per_usd = [
78
75
  float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
79
76
  ]
80
- if input_tokens_per_usd:
81
- min_input_tokens = min(input_tokens_per_usd)
82
- else:
83
- min_input_tokens = 1_000_000
77
+ min_input_tokens = min(input_tokens_per_usd, default=1_000_000)
84
78
 
85
79
  output_tokens_per_usd = [
86
80
  float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
87
81
  ]
88
- if output_tokens_per_usd:
89
- min_output_tokens = min(output_tokens_per_usd)
90
- else:
91
- min_output_tokens = 1_000_000
82
+ min_output_tokens = min(output_tokens_per_usd, default=1_000_000)
92
83
 
93
84
  return {
94
85
  "input": {"one_usd_buys": min_input_tokens},
@@ -103,19 +94,7 @@ class PriceManager:
103
94
  input_token_name: str,
104
95
  output_token_name: str,
105
96
  ) -> Union[float, str]:
106
- """
107
- Calculate the total cost for a model usage based on input and output tokens.
108
-
109
- Args:
110
- inference_service (str): The inference service identifier
111
- model (str): The model identifier
112
- usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
113
- input_token_name (str): Key name for input tokens in the usage dict
114
- output_token_name (str): Key name for output tokens in the usage dict
115
-
116
- Returns:
117
- Union[float, str]: Total cost if calculation successful, error message string if not
118
- """
97
+ """Calculate the total cost for a model usage."""
119
98
  relevant_prices = self.get_price(inference_service, model)
120
99
 
121
100
  # Extract token counts
@@ -137,31 +116,22 @@ class PriceManager:
137
116
  return f"Could not fetch prices from {relevant_prices} - {e}"
138
117
 
139
118
  # Calculate input cost
140
- if inverse_input_price == "infinity":
141
- input_cost = 0
142
- else:
143
- try:
144
- input_cost = input_tokens / float(inverse_input_price)
145
- except Exception as e:
146
- return f"Could not compute input price - {e}."
119
+ input_cost = (
120
+ 0
121
+ if inverse_input_price == "infinity"
122
+ else input_tokens / float(inverse_input_price)
123
+ )
147
124
 
148
125
  # Calculate output cost
149
- if inverse_output_price == "infinity":
150
- output_cost = 0
151
- else:
152
- try:
153
- output_cost = output_tokens / float(inverse_output_price)
154
- except Exception as e:
155
- return f"Could not compute output price - {e}"
126
+ output_cost = (
127
+ 0
128
+ if inverse_output_price == "infinity"
129
+ else output_tokens / float(inverse_output_price)
130
+ )
156
131
 
157
132
  return input_cost + output_cost
158
133
 
159
134
  @property
160
135
  def is_initialized(self) -> bool:
161
- """
162
- Check if the PriceManager has been initialized.
163
-
164
- Returns:
165
- bool: True if initialized, False otherwise
166
- """
136
+ """Check if the PriceManager has been initialized."""
167
137
  return self._is_initialized
@@ -18,6 +18,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
18
18
 
19
19
  _registry = {} # Initialize the registry as a dictionary
20
20
  REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
21
+
22
+ @classmethod
23
+ def clear_registry(cls):
24
+ """Clear the registry to prevent memory leaks."""
25
+ cls._registry = {}
21
26
 
22
27
  def __init__(cls, name, bases, dct):
23
28
  """Register the class in the registry if it has a _model_ attribute."""
@@ -2,6 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import json
5
+ import subprocess
6
+ import tempfile
7
+ import os
8
+ import shutil
5
9
  from typing import Dict, List, Optional, TYPE_CHECKING
6
10
 
7
11
  if TYPE_CHECKING:
@@ -17,12 +21,56 @@ class Notebook(Base):
17
21
  """
18
22
 
19
23
  default_name = "notebook"
24
+
25
+ @staticmethod
26
+ def _lint_code(code: str) -> str:
27
+ """
28
+ Lint Python code using ruff.
29
+
30
+ :param code: The Python code to lint
31
+ :return: The linted code
32
+ """
33
+ try:
34
+ # Check if ruff is installed
35
+ if shutil.which("ruff") is None:
36
+ # If ruff is not installed, return original code
37
+ return code
38
+
39
+ with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as temp_file:
40
+ temp_file.write(code)
41
+ temp_file_path = temp_file.name
42
+
43
+ # Run ruff to format the code
44
+ try:
45
+ result = subprocess.run(
46
+ ["ruff", "format", temp_file_path],
47
+ check=True,
48
+ stdout=subprocess.PIPE,
49
+ stderr=subprocess.PIPE
50
+ )
51
+
52
+ # Read the formatted code
53
+ with open(temp_file_path, 'r') as f:
54
+ linted_code = f.read()
55
+
56
+ return linted_code
57
+ except subprocess.CalledProcessError:
58
+ # If ruff fails, return the original code
59
+ return code
60
+ except FileNotFoundError:
61
+ # If ruff is not installed, return the original code
62
+ return code
63
+ finally:
64
+ # Clean up temporary file
65
+ if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
66
+ os.unlink(temp_file_path)
20
67
 
21
68
  def __init__(
22
69
  self,
23
70
  path: Optional[str] = None,
24
71
  data: Optional[Dict] = None,
25
72
  name: Optional[str] = None,
73
+ lint: bool = True,
26
74
  ):
27
75
  """
28
76
  Initialize a new Notebook.
@@ -32,6 +80,7 @@ class Notebook(Base):
32
80
  :param path: A filepath from which to load the notebook.
33
81
  If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
34
82
  :param name: A name for the Notebook.
83
+ :param lint: Whether to lint Python code cells using ruff. Defaults to True.
35
84
  """
36
85
  import nbformat
37
86
 
@@ -54,6 +103,16 @@ class Notebook(Base):
54
103
  raise NotebookEnvironmentError(
55
104
  "Cannot create a notebook from within itself in this development environment"
56
105
  )
106
+
107
+ # Store the lint parameter
108
+ self.lint = lint
109
+
110
+ # Apply linting to code cells if enabled
111
+ if self.lint and self.data and "cells" in self.data:
112
+ for cell in self.data["cells"]:
113
+ if cell.get("cell_type") == "code" and "source" in cell:
114
+ # Only lint Python code cells
115
+ cell["source"] = self._lint_code(cell["source"])
57
116
 
58
117
  # TODO: perhaps add sanity check function
59
118
  # 1. could check if the notebook is a valid notebook
@@ -63,7 +122,7 @@ class Notebook(Base):
63
122
  self.name = name or self.default_name
64
123
 
65
124
  @classmethod
66
- def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
125
+ def from_script(cls, path: str, name: Optional[str] = None, lint: bool = True) -> "Notebook":
67
126
  import nbformat
68
127
 
69
128
  # Read the script file
@@ -78,12 +137,12 @@ class Notebook(Base):
78
137
  nb.cells.append(first_cell)
79
138
 
80
139
  # Create a Notebook instance with the notebook data
81
- notebook_instance = cls(nb)
140
+ notebook_instance = cls(data=nb, name=name, lint=lint)
82
141
 
83
142
  return notebook_instance
84
143
 
85
144
  @classmethod
86
- def from_current_script(cls) -> "Notebook":
145
+ def from_current_script(cls, lint: bool = True) -> "Notebook":
87
146
  import inspect
88
147
  import os
89
148
 
@@ -93,7 +152,7 @@ class Notebook(Base):
93
152
  current_file_path = os.path.abspath(caller_frame[1].filename)
94
153
 
95
154
  # Use from_script to create the notebook
96
- return cls.from_script(current_file_path)
155
+ return cls.from_script(current_file_path, lint=lint)
97
156
 
98
157
  def __eq__(self, other):
99
158
  """
@@ -114,7 +173,7 @@ class Notebook(Base):
114
173
  """
115
174
  Serialize to a dictionary.
116
175
  """
117
- d = {"name": self.name, "data": self.data}
176
+ d = {"name": self.name, "data": self.data, "lint": self.lint}
118
177
  if add_edsl_version:
119
178
  from .. import __version__
120
179
 
@@ -124,11 +183,17 @@ class Notebook(Base):
124
183
 
125
184
  @classmethod
126
185
  @remove_edsl_version
127
- def from_dict(cls, d: Dict) -> "Notebook":
186
+ def from_dict(cls, d: Dict, lint: bool = None) -> "Notebook":
128
187
  """
129
188
  Convert a dictionary representation of a Notebook to a Notebook object.
189
+
190
+ :param d: Dictionary containing notebook data and name
191
+ :param lint: Whether to lint Python code cells. If None, uses the value from the dictionary or defaults to True.
192
+ :return: A new Notebook instance
130
193
  """
131
- return cls(data=d["data"], name=d["name"])
194
+ # Use the lint parameter from the dictionary if none is provided, otherwise default to True
195
+ notebook_lint = lint if lint is not None else d.get("lint", True)
196
+ return cls(data=d["data"], name=d["name"], lint=notebook_lint)
132
197
 
133
198
  def to_file(self, path: str):
134
199
  """
@@ -205,11 +270,13 @@ class Notebook(Base):
205
270
  return table
206
271
 
207
272
  @classmethod
208
- def example(cls, randomize: bool = False) -> Notebook:
273
+ def example(cls, randomize: bool = False, lint: bool = True) -> Notebook:
209
274
  """
210
275
  Returns an example Notebook instance.
211
276
 
212
277
  :param randomize: If True, adds a random string one of the cells' output.
278
+ :param lint: Whether to lint Python code cells. Defaults to True.
279
+ :return: An example Notebook instance
213
280
  """
214
281
  addition = "" if not randomize else str(uuid4())
215
282
  cells = [
@@ -238,7 +305,7 @@ class Notebook(Base):
238
305
  "nbformat_minor": 4,
239
306
  "cells": cells,
240
307
  }
241
- return cls(data=data)
308
+ return cls(data=data, lint=lint)
242
309
 
243
310
  def code(self) -> List[str]:
244
311
  """
@@ -246,7 +313,7 @@ class Notebook(Base):
246
313
  """
247
314
  lines = []
248
315
  lines.append("from edsl import Notebook") # Keep as absolute for code generation
249
- lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""")')
316
+ lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""", lint={self.lint})')
250
317
  return lines
251
318
 
252
319
  def to_latex(self, filename: str):