edsl 0.1.29.dev3__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +18 -18
- edsl/__init__.py +23 -23
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +79 -41
- edsl/agents/AgentList.py +26 -26
- edsl/agents/Invigilator.py +19 -2
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/InputData.py +39 -8
- edsl/conversation/car_buying.py +1 -1
- edsl/coop/coop.py +187 -150
- edsl/coop/utils.py +43 -75
- edsl/data/Cache.py +41 -18
- edsl/data/CacheEntry.py +6 -7
- edsl/data/SQLiteDict.py +11 -3
- edsl/data_transfer_models.py +4 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +108 -49
- edsl/jobs/buckets/ModelBuckets.py +14 -2
- edsl/jobs/buckets/TokenBucket.py +32 -5
- edsl/jobs/interviews/Interview.py +99 -79
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +19 -24
- edsl/jobs/runners/JobsRunnerAsyncio.py +16 -16
- edsl/jobs/tasks/QuestionTaskCreator.py +10 -6
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +17 -17
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +47 -10
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +38 -13
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +7 -5
- edsl/questions/QuestionFunctional.py +34 -5
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +68 -12
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +46 -4
- edsl/questions/question_registry.py +20 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +570 -0
- edsl/results/Result.py +66 -70
- edsl/results/Results.py +160 -68
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/FileStore.py +299 -0
- edsl/scenarios/Scenario.py +16 -24
- edsl/scenarios/ScenarioList.py +42 -17
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/Study.py +8 -16
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +88 -17
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/METADATA +11 -10
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/RECORD +74 -71
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/WHEEL +1 -1
- edsl-0.1.29.dev3.dist-info/entry_points.txt +0 -3
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/LICENSE +0 -0
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
from typing import Union, List, Any
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
|
-
from collections import UserDict
|
5
|
-
from matplotlib import pyplot as plt
|
6
4
|
|
7
5
|
|
8
6
|
class TokenBucket:
|
@@ -19,11 +17,29 @@ class TokenBucket:
|
|
19
17
|
self.bucket_name = bucket_name
|
20
18
|
self.bucket_type = bucket_type
|
21
19
|
self.capacity = capacity # Maximum number of tokens
|
20
|
+
self._old_capacity = capacity
|
22
21
|
self.tokens = capacity # Current number of available tokens
|
23
22
|
self.refill_rate = refill_rate # Rate at which tokens are refilled
|
23
|
+
self._old_refill_rate = refill_rate
|
24
24
|
self.last_refill = time.monotonic() # Last refill time
|
25
|
-
|
26
25
|
self.log: List[Any] = []
|
26
|
+
self.turbo_mode = False
|
27
|
+
|
28
|
+
def turbo_mode_on(self):
|
29
|
+
"""Set the refill rate to infinity."""
|
30
|
+
if self.turbo_mode:
|
31
|
+
pass
|
32
|
+
else:
|
33
|
+
# pass
|
34
|
+
self.turbo_mode = True
|
35
|
+
self.capacity = float("inf")
|
36
|
+
self.refill_rate = float("inf")
|
37
|
+
|
38
|
+
def turbo_mode_off(self):
|
39
|
+
"""Restore the refill rate to its original value."""
|
40
|
+
self.turbo_mode = False
|
41
|
+
self.capacity = self._old_capacity
|
42
|
+
self.refill_rate = self._old_refill_rate
|
27
43
|
|
28
44
|
def __add__(self, other) -> "TokenBucket":
|
29
45
|
"""Combine two token buckets.
|
@@ -57,7 +73,17 @@ class TokenBucket:
|
|
57
73
|
self.log.append((time.monotonic(), self.tokens))
|
58
74
|
|
59
75
|
def refill(self) -> None:
|
60
|
-
"""Refill the bucket with new tokens based on elapsed time.
|
76
|
+
"""Refill the bucket with new tokens based on elapsed time.
|
77
|
+
|
78
|
+
|
79
|
+
|
80
|
+
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
81
|
+
>>> bucket.tokens = 0
|
82
|
+
>>> bucket.refill()
|
83
|
+
>>> bucket.tokens > 0
|
84
|
+
True
|
85
|
+
|
86
|
+
"""
|
61
87
|
now = time.monotonic()
|
62
88
|
elapsed = now - self.last_refill
|
63
89
|
refill_amount = elapsed * self.refill_rate
|
@@ -100,7 +126,7 @@ class TokenBucket:
|
|
100
126
|
raise ValueError(msg)
|
101
127
|
while self.tokens < amount:
|
102
128
|
self.refill()
|
103
|
-
await asyncio.sleep(0.
|
129
|
+
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|
104
130
|
self.tokens -= amount
|
105
131
|
|
106
132
|
now = time.monotonic()
|
@@ -114,6 +140,7 @@ class TokenBucket:
|
|
114
140
|
times, tokens = zip(*self.get_log())
|
115
141
|
start_time = times[0]
|
116
142
|
times = [t - start_time for t in times] # Normalize time to start from 0
|
143
|
+
from matplotlib import pyplot as plt
|
117
144
|
|
118
145
|
plt.figure(figsize=(10, 6))
|
119
146
|
plt.plot(times, tokens, label="Tokens Available")
|
@@ -6,15 +6,9 @@ import asyncio
|
|
6
6
|
import time
|
7
7
|
from typing import Any, Type, List, Generator, Optional
|
8
8
|
|
9
|
-
from edsl.agents import Agent
|
10
|
-
from edsl.language_models import LanguageModel
|
11
|
-
from edsl.scenarios import Scenario
|
12
|
-
from edsl.surveys import Survey
|
13
|
-
|
14
9
|
from edsl.jobs.Answers import Answers
|
15
10
|
from edsl.surveys.base import EndOfSurvey
|
16
11
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
17
|
-
|
18
12
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
19
13
|
|
20
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
@@ -26,6 +20,12 @@ from edsl.jobs.interviews.retry_management import retry_strategy
|
|
26
20
|
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
27
21
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
28
22
|
|
23
|
+
import asyncio
|
24
|
+
|
25
|
+
|
26
|
+
def run_async(coro):
|
27
|
+
return asyncio.run(coro)
|
28
|
+
|
29
29
|
|
30
30
|
class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
31
31
|
"""
|
@@ -36,14 +36,14 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
36
36
|
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
|
-
agent: Agent,
|
40
|
-
survey: Survey,
|
41
|
-
scenario: Scenario,
|
42
|
-
model: Type[LanguageModel],
|
43
|
-
debug: bool = False,
|
39
|
+
agent: "Agent",
|
40
|
+
survey: "Survey",
|
41
|
+
scenario: "Scenario",
|
42
|
+
model: Type["LanguageModel"],
|
43
|
+
debug: Optional[bool] = False,
|
44
44
|
iteration: int = 0,
|
45
|
-
cache: "Cache" = None,
|
46
|
-
sidecar_model: LanguageModel = None,
|
45
|
+
cache: Optional["Cache"] = None,
|
46
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
47
47
|
):
|
48
48
|
"""Initialize the Interview instance.
|
49
49
|
|
@@ -51,6 +51,24 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
51
51
|
:param survey: the survey being administered to the agent.
|
52
52
|
:param scenario: the scenario that populates the survey questions.
|
53
53
|
:param model: the language model used to answer the questions.
|
54
|
+
:param debug: if True, run without calls to the language model.
|
55
|
+
:param iteration: the iteration number of the interview.
|
56
|
+
:param cache: the cache used to store the answers.
|
57
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
58
|
+
|
59
|
+
>>> i = Interview.example()
|
60
|
+
>>> i.task_creators
|
61
|
+
{}
|
62
|
+
|
63
|
+
>>> i.exceptions
|
64
|
+
{}
|
65
|
+
|
66
|
+
>>> _ = asyncio.run(i.async_conduct_interview())
|
67
|
+
>>> i.task_status_logs['q0']
|
68
|
+
[{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
|
69
|
+
|
70
|
+
>>> i.to_index
|
71
|
+
{'q0': 0, 'q1': 1, 'q2': 2}
|
54
72
|
|
55
73
|
"""
|
56
74
|
self.agent = agent
|
@@ -70,7 +88,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
70
88
|
self.exceptions = InterviewExceptionCollection()
|
71
89
|
self._task_status_log_dict = InterviewStatusLog()
|
72
90
|
|
73
|
-
# dictionary mapping question names to their index in the survey.
|
91
|
+
# dictionary mapping question names to their index in the survey.
|
74
92
|
self.to_index = {
|
75
93
|
question_name: index
|
76
94
|
for index, question_name in enumerate(self.survey.question_names)
|
@@ -82,14 +100,16 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
82
100
|
model_buckets: ModelBuckets = None,
|
83
101
|
debug: bool = False,
|
84
102
|
stop_on_exception: bool = False,
|
85
|
-
sidecar_model: Optional[LanguageModel] = None,
|
103
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
86
104
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
87
105
|
"""
|
88
106
|
Conduct an Interview asynchronously.
|
107
|
+
It returns a tuple with the answers and a list of valid results.
|
89
108
|
|
90
109
|
:param model_buckets: a dictionary of token buckets for the model.
|
91
110
|
:param debug: run without calls to LLM.
|
92
111
|
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
112
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
93
113
|
|
94
114
|
Example usage:
|
95
115
|
|
@@ -97,16 +117,37 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
97
117
|
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
98
118
|
>>> result['q0']
|
99
119
|
'yes'
|
120
|
+
|
121
|
+
>>> i = Interview.example(throw_exception = True)
|
122
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
123
|
+
Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
124
|
+
<BLANKLINE>
|
125
|
+
<BLANKLINE>
|
126
|
+
Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
127
|
+
<BLANKLINE>
|
128
|
+
<BLANKLINE>
|
129
|
+
Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
130
|
+
<BLANKLINE>
|
131
|
+
<BLANKLINE>
|
132
|
+
Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
133
|
+
<BLANKLINE>
|
134
|
+
<BLANKLINE>
|
135
|
+
|
136
|
+
>>> i.exceptions
|
137
|
+
{'q0': [{'exception': "Exception('This is a test error')", 'time': ..., 'traceback': ...
|
138
|
+
|
139
|
+
>>> i = Interview.example()
|
140
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
141
|
+
Traceback (most recent call last):
|
142
|
+
...
|
143
|
+
asyncio.exceptions.CancelledError
|
100
144
|
"""
|
101
145
|
self.sidecar_model = sidecar_model
|
102
146
|
|
103
147
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
104
|
-
# print("model_buckets", model_buckets)
|
105
148
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
106
149
|
model_buckets = ModelBuckets.infinity_bucket()
|
107
150
|
|
108
|
-
# model_buckets = ModelBuckets.infinity_bucket()
|
109
|
-
|
110
151
|
## build the tasks using the InterviewTaskBuildingMixin
|
111
152
|
## This is the key part---it creates a task for each question,
|
112
153
|
## with dependencies on the questions that must be answered before this one can be answered.
|
@@ -128,6 +169,14 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
128
169
|
It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
|
129
170
|
If a task is not done, it raises a ValueError.
|
130
171
|
If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
|
172
|
+
|
173
|
+
>>> i = Interview.example()
|
174
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
175
|
+
>>> results = list(i._extract_valid_results())
|
176
|
+
>>> len(results) == len(i.survey)
|
177
|
+
True
|
178
|
+
>>> type(results[0])
|
179
|
+
<class 'edsl.data_transfer_models.AgentResponseDict'>
|
131
180
|
"""
|
132
181
|
assert len(self.tasks) == len(self.invigilators)
|
133
182
|
|
@@ -145,7 +194,18 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
145
194
|
yield result
|
146
195
|
|
147
196
|
def _record_exception(self, task, exception: Exception) -> None:
|
148
|
-
"""Record an exception in the Interview instance.
|
197
|
+
"""Record an exception in the Interview instance.
|
198
|
+
|
199
|
+
It records the exception in the Interview instance, with the task name and the exception entry.
|
200
|
+
|
201
|
+
>>> i = Interview.example()
|
202
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
203
|
+
>>> i.exceptions
|
204
|
+
{}
|
205
|
+
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
206
|
+
>>> i.exceptions
|
207
|
+
{'q0': [{'exception': "Exception('An exception occurred.')", 'time': ..., 'traceback': 'NoneType: None\\n'}]}
|
208
|
+
"""
|
149
209
|
exception_entry = InterviewExceptionEntry(
|
150
210
|
exception=repr(exception),
|
151
211
|
time=time.time(),
|
@@ -161,6 +221,10 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
161
221
|
It is used to determine the order in which questions should be answered.
|
162
222
|
This reflects both agent 'memory' considerations and 'skip' logic.
|
163
223
|
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
224
|
+
|
225
|
+
>>> i = Interview.example()
|
226
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
227
|
+
True
|
164
228
|
"""
|
165
229
|
return self.survey.dag(textify=True)
|
166
230
|
|
@@ -171,8 +235,15 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
171
235
|
"""Return a string representation of the Interview instance."""
|
172
236
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
173
237
|
|
174
|
-
def duplicate(self, iteration: int, cache: Cache) -> Interview:
|
175
|
-
"""Duplicate the interview, but with a new iteration number and cache.
|
238
|
+
def duplicate(self, iteration: int, cache: "Cache") -> Interview:
|
239
|
+
"""Duplicate the interview, but with a new iteration number and cache.
|
240
|
+
|
241
|
+
>>> i = Interview.example()
|
242
|
+
>>> i2 = i.duplicate(1, None)
|
243
|
+
>>> i.iteration + 1 == i2.iteration
|
244
|
+
True
|
245
|
+
|
246
|
+
"""
|
176
247
|
return Interview(
|
177
248
|
agent=self.agent,
|
178
249
|
survey=self.survey,
|
@@ -183,7 +254,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
183
254
|
)
|
184
255
|
|
185
256
|
@classmethod
|
186
|
-
def example(self):
|
257
|
+
def example(self, throw_exception: bool = False) -> Interview:
|
187
258
|
"""Return an example Interview instance."""
|
188
259
|
from edsl.agents import Agent
|
189
260
|
from edsl.surveys import Survey
|
@@ -198,66 +269,15 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
198
269
|
survey = Survey.example()
|
199
270
|
scenario = Scenario.example()
|
200
271
|
model = LanguageModel.example()
|
272
|
+
if throw_exception:
|
273
|
+
model = LanguageModel.example(test_model=True, throw_exception=True)
|
274
|
+
agent = Agent.example()
|
275
|
+
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
201
276
|
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
202
277
|
|
203
278
|
|
204
279
|
if __name__ == "__main__":
|
205
280
|
import doctest
|
206
281
|
|
207
|
-
|
208
|
-
|
209
|
-
# from edsl.agents import Agent
|
210
|
-
# from edsl.surveys import Survey
|
211
|
-
# from edsl.scenarios import Scenario
|
212
|
-
# from edsl.questions import QuestionMultipleChoice
|
213
|
-
|
214
|
-
# # from edsl.jobs.Interview import Interview
|
215
|
-
|
216
|
-
# # a survey with skip logic
|
217
|
-
# q0 = QuestionMultipleChoice(
|
218
|
-
# question_text="Do you like school?",
|
219
|
-
# question_options=["yes", "no"],
|
220
|
-
# question_name="q0",
|
221
|
-
# )
|
222
|
-
# q1 = QuestionMultipleChoice(
|
223
|
-
# question_text="Why not?",
|
224
|
-
# question_options=["killer bees in cafeteria", "other"],
|
225
|
-
# question_name="q1",
|
226
|
-
# )
|
227
|
-
# q2 = QuestionMultipleChoice(
|
228
|
-
# question_text="Why?",
|
229
|
-
# question_options=["**lack*** of killer bees in cafeteria", "other"],
|
230
|
-
# question_name="q2",
|
231
|
-
# )
|
232
|
-
# s = Survey(questions=[q0, q1, q2])
|
233
|
-
# s = s.add_rule(q0, "q0 == 'yes'", q2)
|
234
|
-
|
235
|
-
# # create an interview
|
236
|
-
# a = Agent(traits=None)
|
237
|
-
|
238
|
-
# def direct_question_answering_method(self, question, scenario):
|
239
|
-
# """Answer a question directly."""
|
240
|
-
# raise Exception("Error!")
|
241
|
-
# # return "yes"
|
242
|
-
|
243
|
-
# a.add_direct_question_answering_method(direct_question_answering_method)
|
244
|
-
# scenario = Scenario()
|
245
|
-
# m = Model()
|
246
|
-
# I = Interview(agent=a, survey=s, scenario=scenario, model=m)
|
247
|
-
|
248
|
-
# result = asyncio.run(I.async_conduct_interview())
|
249
|
-
# # # conduct five interviews
|
250
|
-
# # for _ in range(5):
|
251
|
-
# # I.conduct_interview(debug=True)
|
252
|
-
|
253
|
-
# # # replace missing answers
|
254
|
-
# # I
|
255
|
-
# # repr(I)
|
256
|
-
# # eval(repr(I))
|
257
|
-
# # print(I.task_status_logs.status_matrix(20))
|
258
|
-
# status_matrix = I.task_status_logs.status_matrix(20)
|
259
|
-
# numerical_matrix = I.task_status_logs.numerical_matrix(20)
|
260
|
-
# I.task_status_logs.visualize()
|
261
|
-
|
262
|
-
# I.exceptions.print()
|
263
|
-
# I.exceptions.ascii_table()
|
282
|
+
# add ellipsis
|
283
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -5,17 +5,19 @@ import asyncio
|
|
5
5
|
import time
|
6
6
|
import traceback
|
7
7
|
from typing import Generator, Union
|
8
|
+
|
8
9
|
from edsl import CONFIG
|
9
10
|
from edsl.exceptions import InterviewTimeoutError
|
10
|
-
|
11
|
-
from edsl.questions.QuestionBase import QuestionBase
|
11
|
+
|
12
|
+
# from edsl.questions.QuestionBase import QuestionBase
|
12
13
|
from edsl.surveys.base import EndOfSurvey
|
13
14
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
14
15
|
from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
|
15
16
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
16
17
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
17
18
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
18
|
-
|
19
|
+
|
20
|
+
# from edsl.agents.InvigilatorBase import InvigilatorBase
|
19
21
|
|
20
22
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
21
23
|
|
@@ -23,7 +25,7 @@ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
|
23
25
|
class InterviewTaskBuildingMixin:
|
24
26
|
def _build_invigilators(
|
25
27
|
self, debug: bool
|
26
|
-
) -> Generator[InvigilatorBase, None, None]:
|
28
|
+
) -> Generator["InvigilatorBase", None, None]:
|
27
29
|
"""Create an invigilator for each question.
|
28
30
|
|
29
31
|
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
@@ -33,7 +35,7 @@ class InterviewTaskBuildingMixin:
|
|
33
35
|
for question in self.survey.questions:
|
34
36
|
yield self._get_invigilator(question=question, debug=debug)
|
35
37
|
|
36
|
-
def _get_invigilator(self, question: QuestionBase, debug: bool) -> "Invigilator":
|
38
|
+
def _get_invigilator(self, question: "QuestionBase", debug: bool) -> "Invigilator":
|
37
39
|
"""Return an invigilator for the given question.
|
38
40
|
|
39
41
|
:param question: the question to be answered
|
@@ -44,6 +46,7 @@ class InterviewTaskBuildingMixin:
|
|
44
46
|
scenario=self.scenario,
|
45
47
|
model=self.model,
|
46
48
|
debug=debug,
|
49
|
+
survey=self.survey,
|
47
50
|
memory_plan=self.survey.memory_plan,
|
48
51
|
current_answers=self.answers,
|
49
52
|
iteration=self.iteration,
|
@@ -81,7 +84,7 @@ class InterviewTaskBuildingMixin:
|
|
81
84
|
return tuple(tasks) # , invigilators
|
82
85
|
|
83
86
|
def _get_tasks_that_must_be_completed_before(
|
84
|
-
self, *, tasks: list[asyncio.Task], question: QuestionBase
|
87
|
+
self, *, tasks: list[asyncio.Task], question: "QuestionBase"
|
85
88
|
) -> Generator[asyncio.Task, None, None]:
|
86
89
|
"""Return the tasks that must be completed before the given question can be answered.
|
87
90
|
|
@@ -97,7 +100,7 @@ class InterviewTaskBuildingMixin:
|
|
97
100
|
def _create_question_task(
|
98
101
|
self,
|
99
102
|
*,
|
100
|
-
question: QuestionBase,
|
103
|
+
question: "QuestionBase",
|
101
104
|
tasks_that_must_be_completed_before: list[asyncio.Task],
|
102
105
|
model_buckets: ModelBuckets,
|
103
106
|
debug: bool,
|
@@ -149,15 +152,17 @@ class InterviewTaskBuildingMixin:
|
|
149
152
|
async def _answer_question_and_record_task(
|
150
153
|
self,
|
151
154
|
*,
|
152
|
-
question: QuestionBase,
|
155
|
+
question: "QuestionBase",
|
153
156
|
debug: bool,
|
154
157
|
task=None,
|
155
|
-
) -> AgentResponseDict:
|
158
|
+
) -> "AgentResponseDict":
|
156
159
|
"""Answer a question and records the task.
|
157
160
|
|
158
161
|
This in turn calls the the passed-in agent's async_answer_question method, which returns a response dictionary.
|
159
162
|
Note that is updates answers dictionary with the response.
|
160
163
|
"""
|
164
|
+
from edsl.data_transfer_models import AgentResponseDict
|
165
|
+
|
161
166
|
try:
|
162
167
|
invigilator = self._get_invigilator(question, debug=debug)
|
163
168
|
|
@@ -170,24 +175,14 @@ class InterviewTaskBuildingMixin:
|
|
170
175
|
|
171
176
|
self._add_answer(response=response, question=question)
|
172
177
|
|
173
|
-
# With the answer to the question, we can now cancel any skipped questions
|
174
178
|
self._cancel_skipped_questions(question)
|
175
179
|
return AgentResponseDict(**response)
|
176
180
|
except Exception as e:
|
177
181
|
raise e
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
# # Extract and print the traceback info
|
183
|
-
# tb = e.__traceback__
|
184
|
-
# while tb is not None:
|
185
|
-
# print(f"File {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}, in {tb.tb_frame.f_code.co_name}")
|
186
|
-
# tb = tb.tb_next
|
187
|
-
# breakpoint()
|
188
|
-
# raise e
|
189
|
-
|
190
|
-
def _add_answer(self, response: AgentResponseDict, question: QuestionBase) -> None:
|
182
|
+
|
183
|
+
def _add_answer(
|
184
|
+
self, response: "AgentResponseDict", question: "QuestionBase"
|
185
|
+
) -> None:
|
191
186
|
"""Add the answer to the answers dictionary.
|
192
187
|
|
193
188
|
:param response: the response to the question.
|
@@ -195,7 +190,7 @@ class InterviewTaskBuildingMixin:
|
|
195
190
|
"""
|
196
191
|
self.answers.add_answer(response=response, question=question)
|
197
192
|
|
198
|
-
def _skip_this_question(self, current_question: QuestionBase) -> bool:
|
193
|
+
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
199
194
|
"""Determine if the current question should be skipped.
|
200
195
|
|
201
196
|
:param current_question: the question to be answered.
|
@@ -1,29 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import time
|
3
3
|
import asyncio
|
4
|
-
import
|
4
|
+
import time
|
5
5
|
from contextlib import contextmanager
|
6
6
|
|
7
7
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union
|
8
8
|
|
9
|
-
from rich.live import Live
|
10
|
-
from rich.console import Console
|
11
|
-
|
12
9
|
from edsl import shared_globals
|
13
|
-
from edsl.results import Results, Result
|
14
|
-
|
15
10
|
from edsl.jobs.interviews.Interview import Interview
|
16
|
-
from edsl.utilities.decorators import jupyter_nb_handler
|
17
|
-
|
18
|
-
# from edsl.jobs.Jobs import Jobs
|
19
11
|
from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
|
20
|
-
from edsl.language_models import LanguageModel
|
21
|
-
from edsl.data.Cache import Cache
|
22
|
-
|
23
12
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
24
13
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
25
|
-
|
26
|
-
import time
|
14
|
+
from edsl.utilities.decorators import jupyter_nb_handler
|
27
15
|
|
28
16
|
|
29
17
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
@@ -42,13 +30,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
42
30
|
|
43
31
|
async def run_async_generator(
|
44
32
|
self,
|
45
|
-
cache: Cache,
|
33
|
+
cache: "Cache",
|
46
34
|
n: int = 1,
|
47
35
|
debug: bool = False,
|
48
36
|
stop_on_exception: bool = False,
|
49
37
|
sidecar_model: "LanguageModel" = None,
|
50
38
|
total_interviews: Optional[List["Interview"]] = None,
|
51
|
-
) -> AsyncGenerator[Result, None]:
|
39
|
+
) -> AsyncGenerator["Result", None]:
|
52
40
|
"""Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
53
41
|
|
54
42
|
Completed tasks are yielded as they are completed.
|
@@ -100,6 +88,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
100
88
|
self.total_interviews.append(interview)
|
101
89
|
|
102
90
|
async def run_async(self, cache=None) -> Results:
|
91
|
+
from edsl.results.Results import Results
|
92
|
+
|
103
93
|
if cache is None:
|
104
94
|
self.cache = Cache()
|
105
95
|
else:
|
@@ -110,6 +100,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
110
100
|
return Results(survey=self.jobs.survey, data=data)
|
111
101
|
|
112
102
|
def simple_run(self):
|
103
|
+
from edsl.results.Results import Results
|
104
|
+
|
113
105
|
data = asyncio.run(self.run_async())
|
114
106
|
return Results(survey=self.jobs.survey, data=data)
|
115
107
|
|
@@ -169,6 +161,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
169
161
|
question_name + "_raw_model_response"
|
170
162
|
] = result["raw_model_response"]
|
171
163
|
|
164
|
+
from edsl.results.Result import Result
|
165
|
+
|
172
166
|
result = Result(
|
173
167
|
agent=interview.agent,
|
174
168
|
scenario=interview.scenario,
|
@@ -197,6 +191,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
197
191
|
print_exceptions: bool = True,
|
198
192
|
) -> "Coroutine":
|
199
193
|
"""Runs a collection of interviews, handling both async and sync contexts."""
|
194
|
+
from rich.console import Console
|
195
|
+
|
200
196
|
console = Console()
|
201
197
|
self.results = []
|
202
198
|
self.start_time = time.monotonic()
|
@@ -204,6 +200,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
204
200
|
self.cache = cache
|
205
201
|
self.sidecar_model = sidecar_model
|
206
202
|
|
203
|
+
from edsl.results.Results import Results
|
204
|
+
|
207
205
|
if not progress_bar:
|
208
206
|
# print("Running without progress bar")
|
209
207
|
with cache as c:
|
@@ -225,6 +223,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
225
223
|
results = Results(survey=self.jobs.survey, data=self.results)
|
226
224
|
else:
|
227
225
|
# print("Running with progress bar")
|
226
|
+
from rich.live import Live
|
227
|
+
from rich.console import Console
|
228
228
|
|
229
229
|
def generate_table():
|
230
230
|
return self.status_table(self.results, self.elapsed_time)
|
@@ -144,12 +144,16 @@ class QuestionTaskCreator(UserList):
|
|
144
144
|
self.task_status = TaskStatus.FAILED
|
145
145
|
raise e
|
146
146
|
|
147
|
-
if "
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
147
|
+
if results.get("cache_used", False):
|
148
|
+
self.tokens_bucket.add_tokens(requested_tokens)
|
149
|
+
self.requests_bucket.add_tokens(1)
|
150
|
+
self.from_cache = True
|
151
|
+
# Turbo mode means that we don't wait for tokens or requests.
|
152
|
+
self.tokens_bucket.turbo_mode_on()
|
153
|
+
self.requests_bucket.turbo_mode_on()
|
154
|
+
else:
|
155
|
+
self.tokens_bucket.turbo_mode_off()
|
156
|
+
self.requests_bucket.turbo_mode_off()
|
153
157
|
|
154
158
|
_ = results.pop("cached_response", None)
|
155
159
|
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1
1
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
2
|
-
from matplotlib import pyplot as plt
|
3
2
|
from typing import List, Optional
|
4
|
-
|
5
|
-
import matplotlib.pyplot as plt
|
6
3
|
from io import BytesIO
|
7
4
|
import base64
|
8
5
|
|
@@ -75,6 +72,8 @@ class TaskHistory:
|
|
75
72
|
|
76
73
|
def plot_completion_times(self):
|
77
74
|
"""Plot the completion times for each task."""
|
75
|
+
import matplotlib.pyplot as plt
|
76
|
+
|
78
77
|
updates = self.get_updates()
|
79
78
|
|
80
79
|
elapsed = [update.max_time - update.min_time for update in updates]
|
@@ -126,6 +125,8 @@ class TaskHistory:
|
|
126
125
|
rows = int(len(TaskStatus) ** 0.5) + 1
|
127
126
|
cols = (len(TaskStatus) + rows - 1) // rows # Ensure all plots fit
|
128
127
|
|
128
|
+
import matplotlib.pyplot as plt
|
129
|
+
|
129
130
|
fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
|
130
131
|
axes = axes.flatten() # Flatten in case of a single row/column
|
131
132
|
|