edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +28 -0
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -16
  5. edsl/agents/Invigilator.py +13 -14
  6. edsl/agents/InvigilatorBase.py +4 -1
  7. edsl/agents/PromptConstructor.py +42 -22
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +18 -5
  10. edsl/auto/StageBase.py +53 -40
  11. edsl/auto/StageQuestions.py +2 -1
  12. edsl/auto/utilities.py +0 -6
  13. edsl/coop/coop.py +21 -5
  14. edsl/data/Cache.py +29 -18
  15. edsl/data/CacheHandler.py +0 -2
  16. edsl/data/RemoteCacheSync.py +154 -46
  17. edsl/data/hack.py +10 -0
  18. edsl/enums.py +7 -0
  19. edsl/inference_services/AnthropicService.py +38 -16
  20. edsl/inference_services/AvailableModelFetcher.py +7 -1
  21. edsl/inference_services/GoogleService.py +5 -1
  22. edsl/inference_services/InferenceServicesCollection.py +18 -2
  23. edsl/inference_services/OpenAIService.py +46 -31
  24. edsl/inference_services/TestService.py +1 -3
  25. edsl/inference_services/TogetherAIService.py +5 -3
  26. edsl/inference_services/data_structures.py +74 -2
  27. edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
  28. edsl/jobs/FetchInvigilator.py +10 -3
  29. edsl/jobs/InterviewsConstructor.py +6 -4
  30. edsl/jobs/Jobs.py +299 -233
  31. edsl/jobs/JobsChecks.py +2 -2
  32. edsl/jobs/JobsPrompts.py +1 -1
  33. edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
  34. edsl/jobs/async_interview_runner.py +138 -0
  35. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  36. edsl/jobs/data_structures.py +120 -0
  37. edsl/jobs/interviews/Interview.py +80 -42
  38. edsl/jobs/results_exceptions_handler.py +98 -0
  39. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
  40. edsl/jobs/runners/JobsRunnerStatus.py +131 -164
  41. edsl/jobs/tasks/TaskHistory.py +24 -3
  42. edsl/language_models/LanguageModel.py +59 -4
  43. edsl/language_models/ModelList.py +19 -8
  44. edsl/language_models/__init__.py +1 -1
  45. edsl/language_models/model.py +256 -0
  46. edsl/language_models/repair.py +1 -1
  47. edsl/questions/QuestionBase.py +35 -26
  48. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  49. edsl/questions/QuestionBudget.py +1 -1
  50. edsl/questions/QuestionCheckBox.py +2 -2
  51. edsl/questions/QuestionExtract.py +5 -7
  52. edsl/questions/QuestionFreeText.py +1 -1
  53. edsl/questions/QuestionList.py +9 -15
  54. edsl/questions/QuestionMatrix.py +1 -1
  55. edsl/questions/QuestionMultipleChoice.py +1 -1
  56. edsl/questions/QuestionNumerical.py +1 -1
  57. edsl/questions/QuestionRank.py +1 -1
  58. edsl/questions/SimpleAskMixin.py +1 -1
  59. edsl/questions/__init__.py +1 -1
  60. edsl/questions/data_structures.py +20 -0
  61. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
  62. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
  63. edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
  64. edsl/results/DatasetExportMixin.py +60 -119
  65. edsl/results/Result.py +109 -3
  66. edsl/results/Results.py +50 -39
  67. edsl/results/file_exports.py +252 -0
  68. edsl/scenarios/ScenarioList.py +35 -7
  69. edsl/surveys/Survey.py +71 -20
  70. edsl/test_h +1 -0
  71. edsl/utilities/gcp_bucket/example.py +50 -0
  72. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
  73. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
  74. edsl/language_models/registry.py +0 -180
  75. /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
  76. /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
  77. /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
  78. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  79. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  80. /edsl/results/{Selector.py → results_selector.py} +0 -0
  81. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  82. /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
  83. /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
  84. /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
  85. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/results/Results.py CHANGED
@@ -9,6 +9,8 @@ import random
9
9
  from collections import UserList, defaultdict
10
10
  from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
11
11
 
12
+ from bisect import bisect_left
13
+
12
14
  from edsl.Base import Base
13
15
  from edsl.exceptions.results import (
14
16
  ResultsError,
@@ -24,7 +26,7 @@ if TYPE_CHECKING:
24
26
  from edsl.surveys.Survey import Survey
25
27
  from edsl.data.Cache import Cache
26
28
  from edsl.agents.AgentList import AgentList
27
- from edsl.language_models.registry import Model
29
+ from edsl.language_models.model import Model
28
30
  from edsl.scenarios.ScenarioList import ScenarioList
29
31
  from edsl.results.Result import Result
30
32
  from edsl.jobs.tasks.TaskHistory import TaskHistory
@@ -33,7 +35,7 @@ if TYPE_CHECKING:
33
35
 
34
36
  from edsl.results.ResultsExportMixin import ResultsExportMixin
35
37
  from edsl.results.ResultsGGMixin import ResultsGGMixin
36
- from edsl.results.ResultsFetchMixin import ResultsFetchMixin
38
+ from edsl.results.results_fetch_mixin import ResultsFetchMixin
37
39
  from edsl.utilities.remove_edsl_version import remove_edsl_version
38
40
 
39
41
 
@@ -136,7 +138,33 @@ class Results(UserList, Mixins, Base):
136
138
  }
137
139
  return d
138
140
 
139
- def compute_job_cost(self, include_cached_responses_in_cost=False) -> float:
141
+ def insert(self, item):
142
+ item_order = getattr(item, "order", None)
143
+ if item_order is not None:
144
+ # Get list of orders, putting None at the end
145
+ orders = [getattr(x, "order", None) for x in self]
146
+ # Filter to just the non-None orders for bisect
147
+ sorted_orders = [x for x in orders if x is not None]
148
+ if sorted_orders:
149
+ index = bisect_left(sorted_orders, item_order)
150
+ # Account for any None values before this position
151
+ index += orders[:index].count(None)
152
+ else:
153
+ # If no sorted items yet, insert before any unordered items
154
+ index = 0
155
+ self.data.insert(index, item)
156
+ else:
157
+ # No order - append to end
158
+ self.data.append(item)
159
+
160
+ def append(self, item):
161
+ self.insert(item)
162
+
163
+ def extend(self, other):
164
+ for item in other:
165
+ self.insert(item)
166
+
167
+ def compute_job_cost(self, include_cached_responses_in_cost: bool = False) -> float:
140
168
  """
141
169
  Computes the cost of a completed job in USD.
142
170
  """
@@ -250,24 +278,6 @@ class Results(UserList, Mixins, Base):
250
278
 
251
279
  raise TypeError("Invalid argument type")
252
280
 
253
- # def _update_results(self) -> None:
254
- # from edsl import Agent, Scenario
255
- # from edsl.language_models import LanguageModel
256
- # from edsl.results import Result
257
-
258
- # if self._job_uuid and len(self.data) < self._total_results:
259
- # results = [
260
- # Result(
261
- # agent=Agent.from_dict(json.loads(r.agent)),
262
- # scenario=Scenario.from_dict(json.loads(r.scenario)),
263
- # model=LanguageModel.from_dict(json.loads(r.model)),
264
- # iteration=1,
265
- # answer=json.loads(r.answer),
266
- # )
267
- # for r in CRUD.read_results(self._job_uuid)
268
- # ]
269
- # self.data = results
270
-
271
281
  def __add__(self, other: Results) -> Results:
272
282
  """Add two Results objects together.
273
283
  They must have the same survey and created columns.
@@ -295,13 +305,10 @@ class Results(UserList, Mixins, Base):
295
305
  )
296
306
 
297
307
  def __repr__(self) -> str:
298
- # import reprlib
299
-
300
308
  return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
301
309
 
302
310
  def table(
303
311
  self,
304
- # selector_string: Optional[str] = "*.*",
305
312
  *fields,
306
313
  tablefmt: Optional[str] = None,
307
314
  pretty_labels: Optional[dict] = None,
@@ -340,11 +347,11 @@ class Results(UserList, Mixins, Base):
340
347
 
341
348
  def to_dict(
342
349
  self,
343
- sort=False,
344
- add_edsl_version=False,
345
- include_cache=False,
346
- include_task_history=False,
347
- include_cache_info=True,
350
+ sort: bool = False,
351
+ add_edsl_version: bool = False,
352
+ include_cache: bool = False,
353
+ include_task_history: bool = False,
354
+ include_cache_info: bool = True,
348
355
  ) -> dict[str, Any]:
349
356
  from edsl.data.Cache import Cache
350
357
 
@@ -386,7 +393,7 @@ class Results(UserList, Mixins, Base):
386
393
 
387
394
  return d
388
395
 
389
- def compare(self, other_results):
396
+ def compare(self, other_results: Results) -> dict:
390
397
  """
391
398
  Compare two Results objects and return the differences.
392
399
  """
@@ -404,7 +411,7 @@ class Results(UserList, Mixins, Base):
404
411
  }
405
412
 
406
413
  @property
407
- def has_unfixed_exceptions(self):
414
+ def has_unfixed_exceptions(self) -> bool:
408
415
  return self.task_history.has_unfixed_exceptions
409
416
 
410
417
  def __hash__(self) -> int:
@@ -487,10 +494,6 @@ class Results(UserList, Mixins, Base):
487
494
  raise ResultsDeserializationError(f"Error in Results.from_dict: {e}")
488
495
  return results
489
496
 
490
- ######################
491
- ## Convenience methods
492
- ## & Report methods
493
- ######################
494
497
  @property
495
498
  def _key_to_data_type(self) -> dict[str, str]:
496
499
  """
@@ -689,13 +692,19 @@ class Results(UserList, Mixins, Base):
689
692
  """
690
693
  return self.data[0]
691
694
 
692
- def answer_truncate(self, column: str, top_n=5, new_var_name=None) -> Results:
695
+ def answer_truncate(
696
+ self, column: str, top_n: int = 5, new_var_name: str = None
697
+ ) -> Results:
693
698
  """Create a new variable that truncates the answers to the top_n.
694
699
 
695
700
  :param column: The column to truncate.
696
701
  :param top_n: The number of top answers to keep.
697
702
  :param new_var_name: The name of the new variable. If None, it is the original name + '_truncated'.
698
703
 
704
+ Example:
705
+ >>> r = Results.example()
706
+ >>> r.answer_truncate('how_feeling', top_n = 2).select('how_feeling', 'how_feeling_truncated')
707
+ Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling_truncated': ['Other', 'Other', 'Other', 'Other']}])
699
708
 
700
709
 
701
710
  """
@@ -916,7 +925,7 @@ class Results(UserList, Mixins, Base):
916
925
  n: Optional[int] = None,
917
926
  frac: Optional[float] = None,
918
927
  with_replacement: bool = True,
919
- seed: Optional[str] = "edsl",
928
+ seed: Optional[str] = None,
920
929
  ) -> Results:
921
930
  """Sample the results.
922
931
 
@@ -931,7 +940,7 @@ class Results(UserList, Mixins, Base):
931
940
  >>> len(r.sample(2))
932
941
  2
933
942
  """
934
- if seed != "edsl":
943
+ if seed:
935
944
  random.seed(seed)
936
945
 
937
946
  if n is None and frac is None:
@@ -969,7 +978,7 @@ class Results(UserList, Mixins, Base):
969
978
  Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
970
979
  """
971
980
 
972
- from edsl.results.Selector import Selector
981
+ from edsl.results.results_selector import Selector
973
982
 
974
983
  if len(self) == 0:
975
984
  raise Exception("No data to select from---the Results object is empty.")
@@ -984,6 +993,7 @@ class Results(UserList, Mixins, Base):
984
993
  return selector.select(*columns)
985
994
 
986
995
  def sort_by(self, *columns: str, reverse: bool = False) -> Results:
996
+ """Sort the results by one or more columns."""
987
997
  import warnings
988
998
 
989
999
  warnings.warn(
@@ -992,6 +1002,7 @@ class Results(UserList, Mixins, Base):
992
1002
  return self.order_by(*columns, reverse=reverse)
993
1003
 
994
1004
  def _parse_column(self, column: str) -> tuple[str, str]:
1005
+ """Parse a column name into a data type and key."""
995
1006
  if "." in column:
996
1007
  return column.split(".")
997
1008
  return self._key_to_data_type[column], column
@@ -0,0 +1,252 @@
1
+ from abc import ABC, abstractmethod
2
+ import io
3
+ import csv
4
+ import base64
5
+ from typing import Optional, Union, Tuple, List, Any, Dict
6
+ from openpyxl import Workbook
7
+
8
+ from edsl.scenarios.FileStore import FileStore
9
+
10
+
11
+ class FileExport(ABC):
12
+ def __init__(
13
+ self,
14
+ data: Any,
15
+ filename: Optional[str] = None,
16
+ remove_prefix: bool = False,
17
+ pretty_labels: Optional[Dict[str, str]] = None,
18
+ ):
19
+ self.data = data
20
+ self.filename = filename # or self._get_default_filename()
21
+ self.remove_prefix = remove_prefix
22
+ self.pretty_labels = pretty_labels
23
+
24
+ @property
25
+ def mime_type(self) -> str:
26
+ """Return the MIME type for this export format."""
27
+ return self.__class__.mime_type
28
+
29
+ @property
30
+ def suffix(self) -> str:
31
+ """Return the file suffix for this format."""
32
+ return self.__class__.suffix
33
+
34
+ @property
35
+ def is_binary(self) -> bool:
36
+ """Whether the format is binary or text-based."""
37
+ return self.__class__.is_binary
38
+
39
+ def _get_default_filename(self) -> str:
40
+ """Generate default filename for this format."""
41
+ return f"results.{self.suffix}"
42
+
43
+ def _create_filestore(self, data: Union[str, bytes]) -> "FileStore":
44
+ """Create a FileStore instance with encoded data."""
45
+ if isinstance(data, str):
46
+ base64_string = base64.b64encode(data.encode()).decode()
47
+ else:
48
+ base64_string = base64.b64encode(data).decode()
49
+
50
+ from edsl.scenarios.FileStore import FileStore
51
+
52
+ path = self.filename or self._get_default_filename()
53
+
54
+ fs = FileStore(
55
+ path=path,
56
+ mime_type=self.mime_type,
57
+ binary=self.is_binary,
58
+ suffix=self.suffix,
59
+ base64_string=base64_string,
60
+ )
61
+
62
+ if self.filename is not None:
63
+ fs.write(self.filename)
64
+ return None
65
+ return fs
66
+
67
+ @abstractmethod
68
+ def format_data(self) -> Union[str, bytes]:
69
+ """Convert the input data to the target format."""
70
+ pass
71
+
72
+ def export(self) -> Optional["FileStore"]:
73
+ """Export the data to a FileStore instance."""
74
+ formatted_data = self.format_data()
75
+ return self._create_filestore(formatted_data)
76
+
77
+
78
+ class JSONLExport(FileExport):
79
+ mime_type = "application/jsonl"
80
+ suffix = "jsonl"
81
+ is_binary = False
82
+
83
+ def format_data(self) -> str:
84
+ output = io.StringIO()
85
+ for entry in self.data:
86
+ key, values = list(entry.items())[0]
87
+ output.write(f'{{"{key}": {values}}}\n')
88
+ return output.getvalue()
89
+
90
+
91
+ class TabularExport(FileExport, ABC):
92
+ """Base class for exports that use tabular data."""
93
+
94
+ def __init__(self, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ self.header, self.rows = self.data._get_tabular_data(
97
+ remove_prefix=self.remove_prefix, pretty_labels=self.pretty_labels
98
+ )
99
+
100
+
101
+ class CSVExport(TabularExport):
102
+ mime_type = "text/csv"
103
+ suffix = "csv"
104
+ is_binary = False
105
+
106
+ def format_data(self) -> str:
107
+ output = io.StringIO()
108
+ writer = csv.writer(output)
109
+ writer.writerow(self.header)
110
+ writer.writerows(self.rows)
111
+ return output.getvalue()
112
+
113
+
114
+ class ExcelExport(TabularExport):
115
+ mime_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
116
+ suffix = "xlsx"
117
+ is_binary = True
118
+
119
+ def __init__(self, *args, sheet_name: Optional[str] = None, **kwargs):
120
+ super().__init__(*args, **kwargs)
121
+ self.sheet_name = sheet_name or "Results"
122
+
123
+ def format_data(self) -> bytes:
124
+ wb = Workbook()
125
+ ws = wb.active
126
+ ws.title = self.sheet_name
127
+
128
+ # Write header
129
+ for col, value in enumerate(self.header, 1):
130
+ ws.cell(row=1, column=col, value=value)
131
+
132
+ # Write data rows
133
+ for row_idx, row_data in enumerate(self.rows, 2):
134
+ for col, value in enumerate(row_data, 1):
135
+ ws.cell(row=row_idx, column=col, value=value)
136
+
137
+ # Save to bytes buffer
138
+ buffer = io.BytesIO()
139
+ wb.save(buffer)
140
+ buffer.seek(0)
141
+ return buffer.getvalue()
142
+
143
+
144
+ import sqlite3
145
+ from typing import Any
146
+
147
+
148
+ class SQLiteExport(TabularExport):
149
+ mime_type = "application/x-sqlite3"
150
+ suffix = "db"
151
+ is_binary = True
152
+
153
+ def __init__(
154
+ self, *args, table_name: str = "results", if_exists: str = "replace", **kwargs
155
+ ):
156
+ """
157
+ Initialize SQLite export.
158
+
159
+ Args:
160
+ table_name: Name of the table to create
161
+ if_exists: How to handle existing table ('fail', 'replace', or 'append')
162
+ """
163
+ super().__init__(*args, **kwargs)
164
+ self.table_name = table_name
165
+ self.if_exists = if_exists
166
+
167
+ def _get_column_types(self) -> list[tuple[str, str]]:
168
+ """Infer SQL column types from the data."""
169
+ column_types = []
170
+
171
+ # Check first row of data for types
172
+ if self.rows:
173
+ first_row = self.rows[0]
174
+ for header, value in zip(self.header, first_row):
175
+ if isinstance(value, bool):
176
+ sql_type = "BOOLEAN"
177
+ elif isinstance(value, int):
178
+ sql_type = "INTEGER"
179
+ elif isinstance(value, float):
180
+ sql_type = "REAL"
181
+ else:
182
+ sql_type = "TEXT"
183
+ column_types.append((header, sql_type))
184
+ else:
185
+ # If no data, default to TEXT
186
+ column_types = [(header, "TEXT") for header in self.header]
187
+
188
+ return column_types
189
+
190
+ def _create_table(self, cursor: sqlite3.Cursor) -> None:
191
+ """Create the table with appropriate schema."""
192
+ column_types = self._get_column_types()
193
+
194
+ # Drop existing table if replace mode
195
+ if self.if_exists == "replace":
196
+ cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
197
+ elif self.if_exists == "fail":
198
+ cursor.execute(
199
+ f"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
200
+ (self.table_name,),
201
+ )
202
+ if cursor.fetchone():
203
+ raise ValueError(f"Table {self.table_name} already exists")
204
+
205
+ # Create table
206
+ columns = ", ".join(f'"{col}" {dtype}' for col, dtype in column_types)
207
+ create_table_sql = f"""
208
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
209
+ {columns}
210
+ )
211
+ """
212
+ cursor.execute(create_table_sql)
213
+
214
+ def format_data(self) -> bytes:
215
+ """Convert the data to a SQLite database file."""
216
+ buffer = io.BytesIO()
217
+
218
+ # Create in-memory database
219
+ conn = sqlite3.connect(":memory:")
220
+ cursor = conn.cursor()
221
+
222
+ # Create table and insert data
223
+ self._create_table(cursor)
224
+
225
+ # Prepare placeholders for INSERT
226
+ placeholders = ",".join(["?" for _ in self.header])
227
+ insert_sql = f"INSERT INTO {self.table_name} ({','.join(self.header)}) VALUES ({placeholders})"
228
+
229
+ # Insert data
230
+ cursor.executemany(insert_sql, self.rows)
231
+ conn.commit()
232
+
233
+ # Save to file buffer
234
+ conn.backup(sqlite3.connect(buffer))
235
+ conn.close()
236
+
237
+ buffer.seek(0)
238
+ return buffer.getvalue()
239
+
240
+ def _validate_params(self) -> None:
241
+ """Validate initialization parameters."""
242
+ valid_if_exists = {"fail", "replace", "append"}
243
+ if self.if_exists not in valid_if_exists:
244
+ raise ValueError(
245
+ f"if_exists must be one of {valid_if_exists}, got {self.if_exists}"
246
+ )
247
+
248
+ # Validate table name (basic SQLite identifier validation)
249
+ if not self.table_name.isalnum() and not all(c in "_" for c in self.table_name):
250
+ raise ValueError(
251
+ f"Invalid table name: {self.table_name}. Must contain only alphanumeric characters and underscores."
252
+ )
@@ -45,7 +45,7 @@ from edsl.utilities.naming_utilities import sanitize_string
45
45
  from edsl.utilities.is_valid_variable_name import is_valid_variable_name
46
46
  from edsl.exceptions.scenarios import ScenarioError
47
47
 
48
- from edsl.scenarios.DirectoryScanner import DirectoryScanner
48
+ from edsl.scenarios.directory_scanner import DirectoryScanner
49
49
 
50
50
 
51
51
  class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
@@ -661,7 +661,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
661
661
  >>> s.select('a')
662
662
  ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
663
663
  """
664
- from edsl.scenarios.ScenarioSelector import ScenarioSelector
664
+ from edsl.scenarios.scenario_selector import ScenarioSelector
665
665
 
666
666
  return ScenarioSelector(self).select(*fields)
667
667
 
@@ -840,10 +840,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
840
840
  ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
841
841
  """
842
842
  sl = self.duplicate()
843
+ if len(values) != len(sl):
844
+ raise ScenarioError(
845
+ f"Length of values ({len(values)}) does not match length of ScenarioList ({len(sl)})"
846
+ )
843
847
  for i, value in enumerate(values):
844
848
  sl[i][name] = value
845
849
  return sl
846
850
 
851
+ @classmethod
852
+ def create_empty_scenario_list(cls, n: int) -> ScenarioList:
853
+ """Create an empty ScenarioList with n scenarios.
854
+
855
+ Example:
856
+
857
+ >>> ScenarioList.create_empty_scenario_list(3)
858
+ ScenarioList([Scenario({}), Scenario({}), Scenario({})])
859
+ """
860
+ return ScenarioList([Scenario({}) for _ in range(n)])
861
+
847
862
  def add_value(self, name: str, value: Any) -> ScenarioList:
848
863
  """Add a value to all scenarios in a ScenarioList.
849
864
 
@@ -1222,7 +1237,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1222
1237
  >>> s3 == ScenarioList([Scenario({'age': 30, 'location': 'New York', 'name': 'Alice'}), Scenario({'age': 25, 'location': None, 'name': 'Bob'})])
1223
1238
  True
1224
1239
  """
1225
- from edsl.scenarios.ScenarioJoin import ScenarioJoin
1240
+ from edsl.scenarios.scenario_join import ScenarioJoin
1226
1241
 
1227
1242
  sj = ScenarioJoin(self, other)
1228
1243
  return sj.left_join(by)
@@ -1244,6 +1259,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1244
1259
  else:
1245
1260
  data = self
1246
1261
  d = {"scenarios": [s.to_dict(add_edsl_version=add_edsl_version) for s in data]}
1262
+
1247
1263
  if add_edsl_version:
1248
1264
  from edsl import __version__
1249
1265
 
@@ -1296,10 +1312,22 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1296
1312
 
1297
1313
  @classmethod
1298
1314
  def from_nested_dict(cls, data: dict) -> ScenarioList:
1299
- """Create a `ScenarioList` from a nested dictionary."""
1300
- s = ScenarioList()
1301
- for key, value in data.items():
1302
- s.add_list(key, value)
1315
+ """Create a `ScenarioList` from a nested dictionary.
1316
+
1317
+ >>> data = {"headline": ["Armistice Signed, War Over: Celebrations Erupt Across City"], "date": ["1918-11-11"], "author": ["Jane Smith"]}
1318
+ >>> ScenarioList.from_nested_dict(data)
1319
+ ScenarioList([Scenario({'headline': 'Armistice Signed, War Over: Celebrations Erupt Across City', 'date': '1918-11-11', 'author': 'Jane Smith'})])
1320
+
1321
+ """
1322
+ length_of_first_list = len(next(iter(data.values())))
1323
+ s = ScenarioList.create_empty_scenario_list(n=length_of_first_list)
1324
+
1325
+ if any(len(v) != length_of_first_list for v in data.values()):
1326
+ raise ValueError(
1327
+ "All lists in the dictionary must be of the same length.",
1328
+ )
1329
+ for key, list_of_values in data.items():
1330
+ s = s.add_list(key, list_of_values)
1303
1331
  return s
1304
1332
 
1305
1333
  def code(self) -> str: