edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (344) hide show
  1. edsl/Base.py +413 -332
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -867
  7. edsl/agents/AgentList.py +551 -413
  8. edsl/agents/Invigilator.py +284 -233
  9. edsl/agents/InvigilatorBase.py +257 -270
  10. edsl/agents/PromptConstructor.py +272 -354
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -157
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -1028
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -555
  42. edsl/data/CacheEntry.py +230 -233
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -78
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/hack.py +10 -0
  48. edsl/data/orm.py +10 -10
  49. edsl/data_transfer_models.py +74 -73
  50. edsl/enums.py +202 -175
  51. edsl/exceptions/BaseException.py +21 -21
  52. edsl/exceptions/__init__.py +54 -54
  53. edsl/exceptions/agents.py +54 -42
  54. edsl/exceptions/cache.py +5 -5
  55. edsl/exceptions/configuration.py +16 -16
  56. edsl/exceptions/coop.py +10 -10
  57. edsl/exceptions/data.py +14 -14
  58. edsl/exceptions/general.py +34 -34
  59. edsl/exceptions/inference_services.py +5 -0
  60. edsl/exceptions/jobs.py +33 -33
  61. edsl/exceptions/language_models.py +63 -63
  62. edsl/exceptions/prompts.py +15 -15
  63. edsl/exceptions/questions.py +109 -91
  64. edsl/exceptions/results.py +29 -29
  65. edsl/exceptions/scenarios.py +29 -22
  66. edsl/exceptions/surveys.py +37 -37
  67. edsl/inference_services/AnthropicService.py +106 -87
  68. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  69. edsl/inference_services/AvailableModelFetcher.py +215 -0
  70. edsl/inference_services/AwsBedrock.py +118 -120
  71. edsl/inference_services/AzureAI.py +215 -217
  72. edsl/inference_services/DeepInfraService.py +18 -18
  73. edsl/inference_services/GoogleService.py +143 -148
  74. edsl/inference_services/GroqService.py +20 -20
  75. edsl/inference_services/InferenceServiceABC.py +80 -147
  76. edsl/inference_services/InferenceServicesCollection.py +138 -97
  77. edsl/inference_services/MistralAIService.py +120 -123
  78. edsl/inference_services/OllamaService.py +18 -18
  79. edsl/inference_services/OpenAIService.py +236 -224
  80. edsl/inference_services/PerplexityService.py +160 -163
  81. edsl/inference_services/ServiceAvailability.py +135 -0
  82. edsl/inference_services/TestService.py +90 -89
  83. edsl/inference_services/TogetherAIService.py +172 -170
  84. edsl/inference_services/data_structures.py +134 -0
  85. edsl/inference_services/models_available_cache.py +118 -118
  86. edsl/inference_services/rate_limits_cache.py +25 -25
  87. edsl/inference_services/registry.py +41 -41
  88. edsl/inference_services/write_available.py +10 -10
  89. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  90. edsl/jobs/Answers.py +43 -56
  91. edsl/jobs/FetchInvigilator.py +47 -0
  92. edsl/jobs/InterviewTaskManager.py +98 -0
  93. edsl/jobs/InterviewsConstructor.py +50 -0
  94. edsl/jobs/Jobs.py +823 -898
  95. edsl/jobs/JobsChecks.py +172 -147
  96. edsl/jobs/JobsComponentConstructor.py +189 -0
  97. edsl/jobs/JobsPrompts.py +270 -268
  98. edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
  99. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  100. edsl/jobs/RequestTokenEstimator.py +30 -0
  101. edsl/jobs/__init__.py +1 -1
  102. edsl/jobs/async_interview_runner.py +138 -0
  103. edsl/jobs/buckets/BucketCollection.py +104 -63
  104. edsl/jobs/buckets/ModelBuckets.py +65 -65
  105. edsl/jobs/buckets/TokenBucket.py +283 -251
  106. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  107. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  108. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  109. edsl/jobs/data_structures.py +120 -0
  110. edsl/jobs/decorators.py +35 -0
  111. edsl/jobs/interviews/Interview.py +396 -661
  112. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  113. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  114. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  115. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  116. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  117. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  118. edsl/jobs/interviews/ReportErrors.py +66 -66
  119. edsl/jobs/interviews/interview_status_enum.py +9 -9
  120. edsl/jobs/jobs_status_enums.py +9 -0
  121. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  122. edsl/jobs/results_exceptions_handler.py +98 -0
  123. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
  124. edsl/jobs/runners/JobsRunnerStatus.py +297 -330
  125. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  126. edsl/jobs/tasks/TaskCreators.py +64 -64
  127. edsl/jobs/tasks/TaskHistory.py +470 -450
  128. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  129. edsl/jobs/tasks/task_status_enum.py +161 -163
  130. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  131. edsl/jobs/tokens/TokenUsage.py +34 -34
  132. edsl/language_models/ComputeCost.py +63 -0
  133. edsl/language_models/LanguageModel.py +626 -668
  134. edsl/language_models/ModelList.py +164 -155
  135. edsl/language_models/PriceManager.py +127 -0
  136. edsl/language_models/RawResponseHandler.py +106 -0
  137. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  138. edsl/language_models/ServiceDataSources.py +0 -0
  139. edsl/language_models/__init__.py +2 -3
  140. edsl/language_models/fake_openai_call.py +15 -15
  141. edsl/language_models/fake_openai_service.py +61 -61
  142. edsl/language_models/key_management/KeyLookup.py +63 -0
  143. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  144. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  145. edsl/language_models/key_management/__init__.py +0 -0
  146. edsl/language_models/key_management/models.py +131 -0
  147. edsl/language_models/model.py +256 -0
  148. edsl/language_models/repair.py +156 -156
  149. edsl/language_models/utilities.py +65 -64
  150. edsl/notebooks/Notebook.py +263 -258
  151. edsl/notebooks/NotebookToLaTeX.py +142 -0
  152. edsl/notebooks/__init__.py +1 -1
  153. edsl/prompts/Prompt.py +352 -362
  154. edsl/prompts/__init__.py +2 -2
  155. edsl/questions/ExceptionExplainer.py +77 -0
  156. edsl/questions/HTMLQuestion.py +103 -0
  157. edsl/questions/QuestionBase.py +518 -664
  158. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  159. edsl/questions/QuestionBudget.py +227 -227
  160. edsl/questions/QuestionCheckBox.py +359 -359
  161. edsl/questions/QuestionExtract.py +180 -182
  162. edsl/questions/QuestionFreeText.py +113 -114
  163. edsl/questions/QuestionFunctional.py +166 -166
  164. edsl/questions/QuestionList.py +223 -231
  165. edsl/questions/QuestionMatrix.py +265 -0
  166. edsl/questions/QuestionMultipleChoice.py +330 -286
  167. edsl/questions/QuestionNumerical.py +151 -153
  168. edsl/questions/QuestionRank.py +314 -324
  169. edsl/questions/Quick.py +41 -41
  170. edsl/questions/SimpleAskMixin.py +74 -73
  171. edsl/questions/__init__.py +27 -26
  172. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  173. edsl/questions/compose_questions.py +98 -98
  174. edsl/questions/data_structures.py +20 -0
  175. edsl/questions/decorators.py +21 -21
  176. edsl/questions/derived/QuestionLikertFive.py +76 -76
  177. edsl/questions/derived/QuestionLinearScale.py +90 -87
  178. edsl/questions/derived/QuestionTopK.py +93 -93
  179. edsl/questions/derived/QuestionYesNo.py +82 -82
  180. edsl/questions/descriptors.py +427 -413
  181. edsl/questions/loop_processor.py +149 -0
  182. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  183. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  184. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  185. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  186. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  187. edsl/questions/prompt_templates/question_list.jinja +17 -17
  188. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  189. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  190. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  191. edsl/questions/question_registry.py +177 -177
  192. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  193. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  194. edsl/questions/response_validator_factory.py +34 -0
  195. edsl/questions/settings.py +12 -12
  196. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  197. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  198. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  199. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  200. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  201. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  202. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  203. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  204. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  205. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  206. edsl/questions/templates/list/question_presentation.jinja +5 -5
  207. edsl/questions/templates/matrix/__init__.py +1 -0
  208. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  209. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  210. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  211. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  212. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  213. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  214. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  215. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  216. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  217. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  218. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  219. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  220. edsl/results/CSSParameterizer.py +108 -108
  221. edsl/results/Dataset.py +587 -424
  222. edsl/results/DatasetExportMixin.py +594 -731
  223. edsl/results/DatasetTree.py +295 -275
  224. edsl/results/MarkdownToDocx.py +122 -0
  225. edsl/results/MarkdownToPDF.py +111 -0
  226. edsl/results/Result.py +557 -465
  227. edsl/results/Results.py +1183 -1165
  228. edsl/results/ResultsExportMixin.py +45 -43
  229. edsl/results/ResultsGGMixin.py +121 -121
  230. edsl/results/TableDisplay.py +125 -198
  231. edsl/results/TextEditor.py +50 -0
  232. edsl/results/__init__.py +2 -2
  233. edsl/results/file_exports.py +252 -0
  234. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  235. edsl/results/{Selector.py → results_selector.py} +145 -135
  236. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  237. edsl/results/smart_objects.py +96 -0
  238. edsl/results/table_data_class.py +12 -0
  239. edsl/results/table_display.css +77 -77
  240. edsl/results/table_renderers.py +118 -0
  241. edsl/results/tree_explore.py +115 -115
  242. edsl/scenarios/ConstructDownloadLink.py +109 -0
  243. edsl/scenarios/DocumentChunker.py +102 -0
  244. edsl/scenarios/DocxScenario.py +16 -0
  245. edsl/scenarios/FileStore.py +511 -632
  246. edsl/scenarios/PdfExtractor.py +40 -0
  247. edsl/scenarios/Scenario.py +498 -601
  248. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  249. edsl/scenarios/ScenarioList.py +1458 -1287
  250. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  251. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  252. edsl/scenarios/__init__.py +3 -4
  253. edsl/scenarios/directory_scanner.py +96 -0
  254. edsl/scenarios/file_methods.py +85 -0
  255. edsl/scenarios/handlers/__init__.py +13 -0
  256. edsl/scenarios/handlers/csv.py +38 -0
  257. edsl/scenarios/handlers/docx.py +76 -0
  258. edsl/scenarios/handlers/html.py +37 -0
  259. edsl/scenarios/handlers/json.py +111 -0
  260. edsl/scenarios/handlers/latex.py +5 -0
  261. edsl/scenarios/handlers/md.py +51 -0
  262. edsl/scenarios/handlers/pdf.py +68 -0
  263. edsl/scenarios/handlers/png.py +39 -0
  264. edsl/scenarios/handlers/pptx.py +105 -0
  265. edsl/scenarios/handlers/py.py +294 -0
  266. edsl/scenarios/handlers/sql.py +313 -0
  267. edsl/scenarios/handlers/sqlite.py +149 -0
  268. edsl/scenarios/handlers/txt.py +33 -0
  269. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
  270. edsl/scenarios/scenario_selector.py +156 -0
  271. edsl/shared.py +1 -1
  272. edsl/study/ObjectEntry.py +173 -173
  273. edsl/study/ProofOfWork.py +113 -113
  274. edsl/study/SnapShot.py +80 -80
  275. edsl/study/Study.py +521 -528
  276. edsl/study/__init__.py +4 -4
  277. edsl/surveys/ConstructDAG.py +92 -0
  278. edsl/surveys/DAG.py +148 -148
  279. edsl/surveys/EditSurvey.py +221 -0
  280. edsl/surveys/InstructionHandler.py +100 -0
  281. edsl/surveys/Memory.py +31 -31
  282. edsl/surveys/MemoryManagement.py +72 -0
  283. edsl/surveys/MemoryPlan.py +244 -244
  284. edsl/surveys/Rule.py +327 -326
  285. edsl/surveys/RuleCollection.py +385 -387
  286. edsl/surveys/RuleManager.py +172 -0
  287. edsl/surveys/Simulator.py +75 -0
  288. edsl/surveys/Survey.py +1280 -1801
  289. edsl/surveys/SurveyCSS.py +273 -261
  290. edsl/surveys/SurveyExportMixin.py +259 -259
  291. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
  292. edsl/surveys/SurveyQualtricsImport.py +284 -284
  293. edsl/surveys/SurveyToApp.py +141 -0
  294. edsl/surveys/__init__.py +5 -3
  295. edsl/surveys/base.py +53 -53
  296. edsl/surveys/descriptors.py +60 -56
  297. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  298. edsl/surveys/instructions/Instruction.py +56 -65
  299. edsl/surveys/instructions/InstructionCollection.py +82 -77
  300. edsl/templates/error_reporting/base.html +23 -23
  301. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  302. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  303. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  304. edsl/templates/error_reporting/interview_details.html +115 -115
  305. edsl/templates/error_reporting/interviews.html +19 -19
  306. edsl/templates/error_reporting/overview.html +4 -4
  307. edsl/templates/error_reporting/performance_plot.html +1 -1
  308. edsl/templates/error_reporting/report.css +73 -73
  309. edsl/templates/error_reporting/report.html +117 -117
  310. edsl/templates/error_reporting/report.js +25 -25
  311. edsl/test_h +1 -0
  312. edsl/tools/__init__.py +1 -1
  313. edsl/tools/clusters.py +192 -192
  314. edsl/tools/embeddings.py +27 -27
  315. edsl/tools/embeddings_plotting.py +118 -118
  316. edsl/tools/plotting.py +112 -112
  317. edsl/tools/summarize.py +18 -18
  318. edsl/utilities/PrettyList.py +56 -0
  319. edsl/utilities/SystemInfo.py +28 -28
  320. edsl/utilities/__init__.py +22 -22
  321. edsl/utilities/ast_utilities.py +25 -25
  322. edsl/utilities/data/Registry.py +6 -6
  323. edsl/utilities/data/__init__.py +1 -1
  324. edsl/utilities/data/scooter_results.json +1 -1
  325. edsl/utilities/decorators.py +77 -77
  326. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  327. edsl/utilities/gcp_bucket/example.py +50 -0
  328. edsl/utilities/interface.py +627 -627
  329. edsl/utilities/is_notebook.py +18 -0
  330. edsl/utilities/is_valid_variable_name.py +11 -0
  331. edsl/utilities/naming_utilities.py +263 -263
  332. edsl/utilities/remove_edsl_version.py +24 -0
  333. edsl/utilities/repair_functions.py +28 -28
  334. edsl/utilities/restricted_python.py +70 -70
  335. edsl/utilities/utilities.py +436 -424
  336. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
  337. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
  338. edsl-0.1.39.dev4.dist-info/RECORD +361 -0
  339. edsl/language_models/KeyLookup.py +0 -30
  340. edsl/language_models/registry.py +0 -190
  341. edsl/language_models/unused/ReplicateBase.py +0 -83
  342. edsl/results/ResultsDBMixin.py +0 -238
  343. edsl-0.1.39.dev3.dist-info/RECORD +0 -277
  344. {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -1,466 +1,151 @@
1
- from __future__ import annotations
2
- import time
3
- import asyncio
4
- import threading
5
- import warnings
6
- from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
7
- from uuid import UUID
8
- from collections import UserList
9
-
10
- from edsl.results.Results import Results
11
- from edsl.jobs.interviews.Interview import Interview
12
- from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
13
-
14
- from edsl.jobs.tasks.TaskHistory import TaskHistory
15
- from edsl.jobs.buckets.BucketCollection import BucketCollection
16
- from edsl.utilities.decorators import jupyter_nb_handler
17
- from edsl.data.Cache import Cache
18
- from edsl.results.Result import Result
19
- from edsl.results.Results import Results
20
- from edsl.language_models.LanguageModel import LanguageModel
21
- from edsl.data.Cache import Cache
22
-
23
-
24
- class StatusTracker(UserList):
25
- def __init__(self, total_tasks: int):
26
- self.total_tasks = total_tasks
27
- super().__init__()
28
-
29
- def current_status(self):
30
- return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
31
-
32
-
33
- class JobsRunnerAsyncio:
34
- """A class for running a collection of interviews asynchronously.
35
-
36
- It gets instaniated from a Jobs object.
37
- The Jobs object is a collection of interviews that are to be run.
38
- """
39
-
40
- MAX_CONCURRENT_DEFAULT = 500
41
-
42
- def __init__(self, jobs: "Jobs"):
43
- self.jobs = jobs
44
- self.interviews: List["Interview"] = jobs.interviews()
45
- self.bucket_collection: "BucketCollection" = jobs.bucket_collection
46
- self.total_interviews: List["Interview"] = []
47
- self._initialized = threading.Event()
48
-
49
- from edsl.config import CONFIG
50
-
51
- self.MAX_CONCURRENT = int(CONFIG.get("EDSL_MAX_CONCURRENT_TASKS"))
52
- # print(f"MAX_CONCURRENT: {self.MAX_CONCURRENT}")
53
-
54
- # async def run_async_generator(
55
- # self,
56
- # cache: Cache,
57
- # n: int = 1,
58
- # stop_on_exception: bool = False,
59
- # sidecar_model: Optional[LanguageModel] = None,
60
- # total_interviews: Optional[List["Interview"]] = None,
61
- # raise_validation_errors: bool = False,
62
- # ) -> AsyncGenerator["Result", None]:
63
- # """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
64
-
65
- # Completed tasks are yielded as they are completed.
66
-
67
- # :param n: how many times to run each interview
68
- # :param stop_on_exception: Whether to stop the interview if an exception is raised
69
- # :param sidecar_model: a language model to use in addition to the interview's model
70
- # :param total_interviews: A list of interviews to run can be provided instead.
71
- # :param raise_validation_errors: Whether to raise validation errors
72
- # """
73
- # tasks = []
74
- # if total_interviews: # was already passed in total interviews
75
- # self.total_interviews = total_interviews
76
- # else:
77
- # self.total_interviews = list(
78
- # self._populate_total_interviews(n=n)
79
- # ) # Populate self.total_interviews before creating tasks
80
- # self._initialized.set() # Signal that we're ready
81
-
82
- # for interview in self.total_interviews:
83
- # interviewing_task = self._build_interview_task(
84
- # interview=interview,
85
- # stop_on_exception=stop_on_exception,
86
- # sidecar_model=sidecar_model,
87
- # raise_validation_errors=raise_validation_errors,
88
- # )
89
- # tasks.append(asyncio.create_task(interviewing_task))
90
-
91
- # for task in asyncio.as_completed(tasks):
92
- # result = await task
93
- # self.jobs_runner_status.add_completed_interview(result)
94
- # yield result
95
-
96
- async def run_async_generator(
97
- self,
98
- cache: Cache,
99
- n: int = 1,
100
- stop_on_exception: bool = False,
101
- sidecar_model: Optional[LanguageModel] = None,
102
- total_interviews: Optional[List["Interview"]] = None,
103
- raise_validation_errors: bool = False,
104
- ) -> AsyncGenerator["Result", None]:
105
- """Creates and processes tasks asynchronously, yielding results as they complete.
106
-
107
- Tasks are created and processed in a streaming fashion rather than building the full list upfront.
108
- Results are yielded as soon as they are available.
109
-
110
- :param n: how many times to run each interview
111
- :param stop_on_exception: Whether to stop the interview if an exception is raised
112
- :param sidecar_model: a language model to use in addition to the interview's model
113
- :param total_interviews: A list of interviews to run can be provided instead.
114
- :param raise_validation_errors: Whether to raise validation errors
115
- """
116
- # Initialize interviews iterator
117
- if total_interviews:
118
- interviews_iter = iter(total_interviews)
119
- self.total_interviews = total_interviews
120
- else:
121
- interviews_iter = self._populate_total_interviews(n=n)
122
- self.total_interviews = list(interviews_iter)
123
- interviews_iter = iter(self.total_interviews) # Create fresh iterator
124
-
125
- self._initialized.set() # Signal that we're ready
126
-
127
- # Keep track of active tasks
128
- active_tasks = set()
129
-
130
- try:
131
- while True:
132
- # Add new tasks if we're below max_concurrent and there are more interviews
133
- while len(active_tasks) < self.MAX_CONCURRENT:
134
- try:
135
- interview = next(interviews_iter)
136
- task = asyncio.create_task(
137
- self._build_interview_task(
138
- interview=interview,
139
- stop_on_exception=stop_on_exception,
140
- sidecar_model=sidecar_model,
141
- raise_validation_errors=raise_validation_errors,
142
- )
143
- )
144
- active_tasks.add(task)
145
- # Add callback to remove task from set when done
146
- task.add_done_callback(active_tasks.discard)
147
- except StopIteration:
148
- break
149
-
150
- if not active_tasks:
151
- break
152
-
153
- # Wait for next completed task
154
- done, _ = await asyncio.wait(
155
- active_tasks, return_when=asyncio.FIRST_COMPLETED
156
- )
157
-
158
- # Process completed tasks
159
- for task in done:
160
- try:
161
- result = await task
162
- self.jobs_runner_status.add_completed_interview(result)
163
- yield result
164
- except Exception as e:
165
- if stop_on_exception:
166
- # Cancel remaining tasks
167
- for t in active_tasks:
168
- if not t.done():
169
- t.cancel()
170
- raise
171
- else:
172
- # Log error and continue
173
- # logger.error(f"Task failed with error: {e}")
174
- continue
175
- finally:
176
- # Ensure we cancel any remaining tasks if we exit early
177
- for task in active_tasks:
178
- if not task.done():
179
- task.cancel()
180
-
181
- def _populate_total_interviews(
182
- self, n: int = 1
183
- ) -> Generator["Interview", None, None]:
184
- """Populates self.total_interviews with n copies of each interview.
185
-
186
- :param n: how many times to run each interview.
187
- """
188
- for interview in self.interviews:
189
- for iteration in range(n):
190
- if iteration > 0:
191
- yield interview.duplicate(iteration=iteration, cache=self.cache)
192
- else:
193
- interview.cache = self.cache
194
- yield interview
195
-
196
- async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
197
- """Used for some other modules that have a non-standard way of running interviews."""
198
- self.jobs_runner_status = JobsRunnerStatus(self, n=n)
199
- self.cache = Cache() if cache is None else cache
200
- data = []
201
- async for result in self.run_async_generator(cache=self.cache, n=n):
202
- data.append(result)
203
- return Results(survey=self.jobs.survey, data=data)
204
-
205
- def simple_run(self):
206
- data = asyncio.run(self.run_async())
207
- return Results(survey=self.jobs.survey, data=data)
208
-
209
- async def _build_interview_task(
210
- self,
211
- *,
212
- interview: Interview,
213
- stop_on_exception: bool = False,
214
- sidecar_model: Optional["LanguageModel"] = None,
215
- raise_validation_errors: bool = False,
216
- ) -> "Result":
217
- """Conducts an interview and returns the result.
218
-
219
- :param interview: the interview to conduct
220
- :param stop_on_exception: stops the interview if an exception is raised
221
- :param sidecar_model: a language model to use in addition to the interview's model
222
- """
223
- # the model buckets are used to track usage rates
224
- model_buckets = self.bucket_collection[interview.model]
225
-
226
- # get the results of the interview
227
- answer, valid_results = await interview.async_conduct_interview(
228
- model_buckets=model_buckets,
229
- stop_on_exception=stop_on_exception,
230
- sidecar_model=sidecar_model,
231
- raise_validation_errors=raise_validation_errors,
232
- )
233
-
234
- question_results = {}
235
- for result in valid_results:
236
- question_results[result.question_name] = result
237
-
238
- answer_key_names = list(question_results.keys())
239
-
240
- generated_tokens_dict = {
241
- k + "_generated_tokens": question_results[k].generated_tokens
242
- for k in answer_key_names
243
- }
244
- comments_dict = {
245
- k + "_comment": question_results[k].comment for k in answer_key_names
246
- }
247
-
248
- # we should have a valid result for each question
249
- answer_dict = {k: answer[k] for k in answer_key_names}
250
- assert len(valid_results) == len(answer_key_names)
251
-
252
- # TODO: move this down into Interview
253
- question_name_to_prompts = dict({})
254
- for result in valid_results:
255
- question_name = result.question_name
256
- question_name_to_prompts[question_name] = {
257
- "user_prompt": result.prompts["user_prompt"],
258
- "system_prompt": result.prompts["system_prompt"],
259
- }
260
-
261
- prompt_dictionary = {}
262
- for answer_key_name in answer_key_names:
263
- prompt_dictionary[
264
- answer_key_name + "_user_prompt"
265
- ] = question_name_to_prompts[answer_key_name]["user_prompt"]
266
- prompt_dictionary[
267
- answer_key_name + "_system_prompt"
268
- ] = question_name_to_prompts[answer_key_name]["system_prompt"]
269
-
270
- raw_model_results_dictionary = {}
271
- cache_used_dictionary = {}
272
- for result in valid_results:
273
- question_name = result.question_name
274
- raw_model_results_dictionary[
275
- question_name + "_raw_model_response"
276
- ] = result.raw_model_response
277
- raw_model_results_dictionary[question_name + "_cost"] = result.cost
278
- one_use_buys = (
279
- "NA"
280
- if isinstance(result.cost, str)
281
- or result.cost == 0
282
- or result.cost is None
283
- else 1.0 / result.cost
284
- )
285
- raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
286
- cache_used_dictionary[question_name] = result.cache_used
287
-
288
- result = Result(
289
- agent=interview.agent,
290
- scenario=interview.scenario,
291
- model=interview.model,
292
- iteration=interview.iteration,
293
- answer=answer_dict,
294
- prompt=prompt_dictionary,
295
- raw_model_response=raw_model_results_dictionary,
296
- survey=interview.survey,
297
- generated_tokens=generated_tokens_dict,
298
- comments_dict=comments_dict,
299
- cache_used_dict=cache_used_dictionary,
300
- )
301
- result.interview_hash = hash(interview)
302
-
303
- return result
304
-
305
- @property
306
- def elapsed_time(self):
307
- return time.monotonic() - self.start_time
308
-
309
- def process_results(
310
- self, raw_results: Results, cache: Cache, print_exceptions: bool
311
- ):
312
- interview_lookup = {
313
- hash(interview): index
314
- for index, interview in enumerate(self.total_interviews)
315
- }
316
- interview_hashes = list(interview_lookup.keys())
317
-
318
- task_history = TaskHistory(self.total_interviews, include_traceback=False)
319
-
320
- results = Results(
321
- survey=self.jobs.survey,
322
- data=sorted(
323
- raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
324
- ),
325
- task_history=task_history,
326
- cache=cache,
327
- )
328
- results.bucket_collection = self.bucket_collection
329
-
330
- if results.has_unfixed_exceptions and print_exceptions:
331
- from edsl.scenarios.FileStore import HTMLFileStore
332
- from edsl.config import CONFIG
333
- from edsl.coop.coop import Coop
334
-
335
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
336
-
337
- if len(results.task_history.indices) > 5:
338
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
339
-
340
- import sys
341
-
342
- print(msg, file=sys.stderr)
343
- from edsl.config import CONFIG
344
-
345
- if CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "True":
346
- open_in_browser = True
347
- elif CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "False":
348
- open_in_browser = False
349
- else:
350
- raise Exception(
351
- "EDSL_OPEN_EXCEPTION_REPORT_URL", "must be either True or False"
352
- )
353
-
354
- # print("open_in_browser", open_in_browser)
355
-
356
- filepath = results.task_history.html(
357
- cta="Open report to see details.",
358
- open_in_browser=open_in_browser,
359
- return_link=True,
360
- )
361
-
362
- try:
363
- coop = Coop()
364
- user_edsl_settings = coop.edsl_settings
365
- remote_logging = user_edsl_settings["remote_logging"]
366
- except Exception as e:
367
- print(e)
368
- remote_logging = False
369
-
370
- if remote_logging:
371
- filestore = HTMLFileStore(filepath)
372
- coop_details = filestore.push(description="Error report")
373
- print(coop_details)
374
-
375
- print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
376
-
377
- return results
378
-
379
- @jupyter_nb_handler
380
- async def run(
381
- self,
382
- cache: Union[Cache, False, None],
383
- n: int = 1,
384
- stop_on_exception: bool = False,
385
- progress_bar: bool = False,
386
- sidecar_model: Optional[LanguageModel] = None,
387
- jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
388
- job_uuid: Optional[UUID] = None,
389
- print_exceptions: bool = True,
390
- raise_validation_errors: bool = False,
391
- ) -> "Coroutine":
392
- """Runs a collection of interviews, handling both async and sync contexts."""
393
-
394
- self.results = []
395
- self.start_time = time.monotonic()
396
- self.completed = False
397
- self.cache = cache
398
- self.sidecar_model = sidecar_model
399
-
400
- from edsl.coop import Coop
401
-
402
- coop = Coop()
403
- endpoint_url = coop.get_progress_bar_url()
404
-
405
- if jobs_runner_status is not None:
406
- self.jobs_runner_status = jobs_runner_status(
407
- self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
408
- )
409
- else:
410
- self.jobs_runner_status = JobsRunnerStatus(
411
- self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
412
- )
413
-
414
- stop_event = threading.Event()
415
-
416
- async def process_results(cache):
417
- """Processes results from interviews."""
418
- async for result in self.run_async_generator(
419
- n=n,
420
- stop_on_exception=stop_on_exception,
421
- cache=cache,
422
- sidecar_model=sidecar_model,
423
- raise_validation_errors=raise_validation_errors,
424
- ):
425
- self.results.append(result)
426
- self.completed = True
427
-
428
- def run_progress_bar(stop_event):
429
- """Runs the progress bar in a separate thread."""
430
- self.jobs_runner_status.update_progress(stop_event)
431
-
432
- if progress_bar and self.jobs_runner_status.has_ep_api_key():
433
- self.jobs_runner_status.setup()
434
- progress_thread = threading.Thread(
435
- target=run_progress_bar, args=(stop_event,)
436
- )
437
- progress_thread.start()
438
- elif progress_bar:
439
- warnings.warn(
440
- "You need an Expected Parrot API key to view job progress bars."
441
- )
442
-
443
- exception_to_raise = None
444
- try:
445
- with cache as c:
446
- await process_results(cache=c)
447
- except KeyboardInterrupt:
448
- print("Keyboard interrupt received. Stopping gracefully...")
449
- stop_event.set()
450
- except Exception as e:
451
- if stop_on_exception:
452
- exception_to_raise = e
453
- stop_event.set()
454
- finally:
455
- stop_event.set()
456
- if progress_bar and self.jobs_runner_status.has_ep_api_key():
457
- # self.jobs_runner_status.stop_event.set()
458
- if progress_thread:
459
- progress_thread.join()
460
-
461
- if exception_to_raise:
462
- raise exception_to_raise
463
-
464
- return self.process_results(
465
- raw_results=self.results, cache=cache, print_exceptions=print_exceptions
466
- )
1
+ from __future__ import annotations
2
+ import time
3
+ import asyncio
4
+ import threading
5
+ import warnings
6
+ from typing import TYPE_CHECKING
7
+
8
+ from edsl.results.Results import Results
9
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
10
+ from edsl.jobs.tasks.TaskHistory import TaskHistory
11
+ from edsl.utilities.decorators import jupyter_nb_handler
12
+ from edsl.jobs.async_interview_runner import AsyncInterviewRunner
13
+ from edsl.jobs.data_structures import RunEnvironment, RunParameters, RunConfig
14
+
15
+ if TYPE_CHECKING:
16
+ from edsl.jobs.Jobs import Jobs
17
+
18
+
19
+ class JobsRunnerAsyncio:
20
+ """A class for running a collection of interviews asynchronously.
21
+
22
+ It gets instaniated from a Jobs object.
23
+ The Jobs object is a collection of interviews that are to be run.
24
+ """
25
+
26
+ def __init__(self, jobs: "Jobs", environment: RunEnvironment):
27
+ self.jobs = jobs
28
+ self.environment = environment
29
+
30
+ def __len__(self):
31
+ return len(self.jobs)
32
+
33
+ async def run_async(self, parameters: RunParameters) -> Results:
34
+ """Used for some other modules that have a non-standard way of running interviews."""
35
+
36
+ self.environment.jobs_runner_status = JobsRunnerStatus(self, n=parameters.n)
37
+ data = []
38
+ task_history = TaskHistory(include_traceback=False)
39
+
40
+ run_config = RunConfig(parameters=parameters, environment=self.environment)
41
+ result_generator = AsyncInterviewRunner(self.jobs, run_config)
42
+
43
+ async for result, interview in result_generator.run():
44
+ data.append(result)
45
+ task_history.add_interview(interview)
46
+
47
+ return Results(survey=self.jobs.survey, task_history=task_history, data=data)
48
+
49
+ def simple_run(self):
50
+ data = asyncio.run(self.run_async())
51
+ return Results(survey=self.jobs.survey, data=data)
52
+
53
+ @jupyter_nb_handler
54
+ async def run(self, parameters: RunParameters) -> Results:
55
+ """Runs a collection of interviews, handling both async and sync contexts."""
56
+
57
+ run_config = RunConfig(parameters=parameters, environment=self.environment)
58
+
59
+ self.start_time = time.monotonic()
60
+ self.completed = False
61
+
62
+ from edsl.coop import Coop
63
+
64
+ coop = Coop()
65
+ endpoint_url = coop.get_progress_bar_url()
66
+
67
+ def set_up_jobs_runner_status(jobs_runner_status):
68
+ if jobs_runner_status is not None:
69
+ return jobs_runner_status(
70
+ self,
71
+ n=parameters.n,
72
+ endpoint_url=endpoint_url,
73
+ job_uuid=parameters.job_uuid,
74
+ )
75
+ else:
76
+ return JobsRunnerStatus(
77
+ self,
78
+ n=parameters.n,
79
+ endpoint_url=endpoint_url,
80
+ job_uuid=parameters.job_uuid,
81
+ )
82
+
83
+ run_config.environment.jobs_runner_status = set_up_jobs_runner_status(
84
+ self.environment.jobs_runner_status
85
+ )
86
+
87
+ async def get_results(results) -> None:
88
+ """Conducted the interviews and append to the results list."""
89
+ result_generator = AsyncInterviewRunner(self.jobs, run_config)
90
+ async for result, interview in result_generator.run():
91
+ results.append(result)
92
+ results.task_history.add_interview(interview)
93
+
94
+ self.completed = True
95
+
96
+ def run_progress_bar(stop_event) -> None:
97
+ """Runs the progress bar in a separate thread."""
98
+ self.jobs_runner_status.update_progress(stop_event)
99
+
100
+ def set_up_progress_bar(progress_bar: bool, jobs_runner_status):
101
+ progress_thread = None
102
+ if progress_bar and jobs_runner_status.has_ep_api_key():
103
+ jobs_runner_status.setup()
104
+ progress_thread = threading.Thread(
105
+ target=run_progress_bar, args=(stop_event,)
106
+ )
107
+ progress_thread.start()
108
+ elif progress_bar:
109
+ warnings.warn(
110
+ "You need an Expected Parrot API key to view job progress bars."
111
+ )
112
+ return progress_thread
113
+
114
+ results = Results(
115
+ survey=self.jobs.survey,
116
+ data=[],
117
+ task_history=TaskHistory(),
118
+ cache=self.environment.cache.new_entries_cache(),
119
+ )
120
+ stop_event = threading.Event()
121
+ progress_thread = set_up_progress_bar(
122
+ parameters.progress_bar, run_config.environment.jobs_runner_status
123
+ )
124
+
125
+ exception_to_raise = None
126
+ try:
127
+ await get_results(results)
128
+ except KeyboardInterrupt:
129
+ print("Keyboard interrupt received. Stopping gracefully...")
130
+ stop_event.set()
131
+ except Exception as e:
132
+ if parameters.stop_on_exception:
133
+ exception_to_raise = e
134
+ stop_event.set()
135
+ finally:
136
+ stop_event.set()
137
+ if progress_thread is not None:
138
+ progress_thread.join()
139
+
140
+ if exception_to_raise:
141
+ raise exception_to_raise
142
+
143
+ results.cache = self.environment.cache.new_entries_cache()
144
+ results.bucket_collection = self.environment.bucket_collection
145
+
146
+ from edsl.jobs.results_exceptions_handler import ResultsExceptionsHandler
147
+
148
+ results_exceptions_handler = ResultsExceptionsHandler(results, parameters)
149
+
150
+ results_exceptions_handler.handle_exceptions()
151
+ return results