edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (248) hide show
  1. edsl/Base.py +303 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +49 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +858 -858
  7. edsl/agents/AgentList.py +362 -362
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +284 -284
  10. edsl/agents/PromptConstructor.py +353 -353
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +99 -99
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +279 -279
  26. edsl/config.py +149 -149
  27. edsl/conversation/Conversation.py +290 -290
  28. edsl/conversation/car_buying.py +58 -58
  29. edsl/conversation/chips.py +95 -95
  30. edsl/conversation/mug_negotiation.py +81 -81
  31. edsl/conversation/next_speaker_utilities.py +93 -93
  32. edsl/coop/PriceFetcher.py +54 -54
  33. edsl/coop/__init__.py +2 -2
  34. edsl/coop/coop.py +961 -961
  35. edsl/coop/utils.py +131 -131
  36. edsl/data/Cache.py +530 -530
  37. edsl/data/CacheEntry.py +228 -228
  38. edsl/data/CacheHandler.py +149 -149
  39. edsl/data/RemoteCacheSync.py +97 -97
  40. edsl/data/SQLiteDict.py +292 -292
  41. edsl/data/__init__.py +4 -4
  42. edsl/data/orm.py +10 -10
  43. edsl/data_transfer_models.py +73 -73
  44. edsl/enums.py +173 -173
  45. edsl/exceptions/BaseException.py +21 -21
  46. edsl/exceptions/__init__.py +54 -54
  47. edsl/exceptions/agents.py +42 -42
  48. edsl/exceptions/cache.py +5 -5
  49. edsl/exceptions/configuration.py +16 -16
  50. edsl/exceptions/coop.py +10 -10
  51. edsl/exceptions/data.py +14 -14
  52. edsl/exceptions/general.py +34 -34
  53. edsl/exceptions/jobs.py +33 -33
  54. edsl/exceptions/language_models.py +63 -63
  55. edsl/exceptions/prompts.py +15 -15
  56. edsl/exceptions/questions.py +91 -91
  57. edsl/exceptions/results.py +29 -29
  58. edsl/exceptions/scenarios.py +22 -22
  59. edsl/exceptions/surveys.py +37 -37
  60. edsl/inference_services/AnthropicService.py +87 -87
  61. edsl/inference_services/AwsBedrock.py +120 -120
  62. edsl/inference_services/AzureAI.py +217 -217
  63. edsl/inference_services/DeepInfraService.py +18 -18
  64. edsl/inference_services/GoogleService.py +156 -156
  65. edsl/inference_services/GroqService.py +20 -20
  66. edsl/inference_services/InferenceServiceABC.py +147 -147
  67. edsl/inference_services/InferenceServicesCollection.py +97 -97
  68. edsl/inference_services/MistralAIService.py +123 -123
  69. edsl/inference_services/OllamaService.py +18 -18
  70. edsl/inference_services/OpenAIService.py +224 -224
  71. edsl/inference_services/TestService.py +89 -89
  72. edsl/inference_services/TogetherAIService.py +170 -170
  73. edsl/inference_services/models_available_cache.py +118 -118
  74. edsl/inference_services/rate_limits_cache.py +25 -25
  75. edsl/inference_services/registry.py +39 -39
  76. edsl/inference_services/write_available.py +10 -10
  77. edsl/jobs/Answers.py +56 -56
  78. edsl/jobs/Jobs.py +1358 -1358
  79. edsl/jobs/__init__.py +1 -1
  80. edsl/jobs/buckets/BucketCollection.py +63 -63
  81. edsl/jobs/buckets/ModelBuckets.py +65 -65
  82. edsl/jobs/buckets/TokenBucket.py +251 -251
  83. edsl/jobs/interviews/Interview.py +661 -661
  84. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  85. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  86. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  87. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  88. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  89. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  90. edsl/jobs/interviews/ReportErrors.py +66 -66
  91. edsl/jobs/interviews/interview_status_enum.py +9 -9
  92. edsl/jobs/runners/JobsRunnerAsyncio.py +361 -361
  93. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  94. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  95. edsl/jobs/tasks/TaskCreators.py +64 -64
  96. edsl/jobs/tasks/TaskHistory.py +451 -451
  97. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  98. edsl/jobs/tasks/task_status_enum.py +163 -163
  99. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  100. edsl/jobs/tokens/TokenUsage.py +34 -34
  101. edsl/language_models/KeyLookup.py +30 -30
  102. edsl/language_models/LanguageModel.py +708 -708
  103. edsl/language_models/ModelList.py +109 -109
  104. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  105. edsl/language_models/__init__.py +3 -3
  106. edsl/language_models/fake_openai_call.py +15 -15
  107. edsl/language_models/fake_openai_service.py +61 -61
  108. edsl/language_models/registry.py +137 -137
  109. edsl/language_models/repair.py +156 -156
  110. edsl/language_models/unused/ReplicateBase.py +83 -83
  111. edsl/language_models/utilities.py +64 -64
  112. edsl/notebooks/Notebook.py +258 -258
  113. edsl/notebooks/__init__.py +1 -1
  114. edsl/prompts/Prompt.py +357 -357
  115. edsl/prompts/__init__.py +2 -2
  116. edsl/questions/AnswerValidatorMixin.py +289 -289
  117. edsl/questions/QuestionBase.py +660 -660
  118. edsl/questions/QuestionBaseGenMixin.py +161 -161
  119. edsl/questions/QuestionBasePromptsMixin.py +217 -217
  120. edsl/questions/QuestionBudget.py +227 -227
  121. edsl/questions/QuestionCheckBox.py +359 -359
  122. edsl/questions/QuestionExtract.py +183 -183
  123. edsl/questions/QuestionFreeText.py +114 -114
  124. edsl/questions/QuestionFunctional.py +166 -166
  125. edsl/questions/QuestionList.py +231 -231
  126. edsl/questions/QuestionMultipleChoice.py +286 -286
  127. edsl/questions/QuestionNumerical.py +153 -153
  128. edsl/questions/QuestionRank.py +324 -324
  129. edsl/questions/Quick.py +41 -41
  130. edsl/questions/RegisterQuestionsMeta.py +71 -71
  131. edsl/questions/ResponseValidatorABC.py +174 -174
  132. edsl/questions/SimpleAskMixin.py +73 -73
  133. edsl/questions/__init__.py +26 -26
  134. edsl/questions/compose_questions.py +98 -98
  135. edsl/questions/decorators.py +21 -21
  136. edsl/questions/derived/QuestionLikertFive.py +76 -76
  137. edsl/questions/derived/QuestionLinearScale.py +87 -87
  138. edsl/questions/derived/QuestionTopK.py +93 -93
  139. edsl/questions/derived/QuestionYesNo.py +82 -82
  140. edsl/questions/descriptors.py +413 -413
  141. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  142. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  143. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  144. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  145. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  146. edsl/questions/prompt_templates/question_list.jinja +17 -17
  147. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  148. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  149. edsl/questions/question_registry.py +147 -147
  150. edsl/questions/settings.py +12 -12
  151. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  152. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  153. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  154. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  155. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  156. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  157. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  158. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  159. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  160. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  161. edsl/questions/templates/list/question_presentation.jinja +5 -5
  162. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  163. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  164. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  165. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  166. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  167. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  168. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  169. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  170. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  171. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  172. edsl/results/Dataset.py +293 -293
  173. edsl/results/DatasetExportMixin.py +717 -717
  174. edsl/results/DatasetTree.py +145 -145
  175. edsl/results/Result.py +456 -456
  176. edsl/results/Results.py +1071 -1071
  177. edsl/results/ResultsDBMixin.py +238 -238
  178. edsl/results/ResultsExportMixin.py +43 -43
  179. edsl/results/ResultsFetchMixin.py +33 -33
  180. edsl/results/ResultsGGMixin.py +121 -121
  181. edsl/results/ResultsToolsMixin.py +98 -98
  182. edsl/results/Selector.py +135 -135
  183. edsl/results/__init__.py +2 -2
  184. edsl/results/tree_explore.py +115 -115
  185. edsl/scenarios/FileStore.py +458 -458
  186. edsl/scenarios/Scenario.py +544 -544
  187. edsl/scenarios/ScenarioHtmlMixin.py +64 -64
  188. edsl/scenarios/ScenarioList.py +1112 -1112
  189. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  190. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  191. edsl/scenarios/__init__.py +4 -4
  192. edsl/shared.py +1 -1
  193. edsl/study/ObjectEntry.py +173 -173
  194. edsl/study/ProofOfWork.py +113 -113
  195. edsl/study/SnapShot.py +80 -80
  196. edsl/study/Study.py +528 -528
  197. edsl/study/__init__.py +4 -4
  198. edsl/surveys/DAG.py +148 -148
  199. edsl/surveys/Memory.py +31 -31
  200. edsl/surveys/MemoryPlan.py +244 -244
  201. edsl/surveys/Rule.py +326 -326
  202. edsl/surveys/RuleCollection.py +387 -387
  203. edsl/surveys/Survey.py +1787 -1787
  204. edsl/surveys/SurveyCSS.py +261 -261
  205. edsl/surveys/SurveyExportMixin.py +259 -259
  206. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  207. edsl/surveys/SurveyQualtricsImport.py +284 -284
  208. edsl/surveys/__init__.py +3 -3
  209. edsl/surveys/base.py +53 -53
  210. edsl/surveys/descriptors.py +56 -56
  211. edsl/surveys/instructions/ChangeInstruction.py +49 -49
  212. edsl/surveys/instructions/Instruction.py +53 -53
  213. edsl/surveys/instructions/InstructionCollection.py +77 -77
  214. edsl/templates/error_reporting/base.html +23 -23
  215. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  216. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  217. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  218. edsl/templates/error_reporting/interview_details.html +115 -115
  219. edsl/templates/error_reporting/interviews.html +9 -9
  220. edsl/templates/error_reporting/overview.html +4 -4
  221. edsl/templates/error_reporting/performance_plot.html +1 -1
  222. edsl/templates/error_reporting/report.css +73 -73
  223. edsl/templates/error_reporting/report.html +117 -117
  224. edsl/templates/error_reporting/report.js +25 -25
  225. edsl/tools/__init__.py +1 -1
  226. edsl/tools/clusters.py +192 -192
  227. edsl/tools/embeddings.py +27 -27
  228. edsl/tools/embeddings_plotting.py +118 -118
  229. edsl/tools/plotting.py +112 -112
  230. edsl/tools/summarize.py +18 -18
  231. edsl/utilities/SystemInfo.py +28 -28
  232. edsl/utilities/__init__.py +22 -22
  233. edsl/utilities/ast_utilities.py +25 -25
  234. edsl/utilities/data/Registry.py +6 -6
  235. edsl/utilities/data/__init__.py +1 -1
  236. edsl/utilities/data/scooter_results.json +1 -1
  237. edsl/utilities/decorators.py +77 -77
  238. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  239. edsl/utilities/interface.py +627 -627
  240. edsl/utilities/naming_utilities.py +263 -263
  241. edsl/utilities/repair_functions.py +28 -28
  242. edsl/utilities/restricted_python.py +70 -70
  243. edsl/utilities/utilities.py +409 -409
  244. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev3.dist-info}/LICENSE +21 -21
  245. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev3.dist-info}/METADATA +1 -1
  246. edsl-0.1.38.dev3.dist-info/RECORD +269 -0
  247. edsl-0.1.38.dev2.dist-info/RECORD +0 -269
  248. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev3.dist-info}/WHEEL +0 -0
@@ -1,361 +1,361 @@
1
- from __future__ import annotations
2
- import time
3
- import asyncio
4
- import threading
5
- import warnings
6
- from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
7
- from collections import UserList
8
-
9
- from edsl.results.Results import Results
10
- from edsl.jobs.interviews.Interview import Interview
11
- from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
12
-
13
- from edsl.jobs.tasks.TaskHistory import TaskHistory
14
- from edsl.jobs.buckets.BucketCollection import BucketCollection
15
- from edsl.utilities.decorators import jupyter_nb_handler
16
- from edsl.data.Cache import Cache
17
- from edsl.results.Result import Result
18
- from edsl.results.Results import Results
19
- from edsl.language_models.LanguageModel import LanguageModel
20
- from edsl.data.Cache import Cache
21
-
22
-
23
- class StatusTracker(UserList):
24
- def __init__(self, total_tasks: int):
25
- self.total_tasks = total_tasks
26
- super().__init__()
27
-
28
- def current_status(self):
29
- return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
30
-
31
-
32
- class JobsRunnerAsyncio:
33
- """A class for running a collection of interviews asynchronously.
34
-
35
- It gets instaniated from a Jobs object.
36
- The Jobs object is a collection of interviews that are to be run.
37
- """
38
-
39
- def __init__(self, jobs: "Jobs"):
40
- self.jobs = jobs
41
- self.interviews: List["Interview"] = jobs.interviews()
42
- self.bucket_collection: "BucketCollection" = jobs.bucket_collection
43
- self.total_interviews: List["Interview"] = []
44
- self._initialized = threading.Event()
45
-
46
- async def run_async_generator(
47
- self,
48
- cache: Cache,
49
- n: int = 1,
50
- stop_on_exception: bool = False,
51
- sidecar_model: Optional[LanguageModel] = None,
52
- total_interviews: Optional[List["Interview"]] = None,
53
- raise_validation_errors: bool = False,
54
- ) -> AsyncGenerator["Result", None]:
55
- """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
56
-
57
- Completed tasks are yielded as they are completed.
58
-
59
- :param n: how many times to run each interview
60
- :param stop_on_exception: Whether to stop the interview if an exception is raised
61
- :param sidecar_model: a language model to use in addition to the interview's model
62
- :param total_interviews: A list of interviews to run can be provided instead.
63
- :param raise_validation_errors: Whether to raise validation errors
64
- """
65
- tasks = []
66
- if total_interviews: # was already passed in total interviews
67
- self.total_interviews = total_interviews
68
- else:
69
- self.total_interviews = list(
70
- self._populate_total_interviews(n=n)
71
- ) # Populate self.total_interviews before creating tasks
72
-
73
- self._initialized.set() # Signal that we're ready
74
-
75
- for interview in self.total_interviews:
76
- interviewing_task = self._build_interview_task(
77
- interview=interview,
78
- stop_on_exception=stop_on_exception,
79
- sidecar_model=sidecar_model,
80
- raise_validation_errors=raise_validation_errors,
81
- )
82
- tasks.append(asyncio.create_task(interviewing_task))
83
-
84
- for task in asyncio.as_completed(tasks):
85
- result = await task
86
- self.jobs_runner_status.add_completed_interview(result)
87
- yield result
88
-
89
- def _populate_total_interviews(
90
- self, n: int = 1
91
- ) -> Generator["Interview", None, None]:
92
- """Populates self.total_interviews with n copies of each interview.
93
-
94
- :param n: how many times to run each interview.
95
- """
96
- for interview in self.interviews:
97
- for iteration in range(n):
98
- if iteration > 0:
99
- yield interview.duplicate(iteration=iteration, cache=self.cache)
100
- else:
101
- interview.cache = self.cache
102
- yield interview
103
-
104
- async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
105
- """Used for some other modules that have a non-standard way of running interviews."""
106
- self.jobs_runner_status = JobsRunnerStatus(self, n=n)
107
- self.cache = Cache() if cache is None else cache
108
- data = []
109
- async for result in self.run_async_generator(cache=self.cache, n=n):
110
- data.append(result)
111
- return Results(survey=self.jobs.survey, data=data)
112
-
113
- def simple_run(self):
114
- data = asyncio.run(self.run_async())
115
- return Results(survey=self.jobs.survey, data=data)
116
-
117
- async def _build_interview_task(
118
- self,
119
- *,
120
- interview: Interview,
121
- stop_on_exception: bool = False,
122
- sidecar_model: Optional["LanguageModel"] = None,
123
- raise_validation_errors: bool = False,
124
- ) -> "Result":
125
- """Conducts an interview and returns the result.
126
-
127
- :param interview: the interview to conduct
128
- :param stop_on_exception: stops the interview if an exception is raised
129
- :param sidecar_model: a language model to use in addition to the interview's model
130
- """
131
- # the model buckets are used to track usage rates
132
- model_buckets = self.bucket_collection[interview.model]
133
-
134
- # get the results of the interview
135
- answer, valid_results = await interview.async_conduct_interview(
136
- model_buckets=model_buckets,
137
- stop_on_exception=stop_on_exception,
138
- sidecar_model=sidecar_model,
139
- raise_validation_errors=raise_validation_errors,
140
- )
141
-
142
- question_results = {}
143
- for result in valid_results:
144
- question_results[result.question_name] = result
145
-
146
- answer_key_names = list(question_results.keys())
147
-
148
- generated_tokens_dict = {
149
- k + "_generated_tokens": question_results[k].generated_tokens
150
- for k in answer_key_names
151
- }
152
- comments_dict = {
153
- k + "_comment": question_results[k].comment for k in answer_key_names
154
- }
155
-
156
- # we should have a valid result for each question
157
- answer_dict = {k: answer[k] for k in answer_key_names}
158
- assert len(valid_results) == len(answer_key_names)
159
-
160
- # TODO: move this down into Interview
161
- question_name_to_prompts = dict({})
162
- for result in valid_results:
163
- question_name = result.question_name
164
- question_name_to_prompts[question_name] = {
165
- "user_prompt": result.prompts["user_prompt"],
166
- "system_prompt": result.prompts["system_prompt"],
167
- }
168
-
169
- prompt_dictionary = {}
170
- for answer_key_name in answer_key_names:
171
- prompt_dictionary[answer_key_name + "_user_prompt"] = (
172
- question_name_to_prompts[answer_key_name]["user_prompt"]
173
- )
174
- prompt_dictionary[answer_key_name + "_system_prompt"] = (
175
- question_name_to_prompts[answer_key_name]["system_prompt"]
176
- )
177
-
178
- raw_model_results_dictionary = {}
179
- cache_used_dictionary = {}
180
- for result in valid_results:
181
- question_name = result.question_name
182
- raw_model_results_dictionary[question_name + "_raw_model_response"] = (
183
- result.raw_model_response
184
- )
185
- raw_model_results_dictionary[question_name + "_cost"] = result.cost
186
- one_use_buys = (
187
- "NA"
188
- if isinstance(result.cost, str)
189
- or result.cost == 0
190
- or result.cost is None
191
- else 1.0 / result.cost
192
- )
193
- raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
194
- cache_used_dictionary[question_name] = result.cache_used
195
-
196
- result = Result(
197
- agent=interview.agent,
198
- scenario=interview.scenario,
199
- model=interview.model,
200
- iteration=interview.iteration,
201
- answer=answer_dict,
202
- prompt=prompt_dictionary,
203
- raw_model_response=raw_model_results_dictionary,
204
- survey=interview.survey,
205
- generated_tokens=generated_tokens_dict,
206
- comments_dict=comments_dict,
207
- cache_used_dict=cache_used_dictionary,
208
- )
209
- result.interview_hash = hash(interview)
210
-
211
- return result
212
-
213
- @property
214
- def elapsed_time(self):
215
- return time.monotonic() - self.start_time
216
-
217
- def process_results(
218
- self, raw_results: Results, cache: Cache, print_exceptions: bool
219
- ):
220
- interview_lookup = {
221
- hash(interview): index
222
- for index, interview in enumerate(self.total_interviews)
223
- }
224
- interview_hashes = list(interview_lookup.keys())
225
-
226
- task_history = TaskHistory(self.total_interviews, include_traceback=False)
227
-
228
- results = Results(
229
- survey=self.jobs.survey,
230
- data=sorted(
231
- raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
232
- ),
233
- task_history=task_history,
234
- cache=cache,
235
- )
236
- results.bucket_collection = self.bucket_collection
237
-
238
- if results.has_unfixed_exceptions and print_exceptions:
239
- from edsl.scenarios.FileStore import HTMLFileStore
240
- from edsl.config import CONFIG
241
- from edsl.coop.coop import Coop
242
-
243
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
244
-
245
- if len(results.task_history.indices) > 5:
246
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
247
-
248
- print(msg)
249
- # this is where exceptions are opening up
250
- filepath = results.task_history.html(
251
- cta="Open report to see details.",
252
- open_in_browser=True,
253
- return_link=True,
254
- )
255
-
256
- try:
257
- coop = Coop()
258
- user_edsl_settings = coop.edsl_settings
259
- remote_logging = user_edsl_settings["remote_logging"]
260
- except Exception as e:
261
- print(e)
262
- remote_logging = False
263
-
264
- if remote_logging:
265
- filestore = HTMLFileStore(filepath)
266
- coop_details = filestore.push(description="Error report")
267
- print(coop_details)
268
-
269
- print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
270
-
271
- return results
272
-
273
- @jupyter_nb_handler
274
- async def run(
275
- self,
276
- cache: Union[Cache, False, None],
277
- n: int = 1,
278
- stop_on_exception: bool = False,
279
- progress_bar: bool = False,
280
- sidecar_model: Optional[LanguageModel] = None,
281
- jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
282
- print_exceptions: bool = True,
283
- raise_validation_errors: bool = False,
284
- ) -> "Coroutine":
285
- """Runs a collection of interviews, handling both async and sync contexts."""
286
-
287
- self.results = []
288
- self.start_time = time.monotonic()
289
- self.completed = False
290
- self.cache = cache
291
- self.sidecar_model = sidecar_model
292
-
293
- from edsl.coop import Coop
294
-
295
- coop = Coop()
296
- endpoint_url = coop.get_progress_bar_url()
297
-
298
- if jobs_runner_status is not None:
299
- self.jobs_runner_status = jobs_runner_status(
300
- self, n=n, endpoint_url=endpoint_url
301
- )
302
- else:
303
- self.jobs_runner_status = JobsRunnerStatus(
304
- self,
305
- n=n,
306
- endpoint_url=endpoint_url,
307
- )
308
-
309
- stop_event = threading.Event()
310
-
311
- async def process_results(cache):
312
- """Processes results from interviews."""
313
- async for result in self.run_async_generator(
314
- n=n,
315
- stop_on_exception=stop_on_exception,
316
- cache=cache,
317
- sidecar_model=sidecar_model,
318
- raise_validation_errors=raise_validation_errors,
319
- ):
320
- self.results.append(result)
321
- self.completed = True
322
-
323
- def run_progress_bar(stop_event):
324
- """Runs the progress bar in a separate thread."""
325
- self.jobs_runner_status.update_progress(stop_event)
326
-
327
- if progress_bar and self.jobs_runner_status.has_ep_api_key():
328
- self.jobs_runner_status.setup()
329
- progress_thread = threading.Thread(
330
- target=run_progress_bar, args=(stop_event,)
331
- )
332
- progress_thread.start()
333
- elif progress_bar:
334
- warnings.warn(
335
- "You need an Expected Parrot API key to view job progress bars."
336
- )
337
-
338
- exception_to_raise = None
339
- try:
340
- with cache as c:
341
- await process_results(cache=c)
342
- except KeyboardInterrupt:
343
- print("Keyboard interrupt received. Stopping gracefully...")
344
- stop_event.set()
345
- except Exception as e:
346
- if stop_on_exception:
347
- exception_to_raise = e
348
- stop_event.set()
349
- finally:
350
- stop_event.set()
351
- if progress_bar and self.jobs_runner_status.has_ep_api_key():
352
- # self.jobs_runner_status.stop_event.set()
353
- if progress_thread:
354
- progress_thread.join()
355
-
356
- if exception_to_raise:
357
- raise exception_to_raise
358
-
359
- return self.process_results(
360
- raw_results=self.results, cache=cache, print_exceptions=print_exceptions
361
- )
1
+ from __future__ import annotations
2
+ import time
3
+ import asyncio
4
+ import threading
5
+ import warnings
6
+ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
7
+ from uuid import UUID
8
+ from collections import UserList
9
+
10
+ from edsl.results.Results import Results
11
+ from edsl.jobs.interviews.Interview import Interview
12
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
13
+
14
+ from edsl.jobs.tasks.TaskHistory import TaskHistory
15
+ from edsl.jobs.buckets.BucketCollection import BucketCollection
16
+ from edsl.utilities.decorators import jupyter_nb_handler
17
+ from edsl.data.Cache import Cache
18
+ from edsl.results.Result import Result
19
+ from edsl.results.Results import Results
20
+ from edsl.language_models.LanguageModel import LanguageModel
21
+ from edsl.data.Cache import Cache
22
+
23
+
24
+ class StatusTracker(UserList):
25
+ def __init__(self, total_tasks: int):
26
+ self.total_tasks = total_tasks
27
+ super().__init__()
28
+
29
+ def current_status(self):
30
+ return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
31
+
32
+
33
+ class JobsRunnerAsyncio:
34
+ """A class for running a collection of interviews asynchronously.
35
+
36
+ It gets instaniated from a Jobs object.
37
+ The Jobs object is a collection of interviews that are to be run.
38
+ """
39
+
40
+ def __init__(self, jobs: "Jobs"):
41
+ self.jobs = jobs
42
+ self.interviews: List["Interview"] = jobs.interviews()
43
+ self.bucket_collection: "BucketCollection" = jobs.bucket_collection
44
+ self.total_interviews: List["Interview"] = []
45
+ self._initialized = threading.Event()
46
+
47
+ async def run_async_generator(
48
+ self,
49
+ cache: Cache,
50
+ n: int = 1,
51
+ stop_on_exception: bool = False,
52
+ sidecar_model: Optional[LanguageModel] = None,
53
+ total_interviews: Optional[List["Interview"]] = None,
54
+ raise_validation_errors: bool = False,
55
+ ) -> AsyncGenerator["Result", None]:
56
+ """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
57
+
58
+ Completed tasks are yielded as they are completed.
59
+
60
+ :param n: how many times to run each interview
61
+ :param stop_on_exception: Whether to stop the interview if an exception is raised
62
+ :param sidecar_model: a language model to use in addition to the interview's model
63
+ :param total_interviews: A list of interviews to run can be provided instead.
64
+ :param raise_validation_errors: Whether to raise validation errors
65
+ """
66
+ tasks = []
67
+ if total_interviews: # was already passed in total interviews
68
+ self.total_interviews = total_interviews
69
+ else:
70
+ self.total_interviews = list(
71
+ self._populate_total_interviews(n=n)
72
+ ) # Populate self.total_interviews before creating tasks
73
+
74
+ self._initialized.set() # Signal that we're ready
75
+
76
+ for interview in self.total_interviews:
77
+ interviewing_task = self._build_interview_task(
78
+ interview=interview,
79
+ stop_on_exception=stop_on_exception,
80
+ sidecar_model=sidecar_model,
81
+ raise_validation_errors=raise_validation_errors,
82
+ )
83
+ tasks.append(asyncio.create_task(interviewing_task))
84
+
85
+ for task in asyncio.as_completed(tasks):
86
+ result = await task
87
+ self.jobs_runner_status.add_completed_interview(result)
88
+ yield result
89
+
90
+ def _populate_total_interviews(
91
+ self, n: int = 1
92
+ ) -> Generator["Interview", None, None]:
93
+ """Populates self.total_interviews with n copies of each interview.
94
+
95
+ :param n: how many times to run each interview.
96
+ """
97
+ for interview in self.interviews:
98
+ for iteration in range(n):
99
+ if iteration > 0:
100
+ yield interview.duplicate(iteration=iteration, cache=self.cache)
101
+ else:
102
+ interview.cache = self.cache
103
+ yield interview
104
+
105
+ async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
106
+ """Used for some other modules that have a non-standard way of running interviews."""
107
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
108
+ self.cache = Cache() if cache is None else cache
109
+ data = []
110
+ async for result in self.run_async_generator(cache=self.cache, n=n):
111
+ data.append(result)
112
+ return Results(survey=self.jobs.survey, data=data)
113
+
114
+ def simple_run(self):
115
+ data = asyncio.run(self.run_async())
116
+ return Results(survey=self.jobs.survey, data=data)
117
+
118
+ async def _build_interview_task(
119
+ self,
120
+ *,
121
+ interview: Interview,
122
+ stop_on_exception: bool = False,
123
+ sidecar_model: Optional["LanguageModel"] = None,
124
+ raise_validation_errors: bool = False,
125
+ ) -> "Result":
126
+ """Conducts an interview and returns the result.
127
+
128
+ :param interview: the interview to conduct
129
+ :param stop_on_exception: stops the interview if an exception is raised
130
+ :param sidecar_model: a language model to use in addition to the interview's model
131
+ """
132
+ # the model buckets are used to track usage rates
133
+ model_buckets = self.bucket_collection[interview.model]
134
+
135
+ # get the results of the interview
136
+ answer, valid_results = await interview.async_conduct_interview(
137
+ model_buckets=model_buckets,
138
+ stop_on_exception=stop_on_exception,
139
+ sidecar_model=sidecar_model,
140
+ raise_validation_errors=raise_validation_errors,
141
+ )
142
+
143
+ question_results = {}
144
+ for result in valid_results:
145
+ question_results[result.question_name] = result
146
+
147
+ answer_key_names = list(question_results.keys())
148
+
149
+ generated_tokens_dict = {
150
+ k + "_generated_tokens": question_results[k].generated_tokens
151
+ for k in answer_key_names
152
+ }
153
+ comments_dict = {
154
+ k + "_comment": question_results[k].comment for k in answer_key_names
155
+ }
156
+
157
+ # we should have a valid result for each question
158
+ answer_dict = {k: answer[k] for k in answer_key_names}
159
+ assert len(valid_results) == len(answer_key_names)
160
+
161
+ # TODO: move this down into Interview
162
+ question_name_to_prompts = dict({})
163
+ for result in valid_results:
164
+ question_name = result.question_name
165
+ question_name_to_prompts[question_name] = {
166
+ "user_prompt": result.prompts["user_prompt"],
167
+ "system_prompt": result.prompts["system_prompt"],
168
+ }
169
+
170
+ prompt_dictionary = {}
171
+ for answer_key_name in answer_key_names:
172
+ prompt_dictionary[answer_key_name + "_user_prompt"] = (
173
+ question_name_to_prompts[answer_key_name]["user_prompt"]
174
+ )
175
+ prompt_dictionary[answer_key_name + "_system_prompt"] = (
176
+ question_name_to_prompts[answer_key_name]["system_prompt"]
177
+ )
178
+
179
+ raw_model_results_dictionary = {}
180
+ cache_used_dictionary = {}
181
+ for result in valid_results:
182
+ question_name = result.question_name
183
+ raw_model_results_dictionary[question_name + "_raw_model_response"] = (
184
+ result.raw_model_response
185
+ )
186
+ raw_model_results_dictionary[question_name + "_cost"] = result.cost
187
+ one_use_buys = (
188
+ "NA"
189
+ if isinstance(result.cost, str)
190
+ or result.cost == 0
191
+ or result.cost is None
192
+ else 1.0 / result.cost
193
+ )
194
+ raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
195
+ cache_used_dictionary[question_name] = result.cache_used
196
+
197
+ result = Result(
198
+ agent=interview.agent,
199
+ scenario=interview.scenario,
200
+ model=interview.model,
201
+ iteration=interview.iteration,
202
+ answer=answer_dict,
203
+ prompt=prompt_dictionary,
204
+ raw_model_response=raw_model_results_dictionary,
205
+ survey=interview.survey,
206
+ generated_tokens=generated_tokens_dict,
207
+ comments_dict=comments_dict,
208
+ cache_used_dict=cache_used_dictionary,
209
+ )
210
+ result.interview_hash = hash(interview)
211
+
212
+ return result
213
+
214
+ @property
215
+ def elapsed_time(self):
216
+ return time.monotonic() - self.start_time
217
+
218
+ def process_results(
219
+ self, raw_results: Results, cache: Cache, print_exceptions: bool
220
+ ):
221
+ interview_lookup = {
222
+ hash(interview): index
223
+ for index, interview in enumerate(self.total_interviews)
224
+ }
225
+ interview_hashes = list(interview_lookup.keys())
226
+
227
+ task_history = TaskHistory(self.total_interviews, include_traceback=False)
228
+
229
+ results = Results(
230
+ survey=self.jobs.survey,
231
+ data=sorted(
232
+ raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
233
+ ),
234
+ task_history=task_history,
235
+ cache=cache,
236
+ )
237
+ results.bucket_collection = self.bucket_collection
238
+
239
+ if results.has_unfixed_exceptions and print_exceptions:
240
+ from edsl.scenarios.FileStore import HTMLFileStore
241
+ from edsl.config import CONFIG
242
+ from edsl.coop.coop import Coop
243
+
244
+ msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
245
+
246
+ if len(results.task_history.indices) > 5:
247
+ msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
248
+
249
+ print(msg)
250
+ # this is where exceptions are opening up
251
+ filepath = results.task_history.html(
252
+ cta="Open report to see details.",
253
+ open_in_browser=True,
254
+ return_link=True,
255
+ )
256
+
257
+ try:
258
+ coop = Coop()
259
+ user_edsl_settings = coop.edsl_settings
260
+ remote_logging = user_edsl_settings["remote_logging"]
261
+ except Exception as e:
262
+ print(e)
263
+ remote_logging = False
264
+
265
+ if remote_logging:
266
+ filestore = HTMLFileStore(filepath)
267
+ coop_details = filestore.push(description="Error report")
268
+ print(coop_details)
269
+
270
+ print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
271
+
272
+ return results
273
+
274
+ @jupyter_nb_handler
275
+ async def run(
276
+ self,
277
+ cache: Union[Cache, False, None],
278
+ n: int = 1,
279
+ stop_on_exception: bool = False,
280
+ progress_bar: bool = False,
281
+ sidecar_model: Optional[LanguageModel] = None,
282
+ jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
283
+ job_uuid: Optional[UUID] = None,
284
+ print_exceptions: bool = True,
285
+ raise_validation_errors: bool = False,
286
+ ) -> "Coroutine":
287
+ """Runs a collection of interviews, handling both async and sync contexts."""
288
+
289
+ self.results = []
290
+ self.start_time = time.monotonic()
291
+ self.completed = False
292
+ self.cache = cache
293
+ self.sidecar_model = sidecar_model
294
+
295
+ from edsl.coop import Coop
296
+
297
+ coop = Coop()
298
+ endpoint_url = coop.get_progress_bar_url()
299
+
300
+ if jobs_runner_status is not None:
301
+ self.jobs_runner_status = jobs_runner_status(
302
+ self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
303
+ )
304
+ else:
305
+ self.jobs_runner_status = JobsRunnerStatus(
306
+ self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
307
+ )
308
+
309
+ stop_event = threading.Event()
310
+
311
+ async def process_results(cache):
312
+ """Processes results from interviews."""
313
+ async for result in self.run_async_generator(
314
+ n=n,
315
+ stop_on_exception=stop_on_exception,
316
+ cache=cache,
317
+ sidecar_model=sidecar_model,
318
+ raise_validation_errors=raise_validation_errors,
319
+ ):
320
+ self.results.append(result)
321
+ self.completed = True
322
+
323
+ def run_progress_bar(stop_event):
324
+ """Runs the progress bar in a separate thread."""
325
+ self.jobs_runner_status.update_progress(stop_event)
326
+
327
+ if progress_bar and self.jobs_runner_status.has_ep_api_key():
328
+ self.jobs_runner_status.setup()
329
+ progress_thread = threading.Thread(
330
+ target=run_progress_bar, args=(stop_event,)
331
+ )
332
+ progress_thread.start()
333
+ elif progress_bar:
334
+ warnings.warn(
335
+ "You need an Expected Parrot API key to view job progress bars."
336
+ )
337
+
338
+ exception_to_raise = None
339
+ try:
340
+ with cache as c:
341
+ await process_results(cache=c)
342
+ except KeyboardInterrupt:
343
+ print("Keyboard interrupt received. Stopping gracefully...")
344
+ stop_event.set()
345
+ except Exception as e:
346
+ if stop_on_exception:
347
+ exception_to_raise = e
348
+ stop_event.set()
349
+ finally:
350
+ stop_event.set()
351
+ if progress_bar and self.jobs_runner_status.has_ep_api_key():
352
+ # self.jobs_runner_status.stop_event.set()
353
+ if progress_thread:
354
+ progress_thread.join()
355
+
356
+ if exception_to_raise:
357
+ raise exception_to_raise
358
+
359
+ return self.process_results(
360
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
361
+ )