edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -28
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -9
- edsl/agents/Invigilator.py +14 -13
- edsl/agents/InvigilatorBase.py +1 -4
- edsl/agents/PromptConstructor.py +22 -42
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/coop/coop.py +5 -21
- edsl/data/Cache.py +18 -29
- edsl/data/CacheHandler.py +2 -0
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/enums.py +0 -7
- edsl/inference_services/AnthropicService.py +16 -38
- edsl/inference_services/AvailableModelFetcher.py +1 -7
- edsl/inference_services/GoogleService.py +1 -5
- edsl/inference_services/InferenceServicesCollection.py +2 -18
- edsl/inference_services/OpenAIService.py +31 -46
- edsl/inference_services/TestService.py +3 -1
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/inference_services/data_structures.py +2 -74
- edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
- edsl/jobs/FetchInvigilator.py +3 -10
- edsl/jobs/InterviewsConstructor.py +4 -6
- edsl/jobs/Jobs.py +233 -299
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
- edsl/jobs/interviews/Interview.py +42 -80
- edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/TaskHistory.py +3 -24
- edsl/language_models/LanguageModel.py +4 -59
- edsl/language_models/ModelList.py +8 -19
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/registry.py +180 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +26 -35
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +15 -9
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
- edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/results/DatasetExportMixin.py +119 -60
- edsl/results/Result.py +3 -109
- edsl/results/Results.py +39 -50
- edsl/scenarios/FileStore.py +0 -32
- edsl/scenarios/ScenarioList.py +7 -35
- edsl/scenarios/handlers/csv.py +0 -11
- edsl/surveys/Survey.py +20 -71
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/model.py +0 -256
- edsl/questions/data_structures.py +0 -20
- edsl/results/file_exports.py +0 -252
- /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
- /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
- /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- /edsl/results/{results_selector.py → Selector.py} +0 -0
- /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
- /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
- /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
@@ -143,17 +143,15 @@ class TogetherAIService(OpenAIService):
|
|
143
143
|
_async_client_ = openai.AsyncOpenAI
|
144
144
|
|
145
145
|
@classmethod
|
146
|
-
def get_model_list(cls
|
146
|
+
def get_model_list(cls):
|
147
147
|
# Togheter.ai has a different response in model list then openai
|
148
148
|
# and the OpenAI class returns an error when calling .models.list()
|
149
149
|
import requests
|
150
150
|
import os
|
151
151
|
|
152
152
|
url = "https://api.together.xyz/v1/models?filter=serverless"
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
headers = {"accept": "application/json", "authorization": f"Bearer {api_token}"}
|
153
|
+
token = os.getenv(cls._env_key_name_)
|
154
|
+
headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
|
157
155
|
|
158
156
|
response = requests.get(url, headers=headers)
|
159
157
|
return response.json()
|
@@ -1,33 +1,14 @@
|
|
1
1
|
from collections import UserDict, defaultdict, UserList
|
2
|
-
from typing import Union
|
2
|
+
from typing import Union
|
3
3
|
from edsl.enums import InferenceServiceLiteral
|
4
4
|
from dataclasses import dataclass
|
5
5
|
|
6
6
|
|
7
7
|
@dataclass
|
8
8
|
class LanguageModelInfo:
|
9
|
-
"""A dataclass for storing information about a language model.
|
10
|
-
|
11
|
-
|
12
|
-
>>> LanguageModelInfo("gpt-4-1106-preview", "openai")
|
13
|
-
LanguageModelInfo(model_name='gpt-4-1106-preview', service_name='openai')
|
14
|
-
|
15
|
-
>>> model_name, service = LanguageModelInfo.example()
|
16
|
-
>>> model_name
|
17
|
-
'gpt-4-1106-preview'
|
18
|
-
|
19
|
-
>>> LanguageModelInfo.example().service_name
|
20
|
-
'openai'
|
21
|
-
|
22
|
-
"""
|
23
|
-
|
24
9
|
model_name: str
|
25
10
|
service_name: str
|
26
11
|
|
27
|
-
def __iter__(self):
|
28
|
-
yield self.model_name
|
29
|
-
yield self.service_name
|
30
|
-
|
31
12
|
def __getitem__(self, key: int) -> str:
|
32
13
|
import warnings
|
33
14
|
|
@@ -45,18 +26,13 @@ class LanguageModelInfo:
|
|
45
26
|
else:
|
46
27
|
raise IndexError("Index out of range")
|
47
28
|
|
48
|
-
@classmethod
|
49
|
-
def example(cls) -> "LanguageModelInfo":
|
50
|
-
return cls("gpt-4-1106-preview", "openai")
|
51
|
-
|
52
29
|
|
53
30
|
class ModelNamesList(UserList):
|
54
31
|
pass
|
55
32
|
|
56
33
|
|
57
34
|
class AvailableModels(UserList):
|
58
|
-
|
59
|
-
def __init__(self, data: List[LanguageModelInfo]) -> None:
|
35
|
+
def __init__(self, data: list) -> None:
|
60
36
|
super().__init__(data)
|
61
37
|
|
62
38
|
def __contains__(self, model_name: str) -> bool:
|
@@ -65,54 +41,6 @@ class AvailableModels(UserList):
|
|
65
41
|
return True
|
66
42
|
return False
|
67
43
|
|
68
|
-
def print(self):
|
69
|
-
return self.to_dataset().print()
|
70
|
-
|
71
|
-
def to_dataset(self):
|
72
|
-
from edsl.scenarios.ScenarioList import ScenarioList
|
73
|
-
|
74
|
-
models, services = zip(
|
75
|
-
*[(model.model_name, model.service_name) for model in self]
|
76
|
-
)
|
77
|
-
return (
|
78
|
-
ScenarioList.from_list("model", models)
|
79
|
-
.add_list("service", services)
|
80
|
-
.to_dataset()
|
81
|
-
)
|
82
|
-
|
83
|
-
def to_model_list(self):
|
84
|
-
from edsl.language_models.ModelList import ModelList
|
85
|
-
|
86
|
-
return ModelList.from_available_models(self)
|
87
|
-
|
88
|
-
def search(
|
89
|
-
self, pattern: str, service_name: Optional[str] = None, regex: bool = False
|
90
|
-
) -> "AvailableModels":
|
91
|
-
import re
|
92
|
-
|
93
|
-
if not regex:
|
94
|
-
# Escape special regex characters except *
|
95
|
-
pattern = re.escape(pattern).replace(r"\*", ".*")
|
96
|
-
|
97
|
-
try:
|
98
|
-
regex = re.compile(pattern)
|
99
|
-
avm = AvailableModels(
|
100
|
-
[
|
101
|
-
entry
|
102
|
-
for entry in self
|
103
|
-
if regex.search(entry.model_name)
|
104
|
-
and (service_name is None or entry.service_name == service_name)
|
105
|
-
]
|
106
|
-
)
|
107
|
-
if len(avm) == 0:
|
108
|
-
raise ValueError(
|
109
|
-
"No models found matching the search pattern: " + pattern
|
110
|
-
)
|
111
|
-
else:
|
112
|
-
return avm
|
113
|
-
except re.error as e:
|
114
|
-
raise ValueError(f"Invalid regular expression pattern: {e}")
|
115
|
-
|
116
44
|
|
117
45
|
class ServiceToModelsMapping(UserDict):
|
118
46
|
def __init__(self, data: dict) -> None:
|
@@ -5,8 +5,6 @@ from typing import Union, Type, Callable, TYPE_CHECKING
|
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
8
|
-
from edsl.jobs.interviews.Interview import Interview
|
9
|
-
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
10
8
|
|
11
9
|
from edsl.surveys.base import EndOfSurvey
|
12
10
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
@@ -19,97 +17,34 @@ from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
|
19
17
|
from edsl.jobs.Answers import Answers
|
20
18
|
|
21
19
|
|
22
|
-
class
|
23
|
-
|
24
|
-
|
25
|
-
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
26
|
-
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
27
|
-
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
28
|
-
|
29
|
-
|
30
|
-
class SkipHandler:
|
31
|
-
|
32
|
-
def __init__(self, interview: "Interview"):
|
20
|
+
class AnswerQuestionFunctionConstructor:
|
21
|
+
def __init__(self, interview):
|
33
22
|
self.interview = interview
|
23
|
+
self.had_language_model_no_response_error = False
|
34
24
|
self.question_index = self.interview.to_index
|
35
25
|
|
36
26
|
self.skip_function: Callable = (
|
37
27
|
self.interview.survey.rule_collection.skip_question_before_running
|
38
28
|
)
|
39
29
|
|
40
|
-
def
|
41
|
-
|
42
|
-
current_question_index = self.question_index[current_question.question_name]
|
43
|
-
combined_answers = (
|
44
|
-
self.interview.answers
|
45
|
-
| self.interview.scenario
|
46
|
-
| self.interview.agent["traits"]
|
47
|
-
)
|
48
|
-
return self.skip_function(current_question_index, combined_answers)
|
49
|
-
|
50
|
-
def cancel_skipped_questions(self, current_question: "QuestionBase") -> None:
|
51
|
-
"""Cancel the tasks for questions that should be skipped."""
|
52
|
-
current_question_index: int = self.interview.to_index[
|
53
|
-
current_question.question_name
|
54
|
-
]
|
55
|
-
answers = (
|
56
|
-
self.interview.answers
|
57
|
-
| self.interview.scenario
|
58
|
-
| self.interview.agent["traits"]
|
59
|
-
)
|
60
|
-
|
61
|
-
# Get the index of the next question, which could also be the end of the survey
|
62
|
-
next_question: Union[int, EndOfSurvey] = (
|
63
|
-
self.interview.survey.rule_collection.next_question(
|
64
|
-
q_now=current_question_index,
|
65
|
-
answers=answers,
|
66
|
-
)
|
67
|
-
)
|
30
|
+
def _combined_answers(self) -> Answers:
|
31
|
+
return self.answers | self.interview.scenario | self.interview.agent["traits"]
|
68
32
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
self.interview.tasks[i].cancel()
|
33
|
+
@property
|
34
|
+
def answers(self) -> Answers:
|
35
|
+
return self.interview.answers
|
73
36
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
return
|
79
|
-
|
80
|
-
if next_question_index > (current_question_index + 1):
|
81
|
-
cancel_between(current_question_index + 1, next_question_index)
|
82
|
-
|
83
|
-
|
84
|
-
class AnswerQuestionFunctionConstructor:
|
85
|
-
"""Constructs a function that answers a question and records the answer."""
|
86
|
-
|
87
|
-
def __init__(self, interview: "Interview", key_lookup: "KeyLookup"):
|
88
|
-
self.interview = interview
|
89
|
-
self.key_lookup = key_lookup
|
90
|
-
|
91
|
-
self.had_language_model_no_response_error: bool = False
|
92
|
-
self.question_index = self.interview.to_index
|
93
|
-
|
94
|
-
self.skip_function: Callable = (
|
95
|
-
self.interview.survey.rule_collection.skip_question_before_running
|
96
|
-
)
|
97
|
-
|
98
|
-
self.invigilator_fetcher = FetchInvigilator(
|
99
|
-
self.interview, key_lookup=self.key_lookup
|
100
|
-
)
|
101
|
-
self.skip_handler = SkipHandler(self.interview)
|
37
|
+
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
38
|
+
current_question_index = self.question_index[current_question.question_name]
|
39
|
+
combined_answers = self._combined_answers()
|
40
|
+
return self.skip_function(current_question_index, combined_answers)
|
102
41
|
|
103
42
|
def _handle_exception(
|
104
43
|
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
105
44
|
):
|
106
|
-
"""Handle an exception that occurred while answering a question."""
|
107
|
-
|
108
45
|
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
109
46
|
|
110
|
-
answers = copy.copy(
|
111
|
-
self.interview.answers
|
112
|
-
) # copy to freeze the answers here for logging
|
47
|
+
answers = copy.copy(self.answers) # copy to freeze the answers here for logging
|
113
48
|
exception_entry = InterviewExceptionEntry(
|
114
49
|
exception=e,
|
115
50
|
invigilator=invigilator,
|
@@ -117,7 +52,6 @@ class AnswerQuestionFunctionConstructor:
|
|
117
52
|
)
|
118
53
|
if task:
|
119
54
|
task.task_status = TaskStatus.FAILED
|
120
|
-
|
121
55
|
self.interview.exceptions.add(
|
122
56
|
invigilator.question.question_name, exception_entry
|
123
57
|
)
|
@@ -131,15 +65,41 @@ class AnswerQuestionFunctionConstructor:
|
|
131
65
|
if stop_on_exception:
|
132
66
|
raise e
|
133
67
|
|
68
|
+
def _cancel_skipped_questions(self, current_question: "QuestionBase") -> None:
|
69
|
+
current_question_index: int = self.interview.to_index[
|
70
|
+
current_question.question_name
|
71
|
+
]
|
72
|
+
answers = (
|
73
|
+
self.answers | self.interview.scenario | self.interview.agent["traits"]
|
74
|
+
)
|
75
|
+
|
76
|
+
# Get the index of the next question, which could also be the end of the survey
|
77
|
+
next_question: Union[
|
78
|
+
int, EndOfSurvey
|
79
|
+
] = self.interview.survey.rule_collection.next_question(
|
80
|
+
q_now=current_question_index,
|
81
|
+
answers=answers,
|
82
|
+
)
|
83
|
+
|
84
|
+
def cancel_between(start, end):
|
85
|
+
for i in range(start, end):
|
86
|
+
self.interview.tasks[i].cancel()
|
87
|
+
|
88
|
+
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
89
|
+
cancel_between(
|
90
|
+
current_question_index + 1, len(self.interview.survey.questions)
|
91
|
+
)
|
92
|
+
return
|
93
|
+
|
94
|
+
if next_question_index > (current_question_index + 1):
|
95
|
+
cancel_between(current_question_index + 1, next_question_index)
|
96
|
+
|
134
97
|
def __call__(self):
|
135
|
-
|
98
|
+
from edsl.config import CONFIG
|
136
99
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
question: "QuestionBase",
|
141
|
-
task=None,
|
142
|
-
) -> "AgentResponseDict":
|
100
|
+
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
101
|
+
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
102
|
+
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
143
103
|
|
144
104
|
from tenacity import (
|
145
105
|
retry,
|
@@ -149,75 +109,80 @@ class AnswerQuestionFunctionConstructor:
|
|
149
109
|
RetryError,
|
150
110
|
)
|
151
111
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
112
|
+
async def answer_question_and_record_task(
|
113
|
+
*,
|
114
|
+
question: "QuestionBase",
|
115
|
+
task=None,
|
116
|
+
) -> "AgentResponseDict":
|
117
|
+
@retry(
|
118
|
+
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
119
|
+
wait=wait_exponential(
|
120
|
+
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
121
|
+
),
|
122
|
+
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
123
|
+
reraise=True,
|
124
|
+
)
|
125
|
+
async def attempt_answer():
|
126
|
+
invigilator = FetchInvigilator(self.interview)(question)
|
163
127
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
128
|
+
if self._skip_this_question(question):
|
129
|
+
return invigilator.get_failed_task_result(
|
130
|
+
failure_reason="Question skipped."
|
131
|
+
)
|
168
132
|
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
133
|
+
try:
|
134
|
+
response: EDSLResultObjectInput = (
|
135
|
+
await invigilator.async_answer_question()
|
136
|
+
)
|
137
|
+
if response.validated:
|
138
|
+
self.answers.add_answer(response=response, question=question)
|
139
|
+
self._cancel_skipped_questions(question)
|
140
|
+
else:
|
141
|
+
if (
|
142
|
+
hasattr(response, "exception_occurred")
|
143
|
+
and response.exception_occurred
|
144
|
+
):
|
145
|
+
raise response.exception_occurred
|
146
|
+
|
147
|
+
except QuestionAnswerValidationError as e:
|
148
|
+
self._handle_exception(e, invigilator, task)
|
149
|
+
return invigilator.get_failed_task_result(
|
150
|
+
failure_reason="Question answer validation failed."
|
176
151
|
)
|
177
152
|
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
)
|
184
|
-
raise response.exception_occurred
|
185
|
-
|
186
|
-
except QuestionAnswerValidationError as e:
|
187
|
-
self._handle_exception(e, invigilator, task)
|
188
|
-
return invigilator.get_failed_task_result(
|
189
|
-
failure_reason="Question answer validation failed."
|
190
|
-
)
|
153
|
+
except asyncio.TimeoutError as e:
|
154
|
+
self._handle_exception(e, invigilator, task)
|
155
|
+
had_language_model_no_response_error = True
|
156
|
+
raise LanguageModelNoResponseError(
|
157
|
+
f"Language model timed out for question '{question.question_name}.'"
|
158
|
+
)
|
191
159
|
|
192
|
-
|
193
|
-
|
194
|
-
had_language_model_no_response_error = True
|
195
|
-
raise LanguageModelNoResponseError(
|
196
|
-
f"Language model timed out for question '{question.question_name}.'"
|
197
|
-
)
|
160
|
+
except Exception as e:
|
161
|
+
self._handle_exception(e, invigilator, task)
|
198
162
|
|
199
|
-
|
200
|
-
|
163
|
+
if "response" not in locals():
|
164
|
+
had_language_model_no_response_error = True
|
165
|
+
raise LanguageModelNoResponseError(
|
166
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
167
|
+
)
|
201
168
|
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
169
|
+
if (
|
170
|
+
question.question_name in self.interview.exceptions
|
171
|
+
and had_language_model_no_response_error
|
172
|
+
):
|
173
|
+
self.interview.exceptions.record_fixed_question(
|
174
|
+
question.question_name
|
175
|
+
)
|
207
176
|
|
208
|
-
|
209
|
-
question.question_name in self.interview.exceptions
|
210
|
-
and had_language_model_no_response_error
|
211
|
-
):
|
212
|
-
self.interview.exceptions.record_fixed_question(question.question_name)
|
177
|
+
return response
|
213
178
|
|
214
|
-
|
179
|
+
try:
|
180
|
+
return await attempt_answer()
|
181
|
+
except RetryError as retry_error:
|
182
|
+
original_error = retry_error.last_attempt.exception()
|
183
|
+
self._handle_exception(
|
184
|
+
original_error, FetchInvigilator(self.interview)(question), task
|
185
|
+
)
|
186
|
+
raise original_error
|
215
187
|
|
216
|
-
|
217
|
-
return await attempt_answer()
|
218
|
-
except RetryError as retry_error:
|
219
|
-
original_error = retry_error.last_attempt.exception()
|
220
|
-
self._handle_exception(
|
221
|
-
original_error, self.invigilator_fetcher(question), task
|
222
|
-
)
|
223
|
-
raise original_error
|
188
|
+
return answer_question_and_record_task
|
edsl/jobs/FetchInvigilator.py
CHANGED
@@ -3,23 +3,15 @@ from typing import List, Dict, Any, Optional, TYPE_CHECKING
|
|
3
3
|
if TYPE_CHECKING:
|
4
4
|
from edsl.questions.QuestionBase import QuestionBase
|
5
5
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
6
|
-
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
7
|
-
from edsl.jobs.interviews.Interview import Interview
|
8
6
|
|
9
7
|
|
10
8
|
class FetchInvigilator:
|
11
|
-
def __init__(
|
12
|
-
self,
|
13
|
-
interview: "Interview",
|
14
|
-
current_answers: Optional[Dict[str, Any]] = None,
|
15
|
-
key_lookup: Optional["KeyLookup"] = None,
|
16
|
-
):
|
9
|
+
def __init__(self, interview, current_answers: Optional[Dict[str, Any]] = None):
|
17
10
|
self.interview = interview
|
18
11
|
if current_answers is None:
|
19
12
|
self.current_answers = self.interview.answers
|
20
13
|
else:
|
21
14
|
self.current_answers = current_answers
|
22
|
-
self.key_lookup = key_lookup
|
23
15
|
|
24
16
|
def get_invigilator(self, question: "QuestionBase") -> "InvigilatorBase":
|
25
17
|
"""Return an invigilator for the given question.
|
@@ -32,13 +24,14 @@ class FetchInvigilator:
|
|
32
24
|
question=question,
|
33
25
|
scenario=self.interview.scenario,
|
34
26
|
model=self.interview.model,
|
27
|
+
# debug=False,
|
35
28
|
survey=self.interview.survey,
|
36
29
|
memory_plan=self.interview.survey.memory_plan,
|
37
30
|
current_answers=self.current_answers, # not yet known
|
38
31
|
iteration=self.interview.iteration,
|
39
32
|
cache=self.interview.cache,
|
33
|
+
# sidecar_model=self.interview.sidecar_model,
|
40
34
|
raise_validation_errors=self.interview.raise_validation_errors,
|
41
|
-
key_lookup=self.key_lookup,
|
42
35
|
)
|
43
36
|
"""Return an invigilator for the given question."""
|
44
37
|
return invigilator
|
@@ -6,9 +6,8 @@ if TYPE_CHECKING:
|
|
6
6
|
|
7
7
|
|
8
8
|
class InterviewsConstructor:
|
9
|
-
def __init__(self, jobs
|
9
|
+
def __init__(self, jobs):
|
10
10
|
self.jobs = jobs
|
11
|
-
self.cache = cache
|
12
11
|
|
13
12
|
def create_interviews(self) -> Generator["Interview", None, None]:
|
14
13
|
"""
|
@@ -35,13 +34,12 @@ class InterviewsConstructor:
|
|
35
34
|
self.jobs.agents, self.jobs.scenarios, self.jobs.models
|
36
35
|
):
|
37
36
|
yield Interview(
|
38
|
-
survey=self.jobs.survey
|
37
|
+
survey=self.jobs.survey,
|
39
38
|
agent=agent,
|
40
39
|
scenario=scenario,
|
41
40
|
model=model,
|
42
|
-
|
43
|
-
|
44
|
-
raise_validation_errors=self.jobs.run_config.parameters.raise_validation_errors,
|
41
|
+
skip_retry=self.jobs.skip_retry,
|
42
|
+
raise_validation_errors=self.jobs.raise_validation_errors,
|
45
43
|
indices={
|
46
44
|
"agent": agent_index[hash(agent)],
|
47
45
|
"model": model_index[hash(model)],
|