edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (407) hide show
  1. edsl/Base.py +348 -38
  2. edsl/BaseDiff.py +260 -0
  3. edsl/TemplateLoader.py +24 -0
  4. edsl/__init__.py +46 -10
  5. edsl/__version__.py +1 -0
  6. edsl/agents/Agent.py +842 -144
  7. edsl/agents/AgentList.py +521 -25
  8. edsl/agents/Invigilator.py +250 -374
  9. edsl/agents/InvigilatorBase.py +257 -0
  10. edsl/agents/PromptConstructor.py +272 -0
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/descriptors.py +43 -13
  14. edsl/agents/prompt_helpers.py +129 -0
  15. edsl/agents/question_option_processor.py +172 -0
  16. edsl/auto/AutoStudy.py +130 -0
  17. edsl/auto/StageBase.py +243 -0
  18. edsl/auto/StageGenerateSurvey.py +178 -0
  19. edsl/auto/StageLabelQuestions.py +125 -0
  20. edsl/auto/StagePersona.py +61 -0
  21. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  22. edsl/auto/StagePersonaDimensionValues.py +74 -0
  23. edsl/auto/StagePersonaDimensions.py +69 -0
  24. edsl/auto/StageQuestions.py +74 -0
  25. edsl/auto/SurveyCreatorPipeline.py +21 -0
  26. edsl/auto/utilities.py +218 -0
  27. edsl/base/Base.py +279 -0
  28. edsl/config.py +121 -104
  29. edsl/conversation/Conversation.py +290 -0
  30. edsl/conversation/car_buying.py +59 -0
  31. edsl/conversation/chips.py +95 -0
  32. edsl/conversation/mug_negotiation.py +81 -0
  33. edsl/conversation/next_speaker_utilities.py +93 -0
  34. edsl/coop/CoopFunctionsMixin.py +15 -0
  35. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  36. edsl/coop/PriceFetcher.py +54 -0
  37. edsl/coop/__init__.py +1 -0
  38. edsl/coop/coop.py +1029 -134
  39. edsl/coop/utils.py +131 -0
  40. edsl/data/Cache.py +560 -89
  41. edsl/data/CacheEntry.py +230 -0
  42. edsl/data/CacheHandler.py +168 -0
  43. edsl/data/RemoteCacheSync.py +186 -0
  44. edsl/data/SQLiteDict.py +292 -0
  45. edsl/data/__init__.py +5 -3
  46. edsl/data/orm.py +6 -33
  47. edsl/data_transfer_models.py +74 -27
  48. edsl/enums.py +165 -8
  49. edsl/exceptions/BaseException.py +21 -0
  50. edsl/exceptions/__init__.py +52 -46
  51. edsl/exceptions/agents.py +33 -15
  52. edsl/exceptions/cache.py +5 -0
  53. edsl/exceptions/coop.py +8 -0
  54. edsl/exceptions/general.py +34 -0
  55. edsl/exceptions/inference_services.py +5 -0
  56. edsl/exceptions/jobs.py +15 -0
  57. edsl/exceptions/language_models.py +46 -1
  58. edsl/exceptions/questions.py +80 -5
  59. edsl/exceptions/results.py +16 -5
  60. edsl/exceptions/scenarios.py +29 -0
  61. edsl/exceptions/surveys.py +13 -10
  62. edsl/inference_services/AnthropicService.py +106 -0
  63. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  64. edsl/inference_services/AvailableModelFetcher.py +215 -0
  65. edsl/inference_services/AwsBedrock.py +118 -0
  66. edsl/inference_services/AzureAI.py +215 -0
  67. edsl/inference_services/DeepInfraService.py +18 -0
  68. edsl/inference_services/GoogleService.py +143 -0
  69. edsl/inference_services/GroqService.py +20 -0
  70. edsl/inference_services/InferenceServiceABC.py +80 -0
  71. edsl/inference_services/InferenceServicesCollection.py +138 -0
  72. edsl/inference_services/MistralAIService.py +120 -0
  73. edsl/inference_services/OllamaService.py +18 -0
  74. edsl/inference_services/OpenAIService.py +236 -0
  75. edsl/inference_services/PerplexityService.py +160 -0
  76. edsl/inference_services/ServiceAvailability.py +135 -0
  77. edsl/inference_services/TestService.py +90 -0
  78. edsl/inference_services/TogetherAIService.py +172 -0
  79. edsl/inference_services/data_structures.py +134 -0
  80. edsl/inference_services/models_available_cache.py +118 -0
  81. edsl/inference_services/rate_limits_cache.py +25 -0
  82. edsl/inference_services/registry.py +41 -0
  83. edsl/inference_services/write_available.py +10 -0
  84. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  85. edsl/jobs/Answers.py +21 -20
  86. edsl/jobs/FetchInvigilator.py +47 -0
  87. edsl/jobs/InterviewTaskManager.py +98 -0
  88. edsl/jobs/InterviewsConstructor.py +50 -0
  89. edsl/jobs/Jobs.py +684 -204
  90. edsl/jobs/JobsChecks.py +172 -0
  91. edsl/jobs/JobsComponentConstructor.py +189 -0
  92. edsl/jobs/JobsPrompts.py +270 -0
  93. edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
  94. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  95. edsl/jobs/RequestTokenEstimator.py +30 -0
  96. edsl/jobs/async_interview_runner.py +138 -0
  97. edsl/jobs/buckets/BucketCollection.py +104 -0
  98. edsl/jobs/buckets/ModelBuckets.py +65 -0
  99. edsl/jobs/buckets/TokenBucket.py +283 -0
  100. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  101. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  102. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  103. edsl/jobs/data_structures.py +120 -0
  104. edsl/jobs/decorators.py +35 -0
  105. edsl/jobs/interviews/Interview.py +392 -0
  106. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
  107. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
  108. edsl/jobs/interviews/InterviewStatistic.py +63 -0
  109. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
  110. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
  111. edsl/jobs/interviews/InterviewStatusLog.py +92 -0
  112. edsl/jobs/interviews/ReportErrors.py +66 -0
  113. edsl/jobs/interviews/interview_status_enum.py +9 -0
  114. edsl/jobs/jobs_status_enums.py +9 -0
  115. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  116. edsl/jobs/results_exceptions_handler.py +98 -0
  117. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
  118. edsl/jobs/runners/JobsRunnerStatus.py +298 -0
  119. edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
  120. edsl/jobs/tasks/TaskCreators.py +64 -0
  121. edsl/jobs/tasks/TaskHistory.py +470 -0
  122. edsl/jobs/tasks/TaskStatusLog.py +23 -0
  123. edsl/jobs/tasks/task_status_enum.py +161 -0
  124. edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
  125. edsl/jobs/tokens/TokenUsage.py +34 -0
  126. edsl/language_models/ComputeCost.py +63 -0
  127. edsl/language_models/LanguageModel.py +507 -386
  128. edsl/language_models/ModelList.py +164 -0
  129. edsl/language_models/PriceManager.py +127 -0
  130. edsl/language_models/RawResponseHandler.py +106 -0
  131. edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
  132. edsl/language_models/__init__.py +1 -8
  133. edsl/language_models/fake_openai_call.py +15 -0
  134. edsl/language_models/fake_openai_service.py +61 -0
  135. edsl/language_models/key_management/KeyLookup.py +63 -0
  136. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  137. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  138. edsl/language_models/key_management/__init__.py +0 -0
  139. edsl/language_models/key_management/models.py +131 -0
  140. edsl/language_models/model.py +256 -0
  141. edsl/language_models/repair.py +109 -41
  142. edsl/language_models/utilities.py +65 -0
  143. edsl/notebooks/Notebook.py +263 -0
  144. edsl/notebooks/NotebookToLaTeX.py +142 -0
  145. edsl/notebooks/__init__.py +1 -0
  146. edsl/prompts/Prompt.py +222 -93
  147. edsl/prompts/__init__.py +1 -1
  148. edsl/questions/ExceptionExplainer.py +77 -0
  149. edsl/questions/HTMLQuestion.py +103 -0
  150. edsl/questions/QuestionBase.py +518 -0
  151. edsl/questions/QuestionBasePromptsMixin.py +221 -0
  152. edsl/questions/QuestionBudget.py +164 -67
  153. edsl/questions/QuestionCheckBox.py +281 -62
  154. edsl/questions/QuestionDict.py +343 -0
  155. edsl/questions/QuestionExtract.py +136 -50
  156. edsl/questions/QuestionFreeText.py +79 -55
  157. edsl/questions/QuestionFunctional.py +138 -41
  158. edsl/questions/QuestionList.py +184 -57
  159. edsl/questions/QuestionMatrix.py +265 -0
  160. edsl/questions/QuestionMultipleChoice.py +293 -69
  161. edsl/questions/QuestionNumerical.py +109 -56
  162. edsl/questions/QuestionRank.py +244 -49
  163. edsl/questions/Quick.py +41 -0
  164. edsl/questions/SimpleAskMixin.py +74 -0
  165. edsl/questions/__init__.py +9 -6
  166. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
  167. edsl/questions/compose_questions.py +13 -7
  168. edsl/questions/data_structures.py +20 -0
  169. edsl/questions/decorators.py +21 -0
  170. edsl/questions/derived/QuestionLikertFive.py +28 -26
  171. edsl/questions/derived/QuestionLinearScale.py +41 -28
  172. edsl/questions/derived/QuestionTopK.py +34 -26
  173. edsl/questions/derived/QuestionYesNo.py +40 -27
  174. edsl/questions/descriptors.py +228 -74
  175. edsl/questions/loop_processor.py +149 -0
  176. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  177. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  178. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  179. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  180. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  181. edsl/questions/prompt_templates/question_list.jinja +17 -0
  182. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  183. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  184. edsl/questions/question_base_gen_mixin.py +168 -0
  185. edsl/questions/question_registry.py +130 -46
  186. edsl/questions/register_questions_meta.py +71 -0
  187. edsl/questions/response_validator_abc.py +188 -0
  188. edsl/questions/response_validator_factory.py +34 -0
  189. edsl/questions/settings.py +5 -2
  190. edsl/questions/templates/__init__.py +0 -0
  191. edsl/questions/templates/budget/__init__.py +0 -0
  192. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  193. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  194. edsl/questions/templates/checkbox/__init__.py +0 -0
  195. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  196. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  197. edsl/questions/templates/dict/__init__.py +0 -0
  198. edsl/questions/templates/dict/answering_instructions.jinja +21 -0
  199. edsl/questions/templates/dict/question_presentation.jinja +1 -0
  200. edsl/questions/templates/extract/__init__.py +0 -0
  201. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  202. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  203. edsl/questions/templates/free_text/__init__.py +0 -0
  204. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  205. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  206. edsl/questions/templates/likert_five/__init__.py +0 -0
  207. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  208. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  209. edsl/questions/templates/linear_scale/__init__.py +0 -0
  210. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  211. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  212. edsl/questions/templates/list/__init__.py +0 -0
  213. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  214. edsl/questions/templates/list/question_presentation.jinja +5 -0
  215. edsl/questions/templates/matrix/__init__.py +1 -0
  216. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  217. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  218. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  219. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  220. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  221. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  222. edsl/questions/templates/numerical/__init__.py +0 -0
  223. edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
  224. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  225. edsl/questions/templates/rank/__init__.py +0 -0
  226. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  227. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  228. edsl/questions/templates/top_k/__init__.py +0 -0
  229. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  230. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  231. edsl/questions/templates/yes_no/__init__.py +0 -0
  232. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  233. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  234. edsl/results/CSSParameterizer.py +108 -0
  235. edsl/results/Dataset.py +550 -19
  236. edsl/results/DatasetExportMixin.py +594 -0
  237. edsl/results/DatasetTree.py +295 -0
  238. edsl/results/MarkdownToDocx.py +122 -0
  239. edsl/results/MarkdownToPDF.py +111 -0
  240. edsl/results/Result.py +477 -173
  241. edsl/results/Results.py +987 -269
  242. edsl/results/ResultsExportMixin.py +28 -125
  243. edsl/results/ResultsGGMixin.py +83 -15
  244. edsl/results/TableDisplay.py +125 -0
  245. edsl/results/TextEditor.py +50 -0
  246. edsl/results/__init__.py +1 -1
  247. edsl/results/file_exports.py +252 -0
  248. edsl/results/results_fetch_mixin.py +33 -0
  249. edsl/results/results_selector.py +145 -0
  250. edsl/results/results_tools_mixin.py +98 -0
  251. edsl/results/smart_objects.py +96 -0
  252. edsl/results/table_data_class.py +12 -0
  253. edsl/results/table_display.css +78 -0
  254. edsl/results/table_renderers.py +118 -0
  255. edsl/results/tree_explore.py +115 -0
  256. edsl/scenarios/ConstructDownloadLink.py +109 -0
  257. edsl/scenarios/DocumentChunker.py +102 -0
  258. edsl/scenarios/DocxScenario.py +16 -0
  259. edsl/scenarios/FileStore.py +543 -0
  260. edsl/scenarios/PdfExtractor.py +40 -0
  261. edsl/scenarios/Scenario.py +431 -62
  262. edsl/scenarios/ScenarioHtmlMixin.py +65 -0
  263. edsl/scenarios/ScenarioList.py +1415 -45
  264. edsl/scenarios/ScenarioListExportMixin.py +45 -0
  265. edsl/scenarios/ScenarioListPdfMixin.py +239 -0
  266. edsl/scenarios/__init__.py +2 -0
  267. edsl/scenarios/directory_scanner.py +96 -0
  268. edsl/scenarios/file_methods.py +85 -0
  269. edsl/scenarios/handlers/__init__.py +13 -0
  270. edsl/scenarios/handlers/csv.py +49 -0
  271. edsl/scenarios/handlers/docx.py +76 -0
  272. edsl/scenarios/handlers/html.py +37 -0
  273. edsl/scenarios/handlers/json.py +111 -0
  274. edsl/scenarios/handlers/latex.py +5 -0
  275. edsl/scenarios/handlers/md.py +51 -0
  276. edsl/scenarios/handlers/pdf.py +68 -0
  277. edsl/scenarios/handlers/png.py +39 -0
  278. edsl/scenarios/handlers/pptx.py +105 -0
  279. edsl/scenarios/handlers/py.py +294 -0
  280. edsl/scenarios/handlers/sql.py +313 -0
  281. edsl/scenarios/handlers/sqlite.py +149 -0
  282. edsl/scenarios/handlers/txt.py +33 -0
  283. edsl/scenarios/scenario_join.py +131 -0
  284. edsl/scenarios/scenario_selector.py +156 -0
  285. edsl/shared.py +1 -0
  286. edsl/study/ObjectEntry.py +173 -0
  287. edsl/study/ProofOfWork.py +113 -0
  288. edsl/study/SnapShot.py +80 -0
  289. edsl/study/Study.py +521 -0
  290. edsl/study/__init__.py +4 -0
  291. edsl/surveys/ConstructDAG.py +92 -0
  292. edsl/surveys/DAG.py +92 -11
  293. edsl/surveys/EditSurvey.py +221 -0
  294. edsl/surveys/InstructionHandler.py +100 -0
  295. edsl/surveys/Memory.py +9 -4
  296. edsl/surveys/MemoryManagement.py +72 -0
  297. edsl/surveys/MemoryPlan.py +156 -35
  298. edsl/surveys/Rule.py +221 -74
  299. edsl/surveys/RuleCollection.py +241 -61
  300. edsl/surveys/RuleManager.py +172 -0
  301. edsl/surveys/Simulator.py +75 -0
  302. edsl/surveys/Survey.py +1079 -339
  303. edsl/surveys/SurveyCSS.py +273 -0
  304. edsl/surveys/SurveyExportMixin.py +235 -40
  305. edsl/surveys/SurveyFlowVisualization.py +181 -0
  306. edsl/surveys/SurveyQualtricsImport.py +284 -0
  307. edsl/surveys/SurveyToApp.py +141 -0
  308. edsl/surveys/__init__.py +4 -2
  309. edsl/surveys/base.py +19 -3
  310. edsl/surveys/descriptors.py +17 -6
  311. edsl/surveys/instructions/ChangeInstruction.py +48 -0
  312. edsl/surveys/instructions/Instruction.py +56 -0
  313. edsl/surveys/instructions/InstructionCollection.py +82 -0
  314. edsl/surveys/instructions/__init__.py +0 -0
  315. edsl/templates/error_reporting/base.html +24 -0
  316. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  317. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  318. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  319. edsl/templates/error_reporting/interview_details.html +116 -0
  320. edsl/templates/error_reporting/interviews.html +19 -0
  321. edsl/templates/error_reporting/overview.html +5 -0
  322. edsl/templates/error_reporting/performance_plot.html +2 -0
  323. edsl/templates/error_reporting/report.css +74 -0
  324. edsl/templates/error_reporting/report.html +118 -0
  325. edsl/templates/error_reporting/report.js +25 -0
  326. edsl/tools/__init__.py +1 -0
  327. edsl/tools/clusters.py +192 -0
  328. edsl/tools/embeddings.py +27 -0
  329. edsl/tools/embeddings_plotting.py +118 -0
  330. edsl/tools/plotting.py +112 -0
  331. edsl/tools/summarize.py +18 -0
  332. edsl/utilities/PrettyList.py +56 -0
  333. edsl/utilities/SystemInfo.py +5 -0
  334. edsl/utilities/__init__.py +21 -20
  335. edsl/utilities/ast_utilities.py +3 -0
  336. edsl/utilities/data/Registry.py +2 -0
  337. edsl/utilities/decorators.py +41 -0
  338. edsl/utilities/gcp_bucket/__init__.py +0 -0
  339. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  340. edsl/utilities/interface.py +310 -60
  341. edsl/utilities/is_notebook.py +18 -0
  342. edsl/utilities/is_valid_variable_name.py +11 -0
  343. edsl/utilities/naming_utilities.py +263 -0
  344. edsl/utilities/remove_edsl_version.py +24 -0
  345. edsl/utilities/repair_functions.py +28 -0
  346. edsl/utilities/restricted_python.py +70 -0
  347. edsl/utilities/utilities.py +203 -13
  348. edsl-0.1.40.dist-info/METADATA +111 -0
  349. edsl-0.1.40.dist-info/RECORD +362 -0
  350. {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
  351. edsl/agents/AgentListExportMixin.py +0 -24
  352. edsl/coop/old.py +0 -31
  353. edsl/data/Database.py +0 -141
  354. edsl/data/crud.py +0 -121
  355. edsl/jobs/Interview.py +0 -417
  356. edsl/jobs/JobsRunner.py +0 -63
  357. edsl/jobs/JobsRunnerStatusMixin.py +0 -115
  358. edsl/jobs/base.py +0 -47
  359. edsl/jobs/buckets.py +0 -166
  360. edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
  361. edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
  362. edsl/jobs/task_management.py +0 -218
  363. edsl/jobs/token_tracking.py +0 -78
  364. edsl/language_models/DeepInfra.py +0 -69
  365. edsl/language_models/OpenAI.py +0 -98
  366. edsl/language_models/model_interfaces/GeminiPro.py +0 -66
  367. edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
  368. edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
  369. edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
  370. edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
  371. edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
  372. edsl/language_models/registry.py +0 -81
  373. edsl/language_models/schemas.py +0 -15
  374. edsl/language_models/unused/ReplicateBase.py +0 -83
  375. edsl/prompts/QuestionInstructionsBase.py +0 -6
  376. edsl/prompts/library/agent_instructions.py +0 -29
  377. edsl/prompts/library/agent_persona.py +0 -17
  378. edsl/prompts/library/question_budget.py +0 -26
  379. edsl/prompts/library/question_checkbox.py +0 -32
  380. edsl/prompts/library/question_extract.py +0 -19
  381. edsl/prompts/library/question_freetext.py +0 -14
  382. edsl/prompts/library/question_linear_scale.py +0 -20
  383. edsl/prompts/library/question_list.py +0 -22
  384. edsl/prompts/library/question_multiple_choice.py +0 -44
  385. edsl/prompts/library/question_numerical.py +0 -31
  386. edsl/prompts/library/question_rank.py +0 -21
  387. edsl/prompts/prompt_config.py +0 -33
  388. edsl/prompts/registry.py +0 -185
  389. edsl/questions/Question.py +0 -240
  390. edsl/report/InputOutputDataTypes.py +0 -134
  391. edsl/report/RegressionMixin.py +0 -28
  392. edsl/report/ReportOutputs.py +0 -1228
  393. edsl/report/ResultsFetchMixin.py +0 -106
  394. edsl/report/ResultsOutputMixin.py +0 -14
  395. edsl/report/demo.ipynb +0 -645
  396. edsl/results/ResultsDBMixin.py +0 -184
  397. edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
  398. edsl/trackers/Tracker.py +0 -91
  399. edsl/trackers/TrackerAPI.py +0 -196
  400. edsl/trackers/TrackerTasks.py +0 -70
  401. edsl/utilities/pastebin.py +0 -141
  402. edsl-0.1.14.dist-info/METADATA +0 -69
  403. edsl-0.1.14.dist-info/RECORD +0 -141
  404. /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
  405. /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
  406. /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
  407. {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -0,0 +1,211 @@
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Union, Dict
4
+ from typing import Union, List, Any, Optional
5
+ from threading import RLock
6
+ from edsl.jobs.buckets.TokenBucket import TokenBucket # Original implementation
7
+
8
+
9
+ def safe_float_for_json(value: float) -> Union[float, str]:
10
+ """Convert float('inf') to 'infinity' for JSON serialization.
11
+
12
+ Args:
13
+ value: The float value to convert
14
+
15
+ Returns:
16
+ Either the original float or the string 'infinity' if the value is infinite
17
+ """
18
+ if value == float("inf"):
19
+ return "infinity"
20
+ return value
21
+
22
+
23
+ app = FastAPI()
24
+
25
+ # In-memory storage for TokenBucket instances
26
+ buckets: Dict[str, TokenBucket] = {}
27
+
28
+
29
+ class TokenBucketCreate(BaseModel):
30
+ bucket_name: str
31
+ bucket_type: str
32
+ capacity: Union[int, float]
33
+ refill_rate: Union[int, float]
34
+
35
+
36
+ @app.get("/buckets")
37
+ async def list_buckets(
38
+ bucket_type: Optional[str] = None,
39
+ bucket_name: Optional[str] = None,
40
+ include_logs: bool = False,
41
+ ):
42
+ """List all buckets and their current status.
43
+
44
+ Args:
45
+ bucket_type: Optional filter by bucket type
46
+ bucket_name: Optional filter by bucket name
47
+ include_logs: Whether to include the full logs in the response
48
+ """
49
+ result = {}
50
+
51
+ for bucket_id, bucket in buckets.items():
52
+ # Apply filters if specified
53
+ if bucket_type and bucket.bucket_type != bucket_type:
54
+ continue
55
+ if bucket_name and bucket.bucket_name != bucket_name:
56
+ continue
57
+
58
+ # Get basic bucket info
59
+ bucket_info = {
60
+ "bucket_name": bucket.bucket_name,
61
+ "bucket_type": bucket.bucket_type,
62
+ "tokens": bucket.tokens,
63
+ "capacity": bucket.capacity,
64
+ "refill_rate": bucket.refill_rate,
65
+ "turbo_mode": bucket.turbo_mode,
66
+ "num_requests": bucket.num_requests,
67
+ "num_released": bucket.num_released,
68
+ "tokens_returned": bucket.tokens_returned,
69
+ }
70
+ for k, v in bucket_info.items():
71
+ if isinstance(v, float):
72
+ bucket_info[k] = safe_float_for_json(v)
73
+
74
+ # Only include logs if requested
75
+ if include_logs:
76
+ bucket_info["log"] = bucket.log
77
+
78
+ result[bucket_id] = bucket_info
79
+
80
+ return result
81
+
82
+
83
+ @app.post("/bucket/{bucket_id}/add_tokens")
84
+ async def add_tokens(bucket_id: str, amount: float):
85
+ """Add tokens to an existing bucket."""
86
+ if bucket_id not in buckets:
87
+ raise HTTPException(status_code=404, detail="Bucket not found")
88
+
89
+ if not isinstance(amount, (int, float)) or amount != amount: # Check for NaN
90
+ raise HTTPException(status_code=400, detail="Invalid amount specified")
91
+
92
+ if amount == float("inf") or amount == float("-inf"):
93
+ raise HTTPException(status_code=400, detail="Amount cannot be infinite")
94
+
95
+ bucket = buckets[bucket_id]
96
+ bucket.add_tokens(amount)
97
+
98
+ # Ensure we return a JSON-serializable float
99
+ current_tokens = float(bucket.tokens)
100
+ if not -1e308 <= current_tokens <= 1e308: # Check if within JSON float bounds
101
+ current_tokens = 0.0 # or some other reasonable default
102
+
103
+ return {"status": "success", "current_tokens": safe_float_for_json(current_tokens)}
104
+
105
+
106
+ # @app.post("/bucket")
107
+ # async def create_bucket(bucket: TokenBucketCreate):
108
+ # bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
109
+ # if bucket_id in buckets:
110
+ # raise HTTPException(status_code=400, detail="Bucket already exists")
111
+
112
+ # # Create an actual TokenBucket instance
113
+ # buckets[bucket_id] = TokenBucket(
114
+ # bucket_name=bucket.bucket_name,
115
+ # bucket_type=bucket.bucket_type,
116
+ # capacity=bucket.capacity,
117
+ # refill_rate=bucket.refill_rate,
118
+ # )
119
+ # return {"status": "created"}
120
+
121
+
122
+ @app.post("/bucket")
123
+ async def create_bucket(bucket: TokenBucketCreate):
124
+ if (
125
+ not isinstance(bucket.capacity, (int, float))
126
+ or bucket.capacity != bucket.capacity
127
+ ): # Check for NaN
128
+ raise HTTPException(status_code=400, detail="Invalid capacity value")
129
+ if (
130
+ not isinstance(bucket.refill_rate, (int, float))
131
+ or bucket.refill_rate != bucket.refill_rate
132
+ ): # Check for NaN
133
+ raise HTTPException(status_code=400, detail="Invalid refill rate value")
134
+ if bucket.capacity == float("inf") or bucket.refill_rate == float("inf"):
135
+ raise HTTPException(status_code=400, detail="Values cannot be infinite")
136
+ bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
137
+ if bucket_id in buckets:
138
+ # Instead of error, return success with "existing" status
139
+ return {
140
+ "status": "existing",
141
+ "bucket": {
142
+ "capacity": safe_float_for_json(buckets[bucket_id].capacity),
143
+ "refill_rate": safe_float_for_json(buckets[bucket_id].refill_rate),
144
+ },
145
+ }
146
+
147
+ # Create a new bucket
148
+ buckets[bucket_id] = TokenBucket(
149
+ bucket_name=bucket.bucket_name,
150
+ bucket_type=bucket.bucket_type,
151
+ capacity=bucket.capacity,
152
+ refill_rate=bucket.refill_rate,
153
+ )
154
+ return {"status": "created"}
155
+
156
+
157
+ @app.post("/bucket/{bucket_id}/get_tokens")
158
+ async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool = True):
159
+ if bucket_id not in buckets:
160
+ raise HTTPException(status_code=404, detail="Bucket not found")
161
+
162
+ bucket = buckets[bucket_id]
163
+ await bucket.get_tokens(amount, cheat_bucket_capacity)
164
+ return {"status": "success"}
165
+
166
+
167
+ @app.post("/bucket/{bucket_id}/turbo_mode/{state}")
168
+ async def set_turbo_mode(bucket_id: str, state: bool):
169
+ if bucket_id not in buckets:
170
+ raise HTTPException(status_code=404, detail="Bucket not found")
171
+
172
+ bucket = buckets[bucket_id]
173
+ if state:
174
+ bucket.turbo_mode_on()
175
+ else:
176
+ bucket.turbo_mode_off()
177
+ return {"status": "success"}
178
+
179
+
180
+ @app.get("/bucket/{bucket_id}/status")
181
+ async def get_bucket_status(bucket_id: str):
182
+ if bucket_id not in buckets:
183
+ raise HTTPException(status_code=404, detail="Bucket not found")
184
+
185
+ bucket = buckets[bucket_id]
186
+ status = {
187
+ "tokens": bucket.tokens,
188
+ "capacity": bucket.capacity,
189
+ "refill_rate": bucket.refill_rate,
190
+ "turbo_mode": bucket.turbo_mode,
191
+ "num_requests": bucket.num_requests,
192
+ "num_released": bucket.num_released,
193
+ "tokens_returned": bucket.tokens_returned,
194
+ "log": bucket.log,
195
+ }
196
+ for k, v in status.items():
197
+ if isinstance(v, float):
198
+ status[k] = safe_float_for_json(v)
199
+
200
+ for index, entry in enumerate(status["log"]):
201
+ ts, value = entry
202
+ status["log"][index] = (ts, safe_float_for_json(value))
203
+
204
+ # print(status)
205
+ return status
206
+
207
+
208
+ if __name__ == "__main__":
209
+ import uvicorn
210
+
211
+ uvicorn.run(app, host="0.0.0.0", port=8001)
@@ -0,0 +1,191 @@
1
+ from typing import Union, Optional
2
+ import asyncio
3
+ import time
4
+ import aiohttp
5
+
6
+
7
+ class TokenBucketClient:
8
+ """REST API client version of TokenBucket that maintains the same interface
9
+ by delegating to a server running the original TokenBucket implementation."""
10
+
11
+ def __init__(
12
+ self,
13
+ *,
14
+ bucket_name: str,
15
+ bucket_type: str,
16
+ capacity: Union[int, float],
17
+ refill_rate: Union[int, float],
18
+ api_base_url: str = "http://localhost:8000",
19
+ ):
20
+ self.bucket_name = bucket_name
21
+ self.bucket_type = bucket_type
22
+ self.capacity = capacity
23
+ self.refill_rate = refill_rate
24
+ self.api_base_url = api_base_url
25
+ self.bucket_id = f"{bucket_name}_{bucket_type}"
26
+
27
+ # Initialize the bucket on the server
28
+ asyncio.run(self._create_bucket())
29
+
30
+ # Cache some values locally
31
+ self.creation_time = time.monotonic()
32
+ self.turbo_mode = False
33
+
34
+ async def _create_bucket(self):
35
+ async with aiohttp.ClientSession() as session:
36
+ payload = {
37
+ "bucket_name": self.bucket_name,
38
+ "bucket_type": self.bucket_type,
39
+ "capacity": self.capacity,
40
+ "refill_rate": self.refill_rate,
41
+ }
42
+ async with session.post(
43
+ f"{self.api_base_url}/bucket",
44
+ json=payload,
45
+ ) as response:
46
+ if response.status != 200:
47
+ raise ValueError(f"Unexpected error: {await response.text()}")
48
+
49
+ result = await response.json()
50
+ if result["status"] == "existing":
51
+ # Update our local values to match the existing bucket
52
+ self.capacity = float(result["bucket"]["capacity"])
53
+ self.refill_rate = float(result["bucket"]["refill_rate"])
54
+
55
+ def turbo_mode_on(self):
56
+ """Set the refill rate to infinity."""
57
+ asyncio.run(self._set_turbo_mode(True))
58
+ self.turbo_mode = True
59
+
60
+ def turbo_mode_off(self):
61
+ """Restore the refill rate to its original value."""
62
+ asyncio.run(self._set_turbo_mode(False))
63
+ self.turbo_mode = False
64
+
65
+ async def add_tokens(self, amount: Union[int, float]):
66
+ """Add tokens to the bucket."""
67
+ async with aiohttp.ClientSession() as session:
68
+ async with session.post(
69
+ f"{self.api_base_url}/bucket/{self.bucket_id}/add_tokens",
70
+ params={"amount": amount},
71
+ ) as response:
72
+ if response.status != 200:
73
+ raise ValueError(f"Failed to add tokens: {await response.text()}")
74
+
75
+ async def _set_turbo_mode(self, state: bool):
76
+ async with aiohttp.ClientSession() as session:
77
+ async with session.post(
78
+ f"{self.api_base_url}/bucket/{self.bucket_id}/turbo_mode/{str(state).lower()}"
79
+ ) as response:
80
+ if response.status != 200:
81
+ raise ValueError(
82
+ f"Failed to set turbo mode: {await response.text()}"
83
+ )
84
+
85
+ async def get_tokens(
86
+ self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
87
+ ) -> None:
88
+ async with aiohttp.ClientSession() as session:
89
+ async with session.post(
90
+ f"{self.api_base_url}/bucket/{self.bucket_id}/get_tokens",
91
+ params={
92
+ "amount": amount,
93
+ "cheat_bucket_capacity": int(cheat_bucket_capacity),
94
+ },
95
+ ) as response:
96
+ if response.status != 200:
97
+ raise ValueError(f"Failed to get tokens: {await response.text()}")
98
+
99
+ def get_throughput(self, time_window: Optional[float] = None) -> float:
100
+ status = asyncio.run(self._get_status())
101
+ now = time.monotonic()
102
+
103
+ if time_window is None:
104
+ start_time = self.creation_time
105
+ else:
106
+ start_time = now - time_window
107
+
108
+ if start_time < self.creation_time:
109
+ start_time = self.creation_time
110
+
111
+ elapsed_time = now - start_time
112
+
113
+ if elapsed_time == 0:
114
+ return status["num_released"] / 0.001
115
+
116
+ return (status["num_released"] / elapsed_time) * 60
117
+
118
+ async def _get_status(self) -> dict:
119
+ async with aiohttp.ClientSession() as session:
120
+ async with session.get(
121
+ f"{self.api_base_url}/bucket/{self.bucket_id}/status"
122
+ ) as response:
123
+ if response.status != 200:
124
+ raise ValueError(
125
+ f"Failed to get bucket status: {await response.text()}"
126
+ )
127
+ return await response.json()
128
+
129
+ def __add__(self, other) -> "TokenBucketClient":
130
+ """Combine two token buckets."""
131
+ return TokenBucketClient(
132
+ bucket_name=self.bucket_name,
133
+ bucket_type=self.bucket_type,
134
+ capacity=min(self.capacity, other.capacity),
135
+ refill_rate=min(self.refill_rate, other.refill_rate),
136
+ api_base_url=self.api_base_url,
137
+ )
138
+
139
+ @property
140
+ def tokens(self) -> float:
141
+ """Get the number of tokens remaining in the bucket."""
142
+ status = asyncio.run(self._get_status())
143
+ return float(status["tokens"])
144
+
145
+ def wait_time(self, requested_tokens: Union[float, int]) -> float:
146
+ """Calculate the time to wait for the requested number of tokens."""
147
+ # self.refill() # Update the current token count
148
+ if self.tokens >= float(requested_tokens):
149
+ return 0.0
150
+ try:
151
+ return (requested_tokens - self.tokens) / self.refill_rate
152
+ except Exception as e:
153
+ raise ValueError(f"Error calculating wait time: {e}")
154
+
155
+ # def wait_time(self, num_tokens: Union[int, float]) -> float:
156
+ # return 0 # TODO - Need to implement this on the server side
157
+
158
+ def visualize(self):
159
+ """Visualize the token bucket over time."""
160
+ status = asyncio.run(self._get_status())
161
+ times, tokens = zip(*status["log"])
162
+ start_time = times[0]
163
+ times = [t - start_time for t in times]
164
+
165
+ from matplotlib import pyplot as plt
166
+
167
+ plt.figure(figsize=(10, 6))
168
+ plt.plot(times, tokens, label="Tokens Available")
169
+ plt.xlabel("Time (seconds)", fontsize=12)
170
+ plt.ylabel("Number of Tokens", fontsize=12)
171
+ details = f"{self.bucket_name} ({self.bucket_type}) Bucket Usage Over Time\nCapacity: {self.capacity:.1f}, Refill Rate: {self.refill_rate:.1f}/second"
172
+ plt.title(details, fontsize=14)
173
+ plt.legend()
174
+ plt.grid(True)
175
+ plt.tight_layout()
176
+ plt.show()
177
+
178
+
179
+ if __name__ == "__main__":
180
+ import doctest
181
+
182
+ doctest.testmod()
183
+ # bucket = TokenBucketClient(
184
+ # bucket_name="test", bucket_type="test", capacity=100, refill_rate=10
185
+ # )
186
+ # asyncio.run(bucket.get_tokens(50))
187
+ # time.sleep(1) # Wait for 1 second
188
+ # asyncio.run(bucket.get_tokens(30))
189
+ # throughput = bucket.get_throughput(1)
190
+ # print(throughput)
191
+ # bucket.visualize()
@@ -0,0 +1,85 @@
1
+ import warnings
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from edsl.surveys.Survey import Survey
6
+ from edsl.scenarios.ScenarioList import ScenarioList
7
+
8
+
9
+ class CheckSurveyScenarioCompatibility:
10
+
11
+ def __init__(self, survey: "Survey", scenarios: "ScenarioList"):
12
+ self.survey = survey
13
+ self.scenarios = scenarios
14
+
15
+ def check(self, strict: bool = False, warn: bool = False) -> None:
16
+ """Check if the parameters in the survey and scenarios are consistent.
17
+
18
+ >>> from edsl.jobs.Jobs import Jobs
19
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
20
+ >>> from edsl.surveys.Survey import Survey
21
+ >>> from edsl.scenarios.Scenario import Scenario
22
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
23
+ >>> j = Jobs(survey = Survey(questions=[q]))
24
+ >>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
25
+ >>> with warnings.catch_warnings(record=True) as w:
26
+ ... cs.check(warn = True)
27
+ ... assert len(w) == 1
28
+ ... assert issubclass(w[-1].category, UserWarning)
29
+ ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
30
+
31
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
32
+ >>> s = Scenario({'plop': "A", 'poo': "B"})
33
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
34
+ >>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
35
+ >>> cs.check(strict = True)
36
+ Traceback (most recent call last):
37
+ ...
38
+ ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
39
+
40
+ >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
41
+ >>> s = Scenario({'ugly_question': "B"})
42
+ >>> from edsl.scenarios.ScenarioList import ScenarioList
43
+ >>> cs = CheckSurveyScenarioCompatibility(Survey(questions=[q]), ScenarioList([s]))
44
+ >>> cs.check()
45
+ Traceback (most recent call last):
46
+ ...
47
+ ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
48
+ """
49
+ survey_parameters: set = self.survey.parameters
50
+ scenario_parameters: set = self.scenarios.parameters
51
+
52
+ msg0, msg1, msg2 = None, None, None
53
+
54
+ # look for key issues
55
+ if intersection := set(self.scenarios.parameters) & set(
56
+ self.survey.question_names
57
+ ):
58
+ msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
59
+
60
+ raise ValueError(msg0)
61
+
62
+ if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
63
+ msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
64
+ if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
65
+ msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
66
+
67
+ if msg1 or msg2:
68
+ message = "\n".join(filter(None, [msg1, msg2]))
69
+ if strict:
70
+ raise ValueError(message)
71
+ else:
72
+ if warn:
73
+ warnings.warn(message)
74
+
75
+ if self.scenarios.has_jinja_braces:
76
+ warnings.warn(
77
+ "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
+ )
79
+ self.scenarios = self.scenarios._convert_jinja_braces()
80
+
81
+
82
+ if __name__ == "__main__":
83
+ import doctest
84
+
85
+ doctest.testmod()
@@ -0,0 +1,120 @@
1
+ from typing import Optional, Literal
2
+ from dataclasses import dataclass, asdict
3
+
4
+ # from edsl.data_transfer_models import VisibilityType
5
+ from edsl.data.Cache import Cache
6
+ from edsl.jobs.buckets.BucketCollection import BucketCollection
7
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
8
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
9
+
10
+ VisibilityType = Literal["private", "public", "unlisted"]
11
+ from edsl.Base import Base
12
+
13
+
14
+ @dataclass
15
+ class RunEnvironment:
16
+ cache: Optional[Cache] = None
17
+ bucket_collection: Optional[BucketCollection] = None
18
+ key_lookup: Optional[KeyLookup] = None
19
+ jobs_runner_status: Optional["JobsRunnerStatus"] = None
20
+
21
+
22
+ @dataclass
23
+ class RunParameters(Base):
24
+ n: int = 1
25
+ progress_bar: bool = False
26
+ stop_on_exception: bool = False
27
+ check_api_keys: bool = False
28
+ verbose: bool = True
29
+ print_exceptions: bool = True
30
+ remote_cache_description: Optional[str] = None
31
+ remote_inference_description: Optional[str] = None
32
+ remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
33
+ skip_retry: bool = False
34
+ raise_validation_errors: bool = False
35
+ disable_remote_cache: bool = False
36
+ disable_remote_inference: bool = False
37
+ job_uuid: Optional[str] = None
38
+
39
+ def to_dict(self, add_edsl_version=False) -> dict:
40
+ d = asdict(self)
41
+ if add_edsl_version:
42
+ from edsl import __version__
43
+
44
+ d["edsl_version"] = __version__
45
+ d["edsl_class_name"] = "RunConfig"
46
+ return d
47
+
48
+ @classmethod
49
+ def from_dict(cls, data: dict) -> "RunConfig":
50
+ return cls(**data)
51
+
52
+ def code(self):
53
+ return f"RunConfig(**{self.to_dict()})"
54
+
55
+ @classmethod
56
+ def example(cls) -> "RunConfig":
57
+ return cls()
58
+
59
+
60
+ @dataclass
61
+ class RunConfig:
62
+ environment: RunEnvironment
63
+ parameters: RunParameters
64
+
65
+ def add_environment(self, environment: RunEnvironment):
66
+ self.environment = environment
67
+
68
+ def add_bucket_collection(self, bucket_collection: BucketCollection):
69
+ self.environment.bucket_collection = bucket_collection
70
+
71
+ def add_cache(self, cache: Cache):
72
+ self.environment.cache = cache
73
+
74
+ def add_key_lookup(self, key_lookup: KeyLookup):
75
+ self.environment.key_lookup = key_lookup
76
+
77
+
78
+ """This module contains the Answers class, which is a helper class to hold the answers to a survey."""
79
+
80
+ from collections import UserDict
81
+ from edsl.data_transfer_models import EDSLResultObjectInput
82
+
83
+
84
+ class Answers(UserDict):
85
+ """Helper class to hold the answers to a survey."""
86
+
87
+ def add_answer(
88
+ self, response: EDSLResultObjectInput, question: "QuestionBase"
89
+ ) -> None:
90
+ """Add a response to the answers dictionary."""
91
+ answer = response.answer
92
+ comment = response.comment
93
+ generated_tokens = response.generated_tokens
94
+ # record the answer
95
+ if generated_tokens:
96
+ self[question.question_name + "_generated_tokens"] = generated_tokens
97
+ self[question.question_name] = answer
98
+ if comment:
99
+ self[question.question_name + "_comment"] = comment
100
+
101
+ def replace_missing_answers_with_none(self, survey: "Survey") -> None:
102
+ """Replace missing answers with None. Answers can be missing if the agent skips a question."""
103
+ for question_name in survey.question_names:
104
+ if question_name not in self:
105
+ self[question_name] = None
106
+
107
+ def to_dict(self):
108
+ """Return a dictionary of the answers."""
109
+ return self.data
110
+
111
+ @classmethod
112
+ def from_dict(cls, d):
113
+ """Return an Answers object from a dictionary."""
114
+ return cls(d)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ import doctest
119
+
120
+ doctest.testmod()
@@ -0,0 +1,35 @@
1
+ from functools import wraps
2
+ from threading import RLock
3
+ import inspect
4
+
5
+
6
+ def synchronized_class(wrapped_class):
7
+ """Class decorator that makes all methods thread-safe."""
8
+
9
+ # Add a lock to the class
10
+ setattr(wrapped_class, "_lock", RLock())
11
+
12
+ # Get all methods from the class
13
+ for name, method in inspect.getmembers(wrapped_class, inspect.isfunction):
14
+ # Skip magic methods except __getitem__, __setitem__, __delitem__
15
+ if name.startswith("__") and name not in [
16
+ "__getitem__",
17
+ "__setitem__",
18
+ "__delitem__",
19
+ ]:
20
+ continue
21
+
22
+ # Create synchronized version of the method
23
+ def create_synchronized_method(method):
24
+ @wraps(method)
25
+ def synchronized_method(*args, **kwargs):
26
+ instance = args[0] # first arg is self
27
+ with instance._lock:
28
+ return method(*args, **kwargs)
29
+
30
+ return synchronized_method
31
+
32
+ # Replace the original method with synchronized version
33
+ setattr(wrapped_class, name, create_synchronized_method(method))
34
+
35
+ return wrapped_class