edsl 0.1.15__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (407) hide show
  1. edsl/Base.py +348 -38
  2. edsl/BaseDiff.py +260 -0
  3. edsl/TemplateLoader.py +24 -0
  4. edsl/__init__.py +45 -10
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +842 -144
  7. edsl/agents/AgentList.py +521 -25
  8. edsl/agents/Invigilator.py +250 -374
  9. edsl/agents/InvigilatorBase.py +257 -0
  10. edsl/agents/PromptConstructor.py +272 -0
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/descriptors.py +43 -13
  14. edsl/agents/prompt_helpers.py +129 -0
  15. edsl/agents/question_option_processor.py +172 -0
  16. edsl/auto/AutoStudy.py +130 -0
  17. edsl/auto/StageBase.py +243 -0
  18. edsl/auto/StageGenerateSurvey.py +178 -0
  19. edsl/auto/StageLabelQuestions.py +125 -0
  20. edsl/auto/StagePersona.py +61 -0
  21. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  22. edsl/auto/StagePersonaDimensionValues.py +74 -0
  23. edsl/auto/StagePersonaDimensions.py +69 -0
  24. edsl/auto/StageQuestions.py +74 -0
  25. edsl/auto/SurveyCreatorPipeline.py +21 -0
  26. edsl/auto/utilities.py +218 -0
  27. edsl/base/Base.py +279 -0
  28. edsl/config.py +115 -113
  29. edsl/conversation/Conversation.py +290 -0
  30. edsl/conversation/car_buying.py +59 -0
  31. edsl/conversation/chips.py +95 -0
  32. edsl/conversation/mug_negotiation.py +81 -0
  33. edsl/conversation/next_speaker_utilities.py +93 -0
  34. edsl/coop/CoopFunctionsMixin.py +15 -0
  35. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  36. edsl/coop/PriceFetcher.py +54 -0
  37. edsl/coop/__init__.py +1 -0
  38. edsl/coop/coop.py +1029 -134
  39. edsl/coop/utils.py +131 -0
  40. edsl/data/Cache.py +560 -89
  41. edsl/data/CacheEntry.py +230 -0
  42. edsl/data/CacheHandler.py +168 -0
  43. edsl/data/RemoteCacheSync.py +186 -0
  44. edsl/data/SQLiteDict.py +292 -0
  45. edsl/data/__init__.py +5 -3
  46. edsl/data/orm.py +6 -33
  47. edsl/data_transfer_models.py +74 -27
  48. edsl/enums.py +165 -8
  49. edsl/exceptions/BaseException.py +21 -0
  50. edsl/exceptions/__init__.py +52 -46
  51. edsl/exceptions/agents.py +33 -15
  52. edsl/exceptions/cache.py +5 -0
  53. edsl/exceptions/coop.py +8 -0
  54. edsl/exceptions/general.py +34 -0
  55. edsl/exceptions/inference_services.py +5 -0
  56. edsl/exceptions/jobs.py +15 -0
  57. edsl/exceptions/language_models.py +46 -1
  58. edsl/exceptions/questions.py +80 -5
  59. edsl/exceptions/results.py +16 -5
  60. edsl/exceptions/scenarios.py +29 -0
  61. edsl/exceptions/surveys.py +13 -10
  62. edsl/inference_services/AnthropicService.py +106 -0
  63. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  64. edsl/inference_services/AvailableModelFetcher.py +215 -0
  65. edsl/inference_services/AwsBedrock.py +118 -0
  66. edsl/inference_services/AzureAI.py +215 -0
  67. edsl/inference_services/DeepInfraService.py +18 -0
  68. edsl/inference_services/GoogleService.py +143 -0
  69. edsl/inference_services/GroqService.py +20 -0
  70. edsl/inference_services/InferenceServiceABC.py +80 -0
  71. edsl/inference_services/InferenceServicesCollection.py +138 -0
  72. edsl/inference_services/MistralAIService.py +120 -0
  73. edsl/inference_services/OllamaService.py +18 -0
  74. edsl/inference_services/OpenAIService.py +236 -0
  75. edsl/inference_services/PerplexityService.py +160 -0
  76. edsl/inference_services/ServiceAvailability.py +135 -0
  77. edsl/inference_services/TestService.py +90 -0
  78. edsl/inference_services/TogetherAIService.py +172 -0
  79. edsl/inference_services/data_structures.py +134 -0
  80. edsl/inference_services/models_available_cache.py +118 -0
  81. edsl/inference_services/rate_limits_cache.py +25 -0
  82. edsl/inference_services/registry.py +41 -0
  83. edsl/inference_services/write_available.py +10 -0
  84. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  85. edsl/jobs/Answers.py +21 -20
  86. edsl/jobs/FetchInvigilator.py +47 -0
  87. edsl/jobs/InterviewTaskManager.py +98 -0
  88. edsl/jobs/InterviewsConstructor.py +50 -0
  89. edsl/jobs/Jobs.py +684 -206
  90. edsl/jobs/JobsChecks.py +172 -0
  91. edsl/jobs/JobsComponentConstructor.py +189 -0
  92. edsl/jobs/JobsPrompts.py +270 -0
  93. edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
  94. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  95. edsl/jobs/RequestTokenEstimator.py +30 -0
  96. edsl/jobs/async_interview_runner.py +138 -0
  97. edsl/jobs/buckets/BucketCollection.py +104 -0
  98. edsl/jobs/buckets/ModelBuckets.py +65 -0
  99. edsl/jobs/buckets/TokenBucket.py +283 -0
  100. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  101. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  102. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  103. edsl/jobs/data_structures.py +120 -0
  104. edsl/jobs/decorators.py +35 -0
  105. edsl/jobs/interviews/Interview.py +392 -0
  106. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
  107. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
  108. edsl/jobs/interviews/InterviewStatistic.py +63 -0
  109. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
  110. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
  111. edsl/jobs/interviews/InterviewStatusLog.py +92 -0
  112. edsl/jobs/interviews/ReportErrors.py +66 -0
  113. edsl/jobs/interviews/interview_status_enum.py +9 -0
  114. edsl/jobs/jobs_status_enums.py +9 -0
  115. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  116. edsl/jobs/results_exceptions_handler.py +98 -0
  117. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
  118. edsl/jobs/runners/JobsRunnerStatus.py +298 -0
  119. edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
  120. edsl/jobs/tasks/TaskCreators.py +64 -0
  121. edsl/jobs/tasks/TaskHistory.py +470 -0
  122. edsl/jobs/tasks/TaskStatusLog.py +23 -0
  123. edsl/jobs/tasks/task_status_enum.py +161 -0
  124. edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
  125. edsl/jobs/tokens/TokenUsage.py +34 -0
  126. edsl/language_models/ComputeCost.py +63 -0
  127. edsl/language_models/LanguageModel.py +507 -386
  128. edsl/language_models/ModelList.py +164 -0
  129. edsl/language_models/PriceManager.py +127 -0
  130. edsl/language_models/RawResponseHandler.py +106 -0
  131. edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
  132. edsl/language_models/__init__.py +1 -8
  133. edsl/language_models/fake_openai_call.py +15 -0
  134. edsl/language_models/fake_openai_service.py +61 -0
  135. edsl/language_models/key_management/KeyLookup.py +63 -0
  136. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  137. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  138. edsl/language_models/key_management/__init__.py +0 -0
  139. edsl/language_models/key_management/models.py +131 -0
  140. edsl/language_models/model.py +256 -0
  141. edsl/language_models/repair.py +109 -41
  142. edsl/language_models/utilities.py +65 -0
  143. edsl/notebooks/Notebook.py +263 -0
  144. edsl/notebooks/NotebookToLaTeX.py +142 -0
  145. edsl/notebooks/__init__.py +1 -0
  146. edsl/prompts/Prompt.py +222 -93
  147. edsl/prompts/__init__.py +1 -1
  148. edsl/questions/ExceptionExplainer.py +77 -0
  149. edsl/questions/HTMLQuestion.py +103 -0
  150. edsl/questions/QuestionBase.py +518 -0
  151. edsl/questions/QuestionBasePromptsMixin.py +221 -0
  152. edsl/questions/QuestionBudget.py +164 -67
  153. edsl/questions/QuestionCheckBox.py +281 -62
  154. edsl/questions/QuestionDict.py +343 -0
  155. edsl/questions/QuestionExtract.py +136 -50
  156. edsl/questions/QuestionFreeText.py +79 -55
  157. edsl/questions/QuestionFunctional.py +138 -41
  158. edsl/questions/QuestionList.py +184 -57
  159. edsl/questions/QuestionMatrix.py +265 -0
  160. edsl/questions/QuestionMultipleChoice.py +293 -69
  161. edsl/questions/QuestionNumerical.py +109 -56
  162. edsl/questions/QuestionRank.py +244 -49
  163. edsl/questions/Quick.py +41 -0
  164. edsl/questions/SimpleAskMixin.py +74 -0
  165. edsl/questions/__init__.py +9 -6
  166. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
  167. edsl/questions/compose_questions.py +13 -7
  168. edsl/questions/data_structures.py +20 -0
  169. edsl/questions/decorators.py +21 -0
  170. edsl/questions/derived/QuestionLikertFive.py +28 -26
  171. edsl/questions/derived/QuestionLinearScale.py +41 -28
  172. edsl/questions/derived/QuestionTopK.py +34 -26
  173. edsl/questions/derived/QuestionYesNo.py +40 -27
  174. edsl/questions/descriptors.py +228 -74
  175. edsl/questions/loop_processor.py +149 -0
  176. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  177. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  178. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  179. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  180. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  181. edsl/questions/prompt_templates/question_list.jinja +17 -0
  182. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  183. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  184. edsl/questions/question_base_gen_mixin.py +168 -0
  185. edsl/questions/question_registry.py +130 -46
  186. edsl/questions/register_questions_meta.py +71 -0
  187. edsl/questions/response_validator_abc.py +188 -0
  188. edsl/questions/response_validator_factory.py +34 -0
  189. edsl/questions/settings.py +5 -2
  190. edsl/questions/templates/__init__.py +0 -0
  191. edsl/questions/templates/budget/__init__.py +0 -0
  192. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  193. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  194. edsl/questions/templates/checkbox/__init__.py +0 -0
  195. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  196. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  197. edsl/questions/templates/dict/__init__.py +0 -0
  198. edsl/questions/templates/dict/answering_instructions.jinja +21 -0
  199. edsl/questions/templates/dict/question_presentation.jinja +1 -0
  200. edsl/questions/templates/extract/__init__.py +0 -0
  201. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  202. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  203. edsl/questions/templates/free_text/__init__.py +0 -0
  204. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  205. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  206. edsl/questions/templates/likert_five/__init__.py +0 -0
  207. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  208. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  209. edsl/questions/templates/linear_scale/__init__.py +0 -0
  210. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  211. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  212. edsl/questions/templates/list/__init__.py +0 -0
  213. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  214. edsl/questions/templates/list/question_presentation.jinja +5 -0
  215. edsl/questions/templates/matrix/__init__.py +1 -0
  216. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  217. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  218. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  219. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  220. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  221. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  222. edsl/questions/templates/numerical/__init__.py +0 -0
  223. edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
  224. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  225. edsl/questions/templates/rank/__init__.py +0 -0
  226. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  227. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  228. edsl/questions/templates/top_k/__init__.py +0 -0
  229. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  230. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  231. edsl/questions/templates/yes_no/__init__.py +0 -0
  232. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  233. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  234. edsl/results/CSSParameterizer.py +108 -0
  235. edsl/results/Dataset.py +550 -19
  236. edsl/results/DatasetExportMixin.py +594 -0
  237. edsl/results/DatasetTree.py +295 -0
  238. edsl/results/MarkdownToDocx.py +122 -0
  239. edsl/results/MarkdownToPDF.py +111 -0
  240. edsl/results/Result.py +477 -173
  241. edsl/results/Results.py +987 -269
  242. edsl/results/ResultsExportMixin.py +28 -125
  243. edsl/results/ResultsGGMixin.py +83 -15
  244. edsl/results/TableDisplay.py +125 -0
  245. edsl/results/TextEditor.py +50 -0
  246. edsl/results/__init__.py +1 -1
  247. edsl/results/file_exports.py +252 -0
  248. edsl/results/results_fetch_mixin.py +33 -0
  249. edsl/results/results_selector.py +145 -0
  250. edsl/results/results_tools_mixin.py +98 -0
  251. edsl/results/smart_objects.py +96 -0
  252. edsl/results/table_data_class.py +12 -0
  253. edsl/results/table_display.css +78 -0
  254. edsl/results/table_renderers.py +118 -0
  255. edsl/results/tree_explore.py +115 -0
  256. edsl/scenarios/ConstructDownloadLink.py +109 -0
  257. edsl/scenarios/DocumentChunker.py +102 -0
  258. edsl/scenarios/DocxScenario.py +16 -0
  259. edsl/scenarios/FileStore.py +543 -0
  260. edsl/scenarios/PdfExtractor.py +40 -0
  261. edsl/scenarios/Scenario.py +431 -62
  262. edsl/scenarios/ScenarioHtmlMixin.py +65 -0
  263. edsl/scenarios/ScenarioList.py +1415 -45
  264. edsl/scenarios/ScenarioListExportMixin.py +45 -0
  265. edsl/scenarios/ScenarioListPdfMixin.py +239 -0
  266. edsl/scenarios/__init__.py +2 -0
  267. edsl/scenarios/directory_scanner.py +96 -0
  268. edsl/scenarios/file_methods.py +85 -0
  269. edsl/scenarios/handlers/__init__.py +13 -0
  270. edsl/scenarios/handlers/csv.py +49 -0
  271. edsl/scenarios/handlers/docx.py +76 -0
  272. edsl/scenarios/handlers/html.py +37 -0
  273. edsl/scenarios/handlers/json.py +111 -0
  274. edsl/scenarios/handlers/latex.py +5 -0
  275. edsl/scenarios/handlers/md.py +51 -0
  276. edsl/scenarios/handlers/pdf.py +68 -0
  277. edsl/scenarios/handlers/png.py +39 -0
  278. edsl/scenarios/handlers/pptx.py +105 -0
  279. edsl/scenarios/handlers/py.py +294 -0
  280. edsl/scenarios/handlers/sql.py +313 -0
  281. edsl/scenarios/handlers/sqlite.py +149 -0
  282. edsl/scenarios/handlers/txt.py +33 -0
  283. edsl/scenarios/scenario_join.py +131 -0
  284. edsl/scenarios/scenario_selector.py +156 -0
  285. edsl/shared.py +1 -0
  286. edsl/study/ObjectEntry.py +173 -0
  287. edsl/study/ProofOfWork.py +113 -0
  288. edsl/study/SnapShot.py +80 -0
  289. edsl/study/Study.py +521 -0
  290. edsl/study/__init__.py +4 -0
  291. edsl/surveys/ConstructDAG.py +92 -0
  292. edsl/surveys/DAG.py +92 -11
  293. edsl/surveys/EditSurvey.py +221 -0
  294. edsl/surveys/InstructionHandler.py +100 -0
  295. edsl/surveys/Memory.py +9 -4
  296. edsl/surveys/MemoryManagement.py +72 -0
  297. edsl/surveys/MemoryPlan.py +156 -35
  298. edsl/surveys/Rule.py +221 -74
  299. edsl/surveys/RuleCollection.py +241 -61
  300. edsl/surveys/RuleManager.py +172 -0
  301. edsl/surveys/Simulator.py +75 -0
  302. edsl/surveys/Survey.py +1079 -339
  303. edsl/surveys/SurveyCSS.py +273 -0
  304. edsl/surveys/SurveyExportMixin.py +235 -40
  305. edsl/surveys/SurveyFlowVisualization.py +181 -0
  306. edsl/surveys/SurveyQualtricsImport.py +284 -0
  307. edsl/surveys/SurveyToApp.py +141 -0
  308. edsl/surveys/__init__.py +4 -2
  309. edsl/surveys/base.py +19 -3
  310. edsl/surveys/descriptors.py +17 -6
  311. edsl/surveys/instructions/ChangeInstruction.py +48 -0
  312. edsl/surveys/instructions/Instruction.py +56 -0
  313. edsl/surveys/instructions/InstructionCollection.py +82 -0
  314. edsl/surveys/instructions/__init__.py +0 -0
  315. edsl/templates/error_reporting/base.html +24 -0
  316. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  317. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  318. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  319. edsl/templates/error_reporting/interview_details.html +116 -0
  320. edsl/templates/error_reporting/interviews.html +19 -0
  321. edsl/templates/error_reporting/overview.html +5 -0
  322. edsl/templates/error_reporting/performance_plot.html +2 -0
  323. edsl/templates/error_reporting/report.css +74 -0
  324. edsl/templates/error_reporting/report.html +118 -0
  325. edsl/templates/error_reporting/report.js +25 -0
  326. edsl/tools/__init__.py +1 -0
  327. edsl/tools/clusters.py +192 -0
  328. edsl/tools/embeddings.py +27 -0
  329. edsl/tools/embeddings_plotting.py +118 -0
  330. edsl/tools/plotting.py +112 -0
  331. edsl/tools/summarize.py +18 -0
  332. edsl/utilities/PrettyList.py +56 -0
  333. edsl/utilities/SystemInfo.py +5 -0
  334. edsl/utilities/__init__.py +21 -20
  335. edsl/utilities/ast_utilities.py +3 -0
  336. edsl/utilities/data/Registry.py +2 -0
  337. edsl/utilities/decorators.py +41 -0
  338. edsl/utilities/gcp_bucket/__init__.py +0 -0
  339. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  340. edsl/utilities/interface.py +310 -60
  341. edsl/utilities/is_notebook.py +18 -0
  342. edsl/utilities/is_valid_variable_name.py +11 -0
  343. edsl/utilities/naming_utilities.py +263 -0
  344. edsl/utilities/remove_edsl_version.py +24 -0
  345. edsl/utilities/repair_functions.py +28 -0
  346. edsl/utilities/restricted_python.py +70 -0
  347. edsl/utilities/utilities.py +203 -13
  348. edsl-0.1.40.dist-info/METADATA +111 -0
  349. edsl-0.1.40.dist-info/RECORD +362 -0
  350. {edsl-0.1.15.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
  351. edsl/agents/AgentListExportMixin.py +0 -24
  352. edsl/coop/old.py +0 -31
  353. edsl/data/Database.py +0 -141
  354. edsl/data/crud.py +0 -121
  355. edsl/jobs/Interview.py +0 -435
  356. edsl/jobs/JobsRunner.py +0 -63
  357. edsl/jobs/JobsRunnerStatusMixin.py +0 -115
  358. edsl/jobs/base.py +0 -47
  359. edsl/jobs/buckets.py +0 -178
  360. edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
  361. edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
  362. edsl/jobs/task_management.py +0 -215
  363. edsl/jobs/token_tracking.py +0 -78
  364. edsl/language_models/DeepInfra.py +0 -69
  365. edsl/language_models/OpenAI.py +0 -98
  366. edsl/language_models/model_interfaces/GeminiPro.py +0 -66
  367. edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
  368. edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
  369. edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
  370. edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
  371. edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
  372. edsl/language_models/registry.py +0 -81
  373. edsl/language_models/schemas.py +0 -15
  374. edsl/language_models/unused/ReplicateBase.py +0 -83
  375. edsl/prompts/QuestionInstructionsBase.py +0 -6
  376. edsl/prompts/library/agent_instructions.py +0 -29
  377. edsl/prompts/library/agent_persona.py +0 -17
  378. edsl/prompts/library/question_budget.py +0 -26
  379. edsl/prompts/library/question_checkbox.py +0 -32
  380. edsl/prompts/library/question_extract.py +0 -19
  381. edsl/prompts/library/question_freetext.py +0 -14
  382. edsl/prompts/library/question_linear_scale.py +0 -20
  383. edsl/prompts/library/question_list.py +0 -22
  384. edsl/prompts/library/question_multiple_choice.py +0 -44
  385. edsl/prompts/library/question_numerical.py +0 -31
  386. edsl/prompts/library/question_rank.py +0 -21
  387. edsl/prompts/prompt_config.py +0 -33
  388. edsl/prompts/registry.py +0 -185
  389. edsl/questions/Question.py +0 -240
  390. edsl/report/InputOutputDataTypes.py +0 -134
  391. edsl/report/RegressionMixin.py +0 -28
  392. edsl/report/ReportOutputs.py +0 -1228
  393. edsl/report/ResultsFetchMixin.py +0 -106
  394. edsl/report/ResultsOutputMixin.py +0 -14
  395. edsl/report/demo.ipynb +0 -645
  396. edsl/results/ResultsDBMixin.py +0 -184
  397. edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
  398. edsl/trackers/Tracker.py +0 -91
  399. edsl/trackers/TrackerAPI.py +0 -196
  400. edsl/trackers/TrackerTasks.py +0 -70
  401. edsl/utilities/pastebin.py +0 -141
  402. edsl-0.1.15.dist-info/METADATA +0 -69
  403. edsl-0.1.15.dist-info/RECORD +0 -142
  404. /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
  405. /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
  406. /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
  407. {edsl-0.1.15.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
edsl/coop/coop.py CHANGED
@@ -1,32 +1,101 @@
1
+ import aiohttp
1
2
  import json
2
3
  import requests
3
- from typing import Any, Optional, Type, Union
4
- from edsl import CONFIG
5
- from edsl.questions import Question
6
- from edsl.surveys import Survey
7
4
 
5
+ from typing import Any, Optional, Union, Literal, TypedDict
6
+ from uuid import UUID
7
+ from collections import UserDict, defaultdict
8
8
 
9
- api_url = {
10
- "development": "http://127.0.0.1:8000",
11
- "production": "https://www.expectedparrot.com",
12
- }
9
+ import edsl
10
+ from pathlib import Path
13
11
 
12
+ from edsl.config import CONFIG
13
+ from edsl.data.CacheEntry import CacheEntry
14
+ from edsl.jobs.Jobs import Jobs
15
+ from edsl.surveys.Survey import Survey
14
16
 
15
- class Coop:
16
- def __init__(self, api_key: str = None, run_mode: str = None) -> None:
17
- self.api_key = api_key or CONFIG.EXPECTED_PARROT_API_KEY
18
- self.run_mode = run_mode or CONFIG.EDSL_RUN_MODE
17
+ from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
18
+ from edsl.coop.utils import (
19
+ EDSLObject,
20
+ ObjectRegistry,
21
+ ObjectType,
22
+ RemoteJobStatus,
23
+ VisibilityType,
24
+ )
19
25
 
20
- def __repr__(self):
21
- return f"Client(api_key='{self.api_key}', run_mode='{self.run_mode}')"
26
+ from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
27
+ from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
28
+
29
+ from edsl.inference_services.data_structures import ServiceToModelsMapping
30
+
31
+
32
+ class RemoteInferenceResponse(TypedDict):
33
+ job_uuid: str
34
+ results_uuid: str
35
+ results_url: str
36
+ latest_error_report_uuid: str
37
+ latest_error_report_url: str
38
+ status: str
39
+ reason: str
40
+ credits_consumed: float
41
+ version: str
42
+
43
+
44
+ class RemoteInferenceCreationInfo(TypedDict):
45
+ uuid: str
46
+ description: str
47
+ status: str
48
+ iterations: int
49
+ visibility: str
50
+ version: str
22
51
 
23
- @property
24
- def headers(self) -> dict:
25
- return {"Authorization": f"Bearer {self.api_key}"}
26
52
 
53
+ class Coop(CoopFunctionsMixin):
54
+ """
55
+ Client for the Expected Parrot API.
56
+ """
57
+
58
+ def __init__(
59
+ self, api_key: Optional[str] = None, url: Optional[str] = None
60
+ ) -> None:
61
+ """
62
+ Initialize the client.
63
+ - Provide an API key directly, or through an env variable.
64
+ - Provide a URL directly, or use the default one.
65
+ """
66
+ self.ep_key_handler = ExpectedParrotKeyHandler()
67
+ self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
68
+
69
+ self.url = url or CONFIG.EXPECTED_PARROT_URL
70
+ if self.url.endswith("/"):
71
+ self.url = self.url[:-1]
72
+ if "chick.expectedparrot" in self.url:
73
+ self.api_url = "https://chickapi.expectedparrot.com"
74
+ elif "expectedparrot" in self.url:
75
+ self.api_url = "https://api.expectedparrot.com"
76
+ elif "localhost:1234" in self.url:
77
+ self.api_url = "http://localhost:8000"
78
+ else:
79
+ self.api_url = self.url
80
+ self._edsl_version = edsl.__version__
81
+
82
+ def get_progress_bar_url(self):
83
+ return f"{CONFIG.EXPECTED_PARROT_URL}"
84
+
85
+ ################
86
+ # BASIC METHODS
87
+ ################
27
88
  @property
28
- def url(self) -> str:
29
- return api_url[self.run_mode]
89
+ def headers(self) -> dict:
90
+ """
91
+ Return the headers for the request.
92
+ """
93
+ headers = {}
94
+ if self.api_key:
95
+ headers["Authorization"] = f"Bearer {self.api_key}"
96
+ else:
97
+ headers["Authorization"] = f"Bearer None"
98
+ return headers
30
99
 
31
100
  def _send_server_request(
32
101
  self,
@@ -34,181 +103,1007 @@ class Coop:
34
103
  method: str,
35
104
  payload: Optional[dict[str, Any]] = None,
36
105
  params: Optional[dict[str, Any]] = None,
106
+ timeout: Optional[float] = 5,
37
107
  ) -> requests.Response:
38
- """Sends a request to the server and returns the response."""
39
- url = f"{self.url}/{uri}"
108
+ """
109
+ Send a request to the server and return the response.
110
+ """
111
+ url = f"{self.api_url}/{uri}"
112
+ method = method.upper()
113
+ if payload is None:
114
+ timeout = 40
115
+ elif (
116
+ method.upper() == "POST"
117
+ and "json_string" in payload
118
+ and payload.get("json_string") is not None
119
+ ):
120
+ timeout = max(40, (len(payload.get("json_string", "")) // (1024 * 1024)))
121
+ try:
122
+ if method in ["GET", "DELETE"]:
123
+ response = requests.request(
124
+ method, url, params=params, headers=self.headers, timeout=timeout
125
+ )
126
+ elif method in ["POST", "PATCH"]:
127
+ response = requests.request(
128
+ method,
129
+ url,
130
+ params=params,
131
+ json=payload,
132
+ headers=self.headers,
133
+ timeout=timeout,
134
+ )
135
+ else:
136
+ raise Exception(f"Invalid {method=}.")
137
+ except requests.ConnectionError:
138
+ raise requests.ConnectionError(f"Could not connect to the server at {url}.")
40
139
 
41
- if method.upper() in ["GET", "DELETE"]:
42
- response = requests.request(
43
- method, url, params=params, headers=self.headers
44
- )
140
+ return response
141
+
142
+ def _get_latest_stable_version(self, version: str) -> str:
143
+ """
144
+ Extract the latest stable PyPI version from a version string.
145
+
146
+ Examples:
147
+ - Decrement the patch number of a dev version: "0.1.38.dev1" -> "0.1.37"
148
+ - Return a stable version as is: "0.1.37" -> "0.1.37"
149
+ """
150
+ if "dev" not in version:
151
+ return version
45
152
  else:
46
- response = requests.request(method, url, json=payload, headers=self.headers)
153
+ # For 0.1.38.dev1, split into ["0", "1", "38", "dev1"]
154
+ major, minor, patch = version.split(".")[:3]
47
155
 
48
- return response
156
+ current_patch = int(patch)
157
+ latest_patch = current_patch - 1
158
+ return f"{major}.{minor}.{latest_patch}"
159
+
160
+ def _user_version_is_outdated(
161
+ self, user_version_str: str, server_version_str: str
162
+ ) -> bool:
163
+ """
164
+ Check if the user's EDSL version is outdated compared to the server's.
165
+ """
166
+ server_stable_version_str = self._get_latest_stable_version(server_version_str)
167
+ user_stable_version_str = self._get_latest_stable_version(user_version_str)
168
+
169
+ # Turn the version strings into tuples of ints for comparison
170
+ user_stable_version = tuple(map(int, user_stable_version_str.split(".")))
171
+ server_stable_version = tuple(map(int, server_stable_version_str.split(".")))
172
+
173
+ return user_stable_version < server_stable_version
174
+
175
+ def _resolve_server_response(
176
+ self, response: requests.Response, check_api_key: bool = True
177
+ ) -> None:
178
+ """
179
+ Check the response from the server and raise errors as appropriate.
180
+ """
181
+ # Get EDSL version from header
182
+ # breakpoint()
183
+ server_edsl_version = response.headers.get("X-EDSL-Version")
184
+
185
+ if server_edsl_version:
186
+ if self._user_version_is_outdated(
187
+ user_version_str=self._edsl_version,
188
+ server_version_str=server_edsl_version,
189
+ ):
190
+ print(
191
+ "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
192
+ )
49
193
 
50
- def _resolve_server_response(self, response: requests.Response) -> None:
51
- """Checks the response from the server and raises appropriate errors."""
52
194
  if response.status_code >= 400:
53
- raise Exception(response.json().get("detail"))
195
+ try:
196
+ message = response.json().get("detail")
197
+ except json.JSONDecodeError:
198
+ raise CoopServerResponseError(
199
+ f"Server returned status code {response.status_code}."
200
+ "JSON response could not be decoded.",
201
+ "The server response was: " + response.text,
202
+ )
203
+ # print(response.text)
204
+ if "The API key you provided is invalid" in message and check_api_key:
205
+ import secrets
206
+ from edsl.utilities.utilities import write_api_key_to_env
207
+
208
+ edsl_auth_token = secrets.token_urlsafe(16)
209
+
210
+ print("Your Expected Parrot API key is invalid.")
211
+ self._display_login_url(
212
+ edsl_auth_token=edsl_auth_token,
213
+ link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
214
+ )
215
+ api_key = self._poll_for_api_key(edsl_auth_token)
216
+
217
+ if api_key is None:
218
+ print("\nTimed out waiting for login. Please try again.")
219
+ return
220
+
221
+ print("\n✨ API key retrieved.")
222
+
223
+ if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
224
+ pass
225
+ else:
226
+ path_to_env = write_api_key_to_env(api_key)
227
+ print(
228
+ "\n✨ API key retrieved and written to .env file at the following path:"
229
+ )
230
+ print(f" {path_to_env}")
231
+ print("Rerun your code to try again with a valid API key.")
232
+ return
233
+
234
+ elif "Authorization" in message:
235
+ print(message)
236
+ message = "Please provide an Expected Parrot API key."
237
+
238
+ raise CoopServerResponseError(message)
239
+
240
+ def _poll_for_api_key(
241
+ self, edsl_auth_token: str, timeout: int = 120
242
+ ) -> Union[str, None]:
243
+ """
244
+ Allows the user to retrieve their Expected Parrot API key by logging in with an EDSL auth token.
245
+
246
+ :param edsl_auth_token: The EDSL auth token to use for login
247
+ :param timeout: Maximum time to wait for login, in seconds (default: 120)
248
+ """
249
+ import time
250
+ from datetime import datetime
251
+
252
+ start_poll_time = time.time()
253
+ waiting_for_login = True
254
+ while waiting_for_login:
255
+ elapsed_time = time.time() - start_poll_time
256
+ if elapsed_time > timeout:
257
+ # Timed out waiting for the user to log in
258
+ print("\r" + " " * 80 + "\r", end="")
259
+ return None
260
+
261
+ api_key = self._get_api_key(edsl_auth_token)
262
+ if api_key is not None:
263
+ print("\r" + " " * 80 + "\r", end="")
264
+ return api_key
265
+ else:
266
+ duration = 5
267
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
268
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
269
+ start_time = time.time()
270
+ i = 0
271
+ while time.time() - start_time < duration:
272
+ print(
273
+ f"\r{frames[i % len(frames)]} Waiting for login. Last checked: {time_checked}",
274
+ end="",
275
+ flush=True,
276
+ )
277
+ time.sleep(0.1)
278
+ i += 1
279
+
280
+ def _json_handle_none(self, value: Any) -> Any:
281
+ """
282
+ Handle None values during JSON serialization.
283
+ - Return "null" if the value is None. Otherwise, don't return anything.
284
+ """
285
+ if value is None:
286
+ return "null"
287
+
288
+ def _resolve_uuid(
289
+ self, uuid: Union[str, UUID] = None, url: str = None
290
+ ) -> Union[str, UUID]:
291
+ """
292
+ Resolve the uuid from a uuid or a url.
293
+ """
294
+ if not url and not uuid:
295
+ raise CoopNoUUIDError("No uuid or url provided for the object.")
296
+ if not uuid and url:
297
+ uuid = url.split("/")[-1]
298
+ return uuid
299
+
300
+ @property
301
+ def edsl_settings(self) -> dict:
302
+ """
303
+ Retrieve and return the EDSL settings stored on Coop.
304
+ If no response is received within 5 seconds, return an empty dict.
305
+ """
306
+ from requests.exceptions import Timeout
307
+
308
+ try:
309
+ response = self._send_server_request(
310
+ uri="api/v0/edsl-settings", method="GET", timeout=5
311
+ )
312
+ self._resolve_server_response(response, check_api_key=False)
313
+ return response.json()
314
+ except Timeout:
315
+ return {}
54
316
 
55
- # QUESTIONS METHODS
56
- def create_question(self, question: Type[Question], public: bool = False) -> dict:
317
+ ################
318
+ # Objects
319
+ ################
320
+ def create(
321
+ self,
322
+ object: EDSLObject,
323
+ description: Optional[str] = None,
324
+ alias: Optional[str] = None,
325
+ visibility: Optional[VisibilityType] = "unlisted",
326
+ ) -> dict:
57
327
  """
58
- Creates a Question object.
59
- - `question`: the EDSL Question to be sent.
60
- - `public`: whether the question should be public (defaults to False)
328
+ Create an EDSL object in the Coop server.
61
329
  """
330
+ object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
62
331
  response = self._send_server_request(
63
- uri="api/v0/questions",
332
+ uri=f"api/v0/object",
64
333
  method="POST",
65
- payload={"json_string": json.dumps(question.to_dict()), "public": public},
334
+ payload={
335
+ "description": description,
336
+ "alias": alias,
337
+ "json_string": json.dumps(
338
+ object.to_dict(),
339
+ default=self._json_handle_none,
340
+ ),
341
+ "object_type": object_type,
342
+ "visibility": visibility,
343
+ "version": self._edsl_version,
344
+ },
66
345
  )
67
346
  self._resolve_server_response(response)
68
- return response.json()
347
+ response_json = response.json()
348
+ return {
349
+ "description": response_json.get("description"),
350
+ "object_type": object_type,
351
+ "url": f"{self.url}/content/{response_json.get('uuid')}",
352
+ "uuid": response_json.get("uuid"),
353
+ "version": self._edsl_version,
354
+ "visibility": response_json.get("visibility"),
355
+ }
356
+
357
+ def get(
358
+ self,
359
+ uuid: Union[str, UUID] = None,
360
+ url: str = None,
361
+ expected_object_type: Optional[ObjectType] = None,
362
+ ) -> EDSLObject:
363
+ """
364
+ Retrieve an EDSL object by its uuid or its url.
365
+ - If the object's visibility is private, the user must be the owner.
366
+ - Optionally, check if the retrieved object is of a certain type.
69
367
 
70
- def get_question(self, id: int) -> Type[Question]:
71
- """Retrieves a Question object by id."""
72
- response = self._send_server_request(uri=f"api/v0/questions/{id}", method="GET")
368
+ :param uuid: the uuid of the object either in str or UUID format.
369
+ :param url: the url of the object.
370
+ :param expected_object_type: the expected type of the object.
371
+
372
+ :return: the object instance.
373
+ """
374
+ uuid = self._resolve_uuid(uuid, url)
375
+ response = self._send_server_request(
376
+ uri=f"api/v0/object",
377
+ method="GET",
378
+ params={"uuid": uuid},
379
+ )
73
380
  self._resolve_server_response(response)
74
- return Question.from_dict(json.loads(response.json().get("json_string")))
381
+ json_string = response.json().get("json_string")
382
+ object_type = response.json().get("object_type")
383
+ if expected_object_type and object_type != expected_object_type:
384
+ raise Exception(f"Expected {expected_object_type=} but got {object_type=}")
385
+ edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
386
+ object = edsl_class.from_dict(json.loads(json_string))
387
+ return object
75
388
 
76
- @property
77
- def questions(self) -> list[dict[str, Union[int, Question]]]:
78
- """Retrieves all Questions."""
79
- response = self._send_server_request(uri="api/v0/questions", method="GET")
389
+ def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
390
+ """
391
+ Retrieve all objects of a certain type associated with the user.
392
+ """
393
+ edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
394
+ response = self._send_server_request(
395
+ uri=f"api/v0/objects",
396
+ method="GET",
397
+ params={"type": object_type},
398
+ )
80
399
  self._resolve_server_response(response)
81
- questions = [
400
+ objects = [
82
401
  {
83
- "id": q.get("id"),
84
- "question": Question.from_dict(json.loads(q.get("json_string"))),
402
+ "object": edsl_class.from_dict(json.loads(o.get("json_string"))),
403
+ "uuid": o.get("uuid"),
404
+ "version": o.get("version"),
405
+ "description": o.get("description"),
406
+ "visibility": o.get("visibility"),
407
+ "url": f"{self.url}/content/{o.get('uuid')}",
85
408
  }
86
- for q in response.json()
409
+ for o in response.json()
87
410
  ]
88
- return questions
411
+ return objects
89
412
 
90
- def delete_question(self, id: int) -> dict:
91
- """Deletes a question from the coop."""
413
+ def delete(self, uuid: Union[str, UUID] = None, url: str = None) -> dict:
414
+ """
415
+ Delete an object from the server.
416
+ """
417
+ uuid = self._resolve_uuid(uuid, url)
92
418
  response = self._send_server_request(
93
- uri=f"api/v0/questions/{id}", method="DELETE"
419
+ uri=f"api/v0/object",
420
+ method="DELETE",
421
+ params={"uuid": uuid},
94
422
  )
95
423
  self._resolve_server_response(response)
96
424
  return response.json()
97
425
 
98
- # Surveys METHODS
99
- def create_survey(self, survey: Type[Survey], public: bool = False) -> dict:
426
+ def patch(
427
+ self,
428
+ uuid: Union[str, UUID] = None,
429
+ url: str = None,
430
+ description: Optional[str] = None,
431
+ alias: Optional[str] = None,
432
+ value: Optional[EDSLObject] = None,
433
+ visibility: Optional[VisibilityType] = None,
434
+ ) -> dict:
100
435
  """
101
- Creates a Question object.
102
- - `survey`: the EDSL Survey to be sent.
103
- - `public`: whether the survey should be public (defaults to False)
436
+ Change the attributes of an uploaded object
437
+ - Only supports visibility for now
104
438
  """
439
+ if description is None and visibility is None and value is None:
440
+ raise Exception("Nothing to patch.")
441
+ uuid = self._resolve_uuid(uuid, url)
105
442
  response = self._send_server_request(
106
- uri="api/v0/surveys",
107
- method="POST",
108
- payload={"json_string": json.dumps(survey.to_dict()), "public": public},
443
+ uri=f"api/v0/object",
444
+ method="PATCH",
445
+ params={"uuid": uuid},
446
+ payload={
447
+ "description": description,
448
+ "alias": alias,
449
+ "json_string": (
450
+ json.dumps(
451
+ value.to_dict(),
452
+ default=self._json_handle_none,
453
+ )
454
+ if value
455
+ else None
456
+ ),
457
+ "visibility": visibility,
458
+ },
109
459
  )
110
460
  self._resolve_server_response(response)
111
461
  return response.json()
112
462
 
113
- def get_survey(self, id: int) -> Type[Survey]:
114
- """Retrieves a Survey object by id."""
115
- response = self._send_server_request(uri=f"api/v0/surveys/{id}", method="GET")
116
- self._resolve_server_response(response)
117
- return Survey.from_dict(json.loads(response.json().get("json_string")))
463
+ ################
464
+ # Remote Cache
465
+ ################
466
+ def remote_cache_create(
467
+ self,
468
+ cache_entry: CacheEntry,
469
+ visibility: VisibilityType = "private",
470
+ description: Optional[str] = None,
471
+ ) -> dict:
472
+ """
473
+ Create a single remote cache entry.
474
+ If an entry with the same key already exists in the database, update it instead.
118
475
 
119
- @property
120
- def surveys(self) -> list[dict[str, Union[int, Survey]]]:
121
- """Retrieves all Surveys."""
122
- response = self._send_server_request(uri="api/v0/surveys", method="GET")
476
+ :param cache_entry: The cache entry to send to the server.
477
+ :param visibility: The visibility of the cache entry.
478
+ :param optional description: A description for this entry in the remote cache.
479
+
480
+ >>> entry = CacheEntry.example()
481
+ >>> coop.remote_cache_create(cache_entry=entry)
482
+ {'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
483
+ """
484
+ response = self._send_server_request(
485
+ uri="api/v0/remote-cache",
486
+ method="POST",
487
+ payload={
488
+ "json_string": json.dumps(cache_entry.to_dict()),
489
+ "version": self._edsl_version,
490
+ "visibility": visibility,
491
+ "description": description,
492
+ },
493
+ )
123
494
  self._resolve_server_response(response)
124
- surveys = [
495
+ response_json = response.json()
496
+ created_entry_count = response_json.get("created_entry_count", 0)
497
+ if created_entry_count > 0:
498
+ self.remote_cache_create_log(
499
+ response,
500
+ description="Upload new cache entries to server",
501
+ cache_entry_count=created_entry_count,
502
+ )
503
+ return response.json()
504
+
505
+ def remote_cache_create_many(
506
+ self,
507
+ cache_entries: list[CacheEntry],
508
+ visibility: VisibilityType = "private",
509
+ description: Optional[str] = None,
510
+ ) -> dict:
511
+ """
512
+ Create many remote cache entries.
513
+ If an entry with the same key already exists in the database, update it instead.
514
+
515
+ :param cache_entries: The list of cache entries to send to the server.
516
+ :param visibility: The visibility of the cache entries.
517
+ :param optional description: A description for these entries in the remote cache.
518
+
519
+ >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
520
+ >>> coop.remote_cache_create_many(cache_entries=entries)
521
+ {'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
522
+ """
523
+ payload = [
125
524
  {
126
- "id": q.get("id"),
127
- "survey": Survey.from_dict(json.loads(q.get("json_string"))),
525
+ "json_string": json.dumps(c.to_dict()),
526
+ "version": self._edsl_version,
527
+ "visibility": visibility,
528
+ "description": description,
128
529
  }
129
- for q in response.json()
530
+ for c in cache_entries
531
+ ]
532
+ response = self._send_server_request(
533
+ uri="api/v0/remote-cache/many",
534
+ method="POST",
535
+ payload=payload,
536
+ timeout=40,
537
+ )
538
+ self._resolve_server_response(response)
539
+ response_json = response.json()
540
+ created_entry_count = response_json.get("created_entry_count", 0)
541
+ if created_entry_count > 0:
542
+ self.remote_cache_create_log(
543
+ response,
544
+ description="Upload new cache entries to server",
545
+ cache_entry_count=created_entry_count,
546
+ )
547
+ return response.json()
548
+
549
+ def remote_cache_get(
550
+ self,
551
+ exclude_keys: Optional[list[str]] = None,
552
+ ) -> list[CacheEntry]:
553
+ """
554
+ Get all remote cache entries.
555
+
556
+ :param optional exclude_keys: Exclude CacheEntry objects with these keys.
557
+
558
+ >>> coop.remote_cache_get()
559
+ [CacheEntry(...), CacheEntry(...), ...]
560
+ """
561
+ if exclude_keys is None:
562
+ exclude_keys = []
563
+ response = self._send_server_request(
564
+ uri="api/v0/remote-cache/get-many",
565
+ method="POST",
566
+ payload={"keys": exclude_keys},
567
+ timeout=40,
568
+ )
569
+ self._resolve_server_response(response)
570
+ return [
571
+ CacheEntry.from_dict(json.loads(v.get("json_string")))
572
+ for v in response.json()
130
573
  ]
131
- return surveys
132
574
 
133
- def delete_survey(self, id: int) -> dict:
134
- """Deletes a Survey from the coop."""
575
+ def remote_cache_get_diff(
576
+ self,
577
+ client_cacheentry_keys: list[str],
578
+ ) -> dict:
579
+ """
580
+ Get the difference between local and remote cache entries for a user.
581
+ """
582
+ response = self._send_server_request(
583
+ uri="api/v0/remote-cache/get-diff",
584
+ method="POST",
585
+ payload={"keys": client_cacheentry_keys},
586
+ timeout=40,
587
+ )
588
+ self._resolve_server_response(response)
589
+ response_json = response.json()
590
+ response_dict = {
591
+ "client_missing_cacheentries": [
592
+ CacheEntry.from_dict(json.loads(c.get("json_string")))
593
+ for c in response_json.get("client_missing_cacheentries", [])
594
+ ],
595
+ "server_missing_cacheentry_keys": response_json.get(
596
+ "server_missing_cacheentry_keys", []
597
+ ),
598
+ }
599
+ downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
600
+ if downloaded_entry_count > 0:
601
+ self.remote_cache_create_log(
602
+ response,
603
+ description="Download missing cache entries to client",
604
+ cache_entry_count=downloaded_entry_count,
605
+ )
606
+ return response_dict
607
+
608
+ def remote_cache_clear(self) -> dict:
609
+ """
610
+ Clear all remote cache entries.
611
+
612
+ >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
613
+ >>> coop.remote_cache_create_many(cache_entries=entries)
614
+ >>> coop.remote_cache_clear()
615
+ {'status': 'success', 'deleted_entry_count': 10}
616
+ """
135
617
  response = self._send_server_request(
136
- uri=f"api/v0/surveys/{id}", method="DELETE"
618
+ uri="api/v0/remote-cache/delete-all",
619
+ method="DELETE",
137
620
  )
138
621
  self._resolve_server_response(response)
622
+ response_json = response.json()
623
+ deleted_entry_count = response_json.get("deleted_entry_count", 0)
624
+ if deleted_entry_count > 0:
625
+ self.remote_cache_create_log(
626
+ response,
627
+ description="Clear cache entries",
628
+ cache_entry_count=deleted_entry_count,
629
+ )
139
630
  return response.json()
140
631
 
632
+ def remote_cache_create_log(
633
+ self, response: requests.Response, description: str, cache_entry_count: int
634
+ ) -> Union[dict, None]:
635
+ """
636
+ If a remote cache action has been completed successfully,
637
+ log the action.
638
+ """
639
+ if 200 <= response.status_code < 300:
640
+ log_response = self._send_server_request(
641
+ uri="api/v0/remote-cache-log",
642
+ method="POST",
643
+ payload={
644
+ "description": description,
645
+ "cache_entry_count": cache_entry_count,
646
+ },
647
+ )
648
+ self._resolve_server_response(log_response)
649
+ return response.json()
141
650
 
142
- if __name__ == "__main__":
143
- from edsl.coop import Coop
651
+ def remote_cache_clear_log(self) -> dict:
652
+ """
653
+ Clear all remote cache log entries.
144
654
 
145
- API_KEY = "p-llmNVgNM8pnzCWZQ6-sDCdlMgRgithISctb_9yzqU"
146
- RUN_MODE = "development"
147
- coop = Coop(api_key=API_KEY, run_mode=RUN_MODE)
655
+ >>> coop.remote_cache_clear_log()
656
+ {'status': 'success'}
657
+ """
658
+ response = self._send_server_request(
659
+ uri="api/v0/remote-cache-log/delete-all",
660
+ method="DELETE",
661
+ )
662
+ self._resolve_server_response(response)
663
+ return response.json()
148
664
 
149
- # basics
150
- coop
151
- coop.headers
152
- coop.url
665
+ def remote_inference_create(
666
+ self,
667
+ job: Jobs,
668
+ description: Optional[str] = None,
669
+ status: RemoteJobStatus = "queued",
670
+ visibility: Optional[VisibilityType] = "unlisted",
671
+ initial_results_visibility: Optional[VisibilityType] = "unlisted",
672
+ iterations: Optional[int] = 1,
673
+ ) -> RemoteInferenceCreationInfo:
674
+ """
675
+ Send a remote inference job to the server.
153
676
 
154
- ##############
155
- # A. QUESTIONS
156
- ##############
157
- from edsl.questions import QuestionMultipleChoice
158
- from edsl.questions import QuestionCheckBox
159
- from edsl.questions import QuestionFreeText
677
+ :param job: The EDSL job to send to the server.
678
+ :param optional description: A description for this entry in the remote cache.
679
+ :param status: The status of the job. Should be 'queued', unless you are debugging.
680
+ :param visibility: The visibility of the cache entry.
681
+ :param iterations: The number of times to run each interview.
160
682
 
161
- # check questions on server (should be an empty list)
162
- coop.questions
163
- for question in coop.questions:
164
- coop.delete_question(question.get("id"))
683
+ >>> job = Jobs.example()
684
+ >>> coop.remote_inference_create(job=job, description="My job")
685
+ {'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'iterations': None, 'visibility': 'unlisted', 'version': '0.1.38.dev1'}
686
+ """
687
+ response = self._send_server_request(
688
+ uri="api/v0/remote-inference",
689
+ method="POST",
690
+ payload={
691
+ "json_string": json.dumps(
692
+ job.to_dict(),
693
+ default=self._json_handle_none,
694
+ ),
695
+ "description": description,
696
+ "status": status,
697
+ "iterations": iterations,
698
+ "visibility": visibility,
699
+ "version": self._edsl_version,
700
+ "initial_results_visibility": initial_results_visibility,
701
+ },
702
+ )
703
+ self._resolve_server_response(response)
704
+ response_json = response.json()
165
705
 
166
- # get a question that does not exist (should return None)
167
- coop.get_question(id=1)
706
+ return RemoteInferenceCreationInfo(
707
+ **{
708
+ "uuid": response_json.get("job_uuid"),
709
+ "description": response_json.get("description"),
710
+ "status": response_json.get("status"),
711
+ "iterations": response_json.get("iterations"),
712
+ "visibility": response_json.get("visibility"),
713
+ "version": self._edsl_version,
714
+ }
715
+ )
168
716
 
169
- # now post a Question
170
- coop.create_question(QuestionMultipleChoice.example())
171
- coop.create_question(QuestionCheckBox.example(), public=False)
172
- coop.create_question(QuestionFreeText.example(), public=True)
717
+ def remote_inference_get(
718
+ self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
719
+ ) -> RemoteInferenceResponse:
720
+ """
721
+ Get the details of a remote inference job.
722
+ You can pass either the job uuid or the results uuid as a parameter.
723
+ If you pass both, the job uuid will be prioritized.
173
724
 
174
- # check all questions
175
- coop.questions
725
+ :param job_uuid: The UUID of the EDSL job.
726
+ :param results_uuid: The UUID of the results associated with the EDSL job.
176
727
 
177
- # or get question by id
178
- coop.get_question(id=1)
728
+ >>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
729
+ {'job_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'latest_error_report_uuid': None, 'latest_error_report_url': None, 'status': 'completed', 'reason': None, 'credits_consumed': 0.35, 'version': '0.1.38.dev1'}
730
+ """
731
+ if job_uuid is None and results_uuid is None:
732
+ raise ValueError("Either job_uuid or results_uuid must be provided.")
733
+ elif job_uuid is not None:
734
+ params = {"job_uuid": job_uuid}
735
+ else:
736
+ params = {"results_uuid": results_uuid}
179
737
 
180
- # delete the question
181
- coop.delete_question(id=1)
738
+ response = self._send_server_request(
739
+ uri="api/v0/remote-inference",
740
+ method="GET",
741
+ params=params,
742
+ )
743
+ self._resolve_server_response(response)
744
+ data = response.json()
182
745
 
183
- # check all questions
184
- coop.questions
746
+ results_uuid = data.get("results_uuid")
747
+ latest_error_report_uuid = data.get("latest_error_report_uuid")
185
748
 
186
- ##############
187
- # B. Surveys
188
- ##############
189
- from edsl.surveys import Survey
749
+ if results_uuid is None:
750
+ results_url = None
751
+ else:
752
+ results_url = f"{self.url}/content/{results_uuid}"
753
+
754
+ if latest_error_report_uuid is None:
755
+ latest_error_report_url = None
756
+ else:
757
+ latest_error_report_url = (
758
+ f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
759
+ )
760
+
761
+ return RemoteInferenceResponse(
762
+ **{
763
+ "job_uuid": data.get("job_uuid"),
764
+ "results_uuid": results_uuid,
765
+ "results_url": results_url,
766
+ "latest_error_report_uuid": latest_error_report_uuid,
767
+ "latest_error_report_url": latest_error_report_url,
768
+ "status": data.get("status"),
769
+ "reason": data.get("reason"),
770
+ "credits_consumed": data.get("price"),
771
+ "version": data.get("version"),
772
+ }
773
+ )
774
+
775
+ def get_running_jobs(self) -> list[str]:
776
+ """
777
+ Get a list of currently running job IDs.
778
+
779
+ Returns:
780
+ list[str]: List of running job UUIDs
781
+ """
782
+ response = self._send_server_request(uri="jobs/status", method="GET")
783
+ self._resolve_server_response(response)
784
+ return response.json().get("running_jobs", [])
785
+
786
+ def remote_inference_cost(
787
+ self, input: Union[Jobs, Survey], iterations: int = 1
788
+ ) -> int:
789
+ """
790
+ Get the cost of a remote inference job.
791
+
792
+ :param input: The EDSL job to send to the server.
793
+
794
+ >>> job = Jobs.example()
795
+ >>> coop.remote_inference_cost(input=job)
796
+ {'credits': 0.77, 'usd': 0.0076950000000000005}
797
+ """
798
+ if isinstance(input, Jobs):
799
+ job = input
800
+ elif isinstance(input, Survey):
801
+ job = Jobs(survey=input)
802
+ else:
803
+ raise TypeError("Input must be either a Job or a Survey.")
804
+
805
+ response = self._send_server_request(
806
+ uri="api/v0/remote-inference/cost",
807
+ method="POST",
808
+ payload={
809
+ "json_string": json.dumps(
810
+ job.to_dict(),
811
+ default=self._json_handle_none,
812
+ ),
813
+ "iterations": iterations,
814
+ },
815
+ )
816
+ self._resolve_server_response(response)
817
+ response_json = response.json()
818
+ return {
819
+ "credits": response_json.get("cost_in_credits"),
820
+ "usd": response_json.get("cost_in_usd"),
821
+ }
822
+
823
+ ################
824
+ # DUNDER METHODS
825
+ ################
826
+ def __repr__(self):
827
+ """Return a string representation of the client."""
828
+ return f"Client(api_key='{self.api_key}', url='{self.url}')"
829
+
830
+ ################
831
+ # EXPERIMENTAL
832
+ ################
833
+ async def remote_async_execute_model_call(
834
+ self, model_dict: dict, user_prompt: str, system_prompt: str
835
+ ) -> dict:
836
+ url = self.api_url + "/inference/"
837
+ # print("Now using url: ", url)
838
+ data = {
839
+ "model_dict": model_dict,
840
+ "user_prompt": user_prompt,
841
+ "system_prompt": system_prompt,
842
+ }
843
+ # Use aiohttp to send a POST request asynchronously
844
+ async with aiohttp.ClientSession() as session:
845
+ async with session.post(url, json=data) as response:
846
+ response_data = await response.json()
847
+ return response_data
190
848
 
191
- # check surveys on server (should be an empty list)
192
- coop.surveys
193
- for survey in coop.surveys:
194
- coop.delete_survey(survey.get("id"))
849
+ def web(
850
+ self,
851
+ survey: dict,
852
+ platform: Literal[
853
+ "google_forms", "lime_survey", "survey_monkey"
854
+ ] = "lime_survey",
855
+ email=None,
856
+ ):
857
+ url = f"{self.api_url}/api/v0/export_to_{platform}"
858
+ if email:
859
+ data = {"json_string": json.dumps({"survey": survey, "email": email})}
860
+ else:
861
+ data = {"json_string": json.dumps({"survey": survey, "email": ""})}
862
+
863
+ response_json = requests.post(url, headers=self.headers, data=json.dumps(data))
864
+
865
+ return response_json
866
+
867
+ def fetch_prices(self) -> dict:
868
+ """
869
+ Fetch model prices from Coop. If the request fails, return an empty dict.
870
+ """
871
+
872
+ from edsl.coop.PriceFetcher import PriceFetcher
195
873
 
196
- # get a survey that does not exist (should return None)
197
- coop.get_survey(id=1)
874
+ from edsl.config import CONFIG
198
875
 
199
- # now post a Survey
200
- coop.create_survey(Survey.example())
201
- coop.create_survey(Survey.example(), public=False)
202
- coop.create_survey(Survey.example(), public=True)
876
+ if CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "True":
877
+ price_fetcher = PriceFetcher()
878
+ return price_fetcher.fetch_prices()
879
+ elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
880
+ return {}
881
+ else:
882
+ raise ValueError(
883
+ "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
884
+ )
203
885
 
204
- # check all surveys
205
- coop.surveys
886
+ def fetch_models(self) -> ServiceToModelsMapping:
887
+ """
888
+ Fetch a dict of available models from Coop.
206
889
 
207
- # or get survey by id
208
- coop.get_survey(id=1)
890
+ Each key in the dict is an inference service, and each value is a list of models from that service.
891
+ """
892
+ response = self._send_server_request(uri="api/v0/models", method="GET")
893
+ self._resolve_server_response(response)
894
+ data = response.json()
895
+ return ServiceToModelsMapping(data)
209
896
 
210
- # delete the survey
211
- coop.delete_survey(id=1)
897
+ def fetch_rate_limit_config_vars(self) -> dict:
898
+ """
899
+ Fetch a dict of rate limit config vars from Coop.
212
900
 
213
- # check all surveys
214
- coop.surveys
901
+ The dict keys are RPM and TPM variables like EDSL_SERVICE_RPM_OPENAI.
902
+ """
903
+ response = self._send_server_request(
904
+ uri="api/v0/config-vars",
905
+ method="GET",
906
+ )
907
+ self._resolve_server_response(response)
908
+ data = response.json()
909
+ return data
910
+
911
+ def _display_login_url(
912
+ self, edsl_auth_token: str, link_description: Optional[str] = None
913
+ ):
914
+ """
915
+ Uses rich.print to display a login URL.
916
+
917
+ - We need this function because URL detection with print() does not work alongside animations in VSCode.
918
+ """
919
+ from rich import print as rich_print
920
+
921
+ url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
922
+
923
+ if link_description:
924
+ rich_print(
925
+ f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
926
+ )
927
+ else:
928
+ rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
929
+
930
+ def _get_api_key(self, edsl_auth_token: str):
931
+ """
932
+ Given an EDSL auth token, find the corresponding user's API key.
933
+ """
934
+
935
+ response = self._send_server_request(
936
+ uri="api/v0/get-api-key",
937
+ method="POST",
938
+ payload={
939
+ "edsl_auth_token": edsl_auth_token,
940
+ },
941
+ )
942
+ data = response.json()
943
+ api_key = data.get("api_key")
944
+ return api_key
945
+
946
+ def login(self):
947
+ """
948
+ Starts the EDSL auth token login flow.
949
+ """
950
+ import secrets
951
+ from dotenv import load_dotenv
952
+ from edsl.utilities.utilities import write_api_key_to_env
953
+
954
+ edsl_auth_token = secrets.token_urlsafe(16)
955
+
956
+ self._display_login_url(
957
+ edsl_auth_token=edsl_auth_token,
958
+ link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
959
+ )
960
+ api_key = self._poll_for_api_key(edsl_auth_token)
961
+
962
+ if api_key is None:
963
+ raise Exception("Timed out waiting for login. Please try again.")
964
+
965
+ path_to_env = write_api_key_to_env(api_key)
966
+ print("\n✨ API key retrieved and written to .env file at the following path:")
967
+ print(f" {path_to_env}")
968
+
969
+ # Add API key to environment
970
+ load_dotenv()
971
+
972
+
973
+ def main():
974
+ """
975
+ A simple example for the coop client
976
+ """
977
+ from uuid import uuid4
978
+ from edsl import (
979
+ Agent,
980
+ AgentList,
981
+ Cache,
982
+ Notebook,
983
+ QuestionFreeText,
984
+ QuestionMultipleChoice,
985
+ Results,
986
+ Scenario,
987
+ ScenarioList,
988
+ Survey,
989
+ )
990
+ from edsl.coop import Coop
991
+ from edsl.data.CacheEntry import CacheEntry
992
+ from edsl.jobs import Jobs
993
+
994
+ # init & basics
995
+ API_KEY = "b"
996
+ coop = Coop(api_key=API_KEY)
997
+ coop
998
+ coop.edsl_settings
999
+
1000
+ ##############
1001
+ # A. A simple example
1002
+ ##############
1003
+ # .. create and manipulate an object through the Coop client
1004
+ response = coop.create(QuestionMultipleChoice.example())
1005
+ coop.get(uuid=response.get("uuid"))
1006
+ coop.get(uuid=response.get("uuid"), expected_object_type="question")
1007
+ coop.get(url=response.get("url"))
1008
+ coop.create(QuestionMultipleChoice.example())
1009
+ coop.get_all("question")
1010
+ coop.patch(uuid=response.get("uuid"), visibility="private")
1011
+ coop.patch(uuid=response.get("uuid"), description="hey")
1012
+ coop.patch(uuid=response.get("uuid"), value=QuestionFreeText.example())
1013
+ # coop.patch(uuid=response.get("uuid"), value=Survey.example()) - should throw error
1014
+ coop.get(uuid=response.get("uuid"))
1015
+ coop.delete(uuid=response.get("uuid"))
1016
+
1017
+ # .. create and manipulate an object through the class
1018
+ response = QuestionMultipleChoice.example().push()
1019
+ QuestionMultipleChoice.pull(uuid=response.get("uuid"))
1020
+ QuestionMultipleChoice.pull(url=response.get("url"))
1021
+ QuestionMultipleChoice.patch(uuid=response.get("uuid"), visibility="private")
1022
+ QuestionMultipleChoice.patch(uuid=response.get("uuid"), description="hey")
1023
+ QuestionMultipleChoice.patch(
1024
+ uuid=response.get("uuid"), value=QuestionFreeText.example()
1025
+ )
1026
+ QuestionMultipleChoice.pull(response.get("uuid"))
1027
+ QuestionMultipleChoice.delete(response.get("uuid"))
1028
+
1029
+ ##############
1030
+ # B. Examples with all objects
1031
+ ##############
1032
+ OBJECTS = [
1033
+ ("agent", Agent),
1034
+ ("agent_list", AgentList),
1035
+ ("cache", Cache),
1036
+ ("notebook", Notebook),
1037
+ ("question", QuestionMultipleChoice),
1038
+ ("results", Results),
1039
+ ("scenario", Scenario),
1040
+ ("scenario_list", ScenarioList),
1041
+ ("survey", Survey),
1042
+ ]
1043
+ for object_type, cls in OBJECTS:
1044
+ print(f"Testing {object_type} objects")
1045
+ # 1. Delete existing objects
1046
+ existing_objects = coop.get_all(object_type)
1047
+ for item in existing_objects:
1048
+ coop.delete(uuid=item.get("uuid"))
1049
+ # 2. Create new objects
1050
+ example = cls.example()
1051
+ response_1 = coop.create(example)
1052
+ response_2 = coop.create(cls.example(), visibility="private")
1053
+ response_3 = coop.create(cls.example(), visibility="public")
1054
+ response_4 = coop.create(
1055
+ cls.example(), visibility="unlisted", description="hey"
1056
+ )
1057
+ # 3. Retrieve all objects
1058
+ objects = coop.get_all(object_type)
1059
+ assert len(objects) == 4
1060
+ # 4. Try to retrieve an item that does not exist
1061
+ try:
1062
+ coop.get(uuid=uuid4())
1063
+ except Exception as e:
1064
+ print(e)
1065
+ # 5. Try to retrieve all test objects by their uuids
1066
+ for response in [response_1, response_2, response_3, response_4]:
1067
+ coop.get(uuid=response.get("uuid"))
1068
+ # 6. Change visibility of all objects
1069
+ for item in objects:
1070
+ coop.patch(uuid=item.get("uuid"), visibility="private")
1071
+ # 6. Change description of all objects
1072
+ for item in objects:
1073
+ coop.patch(uuid=item.get("uuid"), description="hey")
1074
+ # 7. Delete all objects
1075
+ for item in objects:
1076
+ coop.delete(uuid=item.get("uuid"))
1077
+ assert len(coop.get_all(object_type)) == 0
1078
+
1079
+ ##############
1080
+ # C. Remote Cache
1081
+ ##############
1082
+ # clear
1083
+ coop.remote_cache_clear()
1084
+ assert coop.remote_cache_get() == []
1085
+ # create one remote cache entry
1086
+ cache_entry = CacheEntry.example()
1087
+ cache_entry.to_dict()
1088
+ coop.remote_cache_create(cache_entry)
1089
+ # create many remote cache entries
1090
+ cache_entries = [CacheEntry.example(randomize=True) for _ in range(10)]
1091
+ coop.remote_cache_create_many(cache_entries)
1092
+ # get all remote cache entries
1093
+ coop.remote_cache_get()
1094
+ coop.remote_cache_get(exclude_keys=[])
1095
+ coop.remote_cache_get(exclude_keys=["a"])
1096
+ exclude_keys = [cache_entry.key for cache_entry in cache_entries]
1097
+ coop.remote_cache_get(exclude_keys)
1098
+ # clear
1099
+ coop.remote_cache_clear()
1100
+ coop.remote_cache_get()
1101
+
1102
+ ##############
1103
+ # D. Remote Inference
1104
+ ##############
1105
+ job = Jobs.example()
1106
+ coop.remote_inference_cost(job)
1107
+ job_coop_object = coop.remote_inference_create(job)
1108
+ job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
1109
+ coop.get(uuid=job_coop_results.get("results_uuid"))