edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -51,6 +51,7 @@ class TestService(InferenceServiceABC):
|
|
51
51
|
@property
|
52
52
|
def _canned_response(self):
|
53
53
|
if hasattr(self, "canned_response"):
|
54
|
+
|
54
55
|
return self.canned_response
|
55
56
|
else:
|
56
57
|
return "Hello, world"
|
@@ -63,9 +64,6 @@ class TestService(InferenceServiceABC):
|
|
63
64
|
files_list: Optional[List["File"]] = None,
|
64
65
|
) -> dict[str, Any]:
|
65
66
|
await asyncio.sleep(0.1)
|
66
|
-
# return {"message": """{"answer": "Hello, world"}"""}
|
67
|
-
|
68
|
-
# breakpoint()
|
69
67
|
|
70
68
|
if hasattr(self, "throw_exception") and self.throw_exception:
|
71
69
|
if hasattr(self, "exception_probability"):
|
@@ -143,15 +143,17 @@ class TogetherAIService(OpenAIService):
|
|
143
143
|
_async_client_ = openai.AsyncOpenAI
|
144
144
|
|
145
145
|
@classmethod
|
146
|
-
def get_model_list(cls):
|
146
|
+
def get_model_list(cls, api_token=None):
|
147
147
|
# Togheter.ai has a different response in model list then openai
|
148
148
|
# and the OpenAI class returns an error when calling .models.list()
|
149
149
|
import requests
|
150
150
|
import os
|
151
151
|
|
152
152
|
url = "https://api.together.xyz/v1/models?filter=serverless"
|
153
|
-
|
154
|
-
|
153
|
+
if api_token is None:
|
154
|
+
api_token = os.getenv(cls._env_key_name_)
|
155
|
+
|
156
|
+
headers = {"accept": "application/json", "authorization": f"Bearer {api_token}"}
|
155
157
|
|
156
158
|
response = requests.get(url, headers=headers)
|
157
159
|
return response.json()
|
@@ -1,14 +1,33 @@
|
|
1
1
|
from collections import UserDict, defaultdict, UserList
|
2
|
-
from typing import Union
|
2
|
+
from typing import Union, Optional, List
|
3
3
|
from edsl.enums import InferenceServiceLiteral
|
4
4
|
from dataclasses import dataclass
|
5
5
|
|
6
6
|
|
7
7
|
@dataclass
|
8
8
|
class LanguageModelInfo:
|
9
|
+
"""A dataclass for storing information about a language model.
|
10
|
+
|
11
|
+
|
12
|
+
>>> LanguageModelInfo("gpt-4-1106-preview", "openai")
|
13
|
+
LanguageModelInfo(model_name='gpt-4-1106-preview', service_name='openai')
|
14
|
+
|
15
|
+
>>> model_name, service = LanguageModelInfo.example()
|
16
|
+
>>> model_name
|
17
|
+
'gpt-4-1106-preview'
|
18
|
+
|
19
|
+
>>> LanguageModelInfo.example().service_name
|
20
|
+
'openai'
|
21
|
+
|
22
|
+
"""
|
23
|
+
|
9
24
|
model_name: str
|
10
25
|
service_name: str
|
11
26
|
|
27
|
+
def __iter__(self):
|
28
|
+
yield self.model_name
|
29
|
+
yield self.service_name
|
30
|
+
|
12
31
|
def __getitem__(self, key: int) -> str:
|
13
32
|
import warnings
|
14
33
|
|
@@ -26,13 +45,18 @@ class LanguageModelInfo:
|
|
26
45
|
else:
|
27
46
|
raise IndexError("Index out of range")
|
28
47
|
|
48
|
+
@classmethod
|
49
|
+
def example(cls) -> "LanguageModelInfo":
|
50
|
+
return cls("gpt-4-1106-preview", "openai")
|
51
|
+
|
29
52
|
|
30
53
|
class ModelNamesList(UserList):
|
31
54
|
pass
|
32
55
|
|
33
56
|
|
34
57
|
class AvailableModels(UserList):
|
35
|
-
|
58
|
+
|
59
|
+
def __init__(self, data: List[LanguageModelInfo]) -> None:
|
36
60
|
super().__init__(data)
|
37
61
|
|
38
62
|
def __contains__(self, model_name: str) -> bool:
|
@@ -41,6 +65,54 @@ class AvailableModels(UserList):
|
|
41
65
|
return True
|
42
66
|
return False
|
43
67
|
|
68
|
+
def print(self):
|
69
|
+
return self.to_dataset().print()
|
70
|
+
|
71
|
+
def to_dataset(self):
|
72
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
73
|
+
|
74
|
+
models, services = zip(
|
75
|
+
*[(model.model_name, model.service_name) for model in self]
|
76
|
+
)
|
77
|
+
return (
|
78
|
+
ScenarioList.from_list("model", models)
|
79
|
+
.add_list("service", services)
|
80
|
+
.to_dataset()
|
81
|
+
)
|
82
|
+
|
83
|
+
def to_model_list(self):
|
84
|
+
from edsl.language_models.ModelList import ModelList
|
85
|
+
|
86
|
+
return ModelList.from_available_models(self)
|
87
|
+
|
88
|
+
def search(
|
89
|
+
self, pattern: str, service_name: Optional[str] = None, regex: bool = False
|
90
|
+
) -> "AvailableModels":
|
91
|
+
import re
|
92
|
+
|
93
|
+
if not regex:
|
94
|
+
# Escape special regex characters except *
|
95
|
+
pattern = re.escape(pattern).replace(r"\*", ".*")
|
96
|
+
|
97
|
+
try:
|
98
|
+
regex = re.compile(pattern)
|
99
|
+
avm = AvailableModels(
|
100
|
+
[
|
101
|
+
entry
|
102
|
+
for entry in self
|
103
|
+
if regex.search(entry.model_name)
|
104
|
+
and (service_name is None or entry.service_name == service_name)
|
105
|
+
]
|
106
|
+
)
|
107
|
+
if len(avm) == 0:
|
108
|
+
raise ValueError(
|
109
|
+
"No models found matching the search pattern: " + pattern
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
return avm
|
113
|
+
except re.error as e:
|
114
|
+
raise ValueError(f"Invalid regular expression pattern: {e}")
|
115
|
+
|
44
116
|
|
45
117
|
class ServiceToModelsMapping(UserDict):
|
46
118
|
def __init__(self, data: dict) -> None:
|
@@ -5,6 +5,8 @@ from typing import Union, Type, Callable, TYPE_CHECKING
|
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
8
|
+
from edsl.jobs.interviews.Interview import Interview
|
9
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
8
10
|
|
9
11
|
from edsl.surveys.base import EndOfSurvey
|
10
12
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
@@ -17,34 +19,97 @@ from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
|
17
19
|
from edsl.jobs.Answers import Answers
|
18
20
|
|
19
21
|
|
20
|
-
class
|
21
|
-
|
22
|
+
class RetryConfig:
|
23
|
+
from edsl.config import CONFIG
|
24
|
+
|
25
|
+
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
26
|
+
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
27
|
+
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
28
|
+
|
29
|
+
|
30
|
+
class SkipHandler:
|
31
|
+
|
32
|
+
def __init__(self, interview: "Interview"):
|
22
33
|
self.interview = interview
|
23
|
-
self.had_language_model_no_response_error = False
|
24
34
|
self.question_index = self.interview.to_index
|
25
35
|
|
26
36
|
self.skip_function: Callable = (
|
27
37
|
self.interview.survey.rule_collection.skip_question_before_running
|
28
38
|
)
|
29
39
|
|
30
|
-
def
|
31
|
-
|
32
|
-
|
33
|
-
@property
|
34
|
-
def answers(self) -> Answers:
|
35
|
-
return self.interview.answers
|
36
|
-
|
37
|
-
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
40
|
+
def should_skip(self, current_question: "QuestionBase") -> bool:
|
41
|
+
"""Determine if the current question should be skipped."""
|
38
42
|
current_question_index = self.question_index[current_question.question_name]
|
39
|
-
combined_answers =
|
43
|
+
combined_answers = (
|
44
|
+
self.interview.answers
|
45
|
+
| self.interview.scenario
|
46
|
+
| self.interview.agent["traits"]
|
47
|
+
)
|
40
48
|
return self.skip_function(current_question_index, combined_answers)
|
41
49
|
|
50
|
+
def cancel_skipped_questions(self, current_question: "QuestionBase") -> None:
|
51
|
+
"""Cancel the tasks for questions that should be skipped."""
|
52
|
+
current_question_index: int = self.interview.to_index[
|
53
|
+
current_question.question_name
|
54
|
+
]
|
55
|
+
answers = (
|
56
|
+
self.interview.answers
|
57
|
+
| self.interview.scenario
|
58
|
+
| self.interview.agent["traits"]
|
59
|
+
)
|
60
|
+
|
61
|
+
# Get the index of the next question, which could also be the end of the survey
|
62
|
+
next_question: Union[int, EndOfSurvey] = (
|
63
|
+
self.interview.survey.rule_collection.next_question(
|
64
|
+
q_now=current_question_index,
|
65
|
+
answers=answers,
|
66
|
+
)
|
67
|
+
)
|
68
|
+
|
69
|
+
def cancel_between(start, end):
|
70
|
+
"""Cancel the tasks for questions between the start and end indices."""
|
71
|
+
for i in range(start, end):
|
72
|
+
self.interview.tasks[i].cancel()
|
73
|
+
|
74
|
+
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
75
|
+
cancel_between(
|
76
|
+
current_question_index + 1, len(self.interview.survey.questions)
|
77
|
+
)
|
78
|
+
return
|
79
|
+
|
80
|
+
if next_question_index > (current_question_index + 1):
|
81
|
+
cancel_between(current_question_index + 1, next_question_index)
|
82
|
+
|
83
|
+
|
84
|
+
class AnswerQuestionFunctionConstructor:
|
85
|
+
"""Constructs a function that answers a question and records the answer."""
|
86
|
+
|
87
|
+
def __init__(self, interview: "Interview", key_lookup: "KeyLookup"):
|
88
|
+
self.interview = interview
|
89
|
+
self.key_lookup = key_lookup
|
90
|
+
|
91
|
+
self.had_language_model_no_response_error: bool = False
|
92
|
+
self.question_index = self.interview.to_index
|
93
|
+
|
94
|
+
self.skip_function: Callable = (
|
95
|
+
self.interview.survey.rule_collection.skip_question_before_running
|
96
|
+
)
|
97
|
+
|
98
|
+
self.invigilator_fetcher = FetchInvigilator(
|
99
|
+
self.interview, key_lookup=self.key_lookup
|
100
|
+
)
|
101
|
+
self.skip_handler = SkipHandler(self.interview)
|
102
|
+
|
42
103
|
def _handle_exception(
|
43
104
|
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
44
105
|
):
|
106
|
+
"""Handle an exception that occurred while answering a question."""
|
107
|
+
|
45
108
|
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
46
109
|
|
47
|
-
answers = copy.copy(
|
110
|
+
answers = copy.copy(
|
111
|
+
self.interview.answers
|
112
|
+
) # copy to freeze the answers here for logging
|
48
113
|
exception_entry = InterviewExceptionEntry(
|
49
114
|
exception=e,
|
50
115
|
invigilator=invigilator,
|
@@ -52,6 +117,7 @@ class AnswerQuestionFunctionConstructor:
|
|
52
117
|
)
|
53
118
|
if task:
|
54
119
|
task.task_status = TaskStatus.FAILED
|
120
|
+
|
55
121
|
self.interview.exceptions.add(
|
56
122
|
invigilator.question.question_name, exception_entry
|
57
123
|
)
|
@@ -65,41 +131,15 @@ class AnswerQuestionFunctionConstructor:
|
|
65
131
|
if stop_on_exception:
|
66
132
|
raise e
|
67
133
|
|
68
|
-
def _cancel_skipped_questions(self, current_question: "QuestionBase") -> None:
|
69
|
-
current_question_index: int = self.interview.to_index[
|
70
|
-
current_question.question_name
|
71
|
-
]
|
72
|
-
answers = (
|
73
|
-
self.answers | self.interview.scenario | self.interview.agent["traits"]
|
74
|
-
)
|
75
|
-
|
76
|
-
# Get the index of the next question, which could also be the end of the survey
|
77
|
-
next_question: Union[
|
78
|
-
int, EndOfSurvey
|
79
|
-
] = self.interview.survey.rule_collection.next_question(
|
80
|
-
q_now=current_question_index,
|
81
|
-
answers=answers,
|
82
|
-
)
|
83
|
-
|
84
|
-
def cancel_between(start, end):
|
85
|
-
for i in range(start, end):
|
86
|
-
self.interview.tasks[i].cancel()
|
87
|
-
|
88
|
-
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
89
|
-
cancel_between(
|
90
|
-
current_question_index + 1, len(self.interview.survey.questions)
|
91
|
-
)
|
92
|
-
return
|
93
|
-
|
94
|
-
if next_question_index > (current_question_index + 1):
|
95
|
-
cancel_between(current_question_index + 1, next_question_index)
|
96
|
-
|
97
134
|
def __call__(self):
|
98
|
-
|
135
|
+
return self.answer_question_and_record_task
|
99
136
|
|
100
|
-
|
101
|
-
|
102
|
-
|
137
|
+
async def answer_question_and_record_task(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
question: "QuestionBase",
|
141
|
+
task=None,
|
142
|
+
) -> "AgentResponseDict":
|
103
143
|
|
104
144
|
from tenacity import (
|
105
145
|
retry,
|
@@ -109,80 +149,75 @@ class AnswerQuestionFunctionConstructor:
|
|
109
149
|
RetryError,
|
110
150
|
)
|
111
151
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
reraise=True,
|
124
|
-
)
|
125
|
-
async def attempt_answer():
|
126
|
-
invigilator = FetchInvigilator(self.interview)(question)
|
152
|
+
@retry(
|
153
|
+
stop=stop_after_attempt(RetryConfig.EDSL_MAX_ATTEMPTS),
|
154
|
+
wait=wait_exponential(
|
155
|
+
multiplier=RetryConfig.EDSL_BACKOFF_START_SEC,
|
156
|
+
max=RetryConfig.EDSL_BACKOFF_MAX_SEC,
|
157
|
+
),
|
158
|
+
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
159
|
+
reraise=True,
|
160
|
+
)
|
161
|
+
async def attempt_answer():
|
162
|
+
invigilator = self.invigilator_fetcher(question)
|
127
163
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
164
|
+
if self.skip_handler.should_skip(question):
|
165
|
+
return invigilator.get_failed_task_result(
|
166
|
+
failure_reason="Question skipped."
|
167
|
+
)
|
132
168
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
else:
|
141
|
-
if (
|
142
|
-
hasattr(response, "exception_occurred")
|
143
|
-
and response.exception_occurred
|
144
|
-
):
|
145
|
-
raise response.exception_occurred
|
146
|
-
|
147
|
-
except QuestionAnswerValidationError as e:
|
148
|
-
self._handle_exception(e, invigilator, task)
|
149
|
-
return invigilator.get_failed_task_result(
|
150
|
-
failure_reason="Question answer validation failed."
|
169
|
+
try:
|
170
|
+
response: EDSLResultObjectInput = (
|
171
|
+
await invigilator.async_answer_question()
|
172
|
+
)
|
173
|
+
if response.validated:
|
174
|
+
self.interview.answers.add_answer(
|
175
|
+
response=response, question=question
|
151
176
|
)
|
152
177
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
)
|
178
|
+
self.skip_handler.cancel_skipped_questions(question)
|
179
|
+
else:
|
180
|
+
if (
|
181
|
+
hasattr(response, "exception_occurred")
|
182
|
+
and response.exception_occurred
|
183
|
+
):
|
184
|
+
raise response.exception_occurred
|
185
|
+
|
186
|
+
except QuestionAnswerValidationError as e:
|
187
|
+
self._handle_exception(e, invigilator, task)
|
188
|
+
return invigilator.get_failed_task_result(
|
189
|
+
failure_reason="Question answer validation failed."
|
190
|
+
)
|
159
191
|
|
160
|
-
|
161
|
-
|
192
|
+
except asyncio.TimeoutError as e:
|
193
|
+
self._handle_exception(e, invigilator, task)
|
194
|
+
had_language_model_no_response_error = True
|
195
|
+
raise LanguageModelNoResponseError(
|
196
|
+
f"Language model timed out for question '{question.question_name}.'"
|
197
|
+
)
|
162
198
|
|
163
|
-
|
164
|
-
|
165
|
-
raise LanguageModelNoResponseError(
|
166
|
-
f"Language model did not return a response for question '{question.question_name}.'"
|
167
|
-
)
|
199
|
+
except Exception as e:
|
200
|
+
self._handle_exception(e, invigilator, task)
|
168
201
|
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
question.question_name
|
175
|
-
)
|
202
|
+
if "response" not in locals():
|
203
|
+
had_language_model_no_response_error = True
|
204
|
+
raise LanguageModelNoResponseError(
|
205
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
206
|
+
)
|
176
207
|
|
177
|
-
|
208
|
+
if (
|
209
|
+
question.question_name in self.interview.exceptions
|
210
|
+
and had_language_model_no_response_error
|
211
|
+
):
|
212
|
+
self.interview.exceptions.record_fixed_question(question.question_name)
|
178
213
|
|
179
|
-
|
180
|
-
return await attempt_answer()
|
181
|
-
except RetryError as retry_error:
|
182
|
-
original_error = retry_error.last_attempt.exception()
|
183
|
-
self._handle_exception(
|
184
|
-
original_error, FetchInvigilator(self.interview)(question), task
|
185
|
-
)
|
186
|
-
raise original_error
|
214
|
+
return response
|
187
215
|
|
188
|
-
|
216
|
+
try:
|
217
|
+
return await attempt_answer()
|
218
|
+
except RetryError as retry_error:
|
219
|
+
original_error = retry_error.last_attempt.exception()
|
220
|
+
self._handle_exception(
|
221
|
+
original_error, self.invigilator_fetcher(question), task
|
222
|
+
)
|
223
|
+
raise original_error
|
edsl/jobs/FetchInvigilator.py
CHANGED
@@ -3,15 +3,23 @@ from typing import List, Dict, Any, Optional, TYPE_CHECKING
|
|
3
3
|
if TYPE_CHECKING:
|
4
4
|
from edsl.questions.QuestionBase import QuestionBase
|
5
5
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
6
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
7
|
+
from edsl.jobs.interviews.Interview import Interview
|
6
8
|
|
7
9
|
|
8
10
|
class FetchInvigilator:
|
9
|
-
def __init__(
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
interview: "Interview",
|
14
|
+
current_answers: Optional[Dict[str, Any]] = None,
|
15
|
+
key_lookup: Optional["KeyLookup"] = None,
|
16
|
+
):
|
10
17
|
self.interview = interview
|
11
18
|
if current_answers is None:
|
12
19
|
self.current_answers = self.interview.answers
|
13
20
|
else:
|
14
21
|
self.current_answers = current_answers
|
22
|
+
self.key_lookup = key_lookup
|
15
23
|
|
16
24
|
def get_invigilator(self, question: "QuestionBase") -> "InvigilatorBase":
|
17
25
|
"""Return an invigilator for the given question.
|
@@ -24,14 +32,13 @@ class FetchInvigilator:
|
|
24
32
|
question=question,
|
25
33
|
scenario=self.interview.scenario,
|
26
34
|
model=self.interview.model,
|
27
|
-
# debug=False,
|
28
35
|
survey=self.interview.survey,
|
29
36
|
memory_plan=self.interview.survey.memory_plan,
|
30
37
|
current_answers=self.current_answers, # not yet known
|
31
38
|
iteration=self.interview.iteration,
|
32
39
|
cache=self.interview.cache,
|
33
|
-
# sidecar_model=self.interview.sidecar_model,
|
34
40
|
raise_validation_errors=self.interview.raise_validation_errors,
|
41
|
+
key_lookup=self.key_lookup,
|
35
42
|
)
|
36
43
|
"""Return an invigilator for the given question."""
|
37
44
|
return invigilator
|
@@ -6,8 +6,9 @@ if TYPE_CHECKING:
|
|
6
6
|
|
7
7
|
|
8
8
|
class InterviewsConstructor:
|
9
|
-
def __init__(self, jobs):
|
9
|
+
def __init__(self, jobs: "Jobs", cache: "Cache"):
|
10
10
|
self.jobs = jobs
|
11
|
+
self.cache = cache
|
11
12
|
|
12
13
|
def create_interviews(self) -> Generator["Interview", None, None]:
|
13
14
|
"""
|
@@ -34,12 +35,13 @@ class InterviewsConstructor:
|
|
34
35
|
self.jobs.agents, self.jobs.scenarios, self.jobs.models
|
35
36
|
):
|
36
37
|
yield Interview(
|
37
|
-
survey=self.jobs.survey,
|
38
|
+
survey=self.jobs.survey.draw(),
|
38
39
|
agent=agent,
|
39
40
|
scenario=scenario,
|
40
41
|
model=model,
|
41
|
-
|
42
|
-
|
42
|
+
cache=self.cache,
|
43
|
+
skip_retry=self.jobs.run_config.parameters.skip_retry,
|
44
|
+
raise_validation_errors=self.jobs.run_config.parameters.raise_validation_errors,
|
43
45
|
indices={
|
44
46
|
"agent": agent_index[hash(agent)],
|
45
47
|
"model": model_index[hash(model)],
|