edsl 0.1.29.dev3__py3-none-any.whl → 0.1.29.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +12 -0
- edsl/agents/AgentList.py +3 -4
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +340 -100
- edsl/conjure/InputData.py +37 -8
- edsl/coop/coop.py +68 -15
- edsl/coop/utils.py +2 -0
- edsl/jobs/Jobs.py +22 -16
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +1 -0
- edsl/notebooks/Notebook.py +30 -0
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +32 -11
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/ResultsToolsMixin.py +2 -1
- edsl/scenarios/ScenarioList.py +19 -3
- edsl/surveys/Survey.py +37 -3
- edsl/tools/plotting.py +4 -2
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.29.dev5.dist-info}/METADATA +11 -10
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.29.dev5.dist-info}/RECORD +23 -24
- edsl-0.1.29.dev3.dist-info/entry_points.txt +0 -3
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.29.dev5.dist-info}/LICENSE +0 -0
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.29.dev5.dist-info}/WHEEL +0 -0
edsl/conjure/InputData.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
import
|
2
|
-
|
1
|
+
import base64
|
3
2
|
from abc import ABC, abstractmethod
|
4
3
|
from typing import Dict, Callable, Optional, List, Generator, Tuple, Union
|
5
4
|
from collections import namedtuple
|
@@ -52,6 +51,7 @@ class InputDataABC(
|
|
52
51
|
config: Optional[dict] = None,
|
53
52
|
naming_function: Optional[Callable] = sanitize_string,
|
54
53
|
raw_data: Optional[List] = None,
|
54
|
+
binary: Optional[str] = None,
|
55
55
|
question_names: Optional[List[str]] = None,
|
56
56
|
question_texts: Optional[List[str]] = None,
|
57
57
|
answer_codebook: Optional[Dict] = None,
|
@@ -83,6 +83,15 @@ class InputDataABC(
|
|
83
83
|
self.config = config
|
84
84
|
self.naming_function = naming_function
|
85
85
|
|
86
|
+
if binary is not None:
|
87
|
+
self.binary = binary
|
88
|
+
else:
|
89
|
+
try:
|
90
|
+
with open(self.datafile_name, 'rb') as file:
|
91
|
+
self.binary = base64.b64encode(file.read()).decode()
|
92
|
+
except FileNotFoundError:
|
93
|
+
self.binary = None
|
94
|
+
|
86
95
|
def default_repair_func(x):
|
87
96
|
return (
|
88
97
|
x.replace("#", "_num")
|
@@ -118,6 +127,14 @@ class InputDataABC(
|
|
118
127
|
if order_options:
|
119
128
|
self.order_options()
|
120
129
|
|
130
|
+
@property
|
131
|
+
def download_link(self):
|
132
|
+
from IPython.display import HTML
|
133
|
+
actual_file_name = self.datafile_name.split("/")[-1]
|
134
|
+
download_link =f'<a href="data:text/plain;base64,{self.binary}" download="{actual_file_name}">Download {self.datafile_name}</a>'
|
135
|
+
return HTML(download_link)
|
136
|
+
|
137
|
+
|
121
138
|
@abstractmethod
|
122
139
|
def get_question_texts(self) -> List[str]:
|
123
140
|
"""Get the text of the questions
|
@@ -151,7 +168,7 @@ class InputDataABC(
|
|
151
168
|
"""
|
152
169
|
raise NotImplementedError
|
153
170
|
|
154
|
-
def rename_questions(self, rename_dict: Dict[str, str]) -> "InputData":
|
171
|
+
def rename_questions(self, rename_dict: Dict[str, str], ignore_missing = False) -> "InputData":
|
155
172
|
"""Rename a question.
|
156
173
|
|
157
174
|
>>> id = InputDataABC.example()
|
@@ -160,10 +177,10 @@ class InputDataABC(
|
|
160
177
|
|
161
178
|
"""
|
162
179
|
for old_name, new_name in rename_dict.items():
|
163
|
-
self.rename(old_name, new_name)
|
180
|
+
self.rename(old_name, new_name, ignore_missing = ignore_missing)
|
164
181
|
return self
|
165
182
|
|
166
|
-
def rename(self, old_name, new_name) -> "InputData":
|
183
|
+
def rename(self, old_name, new_name, ignore_missing = False) -> "InputData":
|
167
184
|
"""Rename a question.
|
168
185
|
|
169
186
|
>>> id = InputDataABC.example()
|
@@ -171,13 +188,19 @@ class InputDataABC(
|
|
171
188
|
['evening', 'feeling']
|
172
189
|
|
173
190
|
"""
|
191
|
+
if old_name not in self.question_names:
|
192
|
+
if ignore_missing:
|
193
|
+
return self
|
194
|
+
else:
|
195
|
+
raise ValueError(f"Question {old_name} not found.")
|
196
|
+
|
174
197
|
idx = self.question_names.index(old_name)
|
175
198
|
self.question_names[idx] = new_name
|
176
199
|
self.answer_codebook[new_name] = self.answer_codebook.pop(old_name, {})
|
177
200
|
|
178
201
|
return self
|
179
202
|
|
180
|
-
def _drop_question(self, question_name):
|
203
|
+
def _drop_question(self, question_name, ignore_missing=False):
|
181
204
|
"""Drop a question
|
182
205
|
|
183
206
|
>>> id = InputDataABC.example()
|
@@ -185,6 +208,11 @@ class InputDataABC(
|
|
185
208
|
['feeling']
|
186
209
|
|
187
210
|
"""
|
211
|
+
if question_name not in self.question_names:
|
212
|
+
if ignore_missing:
|
213
|
+
return self
|
214
|
+
else:
|
215
|
+
raise ValueError(f"Question {question_name} not found.")
|
188
216
|
idx = self.question_names.index(question_name)
|
189
217
|
self._question_names.pop(idx)
|
190
218
|
self._question_texts.pop(idx)
|
@@ -206,7 +234,7 @@ class InputDataABC(
|
|
206
234
|
self._drop_question(qn)
|
207
235
|
return self
|
208
236
|
|
209
|
-
def keep(self, *question_names_to_keep) -> "InputDataABC":
|
237
|
+
def keep(self, *question_names_to_keep, ignore_missing = False) -> "InputDataABC":
|
210
238
|
"""Keep a question.
|
211
239
|
|
212
240
|
>>> id = InputDataABC.example()
|
@@ -217,7 +245,7 @@ class InputDataABC(
|
|
217
245
|
all_question_names = self._question_names[:]
|
218
246
|
for qn in all_question_names:
|
219
247
|
if qn not in question_names_to_keep:
|
220
|
-
self._drop_question(qn)
|
248
|
+
self._drop_question(qn, ignore_missing = ignore_missing)
|
221
249
|
return self
|
222
250
|
|
223
251
|
def modify_question_type(
|
@@ -284,6 +312,7 @@ class InputDataABC(
|
|
284
312
|
"raw_data": self.raw_data,
|
285
313
|
"question_names": self.question_names,
|
286
314
|
"question_texts": self.question_texts,
|
315
|
+
"binary": self.binary,
|
287
316
|
"answer_codebook": self.answer_codebook,
|
288
317
|
"question_types": self.question_types,
|
289
318
|
}
|
edsl/coop/coop.py
CHANGED
@@ -5,8 +5,14 @@ import requests
|
|
5
5
|
from typing import Any, Optional, Union, Literal
|
6
6
|
from uuid import UUID
|
7
7
|
import edsl
|
8
|
-
from edsl import CONFIG, CacheEntry
|
9
|
-
from edsl.coop.utils import
|
8
|
+
from edsl import CONFIG, CacheEntry, Jobs
|
9
|
+
from edsl.coop.utils import (
|
10
|
+
EDSLObject,
|
11
|
+
ObjectRegistry,
|
12
|
+
ObjectType,
|
13
|
+
RemoteJobStatus,
|
14
|
+
VisibilityType,
|
15
|
+
)
|
10
16
|
|
11
17
|
|
12
18
|
class Coop:
|
@@ -494,6 +500,40 @@ class Coop:
|
|
494
500
|
################
|
495
501
|
# Remote Inference
|
496
502
|
################
|
503
|
+
def remote_inference_create(
|
504
|
+
self,
|
505
|
+
job: Jobs,
|
506
|
+
description: Optional[str] = None,
|
507
|
+
status: RemoteJobStatus = "queued",
|
508
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
509
|
+
) -> dict:
|
510
|
+
"""
|
511
|
+
Send a remote inference job to the server.
|
512
|
+
"""
|
513
|
+
response = self._send_server_request(
|
514
|
+
uri="api/v0/remote-inference",
|
515
|
+
method="POST",
|
516
|
+
payload={
|
517
|
+
"json_string": json.dumps(
|
518
|
+
job.to_dict(),
|
519
|
+
default=self._json_handle_none,
|
520
|
+
),
|
521
|
+
"description": description,
|
522
|
+
"status": status,
|
523
|
+
"visibility": visibility,
|
524
|
+
"version": self._edsl_version,
|
525
|
+
},
|
526
|
+
)
|
527
|
+
self._resolve_server_response(response)
|
528
|
+
response_json = response.json()
|
529
|
+
return {
|
530
|
+
"uuid": response_json.get("jobs_uuid"),
|
531
|
+
"description": response_json.get("description"),
|
532
|
+
"status": response_json.get("status"),
|
533
|
+
"visibility": response_json.get("visibility"),
|
534
|
+
"version": self._edsl_version,
|
535
|
+
}
|
536
|
+
|
497
537
|
def remote_inference_get(self, job_uuid: str) -> dict:
|
498
538
|
"""
|
499
539
|
Get the results of a remote inference job.
|
@@ -508,13 +548,34 @@ class Coop:
|
|
508
548
|
return {
|
509
549
|
"jobs_uuid": data.get("jobs_uuid"),
|
510
550
|
"results_uuid": data.get("results_uuid"),
|
511
|
-
"results_url": "
|
551
|
+
"results_url": f"{self.url}/content/{data.get('results_uuid')}",
|
512
552
|
"status": data.get("status"),
|
513
553
|
"reason": data.get("reason"),
|
514
554
|
"price": data.get("price"),
|
515
555
|
"version": data.get("version"),
|
516
556
|
}
|
517
557
|
|
558
|
+
def remote_inference_cost(
|
559
|
+
self,
|
560
|
+
job: Jobs,
|
561
|
+
) -> dict:
|
562
|
+
"""
|
563
|
+
Get the cost of a remote inference job.
|
564
|
+
"""
|
565
|
+
response = self._send_server_request(
|
566
|
+
uri="api/v0/remote-inference/cost",
|
567
|
+
method="POST",
|
568
|
+
payload={
|
569
|
+
"json_string": json.dumps(
|
570
|
+
job.to_dict(),
|
571
|
+
default=self._json_handle_none,
|
572
|
+
),
|
573
|
+
},
|
574
|
+
)
|
575
|
+
self._resolve_server_response(response)
|
576
|
+
response_json = response.json()
|
577
|
+
return response_json.get("cost")
|
578
|
+
|
518
579
|
################
|
519
580
|
# Remote Errors
|
520
581
|
################
|
@@ -705,18 +766,10 @@ if __name__ == "__main__":
|
|
705
766
|
##############
|
706
767
|
from edsl.jobs import Jobs
|
707
768
|
|
708
|
-
|
709
|
-
coop.
|
710
|
-
|
711
|
-
|
712
|
-
# post a job
|
713
|
-
response = coop.create(Jobs.example())
|
714
|
-
# get job and results
|
715
|
-
coop.remote_inference_get(response.get("uuid"))
|
716
|
-
coop.get(
|
717
|
-
object_type="results",
|
718
|
-
uuid=coop.remote_inference_get(response.get("uuid")).get("results_uuid"),
|
719
|
-
)
|
769
|
+
job = Jobs.example()
|
770
|
+
coop.remote_inference_cost(job)
|
771
|
+
results = coop.remote_inference_create(job)
|
772
|
+
coop.remote_inference_get(results.get("uuid"))
|
720
773
|
|
721
774
|
##############
|
722
775
|
# D. Errors
|
edsl/coop/utils.py
CHANGED
edsl/jobs/Jobs.py
CHANGED
@@ -312,10 +312,6 @@ class Jobs(Base):
|
|
312
312
|
# if no agents, models, or scenarios are set, set them to defaults
|
313
313
|
self.agents = self.agents or [Agent()]
|
314
314
|
self.models = self.models or [Model()]
|
315
|
-
# if remote, set all the models to remote
|
316
|
-
if hasattr(self, "remote") and self.remote:
|
317
|
-
for model in self.models:
|
318
|
-
model.remote = True
|
319
315
|
self.scenarios = self.scenarios or [Scenario()]
|
320
316
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
317
|
yield Interview(
|
@@ -368,7 +364,7 @@ class Jobs(Base):
|
|
368
364
|
if self.verbose:
|
369
365
|
print(message)
|
370
366
|
|
371
|
-
def _check_parameters(self, strict=False, warn
|
367
|
+
def _check_parameters(self, strict=False, warn=True) -> None:
|
372
368
|
"""Check if the parameters in the survey and scenarios are consistent.
|
373
369
|
|
374
370
|
>>> from edsl import QuestionFreeText
|
@@ -413,15 +409,13 @@ class Jobs(Base):
|
|
413
409
|
progress_bar: bool = False,
|
414
410
|
stop_on_exception: bool = False,
|
415
411
|
cache: Union[Cache, bool] = None,
|
416
|
-
remote: bool = (
|
417
|
-
False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
|
418
|
-
),
|
419
412
|
check_api_keys: bool = False,
|
420
413
|
sidecar_model: Optional[LanguageModel] = None,
|
421
414
|
batch_mode: Optional[bool] = None,
|
422
415
|
verbose: bool = False,
|
423
416
|
print_exceptions=True,
|
424
417
|
remote_cache_description: Optional[str] = None,
|
418
|
+
remote_inference_description: Optional[str] = None,
|
425
419
|
) -> Results:
|
426
420
|
"""
|
427
421
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -431,11 +425,11 @@ class Jobs(Base):
|
|
431
425
|
:param progress_bar: shows a progress bar
|
432
426
|
:param stop_on_exception: stops the job if an exception is raised
|
433
427
|
:param cache: a cache object to store results
|
434
|
-
:param remote: run the job remotely
|
435
428
|
:param check_api_keys: check if the API keys are valid
|
436
429
|
:param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
|
437
430
|
:param verbose: prints messages
|
438
431
|
:param remote_cache_description: specifies a description for this group of entries in the remote cache
|
432
|
+
:param remote_inference_description: specifies a description for the remote inference job
|
439
433
|
"""
|
440
434
|
from edsl.coop.coop import Coop
|
441
435
|
|
@@ -446,21 +440,33 @@ class Jobs(Base):
|
|
446
440
|
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
447
441
|
)
|
448
442
|
|
449
|
-
self.remote = remote
|
450
443
|
self.verbose = verbose
|
451
444
|
|
452
445
|
try:
|
453
446
|
coop = Coop()
|
454
|
-
|
447
|
+
user_edsl_settings = coop.edsl_settings
|
448
|
+
remote_cache = user_edsl_settings["remote_caching"]
|
449
|
+
remote_inference = user_edsl_settings["remote_inference"]
|
455
450
|
except Exception:
|
456
451
|
remote_cache = False
|
452
|
+
remote_inference = False
|
457
453
|
|
458
|
-
if
|
459
|
-
|
460
|
-
if
|
461
|
-
|
454
|
+
if remote_inference:
|
455
|
+
self._output("Remote inference activated. Sending job to server...")
|
456
|
+
if remote_cache:
|
457
|
+
self._output(
|
458
|
+
"Remote caching activated. The remote cache will be used for this job."
|
459
|
+
)
|
462
460
|
|
463
|
-
|
461
|
+
remote_job_data = coop.remote_inference_create(
|
462
|
+
self,
|
463
|
+
description=remote_inference_description,
|
464
|
+
status="queued",
|
465
|
+
)
|
466
|
+
self._output("Job sent!")
|
467
|
+
self._output(remote_job_data)
|
468
|
+
return remote_job_data
|
469
|
+
else:
|
464
470
|
if check_api_keys:
|
465
471
|
for model in self.models + [Model()]:
|
466
472
|
if not model.has_valid_api_key():
|
edsl/notebooks/Notebook.py
CHANGED
@@ -56,6 +56,36 @@ class Notebook(Base):
|
|
56
56
|
|
57
57
|
self.name = name or self.default_name
|
58
58
|
|
59
|
+
@classmethod
|
60
|
+
def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
|
61
|
+
# Read the script file
|
62
|
+
with open(path, 'r') as script_file:
|
63
|
+
script_content = script_file.read()
|
64
|
+
|
65
|
+
# Create a new Jupyter notebook
|
66
|
+
nb = nbformat.v4.new_notebook()
|
67
|
+
|
68
|
+
# Add the script content to the first cell
|
69
|
+
first_cell = nbformat.v4.new_code_cell(script_content)
|
70
|
+
nb.cells.append(first_cell)
|
71
|
+
|
72
|
+
# Create a Notebook instance with the notebook data
|
73
|
+
notebook_instance = cls(nb)
|
74
|
+
|
75
|
+
return notebook_instance
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def from_current_script(cls) -> "Notebook":
|
79
|
+
import inspect
|
80
|
+
import os
|
81
|
+
# Get the path to the current file
|
82
|
+
current_frame = inspect.currentframe()
|
83
|
+
caller_frame = inspect.getouterframes(current_frame, 2)
|
84
|
+
current_file_path = os.path.abspath(caller_frame[1].filename)
|
85
|
+
|
86
|
+
# Use from_script to create the notebook
|
87
|
+
return cls.from_script(current_file_path)
|
88
|
+
|
59
89
|
def __eq__(self, other):
|
60
90
|
"""
|
61
91
|
Check if two Notebooks are equal.
|
edsl/prompts/Prompt.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1
|
-
"""Class for creating prompts to be used in a survey."""
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
from typing import Optional
|
5
3
|
from abc import ABC
|
6
4
|
from typing import Any, List
|
7
5
|
|
8
6
|
from rich.table import Table
|
9
|
-
from jinja2 import Template, Environment, meta, TemplateSyntaxError
|
7
|
+
from jinja2 import Template, Environment, meta, TemplateSyntaxError, Undefined
|
8
|
+
|
9
|
+
|
10
|
+
class PreserveUndefined(Undefined):
|
11
|
+
def __str__(self):
|
12
|
+
return "{{ " + self._undefined_name + " }}"
|
13
|
+
|
10
14
|
|
11
15
|
from edsl.exceptions.prompts import TemplateRenderError
|
12
16
|
from edsl.prompts.prompt_config import (
|
@@ -35,6 +39,10 @@ class PromptBase(
|
|
35
39
|
|
36
40
|
return data_to_html(self.to_dict())
|
37
41
|
|
42
|
+
def __len__(self):
|
43
|
+
"""Return the length of the prompt text."""
|
44
|
+
return len(self.text)
|
45
|
+
|
38
46
|
@classmethod
|
39
47
|
def prompt_attributes(cls) -> List[str]:
|
40
48
|
"""Return the prompt class attributes."""
|
@@ -75,10 +83,10 @@ class PromptBase(
|
|
75
83
|
>>> p = Prompt("Hello, {{person}}")
|
76
84
|
>>> p2 = Prompt("How are you?")
|
77
85
|
>>> p + p2
|
78
|
-
Prompt(text
|
86
|
+
Prompt(text=\"""Hello, {{person}}How are you?\""")
|
79
87
|
|
80
88
|
>>> p + "How are you?"
|
81
|
-
Prompt(text
|
89
|
+
Prompt(text=\"""Hello, {{person}}How are you?\""")
|
82
90
|
"""
|
83
91
|
if isinstance(other_prompt, str):
|
84
92
|
return self.__class__(self.text + other_prompt)
|
@@ -114,7 +122,7 @@ class PromptBase(
|
|
114
122
|
Example:
|
115
123
|
>>> p = Prompt("Hello, {{person}}")
|
116
124
|
>>> p
|
117
|
-
Prompt(text
|
125
|
+
Prompt(text=\"""Hello, {{person}}\""")
|
118
126
|
"""
|
119
127
|
return f'Prompt(text="""{self.text}""")'
|
120
128
|
|
@@ -137,7 +145,7 @@ class PromptBase(
|
|
137
145
|
:param template: The template to find the variables in.
|
138
146
|
|
139
147
|
"""
|
140
|
-
env = Environment()
|
148
|
+
env = Environment(undefined=PreserveUndefined)
|
141
149
|
ast = env.parse(template)
|
142
150
|
return list(meta.find_undeclared_variables(ast))
|
143
151
|
|
@@ -186,13 +194,16 @@ class PromptBase(
|
|
186
194
|
|
187
195
|
>>> p = Prompt("Hello, {{person}}")
|
188
196
|
>>> p.render({"person": "John"})
|
189
|
-
|
197
|
+
Prompt(text=\"""Hello, John\""")
|
190
198
|
|
191
199
|
>>> p.render({"person": "Mr. {{last_name}}", "last_name": "Horton"})
|
192
|
-
|
200
|
+
Prompt(text=\"""Hello, Mr. Horton\""")
|
193
201
|
|
194
202
|
>>> p.render({"person": "Mr. {{last_name}}", "last_name": "Ho{{letter}}ton"}, max_nesting = 1)
|
195
|
-
|
203
|
+
Prompt(text=\"""Hello, Mr. Ho{{ letter }}ton\""")
|
204
|
+
|
205
|
+
>>> p.render({"person": "Mr. {{last_name}}"})
|
206
|
+
Prompt(text=\"""Hello, Mr. {{ last_name }}\""")
|
196
207
|
"""
|
197
208
|
new_text = self._render(
|
198
209
|
self.text, primary_replacement, **additional_replacements
|
@@ -216,12 +227,13 @@ class PromptBase(
|
|
216
227
|
>>> codebook = {"age": "Age"}
|
217
228
|
>>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
|
218
229
|
>>> p.render({"name": "John", "age": 44}, codebook=codebook)
|
219
|
-
|
230
|
+
Prompt(text=\"""You are an agent named John. Age: 44\""")
|
220
231
|
"""
|
232
|
+
env = Environment(undefined=PreserveUndefined)
|
221
233
|
try:
|
222
234
|
previous_text = None
|
223
235
|
for _ in range(MAX_NESTING):
|
224
|
-
rendered_text =
|
236
|
+
rendered_text = env.from_string(text).render(
|
225
237
|
primary_replacement, **additional_replacements
|
226
238
|
)
|
227
239
|
if rendered_text == previous_text:
|
@@ -258,7 +270,7 @@ class PromptBase(
|
|
258
270
|
>>> p = Prompt("Hello, {{person}}")
|
259
271
|
>>> p2 = Prompt.from_dict(p.to_dict())
|
260
272
|
>>> p2
|
261
|
-
Prompt(text
|
273
|
+
Prompt(text=\"""Hello, {{person}}\""")
|
262
274
|
|
263
275
|
"""
|
264
276
|
class_name = data["class_name"]
|
@@ -290,6 +302,12 @@ class Prompt(PromptBase):
|
|
290
302
|
component_type = ComponentTypes.GENERIC
|
291
303
|
|
292
304
|
|
305
|
+
if __name__ == "__main__":
|
306
|
+
print("Running doctests...")
|
307
|
+
import doctest
|
308
|
+
|
309
|
+
doctest.testmod()
|
310
|
+
|
293
311
|
from edsl.prompts.library.question_multiple_choice import *
|
294
312
|
from edsl.prompts.library.agent_instructions import *
|
295
313
|
from edsl.prompts.library.agent_persona import *
|
@@ -302,9 +320,3 @@ from edsl.prompts.library.question_numerical import *
|
|
302
320
|
from edsl.prompts.library.question_rank import *
|
303
321
|
from edsl.prompts.library.question_extract import *
|
304
322
|
from edsl.prompts.library.question_list import *
|
305
|
-
|
306
|
-
|
307
|
-
if __name__ == "__main__":
|
308
|
-
import doctest
|
309
|
-
|
310
|
-
doctest.testmod()
|
edsl/questions/QuestionBase.py
CHANGED
@@ -173,15 +173,16 @@ class QuestionBase(
|
|
173
173
|
def add_model_instructions(
|
174
174
|
self, *, instructions: str, model: Optional[str] = None
|
175
175
|
) -> None:
|
176
|
-
"""Add model-specific instructions for the question.
|
176
|
+
"""Add model-specific instructions for the question that override the default instructions.
|
177
177
|
|
178
178
|
:param instructions: The instructions to add. This is typically a jinja2 template.
|
179
179
|
:param model: The language model for this instruction.
|
180
180
|
|
181
181
|
>>> from edsl.questions import QuestionFreeText
|
182
182
|
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
|
183
|
-
>>> q.add_model_instructions(instructions = "Answer in valid JSON like so {'answer': 'comment: <>}", model = "gpt3")
|
184
|
-
|
183
|
+
>>> q.add_model_instructions(instructions = "{{question_text}}. Answer in valid JSON like so {'answer': 'comment: <>}", model = "gpt3")
|
184
|
+
>>> q.get_instructions(model = "gpt3")
|
185
|
+
Prompt(text=\"""{{question_text}}. Answer in valid JSON like so {'answer': 'comment: <>}\""")
|
185
186
|
"""
|
186
187
|
from edsl import Model
|
187
188
|
|
@@ -201,6 +202,13 @@ class QuestionBase(
|
|
201
202
|
"""Get the mathcing question-answering instructions for the question.
|
202
203
|
|
203
204
|
:param model: The language model to use.
|
205
|
+
|
206
|
+
>>> from edsl import QuestionFreeText
|
207
|
+
>>> QuestionFreeText.example().get_instructions()
|
208
|
+
Prompt(text=\"""You are being asked the following question: {{question_text}}
|
209
|
+
Return a valid JSON formatted like this:
|
210
|
+
{"answer": "<put free text answer here>"}
|
211
|
+
\""")
|
204
212
|
"""
|
205
213
|
from edsl.prompts.Prompt import Prompt
|
206
214
|
|
@@ -293,7 +301,16 @@ class QuestionBase(
|
|
293
301
|
print_json(json.dumps(self.to_dict()))
|
294
302
|
|
295
303
|
def __call__(self, just_answer=True, model=None, agent=None, **kwargs):
|
296
|
-
"""Call the question.
|
304
|
+
"""Call the question.
|
305
|
+
|
306
|
+
>>> from edsl.language_models import LanguageModel
|
307
|
+
>>> m = LanguageModel.example(canned_response = "Yo, what's up?", test_model = True)
|
308
|
+
>>> from edsl import QuestionFreeText
|
309
|
+
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
|
310
|
+
>>> q(model = m)
|
311
|
+
"Yo, what's up?"
|
312
|
+
|
313
|
+
"""
|
297
314
|
survey = self.to_survey()
|
298
315
|
results = survey(model=model, agent=agent, **kwargs)
|
299
316
|
if just_answer:
|
@@ -304,7 +321,6 @@ class QuestionBase(
|
|
304
321
|
async def run_async(self, just_answer=True, model=None, agent=None, **kwargs):
|
305
322
|
"""Call the question."""
|
306
323
|
survey = self.to_survey()
|
307
|
-
## asyncio.run(survey.async_call());
|
308
324
|
results = await survey.run_async(model=model, agent=agent, **kwargs)
|
309
325
|
if just_answer:
|
310
326
|
return results.select(f"answer.{self.question_name}").first()
|
@@ -383,29 +399,34 @@ class QuestionBase(
|
|
383
399
|
s = Survey([self, other])
|
384
400
|
return s
|
385
401
|
|
386
|
-
def to_survey(self):
|
402
|
+
def to_survey(self) -> "Survey":
|
387
403
|
"""Turn a single question into a survey."""
|
388
404
|
from edsl.surveys.Survey import Survey
|
389
405
|
|
390
406
|
s = Survey([self])
|
391
407
|
return s
|
392
408
|
|
393
|
-
def run(self, *args, **kwargs):
|
409
|
+
def run(self, *args, **kwargs) -> "Results":
|
394
410
|
"""Turn a single question into a survey and run it."""
|
395
411
|
from edsl.surveys.Survey import Survey
|
396
412
|
|
397
413
|
s = self.to_survey()
|
398
414
|
return s.run(*args, **kwargs)
|
399
415
|
|
400
|
-
def by(self, *args):
|
401
|
-
"""Turn a single question into a survey and
|
416
|
+
def by(self, *args) -> "Jobs":
|
417
|
+
"""Turn a single question into a survey and then a Job."""
|
402
418
|
from edsl.surveys.Survey import Survey
|
403
419
|
|
404
420
|
s = Survey([self])
|
405
421
|
return s.by(*args)
|
406
422
|
|
407
|
-
def human_readable(self):
|
408
|
-
"""Print the question in a human readable format.
|
423
|
+
def human_readable(self) -> str:
|
424
|
+
"""Print the question in a human readable format.
|
425
|
+
|
426
|
+
>>> from edsl.questions import QuestionFreeText
|
427
|
+
>>> QuestionFreeText.example().human_readable()
|
428
|
+
'Question Type: free_text\\nQuestion: How are you?'
|
429
|
+
"""
|
409
430
|
lines = []
|
410
431
|
lines.append(f"Question Type: {self.question_type}")
|
411
432
|
lines.append(f"Question: {self.question_text}")
|
edsl/questions/settings.py
CHANGED
edsl/results/Dataset.py
CHANGED
@@ -77,6 +77,28 @@ class Dataset(UserList, ResultsExportMixin):
|
|
77
77
|
return list(d.values())[0]
|
78
78
|
|
79
79
|
return get_values(self.data[0])[0]
|
80
|
+
|
81
|
+
def select(self, *keys):
|
82
|
+
"""Return a new dataset with only the selected keys.
|
83
|
+
|
84
|
+
:param keys: The keys to select.
|
85
|
+
|
86
|
+
>>> d = Dataset([{'a.b':[1,2,3,4]}, {'c.d':[5,6,7,8]}])
|
87
|
+
>>> d.select('a.b')
|
88
|
+
Dataset([{'a.b': [1, 2, 3, 4]}])
|
89
|
+
|
90
|
+
>>> d.select('a.b', 'c.d')
|
91
|
+
Dataset([{'a.b': [1, 2, 3, 4]}, {'c.d': [5, 6, 7, 8]}])
|
92
|
+
"""
|
93
|
+
if isinstance(keys, str):
|
94
|
+
keys = [keys]
|
95
|
+
|
96
|
+
new_data = []
|
97
|
+
for observation in self.data:
|
98
|
+
observation_key = list(observation.keys())[0]
|
99
|
+
if observation_key in keys:
|
100
|
+
new_data.append(observation)
|
101
|
+
return Dataset(new_data)
|
80
102
|
|
81
103
|
def _repr_html_(self) -> str:
|
82
104
|
"""Return an HTML representation of the dataset."""
|
@@ -222,6 +244,15 @@ class Dataset(UserList, ResultsExportMixin):
|
|
222
244
|
new_data.append({key: new_values})
|
223
245
|
|
224
246
|
return Dataset(new_data)
|
247
|
+
|
248
|
+
@classmethod
|
249
|
+
def example(self):
|
250
|
+
"""Return an example dataset.
|
251
|
+
|
252
|
+
>>> Dataset.example()
|
253
|
+
Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
|
254
|
+
"""
|
255
|
+
return Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
|
225
256
|
|
226
257
|
|
227
258
|
if __name__ == "__main__":
|