edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -4
- edsl/agents/Agent.py +46 -14
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +125 -212
- edsl/agents/InvigilatorBase.py +140 -32
- edsl/agents/PromptConstructionMixin.py +43 -66
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +38 -39
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +39 -5
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +120 -38
- edsl/enums.py +2 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +24 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +41 -50
- edsl/inference_services/TestService.py +71 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +4 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +18 -13
- edsl/jobs/buckets/TokenBucket.py +39 -14
- edsl/jobs/interviews/Interview.py +297 -77
- edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
- edsl/jobs/interviews/interview_exception_tracking.py +0 -70
- edsl/jobs/interviews/retry_management.py +3 -1
- edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
- edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +131 -213
- edsl/language_models/LanguageModel.py +239 -129
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +15 -2
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +273 -242
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +6 -0
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +46 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +173 -64
- edsl/questions/QuestionNumerical.py +87 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +169 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +11 -1
- edsl/questions/derived/QuestionTopK.py +6 -0
- edsl/questions/derived/QuestionYesNo.py +16 -1
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +41 -47
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +131 -45
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/Scenario.py +10 -4
- edsl/scenarios/ScenarioList.py +348 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/study/SnapShot.py +8 -1
- edsl/surveys/RuleCollection.py +2 -2
- edsl/surveys/Survey.py +634 -315
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +111 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
- edsl-0.1.33.dev2.dist-info/RECORD +289 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.33.dev1.dist-info/RECORD +0 -209
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -156,7 +156,11 @@ class Jobs(Base):
|
|
156
156
|
from edsl.results.Dataset import Dataset
|
157
157
|
|
158
158
|
for interview_index, interview in enumerate(interviews):
|
159
|
-
invigilators =
|
159
|
+
invigilators = [
|
160
|
+
interview._get_invigilator(question)
|
161
|
+
for question in self.survey.questions
|
162
|
+
]
|
163
|
+
# list(interview._build_invigilators(debug=False))
|
160
164
|
for _, invigilator in enumerate(invigilators):
|
161
165
|
prompts = invigilator.get_prompts()
|
162
166
|
user_prompts.append(prompts["user_prompt"])
|
@@ -344,6 +348,7 @@ class Jobs(Base):
|
|
344
348
|
scenario=scenario,
|
345
349
|
model=model,
|
346
350
|
skip_retry=self.skip_retry,
|
351
|
+
raise_validation_errors=self.raise_validation_errors,
|
347
352
|
)
|
348
353
|
|
349
354
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -461,27 +466,31 @@ class Jobs(Base):
|
|
461
466
|
return False
|
462
467
|
return self._skip_retry
|
463
468
|
|
469
|
+
@property
|
470
|
+
def raise_validation_errors(self):
|
471
|
+
if not hasattr(self, "_raise_validation_errors"):
|
472
|
+
return False
|
473
|
+
return self._raise_validation_errors
|
474
|
+
|
464
475
|
def run(
|
465
476
|
self,
|
466
477
|
n: int = 1,
|
467
|
-
debug: bool = False,
|
468
478
|
progress_bar: bool = False,
|
469
479
|
stop_on_exception: bool = False,
|
470
480
|
cache: Union[Cache, bool] = None,
|
471
481
|
check_api_keys: bool = False,
|
472
482
|
sidecar_model: Optional[LanguageModel] = None,
|
473
|
-
batch_mode: Optional[bool] = None,
|
474
483
|
verbose: bool = False,
|
475
484
|
print_exceptions=True,
|
476
485
|
remote_cache_description: Optional[str] = None,
|
477
486
|
remote_inference_description: Optional[str] = None,
|
478
487
|
skip_retry: bool = False,
|
488
|
+
raise_validation_errors: bool = False,
|
479
489
|
) -> Results:
|
480
490
|
"""
|
481
491
|
Runs the Job: conducts Interviews and returns their results.
|
482
492
|
|
483
493
|
:param n: how many times to run each interview
|
484
|
-
:param debug: prints debug messages
|
485
494
|
:param progress_bar: shows a progress bar
|
486
495
|
:param stop_on_exception: stops the job if an exception is raised
|
487
496
|
:param cache: a cache object to store results
|
@@ -495,11 +504,7 @@ class Jobs(Base):
|
|
495
504
|
|
496
505
|
self._check_parameters()
|
497
506
|
self._skip_retry = skip_retry
|
498
|
-
|
499
|
-
if batch_mode is not None:
|
500
|
-
raise NotImplementedError(
|
501
|
-
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
502
|
-
)
|
507
|
+
self._raise_validation_errors = raise_validation_errors
|
503
508
|
|
504
509
|
self.verbose = verbose
|
505
510
|
|
@@ -587,7 +592,7 @@ class Jobs(Base):
|
|
587
592
|
)
|
588
593
|
|
589
594
|
# handle cache
|
590
|
-
if cache is None:
|
595
|
+
if cache is None or cache is True:
|
591
596
|
from edsl.data.CacheHandler import CacheHandler
|
592
597
|
|
593
598
|
cache = CacheHandler().get_cache()
|
@@ -599,12 +604,12 @@ class Jobs(Base):
|
|
599
604
|
if not remote_cache:
|
600
605
|
results = self._run_local(
|
601
606
|
n=n,
|
602
|
-
debug=debug,
|
603
607
|
progress_bar=progress_bar,
|
604
608
|
cache=cache,
|
605
609
|
stop_on_exception=stop_on_exception,
|
606
610
|
sidecar_model=sidecar_model,
|
607
611
|
print_exceptions=print_exceptions,
|
612
|
+
raise_validation_errors=raise_validation_errors,
|
608
613
|
)
|
609
614
|
|
610
615
|
results.cache = cache.new_entries_cache()
|
@@ -643,12 +648,12 @@ class Jobs(Base):
|
|
643
648
|
self._output("Running job...")
|
644
649
|
results = self._run_local(
|
645
650
|
n=n,
|
646
|
-
debug=debug,
|
647
651
|
progress_bar=progress_bar,
|
648
652
|
cache=cache,
|
649
653
|
stop_on_exception=stop_on_exception,
|
650
654
|
sidecar_model=sidecar_model,
|
651
655
|
print_exceptions=print_exceptions,
|
656
|
+
raise_validation_errors=raise_validation_errors,
|
652
657
|
)
|
653
658
|
self._output("Job completed!")
|
654
659
|
|
@@ -883,7 +888,7 @@ def main():
|
|
883
888
|
|
884
889
|
job = Jobs.example()
|
885
890
|
len(job) == 8
|
886
|
-
results = job.run(
|
891
|
+
results = job.run(cache=Cache())
|
887
892
|
len(results) == 8
|
888
893
|
results
|
889
894
|
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -82,23 +82,30 @@ class TokenBucket:
|
|
82
82
|
>>> bucket.refill()
|
83
83
|
>>> bucket.tokens > 0
|
84
84
|
True
|
85
|
-
|
86
85
|
"""
|
86
|
+
"""Refill the bucket with new tokens based on elapsed time."""
|
87
87
|
now = time.monotonic()
|
88
|
+
# print(f"Time is now: {now}; Last refill time: {self.last_refill}")
|
88
89
|
elapsed = now - self.last_refill
|
90
|
+
# print("Elapsed time: ", elapsed)
|
89
91
|
refill_amount = elapsed * self.refill_rate
|
90
92
|
self.tokens = min(self.capacity, self.tokens + refill_amount)
|
91
93
|
self.last_refill = now
|
92
94
|
|
95
|
+
if self.tokens < self.capacity:
|
96
|
+
pass
|
97
|
+
# print(f"Refilled. Current tokens: {self.tokens:.4f}")
|
98
|
+
# print(f"Elapsed time: {elapsed:.4f} seconds")
|
99
|
+
# print(f"Refill amount: {refill_amount:.4f}")
|
100
|
+
|
93
101
|
self.log.append((now, self.tokens))
|
94
102
|
|
95
103
|
def wait_time(self, requested_tokens: Union[float, int]) -> float:
|
96
104
|
"""Calculate the time to wait for the requested number of tokens."""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
105
|
+
# self.refill() # Update the current token count
|
106
|
+
if self.tokens >= requested_tokens:
|
107
|
+
return 0
|
108
|
+
return (requested_tokens - self.tokens) / self.refill_rate
|
102
109
|
|
103
110
|
async def get_tokens(
|
104
111
|
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
@@ -123,22 +130,40 @@ class TokenBucket:
|
|
123
130
|
...
|
124
131
|
ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
|
125
132
|
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
|
133
|
+
>>> bucket.capacity
|
134
|
+
12.100000000000001
|
126
135
|
"""
|
127
|
-
if amount
|
136
|
+
if amount >= self.capacity:
|
128
137
|
if not cheat_bucket_capacity:
|
129
138
|
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
130
139
|
raise ValueError(msg)
|
131
140
|
else:
|
132
|
-
self.tokens = 0 # clear the bucket but let it go through
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
141
|
+
# self.tokens = 0 # clear the bucket but let it go through
|
142
|
+
# print(
|
143
|
+
# f"""The requested amount, {amount}, exceeds the current bucket capacity of {self.capacity}.Increasing bucket capacity to {amount} * 1.10 accommodate the requested amount."""
|
144
|
+
# )
|
145
|
+
self.capacity = amount * 1.10
|
146
|
+
self._old_capacity = self.capacity
|
147
|
+
|
148
|
+
start_time = time.monotonic()
|
149
|
+
while True:
|
150
|
+
self.refill() # Refill based on elapsed time
|
151
|
+
if self.tokens >= amount:
|
152
|
+
self.tokens -= amount
|
153
|
+
break
|
154
|
+
|
155
|
+
wait_time = self.wait_time(amount)
|
156
|
+
# print(f"Waiting for {wait_time:.4f} seconds")
|
157
|
+
if wait_time > 0:
|
158
|
+
# print(f"Waiting for {wait_time:.4f} seconds")
|
159
|
+
await asyncio.sleep(wait_time)
|
160
|
+
|
161
|
+
# total_elapsed = time.monotonic() - start_time
|
162
|
+
# print(f"Total time to acquire tokens: {total_elapsed:.4f} seconds")
|
139
163
|
|
140
164
|
now = time.monotonic()
|
141
165
|
self.log.append((now, self.tokens))
|
166
|
+
return None
|
142
167
|
|
143
168
|
def get_log(self) -> list[tuple]:
|
144
169
|
return self.log
|
@@ -1,50 +1,68 @@
|
|
1
1
|
"""This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
|
-
import traceback
|
5
4
|
import asyncio
|
6
|
-
import
|
7
|
-
from typing import Any, Type, List, Generator, Optional
|
5
|
+
from typing import Any, Type, List, Generator, Optional, Union
|
8
6
|
|
9
|
-
from edsl
|
7
|
+
from edsl import CONFIG
|
10
8
|
from edsl.surveys.base import EndOfSurvey
|
9
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
10
|
+
from edsl.exceptions import InterviewTimeoutError
|
11
|
+
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
12
|
+
|
11
13
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
14
|
+
from edsl.jobs.Answers import Answers
|
15
|
+
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
12
16
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
13
|
-
|
14
17
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
15
18
|
from edsl.jobs.interviews.interview_exception_tracking import (
|
16
19
|
InterviewExceptionCollection,
|
17
20
|
)
|
18
21
|
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
19
22
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
20
|
-
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
21
23
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
22
24
|
|
23
|
-
import
|
25
|
+
from edsl.surveys.base import EndOfSurvey
|
26
|
+
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
27
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
28
|
+
from edsl.jobs.interviews.retry_management import retry_strategy
|
29
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
30
|
+
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
31
|
+
|
32
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
33
|
+
|
34
|
+
from edsl import Agent, Survey, Scenario, Cache
|
35
|
+
from edsl.language_models import LanguageModel
|
36
|
+
from edsl.questions import QuestionBase
|
37
|
+
from edsl.agents.InvigilatorBase import InvigilatorBase
|
38
|
+
|
39
|
+
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
24
40
|
|
25
41
|
|
26
|
-
|
27
|
-
|
42
|
+
class RetryableLanguageModelNoResponseError(LanguageModelNoResponseError):
|
43
|
+
pass
|
28
44
|
|
29
45
|
|
30
|
-
class Interview(InterviewStatusMixin
|
46
|
+
class Interview(InterviewStatusMixin):
|
31
47
|
"""
|
32
48
|
An 'interview' is one agent answering one survey, with one language model, for a given scenario.
|
33
49
|
|
34
50
|
The main method is `async_conduct_interview`, which conducts the interview asynchronously.
|
51
|
+
Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
|
35
52
|
"""
|
36
53
|
|
37
54
|
def __init__(
|
38
55
|
self,
|
39
|
-
agent:
|
40
|
-
survey:
|
41
|
-
scenario:
|
56
|
+
agent: Agent,
|
57
|
+
survey: Survey,
|
58
|
+
scenario: Scenario,
|
42
59
|
model: Type["LanguageModel"],
|
43
60
|
debug: Optional[bool] = False,
|
44
61
|
iteration: int = 0,
|
45
62
|
cache: Optional["Cache"] = None,
|
46
63
|
sidecar_model: Optional["LanguageModel"] = None,
|
47
|
-
skip_retry=False,
|
64
|
+
skip_retry: bool = False,
|
65
|
+
raise_validation_errors: bool = True,
|
48
66
|
):
|
49
67
|
"""Initialize the Interview instance.
|
50
68
|
|
@@ -79,9 +97,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
79
97
|
self.debug = debug
|
80
98
|
self.iteration = iteration
|
81
99
|
self.cache = cache
|
82
|
-
self.answers: dict[
|
83
|
-
|
84
|
-
|
100
|
+
self.answers: dict[str, str] = (
|
101
|
+
Answers()
|
102
|
+
) # will get filled in as interview progresses
|
85
103
|
self.sidecar_model = sidecar_model
|
86
104
|
|
87
105
|
# Trackers
|
@@ -89,6 +107,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
89
107
|
self.exceptions = InterviewExceptionCollection()
|
90
108
|
self._task_status_log_dict = InterviewStatusLog()
|
91
109
|
self.skip_retry = skip_retry
|
110
|
+
self.raise_validation_errors = raise_validation_errors
|
92
111
|
|
93
112
|
# dictionary mapping question names to their index in the survey.
|
94
113
|
self.to_index = {
|
@@ -96,6 +115,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
96
115
|
for index, question_name in enumerate(self.survey.question_names)
|
97
116
|
}
|
98
117
|
|
118
|
+
self.failed_questions = []
|
119
|
+
|
120
|
+
# region: Serialization
|
99
121
|
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
100
122
|
"""Return a dictionary representation of the Interview instance.
|
101
123
|
This is just for hashing purposes.
|
@@ -120,13 +142,247 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
120
142
|
|
121
143
|
return dict_hash(self._to_dict())
|
122
144
|
|
123
|
-
|
145
|
+
# endregion
|
146
|
+
|
147
|
+
# region: Creating tasks
|
148
|
+
@property
|
149
|
+
def dag(self) -> "DAG":
|
150
|
+
"""Return the directed acyclic graph for the survey.
|
151
|
+
|
152
|
+
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
153
|
+
It is used to determine the order in which questions should be answered.
|
154
|
+
This reflects both agent 'memory' considerations and 'skip' logic.
|
155
|
+
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
156
|
+
|
157
|
+
>>> i = Interview.example()
|
158
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
159
|
+
True
|
160
|
+
"""
|
161
|
+
return self.survey.dag(textify=True)
|
162
|
+
|
163
|
+
def _build_question_tasks(
|
164
|
+
self,
|
165
|
+
model_buckets: ModelBuckets,
|
166
|
+
) -> list[asyncio.Task]:
|
167
|
+
"""Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
|
168
|
+
|
169
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
170
|
+
:param model_buckets: the model buckets used to track and control usage rates.
|
171
|
+
"""
|
172
|
+
tasks = []
|
173
|
+
for question in self.survey.questions:
|
174
|
+
tasks_that_must_be_completed_before = list(
|
175
|
+
self._get_tasks_that_must_be_completed_before(
|
176
|
+
tasks=tasks, question=question
|
177
|
+
)
|
178
|
+
)
|
179
|
+
question_task = self._create_question_task(
|
180
|
+
question=question,
|
181
|
+
tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
|
182
|
+
model_buckets=model_buckets,
|
183
|
+
iteration=self.iteration,
|
184
|
+
)
|
185
|
+
tasks.append(question_task)
|
186
|
+
return tuple(tasks)
|
187
|
+
|
188
|
+
def _get_tasks_that_must_be_completed_before(
|
189
|
+
self, *, tasks: list[asyncio.Task], question: "QuestionBase"
|
190
|
+
) -> Generator[asyncio.Task, None, None]:
|
191
|
+
"""Return the tasks that must be completed before the given question can be answered.
|
192
|
+
|
193
|
+
:param tasks: a list of tasks that have been created so far.
|
194
|
+
:param question: the question for which we are determining dependencies.
|
195
|
+
|
196
|
+
If a question has no dependencies, this will be an empty list, [].
|
197
|
+
"""
|
198
|
+
parents_of_focal_question = self.dag.get(question.question_name, [])
|
199
|
+
for parent_question_name in parents_of_focal_question:
|
200
|
+
yield tasks[self.to_index[parent_question_name]]
|
201
|
+
|
202
|
+
def _create_question_task(
|
203
|
+
self,
|
204
|
+
*,
|
205
|
+
question: QuestionBase,
|
206
|
+
tasks_that_must_be_completed_before: list[asyncio.Task],
|
207
|
+
model_buckets: ModelBuckets,
|
208
|
+
iteration: int = 0,
|
209
|
+
) -> asyncio.Task:
|
210
|
+
"""Create a task that depends on the passed-in dependencies that are awaited before the task is run.
|
211
|
+
|
212
|
+
:param question: the question to be answered. This is the question we are creating a task for.
|
213
|
+
:param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
|
214
|
+
:param model_buckets: the model buckets used to track and control usage rates.
|
215
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
216
|
+
:param iteration: the iteration number for the interview.
|
217
|
+
|
218
|
+
The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
|
219
|
+
It is passed a reference to the function that will be called to answer the question.
|
220
|
+
It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
|
221
|
+
These are added as a dependency to the focal task.
|
222
|
+
"""
|
223
|
+
task_creator = QuestionTaskCreator(
|
224
|
+
question=question,
|
225
|
+
answer_question_func=self._answer_question_and_record_task,
|
226
|
+
token_estimator=self._get_estimated_request_tokens,
|
227
|
+
model_buckets=model_buckets,
|
228
|
+
iteration=iteration,
|
229
|
+
)
|
230
|
+
for task in tasks_that_must_be_completed_before:
|
231
|
+
task_creator.add_dependency(task)
|
232
|
+
|
233
|
+
self.task_creators.update(
|
234
|
+
{question.question_name: task_creator}
|
235
|
+
) # track this task creator
|
236
|
+
return task_creator.generate_task()
|
237
|
+
|
238
|
+
def _get_estimated_request_tokens(self, question) -> float:
|
239
|
+
"""Estimate the number of tokens that will be required to run the focal task."""
|
240
|
+
invigilator = self._get_invigilator(question=question)
|
241
|
+
# TODO: There should be a way to get a more accurate estimate.
|
242
|
+
combined_text = ""
|
243
|
+
for prompt in invigilator.get_prompts().values():
|
244
|
+
if hasattr(prompt, "text"):
|
245
|
+
combined_text += prompt.text
|
246
|
+
elif isinstance(prompt, str):
|
247
|
+
combined_text += prompt
|
248
|
+
else:
|
249
|
+
raise ValueError(f"Prompt is of type {type(prompt)}")
|
250
|
+
return len(combined_text) / 4.0
|
251
|
+
|
252
|
+
async def _answer_question_and_record_task(
|
124
253
|
self,
|
125
254
|
*,
|
126
|
-
|
127
|
-
|
255
|
+
question: "QuestionBase",
|
256
|
+
task=None,
|
257
|
+
) -> "AgentResponseDict":
|
258
|
+
"""Answer a question and records the task."""
|
259
|
+
|
260
|
+
invigilator = self._get_invigilator(question)
|
261
|
+
|
262
|
+
if self._skip_this_question(question):
|
263
|
+
response = invigilator.get_failed_task_result(
|
264
|
+
failure_reason="Question skipped."
|
265
|
+
)
|
266
|
+
|
267
|
+
try:
|
268
|
+
response: EDSLResultObjectInput = await invigilator.async_answer_question()
|
269
|
+
if response.validated:
|
270
|
+
self.answers.add_answer(response=response, question=question)
|
271
|
+
self._cancel_skipped_questions(question)
|
272
|
+
else:
|
273
|
+
if (
|
274
|
+
hasattr(response, "exception_occurred")
|
275
|
+
and response.exception_occurred
|
276
|
+
):
|
277
|
+
raise response.exception_occurred
|
278
|
+
|
279
|
+
except QuestionAnswerValidationError as e:
|
280
|
+
# there's a response, but it couldn't be validated
|
281
|
+
self._handle_exception(e, invigilator, task)
|
282
|
+
|
283
|
+
except asyncio.TimeoutError as e:
|
284
|
+
# the API timed-out - this is recorded but as a response isn't generated, the LanguageModelNoResponseError will also be raised
|
285
|
+
self._handle_exception(e, invigilator, task)
|
286
|
+
|
287
|
+
except Exception as e:
|
288
|
+
# there was some other exception
|
289
|
+
self._handle_exception(e, invigilator, task)
|
290
|
+
|
291
|
+
if "response" not in locals():
|
292
|
+
|
293
|
+
raise LanguageModelNoResponseError(
|
294
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
295
|
+
)
|
296
|
+
|
297
|
+
return response
|
298
|
+
|
299
|
+
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
300
|
+
"""Return an invigilator for the given question.
|
301
|
+
|
302
|
+
:param question: the question to be answered
|
303
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
304
|
+
"""
|
305
|
+
invigilator = self.agent.create_invigilator(
|
306
|
+
question=question,
|
307
|
+
scenario=self.scenario,
|
308
|
+
model=self.model,
|
309
|
+
debug=False,
|
310
|
+
survey=self.survey,
|
311
|
+
memory_plan=self.survey.memory_plan,
|
312
|
+
current_answers=self.answers,
|
313
|
+
iteration=self.iteration,
|
314
|
+
cache=self.cache,
|
315
|
+
sidecar_model=self.sidecar_model,
|
316
|
+
raise_validation_errors=self.raise_validation_errors,
|
317
|
+
)
|
318
|
+
"""Return an invigilator for the given question."""
|
319
|
+
return invigilator
|
320
|
+
|
321
|
+
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
322
|
+
"""Determine if the current question should be skipped.
|
323
|
+
|
324
|
+
:param current_question: the question to be answered.
|
325
|
+
"""
|
326
|
+
current_question_index = self.to_index[current_question.question_name]
|
327
|
+
|
328
|
+
answers = self.answers | self.scenario | self.agent["traits"]
|
329
|
+
skip = self.survey.rule_collection.skip_question_before_running(
|
330
|
+
current_question_index, answers
|
331
|
+
)
|
332
|
+
return skip
|
333
|
+
|
334
|
+
def _handle_exception(
|
335
|
+
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
336
|
+
):
|
337
|
+
exception_entry = InterviewExceptionEntry(
|
338
|
+
exception=e,
|
339
|
+
invigilator=invigilator,
|
340
|
+
)
|
341
|
+
if task:
|
342
|
+
task.task_status = TaskStatus.FAILED
|
343
|
+
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
344
|
+
|
345
|
+
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
346
|
+
"""Cancel the tasks for questions that are skipped.
|
347
|
+
|
348
|
+
:param current_question: the question that was just answered.
|
349
|
+
|
350
|
+
It first determines the next question, given the current question and the current answers.
|
351
|
+
If the next question is the end of the survey, it cancels all remaining tasks.
|
352
|
+
If the next question is after the current question, it cancels all tasks between the current question and the next question.
|
353
|
+
"""
|
354
|
+
current_question_index: int = self.to_index[current_question.question_name]
|
355
|
+
|
356
|
+
next_question: Union[int, EndOfSurvey] = (
|
357
|
+
self.survey.rule_collection.next_question(
|
358
|
+
q_now=current_question_index,
|
359
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
360
|
+
)
|
361
|
+
)
|
362
|
+
|
363
|
+
next_question_index = next_question.next_q
|
364
|
+
|
365
|
+
def cancel_between(start, end):
|
366
|
+
"""Cancel the tasks between the start and end indices."""
|
367
|
+
for i in range(start, end):
|
368
|
+
self.tasks[i].cancel()
|
369
|
+
|
370
|
+
if next_question_index == EndOfSurvey:
|
371
|
+
cancel_between(current_question_index + 1, len(self.survey.questions))
|
372
|
+
return
|
373
|
+
|
374
|
+
if next_question_index > (current_question_index + 1):
|
375
|
+
cancel_between(current_question_index + 1, next_question_index)
|
376
|
+
|
377
|
+
# endregion
|
378
|
+
|
379
|
+
# region: Conducting the interview
|
380
|
+
async def async_conduct_interview(
|
381
|
+
self,
|
382
|
+
model_buckets: Optional[ModelBuckets] = None,
|
128
383
|
stop_on_exception: bool = False,
|
129
384
|
sidecar_model: Optional["LanguageModel"] = None,
|
385
|
+
raise_validation_errors: bool = True,
|
130
386
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
131
387
|
"""
|
132
388
|
Conduct an Interview asynchronously.
|
@@ -146,19 +402,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
146
402
|
|
147
403
|
>>> i = Interview.example(throw_exception = True)
|
148
404
|
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
149
|
-
Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
150
|
-
<BLANKLINE>
|
151
|
-
<BLANKLINE>
|
152
|
-
Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
153
|
-
<BLANKLINE>
|
154
|
-
<BLANKLINE>
|
155
|
-
Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
156
|
-
<BLANKLINE>
|
157
|
-
<BLANKLINE>
|
158
|
-
Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
159
|
-
<BLANKLINE>
|
160
|
-
<BLANKLINE>
|
161
|
-
|
162
405
|
>>> i.exceptions
|
163
406
|
{'q0': ...
|
164
407
|
>>> i = Interview.example()
|
@@ -173,21 +416,22 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
173
416
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
174
417
|
model_buckets = ModelBuckets.infinity_bucket()
|
175
418
|
|
176
|
-
## build the tasks using the InterviewTaskBuildingMixin
|
177
419
|
## This is the key part---it creates a task for each question,
|
178
420
|
## with dependencies on the questions that must be answered before this one can be answered.
|
179
|
-
self.tasks = self._build_question_tasks(
|
180
|
-
debug=debug, model_buckets=model_buckets
|
181
|
-
)
|
421
|
+
self.tasks = self._build_question_tasks(model_buckets=model_buckets)
|
182
422
|
|
183
423
|
## 'Invigilators' are used to administer the survey
|
184
|
-
self.invigilators =
|
185
|
-
|
424
|
+
self.invigilators = [
|
425
|
+
self._get_invigilator(question) for question in self.survey.questions
|
426
|
+
]
|
186
427
|
await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
|
187
428
|
self.answers.replace_missing_answers_with_none(self.survey)
|
188
429
|
valid_results = list(self._extract_valid_results())
|
189
430
|
return self.answers, valid_results
|
190
431
|
|
432
|
+
# endregion
|
433
|
+
|
434
|
+
# region: Extracting results and recording errors
|
191
435
|
def _extract_valid_results(self) -> Generator["Answers", None, None]:
|
192
436
|
"""Extract the valid results from the list of results.
|
193
437
|
|
@@ -200,8 +444,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
200
444
|
>>> results = list(i._extract_valid_results())
|
201
445
|
>>> len(results) == len(i.survey)
|
202
446
|
True
|
203
|
-
>>> type(results[0])
|
204
|
-
<class 'edsl.data_transfer_models.AgentResponseDict'>
|
205
447
|
"""
|
206
448
|
assert len(self.tasks) == len(self.invigilators)
|
207
449
|
|
@@ -212,46 +454,24 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
212
454
|
try:
|
213
455
|
result = task.result()
|
214
456
|
except asyncio.CancelledError as e: # task was cancelled
|
215
|
-
result = invigilator.get_failed_task_result(
|
457
|
+
result = invigilator.get_failed_task_result(
|
458
|
+
failure_reason="Task was cancelled."
|
459
|
+
)
|
216
460
|
except Exception as e: # any other kind of exception in the task
|
217
|
-
result = invigilator.get_failed_task_result(
|
218
|
-
|
219
|
-
|
461
|
+
result = invigilator.get_failed_task_result(
|
462
|
+
failure_reason=f"Task failed with exception: {str(e)}."
|
463
|
+
)
|
464
|
+
exception_entry = InterviewExceptionEntry(
|
465
|
+
exception=e,
|
466
|
+
invigilator=invigilator,
|
467
|
+
)
|
468
|
+
self.exceptions.add(task.get_name(), exception_entry)
|
220
469
|
|
221
|
-
|
222
|
-
"""Record an exception in the Interview instance.
|
223
|
-
|
224
|
-
It records the exception in the Interview instance, with the task name and the exception entry.
|
225
|
-
|
226
|
-
>>> i = Interview.example()
|
227
|
-
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
228
|
-
>>> i.exceptions
|
229
|
-
{}
|
230
|
-
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
231
|
-
>>> i.exceptions
|
232
|
-
{'q0': ...
|
233
|
-
"""
|
234
|
-
exception_entry = InterviewExceptionEntry(exception)
|
235
|
-
self.exceptions.add(task.get_name(), exception_entry)
|
236
|
-
|
237
|
-
@property
|
238
|
-
def dag(self) -> "DAG":
|
239
|
-
"""Return the directed acyclic graph for the survey.
|
240
|
-
|
241
|
-
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
242
|
-
It is used to determine the order in which questions should be answered.
|
243
|
-
This reflects both agent 'memory' considerations and 'skip' logic.
|
244
|
-
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
470
|
+
yield result
|
245
471
|
|
246
|
-
|
247
|
-
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
248
|
-
True
|
249
|
-
"""
|
250
|
-
return self.survey.dag(textify=True)
|
472
|
+
# endregion
|
251
473
|
|
252
|
-
|
253
|
-
# Dunder methods
|
254
|
-
#######################
|
474
|
+
# region: Magic methods
|
255
475
|
def __repr__(self) -> str:
|
256
476
|
"""Return a string representation of the Interview instance."""
|
257
477
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|