edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. edsl/Base.py +332 -385
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +49 -57
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +867 -1079
  7. edsl/agents/AgentList.py +413 -551
  8. edsl/agents/Invigilator.py +233 -285
  9. edsl/agents/InvigilatorBase.py +270 -254
  10. edsl/agents/PromptConstructor.py +354 -252
  11. edsl/agents/__init__.py +3 -2
  12. edsl/agents/descriptors.py +99 -99
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +279 -279
  26. edsl/config.py +157 -177
  27. edsl/conversation/Conversation.py +290 -290
  28. edsl/conversation/car_buying.py +58 -59
  29. edsl/conversation/chips.py +95 -95
  30. edsl/conversation/mug_negotiation.py +81 -81
  31. edsl/conversation/next_speaker_utilities.py +93 -93
  32. edsl/coop/PriceFetcher.py +54 -54
  33. edsl/coop/__init__.py +2 -2
  34. edsl/coop/coop.py +1028 -1090
  35. edsl/coop/utils.py +131 -131
  36. edsl/data/Cache.py +555 -562
  37. edsl/data/CacheEntry.py +233 -230
  38. edsl/data/CacheHandler.py +149 -170
  39. edsl/data/RemoteCacheSync.py +78 -78
  40. edsl/data/SQLiteDict.py +292 -292
  41. edsl/data/__init__.py +4 -5
  42. edsl/data/orm.py +10 -10
  43. edsl/data_transfer_models.py +73 -74
  44. edsl/enums.py +175 -195
  45. edsl/exceptions/BaseException.py +21 -21
  46. edsl/exceptions/__init__.py +54 -54
  47. edsl/exceptions/agents.py +42 -54
  48. edsl/exceptions/cache.py +5 -5
  49. edsl/exceptions/configuration.py +16 -16
  50. edsl/exceptions/coop.py +10 -10
  51. edsl/exceptions/data.py +14 -14
  52. edsl/exceptions/general.py +34 -34
  53. edsl/exceptions/jobs.py +33 -33
  54. edsl/exceptions/language_models.py +63 -63
  55. edsl/exceptions/prompts.py +15 -15
  56. edsl/exceptions/questions.py +91 -109
  57. edsl/exceptions/results.py +29 -29
  58. edsl/exceptions/scenarios.py +22 -29
  59. edsl/exceptions/surveys.py +37 -37
  60. edsl/inference_services/AnthropicService.py +87 -84
  61. edsl/inference_services/AwsBedrock.py +120 -118
  62. edsl/inference_services/AzureAI.py +217 -215
  63. edsl/inference_services/DeepInfraService.py +18 -18
  64. edsl/inference_services/GoogleService.py +148 -139
  65. edsl/inference_services/GroqService.py +20 -20
  66. edsl/inference_services/InferenceServiceABC.py +147 -80
  67. edsl/inference_services/InferenceServicesCollection.py +97 -122
  68. edsl/inference_services/MistralAIService.py +123 -120
  69. edsl/inference_services/OllamaService.py +18 -18
  70. edsl/inference_services/OpenAIService.py +224 -221
  71. edsl/inference_services/PerplexityService.py +163 -160
  72. edsl/inference_services/TestService.py +89 -92
  73. edsl/inference_services/TogetherAIService.py +170 -170
  74. edsl/inference_services/models_available_cache.py +118 -118
  75. edsl/inference_services/rate_limits_cache.py +25 -25
  76. edsl/inference_services/registry.py +41 -41
  77. edsl/inference_services/write_available.py +10 -10
  78. edsl/jobs/Answers.py +56 -43
  79. edsl/jobs/Jobs.py +898 -757
  80. edsl/jobs/JobsChecks.py +147 -172
  81. edsl/jobs/JobsPrompts.py +268 -270
  82. edsl/jobs/JobsRemoteInferenceHandler.py +239 -287
  83. edsl/jobs/__init__.py +1 -1
  84. edsl/jobs/buckets/BucketCollection.py +63 -104
  85. edsl/jobs/buckets/ModelBuckets.py +65 -65
  86. edsl/jobs/buckets/TokenBucket.py +251 -283
  87. edsl/jobs/interviews/Interview.py +661 -358
  88. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  89. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  90. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  91. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  92. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  93. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  94. edsl/jobs/interviews/ReportErrors.py +66 -66
  95. edsl/jobs/interviews/interview_status_enum.py +9 -9
  96. edsl/jobs/runners/JobsRunnerAsyncio.py +466 -421
  97. edsl/jobs/runners/JobsRunnerStatus.py +330 -330
  98. edsl/jobs/tasks/QuestionTaskCreator.py +242 -244
  99. edsl/jobs/tasks/TaskCreators.py +64 -64
  100. edsl/jobs/tasks/TaskHistory.py +450 -449
  101. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  102. edsl/jobs/tasks/task_status_enum.py +163 -161
  103. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  104. edsl/jobs/tokens/TokenUsage.py +34 -34
  105. edsl/language_models/KeyLookup.py +30 -0
  106. edsl/language_models/LanguageModel.py +668 -571
  107. edsl/language_models/ModelList.py +155 -153
  108. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  109. edsl/language_models/__init__.py +3 -2
  110. edsl/language_models/fake_openai_call.py +15 -15
  111. edsl/language_models/fake_openai_service.py +61 -61
  112. edsl/language_models/registry.py +190 -180
  113. edsl/language_models/repair.py +156 -156
  114. edsl/language_models/unused/ReplicateBase.py +83 -0
  115. edsl/language_models/utilities.py +64 -65
  116. edsl/notebooks/Notebook.py +258 -263
  117. edsl/notebooks/__init__.py +1 -1
  118. edsl/prompts/Prompt.py +362 -352
  119. edsl/prompts/__init__.py +2 -2
  120. edsl/questions/AnswerValidatorMixin.py +289 -334
  121. edsl/questions/QuestionBase.py +664 -509
  122. edsl/questions/QuestionBaseGenMixin.py +161 -165
  123. edsl/questions/QuestionBasePromptsMixin.py +217 -221
  124. edsl/questions/QuestionBudget.py +227 -227
  125. edsl/questions/QuestionCheckBox.py +359 -359
  126. edsl/questions/QuestionExtract.py +182 -182
  127. edsl/questions/QuestionFreeText.py +114 -113
  128. edsl/questions/QuestionFunctional.py +166 -166
  129. edsl/questions/QuestionList.py +231 -229
  130. edsl/questions/QuestionMultipleChoice.py +286 -330
  131. edsl/questions/QuestionNumerical.py +153 -151
  132. edsl/questions/QuestionRank.py +324 -314
  133. edsl/questions/Quick.py +41 -41
  134. edsl/questions/RegisterQuestionsMeta.py +71 -71
  135. edsl/questions/ResponseValidatorABC.py +174 -200
  136. edsl/questions/SimpleAskMixin.py +73 -74
  137. edsl/questions/__init__.py +26 -27
  138. edsl/questions/compose_questions.py +98 -98
  139. edsl/questions/decorators.py +21 -21
  140. edsl/questions/derived/QuestionLikertFive.py +76 -76
  141. edsl/questions/derived/QuestionLinearScale.py +87 -90
  142. edsl/questions/derived/QuestionTopK.py +93 -93
  143. edsl/questions/derived/QuestionYesNo.py +82 -82
  144. edsl/questions/descriptors.py +413 -427
  145. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  146. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  147. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  148. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  149. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  150. edsl/questions/prompt_templates/question_list.jinja +17 -17
  151. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  152. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  153. edsl/questions/question_registry.py +177 -177
  154. edsl/questions/settings.py +12 -12
  155. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  156. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  157. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  158. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  159. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  160. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  161. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  162. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  163. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  164. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  165. edsl/questions/templates/list/question_presentation.jinja +5 -5
  166. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  167. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  168. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  169. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  170. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  171. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  172. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  173. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  174. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  176. edsl/results/CSSParameterizer.py +108 -108
  177. edsl/results/Dataset.py +424 -587
  178. edsl/results/DatasetExportMixin.py +731 -653
  179. edsl/results/DatasetTree.py +275 -295
  180. edsl/results/Result.py +465 -451
  181. edsl/results/Results.py +1165 -1172
  182. edsl/results/ResultsDBMixin.py +238 -0
  183. edsl/results/ResultsExportMixin.py +43 -45
  184. edsl/results/ResultsFetchMixin.py +33 -33
  185. edsl/results/ResultsGGMixin.py +121 -121
  186. edsl/results/ResultsToolsMixin.py +98 -98
  187. edsl/results/Selector.py +135 -145
  188. edsl/results/TableDisplay.py +198 -125
  189. edsl/results/__init__.py +2 -2
  190. edsl/results/table_display.css +77 -77
  191. edsl/results/tree_explore.py +115 -115
  192. edsl/scenarios/FileStore.py +632 -511
  193. edsl/scenarios/Scenario.py +601 -498
  194. edsl/scenarios/ScenarioHtmlMixin.py +64 -65
  195. edsl/scenarios/ScenarioJoin.py +127 -131
  196. edsl/scenarios/ScenarioList.py +1287 -1430
  197. edsl/scenarios/ScenarioListExportMixin.py +52 -45
  198. edsl/scenarios/ScenarioListPdfMixin.py +261 -239
  199. edsl/scenarios/__init__.py +4 -3
  200. edsl/shared.py +1 -1
  201. edsl/study/ObjectEntry.py +173 -173
  202. edsl/study/ProofOfWork.py +113 -113
  203. edsl/study/SnapShot.py +80 -80
  204. edsl/study/Study.py +528 -521
  205. edsl/study/__init__.py +4 -4
  206. edsl/surveys/DAG.py +148 -148
  207. edsl/surveys/Memory.py +31 -31
  208. edsl/surveys/MemoryPlan.py +244 -244
  209. edsl/surveys/Rule.py +326 -327
  210. edsl/surveys/RuleCollection.py +387 -385
  211. edsl/surveys/Survey.py +1801 -1229
  212. edsl/surveys/SurveyCSS.py +261 -273
  213. edsl/surveys/SurveyExportMixin.py +259 -259
  214. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +179 -181
  215. edsl/surveys/SurveyQualtricsImport.py +284 -284
  216. edsl/surveys/__init__.py +3 -5
  217. edsl/surveys/base.py +53 -53
  218. edsl/surveys/descriptors.py +56 -60
  219. edsl/surveys/instructions/ChangeInstruction.py +49 -48
  220. edsl/surveys/instructions/Instruction.py +65 -56
  221. edsl/surveys/instructions/InstructionCollection.py +77 -82
  222. edsl/templates/error_reporting/base.html +23 -23
  223. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  224. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  225. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  226. edsl/templates/error_reporting/interview_details.html +115 -115
  227. edsl/templates/error_reporting/interviews.html +19 -19
  228. edsl/templates/error_reporting/overview.html +4 -4
  229. edsl/templates/error_reporting/performance_plot.html +1 -1
  230. edsl/templates/error_reporting/report.css +73 -73
  231. edsl/templates/error_reporting/report.html +117 -117
  232. edsl/templates/error_reporting/report.js +25 -25
  233. edsl/tools/__init__.py +1 -1
  234. edsl/tools/clusters.py +192 -192
  235. edsl/tools/embeddings.py +27 -27
  236. edsl/tools/embeddings_plotting.py +118 -118
  237. edsl/tools/plotting.py +112 -112
  238. edsl/tools/summarize.py +18 -18
  239. edsl/utilities/SystemInfo.py +28 -28
  240. edsl/utilities/__init__.py +22 -22
  241. edsl/utilities/ast_utilities.py +25 -25
  242. edsl/utilities/data/Registry.py +6 -6
  243. edsl/utilities/data/__init__.py +1 -1
  244. edsl/utilities/data/scooter_results.json +1 -1
  245. edsl/utilities/decorators.py +77 -77
  246. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  247. edsl/utilities/interface.py +627 -627
  248. edsl/utilities/naming_utilities.py +263 -263
  249. edsl/utilities/repair_functions.py +28 -28
  250. edsl/utilities/restricted_python.py +70 -70
  251. edsl/utilities/utilities.py +424 -436
  252. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/LICENSE +21 -21
  253. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/METADATA +10 -12
  254. edsl-0.1.39.dev3.dist-info/RECORD +277 -0
  255. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  256. edsl/agents/QuestionOptionProcessor.py +0 -172
  257. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  258. edsl/coop/CoopFunctionsMixin.py +0 -15
  259. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  260. edsl/exceptions/inference_services.py +0 -5
  261. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  262. edsl/inference_services/AvailableModelFetcher.py +0 -209
  263. edsl/inference_services/ServiceAvailability.py +0 -135
  264. edsl/inference_services/data_structures.py +0 -62
  265. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -188
  266. edsl/jobs/FetchInvigilator.py +0 -40
  267. edsl/jobs/InterviewTaskManager.py +0 -98
  268. edsl/jobs/InterviewsConstructor.py +0 -48
  269. edsl/jobs/JobsComponentConstructor.py +0 -189
  270. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  271. edsl/jobs/RequestTokenEstimator.py +0 -30
  272. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  273. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  274. edsl/jobs/decorators.py +0 -35
  275. edsl/jobs/jobs_status_enums.py +0 -9
  276. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  277. edsl/language_models/ComputeCost.py +0 -63
  278. edsl/language_models/PriceManager.py +0 -127
  279. edsl/language_models/RawResponseHandler.py +0 -106
  280. edsl/language_models/ServiceDataSources.py +0 -0
  281. edsl/language_models/key_management/KeyLookup.py +0 -63
  282. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  283. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  284. edsl/language_models/key_management/__init__.py +0 -0
  285. edsl/language_models/key_management/models.py +0 -131
  286. edsl/notebooks/NotebookToLaTeX.py +0 -142
  287. edsl/questions/ExceptionExplainer.py +0 -77
  288. edsl/questions/HTMLQuestion.py +0 -103
  289. edsl/questions/LoopProcessor.py +0 -149
  290. edsl/questions/QuestionMatrix.py +0 -265
  291. edsl/questions/ResponseValidatorFactory.py +0 -28
  292. edsl/questions/templates/matrix/__init__.py +0 -1
  293. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  294. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  295. edsl/results/MarkdownToDocx.py +0 -122
  296. edsl/results/MarkdownToPDF.py +0 -111
  297. edsl/results/TextEditor.py +0 -50
  298. edsl/results/smart_objects.py +0 -96
  299. edsl/results/table_data_class.py +0 -12
  300. edsl/results/table_renderers.py +0 -118
  301. edsl/scenarios/ConstructDownloadLink.py +0 -109
  302. edsl/scenarios/DirectoryScanner.py +0 -96
  303. edsl/scenarios/DocumentChunker.py +0 -102
  304. edsl/scenarios/DocxScenario.py +0 -16
  305. edsl/scenarios/PdfExtractor.py +0 -40
  306. edsl/scenarios/ScenarioSelector.py +0 -156
  307. edsl/scenarios/file_methods.py +0 -85
  308. edsl/scenarios/handlers/__init__.py +0 -13
  309. edsl/scenarios/handlers/csv.py +0 -38
  310. edsl/scenarios/handlers/docx.py +0 -76
  311. edsl/scenarios/handlers/html.py +0 -37
  312. edsl/scenarios/handlers/json.py +0 -111
  313. edsl/scenarios/handlers/latex.py +0 -5
  314. edsl/scenarios/handlers/md.py +0 -51
  315. edsl/scenarios/handlers/pdf.py +0 -68
  316. edsl/scenarios/handlers/png.py +0 -39
  317. edsl/scenarios/handlers/pptx.py +0 -105
  318. edsl/scenarios/handlers/py.py +0 -294
  319. edsl/scenarios/handlers/sql.py +0 -313
  320. edsl/scenarios/handlers/sqlite.py +0 -149
  321. edsl/scenarios/handlers/txt.py +0 -33
  322. edsl/surveys/ConstructDAG.py +0 -92
  323. edsl/surveys/EditSurvey.py +0 -221
  324. edsl/surveys/InstructionHandler.py +0 -100
  325. edsl/surveys/MemoryManagement.py +0 -72
  326. edsl/surveys/RuleManager.py +0 -172
  327. edsl/surveys/Simulator.py +0 -75
  328. edsl/surveys/SurveyToApp.py +0 -141
  329. edsl/utilities/PrettyList.py +0 -56
  330. edsl/utilities/is_notebook.py +0 -18
  331. edsl/utilities/is_valid_variable_name.py +0 -11
  332. edsl/utilities/remove_edsl_version.py +0 -24
  333. edsl-0.1.39.dev2.dist-info/RECORD +0 -352
  334. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/WHEEL +0 -0
@@ -1,421 +1,466 @@
1
- from __future__ import annotations
2
- import time
3
- import asyncio
4
- import threading
5
- import warnings
6
- from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
7
- from uuid import UUID
8
- from collections import UserList
9
-
10
- from edsl.results.Results import Results
11
- from edsl.jobs.interviews.Interview import Interview
12
- from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
13
-
14
- from edsl.jobs.tasks.TaskHistory import TaskHistory
15
- from edsl.jobs.buckets.BucketCollection import BucketCollection
16
- from edsl.utilities.decorators import jupyter_nb_handler
17
- from edsl.data.Cache import Cache
18
- from edsl.results.Result import Result
19
- from edsl.results.Results import Results
20
- from edsl.language_models.LanguageModel import LanguageModel
21
- from edsl.data.Cache import Cache
22
-
23
-
24
- class StatusTracker(UserList):
25
- def __init__(self, total_tasks: int):
26
- self.total_tasks = total_tasks
27
- super().__init__()
28
-
29
- def current_status(self):
30
- return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
31
-
32
-
33
- class JobsRunnerAsyncio:
34
- """A class for running a collection of interviews asynchronously.
35
-
36
- It gets instaniated from a Jobs object.
37
- The Jobs object is a collection of interviews that are to be run.
38
- """
39
-
40
- def __init__(self, jobs: "Jobs", bucket_collection: "BucketCollection"):
41
- self.jobs = jobs
42
- self.interviews: List["Interview"] = jobs.interviews()
43
- self.bucket_collection: "BucketCollection" = bucket_collection
44
-
45
- self.total_interviews: List["Interview"] = []
46
- self._initialized = threading.Event()
47
-
48
- from edsl.config import CONFIG
49
-
50
- self.MAX_CONCURRENT = int(CONFIG.get("EDSL_MAX_CONCURRENT_TASKS"))
51
-
52
- async def run_async_generator(
53
- self,
54
- cache: Cache,
55
- n: int = 1,
56
- stop_on_exception: bool = False,
57
- sidecar_model: Optional[LanguageModel] = None,
58
- total_interviews: Optional[List["Interview"]] = None,
59
- raise_validation_errors: bool = False,
60
- ) -> AsyncGenerator["Result", None]:
61
- """Creates and processes tasks asynchronously, yielding results as they complete.
62
-
63
- Tasks are created and processed in a streaming fashion rather than building the full list upfront.
64
- Results are yielded as soon as they are available.
65
-
66
- :param n: how many times to run each interview
67
- :param stop_on_exception: Whether to stop the interview if an exception is raised
68
- :param sidecar_model: a language model to use in addition to the interview's model
69
- :param total_interviews: A list of interviews to run can be provided instead.
70
- :param raise_validation_errors: Whether to raise validation errors
71
- """
72
- # Initialize interviews iterator
73
- if total_interviews:
74
- interviews_iter = iter(total_interviews)
75
- self.total_interviews = total_interviews
76
- else:
77
- interviews_iter = self._populate_total_interviews(n=n)
78
- self.total_interviews = list(interviews_iter)
79
- interviews_iter = iter(self.total_interviews) # Create fresh iterator
80
-
81
- self._initialized.set() # Signal that we're ready
82
-
83
- # Keep track of active tasks
84
- active_tasks = set()
85
-
86
- try:
87
- while True:
88
- # Add new tasks if we're below max_concurrent and there are more interviews
89
- while len(active_tasks) < self.MAX_CONCURRENT:
90
- try:
91
- interview = next(interviews_iter)
92
- task = asyncio.create_task(
93
- self._build_interview_task(
94
- interview=interview,
95
- stop_on_exception=stop_on_exception,
96
- sidecar_model=sidecar_model,
97
- raise_validation_errors=raise_validation_errors,
98
- )
99
- )
100
- active_tasks.add(task)
101
- # Add callback to remove task from set when done
102
- task.add_done_callback(active_tasks.discard)
103
- except StopIteration:
104
- break
105
-
106
- if not active_tasks:
107
- break
108
-
109
- # Wait for next completed task
110
- done, _ = await asyncio.wait(
111
- active_tasks, return_when=asyncio.FIRST_COMPLETED
112
- )
113
-
114
- # Process completed tasks
115
- for task in done:
116
- try:
117
- result = await task
118
- self.jobs_runner_status.add_completed_interview(result)
119
- yield result
120
- except Exception as e:
121
- if stop_on_exception:
122
- # Cancel remaining tasks
123
- for t in active_tasks:
124
- if not t.done():
125
- t.cancel()
126
- raise
127
- else:
128
- # Log error and continue
129
- # logger.error(f"Task failed with error: {e}")
130
- continue
131
- finally:
132
- # Ensure we cancel any remaining tasks if we exit early
133
- for task in active_tasks:
134
- if not task.done():
135
- task.cancel()
136
-
137
- def _populate_total_interviews(
138
- self, n: int = 1
139
- ) -> Generator["Interview", None, None]:
140
- """Populates self.total_interviews with n copies of each interview.
141
-
142
- :param n: how many times to run each interview.
143
- """
144
- for interview in self.interviews:
145
- for iteration in range(n):
146
- if iteration > 0:
147
- yield interview.duplicate(iteration=iteration, cache=self.cache)
148
- else:
149
- interview.cache = self.cache
150
- yield interview
151
-
152
- async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
153
- """Used for some other modules that have a non-standard way of running interviews."""
154
- self.jobs_runner_status = JobsRunnerStatus(self, n=n)
155
- self.cache = Cache() if cache is None else cache
156
- data = []
157
- async for result in self.run_async_generator(cache=self.cache, n=n):
158
- data.append(result)
159
- return Results(survey=self.jobs.survey, data=data)
160
-
161
- def simple_run(self):
162
- data = asyncio.run(self.run_async())
163
- return Results(survey=self.jobs.survey, data=data)
164
-
165
- async def _build_interview_task(
166
- self,
167
- *,
168
- interview: Interview,
169
- stop_on_exception: bool = False,
170
- sidecar_model: Optional["LanguageModel"] = None,
171
- raise_validation_errors: bool = False,
172
- ) -> "Result":
173
- """Conducts an interview and returns the result.
174
-
175
- :param interview: the interview to conduct
176
- :param stop_on_exception: stops the interview if an exception is raised
177
- :param sidecar_model: a language model to use in addition to the interview's model
178
- """
179
- # the model buckets are used to track usage rates
180
- model_buckets = self.bucket_collection[interview.model]
181
-
182
- # get the results of the interview
183
- answer, valid_results = await interview.async_conduct_interview(
184
- model_buckets=model_buckets,
185
- stop_on_exception=stop_on_exception,
186
- sidecar_model=sidecar_model,
187
- raise_validation_errors=raise_validation_errors,
188
- )
189
-
190
- question_results = {}
191
- for result in valid_results:
192
- question_results[result.question_name] = result
193
-
194
- answer_key_names = list(question_results.keys())
195
-
196
- generated_tokens_dict = {
197
- k + "_generated_tokens": question_results[k].generated_tokens
198
- for k in answer_key_names
199
- }
200
- comments_dict = {
201
- k + "_comment": question_results[k].comment for k in answer_key_names
202
- }
203
-
204
- # we should have a valid result for each question
205
- answer_dict = {k: answer[k] for k in answer_key_names}
206
- assert len(valid_results) == len(answer_key_names)
207
-
208
- # TODO: move this down into Interview
209
- question_name_to_prompts = dict({})
210
- for result in valid_results:
211
- question_name = result.question_name
212
- question_name_to_prompts[question_name] = {
213
- "user_prompt": result.prompts["user_prompt"],
214
- "system_prompt": result.prompts["system_prompt"],
215
- }
216
-
217
- prompt_dictionary = {}
218
- for answer_key_name in answer_key_names:
219
- prompt_dictionary[
220
- answer_key_name + "_user_prompt"
221
- ] = question_name_to_prompts[answer_key_name]["user_prompt"]
222
- prompt_dictionary[
223
- answer_key_name + "_system_prompt"
224
- ] = question_name_to_prompts[answer_key_name]["system_prompt"]
225
-
226
- raw_model_results_dictionary = {}
227
- cache_used_dictionary = {}
228
- for result in valid_results:
229
- question_name = result.question_name
230
- raw_model_results_dictionary[
231
- question_name + "_raw_model_response"
232
- ] = result.raw_model_response
233
- raw_model_results_dictionary[question_name + "_cost"] = result.cost
234
- one_use_buys = (
235
- "NA"
236
- if isinstance(result.cost, str)
237
- or result.cost == 0
238
- or result.cost is None
239
- else 1.0 / result.cost
240
- )
241
- raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
242
- cache_used_dictionary[question_name] = result.cache_used
243
-
244
- result = Result(
245
- agent=interview.agent,
246
- scenario=interview.scenario,
247
- model=interview.model,
248
- iteration=interview.iteration,
249
- answer=answer_dict,
250
- prompt=prompt_dictionary,
251
- raw_model_response=raw_model_results_dictionary,
252
- survey=interview.survey,
253
- generated_tokens=generated_tokens_dict,
254
- comments_dict=comments_dict,
255
- cache_used_dict=cache_used_dictionary,
256
- indices=interview.indices,
257
- )
258
- result.interview_hash = hash(interview)
259
-
260
- return result
261
-
262
- @property
263
- def elapsed_time(self):
264
- return time.monotonic() - self.start_time
265
-
266
- def process_results(
267
- self, raw_results: Results, cache: Cache, print_exceptions: bool
268
- ):
269
- interview_lookup = {
270
- hash(interview): index
271
- for index, interview in enumerate(self.total_interviews)
272
- }
273
- interview_hashes = list(interview_lookup.keys())
274
-
275
- task_history = TaskHistory(self.total_interviews, include_traceback=False)
276
-
277
- results = Results(
278
- survey=self.jobs.survey,
279
- data=sorted(
280
- raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
281
- ),
282
- task_history=task_history,
283
- cache=cache,
284
- )
285
- results.bucket_collection = self.bucket_collection
286
-
287
- if results.has_unfixed_exceptions and print_exceptions:
288
- from edsl.scenarios.FileStore import HTMLFileStore
289
- from edsl.config import CONFIG
290
- from edsl.coop.coop import Coop
291
-
292
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
293
-
294
- if len(results.task_history.indices) > 5:
295
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
296
-
297
- import sys
298
-
299
- print(msg, file=sys.stderr)
300
- from edsl.config import CONFIG
301
-
302
- if CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "True":
303
- open_in_browser = True
304
- elif CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "False":
305
- open_in_browser = False
306
- else:
307
- raise Exception(
308
- "EDSL_OPEN_EXCEPTION_REPORT_URL", "must be either True or False"
309
- )
310
-
311
- filepath = results.task_history.html(
312
- cta="Open report to see details.",
313
- open_in_browser=open_in_browser,
314
- return_link=True,
315
- )
316
-
317
- try:
318
- coop = Coop()
319
- user_edsl_settings = coop.edsl_settings
320
- remote_logging = user_edsl_settings["remote_logging"]
321
- except Exception as e:
322
- print(e)
323
- remote_logging = False
324
-
325
- if remote_logging:
326
- filestore = HTMLFileStore(filepath)
327
- coop_details = filestore.push(description="Error report")
328
- print(coop_details)
329
-
330
- print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
331
-
332
- return results
333
-
334
- @jupyter_nb_handler
335
- async def run(
336
- self,
337
- cache: Union[Cache, False, None],
338
- n: int = 1,
339
- stop_on_exception: bool = False,
340
- progress_bar: bool = False,
341
- sidecar_model: Optional[LanguageModel] = None,
342
- jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
343
- job_uuid: Optional[UUID] = None,
344
- print_exceptions: bool = True,
345
- raise_validation_errors: bool = False,
346
- ) -> "Coroutine":
347
- """Runs a collection of interviews, handling both async and sync contexts."""
348
-
349
- self.results = []
350
- self.start_time = time.monotonic()
351
- self.completed = False
352
- self.cache = cache
353
- self.sidecar_model = sidecar_model
354
-
355
- from edsl.coop import Coop
356
-
357
- coop = Coop()
358
- endpoint_url = coop.get_progress_bar_url()
359
-
360
- if jobs_runner_status is not None:
361
- self.jobs_runner_status = jobs_runner_status(
362
- self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
363
- )
364
- else:
365
- self.jobs_runner_status = JobsRunnerStatus(
366
- self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
367
- )
368
-
369
- stop_event = threading.Event()
370
-
371
- async def process_results(cache):
372
- """Processes results from interviews."""
373
- async for result in self.run_async_generator(
374
- n=n,
375
- stop_on_exception=stop_on_exception,
376
- cache=cache,
377
- sidecar_model=sidecar_model,
378
- raise_validation_errors=raise_validation_errors,
379
- ):
380
- self.results.append(result)
381
- self.completed = True
382
-
383
- def run_progress_bar(stop_event):
384
- """Runs the progress bar in a separate thread."""
385
- self.jobs_runner_status.update_progress(stop_event)
386
-
387
- if progress_bar and self.jobs_runner_status.has_ep_api_key():
388
- self.jobs_runner_status.setup()
389
- progress_thread = threading.Thread(
390
- target=run_progress_bar, args=(stop_event,)
391
- )
392
- progress_thread.start()
393
- elif progress_bar:
394
- warnings.warn(
395
- "You need an Expected Parrot API key to view job progress bars."
396
- )
397
-
398
- exception_to_raise = None
399
- try:
400
- with cache as c:
401
- await process_results(cache=c)
402
- except KeyboardInterrupt:
403
- print("Keyboard interrupt received. Stopping gracefully...")
404
- stop_event.set()
405
- except Exception as e:
406
- if stop_on_exception:
407
- exception_to_raise = e
408
- stop_event.set()
409
- finally:
410
- stop_event.set()
411
- if progress_bar and self.jobs_runner_status.has_ep_api_key():
412
- # self.jobs_runner_status.stop_event.set()
413
- if progress_thread:
414
- progress_thread.join()
415
-
416
- if exception_to_raise:
417
- raise exception_to_raise
418
-
419
- return self.process_results(
420
- raw_results=self.results, cache=cache, print_exceptions=print_exceptions
421
- )
1
+ from __future__ import annotations
2
+ import time
3
+ import asyncio
4
+ import threading
5
+ import warnings
6
+ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
7
+ from uuid import UUID
8
+ from collections import UserList
9
+
10
+ from edsl.results.Results import Results
11
+ from edsl.jobs.interviews.Interview import Interview
12
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
13
+
14
+ from edsl.jobs.tasks.TaskHistory import TaskHistory
15
+ from edsl.jobs.buckets.BucketCollection import BucketCollection
16
+ from edsl.utilities.decorators import jupyter_nb_handler
17
+ from edsl.data.Cache import Cache
18
+ from edsl.results.Result import Result
19
+ from edsl.results.Results import Results
20
+ from edsl.language_models.LanguageModel import LanguageModel
21
+ from edsl.data.Cache import Cache
22
+
23
+
24
+ class StatusTracker(UserList):
25
+ def __init__(self, total_tasks: int):
26
+ self.total_tasks = total_tasks
27
+ super().__init__()
28
+
29
+ def current_status(self):
30
+ return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
31
+
32
+
33
+ class JobsRunnerAsyncio:
34
+ """A class for running a collection of interviews asynchronously.
35
+
36
+ It gets instaniated from a Jobs object.
37
+ The Jobs object is a collection of interviews that are to be run.
38
+ """
39
+
40
+ MAX_CONCURRENT_DEFAULT = 500
41
+
42
+ def __init__(self, jobs: "Jobs"):
43
+ self.jobs = jobs
44
+ self.interviews: List["Interview"] = jobs.interviews()
45
+ self.bucket_collection: "BucketCollection" = jobs.bucket_collection
46
+ self.total_interviews: List["Interview"] = []
47
+ self._initialized = threading.Event()
48
+
49
+ from edsl.config import CONFIG
50
+
51
+ self.MAX_CONCURRENT = int(CONFIG.get("EDSL_MAX_CONCURRENT_TASKS"))
52
+ # print(f"MAX_CONCURRENT: {self.MAX_CONCURRENT}")
53
+
54
+ # async def run_async_generator(
55
+ # self,
56
+ # cache: Cache,
57
+ # n: int = 1,
58
+ # stop_on_exception: bool = False,
59
+ # sidecar_model: Optional[LanguageModel] = None,
60
+ # total_interviews: Optional[List["Interview"]] = None,
61
+ # raise_validation_errors: bool = False,
62
+ # ) -> AsyncGenerator["Result", None]:
63
+ # """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
64
+
65
+ # Completed tasks are yielded as they are completed.
66
+
67
+ # :param n: how many times to run each interview
68
+ # :param stop_on_exception: Whether to stop the interview if an exception is raised
69
+ # :param sidecar_model: a language model to use in addition to the interview's model
70
+ # :param total_interviews: A list of interviews to run can be provided instead.
71
+ # :param raise_validation_errors: Whether to raise validation errors
72
+ # """
73
+ # tasks = []
74
+ # if total_interviews: # was already passed in total interviews
75
+ # self.total_interviews = total_interviews
76
+ # else:
77
+ # self.total_interviews = list(
78
+ # self._populate_total_interviews(n=n)
79
+ # ) # Populate self.total_interviews before creating tasks
80
+ # self._initialized.set() # Signal that we're ready
81
+
82
+ # for interview in self.total_interviews:
83
+ # interviewing_task = self._build_interview_task(
84
+ # interview=interview,
85
+ # stop_on_exception=stop_on_exception,
86
+ # sidecar_model=sidecar_model,
87
+ # raise_validation_errors=raise_validation_errors,
88
+ # )
89
+ # tasks.append(asyncio.create_task(interviewing_task))
90
+
91
+ # for task in asyncio.as_completed(tasks):
92
+ # result = await task
93
+ # self.jobs_runner_status.add_completed_interview(result)
94
+ # yield result
95
+
96
+ async def run_async_generator(
97
+ self,
98
+ cache: Cache,
99
+ n: int = 1,
100
+ stop_on_exception: bool = False,
101
+ sidecar_model: Optional[LanguageModel] = None,
102
+ total_interviews: Optional[List["Interview"]] = None,
103
+ raise_validation_errors: bool = False,
104
+ ) -> AsyncGenerator["Result", None]:
105
+ """Creates and processes tasks asynchronously, yielding results as they complete.
106
+
107
+ Tasks are created and processed in a streaming fashion rather than building the full list upfront.
108
+ Results are yielded as soon as they are available.
109
+
110
+ :param n: how many times to run each interview
111
+ :param stop_on_exception: Whether to stop the interview if an exception is raised
112
+ :param sidecar_model: a language model to use in addition to the interview's model
113
+ :param total_interviews: A list of interviews to run can be provided instead.
114
+ :param raise_validation_errors: Whether to raise validation errors
115
+ """
116
+ # Initialize interviews iterator
117
+ if total_interviews:
118
+ interviews_iter = iter(total_interviews)
119
+ self.total_interviews = total_interviews
120
+ else:
121
+ interviews_iter = self._populate_total_interviews(n=n)
122
+ self.total_interviews = list(interviews_iter)
123
+ interviews_iter = iter(self.total_interviews) # Create fresh iterator
124
+
125
+ self._initialized.set() # Signal that we're ready
126
+
127
+ # Keep track of active tasks
128
+ active_tasks = set()
129
+
130
+ try:
131
+ while True:
132
+ # Add new tasks if we're below max_concurrent and there are more interviews
133
+ while len(active_tasks) < self.MAX_CONCURRENT:
134
+ try:
135
+ interview = next(interviews_iter)
136
+ task = asyncio.create_task(
137
+ self._build_interview_task(
138
+ interview=interview,
139
+ stop_on_exception=stop_on_exception,
140
+ sidecar_model=sidecar_model,
141
+ raise_validation_errors=raise_validation_errors,
142
+ )
143
+ )
144
+ active_tasks.add(task)
145
+ # Add callback to remove task from set when done
146
+ task.add_done_callback(active_tasks.discard)
147
+ except StopIteration:
148
+ break
149
+
150
+ if not active_tasks:
151
+ break
152
+
153
+ # Wait for next completed task
154
+ done, _ = await asyncio.wait(
155
+ active_tasks, return_when=asyncio.FIRST_COMPLETED
156
+ )
157
+
158
+ # Process completed tasks
159
+ for task in done:
160
+ try:
161
+ result = await task
162
+ self.jobs_runner_status.add_completed_interview(result)
163
+ yield result
164
+ except Exception as e:
165
+ if stop_on_exception:
166
+ # Cancel remaining tasks
167
+ for t in active_tasks:
168
+ if not t.done():
169
+ t.cancel()
170
+ raise
171
+ else:
172
+ # Log error and continue
173
+ # logger.error(f"Task failed with error: {e}")
174
+ continue
175
+ finally:
176
+ # Ensure we cancel any remaining tasks if we exit early
177
+ for task in active_tasks:
178
+ if not task.done():
179
+ task.cancel()
180
+
181
+ def _populate_total_interviews(
182
+ self, n: int = 1
183
+ ) -> Generator["Interview", None, None]:
184
+ """Populates self.total_interviews with n copies of each interview.
185
+
186
+ :param n: how many times to run each interview.
187
+ """
188
+ for interview in self.interviews:
189
+ for iteration in range(n):
190
+ if iteration > 0:
191
+ yield interview.duplicate(iteration=iteration, cache=self.cache)
192
+ else:
193
+ interview.cache = self.cache
194
+ yield interview
195
+
196
+ async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
197
+ """Used for some other modules that have a non-standard way of running interviews."""
198
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
199
+ self.cache = Cache() if cache is None else cache
200
+ data = []
201
+ async for result in self.run_async_generator(cache=self.cache, n=n):
202
+ data.append(result)
203
+ return Results(survey=self.jobs.survey, data=data)
204
+
205
+ def simple_run(self):
206
+ data = asyncio.run(self.run_async())
207
+ return Results(survey=self.jobs.survey, data=data)
208
+
209
+ async def _build_interview_task(
210
+ self,
211
+ *,
212
+ interview: Interview,
213
+ stop_on_exception: bool = False,
214
+ sidecar_model: Optional["LanguageModel"] = None,
215
+ raise_validation_errors: bool = False,
216
+ ) -> "Result":
217
+ """Conducts an interview and returns the result.
218
+
219
+ :param interview: the interview to conduct
220
+ :param stop_on_exception: stops the interview if an exception is raised
221
+ :param sidecar_model: a language model to use in addition to the interview's model
222
+ """
223
+ # the model buckets are used to track usage rates
224
+ model_buckets = self.bucket_collection[interview.model]
225
+
226
+ # get the results of the interview
227
+ answer, valid_results = await interview.async_conduct_interview(
228
+ model_buckets=model_buckets,
229
+ stop_on_exception=stop_on_exception,
230
+ sidecar_model=sidecar_model,
231
+ raise_validation_errors=raise_validation_errors,
232
+ )
233
+
234
+ question_results = {}
235
+ for result in valid_results:
236
+ question_results[result.question_name] = result
237
+
238
+ answer_key_names = list(question_results.keys())
239
+
240
+ generated_tokens_dict = {
241
+ k + "_generated_tokens": question_results[k].generated_tokens
242
+ for k in answer_key_names
243
+ }
244
+ comments_dict = {
245
+ k + "_comment": question_results[k].comment for k in answer_key_names
246
+ }
247
+
248
+ # we should have a valid result for each question
249
+ answer_dict = {k: answer[k] for k in answer_key_names}
250
+ assert len(valid_results) == len(answer_key_names)
251
+
252
+ # TODO: move this down into Interview
253
+ question_name_to_prompts = dict({})
254
+ for result in valid_results:
255
+ question_name = result.question_name
256
+ question_name_to_prompts[question_name] = {
257
+ "user_prompt": result.prompts["user_prompt"],
258
+ "system_prompt": result.prompts["system_prompt"],
259
+ }
260
+
261
+ prompt_dictionary = {}
262
+ for answer_key_name in answer_key_names:
263
+ prompt_dictionary[
264
+ answer_key_name + "_user_prompt"
265
+ ] = question_name_to_prompts[answer_key_name]["user_prompt"]
266
+ prompt_dictionary[
267
+ answer_key_name + "_system_prompt"
268
+ ] = question_name_to_prompts[answer_key_name]["system_prompt"]
269
+
270
+ raw_model_results_dictionary = {}
271
+ cache_used_dictionary = {}
272
+ for result in valid_results:
273
+ question_name = result.question_name
274
+ raw_model_results_dictionary[
275
+ question_name + "_raw_model_response"
276
+ ] = result.raw_model_response
277
+ raw_model_results_dictionary[question_name + "_cost"] = result.cost
278
+ one_use_buys = (
279
+ "NA"
280
+ if isinstance(result.cost, str)
281
+ or result.cost == 0
282
+ or result.cost is None
283
+ else 1.0 / result.cost
284
+ )
285
+ raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
286
+ cache_used_dictionary[question_name] = result.cache_used
287
+
288
+ result = Result(
289
+ agent=interview.agent,
290
+ scenario=interview.scenario,
291
+ model=interview.model,
292
+ iteration=interview.iteration,
293
+ answer=answer_dict,
294
+ prompt=prompt_dictionary,
295
+ raw_model_response=raw_model_results_dictionary,
296
+ survey=interview.survey,
297
+ generated_tokens=generated_tokens_dict,
298
+ comments_dict=comments_dict,
299
+ cache_used_dict=cache_used_dictionary,
300
+ )
301
+ result.interview_hash = hash(interview)
302
+
303
+ return result
304
+
305
+ @property
306
+ def elapsed_time(self):
307
+ return time.monotonic() - self.start_time
308
+
309
+ def process_results(
310
+ self, raw_results: Results, cache: Cache, print_exceptions: bool
311
+ ):
312
+ interview_lookup = {
313
+ hash(interview): index
314
+ for index, interview in enumerate(self.total_interviews)
315
+ }
316
+ interview_hashes = list(interview_lookup.keys())
317
+
318
+ task_history = TaskHistory(self.total_interviews, include_traceback=False)
319
+
320
+ results = Results(
321
+ survey=self.jobs.survey,
322
+ data=sorted(
323
+ raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
324
+ ),
325
+ task_history=task_history,
326
+ cache=cache,
327
+ )
328
+ results.bucket_collection = self.bucket_collection
329
+
330
+ if results.has_unfixed_exceptions and print_exceptions:
331
+ from edsl.scenarios.FileStore import HTMLFileStore
332
+ from edsl.config import CONFIG
333
+ from edsl.coop.coop import Coop
334
+
335
+ msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
336
+
337
+ if len(results.task_history.indices) > 5:
338
+ msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
339
+
340
+ import sys
341
+
342
+ print(msg, file=sys.stderr)
343
+ from edsl.config import CONFIG
344
+
345
+ if CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "True":
346
+ open_in_browser = True
347
+ elif CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "False":
348
+ open_in_browser = False
349
+ else:
350
+ raise Exception(
351
+ "EDSL_OPEN_EXCEPTION_REPORT_URL", "must be either True or False"
352
+ )
353
+
354
+ # print("open_in_browser", open_in_browser)
355
+
356
+ filepath = results.task_history.html(
357
+ cta="Open report to see details.",
358
+ open_in_browser=open_in_browser,
359
+ return_link=True,
360
+ )
361
+
362
+ try:
363
+ coop = Coop()
364
+ user_edsl_settings = coop.edsl_settings
365
+ remote_logging = user_edsl_settings["remote_logging"]
366
+ except Exception as e:
367
+ print(e)
368
+ remote_logging = False
369
+
370
+ if remote_logging:
371
+ filestore = HTMLFileStore(filepath)
372
+ coop_details = filestore.push(description="Error report")
373
+ print(coop_details)
374
+
375
+ print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
376
+
377
+ return results
378
+
379
+ @jupyter_nb_handler
380
+ async def run(
381
+ self,
382
+ cache: Union[Cache, False, None],
383
+ n: int = 1,
384
+ stop_on_exception: bool = False,
385
+ progress_bar: bool = False,
386
+ sidecar_model: Optional[LanguageModel] = None,
387
+ jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
388
+ job_uuid: Optional[UUID] = None,
389
+ print_exceptions: bool = True,
390
+ raise_validation_errors: bool = False,
391
+ ) -> "Coroutine":
392
+ """Runs a collection of interviews, handling both async and sync contexts."""
393
+
394
+ self.results = []
395
+ self.start_time = time.monotonic()
396
+ self.completed = False
397
+ self.cache = cache
398
+ self.sidecar_model = sidecar_model
399
+
400
+ from edsl.coop import Coop
401
+
402
+ coop = Coop()
403
+ endpoint_url = coop.get_progress_bar_url()
404
+
405
+ if jobs_runner_status is not None:
406
+ self.jobs_runner_status = jobs_runner_status(
407
+ self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
408
+ )
409
+ else:
410
+ self.jobs_runner_status = JobsRunnerStatus(
411
+ self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
412
+ )
413
+
414
+ stop_event = threading.Event()
415
+
416
+ async def process_results(cache):
417
+ """Processes results from interviews."""
418
+ async for result in self.run_async_generator(
419
+ n=n,
420
+ stop_on_exception=stop_on_exception,
421
+ cache=cache,
422
+ sidecar_model=sidecar_model,
423
+ raise_validation_errors=raise_validation_errors,
424
+ ):
425
+ self.results.append(result)
426
+ self.completed = True
427
+
428
+ def run_progress_bar(stop_event):
429
+ """Runs the progress bar in a separate thread."""
430
+ self.jobs_runner_status.update_progress(stop_event)
431
+
432
+ if progress_bar and self.jobs_runner_status.has_ep_api_key():
433
+ self.jobs_runner_status.setup()
434
+ progress_thread = threading.Thread(
435
+ target=run_progress_bar, args=(stop_event,)
436
+ )
437
+ progress_thread.start()
438
+ elif progress_bar:
439
+ warnings.warn(
440
+ "You need an Expected Parrot API key to view job progress bars."
441
+ )
442
+
443
+ exception_to_raise = None
444
+ try:
445
+ with cache as c:
446
+ await process_results(cache=c)
447
+ except KeyboardInterrupt:
448
+ print("Keyboard interrupt received. Stopping gracefully...")
449
+ stop_event.set()
450
+ except Exception as e:
451
+ if stop_on_exception:
452
+ exception_to_raise = e
453
+ stop_event.set()
454
+ finally:
455
+ stop_event.set()
456
+ if progress_bar and self.jobs_runner_status.has_ep_api_key():
457
+ # self.jobs_runner_status.stop_event.set()
458
+ if progress_thread:
459
+ progress_thread.join()
460
+
461
+ if exception_to_raise:
462
+ raise exception_to_raise
463
+
464
+ return self.process_results(
465
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
466
+ )