edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,120 @@
|
|
1
|
+
from typing import Optional, Literal
|
2
|
+
from dataclasses import dataclass, asdict
|
3
|
+
|
4
|
+
# from edsl.data_transfer_models import VisibilityType
|
5
|
+
from edsl.data.Cache import Cache
|
6
|
+
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
7
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
8
|
+
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
9
|
+
|
10
|
+
VisibilityType = Literal["private", "public", "unlisted"]
|
11
|
+
from edsl.Base import Base
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class RunEnvironment:
|
16
|
+
cache: Optional[Cache] = None
|
17
|
+
bucket_collection: Optional[BucketCollection] = None
|
18
|
+
key_lookup: Optional[KeyLookup] = None
|
19
|
+
jobs_runner_status: Optional["JobsRunnerStatus"] = None
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class RunParameters(Base):
|
24
|
+
n: int = 1
|
25
|
+
progress_bar: bool = False
|
26
|
+
stop_on_exception: bool = False
|
27
|
+
check_api_keys: bool = False
|
28
|
+
verbose: bool = True
|
29
|
+
print_exceptions: bool = True
|
30
|
+
remote_cache_description: Optional[str] = None
|
31
|
+
remote_inference_description: Optional[str] = None
|
32
|
+
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
|
33
|
+
skip_retry: bool = False
|
34
|
+
raise_validation_errors: bool = False
|
35
|
+
disable_remote_cache: bool = False
|
36
|
+
disable_remote_inference: bool = False
|
37
|
+
job_uuid: Optional[str] = None
|
38
|
+
|
39
|
+
def to_dict(self, add_edsl_version=False) -> dict:
|
40
|
+
d = asdict(self)
|
41
|
+
if add_edsl_version:
|
42
|
+
from edsl import __version__
|
43
|
+
|
44
|
+
d["edsl_version"] = __version__
|
45
|
+
d["edsl_class_name"] = "RunConfig"
|
46
|
+
return d
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_dict(cls, data: dict) -> "RunConfig":
|
50
|
+
return cls(**data)
|
51
|
+
|
52
|
+
def code(self):
|
53
|
+
return f"RunConfig(**{self.to_dict()})"
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def example(cls) -> "RunConfig":
|
57
|
+
return cls()
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class RunConfig:
|
62
|
+
environment: RunEnvironment
|
63
|
+
parameters: RunParameters
|
64
|
+
|
65
|
+
def add_environment(self, environment: RunEnvironment):
|
66
|
+
self.environment = environment
|
67
|
+
|
68
|
+
def add_bucket_collection(self, bucket_collection: BucketCollection):
|
69
|
+
self.environment.bucket_collection = bucket_collection
|
70
|
+
|
71
|
+
def add_cache(self, cache: Cache):
|
72
|
+
self.environment.cache = cache
|
73
|
+
|
74
|
+
def add_key_lookup(self, key_lookup: KeyLookup):
|
75
|
+
self.environment.key_lookup = key_lookup
|
76
|
+
|
77
|
+
|
78
|
+
"""This module contains the Answers class, which is a helper class to hold the answers to a survey."""
|
79
|
+
|
80
|
+
from collections import UserDict
|
81
|
+
from edsl.data_transfer_models import EDSLResultObjectInput
|
82
|
+
|
83
|
+
|
84
|
+
class Answers(UserDict):
|
85
|
+
"""Helper class to hold the answers to a survey."""
|
86
|
+
|
87
|
+
def add_answer(
|
88
|
+
self, response: EDSLResultObjectInput, question: "QuestionBase"
|
89
|
+
) -> None:
|
90
|
+
"""Add a response to the answers dictionary."""
|
91
|
+
answer = response.answer
|
92
|
+
comment = response.comment
|
93
|
+
generated_tokens = response.generated_tokens
|
94
|
+
# record the answer
|
95
|
+
if generated_tokens:
|
96
|
+
self[question.question_name + "_generated_tokens"] = generated_tokens
|
97
|
+
self[question.question_name] = answer
|
98
|
+
if comment:
|
99
|
+
self[question.question_name + "_comment"] = comment
|
100
|
+
|
101
|
+
def replace_missing_answers_with_none(self, survey: "Survey") -> None:
|
102
|
+
"""Replace missing answers with None. Answers can be missing if the agent skips a question."""
|
103
|
+
for question_name in survey.question_names:
|
104
|
+
if question_name not in self:
|
105
|
+
self[question_name] = None
|
106
|
+
|
107
|
+
def to_dict(self):
|
108
|
+
"""Return a dictionary of the answers."""
|
109
|
+
return self.data
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def from_dict(cls, d):
|
113
|
+
"""Return an Answers object from a dictionary."""
|
114
|
+
return cls(d)
|
115
|
+
|
116
|
+
|
117
|
+
if __name__ == "__main__":
|
118
|
+
import doctest
|
119
|
+
|
120
|
+
doctest.testmod()
|
@@ -4,10 +4,10 @@ from __future__ import annotations
|
|
4
4
|
import asyncio
|
5
5
|
from typing import Any, Type, List, Generator, Optional, Union, TYPE_CHECKING
|
6
6
|
import copy
|
7
|
+
from dataclasses import dataclass
|
7
8
|
|
8
|
-
# from edsl.
|
9
|
-
|
10
|
-
from edsl.jobs.Answers import Answers
|
9
|
+
# from edsl.jobs.Answers import Answers
|
10
|
+
from edsl.jobs.data_structures import Answers
|
11
11
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
12
12
|
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
13
13
|
from edsl.jobs.interviews.InterviewExceptionCollection import (
|
@@ -22,6 +22,7 @@ from edsl.jobs.InterviewTaskManager import InterviewTaskManager
|
|
22
22
|
from edsl.jobs.FetchInvigilator import FetchInvigilator
|
23
23
|
from edsl.jobs.RequestTokenEstimator import RequestTokenEstimator
|
24
24
|
|
25
|
+
|
25
26
|
if TYPE_CHECKING:
|
26
27
|
from edsl.agents.Agent import Agent
|
27
28
|
from edsl.surveys.Survey import Survey
|
@@ -29,6 +30,16 @@ if TYPE_CHECKING:
|
|
29
30
|
from edsl.data.Cache import Cache
|
30
31
|
from edsl.language_models.LanguageModel import LanguageModel
|
31
32
|
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
33
|
+
from edsl.agents.InvigilatorBase import InvigilatorBase
|
34
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class InterviewRunningConfig:
|
39
|
+
cache: Optional["Cache"] = (None,)
|
40
|
+
skip_retry: bool = (False,) # COULD BE SET WITH CONFIG
|
41
|
+
raise_validation_errors: bool = (True,)
|
42
|
+
stop_on_exception: bool = (False,)
|
32
43
|
|
33
44
|
|
34
45
|
class Interview:
|
@@ -45,13 +56,11 @@ class Interview:
|
|
45
56
|
survey: Survey,
|
46
57
|
scenario: Scenario,
|
47
58
|
model: Type["LanguageModel"],
|
48
|
-
debug: Optional[bool] = False, # DEPRECATE
|
49
59
|
iteration: int = 0,
|
60
|
+
indices: dict = None, # explain?
|
50
61
|
cache: Optional["Cache"] = None,
|
51
|
-
sidecar_model: Optional["LanguageModel"] = None, # DEPRECATE
|
52
62
|
skip_retry: bool = False, # COULD BE SET WITH CONFIG
|
53
63
|
raise_validation_errors: bool = True,
|
54
|
-
indices: dict = None, # explain?
|
55
64
|
):
|
56
65
|
"""Initialize the Interview instance.
|
57
66
|
|
@@ -59,10 +68,9 @@ class Interview:
|
|
59
68
|
:param survey: the survey being administered to the agent.
|
60
69
|
:param scenario: the scenario that populates the survey questions.
|
61
70
|
:param model: the language model used to answer the questions.
|
62
|
-
:param debug: if True, run without calls to the language model.
|
71
|
+
# :param debug: if True, run without calls to the language model.
|
63
72
|
:param iteration: the iteration number of the interview.
|
64
73
|
:param cache: the cache used to store the answers.
|
65
|
-
:param sidecar_model: a sidecar model used to answer questions.
|
66
74
|
|
67
75
|
>>> i = Interview.example()
|
68
76
|
>>> i.task_manager.task_creators
|
@@ -83,12 +91,9 @@ class Interview:
|
|
83
91
|
self.survey = copy.deepcopy(survey) # why do we need to deepcopy the survey?
|
84
92
|
self.scenario = scenario
|
85
93
|
self.model = model
|
86
|
-
self.debug = debug
|
87
94
|
self.iteration = iteration
|
88
|
-
self.cache = cache
|
89
95
|
|
90
96
|
self.answers = Answers() # will get filled in as interview progresses
|
91
|
-
self.sidecar_model = sidecar_model
|
92
97
|
|
93
98
|
self.task_manager = InterviewTaskManager(
|
94
99
|
survey=self.survey,
|
@@ -97,6 +102,13 @@ class Interview:
|
|
97
102
|
|
98
103
|
self.exceptions = InterviewExceptionCollection()
|
99
104
|
|
105
|
+
self.running_config = InterviewRunningConfig(
|
106
|
+
cache=cache,
|
107
|
+
skip_retry=skip_retry,
|
108
|
+
raise_validation_errors=raise_validation_errors,
|
109
|
+
)
|
110
|
+
|
111
|
+
self.cache = cache
|
100
112
|
self.skip_retry = skip_retry
|
101
113
|
self.raise_validation_errors = raise_validation_errors
|
102
114
|
|
@@ -109,6 +121,7 @@ class Interview:
|
|
109
121
|
self.failed_questions = []
|
110
122
|
|
111
123
|
self.indices = indices
|
124
|
+
self.initial_hash = hash(self)
|
112
125
|
|
113
126
|
@property
|
114
127
|
def has_exceptions(self) -> bool:
|
@@ -134,7 +147,6 @@ class Interview:
|
|
134
147
|
# return self.task_creators.interview_status
|
135
148
|
return self.task_manager.interview_status
|
136
149
|
|
137
|
-
# region: Serialization
|
138
150
|
def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
|
139
151
|
"""Return a dictionary representation of the Interview instance.
|
140
152
|
This is just for hashing purposes.
|
@@ -198,13 +210,13 @@ class Interview:
|
|
198
210
|
"""
|
199
211
|
return hash(self) == hash(other)
|
200
212
|
|
201
|
-
# region: Conducting the interview
|
202
213
|
async def async_conduct_interview(
|
203
214
|
self,
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
raise_validation_errors: bool = True,
|
215
|
+
run_config: Optional["RunConfig"] = None,
|
216
|
+
# model_buckets: Optional[ModelBuckets] = None,
|
217
|
+
# stop_on_exception: bool = False,
|
218
|
+
# raise_validation_errors: bool = True,
|
219
|
+
# key_lookup: Optional[KeyLookup] = None,
|
208
220
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
209
221
|
"""
|
210
222
|
Conduct an Interview asynchronously.
|
@@ -213,7 +225,6 @@ class Interview:
|
|
213
225
|
:param model_buckets: a dictionary of token buckets for the model.
|
214
226
|
:param debug: run without calls to LLM.
|
215
227
|
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
216
|
-
:param sidecar_model: a sidecar model used to answer questions.
|
217
228
|
|
218
229
|
Example usage:
|
219
230
|
|
@@ -227,21 +238,39 @@ class Interview:
|
|
227
238
|
>>> i.exceptions
|
228
239
|
{'q0': ...
|
229
240
|
>>> i = Interview.example()
|
230
|
-
>>>
|
241
|
+
>>> from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
|
242
|
+
>>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
|
243
|
+
>>> run_config.parameters.stop_on_exception = True
|
244
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
|
231
245
|
Traceback (most recent call last):
|
232
246
|
...
|
233
247
|
asyncio.exceptions.CancelledError
|
234
248
|
"""
|
235
|
-
|
236
|
-
|
249
|
+
from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
|
250
|
+
|
251
|
+
if run_config is None:
|
252
|
+
run_config = RunConfig(
|
253
|
+
parameters=RunParameters(),
|
254
|
+
environment=RunEnvironment(),
|
255
|
+
)
|
256
|
+
self.stop_on_exception = run_config.parameters.stop_on_exception
|
237
257
|
|
238
258
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
259
|
+
bucket_collection = run_config.environment.bucket_collection
|
260
|
+
|
261
|
+
if bucket_collection:
|
262
|
+
model_buckets = bucket_collection.get(self.model)
|
263
|
+
else:
|
264
|
+
model_buckets = None
|
265
|
+
|
239
266
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
240
267
|
model_buckets = ModelBuckets.infinity_bucket()
|
241
268
|
|
242
269
|
# was "self.tasks" - is that necessary?
|
243
270
|
self.tasks = self.task_manager.build_question_tasks(
|
244
|
-
answer_func=AnswerQuestionFunctionConstructor(
|
271
|
+
answer_func=AnswerQuestionFunctionConstructor(
|
272
|
+
self, key_lookup=run_config.environment.key_lookup
|
273
|
+
)(),
|
245
274
|
token_estimator=RequestTokenEstimator(self),
|
246
275
|
model_buckets=model_buckets,
|
247
276
|
)
|
@@ -250,23 +279,26 @@ class Interview:
|
|
250
279
|
## with dependencies on the questions that must be answered before this one can be answered.
|
251
280
|
|
252
281
|
## 'Invigilators' are used to administer the survey.
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
282
|
+
fetcher = FetchInvigilator(
|
283
|
+
interview=self,
|
284
|
+
current_answers=self.answers,
|
285
|
+
key_lookup=run_config.environment.key_lookup,
|
286
|
+
)
|
287
|
+
self.invigilators = [fetcher(question) for question in self.survey.questions]
|
288
|
+
await asyncio.gather(
|
289
|
+
*self.tasks, return_exceptions=not run_config.parameters.stop_on_exception
|
290
|
+
)
|
258
291
|
self.answers.replace_missing_answers_with_none(self.survey)
|
259
292
|
valid_results = list(
|
260
293
|
self._extract_valid_results(self.tasks, self.invigilators, self.exceptions)
|
261
294
|
)
|
262
295
|
return self.answers, valid_results
|
263
296
|
|
264
|
-
# endregion
|
265
|
-
|
266
|
-
# region: Extracting results and recording errors
|
267
297
|
@staticmethod
|
268
298
|
def _extract_valid_results(
|
269
|
-
tasks
|
299
|
+
tasks: List["asyncio.Task"],
|
300
|
+
invigilators: List["InvigilatorBase"],
|
301
|
+
exceptions: InterviewExceptionCollection,
|
270
302
|
) -> Generator["Answers", None, None]:
|
271
303
|
"""Extract the valid results from the list of results.
|
272
304
|
|
@@ -279,10 +311,7 @@ class Interview:
|
|
279
311
|
"""
|
280
312
|
assert len(tasks) == len(invigilators)
|
281
313
|
|
282
|
-
|
283
|
-
if not task.done():
|
284
|
-
raise ValueError(f"Task {task.get_name()} is not done.")
|
285
|
-
|
314
|
+
def handle_task(task, invigilator):
|
286
315
|
try:
|
287
316
|
result = task.result()
|
288
317
|
except asyncio.CancelledError as e: # task was cancelled
|
@@ -298,17 +327,21 @@ class Interview:
|
|
298
327
|
invigilator=invigilator,
|
299
328
|
)
|
300
329
|
exceptions.add(task.get_name(), exception_entry)
|
330
|
+
return result
|
301
331
|
|
302
|
-
|
332
|
+
for task, invigilator in zip(tasks, invigilators):
|
333
|
+
if not task.done():
|
334
|
+
raise ValueError(f"Task {task.get_name()} is not done.")
|
303
335
|
|
304
|
-
|
336
|
+
yield handle_task(task, invigilator)
|
305
337
|
|
306
|
-
# region: Magic methods
|
307
338
|
def __repr__(self) -> str:
|
308
339
|
"""Return a string representation of the Interview instance."""
|
309
340
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
310
341
|
|
311
|
-
def duplicate(
|
342
|
+
def duplicate(
|
343
|
+
self, iteration: int, cache: "Cache", randomize_survey: Optional[bool] = True
|
344
|
+
) -> Interview:
|
312
345
|
"""Duplicate the interview, but with a new iteration number and cache.
|
313
346
|
|
314
347
|
>>> i = Interview.example()
|
@@ -317,14 +350,19 @@ class Interview:
|
|
317
350
|
True
|
318
351
|
|
319
352
|
"""
|
353
|
+
if randomize_survey:
|
354
|
+
new_survey = self.survey.draw()
|
355
|
+
else:
|
356
|
+
new_survey = self.survey
|
357
|
+
|
320
358
|
return Interview(
|
321
359
|
agent=self.agent,
|
322
|
-
survey=
|
360
|
+
survey=new_survey,
|
323
361
|
scenario=self.scenario,
|
324
362
|
model=self.model,
|
325
363
|
iteration=iteration,
|
326
|
-
cache=cache,
|
327
|
-
skip_retry=self.skip_retry,
|
364
|
+
cache=self.running_config.cache,
|
365
|
+
skip_retry=self.running_config.skip_retry,
|
328
366
|
indices=self.indices,
|
329
367
|
)
|
330
368
|
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import Optional, TYPE_CHECKING, Protocol
|
2
|
+
import sys
|
3
|
+
from edsl.scenarios.FileStore import HTMLFileStore
|
4
|
+
from edsl.config import CONFIG
|
5
|
+
from edsl.coop.coop import Coop
|
6
|
+
|
7
|
+
|
8
|
+
class ResultsProtocol(Protocol):
|
9
|
+
"""Protocol defining the required interface for Results objects."""
|
10
|
+
|
11
|
+
@property
|
12
|
+
def has_unfixed_exceptions(self) -> bool: ...
|
13
|
+
|
14
|
+
@property
|
15
|
+
def task_history(self) -> "TaskHistoryProtocol": ...
|
16
|
+
|
17
|
+
|
18
|
+
class TaskHistoryProtocol(Protocol):
|
19
|
+
"""Protocol defining the required interface for TaskHistory objects."""
|
20
|
+
|
21
|
+
@property
|
22
|
+
def indices(self) -> list: ...
|
23
|
+
|
24
|
+
def html(self, cta: str, open_in_browser: bool, return_link: bool) -> str: ...
|
25
|
+
|
26
|
+
|
27
|
+
class RunParametersProtocol(Protocol):
|
28
|
+
"""Protocol defining the required interface for RunParameters objects."""
|
29
|
+
|
30
|
+
@property
|
31
|
+
def print_exceptions(self) -> bool: ...
|
32
|
+
|
33
|
+
|
34
|
+
class ResultsExceptionsHandler:
|
35
|
+
"""Handles exception reporting and display functionality."""
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self, results: ResultsProtocol, parameters: RunParametersProtocol
|
39
|
+
) -> None:
|
40
|
+
self.results = results
|
41
|
+
self.parameters = parameters
|
42
|
+
|
43
|
+
self.open_in_browser = self._get_browser_setting()
|
44
|
+
self.remote_logging = self._get_remote_logging_setting()
|
45
|
+
|
46
|
+
def _get_browser_setting(self) -> bool:
|
47
|
+
"""Determine if exceptions should be opened in browser based on config."""
|
48
|
+
setting = CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL")
|
49
|
+
if setting == "True":
|
50
|
+
return True
|
51
|
+
elif setting == "False":
|
52
|
+
return False
|
53
|
+
else:
|
54
|
+
raise Exception(
|
55
|
+
"EDSL_OPEN_EXCEPTION_REPORT_URL must be either True or False"
|
56
|
+
)
|
57
|
+
|
58
|
+
def _get_remote_logging_setting(self) -> bool:
|
59
|
+
"""Get remote logging setting from coop."""
|
60
|
+
try:
|
61
|
+
coop = Coop()
|
62
|
+
return coop.edsl_settings["remote_logging"]
|
63
|
+
except Exception as e:
|
64
|
+
# print(e)
|
65
|
+
return False
|
66
|
+
|
67
|
+
def _generate_error_message(self, indices) -> str:
|
68
|
+
"""Generate appropriate error message based on number of exceptions."""
|
69
|
+
msg = f"Exceptions were raised in {len(indices)} interviews.\n"
|
70
|
+
if len(indices) > 5:
|
71
|
+
msg += f"Exceptions were raised in the following interviews: {indices}.\n"
|
72
|
+
return msg
|
73
|
+
|
74
|
+
def handle_exceptions(self) -> None:
|
75
|
+
"""Handle exceptions by printing messages and generating reports as needed."""
|
76
|
+
if not (
|
77
|
+
self.results.has_unfixed_exceptions and self.parameters.print_exceptions
|
78
|
+
):
|
79
|
+
return
|
80
|
+
|
81
|
+
# Print error message
|
82
|
+
error_msg = self._generate_error_message(self.results.task_history.indices)
|
83
|
+
print(error_msg, file=sys.stderr)
|
84
|
+
|
85
|
+
# Generate HTML report
|
86
|
+
filepath = self.results.task_history.html(
|
87
|
+
cta="Open report to see details.",
|
88
|
+
open_in_browser=self.open_in_browser,
|
89
|
+
return_link=True,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Handle remote logging if enabled
|
93
|
+
if self.remote_logging:
|
94
|
+
filestore = HTMLFileStore(filepath)
|
95
|
+
coop_details = filestore.push(description="Error report")
|
96
|
+
print(coop_details)
|
97
|
+
|
98
|
+
print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
|