edsl 0.1.15__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (407) hide show
  1. edsl/Base.py +348 -38
  2. edsl/BaseDiff.py +260 -0
  3. edsl/TemplateLoader.py +24 -0
  4. edsl/__init__.py +45 -10
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +842 -144
  7. edsl/agents/AgentList.py +521 -25
  8. edsl/agents/Invigilator.py +250 -374
  9. edsl/agents/InvigilatorBase.py +257 -0
  10. edsl/agents/PromptConstructor.py +272 -0
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/descriptors.py +43 -13
  14. edsl/agents/prompt_helpers.py +129 -0
  15. edsl/agents/question_option_processor.py +172 -0
  16. edsl/auto/AutoStudy.py +130 -0
  17. edsl/auto/StageBase.py +243 -0
  18. edsl/auto/StageGenerateSurvey.py +178 -0
  19. edsl/auto/StageLabelQuestions.py +125 -0
  20. edsl/auto/StagePersona.py +61 -0
  21. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  22. edsl/auto/StagePersonaDimensionValues.py +74 -0
  23. edsl/auto/StagePersonaDimensions.py +69 -0
  24. edsl/auto/StageQuestions.py +74 -0
  25. edsl/auto/SurveyCreatorPipeline.py +21 -0
  26. edsl/auto/utilities.py +218 -0
  27. edsl/base/Base.py +279 -0
  28. edsl/config.py +115 -113
  29. edsl/conversation/Conversation.py +290 -0
  30. edsl/conversation/car_buying.py +59 -0
  31. edsl/conversation/chips.py +95 -0
  32. edsl/conversation/mug_negotiation.py +81 -0
  33. edsl/conversation/next_speaker_utilities.py +93 -0
  34. edsl/coop/CoopFunctionsMixin.py +15 -0
  35. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  36. edsl/coop/PriceFetcher.py +54 -0
  37. edsl/coop/__init__.py +1 -0
  38. edsl/coop/coop.py +1029 -134
  39. edsl/coop/utils.py +131 -0
  40. edsl/data/Cache.py +560 -89
  41. edsl/data/CacheEntry.py +230 -0
  42. edsl/data/CacheHandler.py +168 -0
  43. edsl/data/RemoteCacheSync.py +186 -0
  44. edsl/data/SQLiteDict.py +292 -0
  45. edsl/data/__init__.py +5 -3
  46. edsl/data/orm.py +6 -33
  47. edsl/data_transfer_models.py +74 -27
  48. edsl/enums.py +165 -8
  49. edsl/exceptions/BaseException.py +21 -0
  50. edsl/exceptions/__init__.py +52 -46
  51. edsl/exceptions/agents.py +33 -15
  52. edsl/exceptions/cache.py +5 -0
  53. edsl/exceptions/coop.py +8 -0
  54. edsl/exceptions/general.py +34 -0
  55. edsl/exceptions/inference_services.py +5 -0
  56. edsl/exceptions/jobs.py +15 -0
  57. edsl/exceptions/language_models.py +46 -1
  58. edsl/exceptions/questions.py +80 -5
  59. edsl/exceptions/results.py +16 -5
  60. edsl/exceptions/scenarios.py +29 -0
  61. edsl/exceptions/surveys.py +13 -10
  62. edsl/inference_services/AnthropicService.py +106 -0
  63. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  64. edsl/inference_services/AvailableModelFetcher.py +215 -0
  65. edsl/inference_services/AwsBedrock.py +118 -0
  66. edsl/inference_services/AzureAI.py +215 -0
  67. edsl/inference_services/DeepInfraService.py +18 -0
  68. edsl/inference_services/GoogleService.py +143 -0
  69. edsl/inference_services/GroqService.py +20 -0
  70. edsl/inference_services/InferenceServiceABC.py +80 -0
  71. edsl/inference_services/InferenceServicesCollection.py +138 -0
  72. edsl/inference_services/MistralAIService.py +120 -0
  73. edsl/inference_services/OllamaService.py +18 -0
  74. edsl/inference_services/OpenAIService.py +236 -0
  75. edsl/inference_services/PerplexityService.py +160 -0
  76. edsl/inference_services/ServiceAvailability.py +135 -0
  77. edsl/inference_services/TestService.py +90 -0
  78. edsl/inference_services/TogetherAIService.py +172 -0
  79. edsl/inference_services/data_structures.py +134 -0
  80. edsl/inference_services/models_available_cache.py +118 -0
  81. edsl/inference_services/rate_limits_cache.py +25 -0
  82. edsl/inference_services/registry.py +41 -0
  83. edsl/inference_services/write_available.py +10 -0
  84. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  85. edsl/jobs/Answers.py +21 -20
  86. edsl/jobs/FetchInvigilator.py +47 -0
  87. edsl/jobs/InterviewTaskManager.py +98 -0
  88. edsl/jobs/InterviewsConstructor.py +50 -0
  89. edsl/jobs/Jobs.py +684 -206
  90. edsl/jobs/JobsChecks.py +172 -0
  91. edsl/jobs/JobsComponentConstructor.py +189 -0
  92. edsl/jobs/JobsPrompts.py +270 -0
  93. edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
  94. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  95. edsl/jobs/RequestTokenEstimator.py +30 -0
  96. edsl/jobs/async_interview_runner.py +138 -0
  97. edsl/jobs/buckets/BucketCollection.py +104 -0
  98. edsl/jobs/buckets/ModelBuckets.py +65 -0
  99. edsl/jobs/buckets/TokenBucket.py +283 -0
  100. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  101. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  102. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  103. edsl/jobs/data_structures.py +120 -0
  104. edsl/jobs/decorators.py +35 -0
  105. edsl/jobs/interviews/Interview.py +392 -0
  106. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
  107. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
  108. edsl/jobs/interviews/InterviewStatistic.py +63 -0
  109. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
  110. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
  111. edsl/jobs/interviews/InterviewStatusLog.py +92 -0
  112. edsl/jobs/interviews/ReportErrors.py +66 -0
  113. edsl/jobs/interviews/interview_status_enum.py +9 -0
  114. edsl/jobs/jobs_status_enums.py +9 -0
  115. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  116. edsl/jobs/results_exceptions_handler.py +98 -0
  117. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
  118. edsl/jobs/runners/JobsRunnerStatus.py +298 -0
  119. edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
  120. edsl/jobs/tasks/TaskCreators.py +64 -0
  121. edsl/jobs/tasks/TaskHistory.py +470 -0
  122. edsl/jobs/tasks/TaskStatusLog.py +23 -0
  123. edsl/jobs/tasks/task_status_enum.py +161 -0
  124. edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
  125. edsl/jobs/tokens/TokenUsage.py +34 -0
  126. edsl/language_models/ComputeCost.py +63 -0
  127. edsl/language_models/LanguageModel.py +507 -386
  128. edsl/language_models/ModelList.py +164 -0
  129. edsl/language_models/PriceManager.py +127 -0
  130. edsl/language_models/RawResponseHandler.py +106 -0
  131. edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
  132. edsl/language_models/__init__.py +1 -8
  133. edsl/language_models/fake_openai_call.py +15 -0
  134. edsl/language_models/fake_openai_service.py +61 -0
  135. edsl/language_models/key_management/KeyLookup.py +63 -0
  136. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  137. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  138. edsl/language_models/key_management/__init__.py +0 -0
  139. edsl/language_models/key_management/models.py +131 -0
  140. edsl/language_models/model.py +256 -0
  141. edsl/language_models/repair.py +109 -41
  142. edsl/language_models/utilities.py +65 -0
  143. edsl/notebooks/Notebook.py +263 -0
  144. edsl/notebooks/NotebookToLaTeX.py +142 -0
  145. edsl/notebooks/__init__.py +1 -0
  146. edsl/prompts/Prompt.py +222 -93
  147. edsl/prompts/__init__.py +1 -1
  148. edsl/questions/ExceptionExplainer.py +77 -0
  149. edsl/questions/HTMLQuestion.py +103 -0
  150. edsl/questions/QuestionBase.py +518 -0
  151. edsl/questions/QuestionBasePromptsMixin.py +221 -0
  152. edsl/questions/QuestionBudget.py +164 -67
  153. edsl/questions/QuestionCheckBox.py +281 -62
  154. edsl/questions/QuestionDict.py +343 -0
  155. edsl/questions/QuestionExtract.py +136 -50
  156. edsl/questions/QuestionFreeText.py +79 -55
  157. edsl/questions/QuestionFunctional.py +138 -41
  158. edsl/questions/QuestionList.py +184 -57
  159. edsl/questions/QuestionMatrix.py +265 -0
  160. edsl/questions/QuestionMultipleChoice.py +293 -69
  161. edsl/questions/QuestionNumerical.py +109 -56
  162. edsl/questions/QuestionRank.py +244 -49
  163. edsl/questions/Quick.py +41 -0
  164. edsl/questions/SimpleAskMixin.py +74 -0
  165. edsl/questions/__init__.py +9 -6
  166. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
  167. edsl/questions/compose_questions.py +13 -7
  168. edsl/questions/data_structures.py +20 -0
  169. edsl/questions/decorators.py +21 -0
  170. edsl/questions/derived/QuestionLikertFive.py +28 -26
  171. edsl/questions/derived/QuestionLinearScale.py +41 -28
  172. edsl/questions/derived/QuestionTopK.py +34 -26
  173. edsl/questions/derived/QuestionYesNo.py +40 -27
  174. edsl/questions/descriptors.py +228 -74
  175. edsl/questions/loop_processor.py +149 -0
  176. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  177. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  178. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  179. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  180. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  181. edsl/questions/prompt_templates/question_list.jinja +17 -0
  182. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  183. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  184. edsl/questions/question_base_gen_mixin.py +168 -0
  185. edsl/questions/question_registry.py +130 -46
  186. edsl/questions/register_questions_meta.py +71 -0
  187. edsl/questions/response_validator_abc.py +188 -0
  188. edsl/questions/response_validator_factory.py +34 -0
  189. edsl/questions/settings.py +5 -2
  190. edsl/questions/templates/__init__.py +0 -0
  191. edsl/questions/templates/budget/__init__.py +0 -0
  192. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  193. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  194. edsl/questions/templates/checkbox/__init__.py +0 -0
  195. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  196. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  197. edsl/questions/templates/dict/__init__.py +0 -0
  198. edsl/questions/templates/dict/answering_instructions.jinja +21 -0
  199. edsl/questions/templates/dict/question_presentation.jinja +1 -0
  200. edsl/questions/templates/extract/__init__.py +0 -0
  201. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  202. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  203. edsl/questions/templates/free_text/__init__.py +0 -0
  204. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  205. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  206. edsl/questions/templates/likert_five/__init__.py +0 -0
  207. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  208. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  209. edsl/questions/templates/linear_scale/__init__.py +0 -0
  210. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  211. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  212. edsl/questions/templates/list/__init__.py +0 -0
  213. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  214. edsl/questions/templates/list/question_presentation.jinja +5 -0
  215. edsl/questions/templates/matrix/__init__.py +1 -0
  216. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  217. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  218. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  219. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  220. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  221. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  222. edsl/questions/templates/numerical/__init__.py +0 -0
  223. edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
  224. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  225. edsl/questions/templates/rank/__init__.py +0 -0
  226. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  227. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  228. edsl/questions/templates/top_k/__init__.py +0 -0
  229. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  230. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  231. edsl/questions/templates/yes_no/__init__.py +0 -0
  232. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  233. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  234. edsl/results/CSSParameterizer.py +108 -0
  235. edsl/results/Dataset.py +550 -19
  236. edsl/results/DatasetExportMixin.py +594 -0
  237. edsl/results/DatasetTree.py +295 -0
  238. edsl/results/MarkdownToDocx.py +122 -0
  239. edsl/results/MarkdownToPDF.py +111 -0
  240. edsl/results/Result.py +477 -173
  241. edsl/results/Results.py +987 -269
  242. edsl/results/ResultsExportMixin.py +28 -125
  243. edsl/results/ResultsGGMixin.py +83 -15
  244. edsl/results/TableDisplay.py +125 -0
  245. edsl/results/TextEditor.py +50 -0
  246. edsl/results/__init__.py +1 -1
  247. edsl/results/file_exports.py +252 -0
  248. edsl/results/results_fetch_mixin.py +33 -0
  249. edsl/results/results_selector.py +145 -0
  250. edsl/results/results_tools_mixin.py +98 -0
  251. edsl/results/smart_objects.py +96 -0
  252. edsl/results/table_data_class.py +12 -0
  253. edsl/results/table_display.css +78 -0
  254. edsl/results/table_renderers.py +118 -0
  255. edsl/results/tree_explore.py +115 -0
  256. edsl/scenarios/ConstructDownloadLink.py +109 -0
  257. edsl/scenarios/DocumentChunker.py +102 -0
  258. edsl/scenarios/DocxScenario.py +16 -0
  259. edsl/scenarios/FileStore.py +543 -0
  260. edsl/scenarios/PdfExtractor.py +40 -0
  261. edsl/scenarios/Scenario.py +431 -62
  262. edsl/scenarios/ScenarioHtmlMixin.py +65 -0
  263. edsl/scenarios/ScenarioList.py +1415 -45
  264. edsl/scenarios/ScenarioListExportMixin.py +45 -0
  265. edsl/scenarios/ScenarioListPdfMixin.py +239 -0
  266. edsl/scenarios/__init__.py +2 -0
  267. edsl/scenarios/directory_scanner.py +96 -0
  268. edsl/scenarios/file_methods.py +85 -0
  269. edsl/scenarios/handlers/__init__.py +13 -0
  270. edsl/scenarios/handlers/csv.py +49 -0
  271. edsl/scenarios/handlers/docx.py +76 -0
  272. edsl/scenarios/handlers/html.py +37 -0
  273. edsl/scenarios/handlers/json.py +111 -0
  274. edsl/scenarios/handlers/latex.py +5 -0
  275. edsl/scenarios/handlers/md.py +51 -0
  276. edsl/scenarios/handlers/pdf.py +68 -0
  277. edsl/scenarios/handlers/png.py +39 -0
  278. edsl/scenarios/handlers/pptx.py +105 -0
  279. edsl/scenarios/handlers/py.py +294 -0
  280. edsl/scenarios/handlers/sql.py +313 -0
  281. edsl/scenarios/handlers/sqlite.py +149 -0
  282. edsl/scenarios/handlers/txt.py +33 -0
  283. edsl/scenarios/scenario_join.py +131 -0
  284. edsl/scenarios/scenario_selector.py +156 -0
  285. edsl/shared.py +1 -0
  286. edsl/study/ObjectEntry.py +173 -0
  287. edsl/study/ProofOfWork.py +113 -0
  288. edsl/study/SnapShot.py +80 -0
  289. edsl/study/Study.py +521 -0
  290. edsl/study/__init__.py +4 -0
  291. edsl/surveys/ConstructDAG.py +92 -0
  292. edsl/surveys/DAG.py +92 -11
  293. edsl/surveys/EditSurvey.py +221 -0
  294. edsl/surveys/InstructionHandler.py +100 -0
  295. edsl/surveys/Memory.py +9 -4
  296. edsl/surveys/MemoryManagement.py +72 -0
  297. edsl/surveys/MemoryPlan.py +156 -35
  298. edsl/surveys/Rule.py +221 -74
  299. edsl/surveys/RuleCollection.py +241 -61
  300. edsl/surveys/RuleManager.py +172 -0
  301. edsl/surveys/Simulator.py +75 -0
  302. edsl/surveys/Survey.py +1079 -339
  303. edsl/surveys/SurveyCSS.py +273 -0
  304. edsl/surveys/SurveyExportMixin.py +235 -40
  305. edsl/surveys/SurveyFlowVisualization.py +181 -0
  306. edsl/surveys/SurveyQualtricsImport.py +284 -0
  307. edsl/surveys/SurveyToApp.py +141 -0
  308. edsl/surveys/__init__.py +4 -2
  309. edsl/surveys/base.py +19 -3
  310. edsl/surveys/descriptors.py +17 -6
  311. edsl/surveys/instructions/ChangeInstruction.py +48 -0
  312. edsl/surveys/instructions/Instruction.py +56 -0
  313. edsl/surveys/instructions/InstructionCollection.py +82 -0
  314. edsl/surveys/instructions/__init__.py +0 -0
  315. edsl/templates/error_reporting/base.html +24 -0
  316. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  317. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  318. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  319. edsl/templates/error_reporting/interview_details.html +116 -0
  320. edsl/templates/error_reporting/interviews.html +19 -0
  321. edsl/templates/error_reporting/overview.html +5 -0
  322. edsl/templates/error_reporting/performance_plot.html +2 -0
  323. edsl/templates/error_reporting/report.css +74 -0
  324. edsl/templates/error_reporting/report.html +118 -0
  325. edsl/templates/error_reporting/report.js +25 -0
  326. edsl/tools/__init__.py +1 -0
  327. edsl/tools/clusters.py +192 -0
  328. edsl/tools/embeddings.py +27 -0
  329. edsl/tools/embeddings_plotting.py +118 -0
  330. edsl/tools/plotting.py +112 -0
  331. edsl/tools/summarize.py +18 -0
  332. edsl/utilities/PrettyList.py +56 -0
  333. edsl/utilities/SystemInfo.py +5 -0
  334. edsl/utilities/__init__.py +21 -20
  335. edsl/utilities/ast_utilities.py +3 -0
  336. edsl/utilities/data/Registry.py +2 -0
  337. edsl/utilities/decorators.py +41 -0
  338. edsl/utilities/gcp_bucket/__init__.py +0 -0
  339. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  340. edsl/utilities/interface.py +310 -60
  341. edsl/utilities/is_notebook.py +18 -0
  342. edsl/utilities/is_valid_variable_name.py +11 -0
  343. edsl/utilities/naming_utilities.py +263 -0
  344. edsl/utilities/remove_edsl_version.py +24 -0
  345. edsl/utilities/repair_functions.py +28 -0
  346. edsl/utilities/restricted_python.py +70 -0
  347. edsl/utilities/utilities.py +203 -13
  348. edsl-0.1.40.dist-info/METADATA +111 -0
  349. edsl-0.1.40.dist-info/RECORD +362 -0
  350. {edsl-0.1.15.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
  351. edsl/agents/AgentListExportMixin.py +0 -24
  352. edsl/coop/old.py +0 -31
  353. edsl/data/Database.py +0 -141
  354. edsl/data/crud.py +0 -121
  355. edsl/jobs/Interview.py +0 -435
  356. edsl/jobs/JobsRunner.py +0 -63
  357. edsl/jobs/JobsRunnerStatusMixin.py +0 -115
  358. edsl/jobs/base.py +0 -47
  359. edsl/jobs/buckets.py +0 -178
  360. edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
  361. edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
  362. edsl/jobs/task_management.py +0 -215
  363. edsl/jobs/token_tracking.py +0 -78
  364. edsl/language_models/DeepInfra.py +0 -69
  365. edsl/language_models/OpenAI.py +0 -98
  366. edsl/language_models/model_interfaces/GeminiPro.py +0 -66
  367. edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
  368. edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
  369. edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
  370. edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
  371. edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
  372. edsl/language_models/registry.py +0 -81
  373. edsl/language_models/schemas.py +0 -15
  374. edsl/language_models/unused/ReplicateBase.py +0 -83
  375. edsl/prompts/QuestionInstructionsBase.py +0 -6
  376. edsl/prompts/library/agent_instructions.py +0 -29
  377. edsl/prompts/library/agent_persona.py +0 -17
  378. edsl/prompts/library/question_budget.py +0 -26
  379. edsl/prompts/library/question_checkbox.py +0 -32
  380. edsl/prompts/library/question_extract.py +0 -19
  381. edsl/prompts/library/question_freetext.py +0 -14
  382. edsl/prompts/library/question_linear_scale.py +0 -20
  383. edsl/prompts/library/question_list.py +0 -22
  384. edsl/prompts/library/question_multiple_choice.py +0 -44
  385. edsl/prompts/library/question_numerical.py +0 -31
  386. edsl/prompts/library/question_rank.py +0 -21
  387. edsl/prompts/prompt_config.py +0 -33
  388. edsl/prompts/registry.py +0 -185
  389. edsl/questions/Question.py +0 -240
  390. edsl/report/InputOutputDataTypes.py +0 -134
  391. edsl/report/RegressionMixin.py +0 -28
  392. edsl/report/ReportOutputs.py +0 -1228
  393. edsl/report/ResultsFetchMixin.py +0 -106
  394. edsl/report/ResultsOutputMixin.py +0 -14
  395. edsl/report/demo.ipynb +0 -645
  396. edsl/results/ResultsDBMixin.py +0 -184
  397. edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
  398. edsl/trackers/Tracker.py +0 -91
  399. edsl/trackers/TrackerAPI.py +0 -196
  400. edsl/trackers/TrackerTasks.py +0 -70
  401. edsl/utilities/pastebin.py +0 -141
  402. edsl-0.1.15.dist-info/METADATA +0 -69
  403. edsl-0.1.15.dist-info/RECORD +0 -142
  404. /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
  405. /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
  406. /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
  407. {edsl-0.1.15.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -1,310 +1,307 @@
1
- from __future__ import annotations
2
- from functools import wraps
3
- import io
4
- import asyncio
5
- import json
6
- import time
7
- import inspect
8
- from typing import Coroutine
9
- from abc import ABC, abstractmethod, ABCMeta
10
- from rich.console import Console
11
- from rich.table import Table
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
+
3
+ Terminology:
12
4
 
5
+ raw_response: The JSON response from the model. This has all the model meta-data about the call.
13
6
 
14
- from edsl.trackers.TrackerAPI import TrackerAPI
15
- from queue import Queue
16
- from typing import Any, Callable, Type, List
17
- from edsl.data import CRUDOperations, CRUD
18
- from edsl.exceptions import LanguageModelResponseNotJSONError
19
- from edsl.language_models.schemas import model_prices
20
- from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
7
+ edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
+ such as the cache key, token usage, etc.
21
9
 
22
- from edsl.language_models.repair import repair
23
- from typing import get_type_hints
10
+ generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
+ edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
24
12
 
25
- from edsl.exceptions.language_models import LanguageModelAttributeTypeError
26
- from edsl.enums import LanguageModelType, InferenceServiceType
13
+ """
27
14
 
28
- from edsl.Base import RichPrintingMixin, PersistenceMixin
15
+ from __future__ import annotations
16
+ import warnings
17
+ from functools import wraps
18
+ import asyncio
19
+ import json
20
+ import os
21
+ from typing import (
22
+ Coroutine,
23
+ Any,
24
+ Type,
25
+ Union,
26
+ List,
27
+ get_type_hints,
28
+ TypedDict,
29
+ Optional,
30
+ TYPE_CHECKING,
31
+ )
32
+ from abc import ABC, abstractmethod
33
+
34
+ from edsl.data_transfer_models import (
35
+ ModelResponse,
36
+ ModelInputs,
37
+ EDSLOutput,
38
+ AgentResponseDict,
39
+ )
40
+
41
+ if TYPE_CHECKING:
42
+ from edsl.data.Cache import Cache
43
+ from edsl.scenarios.FileStore import FileStore
44
+ from edsl.questions.QuestionBase import QuestionBase
45
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
46
+
47
+ from edsl.enums import InferenceServiceType
48
+
49
+ from edsl.utilities.decorators import (
50
+ sync_wrapper,
51
+ jupyter_nb_handler,
52
+ )
53
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
54
+
55
+ from edsl.Base import PersistenceMixin, RepresentationMixin
56
+ from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
57
+
58
+ from edsl.language_models.key_management.KeyLookupCollection import (
59
+ KeyLookupCollection,
60
+ )
61
+
62
+ from edsl.language_models.RawResponseHandler import RawResponseHandler
29
63
 
30
64
 
31
65
  def handle_key_error(func):
66
+ """Handle KeyError exceptions."""
67
+
32
68
  @wraps(func)
33
69
  def wrapper(*args, **kwargs):
34
70
  try:
35
71
  return func(*args, **kwargs)
36
72
  assert True == False
37
73
  except KeyError as e:
38
- # Handle the KeyError exception
39
74
  return f"""KeyError occurred: {e}. This is most likely because the model you are using
40
75
  returned a JSON object we were not expecting."""
41
76
 
42
77
  return wrapper
43
78
 
44
79
 
45
- class RegisterLanguageModelsMeta(ABCMeta):
46
- "Metaclass to register output elements in a registry i.e., those that have a parent"
47
- _registry = {} # Initialize the registry as a dictionary
48
- REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
80
+ class classproperty:
81
+ def __init__(self, method):
82
+ self.method = method
49
83
 
50
- def __init__(cls, name, bases, dct):
51
- super(RegisterLanguageModelsMeta, cls).__init__(name, bases, dct)
52
- # if name != "LanguageModel":
53
- if (model_name := getattr(cls, "_model_", None)) is not None:
54
- RegisterLanguageModelsMeta.check_required_class_variables(
55
- cls, RegisterLanguageModelsMeta.REQUIRED_CLASS_ATTRIBUTES
56
- )
84
+ def __get__(self, instance, cls):
85
+ return self.method(cls)
57
86
 
58
- ## Check that model name is valid
59
- if not LanguageModelType.is_value_valid(model_name):
60
- acceptable_values = [item.value for item in LanguageModelType]
61
- raise LanguageModelAttributeTypeError(
62
- f"""A LanguageModel's model must be one of {LanguageModelType} values, which are
63
- {acceptable_values}. You passed {model_name}."""
64
- )
65
87
 
66
- if not InferenceServiceType.is_value_valid(
67
- inference_service := getattr(cls, "_inference_service_", None)
68
- ):
69
- acceptable_values = [item.value for item in InferenceServiceType]
70
- raise LanguageModelAttributeTypeError(
71
- f"""A LanguageModel's model must have an _inference_service_ value from
72
- {acceptable_values}. You passed {inference_service}."""
73
- )
88
+ from edsl.Base import HashingMixin
74
89
 
75
- # LanguageModel children have to implement the async_execute_model_call method
76
- RegisterLanguageModelsMeta.verify_method(
77
- candidate_class=cls,
78
- method_name="async_execute_model_call",
79
- expected_return_type=dict[str, Any],
80
- required_parameters=[("user_prompt", str), ("system_prompt", str)],
81
- must_be_async=True,
82
- )
83
- # LanguageModel children have to implement the parse_response method
84
- RegisterLanguageModelsMeta.verify_method(
85
- candidate_class=cls,
86
- method_name="parse_response",
87
- expected_return_type=str,
88
- required_parameters=[("raw_response", dict[str, Any])],
89
- must_be_async=False,
90
- )
91
- RegisterLanguageModelsMeta._registry[model_name] = cls
92
90
 
93
- @classmethod
94
- def get_registered_classes(cls):
95
- return cls._registry
91
+ class LanguageModel(
92
+ PersistenceMixin,
93
+ RepresentationMixin,
94
+ HashingMixin,
95
+ ABC,
96
+ metaclass=RegisterLanguageModelsMeta,
97
+ ):
98
+ """ABC for Language Models."""
96
99
 
97
- @staticmethod
98
- def check_required_class_variables(
99
- candidate_class: LanguageModel, required_attributes: List[str] = None
100
+ _model_ = None
101
+ key_sequence = (
102
+ None # This should be something like ["choices", 0, "message", "content"]
103
+ )
104
+
105
+ DEFAULT_RPM = 100
106
+ DEFAULT_TPM = 1000
107
+
108
+ @classproperty
109
+ def response_handler(cls):
110
+ key_sequence = cls.key_sequence
111
+ usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
112
+ return RawResponseHandler(key_sequence, usage_sequence)
113
+
114
+ def __init__(
115
+ self,
116
+ tpm: Optional[float] = None,
117
+ rpm: Optional[float] = None,
118
+ omit_system_prompt_if_empty_string: bool = True,
119
+ key_lookup: Optional["KeyLookup"] = None,
120
+ **kwargs,
100
121
  ):
101
- """Checks if a class has the required attributes
102
- >>> class M:
103
- ... _model_ = "m"
104
- ... _parameters_ = {}
105
- >>> RegisterLanguageModelsMeta.check_required_class_variables(M, ["_model_", "_parameters_"])
106
- >>> class M2:
107
- ... _model_ = "m"
108
- >>> RegisterLanguageModelsMeta.check_required_class_variables(M2, ["_model_", "_parameters_"])
109
- Traceback (most recent call last):
110
- ...
111
- Exception: Class M2 does not have required attribute _parameters_
112
- """
113
- required_attributes = required_attributes or []
114
- for attribute in required_attributes:
115
- if not hasattr(candidate_class, attribute):
116
- raise Exception(
117
- f"Class {candidate_class.__name__} does not have required attribute {attribute}"
118
- )
122
+ """Initialize the LanguageModel."""
123
+ self.model = getattr(self, "_model_", None)
124
+ default_parameters = getattr(self, "_parameters_", None)
125
+ parameters = self._overide_default_parameters(kwargs, default_parameters)
126
+ self.parameters = parameters
127
+ self.remote = False
128
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
119
129
 
120
- @staticmethod
121
- def verify_method(
122
- candidate_class: LanguageModel,
123
- method_name: str,
124
- expected_return_type: Any,
125
- required_parameters: List[tuple[str, Any]] = None,
126
- must_be_async: bool = False,
127
- ):
128
- RegisterLanguageModelsMeta._check_method_defined(candidate_class, method_name)
130
+ self.key_lookup = self._set_key_lookup(key_lookup)
131
+ self.model_info = self.key_lookup.get(self._inference_service_)
129
132
 
130
- required_parameters = required_parameters or []
131
- method = getattr(candidate_class, method_name)
132
- signature = inspect.signature(method)
133
+ if rpm is not None:
134
+ self._rpm = rpm
133
135
 
134
- RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
136
+ if tpm is not None:
137
+ self._tpm = tpm
135
138
 
136
- if must_be_async:
137
- RegisterLanguageModelsMeta._check_is_coroutine(method)
139
+ for key, value in parameters.items():
140
+ setattr(self, key, value)
138
141
 
139
- # Check the parameters
140
- params = signature.parameters
141
- for param_name, param_type in required_parameters:
142
- RegisterLanguageModelsMeta._verify_parameter(
143
- params, param_name, param_type, method_name
144
- )
142
+ for key, value in kwargs.items():
143
+ if key not in parameters:
144
+ setattr(self, key, value)
145
145
 
146
- @staticmethod
147
- def _check_method_defined(cls, method_name):
148
- """Checks if a method is defined in a class
149
- >>> class M:
150
- ... def f(self): pass
151
- >>> RegisterLanguageModelsMeta._check_method_defined(M, "f")
152
- >>> RegisterLanguageModelsMeta._check_method_defined(M, "g")
153
- Traceback (most recent call last):
154
- ...
155
- NotImplementedError: g method must be implemented
156
- """
157
- if not hasattr(cls, method_name):
158
- raise NotImplementedError(f"{method_name} method must be implemented")
146
+ if kwargs.get("skip_api_key_check", False):
147
+ # Skip the API key check. Sometimes this is useful for testing.
148
+ self._api_token = None
159
149
 
160
- @staticmethod
161
- def _check_is_coroutine(func: Callable):
162
- """
163
- Checks to make sure it's a coroutine function
164
- >>> def f(): pass
165
- >>> RegisterLanguageModelsMeta._check_is_coroutine(f)
166
- Traceback (most recent call last):
167
- ...
168
- TypeError: A LangugeModel class with method f must be an asynchronous method
169
- """
170
- if not inspect.iscoroutinefunction(func):
171
- raise TypeError(
172
- f"A LangugeModel class with method {func.__name__} must be an asynchronous method"
173
- )
150
+ def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
151
+ """Set the key lookup."""
152
+ if key_lookup is not None:
153
+ return key_lookup
154
+ else:
155
+ klc = KeyLookupCollection()
156
+ klc.add_key_lookup(fetch_order=("config", "env"))
157
+ return klc.get(("config", "env"))
174
158
 
175
- @staticmethod
176
- def _verify_parameter(params, param_name, param_type, method_name):
177
- if param_name not in params:
178
- raise TypeError(
179
- f"""Parameter "{param_name}" of method "{method_name}" must be defined.
180
- """
181
- )
182
- if params[param_name].annotation != param_type:
183
- raise TypeError(
184
- f"""Parameter "{param_name}" of method "{method_name}" must be of type {param_type.__name__}.
185
- Got {params[param_name].annotation} instead.
186
- """
187
- )
159
+ def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
160
+ """Set the key lookup, later"""
161
+ if hasattr(self, "_api_token"):
162
+ del self._api_token
163
+ self.key_lookup = key_lookup
188
164
 
189
- @staticmethod
190
- def _check_return_type(method, expected_return_type):
191
- """
192
- Checks if the return type of a method is as expected
193
- >>> class M:
194
- ... async def f(self) -> str: pass
195
- >>> RegisterLanguageModelsMeta._check_return_type(M.f, str)
196
- >>> class N:
197
- ... async def f(self) -> int: pass
198
- >>> RegisterLanguageModelsMeta._check_return_type(N.f, str)
199
- Traceback (most recent call last):
200
- ...
201
- TypeError: Return type of f must be <class 'str'>. Got <class 'int'>
165
+ def ask_question(self, question: "QuestionBase") -> str:
166
+ """Ask a question and return the response.
167
+
168
+ :param question: The question to ask.
202
169
  """
203
- if inspect.isroutine(method):
204
- # return_type = inspect.signature(method).return_annotation
205
- return_type = get_type_hints(method)["return"]
206
- if return_type != expected_return_type:
207
- raise TypeError(
208
- f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
209
- )
170
+ user_prompt = question.get_instructions().render(question.data).text
171
+ system_prompt = "You are a helpful agent pretending to be a human."
172
+ return self.execute_model_call(user_prompt, system_prompt)
210
173
 
211
- @classmethod
212
- def model_names_to_classes(cls):
213
- d = {}
214
- for classname, cls in cls._registry.items():
215
- if hasattr(cls, "_model_"):
216
- d[cls._model_] = cls
174
+ @property
175
+ def rpm(self):
176
+ if not hasattr(self, "_rpm"):
177
+ if self.model_info is None:
178
+ self._rpm = self.DEFAULT_RPM
217
179
  else:
218
- raise Exception(
219
- f"Class {classname} does not have a _model_ class attribute"
220
- )
221
- return d
180
+ self._rpm = self.model_info.rpm
181
+ return self._rpm
222
182
 
183
+ @property
184
+ def tpm(self):
185
+ if not hasattr(self, "_tpm"):
186
+ if self.model_info is None:
187
+ self._tpm = self.DEFAULT_TPM
188
+ else:
189
+ self._tpm = self.model_info.tpm
190
+ return self._tpm
223
191
 
224
- class LanguageModel(
225
- RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
226
- ):
227
- """ABC for LLM subclasses."""
192
+ # in case we want to override the default values
193
+ @tpm.setter
194
+ def tpm(self, value):
195
+ self._tpm = value
228
196
 
229
- _model_ = None
197
+ @rpm.setter
198
+ def rpm(self, value):
199
+ self._rpm = value
230
200
 
231
- __rate_limits = None
232
- # TODO: Use the OpenAI Teir 1 rate limits
233
- __default_rate_limits = {"rpm": 10_000, "tpm": 2_000_000}
234
- _safety_factor = 0.8
201
+ @property
202
+ def api_token(self) -> str:
203
+ if not hasattr(self, "_api_token"):
204
+ info = self.key_lookup.get(self._inference_service_, None)
205
+ if info is None:
206
+ raise ValueError(
207
+ f"No key found for service '{self._inference_service_}'"
208
+ )
209
+ self._api_token = info.api_token
210
+ return self._api_token
211
+
212
+ def __getitem__(self, key):
213
+ return getattr(self, key)
214
+
215
+ def hello(self, verbose=False):
216
+ """Runs a simple test to check if the model is working."""
217
+ token = self.api_token
218
+ masked = token[: min(8, len(token))] + "..."
219
+ if verbose:
220
+ print(f"Current key is {masked}")
221
+ return self.execute_model_call(
222
+ user_prompt="Hello, model!", system_prompt="You are a helpful agent."
223
+ )
235
224
 
236
- def __init__(self, crud: CRUDOperations = CRUD, **kwargs):
237
- """
238
- Attributes:
239
- - all attributes inherited from subclasses
240
- - lock: lock for this model to ensure TODO
241
- - api_queue: queue that records messages about API calls the model makes. Used by `InterviewManager` to update details about state of model.
225
+ def has_valid_api_key(self) -> bool:
226
+ """Check if the model has a valid API key.
227
+
228
+ >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
229
+ True
230
+
231
+ This method is used to check if the model has a valid API key.
242
232
  """
243
- self.model = getattr(self, "_model_", None)
244
- default_parameters = getattr(self, "_parameters_", None)
245
- parameters = self._overide_default_parameters(kwargs, default_parameters)
246
- self.parameters = parameters
233
+ from edsl.enums import service_to_api_keyname
247
234
 
248
- for key, value in parameters.items():
249
- setattr(self, key, value)
235
+ if self._model_ == "test":
236
+ return True
250
237
 
251
- for key, value in kwargs.items():
252
- if key not in parameters:
253
- setattr(self, key, value)
238
+ key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
239
+ key_value = os.getenv(key_name)
240
+ return key_value is not None
254
241
 
255
- # TODO: This can very likely be removed
256
- self.api_queue = Queue()
257
- self.crud = crud
242
+ def __hash__(self) -> str:
243
+ """Allow the model to be used as a key in a dictionary.
258
244
 
259
- def __hash__(self):
260
- "Allows the model to be used as a key in a dictionary"
261
- return hash(self.model + str(self.parameters))
245
+ >>> m = LanguageModel.example()
246
+ >>> hash(m)
247
+ 1811901442659237949
248
+ """
249
+ from edsl.utilities.utilities import dict_hash
262
250
 
263
- def __eq__(self, other):
264
- return self.model == other.model and self.parameters == other.parameters
251
+ return dict_hash(self.to_dict(add_edsl_version=False))
265
252
 
266
- def _set_rate_limits(self) -> None:
267
- if self.__rate_limits is None:
268
- if hasattr(self, "get_rate_limits"):
269
- self.__rate_limits = self.get_rate_limits()
270
- else:
271
- self.__rate_limits = self.__default_rate_limits
253
+ def __eq__(self, other) -> bool:
254
+ """Check is two models are the same.
272
255
 
273
- @property
274
- def RPM(self):
275
- "Model's requests-per-minute limit"
276
- self._set_rate_limits()
277
- return self._safety_factor * self.__rate_limits["rpm"]
256
+ >>> m1 = LanguageModel.example()
257
+ >>> m2 = LanguageModel.example()
258
+ >>> m1 == m2
259
+ True
278
260
 
279
- @property
280
- def TPM(self):
281
- "Model's tokens-per-minute limit"
282
- self._set_rate_limits()
283
- return self._safety_factor * self.__rate_limits["tpm"]
261
+ """
262
+ return self.model == other.model and self.parameters == other.parameters
284
263
 
285
264
  @staticmethod
286
265
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
287
- """Returns a dictionary of parameters, with passed parameters taking precedence over defaults.
266
+ """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
288
267
 
289
268
  >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
290
269
  {'temperature': 0.5}
291
270
  >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
292
271
  {'temperature': 0.5, 'max_tokens': 1000}
293
272
  """
294
- parameters = dict({})
295
- for parameter, default_value in default_parameter_dict.items():
296
- if parameter in passed_parameter_dict:
297
- parameters[parameter] = passed_parameter_dict[parameter]
298
- else:
299
- parameters[parameter] = default_value
300
- return parameters
273
+ # this is the case when data is loaded from a dict after serialization
274
+ if "parameters" in passed_parameter_dict:
275
+ passed_parameter_dict = passed_parameter_dict["parameters"]
276
+ return {
277
+ parameter_name: passed_parameter_dict.get(parameter_name, default_value)
278
+ for parameter_name, default_value in default_parameter_dict.items()
279
+ }
280
+
281
+ def __call__(self, user_prompt: str, system_prompt: str):
282
+ return self.execute_model_call(user_prompt, system_prompt)
301
283
 
302
284
  @abstractmethod
303
- async def async_execute_model_call():
285
+ async def async_execute_model_call(user_prompt: str, system_prompt: str):
286
+ """Execute the model call and returns a coroutine."""
304
287
  pass
305
288
 
289
+ async def remote_async_execute_model_call(
290
+ self, user_prompt: str, system_prompt: str
291
+ ):
292
+ """Execute the model call and returns the result as a coroutine, using Coop."""
293
+ from edsl.coop import Coop
294
+
295
+ client = Coop()
296
+ response_data = await client.remote_async_execute_model_call(
297
+ self.to_dict(), user_prompt, system_prompt
298
+ )
299
+ return response_data
300
+
306
301
  @jupyter_nb_handler
307
302
  def execute_model_call(self, *args, **kwargs) -> Coroutine:
303
+ """Execute the model call and returns the result as a coroutine."""
304
+
308
305
  async def main():
309
306
  results = await asyncio.gather(
310
307
  self.async_execute_model_call(*args, **kwargs)
@@ -313,193 +310,317 @@ class LanguageModel(
313
310
 
314
311
  return main()
315
312
 
316
- @abstractmethod
317
- def parse_response(raw_response: dict[str, Any]) -> str:
318
- """Parses the API response and returns the response text.
319
- What is returned by the API is model-specific and often includes meta-data that we do not need.
320
- For example, here is the results from a call to GPT-4:
321
-
322
- {
323
- "id": "chatcmpl-8eORaeuVb4po9WQRjKEFY6w7v6cTm",
324
- "choices": [
325
- {
326
- "finish_reason": "stop",
327
- "index": 0,
328
- "logprobs": None,
329
- "message": {
330
- "content": "Hello! How can I assist you today? If you have any questions or need information on a particular topic, feel free to ask.",
331
- "role": "assistant",
332
- "function_call": None,
333
- "tool_calls": None,
334
- },
335
- }
336
- ],
337
- "created": 1704637774,
338
- "model": "gpt-4-1106-preview",
339
- "object": "chat.completion",
340
- "system_fingerprint": "fp_168383a679",
341
- "usage": {"completion_tokens": 27, "prompt_tokens": 13, "total_tokens": 40},
342
- }
313
+ @classmethod
314
+ def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
315
+ """Return the generated token string from the raw response.
316
+
317
+ >>> m = LanguageModel.example(test_model = True)
318
+ >>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
319
+ >>> m.get_generated_token_string(raw_response)
320
+ 'Hello world'
343
321
 
344
- To actually tract the response, we need to grab
345
- data["choices[0]"]["message"]["content"].
346
322
  """
347
- raise NotImplementedError
323
+ return cls.response_handler.get_generated_token_string(raw_response)
348
324
 
349
- def _update_response_with_tracking(
350
- self, response, start_time, cached_response=False
351
- ):
352
- end_time = time.time()
353
- response["elapsed_time"] = end_time - start_time
354
- response["timestamp"] = end_time
355
- self._post_tracker_event(response)
356
- response["cached_response"] = cached_response
357
- return response
358
-
359
- async def async_get_raw_response(
360
- self, user_prompt: str, system_prompt: str = ""
361
- ) -> dict[str, Any]:
362
- """This is some middle-ware that handles the caching of responses.
363
- If the cache isn't being used, it just returns a 'fresh' call to the LLM,
364
- but appends some tracking information to the response (using the _update_response_with_tracking method).
325
+ @classmethod
326
+ def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
327
+ """Return the usage dictionary from the raw response."""
328
+ return cls.response_handler.get_usage_dict(raw_response)
329
+
330
+ @classmethod
331
+ def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
332
+ """Parses the API response and returns the response text."""
333
+ return cls.response_handler.parse_response(raw_response)
334
+
335
+ async def _async_get_intended_model_call_outcome(
336
+ self,
337
+ user_prompt: str,
338
+ system_prompt: str,
339
+ cache: Cache,
340
+ iteration: int = 0,
341
+ files_list: Optional[List[FileStore]] = None,
342
+ invigilator=None,
343
+ ) -> ModelResponse:
344
+ """Handle caching of responses.
345
+
346
+ :param user_prompt: The user's prompt.
347
+ :param system_prompt: The system's prompt.
348
+ :param iteration: The iteration number.
349
+ :param cache: The cache to use.
350
+ :param files_list: The list of files to use.
351
+ :param invigilator: The invigilator to use.
352
+
353
+ If the cache isn't being used, it just returns a 'fresh' call to the LLM.
365
354
  But if cache is being used, it first checks the database to see if the response is already there.
366
355
  If it is, it returns the cached response, but again appends some tracking information.
367
356
  If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
368
357
 
369
358
  If self.use_cache is True, then attempts to retrieve the response from the database;
370
- if not in the DB, calls the LLM and writes the response to the DB."""
371
- start_time = time.time()
372
-
373
- if not self.use_cache:
374
- response = await self.async_execute_model_call(user_prompt, system_prompt)
375
- return self._update_response_with_tracking(response, start_time, False)
376
-
377
- cached_response = self.crud.get_LLMOutputData(
378
- model=str(self.model),
379
- parameters=str(self.parameters),
380
- system_prompt=system_prompt,
381
- prompt=user_prompt,
382
- )
359
+ if not in the DB, calls the LLM and writes the response to the DB.
360
+
361
+ >>> from edsl import Cache
362
+ >>> m = LanguageModel.example(test_model = True)
363
+ >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
364
+ ModelResponse(...)"""
383
365
 
384
- if cached_response:
366
+ if files_list:
367
+ files_hash = "+".join([str(hash(file)) for file in files_list])
368
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
369
+ else:
370
+ user_prompt_with_hashes = user_prompt
371
+
372
+ cache_call_params = {
373
+ "model": str(self.model),
374
+ "parameters": self.parameters,
375
+ "system_prompt": system_prompt,
376
+ "user_prompt": user_prompt_with_hashes,
377
+ "iteration": iteration,
378
+ }
379
+ cached_response, cache_key = cache.fetch(**cache_call_params)
380
+
381
+ if cache_used := cached_response is not None:
385
382
  response = json.loads(cached_response)
386
- cache_used = True
387
383
  else:
388
- response = await self.async_execute_model_call(user_prompt, system_prompt)
389
- self._save_response_to_db(user_prompt, system_prompt, response)
390
- cache_used = False
384
+ f = (
385
+ self.remote_async_execute_model_call
386
+ if hasattr(self, "remote") and self.remote
387
+ else self.async_execute_model_call
388
+ )
389
+ params = {
390
+ "user_prompt": user_prompt,
391
+ "system_prompt": system_prompt,
392
+ "files_list": files_list,
393
+ }
394
+ from edsl.config import CONFIG
395
+
396
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
+
398
+ response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
399
+ new_cache_key = cache.store(
400
+ **cache_call_params, response=response
401
+ ) # store the response in the cache
402
+ assert new_cache_key == cache_key # should be the same
403
+
404
+ cost = self.cost(response)
405
+ return ModelResponse(
406
+ response=response,
407
+ cache_used=cache_used,
408
+ cache_key=cache_key,
409
+ cached_response=cached_response,
410
+ cost=cost,
411
+ )
391
412
 
392
- return self._update_response_with_tracking(response, start_time, cache_used)
413
+ _get_intended_model_call_outcome = sync_wrapper(
414
+ _async_get_intended_model_call_outcome
415
+ )
393
416
 
394
- get_raw_response = sync_wrapper(async_get_raw_response)
417
+ def simple_ask(
418
+ self,
419
+ question: QuestionBase,
420
+ system_prompt="You are a helpful agent pretending to be a human.",
421
+ top_logprobs=2,
422
+ ):
423
+ """Ask a question and return the response."""
424
+ self.logprobs = True
425
+ self.top_logprobs = top_logprobs
426
+ return self.execute_model_call(
427
+ user_prompt=question.human_readable(), system_prompt=system_prompt
428
+ )
395
429
 
396
- def _save_response_to_db(self, prompt, system_prompt, response):
397
- try:
398
- output = json.dumps(response)
399
- except json.JSONDecodeError:
400
- raise LanguageModelResponseNotJSONError
401
- self.crud.write_LLMOutputData(
402
- model=str(self.model),
403
- parameters=str(self.parameters),
404
- system_prompt=system_prompt,
405
- prompt=prompt,
406
- output=output,
430
+ async def async_get_response(
431
+ self,
432
+ user_prompt: str,
433
+ system_prompt: str,
434
+ cache: Cache,
435
+ iteration: int = 1,
436
+ files_list: Optional[List[FileStore]] = None,
437
+ **kwargs,
438
+ ) -> dict:
439
+ """Get response, parse, and return as string.
440
+
441
+ :param user_prompt: The user's prompt.
442
+ :param system_prompt: The system's prompt.
443
+ :param cache: The cache to use.
444
+ :param iteration: The iteration number.
445
+ :param files_list: The list of files to use.
446
+
447
+ """
448
+ params = {
449
+ "user_prompt": user_prompt,
450
+ "system_prompt": system_prompt,
451
+ "iteration": iteration,
452
+ "cache": cache,
453
+ "files_list": files_list,
454
+ }
455
+ if "invigilator" in kwargs:
456
+ params.update({"invigilator": kwargs["invigilator"]})
457
+
458
+ model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
459
+ model_outputs: ModelResponse = (
460
+ await self._async_get_intended_model_call_outcome(**params)
407
461
  )
462
+ edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
408
463
 
409
- async def async_get_response(self, user_prompt: str, system_prompt: str = ""):
410
- """Get response, parse, and return as string."""
411
- raw_response = await self.async_get_raw_response(user_prompt, system_prompt)
412
- response = self.parse_response(raw_response)
413
- try:
414
- dict_response = json.loads(response)
415
- except json.JSONDecodeError as e:
416
- # TODO: Turn into logs to generate issues
417
- dict_response, success = await repair(response, str(e))
418
- if not success:
419
- raise Exception("Even the repair failed.")
420
-
421
- dict_response["cached_response"] = raw_response["cached_response"]
422
- dict_response["usage"] = raw_response.get("usage", {})
423
- dict_response["raw_model_response"] = raw_response
424
- return dict_response
464
+ agent_response_dict = AgentResponseDict(
465
+ model_inputs=model_inputs,
466
+ model_outputs=model_outputs,
467
+ edsl_dict=edsl_dict,
468
+ )
469
+ return agent_response_dict
425
470
 
426
471
  get_response = sync_wrapper(async_get_response)
427
472
 
428
- #######################
429
- # USEFUL METHODS
430
- #######################
431
- def _post_tracker_event(self, raw_response: dict[str, Any]) -> None:
432
- """Parses the API response and sends usage details to the API Queue."""
433
- usage = raw_response.get("usage", {})
434
- usage.update(
435
- {
436
- "cached_response": raw_response.get("cached_response", None),
437
- "elapsed_time": raw_response.get("elapsed_time", None),
438
- "timestamp": raw_response.get("timestamp", None),
439
- }
473
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
474
+ """Return the dollar cost of a raw response.
475
+
476
+ :param raw_response: The raw response from the model.
477
+ """
478
+
479
+ usage = self.get_usage_dict(raw_response)
480
+ from edsl.language_models.PriceManager import PriceManager
481
+
482
+ price_manger = PriceManager()
483
+ return price_manger.calculate_cost(
484
+ inference_service=self._inference_service_,
485
+ model=self.model,
486
+ usage=usage,
487
+ input_token_name=self.input_token_name,
488
+ output_token_name=self.output_token_name,
440
489
  )
441
- event = TrackerAPI.APICallDetails(details=usage)
442
- self.api_queue.put(event)
443
-
444
- def cost(self, raw_response: dict[str, Any]) -> float:
445
- """Returns the dollar cost of a raw response."""
446
- keys = raw_response["usage"].keys()
447
- prices = model_prices.get(self.model)
448
- return sum([prices.get(key, 0.0) * raw_response["usage"][key] for key in keys])
449
-
450
- #######################
451
- # SERIALIZATION METHODS
452
- #######################
453
- def to_dict(self) -> dict[str, Any]:
454
- """Converts instance to a dictionary."""
455
- return {"model": self.model, "parameters": self.parameters}
490
+
491
+ def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
492
+ """Convert instance to a dictionary
493
+
494
+ :param add_edsl_version: Whether to add the EDSL version to the dictionary.
495
+
496
+ >>> m = LanguageModel.example()
497
+ >>> m.to_dict()
498
+ {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
499
+ """
500
+ d = {
501
+ "model": self.model,
502
+ "parameters": self.parameters,
503
+ }
504
+ if add_edsl_version:
505
+ from edsl import __version__
506
+
507
+ d["edsl_version"] = __version__
508
+ d["edsl_class_name"] = self.__class__.__name__
509
+ return d
456
510
 
457
511
  @classmethod
512
+ @remove_edsl_version
458
513
  def from_dict(cls, data: dict) -> Type[LanguageModel]:
459
- """Converts dictionary to a LanguageModel child instance."""
460
- from edsl.language_models.registry import get_model_class
514
+ """Convert dictionary to a LanguageModel child instance."""
515
+ from edsl.language_models.model import get_model_class
461
516
 
462
517
  model_class = get_model_class(data["model"])
463
- data["use_cache"] = True
464
518
  return model_class(**data)
465
519
 
466
- #######################
467
- # DUNDER METHODS
468
- #######################
469
520
  def __repr__(self) -> str:
470
- return f"{self.__class__.__name__}(model = '{self.model}', parameters={self.parameters})"
521
+ """Return a representation of the object."""
522
+ param_string = ", ".join(
523
+ f"{key} = {value}" for key, value in self.parameters.items()
524
+ )
525
+ return (
526
+ f"Model(model_name = '{self.model}'"
527
+ + (f", {param_string}" if param_string else "")
528
+ + ")"
529
+ )
471
530
 
472
531
  def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
473
- """Combine two models into a single model (other_model takes precedence over self)"""
474
- print(
532
+ """Combine two models into a single model (other_model takes precedence over self)."""
533
+ import warnings
534
+
535
+ warnings.warn(
475
536
  f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
476
537
  by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
477
538
  )
478
539
  return other_model or self
479
540
 
480
- def rich_print(self):
481
- """Displays an object as a table."""
482
- table = Table(title="Language Model")
483
- table.add_column("Attribute", style="bold")
484
- table.add_column("Value")
541
+ @classmethod
542
+ def example(
543
+ cls,
544
+ test_model: bool = False,
545
+ canned_response: str = "Hello world",
546
+ throw_exception: bool = False,
547
+ ) -> LanguageModel:
548
+ """Return a default instance of the class.
549
+
550
+ >>> from edsl.language_models import LanguageModel
551
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
552
+ >>> isinstance(m, LanguageModel)
553
+ True
554
+ >>> from edsl import QuestionFreeText
555
+ >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
556
+ >>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
557
+ 'WOWZA!'
558
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
559
+ >>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
560
+ Exception report saved to ...
561
+ Also see: ...
562
+ """
563
+ from edsl.language_models.model import Model
485
564
 
486
- to_display = self.__dict__.copy()
487
- for attr_name, attr_value in to_display.items():
488
- table.add_row(attr_name, repr(attr_value))
565
+ if test_model:
566
+ m = Model(
567
+ "test", canned_response=canned_response, throw_exception=throw_exception
568
+ )
569
+ return m
570
+ else:
571
+ return Model(skip_api_key_check=True)
489
572
 
490
- return table
573
+ def from_cache(self, cache: "Cache") -> LanguageModel:
491
574
 
492
- @classmethod
493
- def example(cls):
494
- "Returns a default instance of the class"
495
- from edsl import Model
575
+ from copy import deepcopy
576
+ from types import MethodType
577
+ from edsl import Cache
578
+
579
+ new_instance = deepcopy(self)
580
+ print("Cache entries", len(cache))
581
+ new_instance.cache = Cache(
582
+ data={k: v for k, v in cache.items() if v.model == self.model}
583
+ )
584
+ print("Cache entries with same model", len(new_instance.cache))
585
+
586
+ new_instance.user_prompts = [
587
+ ce.user_prompt for ce in new_instance.cache.values()
588
+ ]
589
+ new_instance.system_prompts = [
590
+ ce.system_prompt for ce in new_instance.cache.values()
591
+ ]
592
+
593
+ async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
594
+ cache_call_params = {
595
+ "model": str(self.model),
596
+ "parameters": self.parameters,
597
+ "system_prompt": system_prompt,
598
+ "user_prompt": user_prompt,
599
+ "iteration": 1,
600
+ }
601
+ cached_response, cache_key = cache.fetch(**cache_call_params)
602
+ response = json.loads(cached_response)
603
+ cost = 0
604
+ return ModelResponse(
605
+ response=response,
606
+ cache_used=True,
607
+ cache_key=cache_key,
608
+ cached_response=cached_response,
609
+ cost=cost,
610
+ )
611
+
612
+ # Bind the new method to the copied instance
613
+ setattr(
614
+ new_instance,
615
+ "async_execute_model_call",
616
+ MethodType(async_execute_model_call, new_instance),
617
+ )
496
618
 
497
- return Model(Model.available()[0])
619
+ return new_instance
498
620
 
499
621
 
500
622
  if __name__ == "__main__":
501
- # import doctest
502
- # doctest.testmod()
503
- from edsl.language_models import LanguageModel
623
+ """Run the module's test suite."""
624
+ import doctest
504
625
 
505
- print(LanguageModel.example())
626
+ doctest.testmod(optionflags=doctest.ELLIPSIS)