edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -2,58 +2,33 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import asyncio
5
- from typing import Any, Type, List, Generator, Optional, Union
5
+ from typing import Any, Type, List, Generator, Optional, Union, TYPE_CHECKING
6
6
  import copy
7
7
 
8
- from tenacity import (
9
- retry,
10
- stop_after_attempt,
11
- wait_exponential,
12
- retry_if_exception_type,
13
- RetryError,
14
- )
15
-
16
- from edsl import CONFIG
17
- from edsl.surveys.base import EndOfSurvey
18
- from edsl.exceptions import QuestionAnswerValidationError
19
- from edsl.exceptions import QuestionAnswerValidationError
20
- from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
8
+ # from edsl.config import CONFIG
21
9
 
22
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
23
10
  from edsl.jobs.Answers import Answers
24
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
25
- from edsl.jobs.tasks.TaskCreators import TaskCreators
26
11
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
12
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
27
13
  from edsl.jobs.interviews.InterviewExceptionCollection import (
28
14
  InterviewExceptionCollection,
29
15
  )
30
-
31
- # from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
32
-
33
- from edsl.surveys.base import EndOfSurvey
34
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
35
16
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
36
- from edsl.jobs.tasks.task_status_enum import TaskStatus
37
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
38
-
39
-
40
- from edsl import Agent, Survey, Scenario, Cache
41
- from edsl.language_models import LanguageModel
42
- from edsl.questions import QuestionBase
43
- from edsl.agents.InvigilatorBase import InvigilatorBase
44
-
45
- from edsl.exceptions.language_models import LanguageModelNoResponseError
46
-
47
- from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
48
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
49
- from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
50
-
51
-
52
- from edsl import CONFIG
17
+ from edsl.jobs.buckets.ModelBuckets import ModelBuckets
18
+ from edsl.jobs.AnswerQuestionFunctionConstructor import (
19
+ AnswerQuestionFunctionConstructor,
20
+ )
21
+ from edsl.jobs.InterviewTaskManager import InterviewTaskManager
22
+ from edsl.jobs.FetchInvigilator import FetchInvigilator
23
+ from edsl.jobs.RequestTokenEstimator import RequestTokenEstimator
53
24
 
54
- EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
55
- EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
56
- EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
25
+ if TYPE_CHECKING:
26
+ from edsl.agents.Agent import Agent
27
+ from edsl.surveys.Survey import Survey
28
+ from edsl.scenarios.Scenario import Scenario
29
+ from edsl.data.Cache import Cache
30
+ from edsl.language_models.LanguageModel import LanguageModel
31
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
57
32
 
58
33
 
59
34
  class Interview:
@@ -70,12 +45,13 @@ class Interview:
70
45
  survey: Survey,
71
46
  scenario: Scenario,
72
47
  model: Type["LanguageModel"],
73
- debug: Optional[bool] = False,
48
+ debug: Optional[bool] = False, # DEPRECATE
74
49
  iteration: int = 0,
75
50
  cache: Optional["Cache"] = None,
76
- sidecar_model: Optional["LanguageModel"] = None,
77
- skip_retry: bool = False,
51
+ sidecar_model: Optional["LanguageModel"] = None, # DEPRECATE
52
+ skip_retry: bool = False, # COULD BE SET WITH CONFIG
78
53
  raise_validation_errors: bool = True,
54
+ indices: dict = None, # explain?
79
55
  ):
80
56
  """Initialize the Interview instance.
81
57
 
@@ -89,7 +65,7 @@ class Interview:
89
65
  :param sidecar_model: a sidecar model used to answer questions.
90
66
 
91
67
  >>> i = Interview.example()
92
- >>> i.task_creators
68
+ >>> i.task_manager.task_creators
93
69
  {}
94
70
 
95
71
  >>> i.exceptions
@@ -104,22 +80,23 @@ class Interview:
104
80
 
105
81
  """
106
82
  self.agent = agent
107
- self.survey = copy.deepcopy(survey)
83
+ self.survey = copy.deepcopy(survey) # why do we need to deepcopy the survey?
108
84
  self.scenario = scenario
109
85
  self.model = model
110
86
  self.debug = debug
111
87
  self.iteration = iteration
112
88
  self.cache = cache
113
- self.answers: dict[
114
- str, str
115
- ] = Answers() # will get filled in as interview progresses
89
+
90
+ self.answers = Answers() # will get filled in as interview progresses
116
91
  self.sidecar_model = sidecar_model
117
92
 
118
- # Trackers
119
- self.task_creators = TaskCreators() # tracks the task creators
93
+ self.task_manager = InterviewTaskManager(
94
+ survey=self.survey,
95
+ iteration=iteration,
96
+ )
97
+
120
98
  self.exceptions = InterviewExceptionCollection()
121
99
 
122
- self._task_status_log_dict = InterviewStatusLog()
123
100
  self.skip_retry = skip_retry
124
101
  self.raise_validation_errors = raise_validation_errors
125
102
 
@@ -131,6 +108,8 @@ class Interview:
131
108
 
132
109
  self.failed_questions = []
133
110
 
111
+ self.indices = indices
112
+
134
113
  @property
135
114
  def has_exceptions(self) -> bool:
136
115
  """Return True if there are exceptions."""
@@ -142,21 +121,18 @@ class Interview:
142
121
 
143
122
  The keys are the question names; the values are the lists of status log changes for each task.
144
123
  """
145
- for task_creator in self.task_creators.values():
146
- self._task_status_log_dict[
147
- task_creator.question.question_name
148
- ] = task_creator.status_log
149
- return self._task_status_log_dict
124
+ return self.task_manager.task_status_logs
150
125
 
151
126
  @property
152
127
  def token_usage(self) -> InterviewTokenUsage:
153
128
  """Determine how many tokens were used for the interview."""
154
- return self.task_creators.token_usage
129
+ return self.task_manager.token_usage # task_creators.token_usage
155
130
 
156
131
  @property
157
132
  def interview_status(self) -> InterviewStatusDictionary:
158
133
  """Return a dictionary mapping task status codes to counts."""
159
- return self.task_creators.interview_status
134
+ # return self.task_creators.interview_status
135
+ return self.task_manager.interview_status
160
136
 
161
137
  # region: Serialization
162
138
  def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
@@ -165,7 +141,7 @@ class Interview:
165
141
 
166
142
  >>> i = Interview.example()
167
143
  >>> hash(i)
168
- 1217840301076717434
144
+ 193593189022259693
169
145
  """
170
146
  d = {
171
147
  "agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
@@ -177,23 +153,34 @@ class Interview:
177
153
  }
178
154
  if include_exceptions:
179
155
  d["exceptions"] = self.exceptions.to_dict()
156
+ if hasattr(self, "indices"):
157
+ d["indices"] = self.indices
180
158
  return d
181
159
 
182
160
  @classmethod
183
161
  def from_dict(cls, d: dict[str, Any]) -> "Interview":
184
162
  """Return an Interview instance from a dictionary."""
163
+
164
+ from edsl.agents.Agent import Agent
165
+ from edsl.surveys.Survey import Survey
166
+ from edsl.scenarios.Scenario import Scenario
167
+ from edsl.language_models.LanguageModel import LanguageModel
168
+
185
169
  agent = Agent.from_dict(d["agent"])
186
170
  survey = Survey.from_dict(d["survey"])
187
171
  scenario = Scenario.from_dict(d["scenario"])
188
172
  model = LanguageModel.from_dict(d["model"])
189
173
  iteration = d["iteration"]
190
- interview = cls(
191
- agent=agent,
192
- survey=survey,
193
- scenario=scenario,
194
- model=model,
195
- iteration=iteration,
196
- )
174
+ params = {
175
+ "agent": agent,
176
+ "survey": survey,
177
+ "scenario": scenario,
178
+ "model": model,
179
+ "iteration": iteration,
180
+ }
181
+ if "indices" in d:
182
+ params["indices"] = d["indices"]
183
+ interview = cls(**params)
197
184
  if "exceptions" in d:
198
185
  exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
199
186
  interview.exceptions = exceptions
@@ -211,304 +198,6 @@ class Interview:
211
198
  """
212
199
  return hash(self) == hash(other)
213
200
 
214
- # endregion
215
-
216
- # region: Creating tasks
217
- @property
218
- def dag(self) -> "DAG":
219
- """Return the directed acyclic graph for the survey.
220
-
221
- The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
222
- It is used to determine the order in which questions should be answered.
223
- This reflects both agent 'memory' considerations and 'skip' logic.
224
- The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
225
-
226
- >>> i = Interview.example()
227
- >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
228
- True
229
- """
230
- return self.survey.dag(textify=True)
231
-
232
- def _build_question_tasks(
233
- self,
234
- model_buckets: ModelBuckets,
235
- ) -> list[asyncio.Task]:
236
- """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
237
-
238
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
239
- :param model_buckets: the model buckets used to track and control usage rates.
240
- """
241
- tasks = []
242
- for question in self.survey.questions:
243
- tasks_that_must_be_completed_before = list(
244
- self._get_tasks_that_must_be_completed_before(
245
- tasks=tasks, question=question
246
- )
247
- )
248
- question_task = self._create_question_task(
249
- question=question,
250
- tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
251
- model_buckets=model_buckets,
252
- iteration=self.iteration,
253
- )
254
- tasks.append(question_task)
255
- return tuple(tasks)
256
-
257
- def _get_tasks_that_must_be_completed_before(
258
- self, *, tasks: list[asyncio.Task], question: "QuestionBase"
259
- ) -> Generator[asyncio.Task, None, None]:
260
- """Return the tasks that must be completed before the given question can be answered.
261
-
262
- :param tasks: a list of tasks that have been created so far.
263
- :param question: the question for which we are determining dependencies.
264
-
265
- If a question has no dependencies, this will be an empty list, [].
266
- """
267
- parents_of_focal_question = self.dag.get(question.question_name, [])
268
- for parent_question_name in parents_of_focal_question:
269
- yield tasks[self.to_index[parent_question_name]]
270
-
271
- def _create_question_task(
272
- self,
273
- *,
274
- question: QuestionBase,
275
- tasks_that_must_be_completed_before: list[asyncio.Task],
276
- model_buckets: ModelBuckets,
277
- iteration: int = 0,
278
- ) -> asyncio.Task:
279
- """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
280
-
281
- :param question: the question to be answered. This is the question we are creating a task for.
282
- :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
283
- :param model_buckets: the model buckets used to track and control usage rates.
284
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
285
- :param iteration: the iteration number for the interview.
286
-
287
- The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
288
- It is passed a reference to the function that will be called to answer the question.
289
- It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
290
- These are added as a dependency to the focal task.
291
- """
292
- task_creator = QuestionTaskCreator(
293
- question=question,
294
- answer_question_func=self._answer_question_and_record_task,
295
- token_estimator=self._get_estimated_request_tokens,
296
- model_buckets=model_buckets,
297
- iteration=iteration,
298
- )
299
- for task in tasks_that_must_be_completed_before:
300
- task_creator.add_dependency(task)
301
-
302
- self.task_creators.update(
303
- {question.question_name: task_creator}
304
- ) # track this task creator
305
- return task_creator.generate_task()
306
-
307
- def _get_estimated_request_tokens(self, question) -> float:
308
- """Estimate the number of tokens that will be required to run the focal task."""
309
- from edsl.scenarios.FileStore import FileStore
310
-
311
- invigilator = self._get_invigilator(question=question)
312
- # TODO: There should be a way to get a more accurate estimate.
313
- combined_text = ""
314
- file_tokens = 0
315
- for prompt in invigilator.get_prompts().values():
316
- if hasattr(prompt, "text"):
317
- combined_text += prompt.text
318
- elif isinstance(prompt, str):
319
- combined_text += prompt
320
- elif isinstance(prompt, list):
321
- for file in prompt:
322
- if isinstance(file, FileStore):
323
- file_tokens += file.size * 0.25
324
- else:
325
- raise ValueError(f"Prompt is of type {type(prompt)}")
326
- return len(combined_text) / 4.0 + file_tokens
327
-
328
- async def _answer_question_and_record_task(
329
- self,
330
- *,
331
- question: "QuestionBase",
332
- task=None,
333
- ) -> "AgentResponseDict":
334
- """Answer a question and records the task."""
335
-
336
- had_language_model_no_response_error = False
337
-
338
- @retry(
339
- stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
340
- wait=wait_exponential(
341
- multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
342
- ),
343
- retry=retry_if_exception_type(LanguageModelNoResponseError),
344
- reraise=True,
345
- )
346
- async def attempt_answer():
347
- nonlocal had_language_model_no_response_error
348
-
349
- invigilator = self._get_invigilator(question)
350
-
351
- if self._skip_this_question(question):
352
- return invigilator.get_failed_task_result(
353
- failure_reason="Question skipped."
354
- )
355
-
356
- try:
357
- response: EDSLResultObjectInput = (
358
- await invigilator.async_answer_question()
359
- )
360
- if response.validated:
361
- self.answers.add_answer(response=response, question=question)
362
- self._cancel_skipped_questions(question)
363
- else:
364
- # When a question is not validated, it is not added to the answers.
365
- # this should also cancel and dependent children questions.
366
- # Is that happening now?
367
- if (
368
- hasattr(response, "exception_occurred")
369
- and response.exception_occurred
370
- ):
371
- raise response.exception_occurred
372
-
373
- except QuestionAnswerValidationError as e:
374
- self._handle_exception(e, invigilator, task)
375
- return invigilator.get_failed_task_result(
376
- failure_reason="Question answer validation failed."
377
- )
378
-
379
- except asyncio.TimeoutError as e:
380
- self._handle_exception(e, invigilator, task)
381
- had_language_model_no_response_error = True
382
- raise LanguageModelNoResponseError(
383
- f"Language model timed out for question '{question.question_name}.'"
384
- )
385
-
386
- except Exception as e:
387
- self._handle_exception(e, invigilator, task)
388
-
389
- if "response" not in locals():
390
- had_language_model_no_response_error = True
391
- raise LanguageModelNoResponseError(
392
- f"Language model did not return a response for question '{question.question_name}.'"
393
- )
394
-
395
- # if it gets here, it means the no response error was fixed
396
- if (
397
- question.question_name in self.exceptions
398
- and had_language_model_no_response_error
399
- ):
400
- self.exceptions.record_fixed_question(question.question_name)
401
-
402
- return response
403
-
404
- try:
405
- return await attempt_answer()
406
- except RetryError as retry_error:
407
- # All retries have failed for LanguageModelNoResponseError
408
- original_error = retry_error.last_attempt.exception()
409
- self._handle_exception(
410
- original_error, self._get_invigilator(question), task
411
- )
412
- raise original_error # Re-raise the original error after handling
413
-
414
- def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
415
- """Return an invigilator for the given question.
416
-
417
- :param question: the question to be answered
418
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
419
- """
420
- invigilator = self.agent.create_invigilator(
421
- question=question,
422
- scenario=self.scenario,
423
- model=self.model,
424
- debug=False,
425
- survey=self.survey,
426
- memory_plan=self.survey.memory_plan,
427
- current_answers=self.answers,
428
- iteration=self.iteration,
429
- cache=self.cache,
430
- sidecar_model=self.sidecar_model,
431
- raise_validation_errors=self.raise_validation_errors,
432
- )
433
- """Return an invigilator for the given question."""
434
- return invigilator
435
-
436
- def _skip_this_question(self, current_question: "QuestionBase") -> bool:
437
- """Determine if the current question should be skipped.
438
-
439
- :param current_question: the question to be answered.
440
- """
441
- current_question_index = self.to_index[current_question.question_name]
442
-
443
- answers = self.answers | self.scenario | self.agent["traits"]
444
- skip = self.survey.rule_collection.skip_question_before_running(
445
- current_question_index, answers
446
- )
447
- return skip
448
-
449
- def _handle_exception(
450
- self, e: Exception, invigilator: "InvigilatorBase", task=None
451
- ):
452
- import copy
453
-
454
- # breakpoint()
455
-
456
- answers = copy.copy(self.answers)
457
- exception_entry = InterviewExceptionEntry(
458
- exception=e,
459
- invigilator=invigilator,
460
- answers=answers,
461
- )
462
- if task:
463
- task.task_status = TaskStatus.FAILED
464
- self.exceptions.add(invigilator.question.question_name, exception_entry)
465
-
466
- if self.raise_validation_errors:
467
- if isinstance(e, QuestionAnswerValidationError):
468
- raise e
469
-
470
- if hasattr(self, "stop_on_exception"):
471
- stop_on_exception = self.stop_on_exception
472
- else:
473
- stop_on_exception = False
474
-
475
- if stop_on_exception:
476
- raise e
477
-
478
- def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
479
- """Cancel the tasks for questions that are skipped.
480
-
481
- :param current_question: the question that was just answered.
482
-
483
- It first determines the next question, given the current question and the current answers.
484
- If the next question is the end of the survey, it cancels all remaining tasks.
485
- If the next question is after the current question, it cancels all tasks between the current question and the next question.
486
- """
487
- current_question_index: int = self.to_index[current_question.question_name]
488
-
489
- next_question: Union[
490
- int, EndOfSurvey
491
- ] = self.survey.rule_collection.next_question(
492
- q_now=current_question_index,
493
- answers=self.answers | self.scenario | self.agent["traits"],
494
- )
495
-
496
- next_question_index = next_question.next_q
497
-
498
- def cancel_between(start, end):
499
- """Cancel the tasks between the start and end indices."""
500
- for i in range(start, end):
501
- self.tasks[i].cancel()
502
-
503
- if next_question_index == EndOfSurvey:
504
- cancel_between(current_question_index + 1, len(self.survey.questions))
505
- return
506
-
507
- if next_question_index > (current_question_index + 1):
508
- cancel_between(current_question_index + 1, next_question_index)
509
-
510
- # endregion
511
-
512
201
  # region: Conducting the interview
513
202
  async def async_conduct_interview(
514
203
  self,
@@ -550,25 +239,35 @@ class Interview:
550
239
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
551
240
  model_buckets = ModelBuckets.infinity_bucket()
552
241
 
242
+ # was "self.tasks" - is that necessary?
243
+ self.tasks = self.task_manager.build_question_tasks(
244
+ answer_func=AnswerQuestionFunctionConstructor(self)(),
245
+ token_estimator=RequestTokenEstimator(self),
246
+ model_buckets=model_buckets,
247
+ )
248
+
553
249
  ## This is the key part---it creates a task for each question,
554
250
  ## with dependencies on the questions that must be answered before this one can be answered.
555
- self.tasks = self._build_question_tasks(model_buckets=model_buckets)
556
251
 
557
- ## 'Invigilators' are used to administer the survey
252
+ ## 'Invigilators' are used to administer the survey.
558
253
  self.invigilators = [
559
- self._get_invigilator(question) for question in self.survey.questions
254
+ FetchInvigilator(interview=self, current_answers=self.answers)(question)
255
+ for question in self.survey.questions
560
256
  ]
561
- await asyncio.gather(
562
- *self.tasks, return_exceptions=not stop_on_exception
563
- ) # not stop_on_exception)
257
+ await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
564
258
  self.answers.replace_missing_answers_with_none(self.survey)
565
- valid_results = list(self._extract_valid_results())
259
+ valid_results = list(
260
+ self._extract_valid_results(self.tasks, self.invigilators, self.exceptions)
261
+ )
566
262
  return self.answers, valid_results
567
263
 
568
264
  # endregion
569
265
 
570
266
  # region: Extracting results and recording errors
571
- def _extract_valid_results(self) -> Generator["Answers", None, None]:
267
+ @staticmethod
268
+ def _extract_valid_results(
269
+ tasks, invigilators: List["InvigilatorABC"], exceptions
270
+ ) -> Generator["Answers", None, None]:
572
271
  """Extract the valid results from the list of results.
573
272
 
574
273
  It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
@@ -577,13 +276,10 @@ class Interview:
577
276
 
578
277
  >>> i = Interview.example()
579
278
  >>> result, _ = asyncio.run(i.async_conduct_interview())
580
- >>> results = list(i._extract_valid_results())
581
- >>> len(results) == len(i.survey)
582
- True
583
279
  """
584
- assert len(self.tasks) == len(self.invigilators)
280
+ assert len(tasks) == len(invigilators)
585
281
 
586
- for task, invigilator in zip(self.tasks, self.invigilators):
282
+ for task, invigilator in zip(tasks, invigilators):
587
283
  if not task.done():
588
284
  raise ValueError(f"Task {task.get_name()} is not done.")
589
285
 
@@ -601,7 +297,7 @@ class Interview:
601
297
  exception=e,
602
298
  invigilator=invigilator,
603
299
  )
604
- self.exceptions.add(task.get_name(), exception_entry)
300
+ exceptions.add(task.get_name(), exception_entry)
605
301
 
606
302
  yield result
607
303
 
@@ -629,6 +325,7 @@ class Interview:
629
325
  iteration=iteration,
630
326
  cache=cache,
631
327
  skip_retry=self.skip_retry,
328
+ indices=self.indices,
632
329
  )
633
330
 
634
331
  @classmethod
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+
4
+ class JobsStatus(Enum):
5
+ QUEUED = "queued"
6
+ RUNNING = "running"
7
+ COMPLETED = "completed"
8
+ FAILED = "failed"
9
+ CANCELLED = "cancelled"