edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -2,44 +2,58 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import asyncio
5
- from typing import Any, Type, List, Generator, Optional, Union, TYPE_CHECKING
5
+ from typing import Any, Type, List, Generator, Optional, Union
6
6
  import copy
7
- from dataclasses import dataclass
8
7
 
9
- # from edsl.jobs.Answers import Answers
10
- from edsl.jobs.data_structures import Answers
8
+ from tenacity import (
9
+ retry,
10
+ stop_after_attempt,
11
+ wait_exponential,
12
+ retry_if_exception_type,
13
+ RetryError,
14
+ )
15
+
16
+ from edsl import CONFIG
17
+ from edsl.surveys.base import EndOfSurvey
18
+ from edsl.exceptions import QuestionAnswerValidationError
19
+ from edsl.exceptions import QuestionAnswerValidationError
20
+ from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
21
+
22
+ from edsl.jobs.buckets.ModelBuckets import ModelBuckets
23
+ from edsl.jobs.Answers import Answers
24
+ from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
25
+ from edsl.jobs.tasks.TaskCreators import TaskCreators
11
26
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
12
- from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
13
27
  from edsl.jobs.interviews.InterviewExceptionCollection import (
14
28
  InterviewExceptionCollection,
15
29
  )
16
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
30
+
31
+ # from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
32
+
33
+ from edsl.surveys.base import EndOfSurvey
17
34
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
18
- from edsl.jobs.AnswerQuestionFunctionConstructor import (
19
- AnswerQuestionFunctionConstructor,
20
- )
21
- from edsl.jobs.InterviewTaskManager import InterviewTaskManager
22
- from edsl.jobs.FetchInvigilator import FetchInvigilator
23
- from edsl.jobs.RequestTokenEstimator import RequestTokenEstimator
35
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
36
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
37
+ from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
38
+
39
+
40
+ from edsl import Agent, Survey, Scenario, Cache
41
+ from edsl.language_models import LanguageModel
42
+ from edsl.questions import QuestionBase
43
+ from edsl.agents.InvigilatorBase import InvigilatorBase
44
+
45
+ from edsl.exceptions.language_models import LanguageModelNoResponseError
24
46
 
47
+ from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
48
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
49
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
25
50
 
26
- if TYPE_CHECKING:
27
- from edsl.agents.Agent import Agent
28
- from edsl.surveys.Survey import Survey
29
- from edsl.scenarios.Scenario import Scenario
30
- from edsl.data.Cache import Cache
31
- from edsl.language_models.LanguageModel import LanguageModel
32
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
33
- from edsl.agents.InvigilatorBase import InvigilatorBase
34
- from edsl.language_models.key_management.KeyLookup import KeyLookup
35
51
 
52
+ from edsl import CONFIG
36
53
 
37
- @dataclass
38
- class InterviewRunningConfig:
39
- cache: Optional["Cache"] = (None,)
40
- skip_retry: bool = (False,) # COULD BE SET WITH CONFIG
41
- raise_validation_errors: bool = (True,)
42
- stop_on_exception: bool = (False,)
54
+ EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
55
+ EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
56
+ EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
43
57
 
44
58
 
45
59
  class Interview:
@@ -56,10 +70,11 @@ class Interview:
56
70
  survey: Survey,
57
71
  scenario: Scenario,
58
72
  model: Type["LanguageModel"],
73
+ debug: Optional[bool] = False,
59
74
  iteration: int = 0,
60
- indices: dict = None, # explain?
61
75
  cache: Optional["Cache"] = None,
62
- skip_retry: bool = False, # COULD BE SET WITH CONFIG
76
+ sidecar_model: Optional["LanguageModel"] = None,
77
+ skip_retry: bool = False,
63
78
  raise_validation_errors: bool = True,
64
79
  ):
65
80
  """Initialize the Interview instance.
@@ -68,12 +83,13 @@ class Interview:
68
83
  :param survey: the survey being administered to the agent.
69
84
  :param scenario: the scenario that populates the survey questions.
70
85
  :param model: the language model used to answer the questions.
71
- # :param debug: if True, run without calls to the language model.
86
+ :param debug: if True, run without calls to the language model.
72
87
  :param iteration: the iteration number of the interview.
73
88
  :param cache: the cache used to store the answers.
89
+ :param sidecar_model: a sidecar model used to answer questions.
74
90
 
75
91
  >>> i = Interview.example()
76
- >>> i.task_manager.task_creators
92
+ >>> i.task_creators
77
93
  {}
78
94
 
79
95
  >>> i.exceptions
@@ -88,27 +104,22 @@ class Interview:
88
104
 
89
105
  """
90
106
  self.agent = agent
91
- self.survey = copy.deepcopy(survey) # why do we need to deepcopy the survey?
107
+ self.survey = copy.deepcopy(survey)
92
108
  self.scenario = scenario
93
109
  self.model = model
110
+ self.debug = debug
94
111
  self.iteration = iteration
112
+ self.cache = cache
113
+ self.answers: dict[
114
+ str, str
115
+ ] = Answers() # will get filled in as interview progresses
116
+ self.sidecar_model = sidecar_model
95
117
 
96
- self.answers = Answers() # will get filled in as interview progresses
97
-
98
- self.task_manager = InterviewTaskManager(
99
- survey=self.survey,
100
- iteration=iteration,
101
- )
102
-
118
+ # Trackers
119
+ self.task_creators = TaskCreators() # tracks the task creators
103
120
  self.exceptions = InterviewExceptionCollection()
104
121
 
105
- self.running_config = InterviewRunningConfig(
106
- cache=cache,
107
- skip_retry=skip_retry,
108
- raise_validation_errors=raise_validation_errors,
109
- )
110
-
111
- self.cache = cache
122
+ self._task_status_log_dict = InterviewStatusLog()
112
123
  self.skip_retry = skip_retry
113
124
  self.raise_validation_errors = raise_validation_errors
114
125
 
@@ -120,9 +131,6 @@ class Interview:
120
131
 
121
132
  self.failed_questions = []
122
133
 
123
- self.indices = indices
124
- self.initial_hash = hash(self)
125
-
126
134
  @property
127
135
  def has_exceptions(self) -> bool:
128
136
  """Return True if there are exceptions."""
@@ -134,26 +142,30 @@ class Interview:
134
142
 
135
143
  The keys are the question names; the values are the lists of status log changes for each task.
136
144
  """
137
- return self.task_manager.task_status_logs
145
+ for task_creator in self.task_creators.values():
146
+ self._task_status_log_dict[
147
+ task_creator.question.question_name
148
+ ] = task_creator.status_log
149
+ return self._task_status_log_dict
138
150
 
139
151
  @property
140
152
  def token_usage(self) -> InterviewTokenUsage:
141
153
  """Determine how many tokens were used for the interview."""
142
- return self.task_manager.token_usage # task_creators.token_usage
154
+ return self.task_creators.token_usage
143
155
 
144
156
  @property
145
157
  def interview_status(self) -> InterviewStatusDictionary:
146
158
  """Return a dictionary mapping task status codes to counts."""
147
- # return self.task_creators.interview_status
148
- return self.task_manager.interview_status
159
+ return self.task_creators.interview_status
149
160
 
161
+ # region: Serialization
150
162
  def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
151
163
  """Return a dictionary representation of the Interview instance.
152
164
  This is just for hashing purposes.
153
165
 
154
166
  >>> i = Interview.example()
155
167
  >>> hash(i)
156
- 193593189022259693
168
+ 1217840301076717434
157
169
  """
158
170
  d = {
159
171
  "agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
@@ -165,34 +177,23 @@ class Interview:
165
177
  }
166
178
  if include_exceptions:
167
179
  d["exceptions"] = self.exceptions.to_dict()
168
- if hasattr(self, "indices"):
169
- d["indices"] = self.indices
170
180
  return d
171
181
 
172
182
  @classmethod
173
183
  def from_dict(cls, d: dict[str, Any]) -> "Interview":
174
184
  """Return an Interview instance from a dictionary."""
175
-
176
- from edsl.agents.Agent import Agent
177
- from edsl.surveys.Survey import Survey
178
- from edsl.scenarios.Scenario import Scenario
179
- from edsl.language_models.LanguageModel import LanguageModel
180
-
181
185
  agent = Agent.from_dict(d["agent"])
182
186
  survey = Survey.from_dict(d["survey"])
183
187
  scenario = Scenario.from_dict(d["scenario"])
184
188
  model = LanguageModel.from_dict(d["model"])
185
189
  iteration = d["iteration"]
186
- params = {
187
- "agent": agent,
188
- "survey": survey,
189
- "scenario": scenario,
190
- "model": model,
191
- "iteration": iteration,
192
- }
193
- if "indices" in d:
194
- params["indices"] = d["indices"]
195
- interview = cls(**params)
190
+ interview = cls(
191
+ agent=agent,
192
+ survey=survey,
193
+ scenario=scenario,
194
+ model=model,
195
+ iteration=iteration,
196
+ )
196
197
  if "exceptions" in d:
197
198
  exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
198
199
  interview.exceptions = exceptions
@@ -210,13 +211,311 @@ class Interview:
210
211
  """
211
212
  return hash(self) == hash(other)
212
213
 
214
+ # endregion
215
+
216
+ # region: Creating tasks
217
+ @property
218
+ def dag(self) -> "DAG":
219
+ """Return the directed acyclic graph for the survey.
220
+
221
+ The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
222
+ It is used to determine the order in which questions should be answered.
223
+ This reflects both agent 'memory' considerations and 'skip' logic.
224
+ The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
225
+
226
+ >>> i = Interview.example()
227
+ >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
228
+ True
229
+ """
230
+ return self.survey.dag(textify=True)
231
+
232
+ def _build_question_tasks(
233
+ self,
234
+ model_buckets: ModelBuckets,
235
+ ) -> list[asyncio.Task]:
236
+ """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
237
+
238
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
239
+ :param model_buckets: the model buckets used to track and control usage rates.
240
+ """
241
+ tasks = []
242
+ for question in self.survey.questions:
243
+ tasks_that_must_be_completed_before = list(
244
+ self._get_tasks_that_must_be_completed_before(
245
+ tasks=tasks, question=question
246
+ )
247
+ )
248
+ question_task = self._create_question_task(
249
+ question=question,
250
+ tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
251
+ model_buckets=model_buckets,
252
+ iteration=self.iteration,
253
+ )
254
+ tasks.append(question_task)
255
+ return tuple(tasks)
256
+
257
+ def _get_tasks_that_must_be_completed_before(
258
+ self, *, tasks: list[asyncio.Task], question: "QuestionBase"
259
+ ) -> Generator[asyncio.Task, None, None]:
260
+ """Return the tasks that must be completed before the given question can be answered.
261
+
262
+ :param tasks: a list of tasks that have been created so far.
263
+ :param question: the question for which we are determining dependencies.
264
+
265
+ If a question has no dependencies, this will be an empty list, [].
266
+ """
267
+ parents_of_focal_question = self.dag.get(question.question_name, [])
268
+ for parent_question_name in parents_of_focal_question:
269
+ yield tasks[self.to_index[parent_question_name]]
270
+
271
+ def _create_question_task(
272
+ self,
273
+ *,
274
+ question: QuestionBase,
275
+ tasks_that_must_be_completed_before: list[asyncio.Task],
276
+ model_buckets: ModelBuckets,
277
+ iteration: int = 0,
278
+ ) -> asyncio.Task:
279
+ """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
280
+
281
+ :param question: the question to be answered. This is the question we are creating a task for.
282
+ :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
283
+ :param model_buckets: the model buckets used to track and control usage rates.
284
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
285
+ :param iteration: the iteration number for the interview.
286
+
287
+ The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
288
+ It is passed a reference to the function that will be called to answer the question.
289
+ It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
290
+ These are added as a dependency to the focal task.
291
+ """
292
+ task_creator = QuestionTaskCreator(
293
+ question=question,
294
+ answer_question_func=self._answer_question_and_record_task,
295
+ token_estimator=self._get_estimated_request_tokens,
296
+ model_buckets=model_buckets,
297
+ iteration=iteration,
298
+ )
299
+ for task in tasks_that_must_be_completed_before:
300
+ task_creator.add_dependency(task)
301
+
302
+ self.task_creators.update(
303
+ {question.question_name: task_creator}
304
+ ) # track this task creator
305
+ return task_creator.generate_task()
306
+
307
+ def _get_estimated_request_tokens(self, question) -> float:
308
+ """Estimate the number of tokens that will be required to run the focal task."""
309
+ from edsl.scenarios.FileStore import FileStore
310
+
311
+ invigilator = self._get_invigilator(question=question)
312
+ # TODO: There should be a way to get a more accurate estimate.
313
+ combined_text = ""
314
+ file_tokens = 0
315
+ for prompt in invigilator.get_prompts().values():
316
+ if hasattr(prompt, "text"):
317
+ combined_text += prompt.text
318
+ elif isinstance(prompt, str):
319
+ combined_text += prompt
320
+ elif isinstance(prompt, list):
321
+ for file in prompt:
322
+ if isinstance(file, FileStore):
323
+ file_tokens += file.size * 0.25
324
+ else:
325
+ raise ValueError(f"Prompt is of type {type(prompt)}")
326
+ return len(combined_text) / 4.0 + file_tokens
327
+
328
+ async def _answer_question_and_record_task(
329
+ self,
330
+ *,
331
+ question: "QuestionBase",
332
+ task=None,
333
+ ) -> "AgentResponseDict":
334
+ """Answer a question and records the task."""
335
+
336
+ had_language_model_no_response_error = False
337
+
338
+ @retry(
339
+ stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
340
+ wait=wait_exponential(
341
+ multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
342
+ ),
343
+ retry=retry_if_exception_type(LanguageModelNoResponseError),
344
+ reraise=True,
345
+ )
346
+ async def attempt_answer():
347
+ nonlocal had_language_model_no_response_error
348
+
349
+ invigilator = self._get_invigilator(question)
350
+
351
+ if self._skip_this_question(question):
352
+ return invigilator.get_failed_task_result(
353
+ failure_reason="Question skipped."
354
+ )
355
+
356
+ try:
357
+ response: EDSLResultObjectInput = (
358
+ await invigilator.async_answer_question()
359
+ )
360
+ if response.validated:
361
+ self.answers.add_answer(response=response, question=question)
362
+ self._cancel_skipped_questions(question)
363
+ else:
364
+ # When a question is not validated, it is not added to the answers.
365
+ # this should also cancel and dependent children questions.
366
+ # Is that happening now?
367
+ if (
368
+ hasattr(response, "exception_occurred")
369
+ and response.exception_occurred
370
+ ):
371
+ raise response.exception_occurred
372
+
373
+ except QuestionAnswerValidationError as e:
374
+ self._handle_exception(e, invigilator, task)
375
+ return invigilator.get_failed_task_result(
376
+ failure_reason="Question answer validation failed."
377
+ )
378
+
379
+ except asyncio.TimeoutError as e:
380
+ self._handle_exception(e, invigilator, task)
381
+ had_language_model_no_response_error = True
382
+ raise LanguageModelNoResponseError(
383
+ f"Language model timed out for question '{question.question_name}.'"
384
+ )
385
+
386
+ except Exception as e:
387
+ self._handle_exception(e, invigilator, task)
388
+
389
+ if "response" not in locals():
390
+ had_language_model_no_response_error = True
391
+ raise LanguageModelNoResponseError(
392
+ f"Language model did not return a response for question '{question.question_name}.'"
393
+ )
394
+
395
+ # if it gets here, it means the no response error was fixed
396
+ if (
397
+ question.question_name in self.exceptions
398
+ and had_language_model_no_response_error
399
+ ):
400
+ self.exceptions.record_fixed_question(question.question_name)
401
+
402
+ return response
403
+
404
+ try:
405
+ return await attempt_answer()
406
+ except RetryError as retry_error:
407
+ # All retries have failed for LanguageModelNoResponseError
408
+ original_error = retry_error.last_attempt.exception()
409
+ self._handle_exception(
410
+ original_error, self._get_invigilator(question), task
411
+ )
412
+ raise original_error # Re-raise the original error after handling
413
+
414
+ def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
415
+ """Return an invigilator for the given question.
416
+
417
+ :param question: the question to be answered
418
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
419
+ """
420
+ invigilator = self.agent.create_invigilator(
421
+ question=question,
422
+ scenario=self.scenario,
423
+ model=self.model,
424
+ debug=False,
425
+ survey=self.survey,
426
+ memory_plan=self.survey.memory_plan,
427
+ current_answers=self.answers,
428
+ iteration=self.iteration,
429
+ cache=self.cache,
430
+ sidecar_model=self.sidecar_model,
431
+ raise_validation_errors=self.raise_validation_errors,
432
+ )
433
+ """Return an invigilator for the given question."""
434
+ return invigilator
435
+
436
+ def _skip_this_question(self, current_question: "QuestionBase") -> bool:
437
+ """Determine if the current question should be skipped.
438
+
439
+ :param current_question: the question to be answered.
440
+ """
441
+ current_question_index = self.to_index[current_question.question_name]
442
+
443
+ answers = self.answers | self.scenario | self.agent["traits"]
444
+ skip = self.survey.rule_collection.skip_question_before_running(
445
+ current_question_index, answers
446
+ )
447
+ return skip
448
+
449
+ def _handle_exception(
450
+ self, e: Exception, invigilator: "InvigilatorBase", task=None
451
+ ):
452
+ import copy
453
+
454
+ # breakpoint()
455
+
456
+ answers = copy.copy(self.answers)
457
+ exception_entry = InterviewExceptionEntry(
458
+ exception=e,
459
+ invigilator=invigilator,
460
+ answers=answers,
461
+ )
462
+ if task:
463
+ task.task_status = TaskStatus.FAILED
464
+ self.exceptions.add(invigilator.question.question_name, exception_entry)
465
+
466
+ if self.raise_validation_errors:
467
+ if isinstance(e, QuestionAnswerValidationError):
468
+ raise e
469
+
470
+ if hasattr(self, "stop_on_exception"):
471
+ stop_on_exception = self.stop_on_exception
472
+ else:
473
+ stop_on_exception = False
474
+
475
+ if stop_on_exception:
476
+ raise e
477
+
478
+ def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
479
+ """Cancel the tasks for questions that are skipped.
480
+
481
+ :param current_question: the question that was just answered.
482
+
483
+ It first determines the next question, given the current question and the current answers.
484
+ If the next question is the end of the survey, it cancels all remaining tasks.
485
+ If the next question is after the current question, it cancels all tasks between the current question and the next question.
486
+ """
487
+ current_question_index: int = self.to_index[current_question.question_name]
488
+
489
+ next_question: Union[
490
+ int, EndOfSurvey
491
+ ] = self.survey.rule_collection.next_question(
492
+ q_now=current_question_index,
493
+ answers=self.answers | self.scenario | self.agent["traits"],
494
+ )
495
+
496
+ next_question_index = next_question.next_q
497
+
498
+ def cancel_between(start, end):
499
+ """Cancel the tasks between the start and end indices."""
500
+ for i in range(start, end):
501
+ self.tasks[i].cancel()
502
+
503
+ if next_question_index == EndOfSurvey:
504
+ cancel_between(current_question_index + 1, len(self.survey.questions))
505
+ return
506
+
507
+ if next_question_index > (current_question_index + 1):
508
+ cancel_between(current_question_index + 1, next_question_index)
509
+
510
+ # endregion
511
+
512
+ # region: Conducting the interview
213
513
  async def async_conduct_interview(
214
514
  self,
215
- run_config: Optional["RunConfig"] = None,
216
- # model_buckets: Optional[ModelBuckets] = None,
217
- # stop_on_exception: bool = False,
218
- # raise_validation_errors: bool = True,
219
- # key_lookup: Optional[KeyLookup] = None,
515
+ model_buckets: Optional[ModelBuckets] = None,
516
+ stop_on_exception: bool = False,
517
+ sidecar_model: Optional["LanguageModel"] = None,
518
+ raise_validation_errors: bool = True,
220
519
  ) -> tuple["Answers", List[dict[str, Any]]]:
221
520
  """
222
521
  Conduct an Interview asynchronously.
@@ -225,6 +524,7 @@ class Interview:
225
524
  :param model_buckets: a dictionary of token buckets for the model.
226
525
  :param debug: run without calls to LLM.
227
526
  :param stop_on_exception: if True, stops the interview if an exception is raised.
527
+ :param sidecar_model: a sidecar model used to answer questions.
228
528
 
229
529
  Example usage:
230
530
 
@@ -238,68 +538,37 @@ class Interview:
238
538
  >>> i.exceptions
239
539
  {'q0': ...
240
540
  >>> i = Interview.example()
241
- >>> from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
242
- >>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
243
- >>> run_config.parameters.stop_on_exception = True
244
- >>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
541
+ >>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
245
542
  Traceback (most recent call last):
246
543
  ...
247
544
  asyncio.exceptions.CancelledError
248
545
  """
249
- from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
250
-
251
- if run_config is None:
252
- run_config = RunConfig(
253
- parameters=RunParameters(),
254
- environment=RunEnvironment(),
255
- )
256
- self.stop_on_exception = run_config.parameters.stop_on_exception
546
+ self.sidecar_model = sidecar_model
547
+ self.stop_on_exception = stop_on_exception
257
548
 
258
549
  # if no model bucket is passed, create an 'infinity' bucket with no rate limits
259
- bucket_collection = run_config.environment.bucket_collection
260
-
261
- if bucket_collection:
262
- model_buckets = bucket_collection.get(self.model)
263
- else:
264
- model_buckets = None
265
-
266
550
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
267
551
  model_buckets = ModelBuckets.infinity_bucket()
268
552
 
269
- # was "self.tasks" - is that necessary?
270
- self.tasks = self.task_manager.build_question_tasks(
271
- answer_func=AnswerQuestionFunctionConstructor(
272
- self, key_lookup=run_config.environment.key_lookup
273
- )(),
274
- token_estimator=RequestTokenEstimator(self),
275
- model_buckets=model_buckets,
276
- )
277
-
278
553
  ## This is the key part---it creates a task for each question,
279
554
  ## with dependencies on the questions that must be answered before this one can be answered.
555
+ self.tasks = self._build_question_tasks(model_buckets=model_buckets)
280
556
 
281
- ## 'Invigilators' are used to administer the survey.
282
- fetcher = FetchInvigilator(
283
- interview=self,
284
- current_answers=self.answers,
285
- key_lookup=run_config.environment.key_lookup,
286
- )
287
- self.invigilators = [fetcher(question) for question in self.survey.questions]
557
+ ## 'Invigilators' are used to administer the survey
558
+ self.invigilators = [
559
+ self._get_invigilator(question) for question in self.survey.questions
560
+ ]
288
561
  await asyncio.gather(
289
- *self.tasks, return_exceptions=not run_config.parameters.stop_on_exception
290
- )
562
+ *self.tasks, return_exceptions=not stop_on_exception
563
+ ) # not stop_on_exception)
291
564
  self.answers.replace_missing_answers_with_none(self.survey)
292
- valid_results = list(
293
- self._extract_valid_results(self.tasks, self.invigilators, self.exceptions)
294
- )
565
+ valid_results = list(self._extract_valid_results())
295
566
  return self.answers, valid_results
296
567
 
297
- @staticmethod
298
- def _extract_valid_results(
299
- tasks: List["asyncio.Task"],
300
- invigilators: List["InvigilatorBase"],
301
- exceptions: InterviewExceptionCollection,
302
- ) -> Generator["Answers", None, None]:
568
+ # endregion
569
+
570
+ # region: Extracting results and recording errors
571
+ def _extract_valid_results(self) -> Generator["Answers", None, None]:
303
572
  """Extract the valid results from the list of results.
304
573
 
305
574
  It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
@@ -308,10 +577,16 @@ class Interview:
308
577
 
309
578
  >>> i = Interview.example()
310
579
  >>> result, _ = asyncio.run(i.async_conduct_interview())
580
+ >>> results = list(i._extract_valid_results())
581
+ >>> len(results) == len(i.survey)
582
+ True
311
583
  """
312
- assert len(tasks) == len(invigilators)
584
+ assert len(self.tasks) == len(self.invigilators)
585
+
586
+ for task, invigilator in zip(self.tasks, self.invigilators):
587
+ if not task.done():
588
+ raise ValueError(f"Task {task.get_name()} is not done.")
313
589
 
314
- def handle_task(task, invigilator):
315
590
  try:
316
591
  result = task.result()
317
592
  except asyncio.CancelledError as e: # task was cancelled
@@ -326,22 +601,18 @@ class Interview:
326
601
  exception=e,
327
602
  invigilator=invigilator,
328
603
  )
329
- exceptions.add(task.get_name(), exception_entry)
330
- return result
604
+ self.exceptions.add(task.get_name(), exception_entry)
331
605
 
332
- for task, invigilator in zip(tasks, invigilators):
333
- if not task.done():
334
- raise ValueError(f"Task {task.get_name()} is not done.")
606
+ yield result
335
607
 
336
- yield handle_task(task, invigilator)
608
+ # endregion
337
609
 
610
+ # region: Magic methods
338
611
  def __repr__(self) -> str:
339
612
  """Return a string representation of the Interview instance."""
340
613
  return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
341
614
 
342
- def duplicate(
343
- self, iteration: int, cache: "Cache", randomize_survey: Optional[bool] = True
344
- ) -> Interview:
615
+ def duplicate(self, iteration: int, cache: "Cache") -> Interview:
345
616
  """Duplicate the interview, but with a new iteration number and cache.
346
617
 
347
618
  >>> i = Interview.example()
@@ -350,20 +621,14 @@ class Interview:
350
621
  True
351
622
 
352
623
  """
353
- if randomize_survey:
354
- new_survey = self.survey.draw()
355
- else:
356
- new_survey = self.survey
357
-
358
624
  return Interview(
359
625
  agent=self.agent,
360
- survey=new_survey,
626
+ survey=self.survey,
361
627
  scenario=self.scenario,
362
628
  model=self.model,
363
629
  iteration=iteration,
364
- cache=self.running_config.cache,
365
- skip_retry=self.running_config.skip_retry,
366
- indices=self.indices,
630
+ cache=cache,
631
+ skip_retry=self.skip_retry,
367
632
  )
368
633
 
369
634
  @classmethod