edsl 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +18 -18
- edsl/__init__.py +24 -24
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +77 -41
- edsl/agents/AgentList.py +35 -6
- edsl/agents/Invigilator.py +19 -1
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/InputData.py +39 -8
- edsl/coop/coop.py +188 -151
- edsl/coop/utils.py +43 -75
- edsl/data/Cache.py +19 -5
- edsl/data/SQLiteDict.py +11 -3
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +92 -47
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +3 -9
- edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +15 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +21 -25
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +5 -11
- edsl/language_models/ModelList.py +3 -3
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +40 -3
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +38 -13
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +16 -8
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +4 -2
- edsl/questions/question_registry.py +20 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +493 -0
- edsl/results/Result.py +22 -74
- edsl/results/Results.py +105 -67
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +5 -6
- edsl/scenarios/ScenarioList.py +44 -15
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/ObjectEntry.py +89 -13
- edsl/study/ProofOfWork.py +5 -2
- edsl/study/SnapShot.py +4 -8
- edsl/study/Study.py +21 -14
- edsl/study/__init__.py +2 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +46 -7
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/METADATA +11 -10
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/RECORD +75 -72
- edsl-0.1.28.dist-info/entry_points.txt +0 -3
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
from typing import Union, List, Any
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
|
-
from collections import UserDict
|
5
|
-
from matplotlib import pyplot as plt
|
6
4
|
|
7
5
|
|
8
6
|
class TokenBucket:
|
@@ -114,6 +112,7 @@ class TokenBucket:
|
|
114
112
|
times, tokens = zip(*self.get_log())
|
115
113
|
start_time = times[0]
|
116
114
|
times = [t - start_time for t in times] # Normalize time to start from 0
|
115
|
+
from matplotlib import pyplot as plt
|
117
116
|
|
118
117
|
plt.figure(figsize=(10, 6))
|
119
118
|
plt.plot(times, tokens, label="Tokens Available")
|
@@ -6,15 +6,9 @@ import asyncio
|
|
6
6
|
import time
|
7
7
|
from typing import Any, Type, List, Generator, Optional
|
8
8
|
|
9
|
-
from edsl.agents import Agent
|
10
|
-
from edsl.language_models import LanguageModel
|
11
|
-
from edsl.scenarios import Scenario
|
12
|
-
from edsl.surveys import Survey
|
13
|
-
|
14
9
|
from edsl.jobs.Answers import Answers
|
15
10
|
from edsl.surveys.base import EndOfSurvey
|
16
11
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
17
|
-
|
18
12
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
19
13
|
|
20
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
@@ -60,9 +54,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
60
54
|
self.debug = debug
|
61
55
|
self.iteration = iteration
|
62
56
|
self.cache = cache
|
63
|
-
self.answers: dict[
|
64
|
-
|
65
|
-
|
57
|
+
self.answers: dict[str, str] = (
|
58
|
+
Answers()
|
59
|
+
) # will get filled in as interview progresses
|
66
60
|
self.sidecar_model = sidecar_model
|
67
61
|
|
68
62
|
# Trackers
|
@@ -17,9 +17,9 @@ class InterviewStatusMixin:
|
|
17
17
|
The keys are the question names; the values are the lists of status log changes for each task.
|
18
18
|
"""
|
19
19
|
for task_creator in self.task_creators.values():
|
20
|
-
self._task_status_log_dict[
|
21
|
-
task_creator.
|
22
|
-
|
20
|
+
self._task_status_log_dict[task_creator.question.question_name] = (
|
21
|
+
task_creator.status_log
|
22
|
+
)
|
23
23
|
return self._task_status_log_dict
|
24
24
|
|
25
25
|
@property
|
@@ -5,17 +5,19 @@ import asyncio
|
|
5
5
|
import time
|
6
6
|
import traceback
|
7
7
|
from typing import Generator, Union
|
8
|
+
|
8
9
|
from edsl import CONFIG
|
9
10
|
from edsl.exceptions import InterviewTimeoutError
|
10
|
-
|
11
|
-
from edsl.questions.QuestionBase import QuestionBase
|
11
|
+
|
12
|
+
# from edsl.questions.QuestionBase import QuestionBase
|
12
13
|
from edsl.surveys.base import EndOfSurvey
|
13
14
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
14
15
|
from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
|
15
16
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
16
17
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
17
18
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
18
|
-
|
19
|
+
|
20
|
+
# from edsl.agents.InvigilatorBase import InvigilatorBase
|
19
21
|
|
20
22
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
21
23
|
|
@@ -44,6 +46,7 @@ class InterviewTaskBuildingMixin:
|
|
44
46
|
scenario=self.scenario,
|
45
47
|
model=self.model,
|
46
48
|
debug=debug,
|
49
|
+
survey=self.survey,
|
47
50
|
memory_plan=self.survey.memory_plan,
|
48
51
|
current_answers=self.answers,
|
49
52
|
iteration=self.iteration,
|
@@ -149,15 +152,17 @@ class InterviewTaskBuildingMixin:
|
|
149
152
|
async def _answer_question_and_record_task(
|
150
153
|
self,
|
151
154
|
*,
|
152
|
-
question: QuestionBase,
|
155
|
+
question: "QuestionBase",
|
153
156
|
debug: bool,
|
154
157
|
task=None,
|
155
|
-
) -> AgentResponseDict:
|
158
|
+
) -> "AgentResponseDict":
|
156
159
|
"""Answer a question and records the task.
|
157
160
|
|
158
161
|
This in turn calls the the passed-in agent's async_answer_question method, which returns a response dictionary.
|
159
162
|
Note that is updates answers dictionary with the response.
|
160
163
|
"""
|
164
|
+
from edsl.data_transfer_models import AgentResponseDict
|
165
|
+
|
161
166
|
try:
|
162
167
|
invigilator = self._get_invigilator(question, debug=debug)
|
163
168
|
|
@@ -253,11 +258,11 @@ class InterviewTaskBuildingMixin:
|
|
253
258
|
"""
|
254
259
|
current_question_index: int = self.to_index[current_question.question_name]
|
255
260
|
|
256
|
-
next_question: Union[
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
+
next_question: Union[int, EndOfSurvey] = (
|
262
|
+
self.survey.rule_collection.next_question(
|
263
|
+
q_now=current_question_index,
|
264
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
265
|
+
)
|
261
266
|
)
|
262
267
|
|
263
268
|
next_question_index = next_question.next_q
|
@@ -1,29 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import time
|
3
3
|
import asyncio
|
4
|
-
import
|
4
|
+
import time
|
5
5
|
from contextlib import contextmanager
|
6
6
|
|
7
7
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union
|
8
8
|
|
9
|
-
from rich.live import Live
|
10
|
-
from rich.console import Console
|
11
|
-
|
12
9
|
from edsl import shared_globals
|
13
|
-
from edsl.results import Results, Result
|
14
|
-
|
15
10
|
from edsl.jobs.interviews.Interview import Interview
|
16
|
-
from edsl.utilities.decorators import jupyter_nb_handler
|
17
|
-
|
18
|
-
# from edsl.jobs.Jobs import Jobs
|
19
11
|
from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
|
20
|
-
from edsl.language_models import LanguageModel
|
21
|
-
from edsl.data.Cache import Cache
|
22
|
-
|
23
12
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
24
13
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
25
|
-
|
26
|
-
import time
|
14
|
+
from edsl.utilities.decorators import jupyter_nb_handler
|
27
15
|
|
28
16
|
|
29
17
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
@@ -42,13 +30,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
42
30
|
|
43
31
|
async def run_async_generator(
|
44
32
|
self,
|
45
|
-
cache: Cache,
|
33
|
+
cache: "Cache",
|
46
34
|
n: int = 1,
|
47
35
|
debug: bool = False,
|
48
36
|
stop_on_exception: bool = False,
|
49
37
|
sidecar_model: "LanguageModel" = None,
|
50
38
|
total_interviews: Optional[List["Interview"]] = None,
|
51
|
-
) -> AsyncGenerator[Result, None]:
|
39
|
+
) -> AsyncGenerator["Result", None]:
|
52
40
|
"""Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
53
41
|
|
54
42
|
Completed tasks are yielded as they are completed.
|
@@ -155,19 +143,21 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
155
143
|
|
156
144
|
prompt_dictionary = {}
|
157
145
|
for answer_key_name in answer_key_names:
|
158
|
-
prompt_dictionary[
|
159
|
-
answer_key_name
|
160
|
-
|
161
|
-
prompt_dictionary[
|
162
|
-
answer_key_name
|
163
|
-
|
146
|
+
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
147
|
+
question_name_to_prompts[answer_key_name]["user_prompt"]
|
148
|
+
)
|
149
|
+
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
150
|
+
question_name_to_prompts[answer_key_name]["system_prompt"]
|
151
|
+
)
|
164
152
|
|
165
153
|
raw_model_results_dictionary = {}
|
166
154
|
for result in valid_results:
|
167
155
|
question_name = result["question_name"]
|
168
|
-
raw_model_results_dictionary[
|
169
|
-
|
170
|
-
|
156
|
+
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
157
|
+
result["raw_model_response"]
|
158
|
+
)
|
159
|
+
|
160
|
+
from edsl.results.Result import Result
|
171
161
|
|
172
162
|
result = Result(
|
173
163
|
agent=interview.agent,
|
@@ -197,6 +187,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
197
187
|
print_exceptions: bool = True,
|
198
188
|
) -> "Coroutine":
|
199
189
|
"""Runs a collection of interviews, handling both async and sync contexts."""
|
190
|
+
from rich.console import Console
|
191
|
+
|
200
192
|
console = Console()
|
201
193
|
self.results = []
|
202
194
|
self.start_time = time.monotonic()
|
@@ -204,6 +196,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
204
196
|
self.cache = cache
|
205
197
|
self.sidecar_model = sidecar_model
|
206
198
|
|
199
|
+
from edsl.results.Results import Results
|
200
|
+
|
207
201
|
if not progress_bar:
|
208
202
|
# print("Running without progress bar")
|
209
203
|
with cache as c:
|
@@ -225,6 +219,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
225
219
|
results = Results(survey=self.jobs.survey, data=self.results)
|
226
220
|
else:
|
227
221
|
# print("Running with progress bar")
|
222
|
+
from rich.live import Live
|
223
|
+
from rich.console import Console
|
228
224
|
|
229
225
|
def generate_table():
|
230
226
|
return self.status_table(self.results, self.elapsed_time)
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1
1
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
2
|
-
from matplotlib import pyplot as plt
|
3
2
|
from typing import List, Optional
|
4
|
-
|
5
|
-
import matplotlib.pyplot as plt
|
6
3
|
from io import BytesIO
|
7
4
|
import base64
|
8
5
|
|
@@ -75,6 +72,8 @@ class TaskHistory:
|
|
75
72
|
|
76
73
|
def plot_completion_times(self):
|
77
74
|
"""Plot the completion times for each task."""
|
75
|
+
import matplotlib.pyplot as plt
|
76
|
+
|
78
77
|
updates = self.get_updates()
|
79
78
|
|
80
79
|
elapsed = [update.max_time - update.min_time for update in updates]
|
@@ -126,6 +125,8 @@ class TaskHistory:
|
|
126
125
|
rows = int(len(TaskStatus) ** 0.5) + 1
|
127
126
|
cols = (len(TaskStatus) + rows - 1) // rows # Ensure all plots fit
|
128
127
|
|
128
|
+
import matplotlib.pyplot as plt
|
129
|
+
|
129
130
|
fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
|
130
131
|
axes = axes.flatten() # Flatten in case of a single row/column
|
131
132
|
|
@@ -7,26 +7,18 @@ import asyncio
|
|
7
7
|
import json
|
8
8
|
import time
|
9
9
|
import os
|
10
|
-
|
11
10
|
from typing import Coroutine, Any, Callable, Type, List, get_type_hints
|
12
|
-
|
13
|
-
from abc import ABC, abstractmethod, ABCMeta
|
14
|
-
|
15
|
-
from rich.table import Table
|
11
|
+
from abc import ABC, abstractmethod
|
16
12
|
|
17
13
|
from edsl.config import CONFIG
|
18
14
|
|
19
|
-
from edsl.utilities.utilities import clean_json
|
20
15
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
21
16
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
17
|
+
|
22
18
|
from edsl.language_models.repair import repair
|
23
|
-
from edsl.exceptions.language_models import LanguageModelAttributeTypeError
|
24
19
|
from edsl.enums import InferenceServiceType
|
25
20
|
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
26
|
-
from edsl.data.Cache import Cache
|
27
21
|
from edsl.enums import service_to_api_keyname
|
28
|
-
|
29
|
-
|
30
22
|
from edsl.exceptions import MissingAPIKeyError
|
31
23
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
32
24
|
|
@@ -291,7 +283,7 @@ class LanguageModel(
|
|
291
283
|
self,
|
292
284
|
user_prompt: str,
|
293
285
|
system_prompt: str,
|
294
|
-
cache,
|
286
|
+
cache: "Cache",
|
295
287
|
iteration: int = 0,
|
296
288
|
encoded_image=None,
|
297
289
|
) -> tuple[dict, bool, str]:
|
@@ -490,6 +482,8 @@ class LanguageModel(
|
|
490
482
|
|
491
483
|
def rich_print(self):
|
492
484
|
"""Display an object as a table."""
|
485
|
+
from rich.table import Table
|
486
|
+
|
493
487
|
table = Table(title="Language Model")
|
494
488
|
table.add_column("Attribute", style="bold")
|
495
489
|
table.add_column("Value")
|
@@ -5,7 +5,7 @@ from edsl import Model
|
|
5
5
|
from edsl.language_models import LanguageModel
|
6
6
|
from edsl.Base import Base
|
7
7
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
8
|
-
from edsl.utilities import is_valid_variable_name
|
8
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
9
9
|
from edsl.utilities.utilities import dict_hash
|
10
10
|
|
11
11
|
|
@@ -56,11 +56,11 @@ class ModelList(Base, UserList):
|
|
56
56
|
return {"models": [model._to_dict() for model in self]}
|
57
57
|
|
58
58
|
@classmethod
|
59
|
-
def from_names(self, *args):
|
59
|
+
def from_names(self, *args, **kwargs):
|
60
60
|
"""A a model list from a list of names"""
|
61
61
|
if len(args) == 1 and isinstance(args[0], list):
|
62
62
|
args = args[0]
|
63
|
-
return ModelList([Model(model_name) for model_name in args])
|
63
|
+
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
64
64
|
|
65
65
|
@add_edsl_version
|
66
66
|
def to_dict(self):
|
edsl/language_models/repair.py
CHANGED
@@ -1,18 +1,13 @@
|
|
1
1
|
import json
|
2
2
|
import asyncio
|
3
3
|
import warnings
|
4
|
-
from rich import print
|
5
|
-
from rich.console import Console
|
6
|
-
from rich.syntax import Syntax
|
7
|
-
|
8
|
-
from edsl.utilities.utilities import clean_json
|
9
|
-
|
10
|
-
from edsl.utilities.repair_functions import extract_json_from_string
|
11
4
|
|
12
5
|
|
13
6
|
async def async_repair(
|
14
7
|
bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
|
15
8
|
):
|
9
|
+
from edsl.utilities.utilities import clean_json
|
10
|
+
|
16
11
|
s = clean_json(bad_json)
|
17
12
|
|
18
13
|
try:
|
@@ -27,6 +22,8 @@ async def async_repair(
|
|
27
22
|
return valid_dict, success
|
28
23
|
|
29
24
|
try:
|
25
|
+
from edsl.utilities.repair_functions import extract_json_from_string
|
26
|
+
|
30
27
|
valid_dict = extract_json_from_string(s)
|
31
28
|
success = True
|
32
29
|
except ValueError:
|
@@ -98,6 +95,10 @@ async def async_repair(
|
|
98
95
|
except json.JSONDecodeError:
|
99
96
|
valid_dict = {}
|
100
97
|
success = False
|
98
|
+
from rich import print
|
99
|
+
from rich.console import Console
|
100
|
+
from rich.syntax import Syntax
|
101
|
+
|
101
102
|
console = Console()
|
102
103
|
error_message = (
|
103
104
|
f"All repairs. failed. LLM Model given [red]{str(bad_json)}[/red]"
|
edsl/notebooks/Notebook.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
1
|
"""A Notebook is a utility class that allows you to easily share/pull ipynbs from Coop."""
|
2
2
|
|
3
3
|
import json
|
4
|
-
import nbformat
|
5
|
-
from nbconvert import HTMLExporter
|
6
4
|
from typing import Dict, List, Optional
|
7
|
-
|
5
|
+
|
6
|
+
|
8
7
|
from edsl.Base import Base
|
9
8
|
from edsl.utilities.decorators import (
|
10
9
|
add_edsl_version,
|
@@ -34,6 +33,8 @@ class Notebook(Base):
|
|
34
33
|
If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
|
35
34
|
:param name: A name for the Notebook.
|
36
35
|
"""
|
36
|
+
import nbformat
|
37
|
+
|
37
38
|
# Load current notebook path as fallback (VS Code only)
|
38
39
|
path = path or globals().get("__vsc_ipynb_file__")
|
39
40
|
if data is not None:
|
@@ -56,6 +57,37 @@ class Notebook(Base):
|
|
56
57
|
|
57
58
|
self.name = name or self.default_name
|
58
59
|
|
60
|
+
@classmethod
|
61
|
+
def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
|
62
|
+
# Read the script file
|
63
|
+
with open(path, "r") as script_file:
|
64
|
+
script_content = script_file.read()
|
65
|
+
|
66
|
+
# Create a new Jupyter notebook
|
67
|
+
nb = nbformat.v4.new_notebook()
|
68
|
+
|
69
|
+
# Add the script content to the first cell
|
70
|
+
first_cell = nbformat.v4.new_code_cell(script_content)
|
71
|
+
nb.cells.append(first_cell)
|
72
|
+
|
73
|
+
# Create a Notebook instance with the notebook data
|
74
|
+
notebook_instance = cls(nb)
|
75
|
+
|
76
|
+
return notebook_instance
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def from_current_script(cls) -> "Notebook":
|
80
|
+
import inspect
|
81
|
+
import os
|
82
|
+
|
83
|
+
# Get the path to the current file
|
84
|
+
current_frame = inspect.currentframe()
|
85
|
+
caller_frame = inspect.getouterframes(current_frame, 2)
|
86
|
+
current_file_path = os.path.abspath(caller_frame[1].filename)
|
87
|
+
|
88
|
+
# Use from_script to create the notebook
|
89
|
+
return cls.from_script(current_file_path)
|
90
|
+
|
59
91
|
def __eq__(self, other):
|
60
92
|
"""
|
61
93
|
Check if two Notebooks are equal.
|
@@ -103,6 +135,9 @@ class Notebook(Base):
|
|
103
135
|
"""
|
104
136
|
Return HTML representation of Notebook.
|
105
137
|
"""
|
138
|
+
from nbconvert import HTMLExporter
|
139
|
+
import nbformat
|
140
|
+
|
106
141
|
notebook = nbformat.from_dict(self.data)
|
107
142
|
html_exporter = HTMLExporter(template_name="basic")
|
108
143
|
(body, _) = html_exporter.from_notebook_node(notebook)
|
@@ -143,6 +178,8 @@ class Notebook(Base):
|
|
143
178
|
"""
|
144
179
|
Display a Notebook as a rich table.
|
145
180
|
"""
|
181
|
+
from rich.table import Table
|
182
|
+
|
146
183
|
table_data, column_names = self._table()
|
147
184
|
table = Table(title=f"{self.__class__.__name__} Attributes")
|
148
185
|
for column in column_names:
|
edsl/prompts/Prompt.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1
|
-
"""Class for creating prompts to be used in a survey."""
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
from typing import Optional
|
5
3
|
from abc import ABC
|
6
4
|
from typing import Any, List
|
7
5
|
|
8
6
|
from rich.table import Table
|
9
|
-
from jinja2 import Template, Environment, meta, TemplateSyntaxError
|
7
|
+
from jinja2 import Template, Environment, meta, TemplateSyntaxError, Undefined
|
8
|
+
|
9
|
+
|
10
|
+
class PreserveUndefined(Undefined):
|
11
|
+
def __str__(self):
|
12
|
+
return "{{ " + self._undefined_name + " }}"
|
13
|
+
|
10
14
|
|
11
15
|
from edsl.exceptions.prompts import TemplateRenderError
|
12
16
|
from edsl.prompts.prompt_config import (
|
@@ -35,6 +39,10 @@ class PromptBase(
|
|
35
39
|
|
36
40
|
return data_to_html(self.to_dict())
|
37
41
|
|
42
|
+
def __len__(self):
|
43
|
+
"""Return the length of the prompt text."""
|
44
|
+
return len(self.text)
|
45
|
+
|
38
46
|
@classmethod
|
39
47
|
def prompt_attributes(cls) -> List[str]:
|
40
48
|
"""Return the prompt class attributes."""
|
@@ -75,10 +83,10 @@ class PromptBase(
|
|
75
83
|
>>> p = Prompt("Hello, {{person}}")
|
76
84
|
>>> p2 = Prompt("How are you?")
|
77
85
|
>>> p + p2
|
78
|
-
Prompt(text
|
86
|
+
Prompt(text=\"""Hello, {{person}}How are you?\""")
|
79
87
|
|
80
88
|
>>> p + "How are you?"
|
81
|
-
Prompt(text
|
89
|
+
Prompt(text=\"""Hello, {{person}}How are you?\""")
|
82
90
|
"""
|
83
91
|
if isinstance(other_prompt, str):
|
84
92
|
return self.__class__(self.text + other_prompt)
|
@@ -114,7 +122,7 @@ class PromptBase(
|
|
114
122
|
Example:
|
115
123
|
>>> p = Prompt("Hello, {{person}}")
|
116
124
|
>>> p
|
117
|
-
Prompt(text
|
125
|
+
Prompt(text=\"""Hello, {{person}}\""")
|
118
126
|
"""
|
119
127
|
return f'Prompt(text="""{self.text}""")'
|
120
128
|
|
@@ -137,7 +145,7 @@ class PromptBase(
|
|
137
145
|
:param template: The template to find the variables in.
|
138
146
|
|
139
147
|
"""
|
140
|
-
env = Environment()
|
148
|
+
env = Environment(undefined=PreserveUndefined)
|
141
149
|
ast = env.parse(template)
|
142
150
|
return list(meta.find_undeclared_variables(ast))
|
143
151
|
|
@@ -186,13 +194,16 @@ class PromptBase(
|
|
186
194
|
|
187
195
|
>>> p = Prompt("Hello, {{person}}")
|
188
196
|
>>> p.render({"person": "John"})
|
189
|
-
|
197
|
+
Prompt(text=\"""Hello, John\""")
|
190
198
|
|
191
199
|
>>> p.render({"person": "Mr. {{last_name}}", "last_name": "Horton"})
|
192
|
-
|
200
|
+
Prompt(text=\"""Hello, Mr. Horton\""")
|
193
201
|
|
194
202
|
>>> p.render({"person": "Mr. {{last_name}}", "last_name": "Ho{{letter}}ton"}, max_nesting = 1)
|
195
|
-
|
203
|
+
Prompt(text=\"""Hello, Mr. Ho{{ letter }}ton\""")
|
204
|
+
|
205
|
+
>>> p.render({"person": "Mr. {{last_name}}"})
|
206
|
+
Prompt(text=\"""Hello, Mr. {{ last_name }}\""")
|
196
207
|
"""
|
197
208
|
new_text = self._render(
|
198
209
|
self.text, primary_replacement, **additional_replacements
|
@@ -216,12 +227,13 @@ class PromptBase(
|
|
216
227
|
>>> codebook = {"age": "Age"}
|
217
228
|
>>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
|
218
229
|
>>> p.render({"name": "John", "age": 44}, codebook=codebook)
|
219
|
-
|
230
|
+
Prompt(text=\"""You are an agent named John. Age: 44\""")
|
220
231
|
"""
|
232
|
+
env = Environment(undefined=PreserveUndefined)
|
221
233
|
try:
|
222
234
|
previous_text = None
|
223
235
|
for _ in range(MAX_NESTING):
|
224
|
-
rendered_text =
|
236
|
+
rendered_text = env.from_string(text).render(
|
225
237
|
primary_replacement, **additional_replacements
|
226
238
|
)
|
227
239
|
if rendered_text == previous_text:
|
@@ -258,7 +270,7 @@ class PromptBase(
|
|
258
270
|
>>> p = Prompt("Hello, {{person}}")
|
259
271
|
>>> p2 = Prompt.from_dict(p.to_dict())
|
260
272
|
>>> p2
|
261
|
-
Prompt(text
|
273
|
+
Prompt(text=\"""Hello, {{person}}\""")
|
262
274
|
|
263
275
|
"""
|
264
276
|
class_name = data["class_name"]
|
@@ -290,6 +302,12 @@ class Prompt(PromptBase):
|
|
290
302
|
component_type = ComponentTypes.GENERIC
|
291
303
|
|
292
304
|
|
305
|
+
if __name__ == "__main__":
|
306
|
+
print("Running doctests...")
|
307
|
+
import doctest
|
308
|
+
|
309
|
+
doctest.testmod()
|
310
|
+
|
293
311
|
from edsl.prompts.library.question_multiple_choice import *
|
294
312
|
from edsl.prompts.library.agent_instructions import *
|
295
313
|
from edsl.prompts.library.agent_persona import *
|
@@ -302,9 +320,3 @@ from edsl.prompts.library.question_numerical import *
|
|
302
320
|
from edsl.prompts.library.question_rank import *
|
303
321
|
from edsl.prompts.library.question_extract import *
|
304
322
|
from edsl.prompts.library.question_list import *
|
305
|
-
|
306
|
-
|
307
|
-
if __name__ == "__main__":
|
308
|
-
import doctest
|
309
|
-
|
310
|
-
doctest.testmod()
|