edsl 0.1.30.dev4__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +7 -2
- edsl/agents/PromptConstructionMixin.py +18 -1
- edsl/config.py +4 -0
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/coop.py +4 -0
- edsl/coop/utils.py +9 -1
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +2 -0
- edsl/inference_services/DeepInfraService.py +6 -91
- edsl/inference_services/GroqService.py +18 -0
- edsl/inference_services/InferenceServicesCollection.py +13 -5
- edsl/inference_services/OpenAIService.py +64 -21
- edsl/inference_services/registry.py +2 -1
- edsl/jobs/Jobs.py +80 -33
- edsl/jobs/buckets/TokenBucket.py +24 -5
- edsl/jobs/interviews/Interview.py +122 -75
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +58 -52
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
- edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -5
- edsl/jobs/tasks/TaskCreators.py +8 -2
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +135 -75
- edsl/language_models/ModelList.py +8 -2
- edsl/language_models/registry.py +16 -0
- edsl/questions/QuestionFunctional.py +34 -2
- edsl/questions/QuestionMultipleChoice.py +58 -8
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/descriptors.py +42 -2
- edsl/results/DatasetExportMixin.py +258 -75
- edsl/results/Result.py +53 -5
- edsl/results/Results.py +66 -27
- edsl/results/ResultsToolsMixin.py +1 -1
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +59 -21
- edsl/scenarios/ScenarioListExportMixin.py +16 -5
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/study/Study.py +2 -2
- edsl/surveys/Survey.py +35 -1
- {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/METADATA +4 -2
- {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/RECORD +47 -45
- {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/WHEEL +1 -1
- {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/LICENSE +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -3,9 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import warnings
|
4
4
|
from itertools import product
|
5
5
|
from typing import Optional, Union, Sequence, Generator
|
6
|
-
|
7
6
|
from edsl.Base import Base
|
8
|
-
|
9
7
|
from edsl.exceptions import MissingAPIKeyError
|
10
8
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
11
9
|
from edsl.jobs.interviews.Interview import Interview
|
@@ -321,7 +319,11 @@ class Jobs(Base):
|
|
321
319
|
self.scenarios = self.scenarios or [Scenario()]
|
322
320
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
323
321
|
yield Interview(
|
324
|
-
survey=self.survey,
|
322
|
+
survey=self.survey,
|
323
|
+
agent=agent,
|
324
|
+
scenario=scenario,
|
325
|
+
model=model,
|
326
|
+
skip_retry=self.skip_retry,
|
325
327
|
)
|
326
328
|
|
327
329
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -411,6 +413,12 @@ class Jobs(Base):
|
|
411
413
|
if warn:
|
412
414
|
warnings.warn(message)
|
413
415
|
|
416
|
+
@property
|
417
|
+
def skip_retry(self):
|
418
|
+
if not hasattr(self, "_skip_retry"):
|
419
|
+
return False
|
420
|
+
return self._skip_retry
|
421
|
+
|
414
422
|
def run(
|
415
423
|
self,
|
416
424
|
n: int = 1,
|
@@ -425,6 +433,7 @@ class Jobs(Base):
|
|
425
433
|
print_exceptions=True,
|
426
434
|
remote_cache_description: Optional[str] = None,
|
427
435
|
remote_inference_description: Optional[str] = None,
|
436
|
+
skip_retry: bool = False,
|
428
437
|
) -> Results:
|
429
438
|
"""
|
430
439
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -443,6 +452,7 @@ class Jobs(Base):
|
|
443
452
|
from edsl.coop.coop import Coop
|
444
453
|
|
445
454
|
self._check_parameters()
|
455
|
+
self._skip_retry = skip_retry
|
446
456
|
|
447
457
|
if batch_mode is not None:
|
448
458
|
raise NotImplementedError(
|
@@ -461,12 +471,11 @@ class Jobs(Base):
|
|
461
471
|
remote_inference = False
|
462
472
|
|
463
473
|
if remote_inference:
|
464
|
-
|
465
|
-
from
|
466
|
-
from edsl.
|
467
|
-
|
468
|
-
|
469
|
-
from edsl.surveys.Survey import Survey
|
474
|
+
import time
|
475
|
+
from datetime import datetime
|
476
|
+
from edsl.config import CONFIG
|
477
|
+
|
478
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
470
479
|
|
471
480
|
self._output("Remote inference activated. Sending job to server...")
|
472
481
|
if remote_cache:
|
@@ -474,33 +483,60 @@ class Jobs(Base):
|
|
474
483
|
"Remote caching activated. The remote cache will be used for this job."
|
475
484
|
)
|
476
485
|
|
477
|
-
|
486
|
+
remote_job_creation_data = coop.remote_inference_create(
|
478
487
|
self,
|
479
488
|
description=remote_inference_description,
|
480
489
|
status="queued",
|
490
|
+
iterations=n,
|
481
491
|
)
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
492
|
+
time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
|
493
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
494
|
+
print(f"Remote inference started (Job uuid={job_uuid}).")
|
495
|
+
# print(f"Job queued at {time_queued}.")
|
496
|
+
job_in_queue = True
|
497
|
+
while job_in_queue:
|
498
|
+
remote_job_data = coop.remote_inference_get(job_uuid)
|
499
|
+
status = remote_job_data.get("status")
|
500
|
+
if status == "cancelled":
|
501
|
+
print("\r" + " " * 80 + "\r", end="")
|
502
|
+
print("Job cancelled by the user.")
|
503
|
+
print(
|
504
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
493
505
|
)
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
506
|
+
return None
|
507
|
+
elif status == "failed":
|
508
|
+
print("\r" + " " * 80 + "\r", end="")
|
509
|
+
print("Job failed.")
|
510
|
+
print(
|
511
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
512
|
+
)
|
513
|
+
return None
|
514
|
+
elif status == "completed":
|
515
|
+
results_uuid = remote_job_data.get("results_uuid")
|
516
|
+
results = coop.get(results_uuid, expected_object_type="results")
|
517
|
+
print("\r" + " " * 80 + "\r", end="")
|
518
|
+
print(
|
519
|
+
f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
|
520
|
+
)
|
521
|
+
return results
|
522
|
+
else:
|
523
|
+
duration = 10 if len(self) < 10 else 60
|
524
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
525
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
526
|
+
start_time = time.time()
|
527
|
+
i = 0
|
528
|
+
while time.time() - start_time < duration:
|
529
|
+
print(
|
530
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
531
|
+
end="",
|
532
|
+
flush=True,
|
533
|
+
)
|
534
|
+
time.sleep(0.1)
|
535
|
+
i += 1
|
502
536
|
else:
|
503
537
|
if check_api_keys:
|
538
|
+
from edsl import Model
|
539
|
+
|
504
540
|
for model in self.models + [Model()]:
|
505
541
|
if not model.has_valid_api_key():
|
506
542
|
raise MissingAPIKeyError(
|
@@ -606,9 +642,9 @@ class Jobs(Base):
|
|
606
642
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
607
643
|
return results
|
608
644
|
|
609
|
-
async def run_async(self, cache=None, **kwargs):
|
645
|
+
async def run_async(self, cache=None, n=1, **kwargs):
|
610
646
|
"""Run the job asynchronously."""
|
611
|
-
results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
|
647
|
+
results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
612
648
|
return results
|
613
649
|
|
614
650
|
def all_question_parameters(self):
|
@@ -688,7 +724,10 @@ class Jobs(Base):
|
|
688
724
|
#######################
|
689
725
|
@classmethod
|
690
726
|
def example(
|
691
|
-
cls,
|
727
|
+
cls,
|
728
|
+
throw_exception_probability: int = 0,
|
729
|
+
randomize: bool = False,
|
730
|
+
test_model=False,
|
692
731
|
) -> Jobs:
|
693
732
|
"""Return an example Jobs instance.
|
694
733
|
|
@@ -706,6 +745,11 @@ class Jobs(Base):
|
|
706
745
|
|
707
746
|
addition = "" if not randomize else str(uuid4())
|
708
747
|
|
748
|
+
if test_model:
|
749
|
+
from edsl.language_models import LanguageModel
|
750
|
+
|
751
|
+
m = LanguageModel.example(test_model=True)
|
752
|
+
|
709
753
|
# (status, question, period)
|
710
754
|
agent_answers = {
|
711
755
|
("Joyful", "how_feeling", "morning"): "OK",
|
@@ -753,7 +797,10 @@ class Jobs(Base):
|
|
753
797
|
Scenario({"period": "afternoon"}),
|
754
798
|
]
|
755
799
|
)
|
756
|
-
|
800
|
+
if test_model:
|
801
|
+
job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
|
802
|
+
else:
|
803
|
+
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
757
804
|
|
758
805
|
return job
|
759
806
|
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -30,6 +30,7 @@ class TokenBucket:
|
|
30
30
|
if self.turbo_mode:
|
31
31
|
pass
|
32
32
|
else:
|
33
|
+
# pass
|
33
34
|
self.turbo_mode = True
|
34
35
|
self.capacity = float("inf")
|
35
36
|
self.refill_rate = float("inf")
|
@@ -72,7 +73,17 @@ class TokenBucket:
|
|
72
73
|
self.log.append((time.monotonic(), self.tokens))
|
73
74
|
|
74
75
|
def refill(self) -> None:
|
75
|
-
"""Refill the bucket with new tokens based on elapsed time.
|
76
|
+
"""Refill the bucket with new tokens based on elapsed time.
|
77
|
+
|
78
|
+
|
79
|
+
|
80
|
+
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
81
|
+
>>> bucket.tokens = 0
|
82
|
+
>>> bucket.refill()
|
83
|
+
>>> bucket.tokens > 0
|
84
|
+
True
|
85
|
+
|
86
|
+
"""
|
76
87
|
now = time.monotonic()
|
77
88
|
elapsed = now - self.last_refill
|
78
89
|
refill_amount = elapsed * self.refill_rate
|
@@ -89,7 +100,9 @@ class TokenBucket:
|
|
89
100
|
available_tokens = min(self.capacity, self.tokens + refill_amount)
|
90
101
|
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
91
102
|
|
92
|
-
async def get_tokens(
|
103
|
+
async def get_tokens(
|
104
|
+
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
105
|
+
) -> None:
|
93
106
|
"""Wait for the specified number of tokens to become available.
|
94
107
|
|
95
108
|
|
@@ -105,14 +118,20 @@ class TokenBucket:
|
|
105
118
|
True
|
106
119
|
|
107
120
|
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
108
|
-
>>> asyncio.run(bucket.get_tokens(11))
|
121
|
+
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=False))
|
109
122
|
Traceback (most recent call last):
|
110
123
|
...
|
111
124
|
ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
|
125
|
+
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
|
112
126
|
"""
|
113
127
|
if amount > self.capacity:
|
114
|
-
|
115
|
-
|
128
|
+
if not cheat_bucket_capacity:
|
129
|
+
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
130
|
+
raise ValueError(msg)
|
131
|
+
else:
|
132
|
+
self.tokens = 0 # clear the bucket but let it go through
|
133
|
+
return
|
134
|
+
|
116
135
|
while self.tokens < amount:
|
117
136
|
self.refill()
|
118
137
|
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|
@@ -14,12 +14,18 @@ from edsl.jobs.tasks.TaskCreators import TaskCreators
|
|
14
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
15
15
|
from edsl.jobs.interviews.interview_exception_tracking import (
|
16
16
|
InterviewExceptionCollection,
|
17
|
-
InterviewExceptionEntry,
|
18
17
|
)
|
18
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
19
19
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
20
20
|
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
21
21
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
22
22
|
|
23
|
+
import asyncio
|
24
|
+
|
25
|
+
|
26
|
+
def run_async(coro):
|
27
|
+
return asyncio.run(coro)
|
28
|
+
|
23
29
|
|
24
30
|
class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
25
31
|
"""
|
@@ -36,8 +42,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
36
42
|
model: Type["LanguageModel"],
|
37
43
|
debug: Optional[bool] = False,
|
38
44
|
iteration: int = 0,
|
39
|
-
cache: "Cache" = None,
|
40
|
-
sidecar_model: "LanguageModel" = None,
|
45
|
+
cache: Optional["Cache"] = None,
|
46
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
47
|
+
skip_retry=False,
|
41
48
|
):
|
42
49
|
"""Initialize the Interview instance.
|
43
50
|
|
@@ -45,6 +52,24 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
45
52
|
:param survey: the survey being administered to the agent.
|
46
53
|
:param scenario: the scenario that populates the survey questions.
|
47
54
|
:param model: the language model used to answer the questions.
|
55
|
+
:param debug: if True, run without calls to the language model.
|
56
|
+
:param iteration: the iteration number of the interview.
|
57
|
+
:param cache: the cache used to store the answers.
|
58
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
59
|
+
|
60
|
+
>>> i = Interview.example()
|
61
|
+
>>> i.task_creators
|
62
|
+
{}
|
63
|
+
|
64
|
+
>>> i.exceptions
|
65
|
+
{}
|
66
|
+
|
67
|
+
>>> _ = asyncio.run(i.async_conduct_interview())
|
68
|
+
>>> i.task_status_logs['q0']
|
69
|
+
[{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
|
70
|
+
|
71
|
+
>>> i.to_index
|
72
|
+
{'q0': 0, 'q1': 1, 'q2': 2}
|
48
73
|
|
49
74
|
"""
|
50
75
|
self.agent = agent
|
@@ -63,27 +88,54 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
63
88
|
self.task_creators = TaskCreators() # tracks the task creators
|
64
89
|
self.exceptions = InterviewExceptionCollection()
|
65
90
|
self._task_status_log_dict = InterviewStatusLog()
|
91
|
+
self.skip_retry = skip_retry
|
66
92
|
|
67
|
-
# dictionary mapping question names to their index in the survey.
|
93
|
+
# dictionary mapping question names to their index in the survey.
|
68
94
|
self.to_index = {
|
69
95
|
question_name: index
|
70
96
|
for index, question_name in enumerate(self.survey.question_names)
|
71
97
|
}
|
72
98
|
|
99
|
+
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
100
|
+
"""Return a dictionary representation of the Interview instance.
|
101
|
+
This is just for hashing purposes.
|
102
|
+
|
103
|
+
>>> i = Interview.example()
|
104
|
+
>>> hash(i)
|
105
|
+
1646262796627658719
|
106
|
+
"""
|
107
|
+
d = {
|
108
|
+
"agent": self.agent._to_dict(),
|
109
|
+
"survey": self.survey._to_dict(),
|
110
|
+
"scenario": self.scenario._to_dict(),
|
111
|
+
"model": self.model._to_dict(),
|
112
|
+
"iteration": self.iteration,
|
113
|
+
"exceptions": {},
|
114
|
+
}
|
115
|
+
if include_exceptions:
|
116
|
+
d["exceptions"] = self.exceptions.to_dict()
|
117
|
+
|
118
|
+
def __hash__(self) -> int:
|
119
|
+
from edsl.utilities.utilities import dict_hash
|
120
|
+
|
121
|
+
return dict_hash(self._to_dict())
|
122
|
+
|
73
123
|
async def async_conduct_interview(
|
74
124
|
self,
|
75
125
|
*,
|
76
126
|
model_buckets: ModelBuckets = None,
|
77
127
|
debug: bool = False,
|
78
128
|
stop_on_exception: bool = False,
|
79
|
-
sidecar_model: Optional[LanguageModel] = None,
|
129
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
80
130
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
81
131
|
"""
|
82
132
|
Conduct an Interview asynchronously.
|
133
|
+
It returns a tuple with the answers and a list of valid results.
|
83
134
|
|
84
135
|
:param model_buckets: a dictionary of token buckets for the model.
|
85
136
|
:param debug: run without calls to LLM.
|
86
137
|
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
138
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
87
139
|
|
88
140
|
Example usage:
|
89
141
|
|
@@ -91,17 +143,36 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
91
143
|
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
92
144
|
>>> result['q0']
|
93
145
|
'yes'
|
146
|
+
|
147
|
+
>>> i = Interview.example(throw_exception = True)
|
148
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
149
|
+
Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
150
|
+
<BLANKLINE>
|
151
|
+
<BLANKLINE>
|
152
|
+
Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
153
|
+
<BLANKLINE>
|
154
|
+
<BLANKLINE>
|
155
|
+
Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
156
|
+
<BLANKLINE>
|
157
|
+
<BLANKLINE>
|
158
|
+
Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
159
|
+
<BLANKLINE>
|
160
|
+
<BLANKLINE>
|
161
|
+
|
162
|
+
>>> i.exceptions
|
163
|
+
{'q0': ...
|
164
|
+
>>> i = Interview.example()
|
165
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
166
|
+
Traceback (most recent call last):
|
167
|
+
...
|
168
|
+
asyncio.exceptions.CancelledError
|
94
169
|
"""
|
95
170
|
self.sidecar_model = sidecar_model
|
96
171
|
|
97
172
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
98
|
-
# print("model_buckets", model_buckets)
|
99
173
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
100
174
|
model_buckets = ModelBuckets.infinity_bucket()
|
101
175
|
|
102
|
-
# FOR TESTING
|
103
|
-
# model_buckets = ModelBuckets.infinity_bucket()
|
104
|
-
|
105
176
|
## build the tasks using the InterviewTaskBuildingMixin
|
106
177
|
## This is the key part---it creates a task for each question,
|
107
178
|
## with dependencies on the questions that must be answered before this one can be answered.
|
@@ -123,6 +194,14 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
123
194
|
It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
|
124
195
|
If a task is not done, it raises a ValueError.
|
125
196
|
If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
|
197
|
+
|
198
|
+
>>> i = Interview.example()
|
199
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
200
|
+
>>> results = list(i._extract_valid_results())
|
201
|
+
>>> len(results) == len(i.survey)
|
202
|
+
True
|
203
|
+
>>> type(results[0])
|
204
|
+
<class 'edsl.data_transfer_models.AgentResponseDict'>
|
126
205
|
"""
|
127
206
|
assert len(self.tasks) == len(self.invigilators)
|
128
207
|
|
@@ -140,12 +219,19 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
140
219
|
yield result
|
141
220
|
|
142
221
|
def _record_exception(self, task, exception: Exception) -> None:
|
143
|
-
"""Record an exception in the Interview instance.
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
)
|
222
|
+
"""Record an exception in the Interview instance.
|
223
|
+
|
224
|
+
It records the exception in the Interview instance, with the task name and the exception entry.
|
225
|
+
|
226
|
+
>>> i = Interview.example()
|
227
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
228
|
+
>>> i.exceptions
|
229
|
+
{}
|
230
|
+
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
231
|
+
>>> i.exceptions
|
232
|
+
{'q0': ...
|
233
|
+
"""
|
234
|
+
exception_entry = InterviewExceptionEntry(exception)
|
149
235
|
self.exceptions.add(task.get_name(), exception_entry)
|
150
236
|
|
151
237
|
@property
|
@@ -156,6 +242,10 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
156
242
|
It is used to determine the order in which questions should be answered.
|
157
243
|
This reflects both agent 'memory' considerations and 'skip' logic.
|
158
244
|
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
245
|
+
|
246
|
+
>>> i = Interview.example()
|
247
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
248
|
+
True
|
159
249
|
"""
|
160
250
|
return self.survey.dag(textify=True)
|
161
251
|
|
@@ -166,8 +256,15 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
166
256
|
"""Return a string representation of the Interview instance."""
|
167
257
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
168
258
|
|
169
|
-
def duplicate(self, iteration: int, cache: Cache) -> Interview:
|
170
|
-
"""Duplicate the interview, but with a new iteration number and cache.
|
259
|
+
def duplicate(self, iteration: int, cache: "Cache") -> Interview:
|
260
|
+
"""Duplicate the interview, but with a new iteration number and cache.
|
261
|
+
|
262
|
+
>>> i = Interview.example()
|
263
|
+
>>> i2 = i.duplicate(1, None)
|
264
|
+
>>> i.iteration + 1 == i2.iteration
|
265
|
+
True
|
266
|
+
|
267
|
+
"""
|
171
268
|
return Interview(
|
172
269
|
agent=self.agent,
|
173
270
|
survey=self.survey,
|
@@ -175,10 +272,11 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
175
272
|
model=self.model,
|
176
273
|
iteration=iteration,
|
177
274
|
cache=cache,
|
275
|
+
skip_retry=self.skip_retry,
|
178
276
|
)
|
179
277
|
|
180
278
|
@classmethod
|
181
|
-
def example(self):
|
279
|
+
def example(self, throw_exception: bool = False) -> Interview:
|
182
280
|
"""Return an example Interview instance."""
|
183
281
|
from edsl.agents import Agent
|
184
282
|
from edsl.surveys import Survey
|
@@ -193,66 +291,15 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
193
291
|
survey = Survey.example()
|
194
292
|
scenario = Scenario.example()
|
195
293
|
model = LanguageModel.example()
|
294
|
+
if throw_exception:
|
295
|
+
model = LanguageModel.example(test_model=True, throw_exception=True)
|
296
|
+
agent = Agent.example()
|
297
|
+
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
196
298
|
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
197
299
|
|
198
300
|
|
199
301
|
if __name__ == "__main__":
|
200
302
|
import doctest
|
201
303
|
|
202
|
-
|
203
|
-
|
204
|
-
# from edsl.agents import Agent
|
205
|
-
# from edsl.surveys import Survey
|
206
|
-
# from edsl.scenarios import Scenario
|
207
|
-
# from edsl.questions import QuestionMultipleChoice
|
208
|
-
|
209
|
-
# # from edsl.jobs.Interview import Interview
|
210
|
-
|
211
|
-
# # a survey with skip logic
|
212
|
-
# q0 = QuestionMultipleChoice(
|
213
|
-
# question_text="Do you like school?",
|
214
|
-
# question_options=["yes", "no"],
|
215
|
-
# question_name="q0",
|
216
|
-
# )
|
217
|
-
# q1 = QuestionMultipleChoice(
|
218
|
-
# question_text="Why not?",
|
219
|
-
# question_options=["killer bees in cafeteria", "other"],
|
220
|
-
# question_name="q1",
|
221
|
-
# )
|
222
|
-
# q2 = QuestionMultipleChoice(
|
223
|
-
# question_text="Why?",
|
224
|
-
# question_options=["**lack*** of killer bees in cafeteria", "other"],
|
225
|
-
# question_name="q2",
|
226
|
-
# )
|
227
|
-
# s = Survey(questions=[q0, q1, q2])
|
228
|
-
# s = s.add_rule(q0, "q0 == 'yes'", q2)
|
229
|
-
|
230
|
-
# # create an interview
|
231
|
-
# a = Agent(traits=None)
|
232
|
-
|
233
|
-
# def direct_question_answering_method(self, question, scenario):
|
234
|
-
# """Answer a question directly."""
|
235
|
-
# raise Exception("Error!")
|
236
|
-
# # return "yes"
|
237
|
-
|
238
|
-
# a.add_direct_question_answering_method(direct_question_answering_method)
|
239
|
-
# scenario = Scenario()
|
240
|
-
# m = Model()
|
241
|
-
# I = Interview(agent=a, survey=s, scenario=scenario, model=m)
|
242
|
-
|
243
|
-
# result = asyncio.run(I.async_conduct_interview())
|
244
|
-
# # # conduct five interviews
|
245
|
-
# # for _ in range(5):
|
246
|
-
# # I.conduct_interview(debug=True)
|
247
|
-
|
248
|
-
# # # replace missing answers
|
249
|
-
# # I
|
250
|
-
# # repr(I)
|
251
|
-
# # eval(repr(I))
|
252
|
-
# # print(I.task_status_logs.status_matrix(20))
|
253
|
-
# status_matrix = I.task_status_logs.status_matrix(20)
|
254
|
-
# numerical_matrix = I.task_status_logs.numerical_matrix(20)
|
255
|
-
# I.task_status_logs.visualize()
|
256
|
-
|
257
|
-
# I.exceptions.print()
|
258
|
-
# I.exceptions.ascii_table()
|
304
|
+
# add ellipsis
|
305
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import traceback
|
2
|
+
import datetime
|
3
|
+
import time
|
4
|
+
from collections import UserDict
|
5
|
+
|
6
|
+
# traceback=traceback.format_exc(),
|
7
|
+
# traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
|
8
|
+
# traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
|
9
|
+
|
10
|
+
|
11
|
+
class InterviewExceptionEntry:
|
12
|
+
"""Class to record an exception that occurred during the interview.
|
13
|
+
|
14
|
+
>>> entry = InterviewExceptionEntry.example()
|
15
|
+
>>> entry.to_dict()['exception']
|
16
|
+
"ValueError('An error occurred.')"
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, exception: Exception, traceback_format="html"):
|
20
|
+
self.time = datetime.datetime.now().isoformat()
|
21
|
+
self.exception = exception
|
22
|
+
self.traceback_format = traceback_format
|
23
|
+
|
24
|
+
def __getitem__(self, key):
|
25
|
+
# Support dict-like access obj['a']
|
26
|
+
return str(getattr(self, key))
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def example(cls):
|
30
|
+
try:
|
31
|
+
raise ValueError("An error occurred.")
|
32
|
+
except Exception as e:
|
33
|
+
entry = InterviewExceptionEntry(e)
|
34
|
+
return entry
|
35
|
+
|
36
|
+
@property
|
37
|
+
def traceback(self):
|
38
|
+
"""Return the exception as HTML."""
|
39
|
+
if self.traceback_format == "html":
|
40
|
+
return self.html_traceback
|
41
|
+
else:
|
42
|
+
return self.text_traceback
|
43
|
+
|
44
|
+
@property
|
45
|
+
def text_traceback(self):
|
46
|
+
"""
|
47
|
+
>>> entry = InterviewExceptionEntry.example()
|
48
|
+
>>> entry.text_traceback
|
49
|
+
'Traceback (most recent call last):...'
|
50
|
+
"""
|
51
|
+
e = self.exception
|
52
|
+
tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
53
|
+
return tb_str
|
54
|
+
|
55
|
+
@property
|
56
|
+
def html_traceback(self):
|
57
|
+
from rich.console import Console
|
58
|
+
from rich.table import Table
|
59
|
+
from rich.traceback import Traceback
|
60
|
+
|
61
|
+
from io import StringIO
|
62
|
+
|
63
|
+
html_output = StringIO()
|
64
|
+
|
65
|
+
console = Console(file=html_output, record=True)
|
66
|
+
|
67
|
+
tb = Traceback.from_exception(
|
68
|
+
type(self.exception),
|
69
|
+
self.exception,
|
70
|
+
self.exception.__traceback__,
|
71
|
+
show_locals=True,
|
72
|
+
)
|
73
|
+
console.print(tb)
|
74
|
+
return html_output.getvalue()
|
75
|
+
|
76
|
+
def to_dict(self) -> dict:
|
77
|
+
"""Return the exception as a dictionary.
|
78
|
+
|
79
|
+
>>> entry = InterviewExceptionEntry.example()
|
80
|
+
>>> entry.to_dict()['exception']
|
81
|
+
"ValueError('An error occurred.')"
|
82
|
+
|
83
|
+
"""
|
84
|
+
return {
|
85
|
+
"exception": repr(self.exception),
|
86
|
+
"time": self.time,
|
87
|
+
"traceback": self.traceback,
|
88
|
+
}
|
89
|
+
|
90
|
+
def push(self):
|
91
|
+
from edsl import Coop
|
92
|
+
|
93
|
+
coop = Coop()
|
94
|
+
results = coop.error_create(self.to_dict())
|
95
|
+
return results
|
96
|
+
|
97
|
+
|
98
|
+
if __name__ == "__main__":
|
99
|
+
import doctest
|
100
|
+
|
101
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|