edsl 0.1.29.dev2__py3-none-any.whl → 0.1.29.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +12 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +37 -2
- edsl/agents/AgentList.py +3 -4
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/conjure/InputData.py +39 -8
- edsl/coop/coop.py +187 -150
- edsl/coop/utils.py +17 -76
- edsl/jobs/Jobs.py +23 -17
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +1 -0
- edsl/notebooks/Notebook.py +31 -0
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +32 -11
- edsl/questions/question_registry.py +20 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/Results.py +6 -8
- edsl/results/ResultsToolsMixin.py +4 -1
- edsl/scenarios/ScenarioList.py +17 -3
- edsl/study/Study.py +3 -9
- edsl/surveys/Survey.py +37 -3
- edsl/tools/plotting.py +4 -2
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/METADATA +11 -10
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/RECORD +27 -28
- edsl-0.1.29.dev2.dist-info/entry_points.txt +0 -3
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/LICENSE +0 -0
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -312,10 +312,6 @@ class Jobs(Base):
|
|
312
312
|
# if no agents, models, or scenarios are set, set them to defaults
|
313
313
|
self.agents = self.agents or [Agent()]
|
314
314
|
self.models = self.models or [Model()]
|
315
|
-
# if remote, set all the models to remote
|
316
|
-
if hasattr(self, "remote") and self.remote:
|
317
|
-
for model in self.models:
|
318
|
-
model.remote = True
|
319
315
|
self.scenarios = self.scenarios or [Scenario()]
|
320
316
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
317
|
yield Interview(
|
@@ -368,14 +364,14 @@ class Jobs(Base):
|
|
368
364
|
if self.verbose:
|
369
365
|
print(message)
|
370
366
|
|
371
|
-
def _check_parameters(self, strict=False, warn
|
367
|
+
def _check_parameters(self, strict=False, warn=False) -> None:
|
372
368
|
"""Check if the parameters in the survey and scenarios are consistent.
|
373
369
|
|
374
370
|
>>> from edsl import QuestionFreeText
|
375
371
|
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
376
372
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
377
373
|
>>> with warnings.catch_warnings(record=True) as w:
|
378
|
-
... j._check_parameters()
|
374
|
+
... j._check_parameters(warn = True)
|
379
375
|
... assert len(w) == 1
|
380
376
|
... assert issubclass(w[-1].category, UserWarning)
|
381
377
|
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
@@ -413,15 +409,13 @@ class Jobs(Base):
|
|
413
409
|
progress_bar: bool = False,
|
414
410
|
stop_on_exception: bool = False,
|
415
411
|
cache: Union[Cache, bool] = None,
|
416
|
-
remote: bool = (
|
417
|
-
False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
|
418
|
-
),
|
419
412
|
check_api_keys: bool = False,
|
420
413
|
sidecar_model: Optional[LanguageModel] = None,
|
421
414
|
batch_mode: Optional[bool] = None,
|
422
415
|
verbose: bool = False,
|
423
416
|
print_exceptions=True,
|
424
417
|
remote_cache_description: Optional[str] = None,
|
418
|
+
remote_inference_description: Optional[str] = None,
|
425
419
|
) -> Results:
|
426
420
|
"""
|
427
421
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -431,11 +425,11 @@ class Jobs(Base):
|
|
431
425
|
:param progress_bar: shows a progress bar
|
432
426
|
:param stop_on_exception: stops the job if an exception is raised
|
433
427
|
:param cache: a cache object to store results
|
434
|
-
:param remote: run the job remotely
|
435
428
|
:param check_api_keys: check if the API keys are valid
|
436
429
|
:param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
|
437
430
|
:param verbose: prints messages
|
438
431
|
:param remote_cache_description: specifies a description for this group of entries in the remote cache
|
432
|
+
:param remote_inference_description: specifies a description for the remote inference job
|
439
433
|
"""
|
440
434
|
from edsl.coop.coop import Coop
|
441
435
|
|
@@ -446,21 +440,33 @@ class Jobs(Base):
|
|
446
440
|
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
447
441
|
)
|
448
442
|
|
449
|
-
self.remote = remote
|
450
443
|
self.verbose = verbose
|
451
444
|
|
452
445
|
try:
|
453
446
|
coop = Coop()
|
454
|
-
|
447
|
+
user_edsl_settings = coop.edsl_settings
|
448
|
+
remote_cache = user_edsl_settings["remote_caching"]
|
449
|
+
remote_inference = user_edsl_settings["remote_inference"]
|
455
450
|
except Exception:
|
456
451
|
remote_cache = False
|
452
|
+
remote_inference = False
|
457
453
|
|
458
|
-
if
|
459
|
-
|
460
|
-
if
|
461
|
-
|
454
|
+
if remote_inference:
|
455
|
+
self._output("Remote inference activated. Sending job to server...")
|
456
|
+
if remote_cache:
|
457
|
+
self._output(
|
458
|
+
"Remote caching activated. The remote cache will be used for this job."
|
459
|
+
)
|
462
460
|
|
463
|
-
|
461
|
+
remote_job_data = coop.remote_inference_create(
|
462
|
+
self,
|
463
|
+
description=remote_inference_description,
|
464
|
+
status="queued",
|
465
|
+
)
|
466
|
+
self._output("Job sent!")
|
467
|
+
self._output(remote_job_data)
|
468
|
+
return remote_job_data
|
469
|
+
else:
|
464
470
|
if check_api_keys:
|
465
471
|
for model in self.models + [Model()]:
|
466
472
|
if not model.has_valid_api_key():
|
edsl/notebooks/Notebook.py
CHANGED
@@ -56,6 +56,37 @@ class Notebook(Base):
|
|
56
56
|
|
57
57
|
self.name = name or self.default_name
|
58
58
|
|
59
|
+
@classmethod
|
60
|
+
def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
|
61
|
+
# Read the script file
|
62
|
+
with open(path, "r") as script_file:
|
63
|
+
script_content = script_file.read()
|
64
|
+
|
65
|
+
# Create a new Jupyter notebook
|
66
|
+
nb = nbformat.v4.new_notebook()
|
67
|
+
|
68
|
+
# Add the script content to the first cell
|
69
|
+
first_cell = nbformat.v4.new_code_cell(script_content)
|
70
|
+
nb.cells.append(first_cell)
|
71
|
+
|
72
|
+
# Create a Notebook instance with the notebook data
|
73
|
+
notebook_instance = cls(nb)
|
74
|
+
|
75
|
+
return notebook_instance
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def from_current_script(cls) -> "Notebook":
|
79
|
+
import inspect
|
80
|
+
import os
|
81
|
+
|
82
|
+
# Get the path to the current file
|
83
|
+
current_frame = inspect.currentframe()
|
84
|
+
caller_frame = inspect.getouterframes(current_frame, 2)
|
85
|
+
current_file_path = os.path.abspath(caller_frame[1].filename)
|
86
|
+
|
87
|
+
# Use from_script to create the notebook
|
88
|
+
return cls.from_script(current_file_path)
|
89
|
+
|
59
90
|
def __eq__(self, other):
|
60
91
|
"""
|
61
92
|
Check if two Notebooks are equal.
|
edsl/prompts/Prompt.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1
|
-
"""Class for creating prompts to be used in a survey."""
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
from typing import Optional
|
5
3
|
from abc import ABC
|
6
4
|
from typing import Any, List
|
7
5
|
|
8
6
|
from rich.table import Table
|
9
|
-
from jinja2 import Template, Environment, meta, TemplateSyntaxError
|
7
|
+
from jinja2 import Template, Environment, meta, TemplateSyntaxError, Undefined
|
8
|
+
|
9
|
+
|
10
|
+
class PreserveUndefined(Undefined):
|
11
|
+
def __str__(self):
|
12
|
+
return "{{ " + self._undefined_name + " }}"
|
13
|
+
|
10
14
|
|
11
15
|
from edsl.exceptions.prompts import TemplateRenderError
|
12
16
|
from edsl.prompts.prompt_config import (
|
@@ -35,6 +39,10 @@ class PromptBase(
|
|
35
39
|
|
36
40
|
return data_to_html(self.to_dict())
|
37
41
|
|
42
|
+
def __len__(self):
|
43
|
+
"""Return the length of the prompt text."""
|
44
|
+
return len(self.text)
|
45
|
+
|
38
46
|
@classmethod
|
39
47
|
def prompt_attributes(cls) -> List[str]:
|
40
48
|
"""Return the prompt class attributes."""
|
@@ -75,10 +83,10 @@ class PromptBase(
|
|
75
83
|
>>> p = Prompt("Hello, {{person}}")
|
76
84
|
>>> p2 = Prompt("How are you?")
|
77
85
|
>>> p + p2
|
78
|
-
Prompt(text
|
86
|
+
Prompt(text=\"""Hello, {{person}}How are you?\""")
|
79
87
|
|
80
88
|
>>> p + "How are you?"
|
81
|
-
Prompt(text
|
89
|
+
Prompt(text=\"""Hello, {{person}}How are you?\""")
|
82
90
|
"""
|
83
91
|
if isinstance(other_prompt, str):
|
84
92
|
return self.__class__(self.text + other_prompt)
|
@@ -114,7 +122,7 @@ class PromptBase(
|
|
114
122
|
Example:
|
115
123
|
>>> p = Prompt("Hello, {{person}}")
|
116
124
|
>>> p
|
117
|
-
Prompt(text
|
125
|
+
Prompt(text=\"""Hello, {{person}}\""")
|
118
126
|
"""
|
119
127
|
return f'Prompt(text="""{self.text}""")'
|
120
128
|
|
@@ -137,7 +145,7 @@ class PromptBase(
|
|
137
145
|
:param template: The template to find the variables in.
|
138
146
|
|
139
147
|
"""
|
140
|
-
env = Environment()
|
148
|
+
env = Environment(undefined=PreserveUndefined)
|
141
149
|
ast = env.parse(template)
|
142
150
|
return list(meta.find_undeclared_variables(ast))
|
143
151
|
|
@@ -186,13 +194,16 @@ class PromptBase(
|
|
186
194
|
|
187
195
|
>>> p = Prompt("Hello, {{person}}")
|
188
196
|
>>> p.render({"person": "John"})
|
189
|
-
|
197
|
+
Prompt(text=\"""Hello, John\""")
|
190
198
|
|
191
199
|
>>> p.render({"person": "Mr. {{last_name}}", "last_name": "Horton"})
|
192
|
-
|
200
|
+
Prompt(text=\"""Hello, Mr. Horton\""")
|
193
201
|
|
194
202
|
>>> p.render({"person": "Mr. {{last_name}}", "last_name": "Ho{{letter}}ton"}, max_nesting = 1)
|
195
|
-
|
203
|
+
Prompt(text=\"""Hello, Mr. Ho{{ letter }}ton\""")
|
204
|
+
|
205
|
+
>>> p.render({"person": "Mr. {{last_name}}"})
|
206
|
+
Prompt(text=\"""Hello, Mr. {{ last_name }}\""")
|
196
207
|
"""
|
197
208
|
new_text = self._render(
|
198
209
|
self.text, primary_replacement, **additional_replacements
|
@@ -216,12 +227,13 @@ class PromptBase(
|
|
216
227
|
>>> codebook = {"age": "Age"}
|
217
228
|
>>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
|
218
229
|
>>> p.render({"name": "John", "age": 44}, codebook=codebook)
|
219
|
-
|
230
|
+
Prompt(text=\"""You are an agent named John. Age: 44\""")
|
220
231
|
"""
|
232
|
+
env = Environment(undefined=PreserveUndefined)
|
221
233
|
try:
|
222
234
|
previous_text = None
|
223
235
|
for _ in range(MAX_NESTING):
|
224
|
-
rendered_text =
|
236
|
+
rendered_text = env.from_string(text).render(
|
225
237
|
primary_replacement, **additional_replacements
|
226
238
|
)
|
227
239
|
if rendered_text == previous_text:
|
@@ -258,7 +270,7 @@ class PromptBase(
|
|
258
270
|
>>> p = Prompt("Hello, {{person}}")
|
259
271
|
>>> p2 = Prompt.from_dict(p.to_dict())
|
260
272
|
>>> p2
|
261
|
-
Prompt(text
|
273
|
+
Prompt(text=\"""Hello, {{person}}\""")
|
262
274
|
|
263
275
|
"""
|
264
276
|
class_name = data["class_name"]
|
@@ -290,6 +302,12 @@ class Prompt(PromptBase):
|
|
290
302
|
component_type = ComponentTypes.GENERIC
|
291
303
|
|
292
304
|
|
305
|
+
if __name__ == "__main__":
|
306
|
+
print("Running doctests...")
|
307
|
+
import doctest
|
308
|
+
|
309
|
+
doctest.testmod()
|
310
|
+
|
293
311
|
from edsl.prompts.library.question_multiple_choice import *
|
294
312
|
from edsl.prompts.library.agent_instructions import *
|
295
313
|
from edsl.prompts.library.agent_persona import *
|
@@ -302,9 +320,3 @@ from edsl.prompts.library.question_numerical import *
|
|
302
320
|
from edsl.prompts.library.question_rank import *
|
303
321
|
from edsl.prompts.library.question_extract import *
|
304
322
|
from edsl.prompts.library.question_list import *
|
305
|
-
|
306
|
-
|
307
|
-
if __name__ == "__main__":
|
308
|
-
import doctest
|
309
|
-
|
310
|
-
doctest.testmod()
|
edsl/questions/QuestionBase.py
CHANGED
@@ -173,15 +173,16 @@ class QuestionBase(
|
|
173
173
|
def add_model_instructions(
|
174
174
|
self, *, instructions: str, model: Optional[str] = None
|
175
175
|
) -> None:
|
176
|
-
"""Add model-specific instructions for the question.
|
176
|
+
"""Add model-specific instructions for the question that override the default instructions.
|
177
177
|
|
178
178
|
:param instructions: The instructions to add. This is typically a jinja2 template.
|
179
179
|
:param model: The language model for this instruction.
|
180
180
|
|
181
181
|
>>> from edsl.questions import QuestionFreeText
|
182
182
|
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
|
183
|
-
>>> q.add_model_instructions(instructions = "Answer in valid JSON like so {'answer': 'comment: <>}", model = "gpt3")
|
184
|
-
|
183
|
+
>>> q.add_model_instructions(instructions = "{{question_text}}. Answer in valid JSON like so {'answer': 'comment: <>}", model = "gpt3")
|
184
|
+
>>> q.get_instructions(model = "gpt3")
|
185
|
+
Prompt(text=\"""{{question_text}}. Answer in valid JSON like so {'answer': 'comment: <>}\""")
|
185
186
|
"""
|
186
187
|
from edsl import Model
|
187
188
|
|
@@ -201,6 +202,13 @@ class QuestionBase(
|
|
201
202
|
"""Get the mathcing question-answering instructions for the question.
|
202
203
|
|
203
204
|
:param model: The language model to use.
|
205
|
+
|
206
|
+
>>> from edsl import QuestionFreeText
|
207
|
+
>>> QuestionFreeText.example().get_instructions()
|
208
|
+
Prompt(text=\"""You are being asked the following question: {{question_text}}
|
209
|
+
Return a valid JSON formatted like this:
|
210
|
+
{"answer": "<put free text answer here>"}
|
211
|
+
\""")
|
204
212
|
"""
|
205
213
|
from edsl.prompts.Prompt import Prompt
|
206
214
|
|
@@ -293,7 +301,16 @@ class QuestionBase(
|
|
293
301
|
print_json(json.dumps(self.to_dict()))
|
294
302
|
|
295
303
|
def __call__(self, just_answer=True, model=None, agent=None, **kwargs):
|
296
|
-
"""Call the question.
|
304
|
+
"""Call the question.
|
305
|
+
|
306
|
+
>>> from edsl.language_models import LanguageModel
|
307
|
+
>>> m = LanguageModel.example(canned_response = "Yo, what's up?", test_model = True)
|
308
|
+
>>> from edsl import QuestionFreeText
|
309
|
+
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
|
310
|
+
>>> q(model = m)
|
311
|
+
"Yo, what's up?"
|
312
|
+
|
313
|
+
"""
|
297
314
|
survey = self.to_survey()
|
298
315
|
results = survey(model=model, agent=agent, **kwargs)
|
299
316
|
if just_answer:
|
@@ -304,7 +321,6 @@ class QuestionBase(
|
|
304
321
|
async def run_async(self, just_answer=True, model=None, agent=None, **kwargs):
|
305
322
|
"""Call the question."""
|
306
323
|
survey = self.to_survey()
|
307
|
-
## asyncio.run(survey.async_call());
|
308
324
|
results = await survey.run_async(model=model, agent=agent, **kwargs)
|
309
325
|
if just_answer:
|
310
326
|
return results.select(f"answer.{self.question_name}").first()
|
@@ -383,29 +399,34 @@ class QuestionBase(
|
|
383
399
|
s = Survey([self, other])
|
384
400
|
return s
|
385
401
|
|
386
|
-
def to_survey(self):
|
402
|
+
def to_survey(self) -> "Survey":
|
387
403
|
"""Turn a single question into a survey."""
|
388
404
|
from edsl.surveys.Survey import Survey
|
389
405
|
|
390
406
|
s = Survey([self])
|
391
407
|
return s
|
392
408
|
|
393
|
-
def run(self, *args, **kwargs):
|
409
|
+
def run(self, *args, **kwargs) -> "Results":
|
394
410
|
"""Turn a single question into a survey and run it."""
|
395
411
|
from edsl.surveys.Survey import Survey
|
396
412
|
|
397
413
|
s = self.to_survey()
|
398
414
|
return s.run(*args, **kwargs)
|
399
415
|
|
400
|
-
def by(self, *args):
|
401
|
-
"""Turn a single question into a survey and
|
416
|
+
def by(self, *args) -> "Jobs":
|
417
|
+
"""Turn a single question into a survey and then a Job."""
|
402
418
|
from edsl.surveys.Survey import Survey
|
403
419
|
|
404
420
|
s = Survey([self])
|
405
421
|
return s.by(*args)
|
406
422
|
|
407
|
-
def human_readable(self):
|
408
|
-
"""Print the question in a human readable format.
|
423
|
+
def human_readable(self) -> str:
|
424
|
+
"""Print the question in a human readable format.
|
425
|
+
|
426
|
+
>>> from edsl.questions import QuestionFreeText
|
427
|
+
>>> QuestionFreeText.example().human_readable()
|
428
|
+
'Question Type: free_text\\nQuestion: How are you?'
|
429
|
+
"""
|
409
430
|
lines = []
|
410
431
|
lines.append(f"Question Type: {self.question_type}")
|
411
432
|
lines.append(f"Question: {self.question_text}")
|
@@ -1,10 +1,10 @@
|
|
1
1
|
"""This module provides a factory class for creating question objects."""
|
2
2
|
|
3
3
|
import textwrap
|
4
|
-
from
|
4
|
+
from uuid import UUID
|
5
|
+
from typing import Any, Optional, Union
|
6
|
+
|
5
7
|
|
6
|
-
from edsl.exceptions import QuestionSerializationError
|
7
|
-
from edsl.exceptions import QuestionCreationValidationError
|
8
8
|
from edsl.questions.QuestionBase import RegisterQuestionsMeta
|
9
9
|
|
10
10
|
|
@@ -60,46 +60,35 @@ class Question(metaclass=Meta):
|
|
60
60
|
return q.example()
|
61
61
|
|
62
62
|
@classmethod
|
63
|
-
def pull(cls,
|
63
|
+
def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
|
64
64
|
"""Pull the object from coop."""
|
65
65
|
from edsl.coop import Coop
|
66
66
|
|
67
|
-
|
68
|
-
|
69
|
-
id = id_or_url.split("/")[-1]
|
70
|
-
else:
|
71
|
-
id = id_or_url
|
72
|
-
from edsl.questions.QuestionBase import QuestionBase
|
73
|
-
|
74
|
-
return c._get_base(QuestionBase, id)
|
67
|
+
coop = Coop()
|
68
|
+
return coop.get(uuid, url, "question")
|
75
69
|
|
76
70
|
@classmethod
|
77
|
-
def delete(cls,
|
71
|
+
def delete(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
|
78
72
|
"""Delete the object from coop."""
|
79
73
|
from edsl.coop import Coop
|
80
74
|
|
81
|
-
|
82
|
-
|
83
|
-
id = id_or_url.split("/")[-1]
|
84
|
-
else:
|
85
|
-
id = id_or_url
|
86
|
-
from edsl.questions.QuestionBase import QuestionBase
|
87
|
-
|
88
|
-
return c._delete_base(QuestionBase, id)
|
75
|
+
coop = Coop()
|
76
|
+
return coop.delete(uuid, url)
|
89
77
|
|
90
78
|
@classmethod
|
91
|
-
def
|
92
|
-
|
79
|
+
def patch(
|
80
|
+
cls,
|
81
|
+
uuid: Optional[Union[str, UUID]] = None,
|
82
|
+
url: Optional[str] = None,
|
83
|
+
description: Optional[str] = None,
|
84
|
+
value: Optional[Any] = None,
|
85
|
+
visibility: Optional[str] = None,
|
86
|
+
):
|
87
|
+
"""Patch the object on coop."""
|
93
88
|
from edsl.coop import Coop
|
94
89
|
|
95
|
-
|
96
|
-
|
97
|
-
id = id_or_url.split("/")[-1]
|
98
|
-
else:
|
99
|
-
id = id_or_url
|
100
|
-
from edsl.questions.QuestionBase import QuestionBase
|
101
|
-
|
102
|
-
return c._update_base(QuestionBase, id, visibility)
|
90
|
+
coop = Coop()
|
91
|
+
return coop.patch(uuid, url, description, value, visibility)
|
103
92
|
|
104
93
|
@classmethod
|
105
94
|
def available(cls, show_class_names: bool = False) -> Union[list, dict]:
|
edsl/questions/settings.py
CHANGED
edsl/results/Dataset.py
CHANGED
@@ -78,6 +78,28 @@ class Dataset(UserList, ResultsExportMixin):
|
|
78
78
|
|
79
79
|
return get_values(self.data[0])[0]
|
80
80
|
|
81
|
+
def select(self, *keys):
|
82
|
+
"""Return a new dataset with only the selected keys.
|
83
|
+
|
84
|
+
:param keys: The keys to select.
|
85
|
+
|
86
|
+
>>> d = Dataset([{'a.b':[1,2,3,4]}, {'c.d':[5,6,7,8]}])
|
87
|
+
>>> d.select('a.b')
|
88
|
+
Dataset([{'a.b': [1, 2, 3, 4]}])
|
89
|
+
|
90
|
+
>>> d.select('a.b', 'c.d')
|
91
|
+
Dataset([{'a.b': [1, 2, 3, 4]}, {'c.d': [5, 6, 7, 8]}])
|
92
|
+
"""
|
93
|
+
if isinstance(keys, str):
|
94
|
+
keys = [keys]
|
95
|
+
|
96
|
+
new_data = []
|
97
|
+
for observation in self.data:
|
98
|
+
observation_key = list(observation.keys())[0]
|
99
|
+
if observation_key in keys:
|
100
|
+
new_data.append(observation)
|
101
|
+
return Dataset(new_data)
|
102
|
+
|
81
103
|
def _repr_html_(self) -> str:
|
82
104
|
"""Return an HTML representation of the dataset."""
|
83
105
|
from edsl.utilities.utilities import data_to_html
|
@@ -223,6 +245,15 @@ class Dataset(UserList, ResultsExportMixin):
|
|
223
245
|
|
224
246
|
return Dataset(new_data)
|
225
247
|
|
248
|
+
@classmethod
|
249
|
+
def example(self):
|
250
|
+
"""Return an example dataset.
|
251
|
+
|
252
|
+
>>> Dataset.example()
|
253
|
+
Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
|
254
|
+
"""
|
255
|
+
return Dataset([{"a": [1, 2, 3, 4]}, {"b": [4, 3, 2, 1]}])
|
256
|
+
|
226
257
|
|
227
258
|
if __name__ == "__main__":
|
228
259
|
import doctest
|
edsl/results/Results.py
CHANGED
@@ -165,13 +165,7 @@ class Results(UserList, Mixins, Base):
|
|
165
165
|
)
|
166
166
|
|
167
167
|
def __repr__(self) -> str:
|
168
|
-
|
169
|
-
return f"""Results object
|
170
|
-
Size: {len(self.data)}.
|
171
|
-
Survey questions: {[q.question_name for q in self.survey.questions]}.
|
172
|
-
Created columns: {self.created_columns}
|
173
|
-
Hash: {hash(self)}
|
174
|
-
"""
|
168
|
+
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
175
169
|
|
176
170
|
def _repr_html_(self) -> str:
|
177
171
|
json_str = json.dumps(self.to_dict()["data"], indent=4)
|
@@ -759,7 +753,10 @@ class Results(UserList, Mixins, Base):
|
|
759
753
|
|
760
754
|
def sort_by(self, *columns: str, reverse: bool = False) -> Results:
|
761
755
|
import warnings
|
762
|
-
|
756
|
+
|
757
|
+
warnings.warn(
|
758
|
+
"sort_by is deprecated. Use order_by instead.", DeprecationWarning
|
759
|
+
)
|
763
760
|
return self.order_by(*columns, reverse=reverse)
|
764
761
|
|
765
762
|
def order_by(self, *columns: str, reverse: bool = False) -> Results:
|
@@ -800,6 +797,7 @@ class Results(UserList, Mixins, Base):
|
|
800
797
|
│ Great │
|
801
798
|
└──────────────┘
|
802
799
|
"""
|
800
|
+
|
803
801
|
def to_numeric_if_possible(v):
|
804
802
|
try:
|
805
803
|
return float(v)
|
@@ -13,7 +13,10 @@ class ResultsToolsMixin:
|
|
13
13
|
progress_bar=False,
|
14
14
|
print_exceptions=False,
|
15
15
|
) -> list:
|
16
|
-
values =
|
16
|
+
values = [
|
17
|
+
str(txt)[:1000]
|
18
|
+
for txt in self.shuffle(seed=seed).select(field).to_list()[:max_values]
|
19
|
+
]
|
17
20
|
from edsl import ScenarioList
|
18
21
|
|
19
22
|
q = QuestionList(
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -119,7 +119,7 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
|
|
119
119
|
|
120
120
|
return ScenarioList(random.sample(self.data, n))
|
121
121
|
|
122
|
-
def expand(self, expand_field: str, number_field
|
122
|
+
def expand(self, expand_field: str, number_field=False) -> ScenarioList:
|
123
123
|
"""Expand the ScenarioList by a field.
|
124
124
|
|
125
125
|
Example:
|
@@ -137,7 +137,7 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
|
|
137
137
|
new_scenario = scenario.copy()
|
138
138
|
new_scenario[expand_field] = value
|
139
139
|
if number_field:
|
140
|
-
new_scenario[expand_field +
|
140
|
+
new_scenario[expand_field + "_number"] = index + 1
|
141
141
|
new_scenarios.append(new_scenario)
|
142
142
|
return ScenarioList(new_scenarios)
|
143
143
|
|
@@ -192,7 +192,7 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
|
|
192
192
|
|
193
193
|
def get_sort_key(scenario: Any) -> tuple:
|
194
194
|
return tuple(scenario[field] for field in fields)
|
195
|
-
|
195
|
+
|
196
196
|
return ScenarioList(sorted(self, key=get_sort_key, reverse=reverse))
|
197
197
|
|
198
198
|
def filter(self, expression: str) -> ScenarioList:
|
@@ -343,6 +343,20 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
|
|
343
343
|
"""
|
344
344
|
return cls([Scenario(row) for row in df.to_dict(orient="records")])
|
345
345
|
|
346
|
+
def to_key_value(self, field, value=None) -> Union[dict, set]:
|
347
|
+
"""Return the set of values in the field.
|
348
|
+
|
349
|
+
Example:
|
350
|
+
|
351
|
+
>>> s = ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
|
352
|
+
>>> s.to_key_value('name') == {'Alice', 'Bob'}
|
353
|
+
True
|
354
|
+
"""
|
355
|
+
if value is None:
|
356
|
+
return {scenario[field] for scenario in self}
|
357
|
+
else:
|
358
|
+
return {scenario[field]: scenario[value] for scenario in self}
|
359
|
+
|
346
360
|
@classmethod
|
347
361
|
def from_csv(cls, filename: str) -> ScenarioList:
|
348
362
|
"""Create a ScenarioList from a CSV file.
|
edsl/study/Study.py
CHANGED
@@ -472,18 +472,12 @@ class Study:
|
|
472
472
|
coop.create(self, description=self.description)
|
473
473
|
|
474
474
|
@classmethod
|
475
|
-
def pull(cls,
|
475
|
+
def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
|
476
476
|
"""Pull the object from coop."""
|
477
477
|
from edsl.coop import Coop
|
478
478
|
|
479
|
-
|
480
|
-
|
481
|
-
else:
|
482
|
-
uuid_value = id_or_url
|
483
|
-
|
484
|
-
c = Coop()
|
485
|
-
|
486
|
-
return c._get_base(cls, uuid_value, exec_profile=exec_profile)
|
479
|
+
coop = Coop()
|
480
|
+
return coop.get(uuid, url, "study")
|
487
481
|
|
488
482
|
def __repr__(self):
|
489
483
|
return f"""Study(name = "{self.name}", description = "{self.description}", objects = {self.objects}, cache = {self.cache}, filename = "{self.filename}", coop = {self.coop}, use_study_cache = {self.use_study_cache}, overwrite_on_change = {self.overwrite_on_change})"""
|