edsl 0.1.29.dev6__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +6 -3
- edsl/__init__.py +23 -23
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +43 -40
- edsl/agents/AgentList.py +23 -22
- edsl/agents/Invigilator.py +19 -2
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conversation/car_buying.py +1 -1
- edsl/coop/utils.py +28 -1
- edsl/data/Cache.py +41 -18
- edsl/data/CacheEntry.py +6 -7
- edsl/data/SQLiteDict.py +11 -3
- edsl/data_transfer_models.py +4 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +86 -33
- edsl/jobs/buckets/ModelBuckets.py +14 -2
- edsl/jobs/buckets/TokenBucket.py +32 -5
- edsl/jobs/interviews/Interview.py +99 -79
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +18 -24
- edsl/jobs/runners/JobsRunnerAsyncio.py +16 -16
- edsl/jobs/tasks/QuestionTaskCreator.py +10 -6
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +17 -17
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +16 -10
- edsl/questions/QuestionBase.py +6 -2
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +7 -5
- edsl/questions/QuestionFunctional.py +34 -5
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +68 -12
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +46 -4
- edsl/results/DatasetExportMixin.py +570 -0
- edsl/results/Result.py +66 -70
- edsl/results/Results.py +160 -68
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +1 -4
- edsl/scenarios/FileStore.py +299 -0
- edsl/scenarios/Scenario.py +16 -24
- edsl/scenarios/ScenarioList.py +25 -14
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/Study.py +5 -7
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +52 -15
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/METADATA +1 -1
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/RECORD +65 -61
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/WHEEL +1 -1
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/LICENSE +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -1,30 +1,15 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
-
import os
|
4
3
|
import warnings
|
5
4
|
from itertools import product
|
6
5
|
from typing import Optional, Union, Sequence, Generator
|
7
|
-
|
8
|
-
from edsl.agents import Agent
|
9
|
-
from edsl.agents.AgentList import AgentList
|
6
|
+
|
10
7
|
from edsl.Base import Base
|
11
|
-
from edsl.data.Cache import Cache
|
12
|
-
from edsl.data.CacheHandler import CacheHandler
|
13
|
-
from edsl.results.Dataset import Dataset
|
14
8
|
|
15
|
-
from edsl.exceptions.jobs import MissingRemoteInferenceError
|
16
9
|
from edsl.exceptions import MissingAPIKeyError
|
17
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
18
11
|
from edsl.jobs.interviews.Interview import Interview
|
19
|
-
from edsl.language_models import LanguageModel
|
20
|
-
from edsl.results import Results
|
21
|
-
from edsl.scenarios import Scenario
|
22
|
-
from edsl import ScenarioList
|
23
|
-
from edsl.surveys import Survey
|
24
12
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
25
|
-
|
26
|
-
from edsl.language_models.ModelList import ModelList
|
27
|
-
|
28
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
29
14
|
|
30
15
|
|
@@ -37,10 +22,10 @@ class Jobs(Base):
|
|
37
22
|
|
38
23
|
def __init__(
|
39
24
|
self,
|
40
|
-
survey: Survey,
|
41
|
-
agents: Optional[list[Agent]] = None,
|
42
|
-
models: Optional[list[LanguageModel]] = None,
|
43
|
-
scenarios: Optional[list[Scenario]] = None,
|
25
|
+
survey: "Survey",
|
26
|
+
agents: Optional[list["Agent"]] = None,
|
27
|
+
models: Optional[list["LanguageModel"]] = None,
|
28
|
+
scenarios: Optional[list["Scenario"]] = None,
|
44
29
|
):
|
45
30
|
"""Initialize a Jobs instance.
|
46
31
|
|
@@ -50,8 +35,8 @@ class Jobs(Base):
|
|
50
35
|
:param scenarios: a list of scenarios
|
51
36
|
"""
|
52
37
|
self.survey = survey
|
53
|
-
self.agents: AgentList = agents
|
54
|
-
self.scenarios: ScenarioList = scenarios
|
38
|
+
self.agents: "AgentList" = agents
|
39
|
+
self.scenarios: "ScenarioList" = scenarios
|
55
40
|
self.models = models
|
56
41
|
|
57
42
|
self.__bucket_collection = None
|
@@ -62,6 +47,8 @@ class Jobs(Base):
|
|
62
47
|
|
63
48
|
@models.setter
|
64
49
|
def models(self, value):
|
50
|
+
from edsl import ModelList
|
51
|
+
|
65
52
|
if value:
|
66
53
|
if not isinstance(value, ModelList):
|
67
54
|
self._models = ModelList(value)
|
@@ -76,6 +63,8 @@ class Jobs(Base):
|
|
76
63
|
|
77
64
|
@agents.setter
|
78
65
|
def agents(self, value):
|
66
|
+
from edsl import AgentList
|
67
|
+
|
79
68
|
if value:
|
80
69
|
if not isinstance(value, AgentList):
|
81
70
|
self._agents = AgentList(value)
|
@@ -90,6 +79,8 @@ class Jobs(Base):
|
|
90
79
|
|
91
80
|
@scenarios.setter
|
92
81
|
def scenarios(self, value):
|
82
|
+
from edsl import ScenarioList
|
83
|
+
|
93
84
|
if value:
|
94
85
|
if not isinstance(value, ScenarioList):
|
95
86
|
self._scenarios = ScenarioList(value)
|
@@ -101,10 +92,10 @@ class Jobs(Base):
|
|
101
92
|
def by(
|
102
93
|
self,
|
103
94
|
*args: Union[
|
104
|
-
Agent,
|
105
|
-
Scenario,
|
106
|
-
LanguageModel,
|
107
|
-
Sequence[Union[Agent, Scenario, LanguageModel]],
|
95
|
+
"Agent",
|
96
|
+
"Scenario",
|
97
|
+
"LanguageModel",
|
98
|
+
Sequence[Union["Agent", "Scenario", "LanguageModel"]],
|
108
99
|
],
|
109
100
|
) -> Jobs:
|
110
101
|
"""
|
@@ -144,7 +135,7 @@ class Jobs(Base):
|
|
144
135
|
setattr(self, objects_key, new_objects) # update the job
|
145
136
|
return self
|
146
137
|
|
147
|
-
def prompts(self) -> Dataset:
|
138
|
+
def prompts(self) -> "Dataset":
|
148
139
|
"""Return a Dataset of prompts that will be used.
|
149
140
|
|
150
141
|
|
@@ -160,6 +151,7 @@ class Jobs(Base):
|
|
160
151
|
user_prompts = []
|
161
152
|
system_prompts = []
|
162
153
|
scenario_indices = []
|
154
|
+
from edsl.results.Dataset import Dataset
|
163
155
|
|
164
156
|
for interview_index, interview in enumerate(interviews):
|
165
157
|
invigilators = list(interview._build_invigilators(debug=False))
|
@@ -182,7 +174,10 @@ class Jobs(Base):
|
|
182
174
|
|
183
175
|
@staticmethod
|
184
176
|
def _get_container_class(object):
|
185
|
-
from edsl import AgentList
|
177
|
+
from edsl.agents.AgentList import AgentList
|
178
|
+
from edsl.agents.Agent import Agent
|
179
|
+
from edsl.scenarios.Scenario import Scenario
|
180
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
186
181
|
|
187
182
|
if isinstance(object, Agent):
|
188
183
|
return AgentList
|
@@ -218,6 +213,10 @@ class Jobs(Base):
|
|
218
213
|
def _get_current_objects_of_this_type(
|
219
214
|
self, object: Union[Agent, Scenario, LanguageModel]
|
220
215
|
) -> tuple[list, str]:
|
216
|
+
from edsl.agents.Agent import Agent
|
217
|
+
from edsl.scenarios.Scenario import Scenario
|
218
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
219
|
+
|
221
220
|
"""Return the current objects of the same type as the first argument.
|
222
221
|
|
223
222
|
>>> from edsl.jobs import Jobs
|
@@ -246,6 +245,9 @@ class Jobs(Base):
|
|
246
245
|
@staticmethod
|
247
246
|
def _get_empty_container_object(object):
|
248
247
|
from edsl import AgentList
|
248
|
+
from edsl import Agent
|
249
|
+
from edsl import Scenario
|
250
|
+
from edsl import ScenarioList
|
249
251
|
|
250
252
|
if isinstance(object, Agent):
|
251
253
|
return AgentList([])
|
@@ -310,6 +312,10 @@ class Jobs(Base):
|
|
310
312
|
with us filling in defaults.
|
311
313
|
"""
|
312
314
|
# if no agents, models, or scenarios are set, set them to defaults
|
315
|
+
from edsl.agents.Agent import Agent
|
316
|
+
from edsl.language_models.registry import Model
|
317
|
+
from edsl.scenarios.Scenario import Scenario
|
318
|
+
|
313
319
|
self.agents = self.agents or [Agent()]
|
314
320
|
self.models = self.models or [Model()]
|
315
321
|
self.scenarios = self.scenarios or [Scenario()]
|
@@ -325,6 +331,7 @@ class Jobs(Base):
|
|
325
331
|
These buckets are used to track API calls and token usage.
|
326
332
|
|
327
333
|
>>> from edsl.jobs import Jobs
|
334
|
+
>>> from edsl import Model
|
328
335
|
>>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
|
329
336
|
>>> bc = j.create_bucket_collection()
|
330
337
|
>>> bc
|
@@ -368,6 +375,8 @@ class Jobs(Base):
|
|
368
375
|
"""Check if the parameters in the survey and scenarios are consistent.
|
369
376
|
|
370
377
|
>>> from edsl import QuestionFreeText
|
378
|
+
>>> from edsl import Survey
|
379
|
+
>>> from edsl import Scenario
|
371
380
|
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
372
381
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
373
382
|
>>> with warnings.catch_warnings(record=True) as w:
|
@@ -452,6 +461,13 @@ class Jobs(Base):
|
|
452
461
|
remote_inference = False
|
453
462
|
|
454
463
|
if remote_inference:
|
464
|
+
from edsl.agents.Agent import Agent
|
465
|
+
from edsl.language_models.registry import Model
|
466
|
+
from edsl.results.Result import Result
|
467
|
+
from edsl.results.Results import Results
|
468
|
+
from edsl.scenarios.Scenario import Scenario
|
469
|
+
from edsl.surveys.Survey import Survey
|
470
|
+
|
455
471
|
self._output("Remote inference activated. Sending job to server...")
|
456
472
|
if remote_cache:
|
457
473
|
self._output(
|
@@ -464,8 +480,25 @@ class Jobs(Base):
|
|
464
480
|
status="queued",
|
465
481
|
)
|
466
482
|
self._output("Job sent!")
|
467
|
-
|
468
|
-
|
483
|
+
# Create mock results object to store job data
|
484
|
+
results = Results(
|
485
|
+
survey=Survey(),
|
486
|
+
data=[
|
487
|
+
Result(
|
488
|
+
agent=Agent.example(),
|
489
|
+
scenario=Scenario.example(),
|
490
|
+
model=Model(),
|
491
|
+
iteration=1,
|
492
|
+
answer={"info": "Remote job details"},
|
493
|
+
)
|
494
|
+
],
|
495
|
+
)
|
496
|
+
results.add_columns_from_dict([remote_job_data])
|
497
|
+
if self.verbose:
|
498
|
+
results.select(["info", "uuid", "status", "version"]).print(
|
499
|
+
format="rich"
|
500
|
+
)
|
501
|
+
return results
|
469
502
|
else:
|
470
503
|
if check_api_keys:
|
471
504
|
for model in self.models + [Model()]:
|
@@ -477,8 +510,12 @@ class Jobs(Base):
|
|
477
510
|
|
478
511
|
# handle cache
|
479
512
|
if cache is None:
|
513
|
+
from edsl.data.CacheHandler import CacheHandler
|
514
|
+
|
480
515
|
cache = CacheHandler().get_cache()
|
481
516
|
if cache is False:
|
517
|
+
from edsl.data.Cache import Cache
|
518
|
+
|
482
519
|
cache = Cache()
|
483
520
|
|
484
521
|
if not remote_cache:
|
@@ -630,6 +667,11 @@ class Jobs(Base):
|
|
630
667
|
@remove_edsl_version
|
631
668
|
def from_dict(cls, data: dict) -> Jobs:
|
632
669
|
"""Creates a Jobs instance from a dictionary."""
|
670
|
+
from edsl import Survey
|
671
|
+
from edsl.agents.Agent import Agent
|
672
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
673
|
+
from edsl.scenarios.Scenario import Scenario
|
674
|
+
|
633
675
|
return cls(
|
634
676
|
survey=Survey.from_dict(data["survey"]),
|
635
677
|
agents=[Agent.from_dict(agent) for agent in data["agents"]],
|
@@ -645,7 +687,9 @@ class Jobs(Base):
|
|
645
687
|
# Example methods
|
646
688
|
#######################
|
647
689
|
@classmethod
|
648
|
-
def example(
|
690
|
+
def example(
|
691
|
+
cls, throw_exception_probability: int = 0, randomize: bool = False
|
692
|
+
) -> Jobs:
|
649
693
|
"""Return an example Jobs instance.
|
650
694
|
|
651
695
|
:param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
|
@@ -655,8 +699,12 @@ class Jobs(Base):
|
|
655
699
|
|
656
700
|
"""
|
657
701
|
import random
|
702
|
+
from uuid import uuid4
|
658
703
|
from edsl.questions import QuestionMultipleChoice
|
659
|
-
from edsl import Agent
|
704
|
+
from edsl.agents.Agent import Agent
|
705
|
+
from edsl.scenarios.Scenario import Scenario
|
706
|
+
|
707
|
+
addition = "" if not randomize else str(uuid4())
|
660
708
|
|
661
709
|
# (status, question, period)
|
662
710
|
agent_answers = {
|
@@ -695,10 +743,15 @@ class Jobs(Base):
|
|
695
743
|
question_options=["Good", "Great", "OK", "Terrible"],
|
696
744
|
question_name="how_feeling_yesterday",
|
697
745
|
)
|
746
|
+
from edsl import Survey, ScenarioList
|
747
|
+
|
698
748
|
base_survey = Survey(questions=[q1, q2])
|
699
749
|
|
700
750
|
scenario_list = ScenarioList(
|
701
|
-
[
|
751
|
+
[
|
752
|
+
Scenario({"period": f"morning{addition}"}),
|
753
|
+
Scenario({"period": "afternoon"}),
|
754
|
+
]
|
702
755
|
)
|
703
756
|
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
704
757
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
1
|
+
# from edsl.jobs.buckets.TokenBucket import TokenBucket
|
2
2
|
|
3
3
|
|
4
4
|
class ModelBuckets:
|
@@ -8,7 +8,7 @@ class ModelBuckets:
|
|
8
8
|
A request is one call to the service. The number of tokens required for a request depends on parameters.
|
9
9
|
"""
|
10
10
|
|
11
|
-
def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
|
11
|
+
def __init__(self, requests_bucket: "TokenBucket", tokens_bucket: "TokenBucket"):
|
12
12
|
"""Initialize the model buckets.
|
13
13
|
|
14
14
|
The requests bucket captures requests per unit of time.
|
@@ -25,9 +25,21 @@ class ModelBuckets:
|
|
25
25
|
tokens_bucket=self.tokens_bucket + other.tokens_bucket,
|
26
26
|
)
|
27
27
|
|
28
|
+
def turbo_mode_on(self):
|
29
|
+
"""Set the refill rate to infinity for both buckets."""
|
30
|
+
self.requests_bucket.turbo_mode_on()
|
31
|
+
self.tokens_bucket.turbo_mode_on()
|
32
|
+
|
33
|
+
def turbo_mode_off(self):
|
34
|
+
"""Restore the refill rate to its original value for both buckets."""
|
35
|
+
self.requests_bucket.turbo_mode_off()
|
36
|
+
self.tokens_bucket.turbo_mode_off()
|
37
|
+
|
28
38
|
@classmethod
|
29
39
|
def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
|
30
40
|
"""Create a bucket with infinite capacity and refill rate."""
|
41
|
+
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
42
|
+
|
31
43
|
return cls(
|
32
44
|
requests_bucket=TokenBucket(
|
33
45
|
bucket_name=model_name,
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
from typing import Union, List, Any
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
|
-
from collections import UserDict
|
5
|
-
from matplotlib import pyplot as plt
|
6
4
|
|
7
5
|
|
8
6
|
class TokenBucket:
|
@@ -19,11 +17,29 @@ class TokenBucket:
|
|
19
17
|
self.bucket_name = bucket_name
|
20
18
|
self.bucket_type = bucket_type
|
21
19
|
self.capacity = capacity # Maximum number of tokens
|
20
|
+
self._old_capacity = capacity
|
22
21
|
self.tokens = capacity # Current number of available tokens
|
23
22
|
self.refill_rate = refill_rate # Rate at which tokens are refilled
|
23
|
+
self._old_refill_rate = refill_rate
|
24
24
|
self.last_refill = time.monotonic() # Last refill time
|
25
|
-
|
26
25
|
self.log: List[Any] = []
|
26
|
+
self.turbo_mode = False
|
27
|
+
|
28
|
+
def turbo_mode_on(self):
|
29
|
+
"""Set the refill rate to infinity."""
|
30
|
+
if self.turbo_mode:
|
31
|
+
pass
|
32
|
+
else:
|
33
|
+
# pass
|
34
|
+
self.turbo_mode = True
|
35
|
+
self.capacity = float("inf")
|
36
|
+
self.refill_rate = float("inf")
|
37
|
+
|
38
|
+
def turbo_mode_off(self):
|
39
|
+
"""Restore the refill rate to its original value."""
|
40
|
+
self.turbo_mode = False
|
41
|
+
self.capacity = self._old_capacity
|
42
|
+
self.refill_rate = self._old_refill_rate
|
27
43
|
|
28
44
|
def __add__(self, other) -> "TokenBucket":
|
29
45
|
"""Combine two token buckets.
|
@@ -57,7 +73,17 @@ class TokenBucket:
|
|
57
73
|
self.log.append((time.monotonic(), self.tokens))
|
58
74
|
|
59
75
|
def refill(self) -> None:
|
60
|
-
"""Refill the bucket with new tokens based on elapsed time.
|
76
|
+
"""Refill the bucket with new tokens based on elapsed time.
|
77
|
+
|
78
|
+
|
79
|
+
|
80
|
+
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
81
|
+
>>> bucket.tokens = 0
|
82
|
+
>>> bucket.refill()
|
83
|
+
>>> bucket.tokens > 0
|
84
|
+
True
|
85
|
+
|
86
|
+
"""
|
61
87
|
now = time.monotonic()
|
62
88
|
elapsed = now - self.last_refill
|
63
89
|
refill_amount = elapsed * self.refill_rate
|
@@ -100,7 +126,7 @@ class TokenBucket:
|
|
100
126
|
raise ValueError(msg)
|
101
127
|
while self.tokens < amount:
|
102
128
|
self.refill()
|
103
|
-
await asyncio.sleep(0.
|
129
|
+
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|
104
130
|
self.tokens -= amount
|
105
131
|
|
106
132
|
now = time.monotonic()
|
@@ -114,6 +140,7 @@ class TokenBucket:
|
|
114
140
|
times, tokens = zip(*self.get_log())
|
115
141
|
start_time = times[0]
|
116
142
|
times = [t - start_time for t in times] # Normalize time to start from 0
|
143
|
+
from matplotlib import pyplot as plt
|
117
144
|
|
118
145
|
plt.figure(figsize=(10, 6))
|
119
146
|
plt.plot(times, tokens, label="Tokens Available")
|
@@ -6,15 +6,9 @@ import asyncio
|
|
6
6
|
import time
|
7
7
|
from typing import Any, Type, List, Generator, Optional
|
8
8
|
|
9
|
-
from edsl.agents import Agent
|
10
|
-
from edsl.language_models import LanguageModel
|
11
|
-
from edsl.scenarios import Scenario
|
12
|
-
from edsl.surveys import Survey
|
13
|
-
|
14
9
|
from edsl.jobs.Answers import Answers
|
15
10
|
from edsl.surveys.base import EndOfSurvey
|
16
11
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
17
|
-
|
18
12
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
19
13
|
|
20
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
@@ -26,6 +20,12 @@ from edsl.jobs.interviews.retry_management import retry_strategy
|
|
26
20
|
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
27
21
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
28
22
|
|
23
|
+
import asyncio
|
24
|
+
|
25
|
+
|
26
|
+
def run_async(coro):
|
27
|
+
return asyncio.run(coro)
|
28
|
+
|
29
29
|
|
30
30
|
class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
31
31
|
"""
|
@@ -36,14 +36,14 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
36
36
|
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
|
-
agent: Agent,
|
40
|
-
survey: Survey,
|
41
|
-
scenario: Scenario,
|
42
|
-
model: Type[LanguageModel],
|
43
|
-
debug: bool = False,
|
39
|
+
agent: "Agent",
|
40
|
+
survey: "Survey",
|
41
|
+
scenario: "Scenario",
|
42
|
+
model: Type["LanguageModel"],
|
43
|
+
debug: Optional[bool] = False,
|
44
44
|
iteration: int = 0,
|
45
|
-
cache: "Cache" = None,
|
46
|
-
sidecar_model: LanguageModel = None,
|
45
|
+
cache: Optional["Cache"] = None,
|
46
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
47
47
|
):
|
48
48
|
"""Initialize the Interview instance.
|
49
49
|
|
@@ -51,6 +51,24 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
51
51
|
:param survey: the survey being administered to the agent.
|
52
52
|
:param scenario: the scenario that populates the survey questions.
|
53
53
|
:param model: the language model used to answer the questions.
|
54
|
+
:param debug: if True, run without calls to the language model.
|
55
|
+
:param iteration: the iteration number of the interview.
|
56
|
+
:param cache: the cache used to store the answers.
|
57
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
58
|
+
|
59
|
+
>>> i = Interview.example()
|
60
|
+
>>> i.task_creators
|
61
|
+
{}
|
62
|
+
|
63
|
+
>>> i.exceptions
|
64
|
+
{}
|
65
|
+
|
66
|
+
>>> _ = asyncio.run(i.async_conduct_interview())
|
67
|
+
>>> i.task_status_logs['q0']
|
68
|
+
[{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
|
69
|
+
|
70
|
+
>>> i.to_index
|
71
|
+
{'q0': 0, 'q1': 1, 'q2': 2}
|
54
72
|
|
55
73
|
"""
|
56
74
|
self.agent = agent
|
@@ -70,7 +88,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
70
88
|
self.exceptions = InterviewExceptionCollection()
|
71
89
|
self._task_status_log_dict = InterviewStatusLog()
|
72
90
|
|
73
|
-
# dictionary mapping question names to their index in the survey.
|
91
|
+
# dictionary mapping question names to their index in the survey.
|
74
92
|
self.to_index = {
|
75
93
|
question_name: index
|
76
94
|
for index, question_name in enumerate(self.survey.question_names)
|
@@ -82,14 +100,16 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
82
100
|
model_buckets: ModelBuckets = None,
|
83
101
|
debug: bool = False,
|
84
102
|
stop_on_exception: bool = False,
|
85
|
-
sidecar_model: Optional[LanguageModel] = None,
|
103
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
86
104
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
87
105
|
"""
|
88
106
|
Conduct an Interview asynchronously.
|
107
|
+
It returns a tuple with the answers and a list of valid results.
|
89
108
|
|
90
109
|
:param model_buckets: a dictionary of token buckets for the model.
|
91
110
|
:param debug: run without calls to LLM.
|
92
111
|
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
112
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
93
113
|
|
94
114
|
Example usage:
|
95
115
|
|
@@ -97,16 +117,37 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
97
117
|
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
98
118
|
>>> result['q0']
|
99
119
|
'yes'
|
120
|
+
|
121
|
+
>>> i = Interview.example(throw_exception = True)
|
122
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
123
|
+
Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
124
|
+
<BLANKLINE>
|
125
|
+
<BLANKLINE>
|
126
|
+
Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
127
|
+
<BLANKLINE>
|
128
|
+
<BLANKLINE>
|
129
|
+
Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
130
|
+
<BLANKLINE>
|
131
|
+
<BLANKLINE>
|
132
|
+
Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
133
|
+
<BLANKLINE>
|
134
|
+
<BLANKLINE>
|
135
|
+
|
136
|
+
>>> i.exceptions
|
137
|
+
{'q0': [{'exception': "Exception('This is a test error')", 'time': ..., 'traceback': ...
|
138
|
+
|
139
|
+
>>> i = Interview.example()
|
140
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
141
|
+
Traceback (most recent call last):
|
142
|
+
...
|
143
|
+
asyncio.exceptions.CancelledError
|
100
144
|
"""
|
101
145
|
self.sidecar_model = sidecar_model
|
102
146
|
|
103
147
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
104
|
-
# print("model_buckets", model_buckets)
|
105
148
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
106
149
|
model_buckets = ModelBuckets.infinity_bucket()
|
107
150
|
|
108
|
-
# model_buckets = ModelBuckets.infinity_bucket()
|
109
|
-
|
110
151
|
## build the tasks using the InterviewTaskBuildingMixin
|
111
152
|
## This is the key part---it creates a task for each question,
|
112
153
|
## with dependencies on the questions that must be answered before this one can be answered.
|
@@ -128,6 +169,14 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
128
169
|
It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
|
129
170
|
If a task is not done, it raises a ValueError.
|
130
171
|
If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
|
172
|
+
|
173
|
+
>>> i = Interview.example()
|
174
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
175
|
+
>>> results = list(i._extract_valid_results())
|
176
|
+
>>> len(results) == len(i.survey)
|
177
|
+
True
|
178
|
+
>>> type(results[0])
|
179
|
+
<class 'edsl.data_transfer_models.AgentResponseDict'>
|
131
180
|
"""
|
132
181
|
assert len(self.tasks) == len(self.invigilators)
|
133
182
|
|
@@ -145,7 +194,18 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
145
194
|
yield result
|
146
195
|
|
147
196
|
def _record_exception(self, task, exception: Exception) -> None:
|
148
|
-
"""Record an exception in the Interview instance.
|
197
|
+
"""Record an exception in the Interview instance.
|
198
|
+
|
199
|
+
It records the exception in the Interview instance, with the task name and the exception entry.
|
200
|
+
|
201
|
+
>>> i = Interview.example()
|
202
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
203
|
+
>>> i.exceptions
|
204
|
+
{}
|
205
|
+
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
206
|
+
>>> i.exceptions
|
207
|
+
{'q0': [{'exception': "Exception('An exception occurred.')", 'time': ..., 'traceback': 'NoneType: None\\n'}]}
|
208
|
+
"""
|
149
209
|
exception_entry = InterviewExceptionEntry(
|
150
210
|
exception=repr(exception),
|
151
211
|
time=time.time(),
|
@@ -161,6 +221,10 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
161
221
|
It is used to determine the order in which questions should be answered.
|
162
222
|
This reflects both agent 'memory' considerations and 'skip' logic.
|
163
223
|
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
224
|
+
|
225
|
+
>>> i = Interview.example()
|
226
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
227
|
+
True
|
164
228
|
"""
|
165
229
|
return self.survey.dag(textify=True)
|
166
230
|
|
@@ -171,8 +235,15 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
171
235
|
"""Return a string representation of the Interview instance."""
|
172
236
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
173
237
|
|
174
|
-
def duplicate(self, iteration: int, cache: Cache) -> Interview:
|
175
|
-
"""Duplicate the interview, but with a new iteration number and cache.
|
238
|
+
def duplicate(self, iteration: int, cache: "Cache") -> Interview:
|
239
|
+
"""Duplicate the interview, but with a new iteration number and cache.
|
240
|
+
|
241
|
+
>>> i = Interview.example()
|
242
|
+
>>> i2 = i.duplicate(1, None)
|
243
|
+
>>> i.iteration + 1 == i2.iteration
|
244
|
+
True
|
245
|
+
|
246
|
+
"""
|
176
247
|
return Interview(
|
177
248
|
agent=self.agent,
|
178
249
|
survey=self.survey,
|
@@ -183,7 +254,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
183
254
|
)
|
184
255
|
|
185
256
|
@classmethod
|
186
|
-
def example(self):
|
257
|
+
def example(self, throw_exception: bool = False) -> Interview:
|
187
258
|
"""Return an example Interview instance."""
|
188
259
|
from edsl.agents import Agent
|
189
260
|
from edsl.surveys import Survey
|
@@ -198,66 +269,15 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
198
269
|
survey = Survey.example()
|
199
270
|
scenario = Scenario.example()
|
200
271
|
model = LanguageModel.example()
|
272
|
+
if throw_exception:
|
273
|
+
model = LanguageModel.example(test_model=True, throw_exception=True)
|
274
|
+
agent = Agent.example()
|
275
|
+
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
201
276
|
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
202
277
|
|
203
278
|
|
204
279
|
if __name__ == "__main__":
|
205
280
|
import doctest
|
206
281
|
|
207
|
-
|
208
|
-
|
209
|
-
# from edsl.agents import Agent
|
210
|
-
# from edsl.surveys import Survey
|
211
|
-
# from edsl.scenarios import Scenario
|
212
|
-
# from edsl.questions import QuestionMultipleChoice
|
213
|
-
|
214
|
-
# # from edsl.jobs.Interview import Interview
|
215
|
-
|
216
|
-
# # a survey with skip logic
|
217
|
-
# q0 = QuestionMultipleChoice(
|
218
|
-
# question_text="Do you like school?",
|
219
|
-
# question_options=["yes", "no"],
|
220
|
-
# question_name="q0",
|
221
|
-
# )
|
222
|
-
# q1 = QuestionMultipleChoice(
|
223
|
-
# question_text="Why not?",
|
224
|
-
# question_options=["killer bees in cafeteria", "other"],
|
225
|
-
# question_name="q1",
|
226
|
-
# )
|
227
|
-
# q2 = QuestionMultipleChoice(
|
228
|
-
# question_text="Why?",
|
229
|
-
# question_options=["**lack*** of killer bees in cafeteria", "other"],
|
230
|
-
# question_name="q2",
|
231
|
-
# )
|
232
|
-
# s = Survey(questions=[q0, q1, q2])
|
233
|
-
# s = s.add_rule(q0, "q0 == 'yes'", q2)
|
234
|
-
|
235
|
-
# # create an interview
|
236
|
-
# a = Agent(traits=None)
|
237
|
-
|
238
|
-
# def direct_question_answering_method(self, question, scenario):
|
239
|
-
# """Answer a question directly."""
|
240
|
-
# raise Exception("Error!")
|
241
|
-
# # return "yes"
|
242
|
-
|
243
|
-
# a.add_direct_question_answering_method(direct_question_answering_method)
|
244
|
-
# scenario = Scenario()
|
245
|
-
# m = Model()
|
246
|
-
# I = Interview(agent=a, survey=s, scenario=scenario, model=m)
|
247
|
-
|
248
|
-
# result = asyncio.run(I.async_conduct_interview())
|
249
|
-
# # # conduct five interviews
|
250
|
-
# # for _ in range(5):
|
251
|
-
# # I.conduct_interview(debug=True)
|
252
|
-
|
253
|
-
# # # replace missing answers
|
254
|
-
# # I
|
255
|
-
# # repr(I)
|
256
|
-
# # eval(repr(I))
|
257
|
-
# # print(I.task_status_logs.status_matrix(20))
|
258
|
-
# status_matrix = I.task_status_logs.status_matrix(20)
|
259
|
-
# numerical_matrix = I.task_status_logs.numerical_matrix(20)
|
260
|
-
# I.task_status_logs.visualize()
|
261
|
-
|
262
|
-
# I.exceptions.print()
|
263
|
-
# I.exceptions.ascii_table()
|
282
|
+
# add ellipsis
|
283
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|