edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -3,11 +3,20 @@
3
3
  from __future__ import annotations
4
4
  import asyncio
5
5
  from typing import Any, Type, List, Generator, Optional, Union
6
+ import copy
7
+
8
+ from tenacity import (
9
+ retry,
10
+ stop_after_attempt,
11
+ wait_exponential,
12
+ retry_if_exception_type,
13
+ RetryError,
14
+ )
6
15
 
7
16
  from edsl import CONFIG
8
17
  from edsl.surveys.base import EndOfSurvey
9
18
  from edsl.exceptions import QuestionAnswerValidationError
10
- from edsl.exceptions import InterviewTimeoutError
19
+ from edsl.exceptions import QuestionAnswerValidationError
11
20
  from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
12
21
 
13
22
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
@@ -15,21 +24,18 @@ from edsl.jobs.Answers import Answers
15
24
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
16
25
  from edsl.jobs.tasks.TaskCreators import TaskCreators
17
26
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
18
- from edsl.jobs.interviews.interview_exception_tracking import (
27
+ from edsl.jobs.interviews.InterviewExceptionCollection import (
19
28
  InterviewExceptionCollection,
20
29
  )
21
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
22
- from edsl.jobs.interviews.retry_management import retry_strategy
30
+
23
31
  from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
24
32
 
25
33
  from edsl.surveys.base import EndOfSurvey
26
34
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
27
35
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
28
- from edsl.jobs.interviews.retry_management import retry_strategy
29
36
  from edsl.jobs.tasks.task_status_enum import TaskStatus
30
37
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
31
38
 
32
- from edsl.exceptions import QuestionAnswerValidationError
33
39
 
34
40
  from edsl import Agent, Survey, Scenario, Cache
35
41
  from edsl.language_models import LanguageModel
@@ -39,8 +45,11 @@ from edsl.agents.InvigilatorBase import InvigilatorBase
39
45
  from edsl.exceptions.language_models import LanguageModelNoResponseError
40
46
 
41
47
 
42
- class RetryableLanguageModelNoResponseError(LanguageModelNoResponseError):
43
- pass
48
+ from edsl import CONFIG
49
+
50
+ EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
51
+ EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
52
+ EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
44
53
 
45
54
 
46
55
  class Interview(InterviewStatusMixin):
@@ -91,20 +100,25 @@ class Interview(InterviewStatusMixin):
91
100
 
92
101
  """
93
102
  self.agent = agent
94
- self.survey = survey
103
+ # what I would like to do
104
+ self.survey = copy.deepcopy(survey) # survey copy.deepcopy(survey)
105
+ # self.survey = survey
95
106
  self.scenario = scenario
96
107
  self.model = model
97
108
  self.debug = debug
98
109
  self.iteration = iteration
99
110
  self.cache = cache
100
- self.answers: dict[str, str] = (
101
- Answers()
102
- ) # will get filled in as interview progresses
111
+ self.answers: dict[
112
+ str, str
113
+ ] = Answers() # will get filled in as interview progresses
103
114
  self.sidecar_model = sidecar_model
104
115
 
116
+ # self.stop_on_exception = False
117
+
105
118
  # Trackers
106
119
  self.task_creators = TaskCreators() # tracks the task creators
107
120
  self.exceptions = InterviewExceptionCollection()
121
+
108
122
  self._task_status_log_dict = InterviewStatusLog()
109
123
  self.skip_retry = skip_retry
110
124
  self.raise_validation_errors = raise_validation_errors
@@ -237,17 +251,24 @@ class Interview(InterviewStatusMixin):
237
251
 
238
252
  def _get_estimated_request_tokens(self, question) -> float:
239
253
  """Estimate the number of tokens that will be required to run the focal task."""
254
+ from edsl.scenarios.FileStore import FileStore
255
+
240
256
  invigilator = self._get_invigilator(question=question)
241
257
  # TODO: There should be a way to get a more accurate estimate.
242
258
  combined_text = ""
259
+ file_tokens = 0
243
260
  for prompt in invigilator.get_prompts().values():
244
261
  if hasattr(prompt, "text"):
245
262
  combined_text += prompt.text
246
263
  elif isinstance(prompt, str):
247
264
  combined_text += prompt
265
+ elif isinstance(prompt, list):
266
+ for file in prompt:
267
+ if isinstance(file, FileStore):
268
+ file_tokens += file.size * 0.25
248
269
  else:
249
270
  raise ValueError(f"Prompt is of type {type(prompt)}")
250
- return len(combined_text) / 4.0
271
+ return len(combined_text) / 4.0 + file_tokens
251
272
 
252
273
  async def _answer_question_and_record_task(
253
274
  self,
@@ -257,44 +278,83 @@ class Interview(InterviewStatusMixin):
257
278
  ) -> "AgentResponseDict":
258
279
  """Answer a question and records the task."""
259
280
 
260
- invigilator = self._get_invigilator(question)
281
+ had_language_model_no_response_error = False
261
282
 
262
- if self._skip_this_question(question):
263
- response = invigilator.get_failed_task_result(
264
- failure_reason="Question skipped."
265
- )
283
+ @retry(
284
+ stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
285
+ wait=wait_exponential(
286
+ multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
287
+ ),
288
+ retry=retry_if_exception_type(LanguageModelNoResponseError),
289
+ reraise=True,
290
+ )
291
+ async def attempt_answer():
292
+ nonlocal had_language_model_no_response_error
266
293
 
267
- try:
268
- response: EDSLResultObjectInput = await invigilator.async_answer_question()
269
- if response.validated:
270
- self.answers.add_answer(response=response, question=question)
271
- self._cancel_skipped_questions(question)
272
- else:
273
- if (
274
- hasattr(response, "exception_occurred")
275
- and response.exception_occurred
276
- ):
277
- raise response.exception_occurred
294
+ invigilator = self._get_invigilator(question)
278
295
 
279
- except QuestionAnswerValidationError as e:
280
- # there's a response, but it couldn't be validated
281
- self._handle_exception(e, invigilator, task)
296
+ if self._skip_this_question(question):
297
+ return invigilator.get_failed_task_result(
298
+ failure_reason="Question skipped."
299
+ )
282
300
 
283
- except asyncio.TimeoutError as e:
284
- # the API timed-out - this is recorded but as a response isn't generated, the LanguageModelNoResponseError will also be raised
285
- self._handle_exception(e, invigilator, task)
301
+ try:
302
+ response: EDSLResultObjectInput = (
303
+ await invigilator.async_answer_question()
304
+ )
305
+ if response.validated:
306
+ self.answers.add_answer(response=response, question=question)
307
+ self._cancel_skipped_questions(question)
308
+ else:
309
+ # When a question is not validated, it is not added to the answers.
310
+ # this should also cancel and dependent children questions.
311
+ # Is that happening now?
312
+ if (
313
+ hasattr(response, "exception_occurred")
314
+ and response.exception_occurred
315
+ ):
316
+ raise response.exception_occurred
317
+
318
+ except QuestionAnswerValidationError as e:
319
+ self._handle_exception(e, invigilator, task)
320
+ return invigilator.get_failed_task_result(
321
+ failure_reason="Question answer validation failed."
322
+ )
286
323
 
287
- except Exception as e:
288
- # there was some other exception
289
- self._handle_exception(e, invigilator, task)
324
+ except asyncio.TimeoutError as e:
325
+ self._handle_exception(e, invigilator, task)
326
+ had_language_model_no_response_error = True
327
+ raise LanguageModelNoResponseError(
328
+ f"Language model timed out for question '{question.question_name}.'"
329
+ )
290
330
 
291
- if "response" not in locals():
331
+ except Exception as e:
332
+ self._handle_exception(e, invigilator, task)
292
333
 
293
- raise LanguageModelNoResponseError(
294
- f"Language model did not return a response for question '{question.question_name}.'"
295
- )
334
+ if "response" not in locals():
335
+ had_language_model_no_response_error = True
336
+ raise LanguageModelNoResponseError(
337
+ f"Language model did not return a response for question '{question.question_name}.'"
338
+ )
339
+
340
+ # if it gets here, it means the no response error was fixed
341
+ if (
342
+ question.question_name in self.exceptions
343
+ and had_language_model_no_response_error
344
+ ):
345
+ self.exceptions.record_fixed_question(question.question_name)
296
346
 
297
- return response
347
+ return response
348
+
349
+ try:
350
+ return await attempt_answer()
351
+ except RetryError as retry_error:
352
+ # All retries have failed for LanguageModelNoResponseError
353
+ original_error = retry_error.last_attempt.exception()
354
+ self._handle_exception(
355
+ original_error, self._get_invigilator(question), task
356
+ )
357
+ raise original_error # Re-raise the original error after handling
298
358
 
299
359
  def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
300
360
  """Return an invigilator for the given question.
@@ -334,14 +394,32 @@ class Interview(InterviewStatusMixin):
334
394
  def _handle_exception(
335
395
  self, e: Exception, invigilator: "InvigilatorBase", task=None
336
396
  ):
397
+ import copy
398
+
399
+ # breakpoint()
400
+
401
+ answers = copy.copy(self.answers)
337
402
  exception_entry = InterviewExceptionEntry(
338
403
  exception=e,
339
404
  invigilator=invigilator,
405
+ answers=answers,
340
406
  )
341
407
  if task:
342
408
  task.task_status = TaskStatus.FAILED
343
409
  self.exceptions.add(invigilator.question.question_name, exception_entry)
344
410
 
411
+ if self.raise_validation_errors:
412
+ if isinstance(e, QuestionAnswerValidationError):
413
+ raise e
414
+
415
+ if hasattr(self, "stop_on_exception"):
416
+ stop_on_exception = self.stop_on_exception
417
+ else:
418
+ stop_on_exception = False
419
+
420
+ if stop_on_exception:
421
+ raise e
422
+
345
423
  def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
346
424
  """Cancel the tasks for questions that are skipped.
347
425
 
@@ -353,11 +431,11 @@ class Interview(InterviewStatusMixin):
353
431
  """
354
432
  current_question_index: int = self.to_index[current_question.question_name]
355
433
 
356
- next_question: Union[int, EndOfSurvey] = (
357
- self.survey.rule_collection.next_question(
358
- q_now=current_question_index,
359
- answers=self.answers | self.scenario | self.agent["traits"],
360
- )
434
+ next_question: Union[
435
+ int, EndOfSurvey
436
+ ] = self.survey.rule_collection.next_question(
437
+ q_now=current_question_index,
438
+ answers=self.answers | self.scenario | self.agent["traits"],
361
439
  )
362
440
 
363
441
  next_question_index = next_question.next_q
@@ -411,6 +489,7 @@ class Interview(InterviewStatusMixin):
411
489
  asyncio.exceptions.CancelledError
412
490
  """
413
491
  self.sidecar_model = sidecar_model
492
+ self.stop_on_exception = stop_on_exception
414
493
 
415
494
  # if no model bucket is passed, create an 'infinity' bucket with no rate limits
416
495
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
@@ -424,7 +503,9 @@ class Interview(InterviewStatusMixin):
424
503
  self.invigilators = [
425
504
  self._get_invigilator(question) for question in self.survey.questions
426
505
  ]
427
- await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
506
+ await asyncio.gather(
507
+ *self.tasks, return_exceptions=not stop_on_exception
508
+ ) # not stop_on_exception)
428
509
  self.answers.replace_missing_answers_with_none(self.survey)
429
510
  valid_results = list(self._extract_valid_results())
430
511
  return self.answers, valid_results
@@ -6,6 +6,22 @@ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
6
6
  class InterviewExceptionCollection(UserDict):
7
7
  """A collection of exceptions that occurred during the interview."""
8
8
 
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.fixed = set()
12
+
13
+ def unfixed_exceptions(self) -> list:
14
+ """Return a list of unfixed exceptions."""
15
+ return {k: v for k, v in self.data.items() if k not in self.fixed}
16
+
17
+ def num_unfixed(self) -> list:
18
+ """Return a list of unfixed questions."""
19
+ return len([k for k in self.data.keys() if k not in self.fixed])
20
+
21
+ def record_fixed_question(self, question_name: str) -> None:
22
+ """Record that a question has been fixed."""
23
+ self.fixed.add(question_name)
24
+
9
25
  def add(self, question_name: str, entry: InterviewExceptionEntry) -> None:
10
26
  """Add an exception entry to the collection."""
11
27
  question_name = question_name
@@ -15,12 +15,14 @@ class InterviewExceptionEntry:
15
15
  # failed_question: FailedQuestion,
16
16
  invigilator: "Invigilator",
17
17
  traceback_format="text",
18
+ answers=None,
18
19
  ):
19
20
  self.time = datetime.datetime.now().isoformat()
20
21
  self.exception = exception
21
22
  # self.failed_question = failed_question
22
23
  self.invigilator = invigilator
23
24
  self.traceback_format = traceback_format
25
+ self.answers = answers
24
26
 
25
27
  @property
26
28
  def question_type(self):