edsl 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/buckets/__init__.py +8 -3
  8. edsl/buckets/bucket_collection.py +9 -3
  9. edsl/buckets/model_buckets.py +4 -2
  10. edsl/buckets/token_bucket.py +2 -2
  11. edsl/buckets/token_bucket_client.py +5 -3
  12. edsl/caching/cache.py +131 -62
  13. edsl/caching/cache_entry.py +70 -58
  14. edsl/caching/sql_dict.py +17 -0
  15. edsl/cli.py +99 -0
  16. edsl/config/config_class.py +16 -0
  17. edsl/conversation/__init__.py +31 -0
  18. edsl/coop/coop.py +276 -242
  19. edsl/coop/coop_jobs_objects.py +59 -0
  20. edsl/coop/coop_objects.py +29 -0
  21. edsl/coop/coop_regular_objects.py +26 -0
  22. edsl/coop/utils.py +24 -19
  23. edsl/dataset/dataset.py +338 -101
  24. edsl/db_list/sqlite_list.py +349 -0
  25. edsl/inference_services/__init__.py +40 -5
  26. edsl/inference_services/exceptions.py +11 -0
  27. edsl/inference_services/services/anthropic_service.py +5 -2
  28. edsl/inference_services/services/aws_bedrock.py +6 -2
  29. edsl/inference_services/services/azure_ai.py +6 -2
  30. edsl/inference_services/services/google_service.py +3 -2
  31. edsl/inference_services/services/mistral_ai_service.py +6 -2
  32. edsl/inference_services/services/open_ai_service.py +6 -2
  33. edsl/inference_services/services/perplexity_service.py +6 -2
  34. edsl/inference_services/services/test_service.py +105 -7
  35. edsl/interviews/answering_function.py +167 -59
  36. edsl/interviews/interview.py +124 -72
  37. edsl/interviews/interview_task_manager.py +10 -0
  38. edsl/invigilators/invigilators.py +10 -1
  39. edsl/jobs/async_interview_runner.py +146 -104
  40. edsl/jobs/data_structures.py +6 -4
  41. edsl/jobs/decorators.py +61 -0
  42. edsl/jobs/fetch_invigilator.py +61 -18
  43. edsl/jobs/html_table_job_logger.py +14 -2
  44. edsl/jobs/jobs.py +180 -104
  45. edsl/jobs/jobs_component_constructor.py +2 -2
  46. edsl/jobs/jobs_interview_constructor.py +2 -0
  47. edsl/jobs/jobs_pricing_estimation.py +127 -46
  48. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  49. edsl/jobs/jobs_runner_status.py +30 -25
  50. edsl/jobs/progress_bar_manager.py +79 -0
  51. edsl/jobs/remote_inference.py +35 -1
  52. edsl/key_management/key_lookup_builder.py +6 -1
  53. edsl/language_models/language_model.py +102 -12
  54. edsl/language_models/model.py +10 -3
  55. edsl/language_models/price_manager.py +45 -75
  56. edsl/language_models/registry.py +5 -0
  57. edsl/language_models/utilities.py +2 -1
  58. edsl/notebooks/notebook.py +77 -10
  59. edsl/questions/VALIDATION_README.md +134 -0
  60. edsl/questions/__init__.py +24 -1
  61. edsl/questions/exceptions.py +21 -0
  62. edsl/questions/question_check_box.py +171 -149
  63. edsl/questions/question_dict.py +243 -51
  64. edsl/questions/question_multiple_choice_with_other.py +624 -0
  65. edsl/questions/question_registry.py +2 -1
  66. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  67. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  68. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  69. edsl/questions/validation_analysis.py +185 -0
  70. edsl/questions/validation_cli.py +131 -0
  71. edsl/questions/validation_html_report.py +404 -0
  72. edsl/questions/validation_logger.py +136 -0
  73. edsl/results/result.py +63 -16
  74. edsl/results/results.py +702 -171
  75. edsl/scenarios/construct_download_link.py +16 -3
  76. edsl/scenarios/directory_scanner.py +226 -226
  77. edsl/scenarios/file_methods.py +5 -0
  78. edsl/scenarios/file_store.py +117 -6
  79. edsl/scenarios/handlers/__init__.py +5 -1
  80. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  81. edsl/scenarios/handlers/webm_file_store.py +104 -0
  82. edsl/scenarios/scenario.py +120 -101
  83. edsl/scenarios/scenario_list.py +800 -727
  84. edsl/scenarios/scenario_list_gc_test.py +146 -0
  85. edsl/scenarios/scenario_list_memory_test.py +214 -0
  86. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  87. edsl/scenarios/scenario_selector.py +5 -4
  88. edsl/scenarios/scenario_source.py +1990 -0
  89. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  90. edsl/surveys/survey.py +22 -0
  91. edsl/tasks/__init__.py +4 -2
  92. edsl/tasks/task_history.py +198 -36
  93. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  94. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  95. edsl/utilities/__init__.py +2 -1
  96. edsl/utilities/decorators.py +121 -0
  97. edsl/utilities/memory_debugger.py +1010 -0
  98. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/METADATA +52 -76
  99. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/RECORD +102 -78
  100. edsl/jobs/jobs_runner_asyncio.py +0 -281
  101. edsl/language_models/unused/fake_openai_service.py +0 -60
  102. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
  103. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
  104. {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
@@ -365,6 +365,59 @@ class LanguageModel(
365
365
  self._api_token = info.api_token
366
366
  return self._api_token
367
367
 
368
+ def copy(self) -> "LanguageModel":
369
+ """Create a deep copy of this language model instance.
370
+
371
+ This method creates a completely independent copy of the language model
372
+ by creating a new instance with the same parameters and copying relevant attributes.
373
+
374
+ Returns:
375
+ LanguageModel: A new language model instance that is functionally identical to this one
376
+
377
+ Examples:
378
+ >>> m1 = LanguageModel.example()
379
+ >>> m2 = m1.copy()
380
+ >>> m1 == m2 # Functionally equivalent
381
+ True
382
+ >>> id(m1) == id(m2) # But different objects
383
+ False
384
+ """
385
+ # Create a new instance of the same class with the same parameters
386
+ try:
387
+ # For most models, we can instantiate with the saved parameters
388
+ new_model = self.__class__(**self.parameters)
389
+
390
+ # Copy all important instance attributes
391
+ for key, value in self.__dict__.items():
392
+ if key not in ("_api_token",) and not key.startswith("__"):
393
+ setattr(new_model, key, value)
394
+
395
+ return new_model
396
+ except Exception:
397
+ # Fallback for dynamically created classes like TestServiceLanguageModel
398
+ from ..inference_services import default
399
+
400
+ # If this is a test model, create a new test model instance
401
+ if getattr(self, "_inference_service_", "") == "test":
402
+ service = default.get_service("test")
403
+ model_class = service.create_model("test")
404
+ new_model = model_class(**self.parameters)
405
+
406
+ # Copy attributes
407
+ for key, value in self.__dict__.items():
408
+ if key not in ("_api_token",) and not key.startswith("__"):
409
+ setattr(new_model, key, value)
410
+
411
+ return new_model
412
+
413
+ # If we can't create the model directly, just return a simple test model
414
+ # This is a last resort fallback
415
+ from ..inference_services import get_service
416
+
417
+ service = get_service("test")
418
+ model_class = service.create_model("test")
419
+ return model_class()
420
+
368
421
  def __getitem__(self, key):
369
422
  """Allow dictionary-style access to model attributes.
370
423
 
@@ -509,7 +562,9 @@ class LanguageModel(
509
562
  return self.execute_model_call(user_prompt, system_prompt)
510
563
 
511
564
  @abstractmethod
512
- async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
565
+ async def async_execute_model_call(
566
+ self, user_prompt: str, system_prompt: str, question_name: Optional[str] = None
567
+ ):
513
568
  """Execute the model call asynchronously.
514
569
 
515
570
  This abstract method must be implemented by all model subclasses
@@ -518,6 +573,7 @@ class LanguageModel(
518
573
  Args:
519
574
  user_prompt: The user message or input prompt
520
575
  system_prompt: The system message or context
576
+ question_name: Optional name of the question being asked (primarily used for test models)
521
577
 
522
578
  Returns:
523
579
  Coroutine that resolves to the model response
@@ -529,7 +585,7 @@ class LanguageModel(
529
585
  pass
530
586
 
531
587
  async def remote_async_execute_model_call(
532
- self, user_prompt: str, system_prompt: str
588
+ self, user_prompt: str, system_prompt: str, question_name: Optional[str] = None
533
589
  ):
534
590
  """Execute the model call remotely through the EDSL Coop service.
535
591
 
@@ -540,6 +596,7 @@ class LanguageModel(
540
596
  Args:
541
597
  user_prompt: The user message or input prompt
542
598
  system_prompt: The system message or context
599
+ question_name: Optional name of the question being asked (primarily used for test models)
543
600
 
544
601
  Returns:
545
602
  Coroutine that resolves to the model response from the remote service
@@ -563,6 +620,7 @@ class LanguageModel(
563
620
  Args:
564
621
  *args: Positional arguments to pass to async_execute_model_call
565
622
  **kwargs: Keyword arguments to pass to async_execute_model_call
623
+ Can include question_name for test models
566
624
 
567
625
  Returns:
568
626
  The model response
@@ -674,9 +732,12 @@ class LanguageModel(
674
732
  user_prompt_with_hashes = user_prompt
675
733
 
676
734
  # Prepare parameters for cache lookup
735
+ cache_parameters = self.parameters.copy()
736
+ if self.model == "test":
737
+ cache_parameters.pop("canned_response", None)
677
738
  cache_call_params = {
678
739
  "model": str(self.model),
679
- "parameters": self.parameters,
740
+ "parameters": cache_parameters,
680
741
  "system_prompt": system_prompt,
681
742
  "user_prompt": user_prompt_with_hashes,
682
743
  "iteration": iteration,
@@ -702,7 +763,9 @@ class LanguageModel(
702
763
  "system_prompt": system_prompt,
703
764
  "files_list": files_list,
704
765
  }
705
-
766
+ # Add question_name parameter for test models
767
+ if self.model == "test" and invigilator:
768
+ params["question_name"] = invigilator.question.question_name
706
769
  # Get timeout from configuration
707
770
  from ..config import CONFIG
708
771
 
@@ -710,7 +773,6 @@ class LanguageModel(
710
773
 
711
774
  # Execute the model call with timeout
712
775
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
713
-
714
776
  # Store the response in the cache
715
777
  new_cache_key = cache.store(
716
778
  **cache_call_params, response=response, service=self._inference_service_
@@ -801,7 +863,6 @@ class LanguageModel(
801
863
 
802
864
  # Create structured input record
803
865
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
804
-
805
866
  # Get model response (using cache if available)
806
867
  model_outputs: ModelResponse = (
807
868
  await self._async_get_intended_model_call_outcome(**params)
@@ -839,7 +900,7 @@ class LanguageModel(
839
900
  # Use the price manager to calculate the actual cost
840
901
  from .price_manager import PriceManager
841
902
 
842
- price_manager = PriceManager()
903
+ price_manager = PriceManager.get_instance()
843
904
 
844
905
  return price_manager.calculate_cost(
845
906
  inference_service=self._inference_service_,
@@ -868,9 +929,15 @@ class LanguageModel(
868
929
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
869
930
  """
870
931
  # Build the base dictionary with essential model information
932
+ parameters = self.parameters.copy()
933
+
934
+ # For test models, ensure canned_response is included in serialization
935
+ if self.model == "test" and hasattr(self, "canned_response"):
936
+ parameters["canned_response"] = self.canned_response
937
+
871
938
  d = {
872
939
  "model": self.model,
873
- "parameters": self.parameters,
940
+ "parameters": parameters,
874
941
  "inference_service": self._inference_service_,
875
942
  }
876
943
 
@@ -908,7 +975,25 @@ class LanguageModel(
908
975
  data["model"], service_name=data.get("inference_service", None)
909
976
  )
910
977
 
911
- # Create and return a new instance
978
+ # Handle canned_response in parameters for test models
979
+ if (
980
+ data["model"] == "test"
981
+ and "parameters" in data
982
+ and "canned_response" in data["parameters"]
983
+ ):
984
+ # Extract canned_response from parameters to set as a direct attribute
985
+ canned_response = data["parameters"]["canned_response"]
986
+ params_copy = data.copy()
987
+
988
+ # Direct attribute will be set during initialization
989
+ # Add it as a top-level parameter for model initialization
990
+ if isinstance(params_copy, dict) and "parameters" in params_copy:
991
+ params_copy["canned_response"] = canned_response
992
+
993
+ # Create the instance with canned_response as a direct parameter
994
+ return model_class(**params_copy)
995
+
996
+ # For non-test models or test models without canned_response
912
997
  return model_class(**data)
913
998
 
914
999
  def __repr__(self) -> str:
@@ -994,8 +1079,8 @@ class LanguageModel(
994
1079
 
995
1080
  Create a test model that throws exceptions:
996
1081
 
997
- >>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
998
- >>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
1082
+ >>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True) # doctest: +SKIP
1083
+ >>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True) # doctest: +SKIP
999
1084
  Exception report saved to ...
1000
1085
  """
1001
1086
  from ..language_models import Model
@@ -1046,7 +1131,12 @@ class LanguageModel(
1046
1131
  ]
1047
1132
 
1048
1133
  # Define a new async_execute_model_call that only reads from cache
1049
- async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
1134
+ async def async_execute_model_call(
1135
+ self,
1136
+ user_prompt: str,
1137
+ system_prompt: str,
1138
+ question_name: Optional[str] = None,
1139
+ ):
1050
1140
  """Only use cached responses, never making new API calls."""
1051
1141
  cache_call_params = {
1052
1142
  "model": str(self.model),
@@ -6,8 +6,10 @@ from ..utilities import PrettyList
6
6
  from ..config import CONFIG
7
7
  from .exceptions import LanguageModelValueError
8
8
 
9
+ # Import only what's needed initially to avoid circular imports
9
10
  from ..inference_services import (InferenceServicesCollection,
10
- AvailableModels, InferenceServiceABC, InferenceServiceError, default)
11
+ AvailableModels, InferenceServiceABC, InferenceServiceError)
12
+ # The 'default' import will be imported lazily when needed
11
13
 
12
14
  from ..enums import InferenceServiceLiteral
13
15
 
@@ -20,7 +22,10 @@ def get_model_class(
20
22
  registry: Optional[InferenceServicesCollection] = None,
21
23
  service_name: Optional[InferenceServiceLiteral] = None,
22
24
  ):
23
- registry = registry or default
25
+ if registry is None:
26
+ # Import default lazily only when needed
27
+ from ..inference_services import default as inference_default
28
+ registry = inference_default
24
29
  try:
25
30
  factory = registry.create_model_factory(model_name, service_name=service_name)
26
31
  return factory
@@ -54,7 +59,9 @@ class Model(metaclass=Meta):
54
59
  def get_registry(cls) -> InferenceServicesCollection:
55
60
  """Get the current registry or initialize with default if None"""
56
61
  if cls._registry is None:
57
- cls._registry = default
62
+ # Import default lazily only when needed
63
+ from ..inference_services import default as inference_default
64
+ cls._registry = inference_default
58
65
  return cls._registry
59
66
 
60
67
  @classmethod
@@ -8,20 +8,42 @@ class PriceManager:
8
8
 
9
9
  def __new__(cls):
10
10
  if cls._instance is None:
11
- cls._instance = super(PriceManager, cls).__new__(cls)
11
+ instance = super(PriceManager, cls).__new__(cls)
12
+ instance._price_lookup = {} # Instance-specific attribute
13
+ instance._is_initialized = False
14
+ cls._instance = instance # Store the instance directly
15
+ return instance
12
16
  return cls._instance
13
17
 
14
18
  def __init__(self):
15
- # Only initialize once, even if __init__ is called multiple times
19
+ """Initialize the singleton instance only once."""
16
20
  if not self._is_initialized:
17
21
  self._is_initialized = True
18
22
  self.refresh_prices()
19
23
 
20
- def refresh_prices(self) -> None:
21
- """
22
- Fetch fresh prices from the Coop service and update the internal price lookup.
24
+ @classmethod
25
+ def get_instance(cls):
26
+ """Get the singleton instance, creating it if necessary."""
27
+ if cls._instance is None:
28
+ cls() # Create the instance if it doesn't exist
29
+ return cls._instance
30
+
31
+ @classmethod
32
+ def reset(cls):
33
+ """Reset the singleton instance to clean up resources."""
34
+ cls._instance = None
35
+ cls._is_initialized = False
36
+ cls._price_lookup = {}
23
37
 
24
- """
38
+ def __del__(self):
39
+ """Ensure proper cleanup when the instance is garbage collected."""
40
+ try:
41
+ self._price_lookup = {} # Clean up resources
42
+ except:
43
+ pass # Ignore any cleanup errors
44
+
45
+ def refresh_prices(self) -> None:
46
+ """Fetch fresh prices and update the internal price lookup."""
25
47
  from edsl.coop import Coop
26
48
 
27
49
  c = Coop()
@@ -31,43 +53,18 @@ class PriceManager:
31
53
  print(f"Error fetching prices: {str(e)}")
32
54
 
33
55
  def get_price(self, inference_service: str, model: str) -> Dict:
34
- """
35
- Get the price information for a specific service and model combination.
36
- If no specific price is found, returns a fallback price.
37
-
38
- Args:
39
- inference_service (str): The name of the inference service
40
- model (str): The model identifier
41
-
42
- Returns:
43
- Dict: Price information (either actual or fallback prices)
44
- """
56
+ """Get the price information for a specific service and model."""
45
57
  key = (inference_service, model)
46
58
  return self._price_lookup.get(key) or self._get_fallback_price(
47
59
  inference_service
48
60
  )
49
61
 
50
62
  def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
51
- """
52
- Get the complete price lookup dictionary.
53
-
54
- Returns:
55
- Dict[Tuple[str, str], Dict]: The complete price lookup dictionary
56
- """
63
+ """Get the complete price lookup dictionary."""
57
64
  return self._price_lookup.copy()
58
65
 
59
66
  def _get_fallback_price(self, inference_service: str) -> Dict:
60
- """
61
- Get fallback prices for a service.
62
- - First fallback: The highest input and output prices for that service from the price lookup.
63
- - Second fallback: $1.00 per million tokens (for both input and output).
64
-
65
- Args:
66
- inference_service (str): The inference service name
67
-
68
- Returns:
69
- Dict: Price information
70
- """
67
+ """Get fallback prices for a service."""
71
68
  service_prices = [
72
69
  prices
73
70
  for (service, _), prices in self._price_lookup.items()
@@ -77,18 +74,12 @@ class PriceManager:
77
74
  input_tokens_per_usd = [
78
75
  float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
79
76
  ]
80
- if input_tokens_per_usd:
81
- min_input_tokens = min(input_tokens_per_usd)
82
- else:
83
- min_input_tokens = 1_000_000
77
+ min_input_tokens = min(input_tokens_per_usd, default=1_000_000)
84
78
 
85
79
  output_tokens_per_usd = [
86
80
  float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
87
81
  ]
88
- if output_tokens_per_usd:
89
- min_output_tokens = min(output_tokens_per_usd)
90
- else:
91
- min_output_tokens = 1_000_000
82
+ min_output_tokens = min(output_tokens_per_usd, default=1_000_000)
92
83
 
93
84
  return {
94
85
  "input": {"one_usd_buys": min_input_tokens},
@@ -103,19 +94,7 @@ class PriceManager:
103
94
  input_token_name: str,
104
95
  output_token_name: str,
105
96
  ) -> Union[float, str]:
106
- """
107
- Calculate the total cost for a model usage based on input and output tokens.
108
-
109
- Args:
110
- inference_service (str): The inference service identifier
111
- model (str): The model identifier
112
- usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
113
- input_token_name (str): Key name for input tokens in the usage dict
114
- output_token_name (str): Key name for output tokens in the usage dict
115
-
116
- Returns:
117
- Union[float, str]: Total cost if calculation successful, error message string if not
118
- """
97
+ """Calculate the total cost for a model usage."""
119
98
  relevant_prices = self.get_price(inference_service, model)
120
99
 
121
100
  # Extract token counts
@@ -137,31 +116,22 @@ class PriceManager:
137
116
  return f"Could not fetch prices from {relevant_prices} - {e}"
138
117
 
139
118
  # Calculate input cost
140
- if inverse_input_price == "infinity":
141
- input_cost = 0
142
- else:
143
- try:
144
- input_cost = input_tokens / float(inverse_input_price)
145
- except Exception as e:
146
- return f"Could not compute input price - {e}."
119
+ input_cost = (
120
+ 0
121
+ if inverse_input_price == "infinity"
122
+ else input_tokens / float(inverse_input_price)
123
+ )
147
124
 
148
125
  # Calculate output cost
149
- if inverse_output_price == "infinity":
150
- output_cost = 0
151
- else:
152
- try:
153
- output_cost = output_tokens / float(inverse_output_price)
154
- except Exception as e:
155
- return f"Could not compute output price - {e}"
126
+ output_cost = (
127
+ 0
128
+ if inverse_output_price == "infinity"
129
+ else output_tokens / float(inverse_output_price)
130
+ )
156
131
 
157
132
  return input_cost + output_cost
158
133
 
159
134
  @property
160
135
  def is_initialized(self) -> bool:
161
- """
162
- Check if the PriceManager has been initialized.
163
-
164
- Returns:
165
- bool: True if initialized, False otherwise
166
- """
136
+ """Check if the PriceManager has been initialized."""
167
137
  return self._is_initialized
@@ -18,6 +18,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
18
18
 
19
19
  _registry = {} # Initialize the registry as a dictionary
20
20
  REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
21
+
22
+ @classmethod
23
+ def clear_registry(cls):
24
+ """Clear the registry to prevent memory leaks."""
25
+ cls._registry = {}
21
26
 
22
27
  def __init__(cls, name, bases, dct):
23
28
  """Register the class in the registry if it has a _model_ attribute."""
@@ -5,6 +5,7 @@ from ..surveys import Survey
5
5
 
6
6
  from .language_model import LanguageModel
7
7
 
8
+
8
9
  def create_survey(num_questions: int, chained: bool = True, take_scenario=False):
9
10
  from ..questions import QuestionFreeText
10
11
 
@@ -28,7 +29,6 @@ def create_survey(num_questions: int, chained: bool = True, take_scenario=False)
28
29
  def create_language_model(
29
30
  exception: Exception, fail_at_number: int, never_ending=False
30
31
  ):
31
-
32
32
  class LanguageModelFromUtilities(LanguageModel):
33
33
  _model_ = "test"
34
34
  _parameters_ = {"temperature": 0.5}
@@ -45,6 +45,7 @@ def create_language_model(
45
45
  user_prompt: str,
46
46
  system_prompt: str,
47
47
  files_list: Optional[List[Any]] = None,
48
+ question_name: Optional[str] = None,
48
49
  ) -> dict[str, Any]:
49
50
  question_number = int(
50
51
  user_prompt.split("XX")[1]
@@ -2,6 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import json
5
+ import subprocess
6
+ import tempfile
7
+ import os
8
+ import shutil
5
9
  from typing import Dict, List, Optional, TYPE_CHECKING
6
10
 
7
11
  if TYPE_CHECKING:
@@ -17,12 +21,56 @@ class Notebook(Base):
17
21
  """
18
22
 
19
23
  default_name = "notebook"
24
+
25
+ @staticmethod
26
+ def _lint_code(code: str) -> str:
27
+ """
28
+ Lint Python code using ruff.
29
+
30
+ :param code: The Python code to lint
31
+ :return: The linted code
32
+ """
33
+ try:
34
+ # Check if ruff is installed
35
+ if shutil.which("ruff") is None:
36
+ # If ruff is not installed, return original code
37
+ return code
38
+
39
+ with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as temp_file:
40
+ temp_file.write(code)
41
+ temp_file_path = temp_file.name
42
+
43
+ # Run ruff to format the code
44
+ try:
45
+ result = subprocess.run(
46
+ ["ruff", "format", temp_file_path],
47
+ check=True,
48
+ stdout=subprocess.PIPE,
49
+ stderr=subprocess.PIPE
50
+ )
51
+
52
+ # Read the formatted code
53
+ with open(temp_file_path, 'r') as f:
54
+ linted_code = f.read()
55
+
56
+ return linted_code
57
+ except subprocess.CalledProcessError:
58
+ # If ruff fails, return the original code
59
+ return code
60
+ except FileNotFoundError:
61
+ # If ruff is not installed, return the original code
62
+ return code
63
+ finally:
64
+ # Clean up temporary file
65
+ if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
66
+ os.unlink(temp_file_path)
20
67
 
21
68
  def __init__(
22
69
  self,
23
70
  path: Optional[str] = None,
24
71
  data: Optional[Dict] = None,
25
72
  name: Optional[str] = None,
73
+ lint: bool = True,
26
74
  ):
27
75
  """
28
76
  Initialize a new Notebook.
@@ -32,6 +80,7 @@ class Notebook(Base):
32
80
  :param path: A filepath from which to load the notebook.
33
81
  If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
34
82
  :param name: A name for the Notebook.
83
+ :param lint: Whether to lint Python code cells using ruff. Defaults to True.
35
84
  """
36
85
  import nbformat
37
86
 
@@ -54,6 +103,16 @@ class Notebook(Base):
54
103
  raise NotebookEnvironmentError(
55
104
  "Cannot create a notebook from within itself in this development environment"
56
105
  )
106
+
107
+ # Store the lint parameter
108
+ self.lint = lint
109
+
110
+ # Apply linting to code cells if enabled
111
+ if self.lint and self.data and "cells" in self.data:
112
+ for cell in self.data["cells"]:
113
+ if cell.get("cell_type") == "code" and "source" in cell:
114
+ # Only lint Python code cells
115
+ cell["source"] = self._lint_code(cell["source"])
57
116
 
58
117
  # TODO: perhaps add sanity check function
59
118
  # 1. could check if the notebook is a valid notebook
@@ -63,7 +122,7 @@ class Notebook(Base):
63
122
  self.name = name or self.default_name
64
123
 
65
124
  @classmethod
66
- def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
125
+ def from_script(cls, path: str, name: Optional[str] = None, lint: bool = True) -> "Notebook":
67
126
  import nbformat
68
127
 
69
128
  # Read the script file
@@ -78,12 +137,12 @@ class Notebook(Base):
78
137
  nb.cells.append(first_cell)
79
138
 
80
139
  # Create a Notebook instance with the notebook data
81
- notebook_instance = cls(nb)
140
+ notebook_instance = cls(data=nb, name=name, lint=lint)
82
141
 
83
142
  return notebook_instance
84
143
 
85
144
  @classmethod
86
- def from_current_script(cls) -> "Notebook":
145
+ def from_current_script(cls, lint: bool = True) -> "Notebook":
87
146
  import inspect
88
147
  import os
89
148
 
@@ -93,7 +152,7 @@ class Notebook(Base):
93
152
  current_file_path = os.path.abspath(caller_frame[1].filename)
94
153
 
95
154
  # Use from_script to create the notebook
96
- return cls.from_script(current_file_path)
155
+ return cls.from_script(current_file_path, lint=lint)
97
156
 
98
157
  def __eq__(self, other):
99
158
  """
@@ -114,7 +173,7 @@ class Notebook(Base):
114
173
  """
115
174
  Serialize to a dictionary.
116
175
  """
117
- d = {"name": self.name, "data": self.data}
176
+ d = {"name": self.name, "data": self.data, "lint": self.lint}
118
177
  if add_edsl_version:
119
178
  from .. import __version__
120
179
 
@@ -124,11 +183,17 @@ class Notebook(Base):
124
183
 
125
184
  @classmethod
126
185
  @remove_edsl_version
127
- def from_dict(cls, d: Dict) -> "Notebook":
186
+ def from_dict(cls, d: Dict, lint: bool = None) -> "Notebook":
128
187
  """
129
188
  Convert a dictionary representation of a Notebook to a Notebook object.
189
+
190
+ :param d: Dictionary containing notebook data and name
191
+ :param lint: Whether to lint Python code cells. If None, uses the value from the dictionary or defaults to True.
192
+ :return: A new Notebook instance
130
193
  """
131
- return cls(data=d["data"], name=d["name"])
194
+ # Use the lint parameter from the dictionary if none is provided, otherwise default to True
195
+ notebook_lint = lint if lint is not None else d.get("lint", True)
196
+ return cls(data=d["data"], name=d["name"], lint=notebook_lint)
132
197
 
133
198
  def to_file(self, path: str):
134
199
  """
@@ -205,11 +270,13 @@ class Notebook(Base):
205
270
  return table
206
271
 
207
272
  @classmethod
208
- def example(cls, randomize: bool = False) -> Notebook:
273
+ def example(cls, randomize: bool = False, lint: bool = True) -> Notebook:
209
274
  """
210
275
  Returns an example Notebook instance.
211
276
 
212
277
  :param randomize: If True, adds a random string one of the cells' output.
278
+ :param lint: Whether to lint Python code cells. Defaults to True.
279
+ :return: An example Notebook instance
213
280
  """
214
281
  addition = "" if not randomize else str(uuid4())
215
282
  cells = [
@@ -238,7 +305,7 @@ class Notebook(Base):
238
305
  "nbformat_minor": 4,
239
306
  "cells": cells,
240
307
  }
241
- return cls(data=data)
308
+ return cls(data=data, lint=lint)
242
309
 
243
310
  def code(self) -> List[str]:
244
311
  """
@@ -246,7 +313,7 @@ class Notebook(Base):
246
313
  """
247
314
  lines = []
248
315
  lines.append("from edsl import Notebook") # Keep as absolute for code generation
249
- lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""")')
316
+ lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""", lint={self.lint})')
250
317
  return lines
251
318
 
252
319
  def to_latex(self, filename: str):