edsl 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +5 -0
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +37 -9
- edsl/agents/Invigilator.py +2 -1
- edsl/agents/InvigilatorBase.py +5 -1
- edsl/agents/PromptConstructor.py +31 -67
- edsl/conversation/Conversation.py +1 -1
- edsl/coop/PriceFetcher.py +14 -18
- edsl/coop/coop.py +42 -8
- edsl/data/RemoteCacheSync.py +97 -0
- edsl/exceptions/coop.py +8 -0
- edsl/inference_services/InferenceServiceABC.py +28 -0
- edsl/inference_services/InferenceServicesCollection.py +10 -4
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/inference_services/registry.py +24 -16
- edsl/jobs/Jobs.py +327 -206
- edsl/jobs/interviews/Interview.py +65 -10
- edsl/jobs/interviews/InterviewExceptionCollection.py +9 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +31 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +8 -13
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -5
- edsl/jobs/tasks/TaskHistory.py +23 -7
- edsl/language_models/LanguageModel.py +3 -0
- edsl/prompts/Prompt.py +24 -38
- edsl/prompts/__init__.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +18 -18
- edsl/questions/QuestionFunctional.py +7 -3
- edsl/questions/descriptors.py +24 -24
- edsl/results/Dataset.py +12 -0
- edsl/results/Result.py +2 -0
- edsl/results/Results.py +13 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +15 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Survey.py +3 -0
- edsl/surveys/instructions/Instruction.py +20 -3
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/METADATA +1 -1
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/RECORD +41 -57
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/interviews/InterviewStatusMixin.py +0 -33
- edsl/jobs/tasks/task_management.py +0 -13
- edsl/prompts/QuestionInstructionsBase.py +0 -10
- edsl/prompts/library/agent_instructions.py +0 -38
- edsl/prompts/library/agent_persona.py +0 -21
- edsl/prompts/library/question_budget.py +0 -30
- edsl/prompts/library/question_checkbox.py +0 -38
- edsl/prompts/library/question_extract.py +0 -23
- edsl/prompts/library/question_freetext.py +0 -18
- edsl/prompts/library/question_linear_scale.py +0 -24
- edsl/prompts/library/question_list.py +0 -26
- edsl/prompts/library/question_multiple_choice.py +0 -54
- edsl/prompts/library/question_numerical.py +0 -35
- edsl/prompts/library/question_rank.py +0 -25
- edsl/prompts/prompt_config.py +0 -37
- edsl/prompts/registry.py +0 -202
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/LICENSE +0 -0
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/WHEEL +0 -0
@@ -28,7 +28,7 @@ from edsl.jobs.interviews.InterviewExceptionCollection import (
|
|
28
28
|
InterviewExceptionCollection,
|
29
29
|
)
|
30
30
|
|
31
|
-
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
31
|
+
# from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
32
32
|
|
33
33
|
from edsl.surveys.base import EndOfSurvey
|
34
34
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
@@ -44,6 +44,10 @@ from edsl.agents.InvigilatorBase import InvigilatorBase
|
|
44
44
|
|
45
45
|
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
46
46
|
|
47
|
+
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
48
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
49
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
50
|
+
|
47
51
|
|
48
52
|
from edsl import CONFIG
|
49
53
|
|
@@ -52,7 +56,7 @@ EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
|
52
56
|
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
53
57
|
|
54
58
|
|
55
|
-
class Interview
|
59
|
+
class Interview:
|
56
60
|
"""
|
57
61
|
An 'interview' is one agent answering one survey, with one language model, for a given scenario.
|
58
62
|
|
@@ -100,9 +104,7 @@ class Interview(InterviewStatusMixin):
|
|
100
104
|
|
101
105
|
"""
|
102
106
|
self.agent = agent
|
103
|
-
|
104
|
-
self.survey = copy.deepcopy(survey) # survey copy.deepcopy(survey)
|
105
|
-
# self.survey = survey
|
107
|
+
self.survey = copy.deepcopy(survey)
|
106
108
|
self.scenario = scenario
|
107
109
|
self.model = model
|
108
110
|
self.debug = debug
|
@@ -113,8 +115,6 @@ class Interview(InterviewStatusMixin):
|
|
113
115
|
] = Answers() # will get filled in as interview progresses
|
114
116
|
self.sidecar_model = sidecar_model
|
115
117
|
|
116
|
-
# self.stop_on_exception = False
|
117
|
-
|
118
118
|
# Trackers
|
119
119
|
self.task_creators = TaskCreators() # tracks the task creators
|
120
120
|
self.exceptions = InterviewExceptionCollection()
|
@@ -131,14 +131,41 @@ class Interview(InterviewStatusMixin):
|
|
131
131
|
|
132
132
|
self.failed_questions = []
|
133
133
|
|
134
|
+
@property
|
135
|
+
def has_exceptions(self) -> bool:
|
136
|
+
"""Return True if there are exceptions."""
|
137
|
+
return len(self.exceptions) > 0
|
138
|
+
|
139
|
+
@property
|
140
|
+
def task_status_logs(self) -> InterviewStatusLog:
|
141
|
+
"""Return the task status logs for the interview.
|
142
|
+
|
143
|
+
The keys are the question names; the values are the lists of status log changes for each task.
|
144
|
+
"""
|
145
|
+
for task_creator in self.task_creators.values():
|
146
|
+
self._task_status_log_dict[
|
147
|
+
task_creator.question.question_name
|
148
|
+
] = task_creator.status_log
|
149
|
+
return self._task_status_log_dict
|
150
|
+
|
151
|
+
@property
|
152
|
+
def token_usage(self) -> InterviewTokenUsage:
|
153
|
+
"""Determine how many tokens were used for the interview."""
|
154
|
+
return self.task_creators.token_usage
|
155
|
+
|
156
|
+
@property
|
157
|
+
def interview_status(self) -> InterviewStatusDictionary:
|
158
|
+
"""Return a dictionary mapping task status codes to counts."""
|
159
|
+
return self.task_creators.interview_status
|
160
|
+
|
134
161
|
# region: Serialization
|
135
|
-
def _to_dict(self, include_exceptions=
|
162
|
+
def _to_dict(self, include_exceptions=True) -> dict[str, Any]:
|
136
163
|
"""Return a dictionary representation of the Interview instance.
|
137
164
|
This is just for hashing purposes.
|
138
165
|
|
139
166
|
>>> i = Interview.example()
|
140
167
|
>>> hash(i)
|
141
|
-
|
168
|
+
1217840301076717434
|
142
169
|
"""
|
143
170
|
d = {
|
144
171
|
"agent": self.agent._to_dict(),
|
@@ -150,11 +177,39 @@ class Interview(InterviewStatusMixin):
|
|
150
177
|
}
|
151
178
|
if include_exceptions:
|
152
179
|
d["exceptions"] = self.exceptions.to_dict()
|
180
|
+
return d
|
181
|
+
|
182
|
+
@classmethod
|
183
|
+
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
184
|
+
"""Return an Interview instance from a dictionary."""
|
185
|
+
agent = Agent.from_dict(d["agent"])
|
186
|
+
survey = Survey.from_dict(d["survey"])
|
187
|
+
scenario = Scenario.from_dict(d["scenario"])
|
188
|
+
model = LanguageModel.from_dict(d["model"])
|
189
|
+
iteration = d["iteration"]
|
190
|
+
interview = cls(
|
191
|
+
agent=agent,
|
192
|
+
survey=survey,
|
193
|
+
scenario=scenario,
|
194
|
+
model=model,
|
195
|
+
iteration=iteration,
|
196
|
+
)
|
197
|
+
if "exceptions" in d:
|
198
|
+
exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
|
199
|
+
interview.exceptions = exceptions
|
200
|
+
return interview
|
153
201
|
|
154
202
|
def __hash__(self) -> int:
|
155
203
|
from edsl.utilities.utilities import dict_hash
|
156
204
|
|
157
|
-
return dict_hash(self._to_dict())
|
205
|
+
return dict_hash(self._to_dict(include_exceptions=False))
|
206
|
+
|
207
|
+
def __eq__(self, other: "Interview") -> bool:
|
208
|
+
"""
|
209
|
+
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
|
210
|
+
True
|
211
|
+
"""
|
212
|
+
return hash(self) == hash(other)
|
158
213
|
|
159
214
|
# endregion
|
160
215
|
|
@@ -34,6 +34,15 @@ class InterviewExceptionCollection(UserDict):
|
|
34
34
|
newdata = {k: [e.to_dict() for e in v] for k, v in self.data.items()}
|
35
35
|
return newdata
|
36
36
|
|
37
|
+
@classmethod
|
38
|
+
def from_dict(cls, data: dict) -> "InterviewExceptionCollection":
|
39
|
+
"""Create an InterviewExceptionCollection from a dictionary."""
|
40
|
+
collection = cls()
|
41
|
+
for question_name, entries in data.items():
|
42
|
+
for entry in entries:
|
43
|
+
collection.add(question_name, InterviewExceptionEntry.from_dict(entry))
|
44
|
+
return collection
|
45
|
+
|
37
46
|
def _repr_html_(self) -> str:
|
38
47
|
from edsl.utilities.utilities import data_to_html
|
39
48
|
|
@@ -1,8 +1,5 @@
|
|
1
1
|
import traceback
|
2
2
|
import datetime
|
3
|
-
import time
|
4
|
-
from collections import UserDict
|
5
|
-
from edsl.jobs.FailedQuestion import FailedQuestion
|
6
3
|
|
7
4
|
|
8
5
|
class InterviewExceptionEntry:
|
@@ -12,7 +9,6 @@ class InterviewExceptionEntry:
|
|
12
9
|
self,
|
13
10
|
*,
|
14
11
|
exception: Exception,
|
15
|
-
# failed_question: FailedQuestion,
|
16
12
|
invigilator: "Invigilator",
|
17
13
|
traceback_format="text",
|
18
14
|
answers=None,
|
@@ -137,22 +133,48 @@ class InterviewExceptionEntry:
|
|
137
133
|
console.print(tb)
|
138
134
|
return html_output.getvalue()
|
139
135
|
|
136
|
+
@staticmethod
|
137
|
+
def serialize_exception(exception: Exception) -> dict:
|
138
|
+
return {
|
139
|
+
"type": type(exception).__name__,
|
140
|
+
"message": str(exception),
|
141
|
+
"traceback": "".join(
|
142
|
+
traceback.format_exception(
|
143
|
+
type(exception), exception, exception.__traceback__
|
144
|
+
)
|
145
|
+
),
|
146
|
+
}
|
147
|
+
|
148
|
+
@staticmethod
|
149
|
+
def deserialize_exception(data: dict) -> Exception:
|
150
|
+
try:
|
151
|
+
exception_class = globals()[data["type"]]
|
152
|
+
except KeyError:
|
153
|
+
exception_class = Exception
|
154
|
+
return exception_class(data["message"])
|
155
|
+
|
140
156
|
def to_dict(self) -> dict:
|
141
157
|
"""Return the exception as a dictionary.
|
142
158
|
|
143
159
|
>>> entry = InterviewExceptionEntry.example()
|
144
|
-
>>> entry.to_dict()
|
145
|
-
ValueError()
|
146
|
-
|
160
|
+
>>> _ = entry.to_dict()
|
147
161
|
"""
|
148
162
|
return {
|
149
|
-
"exception": self.exception,
|
163
|
+
"exception": self.serialize_exception(self.exception),
|
150
164
|
"time": self.time,
|
151
165
|
"traceback": self.traceback,
|
152
|
-
# "failed_question": self.failed_question.to_dict(),
|
153
166
|
"invigilator": self.invigilator.to_dict(),
|
154
167
|
}
|
155
168
|
|
169
|
+
@classmethod
|
170
|
+
def from_dict(cls, data: dict) -> "InterviewExceptionEntry":
|
171
|
+
"""Create an InterviewExceptionEntry from a dictionary."""
|
172
|
+
from edsl.agents.Invigilator import InvigilatorAI
|
173
|
+
|
174
|
+
exception = cls.deserialize_exception(data["exception"])
|
175
|
+
invigilator = InvigilatorAI.from_dict(data["invigilator"])
|
176
|
+
return cls(exception=exception, invigilator=invigilator)
|
177
|
+
|
156
178
|
def push(self):
|
157
179
|
from edsl import Coop
|
158
180
|
|
@@ -1,18 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import time
|
3
|
-
import math
|
4
3
|
import asyncio
|
5
|
-
import functools
|
6
4
|
import threading
|
7
5
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
8
6
|
from contextlib import contextmanager
|
9
7
|
from collections import UserList
|
10
8
|
|
11
|
-
from rich.live import Live
|
12
|
-
from rich.console import Console
|
13
|
-
|
14
9
|
from edsl.results.Results import Results
|
15
|
-
from edsl import shared_globals
|
16
10
|
from edsl.jobs.interviews.Interview import Interview
|
17
11
|
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
18
12
|
|
@@ -48,8 +42,6 @@ class JobsRunnerAsyncio:
|
|
48
42
|
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
49
43
|
self.total_interviews: List["Interview"] = []
|
50
44
|
|
51
|
-
# self.jobs_runner_status = JobsRunnerStatus(self, n=1)
|
52
|
-
|
53
45
|
async def run_async_generator(
|
54
46
|
self,
|
55
47
|
cache: Cache,
|
@@ -181,6 +173,7 @@ class JobsRunnerAsyncio:
|
|
181
173
|
] = question_name_to_prompts[answer_key_name]["system_prompt"]
|
182
174
|
|
183
175
|
raw_model_results_dictionary = {}
|
176
|
+
cache_used_dictionary = {}
|
184
177
|
for result in valid_results:
|
185
178
|
question_name = result.question_name
|
186
179
|
raw_model_results_dictionary[
|
@@ -195,6 +188,7 @@ class JobsRunnerAsyncio:
|
|
195
188
|
else 1.0 / result.cost
|
196
189
|
)
|
197
190
|
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
191
|
+
cache_used_dictionary[question_name] = result.cache_used
|
198
192
|
|
199
193
|
result = Result(
|
200
194
|
agent=interview.agent,
|
@@ -207,6 +201,7 @@ class JobsRunnerAsyncio:
|
|
207
201
|
survey=interview.survey,
|
208
202
|
generated_tokens=generated_tokens_dict,
|
209
203
|
comments_dict=comments_dict,
|
204
|
+
cache_used_dict=cache_used_dictionary,
|
210
205
|
)
|
211
206
|
result.interview_hash = hash(interview)
|
212
207
|
|
@@ -225,17 +220,16 @@ class JobsRunnerAsyncio:
|
|
225
220
|
}
|
226
221
|
interview_hashes = list(interview_lookup.keys())
|
227
222
|
|
223
|
+
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
224
|
+
|
228
225
|
results = Results(
|
229
226
|
survey=self.jobs.survey,
|
230
227
|
data=sorted(
|
231
228
|
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
232
229
|
),
|
230
|
+
task_history=task_history,
|
231
|
+
cache=cache,
|
233
232
|
)
|
234
|
-
results.cache = cache
|
235
|
-
results.task_history = TaskHistory(
|
236
|
-
self.total_interviews, include_traceback=False
|
237
|
-
)
|
238
|
-
results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
|
239
233
|
results.bucket_collection = self.bucket_collection
|
240
234
|
|
241
235
|
if results.has_unfixed_exceptions and print_exceptions:
|
@@ -263,6 +257,7 @@ class JobsRunnerAsyncio:
|
|
263
257
|
except Exception as e:
|
264
258
|
print(e)
|
265
259
|
remote_logging = False
|
260
|
+
|
266
261
|
if remote_logging:
|
267
262
|
filestore = HTMLFileStore(filepath)
|
268
263
|
coop_details = filestore.push(description="Error report")
|
@@ -1,20 +1,16 @@
|
|
1
1
|
import asyncio
|
2
2
|
from typing import Callable, Union, List
|
3
3
|
from collections import UserList, UserDict
|
4
|
-
import time
|
5
4
|
|
6
5
|
from edsl.jobs.buckets import ModelBuckets
|
7
6
|
from edsl.exceptions import InterviewErrorPriorTaskCanceled
|
8
7
|
|
9
8
|
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
10
9
|
from edsl.jobs.tasks.task_status_enum import TaskStatus, TaskStatusDescriptor
|
11
|
-
|
12
|
-
# from edsl.jobs.tasks.task_management import TokensUsed
|
13
10
|
from edsl.jobs.tasks.TaskStatusLog import TaskStatusLog
|
14
11
|
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
12
|
from edsl.jobs.tokens.TokenUsage import TokenUsage
|
16
13
|
from edsl.jobs.Answers import Answers
|
17
|
-
|
18
14
|
from edsl.questions.QuestionBase import QuestionBase
|
19
15
|
|
20
16
|
|
@@ -130,7 +126,7 @@ class QuestionTaskCreator(UserList):
|
|
130
126
|
await self.tokens_bucket.get_tokens(requested_tokens)
|
131
127
|
|
132
128
|
if (estimated_wait_time := self.requests_bucket.wait_time(1)) > 0:
|
133
|
-
self.waiting = True
|
129
|
+
self.waiting = True # do we need this?
|
134
130
|
self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
|
135
131
|
|
136
132
|
await self.requests_bucket.get_tokens(1, cheat_bucket_capacity=True)
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -8,7 +8,7 @@ from edsl.jobs.tasks.task_status_enum import TaskStatus
|
|
8
8
|
|
9
9
|
|
10
10
|
class TaskHistory:
|
11
|
-
def __init__(self, interviews: List["Interview"], include_traceback=False):
|
11
|
+
def __init__(self, interviews: List["Interview"], include_traceback: bool = False):
|
12
12
|
"""
|
13
13
|
The structure of a TaskHistory exception
|
14
14
|
|
@@ -25,6 +25,7 @@ class TaskHistory:
|
|
25
25
|
|
26
26
|
@classmethod
|
27
27
|
def example(cls):
|
28
|
+
""" """
|
28
29
|
from edsl.jobs.interviews.Interview import Interview
|
29
30
|
|
30
31
|
from edsl.jobs.Jobs import Jobs
|
@@ -38,6 +39,7 @@ class TaskHistory:
|
|
38
39
|
skip_retry=True,
|
39
40
|
cache=False,
|
40
41
|
raise_validation_errors=True,
|
42
|
+
disable_remote_inference=True,
|
41
43
|
)
|
42
44
|
|
43
45
|
return cls(results.task_history.total_interviews)
|
@@ -72,14 +74,29 @@ class TaskHistory:
|
|
72
74
|
|
73
75
|
def to_dict(self):
|
74
76
|
"""Return the TaskHistory as a dictionary."""
|
77
|
+
# return {
|
78
|
+
# "exceptions": [
|
79
|
+
# e.to_dict(include_traceback=self.include_traceback)
|
80
|
+
# for e in self.exceptions
|
81
|
+
# ],
|
82
|
+
# "indices": self.indices,
|
83
|
+
# }
|
75
84
|
return {
|
76
|
-
"
|
77
|
-
|
78
|
-
for e in self.exceptions
|
79
|
-
],
|
80
|
-
"indices": self.indices,
|
85
|
+
"interviews": [i._to_dict() for i in self.total_interviews],
|
86
|
+
"include_traceback": self.include_traceback,
|
81
87
|
}
|
82
88
|
|
89
|
+
@classmethod
|
90
|
+
def from_dict(cls, data: dict):
|
91
|
+
"""Create a TaskHistory from a dictionary."""
|
92
|
+
if data is None:
|
93
|
+
return cls([], include_traceback=False)
|
94
|
+
|
95
|
+
from edsl.jobs.interviews.Interview import Interview
|
96
|
+
|
97
|
+
interviews = [Interview.from_dict(i) for i in data["interviews"]]
|
98
|
+
return cls(interviews, include_traceback=data["include_traceback"])
|
99
|
+
|
83
100
|
@property
|
84
101
|
def has_exceptions(self) -> bool:
|
85
102
|
"""Return True if there are any exceptions.
|
@@ -259,7 +276,6 @@ class TaskHistory:
|
|
259
276
|
question_type = interview.survey.get_question(
|
260
277
|
question_name
|
261
278
|
).question_type
|
262
|
-
# breakpoint()
|
263
279
|
if (question_name, question_type) not in exceptions_by_question_name:
|
264
280
|
exceptions_by_question_name[(question_name, question_type)] = 0
|
265
281
|
exceptions_by_question_name[(question_name, question_type)] += len(
|
@@ -348,6 +348,9 @@ class LanguageModel(
|
|
348
348
|
"""
|
349
349
|
# parameters = dict({})
|
350
350
|
|
351
|
+
# this is the case when data is loaded from a dict after serialization
|
352
|
+
if "parameters" in passed_parameter_dict:
|
353
|
+
passed_parameter_dict = passed_parameter_dict["parameters"]
|
351
354
|
return {
|
352
355
|
parameter_name: passed_parameter_dict.get(parameter_name, default_value)
|
353
356
|
for parameter_name, default_value in default_parameter_dict.items()
|
edsl/prompts/Prompt.py
CHANGED
@@ -17,25 +17,23 @@ class PreserveUndefined(Undefined):
|
|
17
17
|
|
18
18
|
|
19
19
|
from edsl.exceptions.prompts import TemplateRenderError
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
20
|
+
|
21
|
+
# from edsl.prompts.prompt_config import (
|
22
|
+
# C2A,
|
23
|
+
# names_to_component_types,
|
24
|
+
# ComponentTypes,
|
25
|
+
# NEGATIVE_INFINITY,
|
26
|
+
# )
|
27
|
+
# from edsl.prompts.registry import RegisterPromptsMeta
|
27
28
|
from edsl.Base import PersistenceMixin, RichPrintingMixin
|
28
29
|
|
29
30
|
MAX_NESTING = 100
|
30
31
|
|
31
32
|
|
32
|
-
class
|
33
|
-
PersistenceMixin, RichPrintingMixin, ABC, metaclass=RegisterPromptsMeta
|
34
|
-
):
|
33
|
+
class Prompt(PersistenceMixin, RichPrintingMixin):
|
35
34
|
"""Class for creating a prompt to be used in a survey."""
|
36
35
|
|
37
36
|
default_instructions: Optional[str] = "Do good things, friendly LLM!"
|
38
|
-
component_type = ComponentTypes.GENERIC
|
39
37
|
|
40
38
|
def _repr_html_(self):
|
41
39
|
"""Return an HTML representation of the Prompt."""
|
@@ -111,12 +109,6 @@ class PromptBase(
|
|
111
109
|
with open(folder_path.joinpath(file_name), "r") as f:
|
112
110
|
text = f.read()
|
113
111
|
return cls(text=text)
|
114
|
-
# Resolve the path to get the absolute path
|
115
|
-
# absolute_path = folder_path.resolve()
|
116
|
-
# env = Environment(loader=FileSystemLoader(absolute_path))
|
117
|
-
# template = env.get_template(file_name)
|
118
|
-
# rendered_text = template.render({})
|
119
|
-
# return cls(text=rendered_text)
|
120
112
|
|
121
113
|
@property
|
122
114
|
def text(self):
|
@@ -281,6 +273,7 @@ class PromptBase(
|
|
281
273
|
try:
|
282
274
|
previous_text = None
|
283
275
|
for _ in range(MAX_NESTING):
|
276
|
+
# breakpoint()
|
284
277
|
rendered_text = env.from_string(text).render(
|
285
278
|
primary_replacement, **additional_replacements
|
286
279
|
)
|
@@ -323,9 +316,8 @@ class PromptBase(
|
|
323
316
|
Prompt(text=\"""Hello, {{person}}\""")
|
324
317
|
|
325
318
|
"""
|
326
|
-
class_name = data["class_name"]
|
327
|
-
|
328
|
-
return cls(text=data["text"])
|
319
|
+
# class_name = data["class_name"]
|
320
|
+
return Prompt(text=data["text"])
|
329
321
|
|
330
322
|
def rich_print(self):
|
331
323
|
"""Display an object as a table."""
|
@@ -346,27 +338,21 @@ class PromptBase(
|
|
346
338
|
return cls(cls.default_instructions)
|
347
339
|
|
348
340
|
|
349
|
-
class Prompt(PromptBase):
|
350
|
-
"""A prompt to be used in a survey."""
|
351
|
-
|
352
|
-
component_type = ComponentTypes.GENERIC
|
353
|
-
|
354
|
-
|
355
341
|
if __name__ == "__main__":
|
356
342
|
print("Running doctests...")
|
357
343
|
import doctest
|
358
344
|
|
359
345
|
doctest.testmod()
|
360
346
|
|
361
|
-
from edsl.prompts.library.question_multiple_choice import *
|
362
|
-
from edsl.prompts.library.agent_instructions import *
|
363
|
-
from edsl.prompts.library.agent_persona import *
|
364
|
-
|
365
|
-
from edsl.prompts.library.question_budget import *
|
366
|
-
from edsl.prompts.library.question_checkbox import *
|
367
|
-
from edsl.prompts.library.question_freetext import *
|
368
|
-
from edsl.prompts.library.question_linear_scale import *
|
369
|
-
from edsl.prompts.library.question_numerical import *
|
370
|
-
from edsl.prompts.library.question_rank import *
|
371
|
-
from edsl.prompts.library.question_extract import *
|
372
|
-
from edsl.prompts.library.question_list import *
|
347
|
+
# from edsl.prompts.library.question_multiple_choice import *
|
348
|
+
# from edsl.prompts.library.agent_instructions import *
|
349
|
+
# from edsl.prompts.library.agent_persona import *
|
350
|
+
|
351
|
+
# from edsl.prompts.library.question_budget import *
|
352
|
+
# from edsl.prompts.library.question_checkbox import *
|
353
|
+
# from edsl.prompts.library.question_freetext import *
|
354
|
+
# from edsl.prompts.library.question_linear_scale import *
|
355
|
+
# from edsl.prompts.library.question_numerical import *
|
356
|
+
# from edsl.prompts.library.question_rank import *
|
357
|
+
# from edsl.prompts.library.question_extract import *
|
358
|
+
# from edsl.prompts.library.question_list import *
|
edsl/prompts/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from edsl.prompts.registry import get_classes
|
1
|
+
# from edsl.prompts.registry import get_classes
|
2
2
|
from edsl.prompts.Prompt import Prompt
|
@@ -38,29 +38,29 @@ class QuestionBasePromptsMixin:
|
|
38
38
|
# ) as file:
|
39
39
|
# return file.read()
|
40
40
|
|
41
|
-
@classmethod
|
42
|
-
def applicable_prompts(
|
43
|
-
|
44
|
-
) -> list[type["PromptBase"]]:
|
45
|
-
|
41
|
+
# @classmethod
|
42
|
+
# def applicable_prompts(
|
43
|
+
# cls, model: Optional[str] = None
|
44
|
+
# ) -> list[type["PromptBase"]]:
|
45
|
+
# """Get the prompts that are applicable to the question type.
|
46
46
|
|
47
|
-
|
47
|
+
# :param model: The language model to use.
|
48
48
|
|
49
|
-
|
50
|
-
|
51
|
-
|
49
|
+
# >>> from edsl.questions import QuestionFreeText
|
50
|
+
# >>> QuestionFreeText.applicable_prompts()
|
51
|
+
# [<class 'edsl.prompts.library.question_freetext.FreeText'>]
|
52
52
|
|
53
|
-
|
53
|
+
# :param model: The language model to use. If None, assumes does not matter.
|
54
54
|
|
55
|
-
|
56
|
-
|
55
|
+
# """
|
56
|
+
# from edsl.prompts.registry import get_classes as prompt_lookup
|
57
57
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
58
|
+
# applicable_prompts = prompt_lookup(
|
59
|
+
# component_type="question_instructions",
|
60
|
+
# question_type=cls.question_type,
|
61
|
+
# model=model,
|
62
|
+
# )
|
63
|
+
# return applicable_prompts
|
64
64
|
|
65
65
|
@property
|
66
66
|
def model_instructions(self) -> dict:
|
@@ -50,6 +50,7 @@ class QuestionFunctional(QuestionBase):
|
|
50
50
|
requires_loop: Optional[bool] = False,
|
51
51
|
function_source_code: Optional[str] = None,
|
52
52
|
function_name: Optional[str] = None,
|
53
|
+
unsafe: Optional[bool] = False,
|
53
54
|
):
|
54
55
|
super().__init__()
|
55
56
|
if func:
|
@@ -61,9 +62,12 @@ class QuestionFunctional(QuestionBase):
|
|
61
62
|
|
62
63
|
self.requires_loop = requires_loop
|
63
64
|
|
64
|
-
|
65
|
-
self.
|
66
|
-
|
65
|
+
if unsafe:
|
66
|
+
self.func = func
|
67
|
+
else:
|
68
|
+
self.func = create_restricted_function(
|
69
|
+
self.function_name, self.function_source_code
|
70
|
+
)
|
67
71
|
|
68
72
|
self.question_name = question_name
|
69
73
|
self.question_text = question_text
|
edsl/questions/descriptors.py
CHANGED
@@ -54,32 +54,32 @@ class BaseDescriptor(ABC):
|
|
54
54
|
def __set__(self, instance, value: Any) -> None:
|
55
55
|
"""Set the value of the attribute."""
|
56
56
|
self.validate(value, instance)
|
57
|
-
from edsl.prompts.registry import get_classes
|
57
|
+
# from edsl.prompts.registry import get_classes
|
58
58
|
|
59
59
|
instance.__dict__[self.name] = value
|
60
|
-
if self.name == "_instructions":
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
60
|
+
# if self.name == "_instructions":
|
61
|
+
# instructions = value
|
62
|
+
# if value is not None:
|
63
|
+
# instance.__dict__[self.name] = instructions
|
64
|
+
# instance.set_instructions = True
|
65
|
+
# else:
|
66
|
+
# potential_prompt_classes = get_classes(
|
67
|
+
# question_type=instance.question_type
|
68
|
+
# )
|
69
|
+
# if len(potential_prompt_classes) > 0:
|
70
|
+
# instructions = potential_prompt_classes[0]().text
|
71
|
+
# instance.__dict__[self.name] = instructions
|
72
|
+
# instance.set_instructions = False
|
73
|
+
# else:
|
74
|
+
# if not hasattr(instance, "default_instructions"):
|
75
|
+
# raise Exception(
|
76
|
+
# "No default instructions found and no matching prompts!"
|
77
|
+
# )
|
78
|
+
# instructions = instance.default_instructions
|
79
|
+
# instance.__dict__[self.name] = instructions
|
80
|
+
# instance.set_instructions = False
|
81
|
+
|
82
|
+
# instance.set_instructions = value != instance.default_instructions
|
83
83
|
|
84
84
|
def __set_name__(self, owner, name: str) -> None:
|
85
85
|
"""Set the name of the attribute."""
|
edsl/results/Dataset.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import random
|
5
|
+
import json
|
5
6
|
from collections import UserList
|
6
7
|
from typing import Any, Union, Optional
|
7
8
|
|
@@ -110,6 +111,17 @@ class Dataset(UserList, ResultsExportMixin):
|
|
110
111
|
new_data.append(observation)
|
111
112
|
return Dataset(new_data)
|
112
113
|
|
114
|
+
def to_json(self):
|
115
|
+
"""Return a JSON representation of the dataset.
|
116
|
+
|
117
|
+
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
118
|
+
>>> d.to_json()
|
119
|
+
[{'a.b': [1, 2, 3, 4]}]
|
120
|
+
"""
|
121
|
+
return json.loads(
|
122
|
+
json.dumps(self.data)
|
123
|
+
) # janky but I want to make sure it's serializable & deserializable
|
124
|
+
|
113
125
|
def _repr_html_(self) -> str:
|
114
126
|
"""Return an HTML representation of the dataset."""
|
115
127
|
from edsl.utilities.utilities import data_to_html
|