edsl 0.1.45__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +7 -3
- edsl/__version__.py +1 -1
- edsl/agents/PromptConstructor.py +26 -79
- edsl/agents/QuestionInstructionPromptBuilder.py +70 -32
- edsl/agents/QuestionTemplateReplacementsBuilder.py +12 -2
- edsl/coop/coop.py +155 -94
- edsl/data/RemoteCacheSync.py +10 -9
- edsl/inference_services/AvailableModelFetcher.py +1 -1
- edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
- edsl/jobs/Jobs.py +15 -17
- edsl/jobs/JobsPrompts.py +49 -26
- edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/interviews/Interview.py +6 -3
- edsl/language_models/LanguageModel.py +6 -0
- edsl/questions/question_base_gen_mixin.py +2 -0
- edsl/results/DatasetExportMixin.py +25 -4
- edsl/scenarios/ScenarioList.py +153 -21
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/METADATA +2 -2
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/RECORD +22 -22
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
edsl/jobs/JobsPrompts.py
CHANGED
@@ -18,6 +18,7 @@ from edsl.data.CacheEntry import CacheEntry
|
|
18
18
|
|
19
19
|
logger = logging.getLogger(__name__)
|
20
20
|
|
21
|
+
|
21
22
|
class JobsPrompts:
|
22
23
|
def __init__(self, jobs: "Jobs"):
|
23
24
|
self.interviews = jobs.interviews()
|
@@ -26,7 +27,9 @@ class JobsPrompts:
|
|
26
27
|
self.survey = jobs.survey
|
27
28
|
self._price_lookup = None
|
28
29
|
self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
|
29
|
-
self._scenario_lookup = {
|
30
|
+
self._scenario_lookup = {
|
31
|
+
scenario: idx for idx, scenario in enumerate(self.scenarios)
|
32
|
+
}
|
30
33
|
|
31
34
|
@property
|
32
35
|
def price_lookup(self):
|
@@ -37,7 +40,7 @@ class JobsPrompts:
|
|
37
40
|
self._price_lookup = c.fetch_prices()
|
38
41
|
return self._price_lookup
|
39
42
|
|
40
|
-
def prompts(self) -> "Dataset":
|
43
|
+
def prompts(self, iterations=1) -> "Dataset":
|
41
44
|
"""Return a Dataset of prompts that will be used.
|
42
45
|
|
43
46
|
>>> from edsl.jobs import Jobs
|
@@ -54,11 +57,11 @@ class JobsPrompts:
|
|
54
57
|
models = []
|
55
58
|
costs = []
|
56
59
|
cache_keys = []
|
57
|
-
|
60
|
+
|
58
61
|
for interview_index, interview in enumerate(interviews):
|
59
62
|
logger.info(f"Processing interview {interview_index} of {len(interviews)}")
|
60
63
|
interview_start = time.time()
|
61
|
-
|
64
|
+
|
62
65
|
# Fetch invigilators timing
|
63
66
|
invig_start = time.time()
|
64
67
|
invigilators = [
|
@@ -66,8 +69,10 @@ class JobsPrompts:
|
|
66
69
|
for question in interview.survey.questions
|
67
70
|
]
|
68
71
|
invig_end = time.time()
|
69
|
-
logger.debug(
|
70
|
-
|
72
|
+
logger.debug(
|
73
|
+
f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s"
|
74
|
+
)
|
75
|
+
|
71
76
|
# Process prompts timing
|
72
77
|
prompts_start = time.time()
|
73
78
|
for _, invigilator in enumerate(invigilators):
|
@@ -75,13 +80,15 @@ class JobsPrompts:
|
|
75
80
|
get_prompts_start = time.time()
|
76
81
|
prompts = invigilator.get_prompts()
|
77
82
|
get_prompts_end = time.time()
|
78
|
-
logger.debug(
|
79
|
-
|
83
|
+
logger.debug(
|
84
|
+
f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s"
|
85
|
+
)
|
86
|
+
|
80
87
|
user_prompt = prompts["user_prompt"]
|
81
88
|
system_prompt = prompts["system_prompt"]
|
82
89
|
user_prompts.append(user_prompt)
|
83
90
|
system_prompts.append(system_prompt)
|
84
|
-
|
91
|
+
|
85
92
|
# Index lookups timing
|
86
93
|
index_start = time.time()
|
87
94
|
agent_index = self._agent_lookup[invigilator.agent]
|
@@ -90,14 +97,18 @@ class JobsPrompts:
|
|
90
97
|
scenario_index = self._scenario_lookup[invigilator.scenario]
|
91
98
|
scenario_indices.append(scenario_index)
|
92
99
|
index_end = time.time()
|
93
|
-
logger.debug(
|
94
|
-
|
100
|
+
logger.debug(
|
101
|
+
f"Time taken for index lookups: {index_end - index_start:.4f}s"
|
102
|
+
)
|
103
|
+
|
95
104
|
# Model and question name assignment timing
|
96
105
|
assign_start = time.time()
|
97
106
|
models.append(invigilator.model.model)
|
98
107
|
question_names.append(invigilator.question.question_name)
|
99
108
|
assign_end = time.time()
|
100
|
-
logger.debug(
|
109
|
+
logger.debug(
|
110
|
+
f"Time taken for assignments: {assign_end - assign_start:.4f}s"
|
111
|
+
)
|
101
112
|
|
102
113
|
# Cost estimation timing
|
103
114
|
cost_start = time.time()
|
@@ -109,32 +120,44 @@ class JobsPrompts:
|
|
109
120
|
model=invigilator.model.model,
|
110
121
|
)
|
111
122
|
cost_end = time.time()
|
112
|
-
logger.debug(
|
123
|
+
logger.debug(
|
124
|
+
f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s"
|
125
|
+
)
|
113
126
|
costs.append(prompt_cost["cost_usd"])
|
114
127
|
|
115
128
|
# Cache key generation timing
|
116
129
|
cache_key_gen_start = time.time()
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
130
|
+
for iteration in range(iterations):
|
131
|
+
cache_key = CacheEntry.gen_key(
|
132
|
+
model=invigilator.model.model,
|
133
|
+
parameters=invigilator.model.parameters,
|
134
|
+
system_prompt=system_prompt,
|
135
|
+
user_prompt=user_prompt,
|
136
|
+
iteration=iteration,
|
137
|
+
)
|
138
|
+
cache_keys.append(cache_key)
|
139
|
+
|
124
140
|
cache_key_gen_end = time.time()
|
125
|
-
|
126
|
-
|
141
|
+
logger.debug(
|
142
|
+
f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s"
|
143
|
+
)
|
127
144
|
logger.debug("-" * 50) # Separator between iterations
|
128
145
|
|
129
146
|
prompts_end = time.time()
|
130
|
-
logger.info(
|
131
|
-
|
147
|
+
logger.info(
|
148
|
+
f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s"
|
149
|
+
)
|
150
|
+
|
132
151
|
interview_end = time.time()
|
133
|
-
logger.info(
|
152
|
+
logger.info(
|
153
|
+
f"Overall time taken for interview: {interview_end - interview_start:.4f}s"
|
154
|
+
)
|
134
155
|
logger.info("Time breakdown:")
|
135
156
|
logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
|
136
157
|
logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
|
137
|
-
logger.info(
|
158
|
+
logger.info(
|
159
|
+
f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s"
|
160
|
+
)
|
138
161
|
|
139
162
|
d = Dataset(
|
140
163
|
[
|
@@ -24,7 +24,7 @@ from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
|
|
24
24
|
class RemoteJobConstants:
|
25
25
|
"""Constants for remote job handling."""
|
26
26
|
|
27
|
-
REMOTE_JOB_POLL_INTERVAL =
|
27
|
+
REMOTE_JOB_POLL_INTERVAL = 4
|
28
28
|
REMOTE_JOB_VERBOSE = False
|
29
29
|
DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
|
30
30
|
|
@@ -88,8 +88,8 @@ class JobsRemoteInferenceHandler:
|
|
88
88
|
iterations: int = 1,
|
89
89
|
remote_inference_description: Optional[str] = None,
|
90
90
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
91
|
+
fresh: Optional[bool] = False,
|
91
92
|
) -> RemoteJobInfo:
|
92
|
-
|
93
93
|
from edsl.config import CONFIG
|
94
94
|
from edsl.coop.coop import Coop
|
95
95
|
|
@@ -106,6 +106,7 @@ class JobsRemoteInferenceHandler:
|
|
106
106
|
status="queued",
|
107
107
|
iterations=iterations,
|
108
108
|
initial_results_visibility=remote_inference_results_visibility,
|
109
|
+
fresh=fresh,
|
109
110
|
)
|
110
111
|
logger.update(
|
111
112
|
"Your survey is running at the Expected Parrot server...",
|
@@ -277,9 +278,7 @@ class JobsRemoteInferenceHandler:
|
|
277
278
|
job_in_queue = True
|
278
279
|
while job_in_queue:
|
279
280
|
result = self._attempt_fetch_job(
|
280
|
-
job_info,
|
281
|
-
remote_job_data_fetcher,
|
282
|
-
object_fetcher
|
281
|
+
job_info, remote_job_data_fetcher, object_fetcher
|
283
282
|
)
|
284
283
|
if result != "continue":
|
285
284
|
return result
|
edsl/jobs/data_structures.py
CHANGED
@@ -36,6 +36,9 @@ class RunParameters(Base):
|
|
36
36
|
disable_remote_cache: bool = False
|
37
37
|
disable_remote_inference: bool = False
|
38
38
|
job_uuid: Optional[str] = None
|
39
|
+
fresh: Optional[
|
40
|
+
bool
|
41
|
+
] = False # if True, will not use cache and will save new results to cache
|
39
42
|
|
40
43
|
def to_dict(self, add_edsl_version=False) -> dict:
|
41
44
|
d = asdict(self)
|
@@ -238,9 +238,6 @@ class Interview:
|
|
238
238
|
>>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
|
239
239
|
>>> run_config.parameters.stop_on_exception = True
|
240
240
|
>>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
|
241
|
-
Traceback (most recent call last):
|
242
|
-
...
|
243
|
-
asyncio.exceptions.CancelledError
|
244
241
|
"""
|
245
242
|
from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
|
246
243
|
|
@@ -262,6 +259,8 @@ class Interview:
|
|
262
259
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
263
260
|
model_buckets = ModelBuckets.infinity_bucket()
|
264
261
|
|
262
|
+
self.skip_flags = {q.question_name: False for q in self.survey.questions}
|
263
|
+
|
265
264
|
# was "self.tasks" - is that necessary?
|
266
265
|
self.tasks = self.task_manager.build_question_tasks(
|
267
266
|
answer_func=AnswerQuestionFunctionConstructor(
|
@@ -310,6 +309,10 @@ class Interview:
|
|
310
309
|
def handle_task(task, invigilator):
|
311
310
|
try:
|
312
311
|
result: Answers = task.result()
|
312
|
+
if result == "skipped":
|
313
|
+
result = invigilator.get_failed_task_result(
|
314
|
+
failure_reason="Task was skipped."
|
315
|
+
)
|
313
316
|
except asyncio.CancelledError as e: # task was cancelled
|
314
317
|
result = invigilator.get_failed_task_result(
|
315
318
|
failure_reason="Task was cancelled."
|
@@ -379,8 +379,10 @@ class LanguageModel(
|
|
379
379
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
380
380
|
|
381
381
|
if cache_used := cached_response is not None:
|
382
|
+
# print("cache used")
|
382
383
|
response = json.loads(cached_response)
|
383
384
|
else:
|
385
|
+
# print("cache not used")
|
384
386
|
f = (
|
385
387
|
self.remote_async_execute_model_call
|
386
388
|
if hasattr(self, "remote") and self.remote
|
@@ -400,7 +402,10 @@ class LanguageModel(
|
|
400
402
|
) # store the response in the cache
|
401
403
|
assert new_cache_key == cache_key # should be the same
|
402
404
|
|
405
|
+
#breakpoint()
|
406
|
+
|
403
407
|
cost = self.cost(response)
|
408
|
+
#breakpoint()
|
404
409
|
return ModelResponse(
|
405
410
|
response=response,
|
406
411
|
cache_used=cache_used,
|
@@ -465,6 +470,7 @@ class LanguageModel(
|
|
465
470
|
model_outputs=model_outputs,
|
466
471
|
edsl_dict=edsl_dict,
|
467
472
|
)
|
473
|
+
#breakpoint()
|
468
474
|
return agent_response_dict
|
469
475
|
|
470
476
|
get_response = sync_wrapper(async_get_response)
|
@@ -140,6 +140,8 @@ class QuestionBaseGenMixin:
|
|
140
140
|
k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
|
141
141
|
}
|
142
142
|
|
143
|
+
strings_only_replacement_dict['scenario'] = strings_only_replacement_dict
|
144
|
+
|
143
145
|
def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
|
144
146
|
"""Check if the template string has any unrendered variables."""
|
145
147
|
if not isinstance(template_str, str):
|
@@ -735,11 +735,14 @@ class DatasetExportMixin:
|
|
735
735
|
"""
|
736
736
|
Flatten a field containing a list of dictionaries into separate fields.
|
737
737
|
|
738
|
-
|
739
|
-
[{'
|
738
|
+
>>> from edsl.results.Dataset import Dataset
|
739
|
+
>>> Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5] }]).flatten('a')
|
740
|
+
Dataset([{'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
|
741
|
+
|
742
|
+
|
743
|
+
>>> Dataset([{'answer.example': [{'a': 1, 'b': 2}]}, {'c': [5] }]).flatten('answer.example')
|
744
|
+
Dataset([{'c': [5]}, {'answer.example.a': [1]}, {'answer.example.b': [2]}])
|
740
745
|
|
741
|
-
After d.flatten('data'), it should become:
|
742
|
-
[{'other': ['x', 'y'], 'data.a': [1, None], 'data.b': [None, 2]}]
|
743
746
|
|
744
747
|
Args:
|
745
748
|
field: The field to flatten
|
@@ -753,6 +756,24 @@ class DatasetExportMixin:
|
|
753
756
|
# Ensure the dataset isn't empty
|
754
757
|
if not self.data:
|
755
758
|
return self.copy()
|
759
|
+
|
760
|
+
# Find all columns that contain the field
|
761
|
+
matching_entries = []
|
762
|
+
for entry in self.data:
|
763
|
+
col_name = next(iter(entry.keys()))
|
764
|
+
if field == col_name or (
|
765
|
+
'.' in col_name and
|
766
|
+
(col_name.endswith('.' + field) or col_name.startswith(field + '.'))
|
767
|
+
):
|
768
|
+
matching_entries.append(entry)
|
769
|
+
|
770
|
+
# Check if the field is ambiguous
|
771
|
+
if len(matching_entries) > 1:
|
772
|
+
matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
|
773
|
+
raise ValueError(
|
774
|
+
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
775
|
+
f"Please specify the full column name to flatten."
|
776
|
+
)
|
756
777
|
|
757
778
|
# Get the number of observations
|
758
779
|
num_observations = self.num_observations()
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -436,35 +436,98 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
436
436
|
new_scenarios.append(new_scenario)
|
437
437
|
return ScenarioList(new_scenarios)
|
438
438
|
|
439
|
-
def
|
440
|
-
"""
|
441
|
-
|
439
|
+
def _concatenate(self, fields: List[str], output_type: str = "string", separator: str = ";") -> ScenarioList:
|
440
|
+
"""Private method to handle concatenation logic for different output types.
|
441
|
+
|
442
442
|
:param fields: The fields to concatenate.
|
443
|
-
:param
|
444
|
-
|
443
|
+
:param output_type: The type of output ("string", "list", or "set").
|
444
|
+
:param separator: The separator to use for string concatenation.
|
445
|
+
|
445
446
|
Returns:
|
446
447
|
ScenarioList: A new ScenarioList with concatenated fields.
|
447
|
-
|
448
|
-
Example:
|
449
|
-
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2, 'c': 3}), Scenario({'a': 4, 'b': 5, 'c': 6})])
|
450
|
-
>>> s.concatenate(['a', 'b', 'c'])
|
451
|
-
ScenarioList([Scenario({'concat_a_b_c': '1;2;3'}), Scenario({'concat_a_b_c': '4;5;6'})])
|
452
448
|
"""
|
449
|
+
# Check if fields is a string and raise an exception
|
450
|
+
if isinstance(fields, str):
|
451
|
+
raise ScenarioError(
|
452
|
+
f"The 'fields' parameter must be a list of field names, not a string. Got '{fields}'."
|
453
|
+
)
|
454
|
+
|
453
455
|
new_scenarios = []
|
454
456
|
for scenario in self:
|
455
457
|
new_scenario = scenario.copy()
|
456
|
-
|
458
|
+
values = []
|
457
459
|
for field in fields:
|
458
460
|
if field in new_scenario:
|
459
|
-
|
461
|
+
values.append(new_scenario[field])
|
460
462
|
del new_scenario[field]
|
461
463
|
|
462
464
|
new_field_name = f"concat_{'_'.join(fields)}"
|
463
|
-
|
465
|
+
|
466
|
+
if output_type == "string":
|
467
|
+
# Convert all values to strings and join with separator
|
468
|
+
new_scenario[new_field_name] = separator.join(str(v) for v in values)
|
469
|
+
elif output_type == "list":
|
470
|
+
# Keep as a list
|
471
|
+
new_scenario[new_field_name] = values
|
472
|
+
elif output_type == "set":
|
473
|
+
# Convert to a set (removes duplicates)
|
474
|
+
new_scenario[new_field_name] = set(values)
|
475
|
+
else:
|
476
|
+
raise ValueError(f"Invalid output_type: {output_type}. Must be 'string', 'list', or 'set'.")
|
477
|
+
|
464
478
|
new_scenarios.append(new_scenario)
|
465
479
|
|
466
480
|
return ScenarioList(new_scenarios)
|
467
481
|
|
482
|
+
def concatenate(self, fields: List[str], separator: str = ";") -> ScenarioList:
|
483
|
+
"""Concatenate specified fields into a single string field.
|
484
|
+
|
485
|
+
:param fields: The fields to concatenate.
|
486
|
+
:param separator: The separator to use.
|
487
|
+
|
488
|
+
Returns:
|
489
|
+
ScenarioList: A new ScenarioList with concatenated fields.
|
490
|
+
|
491
|
+
Example:
|
492
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2, 'c': 3}), Scenario({'a': 4, 'b': 5, 'c': 6})])
|
493
|
+
>>> s.concatenate(['a', 'b', 'c'])
|
494
|
+
ScenarioList([Scenario({'concat_a_b_c': '1;2;3'}), Scenario({'concat_a_b_c': '4;5;6'})])
|
495
|
+
"""
|
496
|
+
return self._concatenate(fields, output_type="string", separator=separator)
|
497
|
+
|
498
|
+
def concatenate_to_list(self, fields: List[str]) -> ScenarioList:
|
499
|
+
"""Concatenate specified fields into a single list field.
|
500
|
+
|
501
|
+
:param fields: The fields to concatenate.
|
502
|
+
|
503
|
+
Returns:
|
504
|
+
ScenarioList: A new ScenarioList with fields concatenated into a list.
|
505
|
+
|
506
|
+
Example:
|
507
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2, 'c': 3}), Scenario({'a': 4, 'b': 5, 'c': 6})])
|
508
|
+
>>> s.concatenate_to_list(['a', 'b', 'c'])
|
509
|
+
ScenarioList([Scenario({'concat_a_b_c': [1, 2, 3]}), Scenario({'concat_a_b_c': [4, 5, 6]})])
|
510
|
+
"""
|
511
|
+
return self._concatenate(fields, output_type="list")
|
512
|
+
|
513
|
+
def concatenate_to_set(self, fields: List[str]) -> ScenarioList:
|
514
|
+
"""Concatenate specified fields into a single set field.
|
515
|
+
|
516
|
+
:param fields: The fields to concatenate.
|
517
|
+
|
518
|
+
Returns:
|
519
|
+
ScenarioList: A new ScenarioList with fields concatenated into a set.
|
520
|
+
|
521
|
+
Example:
|
522
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2, 'c': 3}), Scenario({'a': 4, 'b': 5, 'c': 6})])
|
523
|
+
>>> s.concatenate_to_set(['a', 'b', 'c'])
|
524
|
+
ScenarioList([Scenario({'concat_a_b_c': {1, 2, 3}}), Scenario({'concat_a_b_c': {4, 5, 6}})])
|
525
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1, 'c': 3})])
|
526
|
+
>>> s.concatenate_to_set(['a', 'b', 'c'])
|
527
|
+
ScenarioList([Scenario({'concat_a_b_c': {1, 3}})])
|
528
|
+
"""
|
529
|
+
return self._concatenate(fields, output_type="set")
|
530
|
+
|
468
531
|
def unpack_dict(
|
469
532
|
self, field: str, prefix: Optional[str] = None, drop_field: bool = False
|
470
533
|
) -> ScenarioList:
|
@@ -937,16 +1000,42 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
937
1000
|
# return new_list
|
938
1001
|
|
939
1002
|
@classmethod
|
940
|
-
def from_sqlite(cls, filepath: str, table: str):
|
941
|
-
"""Create a ScenarioList from a SQLite database.
|
1003
|
+
def from_sqlite(cls, filepath: str, table: Optional[str] = None, sql_query: Optional[str] = None):
|
1004
|
+
"""Create a ScenarioList from a SQLite database.
|
1005
|
+
|
1006
|
+
Args:
|
1007
|
+
filepath (str): Path to the SQLite database file
|
1008
|
+
table (Optional[str]): Name of table to query. If None, sql_query must be provided.
|
1009
|
+
sql_query (Optional[str]): SQL query to execute. Used if table is None.
|
1010
|
+
|
1011
|
+
Returns:
|
1012
|
+
ScenarioList: List of scenarios created from database rows
|
1013
|
+
|
1014
|
+
Raises:
|
1015
|
+
ValueError: If both table and sql_query are None
|
1016
|
+
sqlite3.Error: If there is an error executing the database query
|
1017
|
+
"""
|
942
1018
|
import sqlite3
|
943
1019
|
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
1020
|
+
if table is None and sql_query is None:
|
1021
|
+
raise ValueError("Either table or sql_query must be provided")
|
1022
|
+
|
1023
|
+
try:
|
1024
|
+
with sqlite3.connect(filepath) as conn:
|
1025
|
+
cursor = conn.cursor()
|
1026
|
+
|
1027
|
+
if table is not None:
|
1028
|
+
cursor.execute(f"SELECT * FROM {table}")
|
1029
|
+
else:
|
1030
|
+
cursor.execute(sql_query)
|
1031
|
+
|
1032
|
+
columns = [description[0] for description in cursor.description]
|
1033
|
+
data = cursor.fetchall()
|
1034
|
+
|
1035
|
+
return cls([Scenario(dict(zip(columns, row))) for row in data])
|
1036
|
+
|
1037
|
+
except sqlite3.Error as e:
|
1038
|
+
raise sqlite3.Error(f"Database error occurred: {str(e)}")
|
950
1039
|
|
951
1040
|
@classmethod
|
952
1041
|
def from_latex(cls, tex_file_path: str):
|
@@ -1540,6 +1629,49 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1540
1629
|
new_scenarios.extend(replacement_scenarios)
|
1541
1630
|
return ScenarioList(new_scenarios)
|
1542
1631
|
|
1632
|
+
def collapse(self, field: str) -> ScenarioList:
|
1633
|
+
"""Collapse a ScenarioList by grouping on all fields except the specified one,
|
1634
|
+
collecting the values of the specified field into a list.
|
1635
|
+
|
1636
|
+
Args:
|
1637
|
+
field: The field to collapse (whose values will be collected into lists)
|
1638
|
+
|
1639
|
+
Returns:
|
1640
|
+
ScenarioList: A new ScenarioList with the specified field collapsed into lists
|
1641
|
+
|
1642
|
+
Example:
|
1643
|
+
>>> s = ScenarioList([
|
1644
|
+
... Scenario({'category': 'fruit', 'color': 'red', 'item': 'apple'}),
|
1645
|
+
... Scenario({'category': 'fruit', 'color': 'yellow', 'item': 'banana'}),
|
1646
|
+
... Scenario({'category': 'fruit', 'color': 'red', 'item': 'cherry'}),
|
1647
|
+
... Scenario({'category': 'vegetable', 'color': 'green', 'item': 'spinach'})
|
1648
|
+
... ])
|
1649
|
+
>>> s.collapse('item')
|
1650
|
+
ScenarioList([Scenario({'category': 'fruit', 'color': 'red', 'item': ['apple', 'cherry']}), Scenario({'category': 'fruit', 'color': 'yellow', 'item': ['banana']}), Scenario({'category': 'vegetable', 'color': 'green', 'item': ['spinach']})])
|
1651
|
+
"""
|
1652
|
+
if not self:
|
1653
|
+
return ScenarioList([])
|
1654
|
+
|
1655
|
+
# Determine all fields except the one to collapse
|
1656
|
+
id_vars = [key for key in self[0].keys() if key != field]
|
1657
|
+
|
1658
|
+
# Group the scenarios
|
1659
|
+
grouped = defaultdict(list)
|
1660
|
+
for scenario in self:
|
1661
|
+
# Create a tuple of the values of all fields except the one to collapse
|
1662
|
+
key = tuple(scenario[id_var] for id_var in id_vars)
|
1663
|
+
# Add the value of the field to collapse to the list for this key
|
1664
|
+
grouped[key].append(scenario[field])
|
1665
|
+
|
1666
|
+
# Create a new ScenarioList with the collapsed field
|
1667
|
+
result = []
|
1668
|
+
for key, values in grouped.items():
|
1669
|
+
new_scenario = dict(zip(id_vars, key))
|
1670
|
+
new_scenario[field] = values
|
1671
|
+
result.append(Scenario(new_scenario))
|
1672
|
+
|
1673
|
+
return ScenarioList(result)
|
1674
|
+
|
1543
1675
|
|
1544
1676
|
if __name__ == "__main__":
|
1545
1677
|
import doctest
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: edsl
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.46
|
4
4
|
Summary: Create and analyze LLM-based surveys
|
5
5
|
Home-page: https://www.expectedparrot.com/
|
6
6
|
License: MIT
|
@@ -242,5 +242,5 @@ An integrated platform for running experiments, sharing workflows and launching
|
|
242
242
|
- <a href="https://blog.expectedparrot.com" target="_blank" rel="noopener noreferrer">Blog</a>
|
243
243
|
|
244
244
|
## Contact
|
245
|
-
- <a href="mailto:info@expectedparrot.com" target="_blank" rel="noopener noreferrer">Email</a
|
245
|
+
- <a href="mailto:info@expectedparrot.com" target="_blank" rel="noopener noreferrer">Email</a>
|
246
246
|
|