edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/results/Results.py
CHANGED
@@ -9,6 +9,8 @@ import random
|
|
9
9
|
from collections import UserList, defaultdict
|
10
10
|
from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
|
11
11
|
|
12
|
+
from bisect import bisect_left
|
13
|
+
|
12
14
|
from edsl.Base import Base
|
13
15
|
from edsl.exceptions.results import (
|
14
16
|
ResultsError,
|
@@ -24,7 +26,7 @@ if TYPE_CHECKING:
|
|
24
26
|
from edsl.surveys.Survey import Survey
|
25
27
|
from edsl.data.Cache import Cache
|
26
28
|
from edsl.agents.AgentList import AgentList
|
27
|
-
from edsl.language_models.
|
29
|
+
from edsl.language_models.model import Model
|
28
30
|
from edsl.scenarios.ScenarioList import ScenarioList
|
29
31
|
from edsl.results.Result import Result
|
30
32
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
@@ -33,7 +35,7 @@ if TYPE_CHECKING:
|
|
33
35
|
|
34
36
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
35
37
|
from edsl.results.ResultsGGMixin import ResultsGGMixin
|
36
|
-
from edsl.results.
|
38
|
+
from edsl.results.results_fetch_mixin import ResultsFetchMixin
|
37
39
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
38
40
|
|
39
41
|
|
@@ -136,7 +138,33 @@ class Results(UserList, Mixins, Base):
|
|
136
138
|
}
|
137
139
|
return d
|
138
140
|
|
139
|
-
def
|
141
|
+
def insert(self, item):
|
142
|
+
item_order = getattr(item, "order", None)
|
143
|
+
if item_order is not None:
|
144
|
+
# Get list of orders, putting None at the end
|
145
|
+
orders = [getattr(x, "order", None) for x in self]
|
146
|
+
# Filter to just the non-None orders for bisect
|
147
|
+
sorted_orders = [x for x in orders if x is not None]
|
148
|
+
if sorted_orders:
|
149
|
+
index = bisect_left(sorted_orders, item_order)
|
150
|
+
# Account for any None values before this position
|
151
|
+
index += orders[:index].count(None)
|
152
|
+
else:
|
153
|
+
# If no sorted items yet, insert before any unordered items
|
154
|
+
index = 0
|
155
|
+
self.data.insert(index, item)
|
156
|
+
else:
|
157
|
+
# No order - append to end
|
158
|
+
self.data.append(item)
|
159
|
+
|
160
|
+
def append(self, item):
|
161
|
+
self.insert(item)
|
162
|
+
|
163
|
+
def extend(self, other):
|
164
|
+
for item in other:
|
165
|
+
self.insert(item)
|
166
|
+
|
167
|
+
def compute_job_cost(self, include_cached_responses_in_cost: bool = False) -> float:
|
140
168
|
"""
|
141
169
|
Computes the cost of a completed job in USD.
|
142
170
|
"""
|
@@ -250,24 +278,6 @@ class Results(UserList, Mixins, Base):
|
|
250
278
|
|
251
279
|
raise TypeError("Invalid argument type")
|
252
280
|
|
253
|
-
# def _update_results(self) -> None:
|
254
|
-
# from edsl import Agent, Scenario
|
255
|
-
# from edsl.language_models import LanguageModel
|
256
|
-
# from edsl.results import Result
|
257
|
-
|
258
|
-
# if self._job_uuid and len(self.data) < self._total_results:
|
259
|
-
# results = [
|
260
|
-
# Result(
|
261
|
-
# agent=Agent.from_dict(json.loads(r.agent)),
|
262
|
-
# scenario=Scenario.from_dict(json.loads(r.scenario)),
|
263
|
-
# model=LanguageModel.from_dict(json.loads(r.model)),
|
264
|
-
# iteration=1,
|
265
|
-
# answer=json.loads(r.answer),
|
266
|
-
# )
|
267
|
-
# for r in CRUD.read_results(self._job_uuid)
|
268
|
-
# ]
|
269
|
-
# self.data = results
|
270
|
-
|
271
281
|
def __add__(self, other: Results) -> Results:
|
272
282
|
"""Add two Results objects together.
|
273
283
|
They must have the same survey and created columns.
|
@@ -295,13 +305,10 @@ class Results(UserList, Mixins, Base):
|
|
295
305
|
)
|
296
306
|
|
297
307
|
def __repr__(self) -> str:
|
298
|
-
# import reprlib
|
299
|
-
|
300
308
|
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
301
309
|
|
302
310
|
def table(
|
303
311
|
self,
|
304
|
-
# selector_string: Optional[str] = "*.*",
|
305
312
|
*fields,
|
306
313
|
tablefmt: Optional[str] = None,
|
307
314
|
pretty_labels: Optional[dict] = None,
|
@@ -340,11 +347,11 @@ class Results(UserList, Mixins, Base):
|
|
340
347
|
|
341
348
|
def to_dict(
|
342
349
|
self,
|
343
|
-
sort=False,
|
344
|
-
add_edsl_version=False,
|
345
|
-
include_cache=False,
|
346
|
-
include_task_history=False,
|
347
|
-
include_cache_info=True,
|
350
|
+
sort: bool = False,
|
351
|
+
add_edsl_version: bool = False,
|
352
|
+
include_cache: bool = False,
|
353
|
+
include_task_history: bool = False,
|
354
|
+
include_cache_info: bool = True,
|
348
355
|
) -> dict[str, Any]:
|
349
356
|
from edsl.data.Cache import Cache
|
350
357
|
|
@@ -386,7 +393,7 @@ class Results(UserList, Mixins, Base):
|
|
386
393
|
|
387
394
|
return d
|
388
395
|
|
389
|
-
def compare(self, other_results):
|
396
|
+
def compare(self, other_results: Results) -> dict:
|
390
397
|
"""
|
391
398
|
Compare two Results objects and return the differences.
|
392
399
|
"""
|
@@ -404,7 +411,7 @@ class Results(UserList, Mixins, Base):
|
|
404
411
|
}
|
405
412
|
|
406
413
|
@property
|
407
|
-
def has_unfixed_exceptions(self):
|
414
|
+
def has_unfixed_exceptions(self) -> bool:
|
408
415
|
return self.task_history.has_unfixed_exceptions
|
409
416
|
|
410
417
|
def __hash__(self) -> int:
|
@@ -487,10 +494,6 @@ class Results(UserList, Mixins, Base):
|
|
487
494
|
raise ResultsDeserializationError(f"Error in Results.from_dict: {e}")
|
488
495
|
return results
|
489
496
|
|
490
|
-
######################
|
491
|
-
## Convenience methods
|
492
|
-
## & Report methods
|
493
|
-
######################
|
494
497
|
@property
|
495
498
|
def _key_to_data_type(self) -> dict[str, str]:
|
496
499
|
"""
|
@@ -689,13 +692,19 @@ class Results(UserList, Mixins, Base):
|
|
689
692
|
"""
|
690
693
|
return self.data[0]
|
691
694
|
|
692
|
-
def answer_truncate(
|
695
|
+
def answer_truncate(
|
696
|
+
self, column: str, top_n: int = 5, new_var_name: str = None
|
697
|
+
) -> Results:
|
693
698
|
"""Create a new variable that truncates the answers to the top_n.
|
694
699
|
|
695
700
|
:param column: The column to truncate.
|
696
701
|
:param top_n: The number of top answers to keep.
|
697
702
|
:param new_var_name: The name of the new variable. If None, it is the original name + '_truncated'.
|
698
703
|
|
704
|
+
Example:
|
705
|
+
>>> r = Results.example()
|
706
|
+
>>> r.answer_truncate('how_feeling', top_n = 2).select('how_feeling', 'how_feeling_truncated')
|
707
|
+
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling_truncated': ['Other', 'Other', 'Other', 'Other']}])
|
699
708
|
|
700
709
|
|
701
710
|
"""
|
@@ -916,7 +925,7 @@ class Results(UserList, Mixins, Base):
|
|
916
925
|
n: Optional[int] = None,
|
917
926
|
frac: Optional[float] = None,
|
918
927
|
with_replacement: bool = True,
|
919
|
-
seed: Optional[str] =
|
928
|
+
seed: Optional[str] = None,
|
920
929
|
) -> Results:
|
921
930
|
"""Sample the results.
|
922
931
|
|
@@ -931,7 +940,7 @@ class Results(UserList, Mixins, Base):
|
|
931
940
|
>>> len(r.sample(2))
|
932
941
|
2
|
933
942
|
"""
|
934
|
-
if seed
|
943
|
+
if seed:
|
935
944
|
random.seed(seed)
|
936
945
|
|
937
946
|
if n is None and frac is None:
|
@@ -969,7 +978,7 @@ class Results(UserList, Mixins, Base):
|
|
969
978
|
Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
|
970
979
|
"""
|
971
980
|
|
972
|
-
from edsl.results.
|
981
|
+
from edsl.results.results_selector import Selector
|
973
982
|
|
974
983
|
if len(self) == 0:
|
975
984
|
raise Exception("No data to select from---the Results object is empty.")
|
@@ -984,6 +993,7 @@ class Results(UserList, Mixins, Base):
|
|
984
993
|
return selector.select(*columns)
|
985
994
|
|
986
995
|
def sort_by(self, *columns: str, reverse: bool = False) -> Results:
|
996
|
+
"""Sort the results by one or more columns."""
|
987
997
|
import warnings
|
988
998
|
|
989
999
|
warnings.warn(
|
@@ -992,6 +1002,7 @@ class Results(UserList, Mixins, Base):
|
|
992
1002
|
return self.order_by(*columns, reverse=reverse)
|
993
1003
|
|
994
1004
|
def _parse_column(self, column: str) -> tuple[str, str]:
|
1005
|
+
"""Parse a column name into a data type and key."""
|
995
1006
|
if "." in column:
|
996
1007
|
return column.split(".")
|
997
1008
|
return self._key_to_data_type[column], column
|
@@ -0,0 +1,252 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
import io
|
3
|
+
import csv
|
4
|
+
import base64
|
5
|
+
from typing import Optional, Union, Tuple, List, Any, Dict
|
6
|
+
from openpyxl import Workbook
|
7
|
+
|
8
|
+
from edsl.scenarios.FileStore import FileStore
|
9
|
+
|
10
|
+
|
11
|
+
class FileExport(ABC):
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
data: Any,
|
15
|
+
filename: Optional[str] = None,
|
16
|
+
remove_prefix: bool = False,
|
17
|
+
pretty_labels: Optional[Dict[str, str]] = None,
|
18
|
+
):
|
19
|
+
self.data = data
|
20
|
+
self.filename = filename # or self._get_default_filename()
|
21
|
+
self.remove_prefix = remove_prefix
|
22
|
+
self.pretty_labels = pretty_labels
|
23
|
+
|
24
|
+
@property
|
25
|
+
def mime_type(self) -> str:
|
26
|
+
"""Return the MIME type for this export format."""
|
27
|
+
return self.__class__.mime_type
|
28
|
+
|
29
|
+
@property
|
30
|
+
def suffix(self) -> str:
|
31
|
+
"""Return the file suffix for this format."""
|
32
|
+
return self.__class__.suffix
|
33
|
+
|
34
|
+
@property
|
35
|
+
def is_binary(self) -> bool:
|
36
|
+
"""Whether the format is binary or text-based."""
|
37
|
+
return self.__class__.is_binary
|
38
|
+
|
39
|
+
def _get_default_filename(self) -> str:
|
40
|
+
"""Generate default filename for this format."""
|
41
|
+
return f"results.{self.suffix}"
|
42
|
+
|
43
|
+
def _create_filestore(self, data: Union[str, bytes]) -> "FileStore":
|
44
|
+
"""Create a FileStore instance with encoded data."""
|
45
|
+
if isinstance(data, str):
|
46
|
+
base64_string = base64.b64encode(data.encode()).decode()
|
47
|
+
else:
|
48
|
+
base64_string = base64.b64encode(data).decode()
|
49
|
+
|
50
|
+
from edsl.scenarios.FileStore import FileStore
|
51
|
+
|
52
|
+
path = self.filename or self._get_default_filename()
|
53
|
+
|
54
|
+
fs = FileStore(
|
55
|
+
path=path,
|
56
|
+
mime_type=self.mime_type,
|
57
|
+
binary=self.is_binary,
|
58
|
+
suffix=self.suffix,
|
59
|
+
base64_string=base64_string,
|
60
|
+
)
|
61
|
+
|
62
|
+
if self.filename is not None:
|
63
|
+
fs.write(self.filename)
|
64
|
+
return None
|
65
|
+
return fs
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
def format_data(self) -> Union[str, bytes]:
|
69
|
+
"""Convert the input data to the target format."""
|
70
|
+
pass
|
71
|
+
|
72
|
+
def export(self) -> Optional["FileStore"]:
|
73
|
+
"""Export the data to a FileStore instance."""
|
74
|
+
formatted_data = self.format_data()
|
75
|
+
return self._create_filestore(formatted_data)
|
76
|
+
|
77
|
+
|
78
|
+
class JSONLExport(FileExport):
|
79
|
+
mime_type = "application/jsonl"
|
80
|
+
suffix = "jsonl"
|
81
|
+
is_binary = False
|
82
|
+
|
83
|
+
def format_data(self) -> str:
|
84
|
+
output = io.StringIO()
|
85
|
+
for entry in self.data:
|
86
|
+
key, values = list(entry.items())[0]
|
87
|
+
output.write(f'{{"{key}": {values}}}\n')
|
88
|
+
return output.getvalue()
|
89
|
+
|
90
|
+
|
91
|
+
class TabularExport(FileExport, ABC):
|
92
|
+
"""Base class for exports that use tabular data."""
|
93
|
+
|
94
|
+
def __init__(self, *args, **kwargs):
|
95
|
+
super().__init__(*args, **kwargs)
|
96
|
+
self.header, self.rows = self.data._get_tabular_data(
|
97
|
+
remove_prefix=self.remove_prefix, pretty_labels=self.pretty_labels
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
class CSVExport(TabularExport):
|
102
|
+
mime_type = "text/csv"
|
103
|
+
suffix = "csv"
|
104
|
+
is_binary = False
|
105
|
+
|
106
|
+
def format_data(self) -> str:
|
107
|
+
output = io.StringIO()
|
108
|
+
writer = csv.writer(output)
|
109
|
+
writer.writerow(self.header)
|
110
|
+
writer.writerows(self.rows)
|
111
|
+
return output.getvalue()
|
112
|
+
|
113
|
+
|
114
|
+
class ExcelExport(TabularExport):
|
115
|
+
mime_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
116
|
+
suffix = "xlsx"
|
117
|
+
is_binary = True
|
118
|
+
|
119
|
+
def __init__(self, *args, sheet_name: Optional[str] = None, **kwargs):
|
120
|
+
super().__init__(*args, **kwargs)
|
121
|
+
self.sheet_name = sheet_name or "Results"
|
122
|
+
|
123
|
+
def format_data(self) -> bytes:
|
124
|
+
wb = Workbook()
|
125
|
+
ws = wb.active
|
126
|
+
ws.title = self.sheet_name
|
127
|
+
|
128
|
+
# Write header
|
129
|
+
for col, value in enumerate(self.header, 1):
|
130
|
+
ws.cell(row=1, column=col, value=value)
|
131
|
+
|
132
|
+
# Write data rows
|
133
|
+
for row_idx, row_data in enumerate(self.rows, 2):
|
134
|
+
for col, value in enumerate(row_data, 1):
|
135
|
+
ws.cell(row=row_idx, column=col, value=value)
|
136
|
+
|
137
|
+
# Save to bytes buffer
|
138
|
+
buffer = io.BytesIO()
|
139
|
+
wb.save(buffer)
|
140
|
+
buffer.seek(0)
|
141
|
+
return buffer.getvalue()
|
142
|
+
|
143
|
+
|
144
|
+
import sqlite3
|
145
|
+
from typing import Any
|
146
|
+
|
147
|
+
|
148
|
+
class SQLiteExport(TabularExport):
|
149
|
+
mime_type = "application/x-sqlite3"
|
150
|
+
suffix = "db"
|
151
|
+
is_binary = True
|
152
|
+
|
153
|
+
def __init__(
|
154
|
+
self, *args, table_name: str = "results", if_exists: str = "replace", **kwargs
|
155
|
+
):
|
156
|
+
"""
|
157
|
+
Initialize SQLite export.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
table_name: Name of the table to create
|
161
|
+
if_exists: How to handle existing table ('fail', 'replace', or 'append')
|
162
|
+
"""
|
163
|
+
super().__init__(*args, **kwargs)
|
164
|
+
self.table_name = table_name
|
165
|
+
self.if_exists = if_exists
|
166
|
+
|
167
|
+
def _get_column_types(self) -> list[tuple[str, str]]:
|
168
|
+
"""Infer SQL column types from the data."""
|
169
|
+
column_types = []
|
170
|
+
|
171
|
+
# Check first row of data for types
|
172
|
+
if self.rows:
|
173
|
+
first_row = self.rows[0]
|
174
|
+
for header, value in zip(self.header, first_row):
|
175
|
+
if isinstance(value, bool):
|
176
|
+
sql_type = "BOOLEAN"
|
177
|
+
elif isinstance(value, int):
|
178
|
+
sql_type = "INTEGER"
|
179
|
+
elif isinstance(value, float):
|
180
|
+
sql_type = "REAL"
|
181
|
+
else:
|
182
|
+
sql_type = "TEXT"
|
183
|
+
column_types.append((header, sql_type))
|
184
|
+
else:
|
185
|
+
# If no data, default to TEXT
|
186
|
+
column_types = [(header, "TEXT") for header in self.header]
|
187
|
+
|
188
|
+
return column_types
|
189
|
+
|
190
|
+
def _create_table(self, cursor: sqlite3.Cursor) -> None:
|
191
|
+
"""Create the table with appropriate schema."""
|
192
|
+
column_types = self._get_column_types()
|
193
|
+
|
194
|
+
# Drop existing table if replace mode
|
195
|
+
if self.if_exists == "replace":
|
196
|
+
cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
197
|
+
elif self.if_exists == "fail":
|
198
|
+
cursor.execute(
|
199
|
+
f"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
200
|
+
(self.table_name,),
|
201
|
+
)
|
202
|
+
if cursor.fetchone():
|
203
|
+
raise ValueError(f"Table {self.table_name} already exists")
|
204
|
+
|
205
|
+
# Create table
|
206
|
+
columns = ", ".join(f'"{col}" {dtype}' for col, dtype in column_types)
|
207
|
+
create_table_sql = f"""
|
208
|
+
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
209
|
+
{columns}
|
210
|
+
)
|
211
|
+
"""
|
212
|
+
cursor.execute(create_table_sql)
|
213
|
+
|
214
|
+
def format_data(self) -> bytes:
|
215
|
+
"""Convert the data to a SQLite database file."""
|
216
|
+
buffer = io.BytesIO()
|
217
|
+
|
218
|
+
# Create in-memory database
|
219
|
+
conn = sqlite3.connect(":memory:")
|
220
|
+
cursor = conn.cursor()
|
221
|
+
|
222
|
+
# Create table and insert data
|
223
|
+
self._create_table(cursor)
|
224
|
+
|
225
|
+
# Prepare placeholders for INSERT
|
226
|
+
placeholders = ",".join(["?" for _ in self.header])
|
227
|
+
insert_sql = f"INSERT INTO {self.table_name} ({','.join(self.header)}) VALUES ({placeholders})"
|
228
|
+
|
229
|
+
# Insert data
|
230
|
+
cursor.executemany(insert_sql, self.rows)
|
231
|
+
conn.commit()
|
232
|
+
|
233
|
+
# Save to file buffer
|
234
|
+
conn.backup(sqlite3.connect(buffer))
|
235
|
+
conn.close()
|
236
|
+
|
237
|
+
buffer.seek(0)
|
238
|
+
return buffer.getvalue()
|
239
|
+
|
240
|
+
def _validate_params(self) -> None:
|
241
|
+
"""Validate initialization parameters."""
|
242
|
+
valid_if_exists = {"fail", "replace", "append"}
|
243
|
+
if self.if_exists not in valid_if_exists:
|
244
|
+
raise ValueError(
|
245
|
+
f"if_exists must be one of {valid_if_exists}, got {self.if_exists}"
|
246
|
+
)
|
247
|
+
|
248
|
+
# Validate table name (basic SQLite identifier validation)
|
249
|
+
if not self.table_name.isalnum() and not all(c in "_" for c in self.table_name):
|
250
|
+
raise ValueError(
|
251
|
+
f"Invalid table name: {self.table_name}. Must contain only alphanumeric characters and underscores."
|
252
|
+
)
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -45,7 +45,7 @@ from edsl.utilities.naming_utilities import sanitize_string
|
|
45
45
|
from edsl.utilities.is_valid_variable_name import is_valid_variable_name
|
46
46
|
from edsl.exceptions.scenarios import ScenarioError
|
47
47
|
|
48
|
-
from edsl.scenarios.
|
48
|
+
from edsl.scenarios.directory_scanner import DirectoryScanner
|
49
49
|
|
50
50
|
|
51
51
|
class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
|
@@ -661,7 +661,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
661
661
|
>>> s.select('a')
|
662
662
|
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
663
663
|
"""
|
664
|
-
from edsl.scenarios.
|
664
|
+
from edsl.scenarios.scenario_selector import ScenarioSelector
|
665
665
|
|
666
666
|
return ScenarioSelector(self).select(*fields)
|
667
667
|
|
@@ -840,10 +840,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
840
840
|
ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
841
841
|
"""
|
842
842
|
sl = self.duplicate()
|
843
|
+
if len(values) != len(sl):
|
844
|
+
raise ScenarioError(
|
845
|
+
f"Length of values ({len(values)}) does not match length of ScenarioList ({len(sl)})"
|
846
|
+
)
|
843
847
|
for i, value in enumerate(values):
|
844
848
|
sl[i][name] = value
|
845
849
|
return sl
|
846
850
|
|
851
|
+
@classmethod
|
852
|
+
def create_empty_scenario_list(cls, n: int) -> ScenarioList:
|
853
|
+
"""Create an empty ScenarioList with n scenarios.
|
854
|
+
|
855
|
+
Example:
|
856
|
+
|
857
|
+
>>> ScenarioList.create_empty_scenario_list(3)
|
858
|
+
ScenarioList([Scenario({}), Scenario({}), Scenario({})])
|
859
|
+
"""
|
860
|
+
return ScenarioList([Scenario({}) for _ in range(n)])
|
861
|
+
|
847
862
|
def add_value(self, name: str, value: Any) -> ScenarioList:
|
848
863
|
"""Add a value to all scenarios in a ScenarioList.
|
849
864
|
|
@@ -1222,7 +1237,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1222
1237
|
>>> s3 == ScenarioList([Scenario({'age': 30, 'location': 'New York', 'name': 'Alice'}), Scenario({'age': 25, 'location': None, 'name': 'Bob'})])
|
1223
1238
|
True
|
1224
1239
|
"""
|
1225
|
-
from edsl.scenarios.
|
1240
|
+
from edsl.scenarios.scenario_join import ScenarioJoin
|
1226
1241
|
|
1227
1242
|
sj = ScenarioJoin(self, other)
|
1228
1243
|
return sj.left_join(by)
|
@@ -1244,6 +1259,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1244
1259
|
else:
|
1245
1260
|
data = self
|
1246
1261
|
d = {"scenarios": [s.to_dict(add_edsl_version=add_edsl_version) for s in data]}
|
1262
|
+
|
1247
1263
|
if add_edsl_version:
|
1248
1264
|
from edsl import __version__
|
1249
1265
|
|
@@ -1296,10 +1312,22 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1296
1312
|
|
1297
1313
|
@classmethod
|
1298
1314
|
def from_nested_dict(cls, data: dict) -> ScenarioList:
|
1299
|
-
"""Create a `ScenarioList` from a nested dictionary.
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1315
|
+
"""Create a `ScenarioList` from a nested dictionary.
|
1316
|
+
|
1317
|
+
>>> data = {"headline": ["Armistice Signed, War Over: Celebrations Erupt Across City"], "date": ["1918-11-11"], "author": ["Jane Smith"]}
|
1318
|
+
>>> ScenarioList.from_nested_dict(data)
|
1319
|
+
ScenarioList([Scenario({'headline': 'Armistice Signed, War Over: Celebrations Erupt Across City', 'date': '1918-11-11', 'author': 'Jane Smith'})])
|
1320
|
+
|
1321
|
+
"""
|
1322
|
+
length_of_first_list = len(next(iter(data.values())))
|
1323
|
+
s = ScenarioList.create_empty_scenario_list(n=length_of_first_list)
|
1324
|
+
|
1325
|
+
if any(len(v) != length_of_first_list for v in data.values()):
|
1326
|
+
raise ValueError(
|
1327
|
+
"All lists in the dictionary must be of the same length.",
|
1328
|
+
)
|
1329
|
+
for key, list_of_values in data.items():
|
1330
|
+
s = s.add_list(key, list_of_values)
|
1303
1331
|
return s
|
1304
1332
|
|
1305
1333
|
def code(self) -> str:
|