edsl 0.1.30__py3-none-any.whl → 0.1.30.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -8
- edsl/agents/AgentList.py +19 -9
- edsl/agents/Invigilator.py +5 -4
- edsl/conversation/car_buying.py +1 -1
- edsl/data/Cache.py +16 -25
- edsl/data/CacheEntry.py +7 -6
- edsl/data_transfer_models.py +0 -4
- edsl/jobs/Jobs.py +2 -17
- edsl/jobs/buckets/ModelBuckets.py +0 -10
- edsl/jobs/buckets/TokenBucket.py +3 -31
- edsl/jobs/interviews/Interview.py +73 -99
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +19 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +0 -4
- edsl/jobs/tasks/QuestionTaskCreator.py +6 -10
- edsl/language_models/LanguageModel.py +6 -12
- edsl/notebooks/Notebook.py +9 -9
- edsl/questions/QuestionFreeText.py +2 -4
- edsl/questions/QuestionFunctional.py +2 -34
- edsl/questions/QuestionMultipleChoice.py +8 -57
- edsl/questions/descriptors.py +2 -42
- edsl/results/DatasetExportMixin.py +5 -84
- edsl/results/Result.py +5 -53
- edsl/results/Results.py +30 -70
- edsl/scenarios/FileStore.py +4 -163
- edsl/scenarios/Scenario.py +19 -12
- edsl/scenarios/ScenarioList.py +6 -8
- edsl/study/Study.py +7 -5
- edsl/surveys/Survey.py +12 -44
- {edsl-0.1.30.dist-info → edsl-0.1.30.dev1.dist-info}/METADATA +1 -1
- {edsl-0.1.30.dist-info → edsl-0.1.30.dev1.dist-info}/RECORD +33 -33
- {edsl-0.1.30.dist-info → edsl-0.1.30.dev1.dist-info}/WHEEL +1 -1
- {edsl-0.1.30.dist-info → edsl-0.1.30.dev1.dist-info}/LICENSE +0 -0
@@ -25,7 +25,7 @@ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
|
25
25
|
class InterviewTaskBuildingMixin:
|
26
26
|
def _build_invigilators(
|
27
27
|
self, debug: bool
|
28
|
-
) -> Generator[
|
28
|
+
) -> Generator[InvigilatorBase, None, None]:
|
29
29
|
"""Create an invigilator for each question.
|
30
30
|
|
31
31
|
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
@@ -35,7 +35,7 @@ class InterviewTaskBuildingMixin:
|
|
35
35
|
for question in self.survey.questions:
|
36
36
|
yield self._get_invigilator(question=question, debug=debug)
|
37
37
|
|
38
|
-
def _get_invigilator(self, question:
|
38
|
+
def _get_invigilator(self, question: QuestionBase, debug: bool) -> "Invigilator":
|
39
39
|
"""Return an invigilator for the given question.
|
40
40
|
|
41
41
|
:param question: the question to be answered
|
@@ -84,7 +84,7 @@ class InterviewTaskBuildingMixin:
|
|
84
84
|
return tuple(tasks) # , invigilators
|
85
85
|
|
86
86
|
def _get_tasks_that_must_be_completed_before(
|
87
|
-
self, *, tasks: list[asyncio.Task], question:
|
87
|
+
self, *, tasks: list[asyncio.Task], question: QuestionBase
|
88
88
|
) -> Generator[asyncio.Task, None, None]:
|
89
89
|
"""Return the tasks that must be completed before the given question can be answered.
|
90
90
|
|
@@ -100,7 +100,7 @@ class InterviewTaskBuildingMixin:
|
|
100
100
|
def _create_question_task(
|
101
101
|
self,
|
102
102
|
*,
|
103
|
-
question:
|
103
|
+
question: QuestionBase,
|
104
104
|
tasks_that_must_be_completed_before: list[asyncio.Task],
|
105
105
|
model_buckets: ModelBuckets,
|
106
106
|
debug: bool,
|
@@ -175,14 +175,24 @@ class InterviewTaskBuildingMixin:
|
|
175
175
|
|
176
176
|
self._add_answer(response=response, question=question)
|
177
177
|
|
178
|
+
# With the answer to the question, we can now cancel any skipped questions
|
178
179
|
self._cancel_skipped_questions(question)
|
179
180
|
return AgentResponseDict(**response)
|
180
181
|
except Exception as e:
|
181
182
|
raise e
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
183
|
+
# import traceback
|
184
|
+
# print("Exception caught:")
|
185
|
+
# traceback.print_exc()
|
186
|
+
|
187
|
+
# # Extract and print the traceback info
|
188
|
+
# tb = e.__traceback__
|
189
|
+
# while tb is not None:
|
190
|
+
# print(f"File {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}, in {tb.tb_frame.f_code.co_name}")
|
191
|
+
# tb = tb.tb_next
|
192
|
+
# breakpoint()
|
193
|
+
# raise e
|
194
|
+
|
195
|
+
def _add_answer(self, response: AgentResponseDict, question: QuestionBase) -> None:
|
186
196
|
"""Add the answer to the answers dictionary.
|
187
197
|
|
188
198
|
:param response: the response to the question.
|
@@ -190,7 +200,7 @@ class InterviewTaskBuildingMixin:
|
|
190
200
|
"""
|
191
201
|
self.answers.add_answer(response=response, question=question)
|
192
202
|
|
193
|
-
def _skip_this_question(self, current_question:
|
203
|
+
def _skip_this_question(self, current_question: QuestionBase) -> bool:
|
194
204
|
"""Determine if the current question should be skipped.
|
195
205
|
|
196
206
|
:param current_question: the question to be answered.
|
@@ -88,8 +88,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
88
88
|
self.total_interviews.append(interview)
|
89
89
|
|
90
90
|
async def run_async(self, cache=None) -> Results:
|
91
|
-
from edsl.results.Results import Results
|
92
|
-
|
93
91
|
if cache is None:
|
94
92
|
self.cache = Cache()
|
95
93
|
else:
|
@@ -100,8 +98,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
100
98
|
return Results(survey=self.jobs.survey, data=data)
|
101
99
|
|
102
100
|
def simple_run(self):
|
103
|
-
from edsl.results.Results import Results
|
104
|
-
|
105
101
|
data = asyncio.run(self.run_async())
|
106
102
|
return Results(survey=self.jobs.survey, data=data)
|
107
103
|
|
@@ -144,16 +144,12 @@ class QuestionTaskCreator(UserList):
|
|
144
144
|
self.task_status = TaskStatus.FAILED
|
145
145
|
raise e
|
146
146
|
|
147
|
-
if
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
self.requests_bucket.turbo_mode_on()
|
154
|
-
else:
|
155
|
-
self.tokens_bucket.turbo_mode_off()
|
156
|
-
self.requests_bucket.turbo_mode_off()
|
147
|
+
if "cached_response" in results:
|
148
|
+
if results["cached_response"]:
|
149
|
+
# Gives back the tokens b/c the API was not called.
|
150
|
+
self.tokens_bucket.add_tokens(requested_tokens)
|
151
|
+
self.requests_bucket.add_tokens(1)
|
152
|
+
self.from_cache = True
|
157
153
|
|
158
154
|
_ = results.pop("cached_response", None)
|
159
155
|
|
@@ -323,10 +323,12 @@ class LanguageModel(
|
|
323
323
|
image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
|
324
324
|
cache_call_params["user_prompt"] = f"{user_prompt} {image_hash}"
|
325
325
|
|
326
|
-
cached_response
|
326
|
+
cached_response = cache.fetch(**cache_call_params)
|
327
|
+
|
327
328
|
if cached_response:
|
328
329
|
response = json.loads(cached_response)
|
329
330
|
cache_used = True
|
331
|
+
cache_key = None
|
330
332
|
else:
|
331
333
|
remote_call = hasattr(self, "remote") and self.remote
|
332
334
|
f = (
|
@@ -338,7 +340,7 @@ class LanguageModel(
|
|
338
340
|
if encoded_image:
|
339
341
|
params["encoded_image"] = encoded_image
|
340
342
|
response = await f(**params)
|
341
|
-
|
343
|
+
cache_key = cache.store(
|
342
344
|
user_prompt=user_prompt,
|
343
345
|
model=str(self.model),
|
344
346
|
parameters=self.parameters,
|
@@ -346,7 +348,6 @@ class LanguageModel(
|
|
346
348
|
response=response,
|
347
349
|
iteration=iteration,
|
348
350
|
)
|
349
|
-
assert new_cache_key == cache_key
|
350
351
|
cache_used = False
|
351
352
|
|
352
353
|
return response, cache_used, cache_key
|
@@ -411,7 +412,7 @@ class LanguageModel(
|
|
411
412
|
|
412
413
|
dict_response.update(
|
413
414
|
{
|
414
|
-
"
|
415
|
+
"cached_used": cache_used,
|
415
416
|
"cache_key": cache_key,
|
416
417
|
"usage": raw_response.get("usage", {}),
|
417
418
|
"raw_model_response": raw_response,
|
@@ -494,12 +495,7 @@ class LanguageModel(
|
|
494
495
|
return table
|
495
496
|
|
496
497
|
@classmethod
|
497
|
-
def example(
|
498
|
-
cls,
|
499
|
-
test_model: bool = False,
|
500
|
-
canned_response: str = "Hello world",
|
501
|
-
throw_exception: bool = False,
|
502
|
-
):
|
498
|
+
def example(cls, test_model: bool = False, canned_response: str = "Hello world"):
|
503
499
|
"""Return a default instance of the class.
|
504
500
|
|
505
501
|
>>> from edsl.language_models import LanguageModel
|
@@ -524,8 +520,6 @@ class LanguageModel(
|
|
524
520
|
) -> dict[str, Any]:
|
525
521
|
await asyncio.sleep(0.1)
|
526
522
|
# return {"message": """{"answer": "Hello, world"}"""}
|
527
|
-
if throw_exception:
|
528
|
-
raise Exception("This is a test error")
|
529
523
|
return {"message": f'{{"answer": "{canned_response}"}}'}
|
530
524
|
|
531
525
|
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
edsl/notebooks/Notebook.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1
1
|
"""A Notebook is a utility class that allows you to easily share/pull ipynbs from Coop."""
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
3
|
import json
|
5
4
|
from typing import Dict, List, Optional
|
6
|
-
|
5
|
+
|
6
|
+
|
7
7
|
from edsl.Base import Base
|
8
|
-
from edsl.utilities.decorators import
|
8
|
+
from edsl.utilities.decorators import (
|
9
|
+
add_edsl_version,
|
10
|
+
remove_edsl_version,
|
11
|
+
)
|
9
12
|
|
10
13
|
|
11
14
|
class Notebook(Base):
|
@@ -189,13 +192,10 @@ class Notebook(Base):
|
|
189
192
|
return table
|
190
193
|
|
191
194
|
@classmethod
|
192
|
-
def example(cls
|
195
|
+
def example(cls) -> "Notebook":
|
193
196
|
"""
|
194
|
-
|
195
|
-
|
196
|
-
:param randomize: If True, adds a random string one of the cells' output.
|
197
|
+
Return an example Notebook.
|
197
198
|
"""
|
198
|
-
addition = "" if not randomize else str(uuid4())
|
199
199
|
cells = [
|
200
200
|
{
|
201
201
|
"cell_type": "markdown",
|
@@ -210,7 +210,7 @@ class Notebook(Base):
|
|
210
210
|
{
|
211
211
|
"name": "stdout",
|
212
212
|
"output_type": "stream",
|
213
|
-
"text":
|
213
|
+
"text": "Hello world!\n",
|
214
214
|
}
|
215
215
|
],
|
216
216
|
"source": 'print("Hello world!")',
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import textwrap
|
3
3
|
from typing import Any, Optional
|
4
|
-
from uuid import uuid4
|
5
4
|
from edsl.questions.QuestionBase import QuestionBase
|
6
5
|
|
7
6
|
|
@@ -66,10 +65,9 @@ class QuestionFreeText(QuestionBase):
|
|
66
65
|
return question_html_content
|
67
66
|
|
68
67
|
@classmethod
|
69
|
-
def example(cls
|
68
|
+
def example(cls) -> QuestionFreeText:
|
70
69
|
"""Return an example instance of a free text question."""
|
71
|
-
|
72
|
-
return cls(question_name="how_are_you", question_text=f"How are you?{addition}")
|
70
|
+
return cls(question_name="how_are_you", question_text="How are you?")
|
73
71
|
|
74
72
|
|
75
73
|
def main():
|
@@ -4,34 +4,10 @@ import inspect
|
|
4
4
|
from edsl.questions.QuestionBase import QuestionBase
|
5
5
|
|
6
6
|
from edsl.utilities.restricted_python import create_restricted_function
|
7
|
-
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
8
7
|
|
9
8
|
|
10
9
|
class QuestionFunctional(QuestionBase):
|
11
|
-
"""A special type of question that is *not* answered by an LLM.
|
12
|
-
|
13
|
-
>>> from edsl import Scenario, Agent
|
14
|
-
|
15
|
-
# Create an instance of QuestionFunctional with the new function
|
16
|
-
>>> question = QuestionFunctional.example()
|
17
|
-
|
18
|
-
# Activate and test the function
|
19
|
-
>>> question.activate()
|
20
|
-
>>> scenario = Scenario({"numbers": [1, 2, 3, 4, 5]})
|
21
|
-
>>> agent = Agent(traits={"multiplier": 10})
|
22
|
-
>>> results = question.by(scenario).by(agent).run()
|
23
|
-
>>> results.select("answer.*").to_list()[0] == 150
|
24
|
-
True
|
25
|
-
|
26
|
-
# Serialize the question to a dictionary
|
27
|
-
|
28
|
-
>>> from edsl.questions.QuestionBase import QuestionBase
|
29
|
-
>>> new_question = QuestionBase.from_dict(question.to_dict())
|
30
|
-
>>> results = new_question.by(scenario).by(agent).run()
|
31
|
-
>>> results.select("answer.*").to_list()[0] == 150
|
32
|
-
True
|
33
|
-
|
34
|
-
"""
|
10
|
+
"""A special type of question that is *not* answered by an LLM."""
|
35
11
|
|
36
12
|
question_type = "functional"
|
37
13
|
default_instructions = ""
|
@@ -97,7 +73,6 @@ class QuestionFunctional(QuestionBase):
|
|
97
73
|
"""Required by Question, but not used by QuestionFunctional."""
|
98
74
|
raise NotImplementedError
|
99
75
|
|
100
|
-
@add_edsl_version
|
101
76
|
def to_dict(self):
|
102
77
|
return {
|
103
78
|
"question_name": self.question_name,
|
@@ -138,11 +113,4 @@ def main():
|
|
138
113
|
scenario = Scenario({"numbers": [1, 2, 3, 4, 5]})
|
139
114
|
agent = Agent(traits={"multiplier": 10})
|
140
115
|
results = question.by(scenario).by(agent).run()
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
if __name__ == "__main__":
|
145
|
-
# main()
|
146
|
-
import doctest
|
147
|
-
|
148
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
116
|
+
print(results)
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
import time
|
3
3
|
from typing import Union
|
4
4
|
import random
|
5
|
-
|
5
|
+
|
6
6
|
from jinja2 import Template
|
7
7
|
|
8
8
|
from edsl.questions.QuestionBase import QuestionBase
|
@@ -10,11 +10,7 @@ from edsl.questions.descriptors import QuestionOptionsDescriptor
|
|
10
10
|
|
11
11
|
|
12
12
|
class QuestionMultipleChoice(QuestionBase):
|
13
|
-
"""This question prompts the agent to select one option from a list of options.
|
14
|
-
|
15
|
-
https://docs.expectedparrot.com/en/latest/questions.html#questionmultiplechoice-class
|
16
|
-
|
17
|
-
"""
|
13
|
+
"""This question prompts the agent to select one option from a list of options."""
|
18
14
|
|
19
15
|
question_type = "multiple_choice"
|
20
16
|
purpose = "When options are known and limited"
|
@@ -39,71 +35,27 @@ class QuestionMultipleChoice(QuestionBase):
|
|
39
35
|
self.question_text = question_text
|
40
36
|
self.question_options = question_options
|
41
37
|
|
42
|
-
# @property
|
43
|
-
# def question_options(self) -> Union[list[str], list[list], list[float], list[int]]:
|
44
|
-
# """Return the question options."""
|
45
|
-
# return self._question_options
|
46
|
-
|
47
38
|
################
|
48
39
|
# Answer methods
|
49
40
|
################
|
50
41
|
def _validate_answer(
|
51
42
|
self, answer: dict[str, Union[str, int]]
|
52
43
|
) -> dict[str, Union[str, int]]:
|
53
|
-
"""Validate the answer.
|
54
|
-
|
55
|
-
>>> q = QuestionMultipleChoice.example()
|
56
|
-
>>> q._validate_answer({"answer": 0, "comment": "I like custard"})
|
57
|
-
{'answer': 0, 'comment': 'I like custard'}
|
58
|
-
|
59
|
-
>>> q = QuestionMultipleChoice(question_name="how_feeling", question_text="How are you?", question_options=["Good", "Great", "OK", "Bad"])
|
60
|
-
>>> q._validate_answer({"answer": -1, "comment": "I like custard"})
|
61
|
-
Traceback (most recent call last):
|
62
|
-
...
|
63
|
-
edsl.exceptions.questions.QuestionAnswerValidationError: Answer code must be a non-negative integer (got -1).
|
64
|
-
"""
|
44
|
+
"""Validate the answer."""
|
65
45
|
self._validate_answer_template_basic(answer)
|
66
46
|
self._validate_answer_multiple_choice(answer)
|
67
47
|
return answer
|
68
48
|
|
69
49
|
def _translate_answer_code_to_answer(
|
70
|
-
self, answer_code
|
50
|
+
self, answer_code, scenario: "Scenario" = None
|
71
51
|
):
|
72
|
-
"""Translate the answer code to the actual answer.
|
73
|
-
|
74
|
-
It is used to translate the answer code to the actual answer.
|
75
|
-
The question options might be templates, so they need to be rendered with the scenario.
|
76
|
-
|
77
|
-
>>> q = QuestionMultipleChoice.example()
|
78
|
-
>>> q._translate_answer_code_to_answer(0, {})
|
79
|
-
'Good'
|
80
|
-
|
81
|
-
>>> q = QuestionMultipleChoice(question_name="how_feeling", question_text="How are you?", question_options=["{{emotion[0]}}", "emotion[1]"])
|
82
|
-
>>> q._translate_answer_code_to_answer(0, {"emotion": ["Happy", "Sad"]})
|
83
|
-
'Happy'
|
84
|
-
|
85
|
-
"""
|
52
|
+
"""Translate the answer code to the actual answer."""
|
86
53
|
from edsl.scenarios.Scenario import Scenario
|
87
54
|
|
88
55
|
scenario = scenario or Scenario()
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
from jinja2 import Environment, meta
|
93
|
-
|
94
|
-
env = Environment()
|
95
|
-
parsed_content = env.parse(self.question_options)
|
96
|
-
question_option_key = list(meta.find_undeclared_variables(parsed_content))[
|
97
|
-
0
|
98
|
-
]
|
99
|
-
translated_options = scenario.get(question_option_key)
|
100
|
-
else:
|
101
|
-
translated_options = [
|
102
|
-
Template(str(option)).render(scenario)
|
103
|
-
for option in self.question_options
|
104
|
-
]
|
105
|
-
# print("Translated options:", translated_options)
|
106
|
-
# breakpoint()
|
56
|
+
translated_options = [
|
57
|
+
Template(str(option)).render(scenario) for option in self.question_options
|
58
|
+
]
|
107
59
|
return translated_options[int(answer_code)]
|
108
60
|
|
109
61
|
def _simulate_answer(
|
@@ -123,7 +75,6 @@ class QuestionMultipleChoice(QuestionBase):
|
|
123
75
|
|
124
76
|
@property
|
125
77
|
def question_html_content(self) -> str:
|
126
|
-
"""Return the HTML version of the question."""
|
127
78
|
if hasattr(self, "option_labels"):
|
128
79
|
option_labels = self.option_labels
|
129
80
|
else:
|
edsl/questions/descriptors.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import re
|
5
|
-
from typing import Any, Callable
|
5
|
+
from typing import Any, Callable
|
6
6
|
from edsl.exceptions import (
|
7
7
|
QuestionCreationValidationError,
|
8
8
|
QuestionAnswerValidationError,
|
@@ -242,16 +242,6 @@ class QuestionNameDescriptor(BaseDescriptor):
|
|
242
242
|
class QuestionOptionsDescriptor(BaseDescriptor):
|
243
243
|
"""Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items."""
|
244
244
|
|
245
|
-
@classmethod
|
246
|
-
def example(cls):
|
247
|
-
class TestQuestion:
|
248
|
-
question_options = QuestionOptionsDescriptor()
|
249
|
-
|
250
|
-
def __init__(self, question_options: List[str]):
|
251
|
-
self.question_options = question_options
|
252
|
-
|
253
|
-
return TestQuestion
|
254
|
-
|
255
245
|
def __init__(
|
256
246
|
self,
|
257
247
|
num_choices: int = None,
|
@@ -264,31 +254,7 @@ class QuestionOptionsDescriptor(BaseDescriptor):
|
|
264
254
|
self.q_budget = q_budget
|
265
255
|
|
266
256
|
def validate(self, value: Any, instance) -> None:
|
267
|
-
"""Validate the question options.
|
268
|
-
|
269
|
-
>>> q_class = QuestionOptionsDescriptor.example()
|
270
|
-
>>> _ = q_class(["a", "b", "c"])
|
271
|
-
>>> _ = q_class(["a", "b", "c", "d", "d"])
|
272
|
-
Traceback (most recent call last):
|
273
|
-
...
|
274
|
-
edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
|
275
|
-
|
276
|
-
We allow dynamic question options, which are strings of the form '{{ question_options }}'.
|
277
|
-
|
278
|
-
>>> _ = q_class("{{dynamic_options}}")
|
279
|
-
>>> _ = q_class("dynamic_options")
|
280
|
-
Traceback (most recent call last):
|
281
|
-
...
|
282
|
-
edsl.exceptions.questions.QuestionCreationValidationError: Dynamic question options must be of the form: '{{ question_options }}'.
|
283
|
-
"""
|
284
|
-
if isinstance(value, str):
|
285
|
-
# Check if the string is a dynamic question option
|
286
|
-
if "{{" in value and "}}" in value:
|
287
|
-
return None
|
288
|
-
else:
|
289
|
-
raise QuestionCreationValidationError(
|
290
|
-
"Dynamic question options must be of the form: '{{ question_options }}'."
|
291
|
-
)
|
257
|
+
"""Validate the question options."""
|
292
258
|
if not isinstance(value, list):
|
293
259
|
raise QuestionCreationValidationError(
|
294
260
|
f"Question options must be a list (got {value})."
|
@@ -373,9 +339,3 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
373
339
|
f"WARNING: Question text contains a single-braced substring: If you intended to parameterize the question with a Scenario this should be changed to a double-braced substring, e.g. {{variable}}.\nSee details on constructing Scenarios in the docs: https://docs.expectedparrot.com/en/latest/scenarios.html",
|
374
340
|
UserWarning,
|
375
341
|
)
|
376
|
-
|
377
|
-
|
378
|
-
if __name__ == "__main__":
|
379
|
-
import doctest
|
380
|
-
|
381
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -15,9 +15,6 @@ class DatasetExportMixin:
|
|
15
15
|
) -> list:
|
16
16
|
"""Return the set of keys that are present in the dataset.
|
17
17
|
|
18
|
-
:param data_type: The data type to filter by.
|
19
|
-
:param remove_prefix: Whether to remove the prefix from the column names.
|
20
|
-
|
21
18
|
>>> from edsl.results.Dataset import Dataset
|
22
19
|
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
23
20
|
>>> d.relevant_columns()
|
@@ -30,6 +27,7 @@ class DatasetExportMixin:
|
|
30
27
|
['answer.how_feeling', 'answer.how_feeling_yesterday']
|
31
28
|
"""
|
32
29
|
columns = [list(x.keys())[0] for x in self]
|
30
|
+
# columns = set([list(result.keys())[0] for result in self.data])
|
33
31
|
if remove_prefix:
|
34
32
|
columns = [column.split(".")[-1] for column in columns]
|
35
33
|
|
@@ -73,15 +71,7 @@ class DatasetExportMixin:
|
|
73
71
|
return header, rows
|
74
72
|
|
75
73
|
def print_long(self):
|
76
|
-
"""Print the results in a long format.
|
77
|
-
>>> from edsl.results import Results
|
78
|
-
>>> r = Results.example()
|
79
|
-
>>> r.select('how_feeling').print_long()
|
80
|
-
answer.how_feeling: OK
|
81
|
-
answer.how_feeling: Great
|
82
|
-
answer.how_feeling: Terrible
|
83
|
-
answer.how_feeling: OK
|
84
|
-
"""
|
74
|
+
"""Print the results in a long format."""
|
85
75
|
for entry in self:
|
86
76
|
key, list_of_values = list(entry.items())[0]
|
87
77
|
for value in list_of_values:
|
@@ -127,42 +117,6 @@ class DatasetExportMixin:
|
|
127
117
|
│ OK │
|
128
118
|
└──────────────┘
|
129
119
|
|
130
|
-
>>> r = Results.example()
|
131
|
-
>>> r2 = r.select("how_feeling").print(format = "rich", tee = True, max_rows = 2)
|
132
|
-
┏━━━━━━━━━━━━━━┓
|
133
|
-
┃ answer ┃
|
134
|
-
┃ .how_feeling ┃
|
135
|
-
┡━━━━━━━━━━━━━━┩
|
136
|
-
│ OK │
|
137
|
-
├──────────────┤
|
138
|
-
│ Great │
|
139
|
-
└──────────────┘
|
140
|
-
>>> r2
|
141
|
-
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
142
|
-
|
143
|
-
>>> r.select('how_feeling').print(format = "rich", max_rows = 2)
|
144
|
-
┏━━━━━━━━━━━━━━┓
|
145
|
-
┃ answer ┃
|
146
|
-
┃ .how_feeling ┃
|
147
|
-
┡━━━━━━━━━━━━━━┩
|
148
|
-
│ OK │
|
149
|
-
├──────────────┤
|
150
|
-
│ Great │
|
151
|
-
└──────────────┘
|
152
|
-
|
153
|
-
>>> r.select('how_feeling').print(format = "rich", split_at_dot = False)
|
154
|
-
┏━━━━━━━━━━━━━━━━━━━━┓
|
155
|
-
┃ answer.how_feeling ┃
|
156
|
-
┡━━━━━━━━━━━━━━━━━━━━┩
|
157
|
-
│ OK │
|
158
|
-
├────────────────────┤
|
159
|
-
│ Great │
|
160
|
-
├────────────────────┤
|
161
|
-
│ Terrible │
|
162
|
-
├────────────────────┤
|
163
|
-
│ OK │
|
164
|
-
└────────────────────┘
|
165
|
-
|
166
120
|
Example: using the pretty_labels parameter
|
167
121
|
|
168
122
|
>>> r.select('how_feeling').print(format="rich", pretty_labels = {'answer.how_feeling': "How are you feeling"})
|
@@ -200,9 +154,6 @@ class DatasetExportMixin:
|
|
200
154
|
|
201
155
|
if pretty_labels is None:
|
202
156
|
pretty_labels = {}
|
203
|
-
else:
|
204
|
-
# if the user passes in pretty_labels, we don't want to split at the dot
|
205
|
-
split_at_dot = False
|
206
157
|
|
207
158
|
if format not in ["rich", "html", "markdown", "latex"]:
|
208
159
|
raise ValueError("format must be one of 'rich', 'html', or 'markdown'.")
|
@@ -217,6 +168,7 @@ class DatasetExportMixin:
|
|
217
168
|
for key in entry:
|
218
169
|
actual_rows = len(entry[key])
|
219
170
|
entry[key] = entry[key][:max_rows]
|
171
|
+
# print(f"Showing only the first {max_rows} rows of {actual_rows} rows.")
|
220
172
|
|
221
173
|
if format == "rich":
|
222
174
|
from edsl.utilities.interface import print_dataset_with_rich
|
@@ -293,10 +245,6 @@ class DatasetExportMixin:
|
|
293
245
|
>>> r = Results.example()
|
294
246
|
>>> r.select('how_feeling').to_csv()
|
295
247
|
'answer.how_feeling\\r\\nOK\\r\\nGreat\\r\\nTerrible\\r\\nOK\\r\\n'
|
296
|
-
|
297
|
-
>>> r.select('how_feeling').to_csv(pretty_labels = {'answer.how_feeling': "How are you feeling"})
|
298
|
-
'How are you feeling\\r\\nOK\\r\\nGreat\\r\\nTerrible\\r\\nOK\\r\\n'
|
299
|
-
|
300
248
|
"""
|
301
249
|
if pretty_labels is None:
|
302
250
|
pretty_labels = {}
|
@@ -361,15 +309,6 @@ class DatasetExportMixin:
|
|
361
309
|
return ScenarioList([Scenario(d) for d in list_of_dicts])
|
362
310
|
|
363
311
|
def to_agent_list(self, remove_prefix: bool = True):
|
364
|
-
"""Convert the results to a list of dictionaries, one per agent.
|
365
|
-
|
366
|
-
:param remove_prefix: Whether to remove the prefix from the column names.
|
367
|
-
|
368
|
-
>>> from edsl.results import Results
|
369
|
-
>>> r = Results.example()
|
370
|
-
>>> r.select('how_feeling').to_agent_list()
|
371
|
-
AgentList([Agent(traits = {'how_feeling': 'OK'}), Agent(traits = {'how_feeling': 'Great'}), Agent(traits = {'how_feeling': 'Terrible'}), Agent(traits = {'how_feeling': 'OK'})])
|
372
|
-
"""
|
373
312
|
from edsl import AgentList, Agent
|
374
313
|
|
375
314
|
list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
|
@@ -405,9 +344,6 @@ class DatasetExportMixin:
|
|
405
344
|
def to_list(self, flatten=False, remove_none=False) -> list[list]:
|
406
345
|
"""Convert the results to a list of lists.
|
407
346
|
|
408
|
-
:param flatten: Whether to flatten the list of lists.
|
409
|
-
:param remove_none: Whether to remove None values from the list.
|
410
|
-
|
411
347
|
>>> from edsl.results import Results
|
412
348
|
>>> Results.example().select('how_feeling', 'how_feeling_yesterday')
|
413
349
|
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
|
@@ -418,18 +354,6 @@ class DatasetExportMixin:
|
|
418
354
|
>>> r = Results.example()
|
419
355
|
>>> r.select('how_feeling').to_list()
|
420
356
|
['OK', 'Great', 'Terrible', 'OK']
|
421
|
-
|
422
|
-
>>> from edsl.results.Dataset import Dataset
|
423
|
-
>>> Dataset([{'a.b': [[1, 9], 2, 3, 4]}]).select('a.b').to_list(flatten = True)
|
424
|
-
[1, 9, 2, 3, 4]
|
425
|
-
|
426
|
-
>>> from edsl.results.Dataset import Dataset
|
427
|
-
>>> Dataset([{'a.b': [[1, 9], 2, 3, 4]}, {'c': [6, 2, 3, 4]}]).select('a.b', 'c').to_list(flatten = True)
|
428
|
-
Traceback (most recent call last):
|
429
|
-
...
|
430
|
-
ValueError: Cannot flatten a list of lists when there are multiple columns selected.
|
431
|
-
|
432
|
-
|
433
357
|
"""
|
434
358
|
if len(self.relevant_columns()) > 1 and flatten:
|
435
359
|
raise ValueError(
|
@@ -461,10 +385,7 @@ class DatasetExportMixin:
|
|
461
385
|
return list_to_return
|
462
386
|
|
463
387
|
def html(
|
464
|
-
self,
|
465
|
-
filename: Optional[str] = None,
|
466
|
-
cta: str = "Open in browser",
|
467
|
-
return_link: bool = False,
|
388
|
+
self, filename: str = None, cta: str = "Open in browser", return_link=False
|
468
389
|
):
|
469
390
|
import os
|
470
391
|
import tempfile
|
@@ -498,7 +419,7 @@ class DatasetExportMixin:
|
|
498
419
|
return filename
|
499
420
|
|
500
421
|
def tally(
|
501
|
-
self, *fields: Optional[str], top_n
|
422
|
+
self, *fields: Optional[str], top_n=None, output="dict"
|
502
423
|
) -> Union[dict, "Dataset"]:
|
503
424
|
"""Tally the values of a field or perform a cross-tab of multiple fields.
|
504
425
|
|