edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -2,58 +2,44 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import asyncio
5
- from typing import Any, Type, List, Generator, Optional, Union
5
+ from typing import Any, Type, List, Generator, Optional, Union, TYPE_CHECKING
6
6
  import copy
7
+ from dataclasses import dataclass
7
8
 
8
- from tenacity import (
9
- retry,
10
- stop_after_attempt,
11
- wait_exponential,
12
- retry_if_exception_type,
13
- RetryError,
14
- )
15
-
16
- from edsl import CONFIG
17
- from edsl.surveys.base import EndOfSurvey
18
- from edsl.exceptions import QuestionAnswerValidationError
19
- from edsl.exceptions import QuestionAnswerValidationError
20
- from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
21
-
22
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
23
- from edsl.jobs.Answers import Answers
24
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
25
- from edsl.jobs.tasks.TaskCreators import TaskCreators
9
+ # from edsl.jobs.Answers import Answers
10
+ from edsl.jobs.data_structures import Answers
26
11
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
12
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
27
13
  from edsl.jobs.interviews.InterviewExceptionCollection import (
28
14
  InterviewExceptionCollection,
29
15
  )
30
-
31
- # from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
32
-
33
- from edsl.surveys.base import EndOfSurvey
34
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
35
16
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
36
- from edsl.jobs.tasks.task_status_enum import TaskStatus
37
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
38
-
39
-
40
- from edsl import Agent, Survey, Scenario, Cache
41
- from edsl.language_models import LanguageModel
42
- from edsl.questions import QuestionBase
43
- from edsl.agents.InvigilatorBase import InvigilatorBase
44
-
45
- from edsl.exceptions.language_models import LanguageModelNoResponseError
17
+ from edsl.jobs.buckets.ModelBuckets import ModelBuckets
18
+ from edsl.jobs.AnswerQuestionFunctionConstructor import (
19
+ AnswerQuestionFunctionConstructor,
20
+ )
21
+ from edsl.jobs.InterviewTaskManager import InterviewTaskManager
22
+ from edsl.jobs.FetchInvigilator import FetchInvigilator
23
+ from edsl.jobs.RequestTokenEstimator import RequestTokenEstimator
46
24
 
47
- from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
48
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
49
- from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
50
25
 
26
+ if TYPE_CHECKING:
27
+ from edsl.agents.Agent import Agent
28
+ from edsl.surveys.Survey import Survey
29
+ from edsl.scenarios.Scenario import Scenario
30
+ from edsl.data.Cache import Cache
31
+ from edsl.language_models.LanguageModel import LanguageModel
32
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
33
+ from edsl.agents.InvigilatorBase import InvigilatorBase
34
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
51
35
 
52
- from edsl import CONFIG
53
36
 
54
- EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
55
- EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
56
- EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
37
+ @dataclass
38
+ class InterviewRunningConfig:
39
+ cache: Optional["Cache"] = (None,)
40
+ skip_retry: bool = (False,) # COULD BE SET WITH CONFIG
41
+ raise_validation_errors: bool = (True,)
42
+ stop_on_exception: bool = (False,)
57
43
 
58
44
 
59
45
  class Interview:
@@ -70,11 +56,10 @@ class Interview:
70
56
  survey: Survey,
71
57
  scenario: Scenario,
72
58
  model: Type["LanguageModel"],
73
- debug: Optional[bool] = False,
74
59
  iteration: int = 0,
60
+ indices: dict = None, # explain?
75
61
  cache: Optional["Cache"] = None,
76
- sidecar_model: Optional["LanguageModel"] = None,
77
- skip_retry: bool = False,
62
+ skip_retry: bool = False, # COULD BE SET WITH CONFIG
78
63
  raise_validation_errors: bool = True,
79
64
  ):
80
65
  """Initialize the Interview instance.
@@ -83,13 +68,12 @@ class Interview:
83
68
  :param survey: the survey being administered to the agent.
84
69
  :param scenario: the scenario that populates the survey questions.
85
70
  :param model: the language model used to answer the questions.
86
- :param debug: if True, run without calls to the language model.
71
+ # :param debug: if True, run without calls to the language model.
87
72
  :param iteration: the iteration number of the interview.
88
73
  :param cache: the cache used to store the answers.
89
- :param sidecar_model: a sidecar model used to answer questions.
90
74
 
91
75
  >>> i = Interview.example()
92
- >>> i.task_creators
76
+ >>> i.task_manager.task_creators
93
77
  {}
94
78
 
95
79
  >>> i.exceptions
@@ -104,22 +88,27 @@ class Interview:
104
88
 
105
89
  """
106
90
  self.agent = agent
107
- self.survey = copy.deepcopy(survey)
91
+ self.survey = copy.deepcopy(survey) # why do we need to deepcopy the survey?
108
92
  self.scenario = scenario
109
93
  self.model = model
110
- self.debug = debug
111
94
  self.iteration = iteration
112
- self.cache = cache
113
- self.answers: dict[
114
- str, str
115
- ] = Answers() # will get filled in as interview progresses
116
- self.sidecar_model = sidecar_model
117
95
 
118
- # Trackers
119
- self.task_creators = TaskCreators() # tracks the task creators
96
+ self.answers = Answers() # will get filled in as interview progresses
97
+
98
+ self.task_manager = InterviewTaskManager(
99
+ survey=self.survey,
100
+ iteration=iteration,
101
+ )
102
+
120
103
  self.exceptions = InterviewExceptionCollection()
121
104
 
122
- self._task_status_log_dict = InterviewStatusLog()
105
+ self.running_config = InterviewRunningConfig(
106
+ cache=cache,
107
+ skip_retry=skip_retry,
108
+ raise_validation_errors=raise_validation_errors,
109
+ )
110
+
111
+ self.cache = cache
123
112
  self.skip_retry = skip_retry
124
113
  self.raise_validation_errors = raise_validation_errors
125
114
 
@@ -131,6 +120,9 @@ class Interview:
131
120
 
132
121
  self.failed_questions = []
133
122
 
123
+ self.indices = indices
124
+ self.initial_hash = hash(self)
125
+
134
126
  @property
135
127
  def has_exceptions(self) -> bool:
136
128
  """Return True if there are exceptions."""
@@ -142,30 +134,26 @@ class Interview:
142
134
 
143
135
  The keys are the question names; the values are the lists of status log changes for each task.
144
136
  """
145
- for task_creator in self.task_creators.values():
146
- self._task_status_log_dict[
147
- task_creator.question.question_name
148
- ] = task_creator.status_log
149
- return self._task_status_log_dict
137
+ return self.task_manager.task_status_logs
150
138
 
151
139
  @property
152
140
  def token_usage(self) -> InterviewTokenUsage:
153
141
  """Determine how many tokens were used for the interview."""
154
- return self.task_creators.token_usage
142
+ return self.task_manager.token_usage # task_creators.token_usage
155
143
 
156
144
  @property
157
145
  def interview_status(self) -> InterviewStatusDictionary:
158
146
  """Return a dictionary mapping task status codes to counts."""
159
- return self.task_creators.interview_status
147
+ # return self.task_creators.interview_status
148
+ return self.task_manager.interview_status
160
149
 
161
- # region: Serialization
162
150
  def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
163
151
  """Return a dictionary representation of the Interview instance.
164
152
  This is just for hashing purposes.
165
153
 
166
154
  >>> i = Interview.example()
167
155
  >>> hash(i)
168
- 1217840301076717434
156
+ 193593189022259693
169
157
  """
170
158
  d = {
171
159
  "agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
@@ -177,23 +165,34 @@ class Interview:
177
165
  }
178
166
  if include_exceptions:
179
167
  d["exceptions"] = self.exceptions.to_dict()
168
+ if hasattr(self, "indices"):
169
+ d["indices"] = self.indices
180
170
  return d
181
171
 
182
172
  @classmethod
183
173
  def from_dict(cls, d: dict[str, Any]) -> "Interview":
184
174
  """Return an Interview instance from a dictionary."""
175
+
176
+ from edsl.agents.Agent import Agent
177
+ from edsl.surveys.Survey import Survey
178
+ from edsl.scenarios.Scenario import Scenario
179
+ from edsl.language_models.LanguageModel import LanguageModel
180
+
185
181
  agent = Agent.from_dict(d["agent"])
186
182
  survey = Survey.from_dict(d["survey"])
187
183
  scenario = Scenario.from_dict(d["scenario"])
188
184
  model = LanguageModel.from_dict(d["model"])
189
185
  iteration = d["iteration"]
190
- interview = cls(
191
- agent=agent,
192
- survey=survey,
193
- scenario=scenario,
194
- model=model,
195
- iteration=iteration,
196
- )
186
+ params = {
187
+ "agent": agent,
188
+ "survey": survey,
189
+ "scenario": scenario,
190
+ "model": model,
191
+ "iteration": iteration,
192
+ }
193
+ if "indices" in d:
194
+ params["indices"] = d["indices"]
195
+ interview = cls(**params)
197
196
  if "exceptions" in d:
198
197
  exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
199
198
  interview.exceptions = exceptions
@@ -211,311 +210,13 @@ class Interview:
211
210
  """
212
211
  return hash(self) == hash(other)
213
212
 
214
- # endregion
215
-
216
- # region: Creating tasks
217
- @property
218
- def dag(self) -> "DAG":
219
- """Return the directed acyclic graph for the survey.
220
-
221
- The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
222
- It is used to determine the order in which questions should be answered.
223
- This reflects both agent 'memory' considerations and 'skip' logic.
224
- The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
225
-
226
- >>> i = Interview.example()
227
- >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
228
- True
229
- """
230
- return self.survey.dag(textify=True)
231
-
232
- def _build_question_tasks(
233
- self,
234
- model_buckets: ModelBuckets,
235
- ) -> list[asyncio.Task]:
236
- """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
237
-
238
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
239
- :param model_buckets: the model buckets used to track and control usage rates.
240
- """
241
- tasks = []
242
- for question in self.survey.questions:
243
- tasks_that_must_be_completed_before = list(
244
- self._get_tasks_that_must_be_completed_before(
245
- tasks=tasks, question=question
246
- )
247
- )
248
- question_task = self._create_question_task(
249
- question=question,
250
- tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
251
- model_buckets=model_buckets,
252
- iteration=self.iteration,
253
- )
254
- tasks.append(question_task)
255
- return tuple(tasks)
256
-
257
- def _get_tasks_that_must_be_completed_before(
258
- self, *, tasks: list[asyncio.Task], question: "QuestionBase"
259
- ) -> Generator[asyncio.Task, None, None]:
260
- """Return the tasks that must be completed before the given question can be answered.
261
-
262
- :param tasks: a list of tasks that have been created so far.
263
- :param question: the question for which we are determining dependencies.
264
-
265
- If a question has no dependencies, this will be an empty list, [].
266
- """
267
- parents_of_focal_question = self.dag.get(question.question_name, [])
268
- for parent_question_name in parents_of_focal_question:
269
- yield tasks[self.to_index[parent_question_name]]
270
-
271
- def _create_question_task(
272
- self,
273
- *,
274
- question: QuestionBase,
275
- tasks_that_must_be_completed_before: list[asyncio.Task],
276
- model_buckets: ModelBuckets,
277
- iteration: int = 0,
278
- ) -> asyncio.Task:
279
- """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
280
-
281
- :param question: the question to be answered. This is the question we are creating a task for.
282
- :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
283
- :param model_buckets: the model buckets used to track and control usage rates.
284
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
285
- :param iteration: the iteration number for the interview.
286
-
287
- The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
288
- It is passed a reference to the function that will be called to answer the question.
289
- It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
290
- These are added as a dependency to the focal task.
291
- """
292
- task_creator = QuestionTaskCreator(
293
- question=question,
294
- answer_question_func=self._answer_question_and_record_task,
295
- token_estimator=self._get_estimated_request_tokens,
296
- model_buckets=model_buckets,
297
- iteration=iteration,
298
- )
299
- for task in tasks_that_must_be_completed_before:
300
- task_creator.add_dependency(task)
301
-
302
- self.task_creators.update(
303
- {question.question_name: task_creator}
304
- ) # track this task creator
305
- return task_creator.generate_task()
306
-
307
- def _get_estimated_request_tokens(self, question) -> float:
308
- """Estimate the number of tokens that will be required to run the focal task."""
309
- from edsl.scenarios.FileStore import FileStore
310
-
311
- invigilator = self._get_invigilator(question=question)
312
- # TODO: There should be a way to get a more accurate estimate.
313
- combined_text = ""
314
- file_tokens = 0
315
- for prompt in invigilator.get_prompts().values():
316
- if hasattr(prompt, "text"):
317
- combined_text += prompt.text
318
- elif isinstance(prompt, str):
319
- combined_text += prompt
320
- elif isinstance(prompt, list):
321
- for file in prompt:
322
- if isinstance(file, FileStore):
323
- file_tokens += file.size * 0.25
324
- else:
325
- raise ValueError(f"Prompt is of type {type(prompt)}")
326
- return len(combined_text) / 4.0 + file_tokens
327
-
328
- async def _answer_question_and_record_task(
329
- self,
330
- *,
331
- question: "QuestionBase",
332
- task=None,
333
- ) -> "AgentResponseDict":
334
- """Answer a question and records the task."""
335
-
336
- had_language_model_no_response_error = False
337
-
338
- @retry(
339
- stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
340
- wait=wait_exponential(
341
- multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
342
- ),
343
- retry=retry_if_exception_type(LanguageModelNoResponseError),
344
- reraise=True,
345
- )
346
- async def attempt_answer():
347
- nonlocal had_language_model_no_response_error
348
-
349
- invigilator = self._get_invigilator(question)
350
-
351
- if self._skip_this_question(question):
352
- return invigilator.get_failed_task_result(
353
- failure_reason="Question skipped."
354
- )
355
-
356
- try:
357
- response: EDSLResultObjectInput = (
358
- await invigilator.async_answer_question()
359
- )
360
- if response.validated:
361
- self.answers.add_answer(response=response, question=question)
362
- self._cancel_skipped_questions(question)
363
- else:
364
- # When a question is not validated, it is not added to the answers.
365
- # this should also cancel and dependent children questions.
366
- # Is that happening now?
367
- if (
368
- hasattr(response, "exception_occurred")
369
- and response.exception_occurred
370
- ):
371
- raise response.exception_occurred
372
-
373
- except QuestionAnswerValidationError as e:
374
- self._handle_exception(e, invigilator, task)
375
- return invigilator.get_failed_task_result(
376
- failure_reason="Question answer validation failed."
377
- )
378
-
379
- except asyncio.TimeoutError as e:
380
- self._handle_exception(e, invigilator, task)
381
- had_language_model_no_response_error = True
382
- raise LanguageModelNoResponseError(
383
- f"Language model timed out for question '{question.question_name}.'"
384
- )
385
-
386
- except Exception as e:
387
- self._handle_exception(e, invigilator, task)
388
-
389
- if "response" not in locals():
390
- had_language_model_no_response_error = True
391
- raise LanguageModelNoResponseError(
392
- f"Language model did not return a response for question '{question.question_name}.'"
393
- )
394
-
395
- # if it gets here, it means the no response error was fixed
396
- if (
397
- question.question_name in self.exceptions
398
- and had_language_model_no_response_error
399
- ):
400
- self.exceptions.record_fixed_question(question.question_name)
401
-
402
- return response
403
-
404
- try:
405
- return await attempt_answer()
406
- except RetryError as retry_error:
407
- # All retries have failed for LanguageModelNoResponseError
408
- original_error = retry_error.last_attempt.exception()
409
- self._handle_exception(
410
- original_error, self._get_invigilator(question), task
411
- )
412
- raise original_error # Re-raise the original error after handling
413
-
414
- def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
415
- """Return an invigilator for the given question.
416
-
417
- :param question: the question to be answered
418
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
419
- """
420
- invigilator = self.agent.create_invigilator(
421
- question=question,
422
- scenario=self.scenario,
423
- model=self.model,
424
- debug=False,
425
- survey=self.survey,
426
- memory_plan=self.survey.memory_plan,
427
- current_answers=self.answers,
428
- iteration=self.iteration,
429
- cache=self.cache,
430
- sidecar_model=self.sidecar_model,
431
- raise_validation_errors=self.raise_validation_errors,
432
- )
433
- """Return an invigilator for the given question."""
434
- return invigilator
435
-
436
- def _skip_this_question(self, current_question: "QuestionBase") -> bool:
437
- """Determine if the current question should be skipped.
438
-
439
- :param current_question: the question to be answered.
440
- """
441
- current_question_index = self.to_index[current_question.question_name]
442
-
443
- answers = self.answers | self.scenario | self.agent["traits"]
444
- skip = self.survey.rule_collection.skip_question_before_running(
445
- current_question_index, answers
446
- )
447
- return skip
448
-
449
- def _handle_exception(
450
- self, e: Exception, invigilator: "InvigilatorBase", task=None
451
- ):
452
- import copy
453
-
454
- # breakpoint()
455
-
456
- answers = copy.copy(self.answers)
457
- exception_entry = InterviewExceptionEntry(
458
- exception=e,
459
- invigilator=invigilator,
460
- answers=answers,
461
- )
462
- if task:
463
- task.task_status = TaskStatus.FAILED
464
- self.exceptions.add(invigilator.question.question_name, exception_entry)
465
-
466
- if self.raise_validation_errors:
467
- if isinstance(e, QuestionAnswerValidationError):
468
- raise e
469
-
470
- if hasattr(self, "stop_on_exception"):
471
- stop_on_exception = self.stop_on_exception
472
- else:
473
- stop_on_exception = False
474
-
475
- if stop_on_exception:
476
- raise e
477
-
478
- def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
479
- """Cancel the tasks for questions that are skipped.
480
-
481
- :param current_question: the question that was just answered.
482
-
483
- It first determines the next question, given the current question and the current answers.
484
- If the next question is the end of the survey, it cancels all remaining tasks.
485
- If the next question is after the current question, it cancels all tasks between the current question and the next question.
486
- """
487
- current_question_index: int = self.to_index[current_question.question_name]
488
-
489
- next_question: Union[
490
- int, EndOfSurvey
491
- ] = self.survey.rule_collection.next_question(
492
- q_now=current_question_index,
493
- answers=self.answers | self.scenario | self.agent["traits"],
494
- )
495
-
496
- next_question_index = next_question.next_q
497
-
498
- def cancel_between(start, end):
499
- """Cancel the tasks between the start and end indices."""
500
- for i in range(start, end):
501
- self.tasks[i].cancel()
502
-
503
- if next_question_index == EndOfSurvey:
504
- cancel_between(current_question_index + 1, len(self.survey.questions))
505
- return
506
-
507
- if next_question_index > (current_question_index + 1):
508
- cancel_between(current_question_index + 1, next_question_index)
509
-
510
- # endregion
511
-
512
- # region: Conducting the interview
513
213
  async def async_conduct_interview(
514
214
  self,
515
- model_buckets: Optional[ModelBuckets] = None,
516
- stop_on_exception: bool = False,
517
- sidecar_model: Optional["LanguageModel"] = None,
518
- raise_validation_errors: bool = True,
215
+ run_config: Optional["RunConfig"] = None,
216
+ # model_buckets: Optional[ModelBuckets] = None,
217
+ # stop_on_exception: bool = False,
218
+ # raise_validation_errors: bool = True,
219
+ # key_lookup: Optional[KeyLookup] = None,
519
220
  ) -> tuple["Answers", List[dict[str, Any]]]:
520
221
  """
521
222
  Conduct an Interview asynchronously.
@@ -524,7 +225,6 @@ class Interview:
524
225
  :param model_buckets: a dictionary of token buckets for the model.
525
226
  :param debug: run without calls to LLM.
526
227
  :param stop_on_exception: if True, stops the interview if an exception is raised.
527
- :param sidecar_model: a sidecar model used to answer questions.
528
228
 
529
229
  Example usage:
530
230
 
@@ -538,37 +238,68 @@ class Interview:
538
238
  >>> i.exceptions
539
239
  {'q0': ...
540
240
  >>> i = Interview.example()
541
- >>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
241
+ >>> from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
242
+ >>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
243
+ >>> run_config.parameters.stop_on_exception = True
244
+ >>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
542
245
  Traceback (most recent call last):
543
246
  ...
544
247
  asyncio.exceptions.CancelledError
545
248
  """
546
- self.sidecar_model = sidecar_model
547
- self.stop_on_exception = stop_on_exception
249
+ from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
250
+
251
+ if run_config is None:
252
+ run_config = RunConfig(
253
+ parameters=RunParameters(),
254
+ environment=RunEnvironment(),
255
+ )
256
+ self.stop_on_exception = run_config.parameters.stop_on_exception
548
257
 
549
258
  # if no model bucket is passed, create an 'infinity' bucket with no rate limits
259
+ bucket_collection = run_config.environment.bucket_collection
260
+
261
+ if bucket_collection:
262
+ model_buckets = bucket_collection.get(self.model)
263
+ else:
264
+ model_buckets = None
265
+
550
266
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
551
267
  model_buckets = ModelBuckets.infinity_bucket()
552
268
 
269
+ # was "self.tasks" - is that necessary?
270
+ self.tasks = self.task_manager.build_question_tasks(
271
+ answer_func=AnswerQuestionFunctionConstructor(
272
+ self, key_lookup=run_config.environment.key_lookup
273
+ )(),
274
+ token_estimator=RequestTokenEstimator(self),
275
+ model_buckets=model_buckets,
276
+ )
277
+
553
278
  ## This is the key part---it creates a task for each question,
554
279
  ## with dependencies on the questions that must be answered before this one can be answered.
555
- self.tasks = self._build_question_tasks(model_buckets=model_buckets)
556
280
 
557
- ## 'Invigilators' are used to administer the survey
558
- self.invigilators = [
559
- self._get_invigilator(question) for question in self.survey.questions
560
- ]
281
+ ## 'Invigilators' are used to administer the survey.
282
+ fetcher = FetchInvigilator(
283
+ interview=self,
284
+ current_answers=self.answers,
285
+ key_lookup=run_config.environment.key_lookup,
286
+ )
287
+ self.invigilators = [fetcher(question) for question in self.survey.questions]
561
288
  await asyncio.gather(
562
- *self.tasks, return_exceptions=not stop_on_exception
563
- ) # not stop_on_exception)
289
+ *self.tasks, return_exceptions=not run_config.parameters.stop_on_exception
290
+ )
564
291
  self.answers.replace_missing_answers_with_none(self.survey)
565
- valid_results = list(self._extract_valid_results())
292
+ valid_results = list(
293
+ self._extract_valid_results(self.tasks, self.invigilators, self.exceptions)
294
+ )
566
295
  return self.answers, valid_results
567
296
 
568
- # endregion
569
-
570
- # region: Extracting results and recording errors
571
- def _extract_valid_results(self) -> Generator["Answers", None, None]:
297
+ @staticmethod
298
+ def _extract_valid_results(
299
+ tasks: List["asyncio.Task"],
300
+ invigilators: List["InvigilatorBase"],
301
+ exceptions: InterviewExceptionCollection,
302
+ ) -> Generator["Answers", None, None]:
572
303
  """Extract the valid results from the list of results.
573
304
 
574
305
  It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
@@ -577,16 +308,10 @@ class Interview:
577
308
 
578
309
  >>> i = Interview.example()
579
310
  >>> result, _ = asyncio.run(i.async_conduct_interview())
580
- >>> results = list(i._extract_valid_results())
581
- >>> len(results) == len(i.survey)
582
- True
583
311
  """
584
- assert len(self.tasks) == len(self.invigilators)
585
-
586
- for task, invigilator in zip(self.tasks, self.invigilators):
587
- if not task.done():
588
- raise ValueError(f"Task {task.get_name()} is not done.")
312
+ assert len(tasks) == len(invigilators)
589
313
 
314
+ def handle_task(task, invigilator):
590
315
  try:
591
316
  result = task.result()
592
317
  except asyncio.CancelledError as e: # task was cancelled
@@ -601,18 +326,22 @@ class Interview:
601
326
  exception=e,
602
327
  invigilator=invigilator,
603
328
  )
604
- self.exceptions.add(task.get_name(), exception_entry)
329
+ exceptions.add(task.get_name(), exception_entry)
330
+ return result
605
331
 
606
- yield result
332
+ for task, invigilator in zip(tasks, invigilators):
333
+ if not task.done():
334
+ raise ValueError(f"Task {task.get_name()} is not done.")
607
335
 
608
- # endregion
336
+ yield handle_task(task, invigilator)
609
337
 
610
- # region: Magic methods
611
338
  def __repr__(self) -> str:
612
339
  """Return a string representation of the Interview instance."""
613
340
  return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
614
341
 
615
- def duplicate(self, iteration: int, cache: "Cache") -> Interview:
342
+ def duplicate(
343
+ self, iteration: int, cache: "Cache", randomize_survey: Optional[bool] = True
344
+ ) -> Interview:
616
345
  """Duplicate the interview, but with a new iteration number and cache.
617
346
 
618
347
  >>> i = Interview.example()
@@ -621,14 +350,20 @@ class Interview:
621
350
  True
622
351
 
623
352
  """
353
+ if randomize_survey:
354
+ new_survey = self.survey.draw()
355
+ else:
356
+ new_survey = self.survey
357
+
624
358
  return Interview(
625
359
  agent=self.agent,
626
- survey=self.survey,
360
+ survey=new_survey,
627
361
  scenario=self.scenario,
628
362
  model=self.model,
629
363
  iteration=iteration,
630
- cache=cache,
631
- skip_retry=self.skip_retry,
364
+ cache=self.running_config.cache,
365
+ skip_retry=self.running_config.skip_retry,
366
+ indices=self.indices,
632
367
  )
633
368
 
634
369
  @classmethod