edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +24 -14
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +28 -6
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
- edsl/agents/prompt_helpers.py +129 -0
- edsl/config.py +26 -34
- edsl/coop/coop.py +14 -4
- edsl/data_transfer_models.py +26 -73
- edsl/enums.py +2 -0
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +10 -6
- edsl/inference_services/TestService.py +34 -16
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +109 -18
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +130 -49
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
- edsl/jobs/runners/JobsRunnerStatus.py +332 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +36 -38
- edsl/language_models/registry.py +13 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +74 -16
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +13 -24
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +11 -6
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +178 -34
- edsl/scenarios/Scenario.py +76 -37
- edsl/scenarios/ScenarioList.py +19 -2
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +34 -1
- edsl/surveys/RuleCollection.py +55 -5
- edsl/surveys/Survey.py +189 -10
- edsl/surveys/base.py +4 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -3,11 +3,20 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import asyncio
|
5
5
|
from typing import Any, Type, List, Generator, Optional, Union
|
6
|
+
import copy
|
7
|
+
|
8
|
+
from tenacity import (
|
9
|
+
retry,
|
10
|
+
stop_after_attempt,
|
11
|
+
wait_exponential,
|
12
|
+
retry_if_exception_type,
|
13
|
+
RetryError,
|
14
|
+
)
|
6
15
|
|
7
16
|
from edsl import CONFIG
|
8
17
|
from edsl.surveys.base import EndOfSurvey
|
9
18
|
from edsl.exceptions import QuestionAnswerValidationError
|
10
|
-
from edsl.exceptions import
|
19
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
11
20
|
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
12
21
|
|
13
22
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
@@ -15,21 +24,18 @@ from edsl.jobs.Answers import Answers
|
|
15
24
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
16
25
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
17
26
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
18
|
-
from edsl.jobs.interviews.
|
27
|
+
from edsl.jobs.interviews.InterviewExceptionCollection import (
|
19
28
|
InterviewExceptionCollection,
|
20
29
|
)
|
21
|
-
|
22
|
-
from edsl.jobs.interviews.retry_management import retry_strategy
|
30
|
+
|
23
31
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
24
32
|
|
25
33
|
from edsl.surveys.base import EndOfSurvey
|
26
34
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
27
35
|
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
28
|
-
from edsl.jobs.interviews.retry_management import retry_strategy
|
29
36
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
30
37
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
31
38
|
|
32
|
-
from edsl.exceptions import QuestionAnswerValidationError
|
33
39
|
|
34
40
|
from edsl import Agent, Survey, Scenario, Cache
|
35
41
|
from edsl.language_models import LanguageModel
|
@@ -39,8 +45,11 @@ from edsl.agents.InvigilatorBase import InvigilatorBase
|
|
39
45
|
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
40
46
|
|
41
47
|
|
42
|
-
|
43
|
-
|
48
|
+
from edsl import CONFIG
|
49
|
+
|
50
|
+
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
51
|
+
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
52
|
+
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
44
53
|
|
45
54
|
|
46
55
|
class Interview(InterviewStatusMixin):
|
@@ -91,20 +100,25 @@ class Interview(InterviewStatusMixin):
|
|
91
100
|
|
92
101
|
"""
|
93
102
|
self.agent = agent
|
94
|
-
|
103
|
+
# what I would like to do
|
104
|
+
self.survey = copy.deepcopy(survey) # survey copy.deepcopy(survey)
|
105
|
+
# self.survey = survey
|
95
106
|
self.scenario = scenario
|
96
107
|
self.model = model
|
97
108
|
self.debug = debug
|
98
109
|
self.iteration = iteration
|
99
110
|
self.cache = cache
|
100
|
-
self.answers: dict[
|
101
|
-
|
102
|
-
) # will get filled in as interview progresses
|
111
|
+
self.answers: dict[
|
112
|
+
str, str
|
113
|
+
] = Answers() # will get filled in as interview progresses
|
103
114
|
self.sidecar_model = sidecar_model
|
104
115
|
|
116
|
+
# self.stop_on_exception = False
|
117
|
+
|
105
118
|
# Trackers
|
106
119
|
self.task_creators = TaskCreators() # tracks the task creators
|
107
120
|
self.exceptions = InterviewExceptionCollection()
|
121
|
+
|
108
122
|
self._task_status_log_dict = InterviewStatusLog()
|
109
123
|
self.skip_retry = skip_retry
|
110
124
|
self.raise_validation_errors = raise_validation_errors
|
@@ -237,17 +251,24 @@ class Interview(InterviewStatusMixin):
|
|
237
251
|
|
238
252
|
def _get_estimated_request_tokens(self, question) -> float:
|
239
253
|
"""Estimate the number of tokens that will be required to run the focal task."""
|
254
|
+
from edsl.scenarios.FileStore import FileStore
|
255
|
+
|
240
256
|
invigilator = self._get_invigilator(question=question)
|
241
257
|
# TODO: There should be a way to get a more accurate estimate.
|
242
258
|
combined_text = ""
|
259
|
+
file_tokens = 0
|
243
260
|
for prompt in invigilator.get_prompts().values():
|
244
261
|
if hasattr(prompt, "text"):
|
245
262
|
combined_text += prompt.text
|
246
263
|
elif isinstance(prompt, str):
|
247
264
|
combined_text += prompt
|
265
|
+
elif isinstance(prompt, list):
|
266
|
+
for file in prompt:
|
267
|
+
if isinstance(file, FileStore):
|
268
|
+
file_tokens += file.size * 0.25
|
248
269
|
else:
|
249
270
|
raise ValueError(f"Prompt is of type {type(prompt)}")
|
250
|
-
return len(combined_text) / 4.0
|
271
|
+
return len(combined_text) / 4.0 + file_tokens
|
251
272
|
|
252
273
|
async def _answer_question_and_record_task(
|
253
274
|
self,
|
@@ -257,44 +278,83 @@ class Interview(InterviewStatusMixin):
|
|
257
278
|
) -> "AgentResponseDict":
|
258
279
|
"""Answer a question and records the task."""
|
259
280
|
|
260
|
-
|
281
|
+
had_language_model_no_response_error = False
|
261
282
|
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
283
|
+
@retry(
|
284
|
+
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
285
|
+
wait=wait_exponential(
|
286
|
+
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
287
|
+
),
|
288
|
+
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
289
|
+
reraise=True,
|
290
|
+
)
|
291
|
+
async def attempt_answer():
|
292
|
+
nonlocal had_language_model_no_response_error
|
266
293
|
|
267
|
-
|
268
|
-
response: EDSLResultObjectInput = await invigilator.async_answer_question()
|
269
|
-
if response.validated:
|
270
|
-
self.answers.add_answer(response=response, question=question)
|
271
|
-
self._cancel_skipped_questions(question)
|
272
|
-
else:
|
273
|
-
if (
|
274
|
-
hasattr(response, "exception_occurred")
|
275
|
-
and response.exception_occurred
|
276
|
-
):
|
277
|
-
raise response.exception_occurred
|
294
|
+
invigilator = self._get_invigilator(question)
|
278
295
|
|
279
|
-
|
280
|
-
|
281
|
-
|
296
|
+
if self._skip_this_question(question):
|
297
|
+
return invigilator.get_failed_task_result(
|
298
|
+
failure_reason="Question skipped."
|
299
|
+
)
|
282
300
|
|
283
|
-
|
284
|
-
|
285
|
-
|
301
|
+
try:
|
302
|
+
response: EDSLResultObjectInput = (
|
303
|
+
await invigilator.async_answer_question()
|
304
|
+
)
|
305
|
+
if response.validated:
|
306
|
+
self.answers.add_answer(response=response, question=question)
|
307
|
+
self._cancel_skipped_questions(question)
|
308
|
+
else:
|
309
|
+
# When a question is not validated, it is not added to the answers.
|
310
|
+
# this should also cancel and dependent children questions.
|
311
|
+
# Is that happening now?
|
312
|
+
if (
|
313
|
+
hasattr(response, "exception_occurred")
|
314
|
+
and response.exception_occurred
|
315
|
+
):
|
316
|
+
raise response.exception_occurred
|
317
|
+
|
318
|
+
except QuestionAnswerValidationError as e:
|
319
|
+
self._handle_exception(e, invigilator, task)
|
320
|
+
return invigilator.get_failed_task_result(
|
321
|
+
failure_reason="Question answer validation failed."
|
322
|
+
)
|
286
323
|
|
287
|
-
|
288
|
-
|
289
|
-
|
324
|
+
except asyncio.TimeoutError as e:
|
325
|
+
self._handle_exception(e, invigilator, task)
|
326
|
+
had_language_model_no_response_error = True
|
327
|
+
raise LanguageModelNoResponseError(
|
328
|
+
f"Language model timed out for question '{question.question_name}.'"
|
329
|
+
)
|
290
330
|
|
291
|
-
|
331
|
+
except Exception as e:
|
332
|
+
self._handle_exception(e, invigilator, task)
|
292
333
|
|
293
|
-
|
294
|
-
|
295
|
-
|
334
|
+
if "response" not in locals():
|
335
|
+
had_language_model_no_response_error = True
|
336
|
+
raise LanguageModelNoResponseError(
|
337
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
338
|
+
)
|
339
|
+
|
340
|
+
# if it gets here, it means the no response error was fixed
|
341
|
+
if (
|
342
|
+
question.question_name in self.exceptions
|
343
|
+
and had_language_model_no_response_error
|
344
|
+
):
|
345
|
+
self.exceptions.record_fixed_question(question.question_name)
|
296
346
|
|
297
|
-
|
347
|
+
return response
|
348
|
+
|
349
|
+
try:
|
350
|
+
return await attempt_answer()
|
351
|
+
except RetryError as retry_error:
|
352
|
+
# All retries have failed for LanguageModelNoResponseError
|
353
|
+
original_error = retry_error.last_attempt.exception()
|
354
|
+
self._handle_exception(
|
355
|
+
original_error, self._get_invigilator(question), task
|
356
|
+
)
|
357
|
+
raise original_error # Re-raise the original error after handling
|
298
358
|
|
299
359
|
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
300
360
|
"""Return an invigilator for the given question.
|
@@ -334,14 +394,32 @@ class Interview(InterviewStatusMixin):
|
|
334
394
|
def _handle_exception(
|
335
395
|
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
336
396
|
):
|
397
|
+
import copy
|
398
|
+
|
399
|
+
# breakpoint()
|
400
|
+
|
401
|
+
answers = copy.copy(self.answers)
|
337
402
|
exception_entry = InterviewExceptionEntry(
|
338
403
|
exception=e,
|
339
404
|
invigilator=invigilator,
|
405
|
+
answers=answers,
|
340
406
|
)
|
341
407
|
if task:
|
342
408
|
task.task_status = TaskStatus.FAILED
|
343
409
|
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
344
410
|
|
411
|
+
if self.raise_validation_errors:
|
412
|
+
if isinstance(e, QuestionAnswerValidationError):
|
413
|
+
raise e
|
414
|
+
|
415
|
+
if hasattr(self, "stop_on_exception"):
|
416
|
+
stop_on_exception = self.stop_on_exception
|
417
|
+
else:
|
418
|
+
stop_on_exception = False
|
419
|
+
|
420
|
+
if stop_on_exception:
|
421
|
+
raise e
|
422
|
+
|
345
423
|
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
346
424
|
"""Cancel the tasks for questions that are skipped.
|
347
425
|
|
@@ -353,11 +431,11 @@ class Interview(InterviewStatusMixin):
|
|
353
431
|
"""
|
354
432
|
current_question_index: int = self.to_index[current_question.question_name]
|
355
433
|
|
356
|
-
next_question: Union[
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
434
|
+
next_question: Union[
|
435
|
+
int, EndOfSurvey
|
436
|
+
] = self.survey.rule_collection.next_question(
|
437
|
+
q_now=current_question_index,
|
438
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
361
439
|
)
|
362
440
|
|
363
441
|
next_question_index = next_question.next_q
|
@@ -411,6 +489,7 @@ class Interview(InterviewStatusMixin):
|
|
411
489
|
asyncio.exceptions.CancelledError
|
412
490
|
"""
|
413
491
|
self.sidecar_model = sidecar_model
|
492
|
+
self.stop_on_exception = stop_on_exception
|
414
493
|
|
415
494
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
416
495
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
@@ -424,7 +503,9 @@ class Interview(InterviewStatusMixin):
|
|
424
503
|
self.invigilators = [
|
425
504
|
self._get_invigilator(question) for question in self.survey.questions
|
426
505
|
]
|
427
|
-
await asyncio.gather(
|
506
|
+
await asyncio.gather(
|
507
|
+
*self.tasks, return_exceptions=not stop_on_exception
|
508
|
+
) # not stop_on_exception)
|
428
509
|
self.answers.replace_missing_answers_with_none(self.survey)
|
429
510
|
valid_results = list(self._extract_valid_results())
|
430
511
|
return self.answers, valid_results
|
@@ -6,6 +6,22 @@ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
|
6
6
|
class InterviewExceptionCollection(UserDict):
|
7
7
|
"""A collection of exceptions that occurred during the interview."""
|
8
8
|
|
9
|
+
def __init__(self):
|
10
|
+
super().__init__()
|
11
|
+
self.fixed = set()
|
12
|
+
|
13
|
+
def unfixed_exceptions(self) -> list:
|
14
|
+
"""Return a list of unfixed exceptions."""
|
15
|
+
return {k: v for k, v in self.data.items() if k not in self.fixed}
|
16
|
+
|
17
|
+
def num_unfixed(self) -> list:
|
18
|
+
"""Return a list of unfixed questions."""
|
19
|
+
return len([k for k in self.data.keys() if k not in self.fixed])
|
20
|
+
|
21
|
+
def record_fixed_question(self, question_name: str) -> None:
|
22
|
+
"""Record that a question has been fixed."""
|
23
|
+
self.fixed.add(question_name)
|
24
|
+
|
9
25
|
def add(self, question_name: str, entry: InterviewExceptionEntry) -> None:
|
10
26
|
"""Add an exception entry to the collection."""
|
11
27
|
question_name = question_name
|
@@ -15,12 +15,14 @@ class InterviewExceptionEntry:
|
|
15
15
|
# failed_question: FailedQuestion,
|
16
16
|
invigilator: "Invigilator",
|
17
17
|
traceback_format="text",
|
18
|
+
answers=None,
|
18
19
|
):
|
19
20
|
self.time = datetime.datetime.now().isoformat()
|
20
21
|
self.exception = exception
|
21
22
|
# self.failed_question = failed_question
|
22
23
|
self.invigilator = invigilator
|
23
24
|
self.traceback_format = traceback_format
|
25
|
+
self.answers = answers
|
24
26
|
|
25
27
|
@property
|
26
28
|
def question_type(self):
|