edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +92 -41
- edsl/agents/AgentList.py +15 -2
- edsl/agents/InvigilatorBase.py +15 -25
- edsl/agents/PromptConstructor.py +149 -108
- edsl/agents/descriptors.py +17 -4
- edsl/conjure/AgentConstructionMixin.py +11 -3
- edsl/conversation/Conversation.py +66 -14
- edsl/conversation/chips.py +95 -0
- edsl/coop/coop.py +148 -39
- edsl/data/Cache.py +1 -1
- edsl/data/RemoteCacheSync.py +25 -12
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +7 -3
- edsl/exceptions/agents.py +17 -19
- edsl/exceptions/results.py +11 -8
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AwsBedrock.py +7 -2
- edsl/inference_services/InferenceServicesCollection.py +42 -13
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/jobs/Jobs.py +306 -71
- edsl/jobs/interviews/Interview.py +24 -14
- edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
- edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
- edsl/jobs/interviews/ReportErrors.py +2 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
- edsl/jobs/tasks/TaskHistory.py +1 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +47 -59
- edsl/language_models/__init__.py +1 -0
- edsl/prompts/Prompt.py +11 -12
- edsl/questions/QuestionBase.py +53 -13
- edsl/questions/QuestionBasePromptsMixin.py +1 -33
- edsl/questions/QuestionFreeText.py +1 -0
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +23 -28
- edsl/results/DatasetExportMixin.py +25 -1
- edsl/results/Result.py +27 -10
- edsl/results/Results.py +34 -121
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +18 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +52 -13
- edsl/scenarios/ScenarioHtmlMixin.py +7 -2
- edsl/scenarios/ScenarioList.py +12 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Rule.py +10 -4
- edsl/surveys/Survey.py +100 -77
- edsl/utilities/utilities.py +18 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
edsl/questions/descriptors.py
CHANGED
@@ -53,33 +53,12 @@ class BaseDescriptor(ABC):
|
|
53
53
|
|
54
54
|
def __set__(self, instance, value: Any) -> None:
|
55
55
|
"""Set the value of the attribute."""
|
56
|
-
self.validate(value, instance)
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
# if value is not None:
|
63
|
-
# instance.__dict__[self.name] = instructions
|
64
|
-
# instance.set_instructions = True
|
65
|
-
# else:
|
66
|
-
# potential_prompt_classes = get_classes(
|
67
|
-
# question_type=instance.question_type
|
68
|
-
# )
|
69
|
-
# if len(potential_prompt_classes) > 0:
|
70
|
-
# instructions = potential_prompt_classes[0]().text
|
71
|
-
# instance.__dict__[self.name] = instructions
|
72
|
-
# instance.set_instructions = False
|
73
|
-
# else:
|
74
|
-
# if not hasattr(instance, "default_instructions"):
|
75
|
-
# raise Exception(
|
76
|
-
# "No default instructions found and no matching prompts!"
|
77
|
-
# )
|
78
|
-
# instructions = instance.default_instructions
|
79
|
-
# instance.__dict__[self.name] = instructions
|
80
|
-
# instance.set_instructions = False
|
81
|
-
|
82
|
-
# instance.set_instructions = value != instance.default_instructions
|
56
|
+
new_value = self.validate(value, instance)
|
57
|
+
|
58
|
+
if new_value is not None:
|
59
|
+
instance.__dict__[self.name] = new_value
|
60
|
+
else:
|
61
|
+
instance.__dict__[self.name] = value
|
83
62
|
|
84
63
|
def __set_name__(self, owner, name: str) -> None:
|
85
64
|
"""Set the name of the attribute."""
|
@@ -400,10 +379,24 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
400
379
|
if contains_single_braced_substring(value):
|
401
380
|
import warnings
|
402
381
|
|
382
|
+
# # warnings.warn(
|
383
|
+
# # f"WARNING: Question text contains a single-braced substring: If you intended to parameterize the question with a Scenario this should be changed to a double-braced substring, e.g. {{variable}}.\nSee details on constructing Scenarios in the docs: https://docs.expectedparrot.com/en/latest/scenarios.html",
|
384
|
+
# # UserWarning,
|
385
|
+
# # )
|
403
386
|
warnings.warn(
|
404
|
-
|
387
|
+
"WARNING: Question text contains a single-braced substring. "
|
388
|
+
"If you intended to parameterize the question with a Scenario, this will "
|
389
|
+
"be changed to a double-braced substring, e.g. {{variable}}.\n"
|
390
|
+
"See details on constructing Scenarios in the docs: "
|
391
|
+
"https://docs.expectedparrot.com/en/latest/scenarios.html",
|
405
392
|
UserWarning,
|
406
393
|
)
|
394
|
+
# Automatically replace single braces with double braces
|
395
|
+
# This is here because if the user is using an f-string, the double brace will get converted to a single brace.
|
396
|
+
# This undoes that.
|
397
|
+
value = re.sub(r"\{([^\{\}]+)\}", r"{{\1}}", value)
|
398
|
+
return value
|
399
|
+
|
407
400
|
# iterate through all doubles braces and check if they are valid python identifiers
|
408
401
|
for match in re.finditer(r"\{\{([^\{\}]+)\}\}", value):
|
409
402
|
if " " in match.group(1).strip():
|
@@ -411,6 +404,8 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
411
404
|
f"Question text contains an invalid identifier: '{match.group(1)}'"
|
412
405
|
)
|
413
406
|
|
407
|
+
return None
|
408
|
+
|
414
409
|
|
415
410
|
if __name__ == "__main__":
|
416
411
|
import doctest
|
@@ -437,7 +437,30 @@ class DatasetExportMixin:
|
|
437
437
|
b64 = base64.b64encode(csv_string.encode()).decode()
|
438
438
|
return f'<a href="data:file/csv;base64,{b64}" download="my_data.csv">Download CSV file</a>'
|
439
439
|
|
440
|
-
def to_pandas(
|
440
|
+
def to_pandas(
|
441
|
+
self, remove_prefix: bool = False, lists_as_strings=False
|
442
|
+
) -> "DataFrame":
|
443
|
+
"""Convert the results to a pandas DataFrame, ensuring that lists remain as lists.
|
444
|
+
|
445
|
+
:param remove_prefix: Whether to remove the prefix from the column names.
|
446
|
+
|
447
|
+
"""
|
448
|
+
return self._to_pandas_strings(remove_prefix)
|
449
|
+
# if lists_as_strings:
|
450
|
+
# return self._to_pandas_strings(remove_prefix=remove_prefix)
|
451
|
+
|
452
|
+
# import pandas as pd
|
453
|
+
|
454
|
+
# df = pd.DataFrame(self.data)
|
455
|
+
|
456
|
+
# if remove_prefix:
|
457
|
+
# # Optionally remove prefixes from column names
|
458
|
+
# df.columns = [col.split(".")[-1] for col in df.columns]
|
459
|
+
|
460
|
+
# df_sorted = df.sort_index(axis=1) # Sort columns alphabetically
|
461
|
+
# return df_sorted
|
462
|
+
|
463
|
+
def _to_pandas_strings(self, remove_prefix: bool = False) -> "pd.DataFrame":
|
441
464
|
"""Convert the results to a pandas DataFrame.
|
442
465
|
|
443
466
|
:param remove_prefix: Whether to remove the prefix from the column names.
|
@@ -451,6 +474,7 @@ class DatasetExportMixin:
|
|
451
474
|
2 Terrible
|
452
475
|
3 OK
|
453
476
|
"""
|
477
|
+
|
454
478
|
import pandas as pd
|
455
479
|
|
456
480
|
csv_string = self.to_csv(remove_prefix=remove_prefix)
|
edsl/results/Result.py
CHANGED
@@ -117,6 +117,7 @@ class Result(Base, UserDict):
|
|
117
117
|
"raw_model_response": raw_model_response or {},
|
118
118
|
"question_to_attributes": question_to_attributes,
|
119
119
|
"generated_tokens": generated_tokens or {},
|
120
|
+
"comments_dict": comments_dict or {},
|
120
121
|
}
|
121
122
|
super().__init__(**data)
|
122
123
|
# but also store the data as attributes
|
@@ -155,15 +156,15 @@ class Result(Base, UserDict):
|
|
155
156
|
if key in self.question_to_attributes:
|
156
157
|
# You might be tempted to just use the naked key
|
157
158
|
# but this is a bad idea because it pollutes the namespace
|
158
|
-
question_text_dict[
|
159
|
-
|
160
|
-
|
161
|
-
question_options_dict[
|
162
|
-
|
163
|
-
|
164
|
-
question_type_dict[
|
165
|
-
|
166
|
-
|
159
|
+
question_text_dict[
|
160
|
+
key + "_question_text"
|
161
|
+
] = self.question_to_attributes[key]["question_text"]
|
162
|
+
question_options_dict[
|
163
|
+
key + "_question_options"
|
164
|
+
] = self.question_to_attributes[key]["question_options"]
|
165
|
+
question_type_dict[
|
166
|
+
key + "_question_type"
|
167
|
+
] = self.question_to_attributes[key]["question_type"]
|
167
168
|
|
168
169
|
return {
|
169
170
|
"agent": self.agent.traits
|
@@ -256,10 +257,25 @@ class Result(Base, UserDict):
|
|
256
257
|
|
257
258
|
"""
|
258
259
|
d = {}
|
259
|
-
|
260
|
+
problem_keys = []
|
261
|
+
data_types = sorted(self.sub_dicts.keys())
|
260
262
|
for data_type in data_types:
|
261
263
|
for key in self.sub_dicts[data_type]:
|
264
|
+
if key in d:
|
265
|
+
import warnings
|
266
|
+
|
267
|
+
warnings.warn(
|
268
|
+
f"Key '{key}' of data type '{data_type}' is already in use. Renaming to {key}_{data_type}"
|
269
|
+
)
|
270
|
+
problem_keys.append((key, data_type))
|
271
|
+
key = f"{key}_{data_type}"
|
272
|
+
# raise ValueError(f"Key '{key}' is already in the dictionary")
|
262
273
|
d[key] = data_type
|
274
|
+
|
275
|
+
for key, data_type in problem_keys:
|
276
|
+
self.sub_dicts[data_type][f"{key}_{data_type}"] = self.sub_dicts[
|
277
|
+
data_type
|
278
|
+
].pop(key)
|
263
279
|
return d
|
264
280
|
|
265
281
|
def rows(self, index) -> tuple[int, str, str, str]:
|
@@ -370,6 +386,7 @@ class Result(Base, UserDict):
|
|
370
386
|
),
|
371
387
|
question_to_attributes=json_dict.get("question_to_attributes", None),
|
372
388
|
generated_tokens=json_dict.get("generated_tokens", {}),
|
389
|
+
comments_dict=json_dict.get("comments_dict", {}),
|
373
390
|
)
|
374
391
|
return result
|
375
392
|
|
edsl/results/Results.py
CHANGED
@@ -7,11 +7,17 @@ from __future__ import annotations
|
|
7
7
|
import json
|
8
8
|
import random
|
9
9
|
from collections import UserList, defaultdict
|
10
|
-
from typing import Optional, Callable, Any, Type, Union, List
|
10
|
+
from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
|
14
|
+
from edsl.results.Result import Result
|
15
|
+
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
11
16
|
|
12
17
|
from simpleeval import EvalWithCompoundTypes
|
13
18
|
|
14
19
|
from edsl.exceptions.results import (
|
20
|
+
ResultsError,
|
15
21
|
ResultsBadMutationstringError,
|
16
22
|
ResultsColumnNotFoundError,
|
17
23
|
ResultsInvalidNameError,
|
@@ -40,7 +46,7 @@ class Mixins(
|
|
40
46
|
ResultsGGMixin,
|
41
47
|
ResultsToolsMixin,
|
42
48
|
):
|
43
|
-
def print_long(self, max_rows=None) -> None:
|
49
|
+
def print_long(self, max_rows: int = None) -> None:
|
44
50
|
"""Print the results in long format.
|
45
51
|
|
46
52
|
>>> from edsl.results import Results
|
@@ -84,13 +90,13 @@ class Results(UserList, Mixins, Base):
|
|
84
90
|
|
85
91
|
def __init__(
|
86
92
|
self,
|
87
|
-
survey: Optional[
|
88
|
-
data: Optional[list[
|
93
|
+
survey: Optional[Survey] = None,
|
94
|
+
data: Optional[list[Result]] = None,
|
89
95
|
created_columns: Optional[list[str]] = None,
|
90
|
-
cache: Optional[
|
96
|
+
cache: Optional[Cache] = None,
|
91
97
|
job_uuid: Optional[str] = None,
|
92
98
|
total_results: Optional[int] = None,
|
93
|
-
task_history: Optional[
|
99
|
+
task_history: Optional[TaskHistory] = None,
|
94
100
|
):
|
95
101
|
"""Instantiate a `Results` object with a survey and a list of `Result` objects.
|
96
102
|
|
@@ -110,7 +116,7 @@ class Results(UserList, Mixins, Base):
|
|
110
116
|
self._total_results = total_results
|
111
117
|
self.cache = cache or Cache()
|
112
118
|
|
113
|
-
self.task_history = task_history or TaskHistory(interviews
|
119
|
+
self.task_history = task_history or TaskHistory(interviews=[])
|
114
120
|
|
115
121
|
if hasattr(self, "_add_output_functions"):
|
116
122
|
self._add_output_functions()
|
@@ -235,11 +241,11 @@ class Results(UserList, Mixins, Base):
|
|
235
241
|
>>> r3 = r + r2
|
236
242
|
"""
|
237
243
|
if self.survey != other.survey:
|
238
|
-
raise
|
239
|
-
"The surveys are not the same so
|
244
|
+
raise ResultsError(
|
245
|
+
"The surveys are not the same so the the results cannot be added together."
|
240
246
|
)
|
241
247
|
if self.created_columns != other.created_columns:
|
242
|
-
raise
|
248
|
+
raise ResultsError(
|
243
249
|
"The created columns are not the same so they cannot be added together."
|
244
250
|
)
|
245
251
|
|
@@ -258,16 +264,7 @@ class Results(UserList, Mixins, Base):
|
|
258
264
|
from IPython.display import HTML
|
259
265
|
|
260
266
|
json_str = json.dumps(self.to_dict()["data"], indent=4)
|
261
|
-
|
262
|
-
from pygments.lexers import JsonLexer
|
263
|
-
from pygments.formatters import HtmlFormatter
|
264
|
-
|
265
|
-
formatted_json = highlight(
|
266
|
-
json_str,
|
267
|
-
JsonLexer(),
|
268
|
-
HtmlFormatter(style="default", full=True, noclasses=True),
|
269
|
-
)
|
270
|
-
return HTML(formatted_json).data
|
267
|
+
return f"<pre>{json_str}</pre>"
|
271
268
|
|
272
269
|
def _to_dict(self, sort=False):
|
273
270
|
from edsl.data.Cache import Cache
|
@@ -301,7 +298,7 @@ class Results(UserList, Mixins, Base):
|
|
301
298
|
"b_not_a": [other_results[i] for i in indices_other],
|
302
299
|
}
|
303
300
|
|
304
|
-
@property
|
301
|
+
@property
|
305
302
|
def has_unfixed_exceptions(self):
|
306
303
|
return self.task_history.has_unfixed_exceptions
|
307
304
|
|
@@ -326,7 +323,7 @@ class Results(UserList, Mixins, Base):
|
|
326
323
|
def hashes(self) -> set:
|
327
324
|
return set(hash(result) for result in self.data)
|
328
325
|
|
329
|
-
def sample(self, n: int) ->
|
326
|
+
def sample(self, n: int) -> Results:
|
330
327
|
"""Return a random sample of the results.
|
331
328
|
|
332
329
|
:param n: The number of samples to return.
|
@@ -344,7 +341,7 @@ class Results(UserList, Mixins, Base):
|
|
344
341
|
indices = list(range(len(values)))
|
345
342
|
sampled_indices = random.sample(indices, n)
|
346
343
|
if n > len(indices):
|
347
|
-
raise
|
344
|
+
raise ResultsError(
|
348
345
|
f"Cannot sample {n} items from a list of length {len(indices)}."
|
349
346
|
)
|
350
347
|
entry[key] = [values[i] for i in sampled_indices]
|
@@ -397,11 +394,12 @@ class Results(UserList, Mixins, Base):
|
|
397
394
|
- Uses the key_to_data_type property of the Result class.
|
398
395
|
- Includes any columns that the user has created with `mutate`
|
399
396
|
"""
|
400
|
-
d = {}
|
397
|
+
d: dict = {}
|
401
398
|
for result in self.data:
|
402
399
|
d.update(result.key_to_data_type)
|
403
400
|
for column in self.created_columns:
|
404
401
|
d[column] = "answer"
|
402
|
+
|
405
403
|
return d
|
406
404
|
|
407
405
|
@property
|
@@ -451,7 +449,7 @@ class Results(UserList, Mixins, Base):
|
|
451
449
|
from edsl.utilities.utilities import shorten_string
|
452
450
|
|
453
451
|
if not self.survey:
|
454
|
-
raise
|
452
|
+
raise ResultsError("Survey is not defined so no answer keys are available.")
|
455
453
|
|
456
454
|
answer_keys = self._data_type_to_keys["answer"]
|
457
455
|
answer_keys = {k for k in answer_keys if "_comment" not in k}
|
@@ -464,7 +462,7 @@ class Results(UserList, Mixins, Base):
|
|
464
462
|
return sorted_dict
|
465
463
|
|
466
464
|
@property
|
467
|
-
def agents(self) ->
|
465
|
+
def agents(self) -> AgentList:
|
468
466
|
"""Return a list of all of the agents in the Results.
|
469
467
|
|
470
468
|
Example:
|
@@ -478,7 +476,7 @@ class Results(UserList, Mixins, Base):
|
|
478
476
|
return AgentList([r.agent for r in self.data])
|
479
477
|
|
480
478
|
@property
|
481
|
-
def models(self) ->
|
479
|
+
def models(self) -> ModelList:
|
482
480
|
"""Return a list of all of the models in the Results.
|
483
481
|
|
484
482
|
Example:
|
@@ -487,10 +485,12 @@ class Results(UserList, Mixins, Base):
|
|
487
485
|
>>> r.models[0]
|
488
486
|
Model(model_name = ...)
|
489
487
|
"""
|
490
|
-
|
488
|
+
from edsl import ModelList
|
489
|
+
|
490
|
+
return ModelList([r.model for r in self.data])
|
491
491
|
|
492
492
|
@property
|
493
|
-
def scenarios(self) ->
|
493
|
+
def scenarios(self) -> ScenarioList:
|
494
494
|
"""Return a list of all of the scenarios in the Results.
|
495
495
|
|
496
496
|
Example:
|
@@ -567,7 +567,7 @@ class Results(UserList, Mixins, Base):
|
|
567
567
|
)
|
568
568
|
return sorted(list(all_keys))
|
569
569
|
|
570
|
-
def first(self) ->
|
570
|
+
def first(self) -> Result:
|
571
571
|
"""Return the first observation in the results.
|
572
572
|
|
573
573
|
Example:
|
@@ -817,7 +817,7 @@ class Results(UserList, Mixins, Base):
|
|
817
817
|
|
818
818
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
819
819
|
|
820
|
-
def select(self, *columns: Union[str, list[str]]) ->
|
820
|
+
def select(self, *columns: Union[str, list[str]]) -> Results:
|
821
821
|
"""
|
822
822
|
Select data from the results and format it.
|
823
823
|
|
@@ -830,93 +830,12 @@ class Results(UserList, Mixins, Base):
|
|
830
830
|
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
831
831
|
|
832
832
|
>>> results.select('how_feeling', 'model', 'how_feeling')
|
833
|
-
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
833
|
+
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
834
834
|
|
835
835
|
>>> from edsl import Results; r = Results.example(); r.select('answer.how_feeling_y')
|
836
836
|
Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
|
837
837
|
"""
|
838
838
|
|
839
|
-
# if len(self) == 0:
|
840
|
-
# raise Exception("No data to select from---the Results object is empty.")
|
841
|
-
|
842
|
-
if not columns or columns == ("*",) or columns == (None,):
|
843
|
-
# is the users passes nothing, then we'll return all the columns
|
844
|
-
columns = ("*.*",)
|
845
|
-
|
846
|
-
if isinstance(columns[0], list):
|
847
|
-
columns = tuple(columns[0])
|
848
|
-
|
849
|
-
def get_data_types_to_return(parsed_data_type):
|
850
|
-
if parsed_data_type == "*": # they want all of the columns
|
851
|
-
return self.known_data_types
|
852
|
-
else:
|
853
|
-
if parsed_data_type not in self.known_data_types:
|
854
|
-
raise Exception(
|
855
|
-
f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
|
856
|
-
)
|
857
|
-
return [parsed_data_type]
|
858
|
-
|
859
|
-
# we're doing to populate this with the data we want to fetch
|
860
|
-
to_fetch = defaultdict(list)
|
861
|
-
|
862
|
-
new_data = []
|
863
|
-
items_in_order = []
|
864
|
-
# iterate through the passed columns
|
865
|
-
for column in columns:
|
866
|
-
# a user could pass 'result.how_feeling' or just 'how_feeling'
|
867
|
-
matches = self._matching_columns(column)
|
868
|
-
if len(matches) > 1:
|
869
|
-
raise Exception(
|
870
|
-
f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
|
871
|
-
)
|
872
|
-
if len(matches) == 0 and ".*" not in column:
|
873
|
-
raise Exception(f"Column '{column}' not found in data.")
|
874
|
-
if len(matches) == 1:
|
875
|
-
column = matches[0]
|
876
|
-
|
877
|
-
parsed_data_type, parsed_key = self._parse_column(column)
|
878
|
-
data_types = get_data_types_to_return(parsed_data_type)
|
879
|
-
found_once = False # we need to track this to make sure we found the key at least once
|
880
|
-
|
881
|
-
for data_type in data_types:
|
882
|
-
# the keys for that data_type e.g.,# if data_type is 'answer', then the keys are 'how_feeling', 'how_feeling_comment', etc.
|
883
|
-
relevant_keys = self._data_type_to_keys[data_type]
|
884
|
-
|
885
|
-
for key in relevant_keys:
|
886
|
-
if key == parsed_key or parsed_key == "*":
|
887
|
-
found_once = True
|
888
|
-
to_fetch[data_type].append(key)
|
889
|
-
items_in_order.append(data_type + "." + key)
|
890
|
-
|
891
|
-
if not found_once:
|
892
|
-
raise Exception(f"Key {parsed_key} not found in data.")
|
893
|
-
|
894
|
-
for data_type in to_fetch:
|
895
|
-
for key in to_fetch[data_type]:
|
896
|
-
entries = self._fetch_list(data_type, key)
|
897
|
-
new_data.append({data_type + "." + key: entries})
|
898
|
-
|
899
|
-
def sort_by_key_order(dictionary):
|
900
|
-
# Extract the single key from the dictionary
|
901
|
-
single_key = next(iter(dictionary))
|
902
|
-
# Return the index of this key in the list_of_keys
|
903
|
-
return items_in_order.index(single_key)
|
904
|
-
|
905
|
-
# sorted(new_data, key=sort_by_key_order)
|
906
|
-
from edsl.results.Dataset import Dataset
|
907
|
-
|
908
|
-
sorted_new_data = []
|
909
|
-
|
910
|
-
# WORKS but slow
|
911
|
-
for key in items_in_order:
|
912
|
-
for d in new_data:
|
913
|
-
if key in d:
|
914
|
-
sorted_new_data.append(d)
|
915
|
-
break
|
916
|
-
|
917
|
-
return Dataset(sorted_new_data)
|
918
|
-
|
919
|
-
def select(self, *columns: Union[str, list[str]]) -> "Results":
|
920
839
|
from edsl.results.Selector import Selector
|
921
840
|
|
922
841
|
if len(self) == 0:
|
@@ -1026,6 +945,7 @@ class Results(UserList, Mixins, Base):
|
|
1026
945
|
Traceback (most recent call last):
|
1027
946
|
...
|
1028
947
|
edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
|
948
|
+
...
|
1029
949
|
|
1030
950
|
>>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
|
1031
951
|
┏━━━━━━━━━━━━━━┓
|
@@ -1103,6 +1023,7 @@ class Results(UserList, Mixins, Base):
|
|
1103
1023
|
stop_on_exception=True,
|
1104
1024
|
skip_retry=True,
|
1105
1025
|
raise_validation_errors=True,
|
1026
|
+
disable_remote_cache=True,
|
1106
1027
|
disable_remote_inference=True,
|
1107
1028
|
)
|
1108
1029
|
return results
|
@@ -1110,14 +1031,6 @@ class Results(UserList, Mixins, Base):
|
|
1110
1031
|
def rich_print(self):
|
1111
1032
|
"""Display an object as a table."""
|
1112
1033
|
pass
|
1113
|
-
# with io.StringIO() as buf:
|
1114
|
-
# console = Console(file=buf, record=True)
|
1115
|
-
|
1116
|
-
# for index, result in enumerate(self):
|
1117
|
-
# console.print(f"Result {index}")
|
1118
|
-
# console.print(result.rich_print())
|
1119
|
-
|
1120
|
-
# return console.export_text()
|
1121
1034
|
|
1122
1035
|
def __str__(self):
|
1123
1036
|
data = self.to_dict()["data"]
|
edsl/results/ResultsDBMixin.py
CHANGED
@@ -93,7 +93,7 @@ class ResultsDBMixin:
|
|
93
93
|
from sqlalchemy import create_engine
|
94
94
|
|
95
95
|
engine = create_engine("sqlite:///:memory:")
|
96
|
-
df = self.to_pandas(remove_prefix=remove_prefix)
|
96
|
+
df = self.to_pandas(remove_prefix=remove_prefix, lists_as_strings=True)
|
97
97
|
df.to_sql("self", engine, index=False, if_exists="replace")
|
98
98
|
return engine.connect()
|
99
99
|
else:
|
edsl/results/Selector.py
CHANGED
@@ -12,6 +12,7 @@ class Selector:
|
|
12
12
|
fetch_list_func,
|
13
13
|
columns: List[str],
|
14
14
|
):
|
15
|
+
"""Selects columns from a Results object"""
|
15
16
|
self.known_data_types = known_data_types
|
16
17
|
self._data_type_to_keys = data_type_to_keys
|
17
18
|
self._key_to_data_type = key_to_data_type
|
@@ -21,10 +22,19 @@ class Selector:
|
|
21
22
|
def select(self, *columns: Union[str, List[str]]) -> "Dataset":
|
22
23
|
columns = self._normalize_columns(columns)
|
23
24
|
to_fetch = self._get_columns_to_fetch(columns)
|
25
|
+
# breakpoint()
|
24
26
|
new_data = self._fetch_data(to_fetch)
|
25
27
|
return Dataset(new_data)
|
26
28
|
|
27
29
|
def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
|
30
|
+
"""Normalize the columns to a tuple of strings
|
31
|
+
|
32
|
+
>>> s = Selector([], {}, {}, lambda x, y: x, [])
|
33
|
+
>>> s._normalize_columns([["a", "b"], ])
|
34
|
+
('a', 'b')
|
35
|
+
>>> s._normalize_columns(None)
|
36
|
+
('*.*',)
|
37
|
+
"""
|
28
38
|
if not columns or columns == ("*",) or columns == (None,):
|
29
39
|
return ("*.*",)
|
30
40
|
if isinstance(columns[0], list):
|
@@ -37,6 +47,7 @@ class Selector:
|
|
37
47
|
|
38
48
|
for column in columns:
|
39
49
|
matches = self._find_matching_columns(column)
|
50
|
+
# breakpoint()
|
40
51
|
self._validate_matches(column, matches)
|
41
52
|
|
42
53
|
if len(matches) == 1:
|
@@ -52,7 +63,7 @@ class Selector:
|
|
52
63
|
search_in_list = self.columns
|
53
64
|
else:
|
54
65
|
search_in_list = [s.split(".")[1] for s in self.columns]
|
55
|
-
|
66
|
+
# breakpoint()
|
56
67
|
matches = [s for s in search_in_list if s.startswith(partial_name)]
|
57
68
|
return [partial_name] if partial_name in matches else matches
|
58
69
|
|
@@ -116,3 +127,9 @@ class Selector:
|
|
116
127
|
new_data.append({f"{data_type}.{key}": entries})
|
117
128
|
|
118
129
|
return [d for key in self.items_in_order for d in new_data if key in d]
|
130
|
+
|
131
|
+
|
132
|
+
if __name__ == "__main__":
|
133
|
+
import doctest
|
134
|
+
|
135
|
+
doctest.testmod()
|
edsl/scenarios/FileStore.py
CHANGED
@@ -77,8 +77,19 @@ class FileStore(Scenario):
|
|
77
77
|
def __str__(self):
|
78
78
|
return "FileStore: self.path"
|
79
79
|
|
80
|
+
@classmethod
|
81
|
+
def example(self):
|
82
|
+
import tempfile
|
83
|
+
|
84
|
+
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
85
|
+
f.write(b"Hello, World!")
|
86
|
+
|
87
|
+
return self(path=f.name)
|
88
|
+
|
80
89
|
@property
|
81
90
|
def size(self) -> int:
|
91
|
+
if self.base64_string != None:
|
92
|
+
return (len(self.base64_string) / 4.0) * 3 # from base64 to char size
|
82
93
|
return os.path.getsize(self.path)
|
83
94
|
|
84
95
|
def upload_google(self, refresh: bool = False) -> None:
|
@@ -93,7 +104,7 @@ class FileStore(Scenario):
|
|
93
104
|
return cls(**d)
|
94
105
|
|
95
106
|
def __repr__(self):
|
96
|
-
return f"FileStore({self.path})"
|
107
|
+
return f"FileStore(path='{self.path}')"
|
97
108
|
|
98
109
|
def encode_file_to_base64_string(self, file_path: str):
|
99
110
|
try:
|
@@ -272,7 +283,8 @@ class CSVFileStore(FileStore):
|
|
272
283
|
|
273
284
|
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
|
274
285
|
r.to_csv(filename=f.name)
|
275
|
-
|
286
|
+
|
287
|
+
return cls(f.name)
|
276
288
|
|
277
289
|
def view(self):
|
278
290
|
import pandas as pd
|
@@ -352,7 +364,8 @@ class PDFFileStore(FileStore):
|
|
352
364
|
|
353
365
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
354
366
|
f.write(pdf_string.encode())
|
355
|
-
|
367
|
+
|
368
|
+
return cls(f.name)
|
356
369
|
|
357
370
|
|
358
371
|
class PNGFileStore(FileStore):
|
@@ -367,7 +380,8 @@ class PNGFileStore(FileStore):
|
|
367
380
|
|
368
381
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
369
382
|
f.write(png_string.encode())
|
370
|
-
|
383
|
+
|
384
|
+
return cls(f.name)
|
371
385
|
|
372
386
|
def view(self):
|
373
387
|
import matplotlib.pyplot as plt
|
@@ -407,7 +421,8 @@ class HTMLFileStore(FileStore):
|
|
407
421
|
|
408
422
|
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f:
|
409
423
|
f.write("<html><body><h1>Test</h1></body></html>".encode())
|
410
|
-
|
424
|
+
|
425
|
+
return cls(f.name)
|
411
426
|
|
412
427
|
def view(self):
|
413
428
|
import webbrowser
|