edsl 0.1.33__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +0 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +3 -6
  6. edsl/agents/InvigilatorBase.py +27 -8
  7. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +29 -101
  8. edsl/config.py +34 -26
  9. edsl/coop/coop.py +2 -11
  10. edsl/data_transfer_models.py +73 -26
  11. edsl/enums.py +0 -2
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +13 -44
  14. edsl/inference_services/OpenAIService.py +4 -7
  15. edsl/inference_services/TestService.py +15 -24
  16. edsl/inference_services/registry.py +0 -2
  17. edsl/jobs/Jobs.py +8 -18
  18. edsl/jobs/buckets/BucketCollection.py +15 -24
  19. edsl/jobs/buckets/TokenBucket.py +10 -64
  20. edsl/jobs/interviews/Interview.py +47 -115
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +0 -2
  22. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +0 -16
  23. edsl/jobs/interviews/retry_management.py +39 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +170 -95
  25. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  26. edsl/jobs/tasks/TaskHistory.py +0 -17
  27. edsl/language_models/LanguageModel.py +31 -26
  28. edsl/language_models/registry.py +9 -13
  29. edsl/questions/QuestionBase.py +14 -63
  30. edsl/questions/QuestionBudget.py +41 -93
  31. edsl/questions/QuestionFreeText.py +0 -6
  32. edsl/questions/QuestionMultipleChoice.py +23 -8
  33. edsl/questions/QuestionNumerical.py +4 -5
  34. edsl/questions/ResponseValidatorABC.py +5 -6
  35. edsl/questions/derived/QuestionLinearScale.py +1 -4
  36. edsl/questions/derived/QuestionTopK.py +1 -4
  37. edsl/questions/derived/QuestionYesNo.py +2 -8
  38. edsl/results/DatasetExportMixin.py +1 -5
  39. edsl/results/Result.py +1 -1
  40. edsl/results/Results.py +1 -4
  41. edsl/scenarios/FileStore.py +10 -71
  42. edsl/scenarios/Scenario.py +21 -86
  43. edsl/scenarios/ScenarioImageMixin.py +2 -2
  44. edsl/scenarios/ScenarioList.py +0 -13
  45. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  46. edsl/study/Study.py +0 -32
  47. edsl/surveys/Rule.py +1 -10
  48. edsl/surveys/RuleCollection.py +3 -19
  49. edsl/surveys/Survey.py +0 -7
  50. edsl/templates/error_reporting/interview_details.html +1 -6
  51. edsl/utilities/utilities.py +1 -9
  52. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +1 -2
  53. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/RECORD +55 -61
  54. edsl/inference_services/TogetherAIService.py +0 -170
  55. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  56. edsl/questions/Quick.py +0 -41
  57. edsl/questions/templates/budget/__init__.py +0 -0
  58. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  59. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  60. edsl/questions/templates/extract/__init__.py +0 -0
  61. edsl/questions/templates/rank/__init__.py +0 -0
  62. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -4,18 +4,10 @@ from __future__ import annotations
4
4
  import asyncio
5
5
  from typing import Any, Type, List, Generator, Optional, Union
6
6
 
7
- from tenacity import (
8
- retry,
9
- stop_after_attempt,
10
- wait_exponential,
11
- retry_if_exception_type,
12
- RetryError,
13
- )
14
-
15
7
  from edsl import CONFIG
16
8
  from edsl.surveys.base import EndOfSurvey
17
9
  from edsl.exceptions import QuestionAnswerValidationError
18
- from edsl.exceptions import QuestionAnswerValidationError
10
+ from edsl.exceptions import InterviewTimeoutError
19
11
  from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
20
12
 
21
13
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
@@ -23,18 +15,21 @@ from edsl.jobs.Answers import Answers
23
15
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
24
16
  from edsl.jobs.tasks.TaskCreators import TaskCreators
25
17
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
26
- from edsl.jobs.interviews.InterviewExceptionCollection import (
18
+ from edsl.jobs.interviews.interview_exception_tracking import (
27
19
  InterviewExceptionCollection,
28
20
  )
29
-
21
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
22
+ from edsl.jobs.interviews.retry_management import retry_strategy
30
23
  from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
31
24
 
32
25
  from edsl.surveys.base import EndOfSurvey
33
26
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
34
27
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
28
+ from edsl.jobs.interviews.retry_management import retry_strategy
35
29
  from edsl.jobs.tasks.task_status_enum import TaskStatus
36
30
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
37
31
 
32
+ from edsl.exceptions import QuestionAnswerValidationError
38
33
 
39
34
  from edsl import Agent, Survey, Scenario, Cache
40
35
  from edsl.language_models import LanguageModel
@@ -44,11 +39,8 @@ from edsl.agents.InvigilatorBase import InvigilatorBase
44
39
  from edsl.exceptions.language_models import LanguageModelNoResponseError
45
40
 
46
41
 
47
- from edsl import CONFIG
48
-
49
- EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
50
- EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
51
- EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
42
+ class RetryableLanguageModelNoResponseError(LanguageModelNoResponseError):
43
+ pass
52
44
 
53
45
 
54
46
  class Interview(InterviewStatusMixin):
@@ -105,17 +97,14 @@ class Interview(InterviewStatusMixin):
105
97
  self.debug = debug
106
98
  self.iteration = iteration
107
99
  self.cache = cache
108
- self.answers: dict[
109
- str, str
110
- ] = Answers() # will get filled in as interview progresses
100
+ self.answers: dict[str, str] = (
101
+ Answers()
102
+ ) # will get filled in as interview progresses
111
103
  self.sidecar_model = sidecar_model
112
104
 
113
- # self.stop_on_exception = False
114
-
115
105
  # Trackers
116
106
  self.task_creators = TaskCreators() # tracks the task creators
117
107
  self.exceptions = InterviewExceptionCollection()
118
-
119
108
  self._task_status_log_dict = InterviewStatusLog()
120
109
  self.skip_retry = skip_retry
121
110
  self.raise_validation_errors = raise_validation_errors
@@ -268,80 +257,44 @@ class Interview(InterviewStatusMixin):
268
257
  ) -> "AgentResponseDict":
269
258
  """Answer a question and records the task."""
270
259
 
271
- had_language_model_no_response_error = False
272
-
273
- @retry(
274
- stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
275
- wait=wait_exponential(
276
- multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
277
- ),
278
- retry=retry_if_exception_type(LanguageModelNoResponseError),
279
- reraise=True,
280
- )
281
- async def attempt_answer():
282
- nonlocal had_language_model_no_response_error
283
-
284
- invigilator = self._get_invigilator(question)
285
-
286
- if self._skip_this_question(question):
287
- return invigilator.get_failed_task_result(
288
- failure_reason="Question skipped."
289
- )
260
+ invigilator = self._get_invigilator(question)
290
261
 
291
- try:
292
- response: EDSLResultObjectInput = (
293
- await invigilator.async_answer_question()
294
- )
295
- if response.validated:
296
- self.answers.add_answer(response=response, question=question)
297
- self._cancel_skipped_questions(question)
298
- else:
299
- if (
300
- hasattr(response, "exception_occurred")
301
- and response.exception_occurred
302
- ):
303
- raise response.exception_occurred
304
-
305
- except QuestionAnswerValidationError as e:
306
- self._handle_exception(e, invigilator, task)
307
- return invigilator.get_failed_task_result(
308
- failure_reason="Question answer validation failed."
309
- )
262
+ if self._skip_this_question(question):
263
+ response = invigilator.get_failed_task_result(
264
+ failure_reason="Question skipped."
265
+ )
310
266
 
311
- except asyncio.TimeoutError as e:
312
- self._handle_exception(e, invigilator, task)
313
- had_language_model_no_response_error = True
314
- raise LanguageModelNoResponseError(
315
- f"Language model timed out for question '{question.question_name}.'"
316
- )
267
+ try:
268
+ response: EDSLResultObjectInput = await invigilator.async_answer_question()
269
+ if response.validated:
270
+ self.answers.add_answer(response=response, question=question)
271
+ self._cancel_skipped_questions(question)
272
+ else:
273
+ if (
274
+ hasattr(response, "exception_occurred")
275
+ and response.exception_occurred
276
+ ):
277
+ raise response.exception_occurred
317
278
 
318
- except Exception as e:
319
- self._handle_exception(e, invigilator, task)
279
+ except QuestionAnswerValidationError as e:
280
+ # there's a response, but it couldn't be validated
281
+ self._handle_exception(e, invigilator, task)
320
282
 
321
- if "response" not in locals():
322
- had_language_model_no_response_error = True
323
- raise LanguageModelNoResponseError(
324
- f"Language model did not return a response for question '{question.question_name}.'"
325
- )
283
+ except asyncio.TimeoutError as e:
284
+ # the API timed-out - this is recorded but as a response isn't generated, the LanguageModelNoResponseError will also be raised
285
+ self._handle_exception(e, invigilator, task)
326
286
 
327
- # if it gets here, it means the no response error was fixed
328
- if (
329
- question.question_name in self.exceptions
330
- and had_language_model_no_response_error
331
- ):
332
- self.exceptions.record_fixed_question(question.question_name)
287
+ except Exception as e:
288
+ # there was some other exception
289
+ self._handle_exception(e, invigilator, task)
333
290
 
334
- return response
291
+ if "response" not in locals():
335
292
 
336
- try:
337
- return await attempt_answer()
338
- except RetryError as retry_error:
339
- # All retries have failed for LanguageModelNoResponseError
340
- original_error = retry_error.last_attempt.exception()
341
- self._handle_exception(
342
- original_error, self._get_invigilator(question), task
293
+ raise LanguageModelNoResponseError(
294
+ f"Language model did not return a response for question '{question.question_name}.'"
343
295
  )
344
- raise original_error # Re-raise the original error after handling
296
+
297
+ return response
345
298
 
346
299
  def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
347
300
  """Return an invigilator for the given question.
@@ -381,32 +334,14 @@ class Interview(InterviewStatusMixin):
381
334
  def _handle_exception(
382
335
  self, e: Exception, invigilator: "InvigilatorBase", task=None
383
336
  ):
384
- import copy
385
-
386
- # breakpoint()
387
-
388
- answers = copy.copy(self.answers)
389
337
  exception_entry = InterviewExceptionEntry(
390
338
  exception=e,
391
339
  invigilator=invigilator,
392
- answers=answers,
393
340
  )
394
341
  if task:
395
342
  task.task_status = TaskStatus.FAILED
396
343
  self.exceptions.add(invigilator.question.question_name, exception_entry)
397
344
 
398
- if self.raise_validation_errors:
399
- if isinstance(e, QuestionAnswerValidationError):
400
- raise e
401
-
402
- if hasattr(self, "stop_on_exception"):
403
- stop_on_exception = self.stop_on_exception
404
- else:
405
- stop_on_exception = False
406
-
407
- if stop_on_exception:
408
- raise e
409
-
410
345
  def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
411
346
  """Cancel the tasks for questions that are skipped.
412
347
 
@@ -418,11 +353,11 @@ class Interview(InterviewStatusMixin):
418
353
  """
419
354
  current_question_index: int = self.to_index[current_question.question_name]
420
355
 
421
- next_question: Union[
422
- int, EndOfSurvey
423
- ] = self.survey.rule_collection.next_question(
424
- q_now=current_question_index,
425
- answers=self.answers | self.scenario | self.agent["traits"],
356
+ next_question: Union[int, EndOfSurvey] = (
357
+ self.survey.rule_collection.next_question(
358
+ q_now=current_question_index,
359
+ answers=self.answers | self.scenario | self.agent["traits"],
360
+ )
426
361
  )
427
362
 
428
363
  next_question_index = next_question.next_q
@@ -476,7 +411,6 @@ class Interview(InterviewStatusMixin):
476
411
  asyncio.exceptions.CancelledError
477
412
  """
478
413
  self.sidecar_model = sidecar_model
479
- self.stop_on_exception = stop_on_exception
480
414
 
481
415
  # if no model bucket is passed, create an 'infinity' bucket with no rate limits
482
416
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
@@ -490,9 +424,7 @@ class Interview(InterviewStatusMixin):
490
424
  self.invigilators = [
491
425
  self._get_invigilator(question) for question in self.survey.questions
492
426
  ]
493
- await asyncio.gather(
494
- *self.tasks, return_exceptions=not stop_on_exception
495
- ) # not stop_on_exception)
427
+ await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
496
428
  self.answers.replace_missing_answers_with_none(self.survey)
497
429
  valid_results = list(self._extract_valid_results())
498
430
  return self.answers, valid_results
@@ -15,14 +15,12 @@ class InterviewExceptionEntry:
15
15
  # failed_question: FailedQuestion,
16
16
  invigilator: "Invigilator",
17
17
  traceback_format="text",
18
- answers=None,
19
18
  ):
20
19
  self.time = datetime.datetime.now().isoformat()
21
20
  self.exception = exception
22
21
  # self.failed_question = failed_question
23
22
  self.invigilator = invigilator
24
23
  self.traceback_format = traceback_format
25
- self.answers = answers
26
24
 
27
25
  @property
28
26
  def question_type(self):
@@ -6,22 +6,6 @@ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
6
6
  class InterviewExceptionCollection(UserDict):
7
7
  """A collection of exceptions that occurred during the interview."""
8
8
 
9
- def __init__(self):
10
- super().__init__()
11
- self.fixed = set()
12
-
13
- def unfixed_exceptions(self) -> list:
14
- """Return a list of unfixed exceptions."""
15
- return {k: v for k, v in self.data.items() if k not in self.fixed}
16
-
17
- def num_unfixed(self) -> list:
18
- """Return a list of unfixed questions."""
19
- return len([k for k in self.data.keys() if k not in self.fixed])
20
-
21
- def record_fixed_question(self, question_name: str) -> None:
22
- """Record that a question has been fixed."""
23
- self.fixed.add(question_name)
24
-
25
9
  def add(self, question_name: str, entry: InterviewExceptionEntry) -> None:
26
10
  """Add an exception entry to the collection."""
27
11
  question_name = question_name
@@ -0,0 +1,39 @@
1
+ from edsl import CONFIG
2
+
3
+ from tenacity import (
4
+ retry,
5
+ wait_exponential,
6
+ stop_after_attempt,
7
+ retry_if_exception_type,
8
+ before_sleep,
9
+ )
10
+
11
+ EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
12
+ EDSL_MAX_BACKOFF_SEC = float(CONFIG.get("EDSL_MAX_BACKOFF_SEC"))
13
+ EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
14
+
15
+
16
+ def print_retry(retry_state, print_to_terminal=True):
17
+ "Prints details on tenacity retries."
18
+ attempt_number = retry_state.attempt_number
19
+ exception = retry_state.outcome.exception()
20
+ wait_time = retry_state.next_action.sleep
21
+ exception_name = type(exception).__name__
22
+ if print_to_terminal:
23
+ print(
24
+ f"Attempt {attempt_number} failed with exception '{exception_name}':"
25
+ f"{exception}",
26
+ f"now waiting {wait_time:.2f} seconds before retrying."
27
+ f"Parameters: start={EDSL_BACKOFF_START_SEC}, max={EDSL_MAX_BACKOFF_SEC}, max_attempts={EDSL_MAX_ATTEMPTS}."
28
+ "\n\n",
29
+ )
30
+
31
+
32
+ retry_strategy = retry(
33
+ wait=wait_exponential(
34
+ multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_MAX_BACKOFF_SEC
35
+ ), # Exponential back-off starting at 1s, doubling, maxing out at 60s
36
+ stop=stop_after_attempt(EDSL_MAX_ATTEMPTS), # Stop after 5 attempts
37
+ # retry=retry_if_exception_type(Exception), # Customize this as per your specific retry-able exception
38
+ before_sleep=print_retry, # Use custom print function for retries
39
+ )