edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. edsl/TemplateLoader.py +24 -0
  2. edsl/__init__.py +8 -4
  3. edsl/agents/Agent.py +46 -14
  4. edsl/agents/AgentList.py +43 -0
  5. edsl/agents/Invigilator.py +125 -212
  6. edsl/agents/InvigilatorBase.py +140 -32
  7. edsl/agents/PromptConstructionMixin.py +43 -66
  8. edsl/agents/__init__.py +1 -0
  9. edsl/auto/AutoStudy.py +117 -0
  10. edsl/auto/StageBase.py +230 -0
  11. edsl/auto/StageGenerateSurvey.py +178 -0
  12. edsl/auto/StageLabelQuestions.py +125 -0
  13. edsl/auto/StagePersona.py +61 -0
  14. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  15. edsl/auto/StagePersonaDimensionValues.py +74 -0
  16. edsl/auto/StagePersonaDimensions.py +69 -0
  17. edsl/auto/StageQuestions.py +73 -0
  18. edsl/auto/SurveyCreatorPipeline.py +21 -0
  19. edsl/auto/utilities.py +224 -0
  20. edsl/config.py +38 -39
  21. edsl/coop/PriceFetcher.py +58 -0
  22. edsl/coop/coop.py +39 -5
  23. edsl/data/Cache.py +35 -1
  24. edsl/data_transfer_models.py +120 -38
  25. edsl/enums.py +2 -0
  26. edsl/exceptions/language_models.py +25 -1
  27. edsl/exceptions/questions.py +62 -5
  28. edsl/exceptions/results.py +4 -0
  29. edsl/inference_services/AnthropicService.py +13 -11
  30. edsl/inference_services/AwsBedrock.py +19 -17
  31. edsl/inference_services/AzureAI.py +37 -20
  32. edsl/inference_services/GoogleService.py +16 -12
  33. edsl/inference_services/GroqService.py +2 -0
  34. edsl/inference_services/InferenceServiceABC.py +24 -0
  35. edsl/inference_services/MistralAIService.py +120 -0
  36. edsl/inference_services/OpenAIService.py +41 -50
  37. edsl/inference_services/TestService.py +71 -0
  38. edsl/inference_services/models_available_cache.py +0 -6
  39. edsl/inference_services/registry.py +4 -0
  40. edsl/jobs/Answers.py +10 -12
  41. edsl/jobs/FailedQuestion.py +78 -0
  42. edsl/jobs/Jobs.py +18 -13
  43. edsl/jobs/buckets/TokenBucket.py +39 -14
  44. edsl/jobs/interviews/Interview.py +297 -77
  45. edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
  46. edsl/jobs/interviews/interview_exception_tracking.py +0 -70
  47. edsl/jobs/interviews/retry_management.py +3 -1
  48. edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
  49. edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
  50. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  51. edsl/jobs/tasks/TaskHistory.py +131 -213
  52. edsl/language_models/LanguageModel.py +239 -129
  53. edsl/language_models/ModelList.py +2 -2
  54. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  55. edsl/language_models/fake_openai_call.py +15 -0
  56. edsl/language_models/fake_openai_service.py +61 -0
  57. edsl/language_models/registry.py +15 -2
  58. edsl/language_models/repair.py +0 -19
  59. edsl/language_models/utilities.py +61 -0
  60. edsl/prompts/Prompt.py +52 -2
  61. edsl/questions/AnswerValidatorMixin.py +23 -26
  62. edsl/questions/QuestionBase.py +273 -242
  63. edsl/questions/QuestionBaseGenMixin.py +133 -0
  64. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  65. edsl/questions/QuestionBudget.py +6 -0
  66. edsl/questions/QuestionCheckBox.py +227 -35
  67. edsl/questions/QuestionExtract.py +98 -27
  68. edsl/questions/QuestionFreeText.py +46 -29
  69. edsl/questions/QuestionFunctional.py +7 -0
  70. edsl/questions/QuestionList.py +141 -22
  71. edsl/questions/QuestionMultipleChoice.py +173 -64
  72. edsl/questions/QuestionNumerical.py +87 -46
  73. edsl/questions/QuestionRank.py +182 -24
  74. edsl/questions/RegisterQuestionsMeta.py +31 -12
  75. edsl/questions/ResponseValidatorABC.py +169 -0
  76. edsl/questions/__init__.py +3 -4
  77. edsl/questions/decorators.py +21 -0
  78. edsl/questions/derived/QuestionLikertFive.py +10 -5
  79. edsl/questions/derived/QuestionLinearScale.py +11 -1
  80. edsl/questions/derived/QuestionTopK.py +6 -0
  81. edsl/questions/derived/QuestionYesNo.py +16 -1
  82. edsl/questions/descriptors.py +43 -7
  83. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  84. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  85. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  86. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  87. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  88. edsl/questions/prompt_templates/question_list.jinja +17 -0
  89. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  90. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  91. edsl/questions/question_registry.py +6 -2
  92. edsl/questions/templates/__init__.py +0 -0
  93. edsl/questions/templates/checkbox/__init__.py +0 -0
  94. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  95. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  96. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  97. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  98. edsl/questions/templates/free_text/__init__.py +0 -0
  99. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  100. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  101. edsl/questions/templates/likert_five/__init__.py +0 -0
  102. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  104. edsl/questions/templates/linear_scale/__init__.py +0 -0
  105. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  106. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  107. edsl/questions/templates/list/__init__.py +0 -0
  108. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  109. edsl/questions/templates/list/question_presentation.jinja +5 -0
  110. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  111. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  112. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  113. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  114. edsl/questions/templates/numerical/__init__.py +0 -0
  115. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  116. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  117. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  118. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  119. edsl/questions/templates/top_k/__init__.py +0 -0
  120. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  121. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  122. edsl/questions/templates/yes_no/__init__.py +0 -0
  123. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  124. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  125. edsl/results/Dataset.py +20 -0
  126. edsl/results/DatasetExportMixin.py +41 -47
  127. edsl/results/DatasetTree.py +145 -0
  128. edsl/results/Result.py +32 -5
  129. edsl/results/Results.py +131 -45
  130. edsl/results/ResultsDBMixin.py +3 -3
  131. edsl/results/Selector.py +118 -0
  132. edsl/results/tree_explore.py +115 -0
  133. edsl/scenarios/Scenario.py +10 -4
  134. edsl/scenarios/ScenarioList.py +348 -39
  135. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  136. edsl/study/SnapShot.py +8 -1
  137. edsl/surveys/RuleCollection.py +2 -2
  138. edsl/surveys/Survey.py +634 -315
  139. edsl/surveys/SurveyExportMixin.py +71 -9
  140. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  141. edsl/surveys/SurveyQualtricsImport.py +75 -4
  142. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  143. edsl/surveys/instructions/Instruction.py +34 -0
  144. edsl/surveys/instructions/InstructionCollection.py +77 -0
  145. edsl/surveys/instructions/__init__.py +0 -0
  146. edsl/templates/error_reporting/base.html +24 -0
  147. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  148. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  149. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  150. edsl/templates/error_reporting/interview_details.html +111 -0
  151. edsl/templates/error_reporting/interviews.html +10 -0
  152. edsl/templates/error_reporting/overview.html +5 -0
  153. edsl/templates/error_reporting/performance_plot.html +2 -0
  154. edsl/templates/error_reporting/report.css +74 -0
  155. edsl/templates/error_reporting/report.html +118 -0
  156. edsl/templates/error_reporting/report.js +25 -0
  157. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
  158. edsl-0.1.33.dev2.dist-info/RECORD +289 -0
  159. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  160. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  161. edsl-0.1.33.dev1.dist-info/RECORD +0 -209
  162. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
  163. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -156,7 +156,11 @@ class Jobs(Base):
156
156
  from edsl.results.Dataset import Dataset
157
157
 
158
158
  for interview_index, interview in enumerate(interviews):
159
- invigilators = list(interview._build_invigilators(debug=False))
159
+ invigilators = [
160
+ interview._get_invigilator(question)
161
+ for question in self.survey.questions
162
+ ]
163
+ # list(interview._build_invigilators(debug=False))
160
164
  for _, invigilator in enumerate(invigilators):
161
165
  prompts = invigilator.get_prompts()
162
166
  user_prompts.append(prompts["user_prompt"])
@@ -344,6 +348,7 @@ class Jobs(Base):
344
348
  scenario=scenario,
345
349
  model=model,
346
350
  skip_retry=self.skip_retry,
351
+ raise_validation_errors=self.raise_validation_errors,
347
352
  )
348
353
 
349
354
  def create_bucket_collection(self) -> BucketCollection:
@@ -461,27 +466,31 @@ class Jobs(Base):
461
466
  return False
462
467
  return self._skip_retry
463
468
 
469
+ @property
470
+ def raise_validation_errors(self):
471
+ if not hasattr(self, "_raise_validation_errors"):
472
+ return False
473
+ return self._raise_validation_errors
474
+
464
475
  def run(
465
476
  self,
466
477
  n: int = 1,
467
- debug: bool = False,
468
478
  progress_bar: bool = False,
469
479
  stop_on_exception: bool = False,
470
480
  cache: Union[Cache, bool] = None,
471
481
  check_api_keys: bool = False,
472
482
  sidecar_model: Optional[LanguageModel] = None,
473
- batch_mode: Optional[bool] = None,
474
483
  verbose: bool = False,
475
484
  print_exceptions=True,
476
485
  remote_cache_description: Optional[str] = None,
477
486
  remote_inference_description: Optional[str] = None,
478
487
  skip_retry: bool = False,
488
+ raise_validation_errors: bool = False,
479
489
  ) -> Results:
480
490
  """
481
491
  Runs the Job: conducts Interviews and returns their results.
482
492
 
483
493
  :param n: how many times to run each interview
484
- :param debug: prints debug messages
485
494
  :param progress_bar: shows a progress bar
486
495
  :param stop_on_exception: stops the job if an exception is raised
487
496
  :param cache: a cache object to store results
@@ -495,11 +504,7 @@ class Jobs(Base):
495
504
 
496
505
  self._check_parameters()
497
506
  self._skip_retry = skip_retry
498
-
499
- if batch_mode is not None:
500
- raise NotImplementedError(
501
- "Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
502
- )
507
+ self._raise_validation_errors = raise_validation_errors
503
508
 
504
509
  self.verbose = verbose
505
510
 
@@ -587,7 +592,7 @@ class Jobs(Base):
587
592
  )
588
593
 
589
594
  # handle cache
590
- if cache is None:
595
+ if cache is None or cache is True:
591
596
  from edsl.data.CacheHandler import CacheHandler
592
597
 
593
598
  cache = CacheHandler().get_cache()
@@ -599,12 +604,12 @@ class Jobs(Base):
599
604
  if not remote_cache:
600
605
  results = self._run_local(
601
606
  n=n,
602
- debug=debug,
603
607
  progress_bar=progress_bar,
604
608
  cache=cache,
605
609
  stop_on_exception=stop_on_exception,
606
610
  sidecar_model=sidecar_model,
607
611
  print_exceptions=print_exceptions,
612
+ raise_validation_errors=raise_validation_errors,
608
613
  )
609
614
 
610
615
  results.cache = cache.new_entries_cache()
@@ -643,12 +648,12 @@ class Jobs(Base):
643
648
  self._output("Running job...")
644
649
  results = self._run_local(
645
650
  n=n,
646
- debug=debug,
647
651
  progress_bar=progress_bar,
648
652
  cache=cache,
649
653
  stop_on_exception=stop_on_exception,
650
654
  sidecar_model=sidecar_model,
651
655
  print_exceptions=print_exceptions,
656
+ raise_validation_errors=raise_validation_errors,
652
657
  )
653
658
  self._output("Job completed!")
654
659
 
@@ -883,7 +888,7 @@ def main():
883
888
 
884
889
  job = Jobs.example()
885
890
  len(job) == 8
886
- results = job.run(debug=True, cache=Cache())
891
+ results = job.run(cache=Cache())
887
892
  len(results) == 8
888
893
  results
889
894
 
@@ -82,23 +82,30 @@ class TokenBucket:
82
82
  >>> bucket.refill()
83
83
  >>> bucket.tokens > 0
84
84
  True
85
-
86
85
  """
86
+ """Refill the bucket with new tokens based on elapsed time."""
87
87
  now = time.monotonic()
88
+ # print(f"Time is now: {now}; Last refill time: {self.last_refill}")
88
89
  elapsed = now - self.last_refill
90
+ # print("Elapsed time: ", elapsed)
89
91
  refill_amount = elapsed * self.refill_rate
90
92
  self.tokens = min(self.capacity, self.tokens + refill_amount)
91
93
  self.last_refill = now
92
94
 
95
+ if self.tokens < self.capacity:
96
+ pass
97
+ # print(f"Refilled. Current tokens: {self.tokens:.4f}")
98
+ # print(f"Elapsed time: {elapsed:.4f} seconds")
99
+ # print(f"Refill amount: {refill_amount:.4f}")
100
+
93
101
  self.log.append((now, self.tokens))
94
102
 
95
103
  def wait_time(self, requested_tokens: Union[float, int]) -> float:
96
104
  """Calculate the time to wait for the requested number of tokens."""
97
- now = time.monotonic()
98
- elapsed = now - self.last_refill
99
- refill_amount = elapsed * self.refill_rate
100
- available_tokens = min(self.capacity, self.tokens + refill_amount)
101
- return max(0, requested_tokens - available_tokens) / self.refill_rate
105
+ # self.refill() # Update the current token count
106
+ if self.tokens >= requested_tokens:
107
+ return 0
108
+ return (requested_tokens - self.tokens) / self.refill_rate
102
109
 
103
110
  async def get_tokens(
104
111
  self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
@@ -123,22 +130,40 @@ class TokenBucket:
123
130
  ...
124
131
  ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
125
132
  >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
133
+ >>> bucket.capacity
134
+ 12.100000000000001
126
135
  """
127
- if amount > self.capacity:
136
+ if amount >= self.capacity:
128
137
  if not cheat_bucket_capacity:
129
138
  msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
130
139
  raise ValueError(msg)
131
140
  else:
132
- self.tokens = 0 # clear the bucket but let it go through
133
- return
134
-
135
- while self.tokens < amount:
136
- self.refill()
137
- await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
138
- self.tokens -= amount
141
+ # self.tokens = 0 # clear the bucket but let it go through
142
+ # print(
143
+ # f"""The requested amount, {amount}, exceeds the current bucket capacity of {self.capacity}.Increasing bucket capacity to {amount} * 1.10 accommodate the requested amount."""
144
+ # )
145
+ self.capacity = amount * 1.10
146
+ self._old_capacity = self.capacity
147
+
148
+ start_time = time.monotonic()
149
+ while True:
150
+ self.refill() # Refill based on elapsed time
151
+ if self.tokens >= amount:
152
+ self.tokens -= amount
153
+ break
154
+
155
+ wait_time = self.wait_time(amount)
156
+ # print(f"Waiting for {wait_time:.4f} seconds")
157
+ if wait_time > 0:
158
+ # print(f"Waiting for {wait_time:.4f} seconds")
159
+ await asyncio.sleep(wait_time)
160
+
161
+ # total_elapsed = time.monotonic() - start_time
162
+ # print(f"Total time to acquire tokens: {total_elapsed:.4f} seconds")
139
163
 
140
164
  now = time.monotonic()
141
165
  self.log.append((now, self.tokens))
166
+ return None
142
167
 
143
168
  def get_log(self) -> list[tuple]:
144
169
  return self.log
@@ -1,50 +1,68 @@
1
1
  """This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
2
2
 
3
3
  from __future__ import annotations
4
- import traceback
5
4
  import asyncio
6
- import time
7
- from typing import Any, Type, List, Generator, Optional
5
+ from typing import Any, Type, List, Generator, Optional, Union
8
6
 
9
- from edsl.jobs.Answers import Answers
7
+ from edsl import CONFIG
10
8
  from edsl.surveys.base import EndOfSurvey
9
+ from edsl.exceptions import QuestionAnswerValidationError
10
+ from edsl.exceptions import InterviewTimeoutError
11
+ from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
12
+
11
13
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
14
+ from edsl.jobs.Answers import Answers
15
+ from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
12
16
  from edsl.jobs.tasks.TaskCreators import TaskCreators
13
-
14
17
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
15
18
  from edsl.jobs.interviews.interview_exception_tracking import (
16
19
  InterviewExceptionCollection,
17
20
  )
18
21
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
19
22
  from edsl.jobs.interviews.retry_management import retry_strategy
20
- from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
21
23
  from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
22
24
 
23
- import asyncio
25
+ from edsl.surveys.base import EndOfSurvey
26
+ from edsl.jobs.buckets.ModelBuckets import ModelBuckets
27
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
28
+ from edsl.jobs.interviews.retry_management import retry_strategy
29
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
30
+ from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
31
+
32
+ from edsl.exceptions import QuestionAnswerValidationError
33
+
34
+ from edsl import Agent, Survey, Scenario, Cache
35
+ from edsl.language_models import LanguageModel
36
+ from edsl.questions import QuestionBase
37
+ from edsl.agents.InvigilatorBase import InvigilatorBase
38
+
39
+ from edsl.exceptions.language_models import LanguageModelNoResponseError
24
40
 
25
41
 
26
- def run_async(coro):
27
- return asyncio.run(coro)
42
+ class RetryableLanguageModelNoResponseError(LanguageModelNoResponseError):
43
+ pass
28
44
 
29
45
 
30
- class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
46
+ class Interview(InterviewStatusMixin):
31
47
  """
32
48
  An 'interview' is one agent answering one survey, with one language model, for a given scenario.
33
49
 
34
50
  The main method is `async_conduct_interview`, which conducts the interview asynchronously.
51
+ Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
35
52
  """
36
53
 
37
54
  def __init__(
38
55
  self,
39
- agent: "Agent",
40
- survey: "Survey",
41
- scenario: "Scenario",
56
+ agent: Agent,
57
+ survey: Survey,
58
+ scenario: Scenario,
42
59
  model: Type["LanguageModel"],
43
60
  debug: Optional[bool] = False,
44
61
  iteration: int = 0,
45
62
  cache: Optional["Cache"] = None,
46
63
  sidecar_model: Optional["LanguageModel"] = None,
47
- skip_retry=False,
64
+ skip_retry: bool = False,
65
+ raise_validation_errors: bool = True,
48
66
  ):
49
67
  """Initialize the Interview instance.
50
68
 
@@ -79,9 +97,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
79
97
  self.debug = debug
80
98
  self.iteration = iteration
81
99
  self.cache = cache
82
- self.answers: dict[
83
- str, str
84
- ] = Answers() # will get filled in as interview progresses
100
+ self.answers: dict[str, str] = (
101
+ Answers()
102
+ ) # will get filled in as interview progresses
85
103
  self.sidecar_model = sidecar_model
86
104
 
87
105
  # Trackers
@@ -89,6 +107,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
89
107
  self.exceptions = InterviewExceptionCollection()
90
108
  self._task_status_log_dict = InterviewStatusLog()
91
109
  self.skip_retry = skip_retry
110
+ self.raise_validation_errors = raise_validation_errors
92
111
 
93
112
  # dictionary mapping question names to their index in the survey.
94
113
  self.to_index = {
@@ -96,6 +115,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
96
115
  for index, question_name in enumerate(self.survey.question_names)
97
116
  }
98
117
 
118
+ self.failed_questions = []
119
+
120
+ # region: Serialization
99
121
  def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
100
122
  """Return a dictionary representation of the Interview instance.
101
123
  This is just for hashing purposes.
@@ -120,13 +142,247 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
120
142
 
121
143
  return dict_hash(self._to_dict())
122
144
 
123
- async def async_conduct_interview(
145
+ # endregion
146
+
147
+ # region: Creating tasks
148
+ @property
149
+ def dag(self) -> "DAG":
150
+ """Return the directed acyclic graph for the survey.
151
+
152
+ The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
153
+ It is used to determine the order in which questions should be answered.
154
+ This reflects both agent 'memory' considerations and 'skip' logic.
155
+ The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
156
+
157
+ >>> i = Interview.example()
158
+ >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
159
+ True
160
+ """
161
+ return self.survey.dag(textify=True)
162
+
163
+ def _build_question_tasks(
164
+ self,
165
+ model_buckets: ModelBuckets,
166
+ ) -> list[asyncio.Task]:
167
+ """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
168
+
169
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
170
+ :param model_buckets: the model buckets used to track and control usage rates.
171
+ """
172
+ tasks = []
173
+ for question in self.survey.questions:
174
+ tasks_that_must_be_completed_before = list(
175
+ self._get_tasks_that_must_be_completed_before(
176
+ tasks=tasks, question=question
177
+ )
178
+ )
179
+ question_task = self._create_question_task(
180
+ question=question,
181
+ tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
182
+ model_buckets=model_buckets,
183
+ iteration=self.iteration,
184
+ )
185
+ tasks.append(question_task)
186
+ return tuple(tasks)
187
+
188
+ def _get_tasks_that_must_be_completed_before(
189
+ self, *, tasks: list[asyncio.Task], question: "QuestionBase"
190
+ ) -> Generator[asyncio.Task, None, None]:
191
+ """Return the tasks that must be completed before the given question can be answered.
192
+
193
+ :param tasks: a list of tasks that have been created so far.
194
+ :param question: the question for which we are determining dependencies.
195
+
196
+ If a question has no dependencies, this will be an empty list, [].
197
+ """
198
+ parents_of_focal_question = self.dag.get(question.question_name, [])
199
+ for parent_question_name in parents_of_focal_question:
200
+ yield tasks[self.to_index[parent_question_name]]
201
+
202
+ def _create_question_task(
203
+ self,
204
+ *,
205
+ question: QuestionBase,
206
+ tasks_that_must_be_completed_before: list[asyncio.Task],
207
+ model_buckets: ModelBuckets,
208
+ iteration: int = 0,
209
+ ) -> asyncio.Task:
210
+ """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
211
+
212
+ :param question: the question to be answered. This is the question we are creating a task for.
213
+ :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
214
+ :param model_buckets: the model buckets used to track and control usage rates.
215
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
216
+ :param iteration: the iteration number for the interview.
217
+
218
+ The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
219
+ It is passed a reference to the function that will be called to answer the question.
220
+ It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
221
+ These are added as a dependency to the focal task.
222
+ """
223
+ task_creator = QuestionTaskCreator(
224
+ question=question,
225
+ answer_question_func=self._answer_question_and_record_task,
226
+ token_estimator=self._get_estimated_request_tokens,
227
+ model_buckets=model_buckets,
228
+ iteration=iteration,
229
+ )
230
+ for task in tasks_that_must_be_completed_before:
231
+ task_creator.add_dependency(task)
232
+
233
+ self.task_creators.update(
234
+ {question.question_name: task_creator}
235
+ ) # track this task creator
236
+ return task_creator.generate_task()
237
+
238
+ def _get_estimated_request_tokens(self, question) -> float:
239
+ """Estimate the number of tokens that will be required to run the focal task."""
240
+ invigilator = self._get_invigilator(question=question)
241
+ # TODO: There should be a way to get a more accurate estimate.
242
+ combined_text = ""
243
+ for prompt in invigilator.get_prompts().values():
244
+ if hasattr(prompt, "text"):
245
+ combined_text += prompt.text
246
+ elif isinstance(prompt, str):
247
+ combined_text += prompt
248
+ else:
249
+ raise ValueError(f"Prompt is of type {type(prompt)}")
250
+ return len(combined_text) / 4.0
251
+
252
+ async def _answer_question_and_record_task(
124
253
  self,
125
254
  *,
126
- model_buckets: ModelBuckets = None,
127
- debug: bool = False,
255
+ question: "QuestionBase",
256
+ task=None,
257
+ ) -> "AgentResponseDict":
258
+ """Answer a question and records the task."""
259
+
260
+ invigilator = self._get_invigilator(question)
261
+
262
+ if self._skip_this_question(question):
263
+ response = invigilator.get_failed_task_result(
264
+ failure_reason="Question skipped."
265
+ )
266
+
267
+ try:
268
+ response: EDSLResultObjectInput = await invigilator.async_answer_question()
269
+ if response.validated:
270
+ self.answers.add_answer(response=response, question=question)
271
+ self._cancel_skipped_questions(question)
272
+ else:
273
+ if (
274
+ hasattr(response, "exception_occurred")
275
+ and response.exception_occurred
276
+ ):
277
+ raise response.exception_occurred
278
+
279
+ except QuestionAnswerValidationError as e:
280
+ # there's a response, but it couldn't be validated
281
+ self._handle_exception(e, invigilator, task)
282
+
283
+ except asyncio.TimeoutError as e:
284
+ # the API timed-out - this is recorded but as a response isn't generated, the LanguageModelNoResponseError will also be raised
285
+ self._handle_exception(e, invigilator, task)
286
+
287
+ except Exception as e:
288
+ # there was some other exception
289
+ self._handle_exception(e, invigilator, task)
290
+
291
+ if "response" not in locals():
292
+
293
+ raise LanguageModelNoResponseError(
294
+ f"Language model did not return a response for question '{question.question_name}.'"
295
+ )
296
+
297
+ return response
298
+
299
+ def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
300
+ """Return an invigilator for the given question.
301
+
302
+ :param question: the question to be answered
303
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
304
+ """
305
+ invigilator = self.agent.create_invigilator(
306
+ question=question,
307
+ scenario=self.scenario,
308
+ model=self.model,
309
+ debug=False,
310
+ survey=self.survey,
311
+ memory_plan=self.survey.memory_plan,
312
+ current_answers=self.answers,
313
+ iteration=self.iteration,
314
+ cache=self.cache,
315
+ sidecar_model=self.sidecar_model,
316
+ raise_validation_errors=self.raise_validation_errors,
317
+ )
318
+ """Return an invigilator for the given question."""
319
+ return invigilator
320
+
321
+ def _skip_this_question(self, current_question: "QuestionBase") -> bool:
322
+ """Determine if the current question should be skipped.
323
+
324
+ :param current_question: the question to be answered.
325
+ """
326
+ current_question_index = self.to_index[current_question.question_name]
327
+
328
+ answers = self.answers | self.scenario | self.agent["traits"]
329
+ skip = self.survey.rule_collection.skip_question_before_running(
330
+ current_question_index, answers
331
+ )
332
+ return skip
333
+
334
+ def _handle_exception(
335
+ self, e: Exception, invigilator: "InvigilatorBase", task=None
336
+ ):
337
+ exception_entry = InterviewExceptionEntry(
338
+ exception=e,
339
+ invigilator=invigilator,
340
+ )
341
+ if task:
342
+ task.task_status = TaskStatus.FAILED
343
+ self.exceptions.add(invigilator.question.question_name, exception_entry)
344
+
345
+ def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
346
+ """Cancel the tasks for questions that are skipped.
347
+
348
+ :param current_question: the question that was just answered.
349
+
350
+ It first determines the next question, given the current question and the current answers.
351
+ If the next question is the end of the survey, it cancels all remaining tasks.
352
+ If the next question is after the current question, it cancels all tasks between the current question and the next question.
353
+ """
354
+ current_question_index: int = self.to_index[current_question.question_name]
355
+
356
+ next_question: Union[int, EndOfSurvey] = (
357
+ self.survey.rule_collection.next_question(
358
+ q_now=current_question_index,
359
+ answers=self.answers | self.scenario | self.agent["traits"],
360
+ )
361
+ )
362
+
363
+ next_question_index = next_question.next_q
364
+
365
+ def cancel_between(start, end):
366
+ """Cancel the tasks between the start and end indices."""
367
+ for i in range(start, end):
368
+ self.tasks[i].cancel()
369
+
370
+ if next_question_index == EndOfSurvey:
371
+ cancel_between(current_question_index + 1, len(self.survey.questions))
372
+ return
373
+
374
+ if next_question_index > (current_question_index + 1):
375
+ cancel_between(current_question_index + 1, next_question_index)
376
+
377
+ # endregion
378
+
379
+ # region: Conducting the interview
380
+ async def async_conduct_interview(
381
+ self,
382
+ model_buckets: Optional[ModelBuckets] = None,
128
383
  stop_on_exception: bool = False,
129
384
  sidecar_model: Optional["LanguageModel"] = None,
385
+ raise_validation_errors: bool = True,
130
386
  ) -> tuple["Answers", List[dict[str, Any]]]:
131
387
  """
132
388
  Conduct an Interview asynchronously.
@@ -146,19 +402,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
146
402
 
147
403
  >>> i = Interview.example(throw_exception = True)
148
404
  >>> result, _ = asyncio.run(i.async_conduct_interview())
149
- Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
150
- <BLANKLINE>
151
- <BLANKLINE>
152
- Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
153
- <BLANKLINE>
154
- <BLANKLINE>
155
- Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
156
- <BLANKLINE>
157
- <BLANKLINE>
158
- Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
159
- <BLANKLINE>
160
- <BLANKLINE>
161
-
162
405
  >>> i.exceptions
163
406
  {'q0': ...
164
407
  >>> i = Interview.example()
@@ -173,21 +416,22 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
173
416
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
174
417
  model_buckets = ModelBuckets.infinity_bucket()
175
418
 
176
- ## build the tasks using the InterviewTaskBuildingMixin
177
419
  ## This is the key part---it creates a task for each question,
178
420
  ## with dependencies on the questions that must be answered before this one can be answered.
179
- self.tasks = self._build_question_tasks(
180
- debug=debug, model_buckets=model_buckets
181
- )
421
+ self.tasks = self._build_question_tasks(model_buckets=model_buckets)
182
422
 
183
423
  ## 'Invigilators' are used to administer the survey
184
- self.invigilators = list(self._build_invigilators(debug=debug))
185
- # await the tasks being conducted
424
+ self.invigilators = [
425
+ self._get_invigilator(question) for question in self.survey.questions
426
+ ]
186
427
  await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
187
428
  self.answers.replace_missing_answers_with_none(self.survey)
188
429
  valid_results = list(self._extract_valid_results())
189
430
  return self.answers, valid_results
190
431
 
432
+ # endregion
433
+
434
+ # region: Extracting results and recording errors
191
435
  def _extract_valid_results(self) -> Generator["Answers", None, None]:
192
436
  """Extract the valid results from the list of results.
193
437
 
@@ -200,8 +444,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
200
444
  >>> results = list(i._extract_valid_results())
201
445
  >>> len(results) == len(i.survey)
202
446
  True
203
- >>> type(results[0])
204
- <class 'edsl.data_transfer_models.AgentResponseDict'>
205
447
  """
206
448
  assert len(self.tasks) == len(self.invigilators)
207
449
 
@@ -212,46 +454,24 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
212
454
  try:
213
455
  result = task.result()
214
456
  except asyncio.CancelledError as e: # task was cancelled
215
- result = invigilator.get_failed_task_result()
457
+ result = invigilator.get_failed_task_result(
458
+ failure_reason="Task was cancelled."
459
+ )
216
460
  except Exception as e: # any other kind of exception in the task
217
- result = invigilator.get_failed_task_result()
218
- self._record_exception(task, e)
219
- yield result
461
+ result = invigilator.get_failed_task_result(
462
+ failure_reason=f"Task failed with exception: {str(e)}."
463
+ )
464
+ exception_entry = InterviewExceptionEntry(
465
+ exception=e,
466
+ invigilator=invigilator,
467
+ )
468
+ self.exceptions.add(task.get_name(), exception_entry)
220
469
 
221
- def _record_exception(self, task, exception: Exception) -> None:
222
- """Record an exception in the Interview instance.
223
-
224
- It records the exception in the Interview instance, with the task name and the exception entry.
225
-
226
- >>> i = Interview.example()
227
- >>> result, _ = asyncio.run(i.async_conduct_interview())
228
- >>> i.exceptions
229
- {}
230
- >>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
231
- >>> i.exceptions
232
- {'q0': ...
233
- """
234
- exception_entry = InterviewExceptionEntry(exception)
235
- self.exceptions.add(task.get_name(), exception_entry)
236
-
237
- @property
238
- def dag(self) -> "DAG":
239
- """Return the directed acyclic graph for the survey.
240
-
241
- The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
242
- It is used to determine the order in which questions should be answered.
243
- This reflects both agent 'memory' considerations and 'skip' logic.
244
- The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
470
+ yield result
245
471
 
246
- >>> i = Interview.example()
247
- >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
248
- True
249
- """
250
- return self.survey.dag(textify=True)
472
+ # endregion
251
473
 
252
- #######################
253
- # Dunder methods
254
- #######################
474
+ # region: Magic methods
255
475
  def __repr__(self) -> str:
256
476
  """Return a string representation of the Interview instance."""
257
477
  return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"