edsl 0.1.38.dev3__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (341) hide show
  1. edsl/Base.py +413 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +57 -49
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +1071 -858
  7. edsl/agents/AgentList.py +551 -362
  8. edsl/agents/Invigilator.py +284 -222
  9. edsl/agents/InvigilatorBase.py +257 -284
  10. edsl/agents/PromptConstructor.py +272 -353
  11. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  12. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  13. edsl/agents/__init__.py +2 -3
  14. edsl/agents/descriptors.py +99 -99
  15. edsl/agents/prompt_helpers.py +129 -129
  16. edsl/agents/question_option_processor.py +172 -0
  17. edsl/auto/AutoStudy.py +130 -117
  18. edsl/auto/StageBase.py +243 -230
  19. edsl/auto/StageGenerateSurvey.py +178 -178
  20. edsl/auto/StageLabelQuestions.py +125 -125
  21. edsl/auto/StagePersona.py +61 -61
  22. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  23. edsl/auto/StagePersonaDimensionValues.py +74 -74
  24. edsl/auto/StagePersonaDimensions.py +69 -69
  25. edsl/auto/StageQuestions.py +74 -73
  26. edsl/auto/SurveyCreatorPipeline.py +21 -21
  27. edsl/auto/utilities.py +218 -224
  28. edsl/base/Base.py +279 -279
  29. edsl/config.py +177 -149
  30. edsl/conversation/Conversation.py +290 -290
  31. edsl/conversation/car_buying.py +59 -58
  32. edsl/conversation/chips.py +95 -95
  33. edsl/conversation/mug_negotiation.py +81 -81
  34. edsl/conversation/next_speaker_utilities.py +93 -93
  35. edsl/coop/CoopFunctionsMixin.py +15 -0
  36. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  37. edsl/coop/PriceFetcher.py +54 -54
  38. edsl/coop/__init__.py +2 -2
  39. edsl/coop/coop.py +1106 -961
  40. edsl/coop/utils.py +131 -131
  41. edsl/data/Cache.py +573 -530
  42. edsl/data/CacheEntry.py +230 -228
  43. edsl/data/CacheHandler.py +168 -149
  44. edsl/data/RemoteCacheSync.py +186 -97
  45. edsl/data/SQLiteDict.py +292 -292
  46. edsl/data/__init__.py +5 -4
  47. edsl/data/orm.py +10 -10
  48. edsl/data_transfer_models.py +74 -73
  49. edsl/enums.py +202 -173
  50. edsl/exceptions/BaseException.py +21 -21
  51. edsl/exceptions/__init__.py +54 -54
  52. edsl/exceptions/agents.py +54 -42
  53. edsl/exceptions/cache.py +5 -5
  54. edsl/exceptions/configuration.py +16 -16
  55. edsl/exceptions/coop.py +10 -10
  56. edsl/exceptions/data.py +14 -14
  57. edsl/exceptions/general.py +34 -34
  58. edsl/exceptions/inference_services.py +5 -0
  59. edsl/exceptions/jobs.py +33 -33
  60. edsl/exceptions/language_models.py +63 -63
  61. edsl/exceptions/prompts.py +15 -15
  62. edsl/exceptions/questions.py +109 -91
  63. edsl/exceptions/results.py +29 -29
  64. edsl/exceptions/scenarios.py +29 -22
  65. edsl/exceptions/surveys.py +37 -37
  66. edsl/inference_services/AnthropicService.py +106 -87
  67. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  68. edsl/inference_services/AvailableModelFetcher.py +215 -0
  69. edsl/inference_services/AwsBedrock.py +118 -120
  70. edsl/inference_services/AzureAI.py +215 -217
  71. edsl/inference_services/DeepInfraService.py +18 -18
  72. edsl/inference_services/GoogleService.py +143 -156
  73. edsl/inference_services/GroqService.py +20 -20
  74. edsl/inference_services/InferenceServiceABC.py +80 -147
  75. edsl/inference_services/InferenceServicesCollection.py +138 -97
  76. edsl/inference_services/MistralAIService.py +120 -123
  77. edsl/inference_services/OllamaService.py +18 -18
  78. edsl/inference_services/OpenAIService.py +236 -224
  79. edsl/inference_services/PerplexityService.py +160 -0
  80. edsl/inference_services/ServiceAvailability.py +135 -0
  81. edsl/inference_services/TestService.py +90 -89
  82. edsl/inference_services/TogetherAIService.py +172 -170
  83. edsl/inference_services/data_structures.py +134 -0
  84. edsl/inference_services/models_available_cache.py +118 -118
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +41 -39
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  89. edsl/jobs/Answers.py +43 -56
  90. edsl/jobs/FetchInvigilator.py +47 -0
  91. edsl/jobs/InterviewTaskManager.py +98 -0
  92. edsl/jobs/InterviewsConstructor.py +50 -0
  93. edsl/jobs/Jobs.py +823 -1358
  94. edsl/jobs/JobsChecks.py +172 -0
  95. edsl/jobs/JobsComponentConstructor.py +189 -0
  96. edsl/jobs/JobsPrompts.py +270 -0
  97. edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
  98. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  99. edsl/jobs/RequestTokenEstimator.py +30 -0
  100. edsl/jobs/__init__.py +1 -1
  101. edsl/jobs/async_interview_runner.py +138 -0
  102. edsl/jobs/buckets/BucketCollection.py +104 -63
  103. edsl/jobs/buckets/ModelBuckets.py +65 -65
  104. edsl/jobs/buckets/TokenBucket.py +283 -251
  105. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  106. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  107. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  108. edsl/jobs/data_structures.py +120 -0
  109. edsl/jobs/decorators.py +35 -0
  110. edsl/jobs/interviews/Interview.py +396 -661
  111. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  112. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  113. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  114. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  115. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  116. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  117. edsl/jobs/interviews/ReportErrors.py +66 -66
  118. edsl/jobs/interviews/interview_status_enum.py +9 -9
  119. edsl/jobs/jobs_status_enums.py +9 -0
  120. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  121. edsl/jobs/results_exceptions_handler.py +98 -0
  122. edsl/jobs/runners/JobsRunnerAsyncio.py +151 -361
  123. edsl/jobs/runners/JobsRunnerStatus.py +298 -332
  124. edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
  125. edsl/jobs/tasks/TaskCreators.py +64 -64
  126. edsl/jobs/tasks/TaskHistory.py +470 -451
  127. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  128. edsl/jobs/tasks/task_status_enum.py +161 -163
  129. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  130. edsl/jobs/tokens/TokenUsage.py +34 -34
  131. edsl/language_models/ComputeCost.py +63 -0
  132. edsl/language_models/LanguageModel.py +626 -708
  133. edsl/language_models/ModelList.py +164 -109
  134. edsl/language_models/PriceManager.py +127 -0
  135. edsl/language_models/RawResponseHandler.py +106 -0
  136. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  137. edsl/language_models/ServiceDataSources.py +0 -0
  138. edsl/language_models/__init__.py +2 -3
  139. edsl/language_models/fake_openai_call.py +15 -15
  140. edsl/language_models/fake_openai_service.py +61 -61
  141. edsl/language_models/key_management/KeyLookup.py +63 -0
  142. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  143. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  144. edsl/language_models/key_management/__init__.py +0 -0
  145. edsl/language_models/key_management/models.py +131 -0
  146. edsl/language_models/model.py +256 -0
  147. edsl/language_models/repair.py +156 -156
  148. edsl/language_models/utilities.py +65 -64
  149. edsl/notebooks/Notebook.py +263 -258
  150. edsl/notebooks/NotebookToLaTeX.py +142 -0
  151. edsl/notebooks/__init__.py +1 -1
  152. edsl/prompts/Prompt.py +352 -357
  153. edsl/prompts/__init__.py +2 -2
  154. edsl/questions/ExceptionExplainer.py +77 -0
  155. edsl/questions/HTMLQuestion.py +103 -0
  156. edsl/questions/QuestionBase.py +518 -660
  157. edsl/questions/QuestionBasePromptsMixin.py +221 -217
  158. edsl/questions/QuestionBudget.py +227 -227
  159. edsl/questions/QuestionCheckBox.py +359 -359
  160. edsl/questions/QuestionExtract.py +180 -183
  161. edsl/questions/QuestionFreeText.py +113 -114
  162. edsl/questions/QuestionFunctional.py +166 -166
  163. edsl/questions/QuestionList.py +223 -231
  164. edsl/questions/QuestionMatrix.py +265 -0
  165. edsl/questions/QuestionMultipleChoice.py +330 -286
  166. edsl/questions/QuestionNumerical.py +151 -153
  167. edsl/questions/QuestionRank.py +314 -324
  168. edsl/questions/Quick.py +41 -41
  169. edsl/questions/SimpleAskMixin.py +74 -73
  170. edsl/questions/__init__.py +27 -26
  171. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
  172. edsl/questions/compose_questions.py +98 -98
  173. edsl/questions/data_structures.py +20 -0
  174. edsl/questions/decorators.py +21 -21
  175. edsl/questions/derived/QuestionLikertFive.py +76 -76
  176. edsl/questions/derived/QuestionLinearScale.py +90 -87
  177. edsl/questions/derived/QuestionTopK.py +93 -93
  178. edsl/questions/derived/QuestionYesNo.py +82 -82
  179. edsl/questions/descriptors.py +427 -413
  180. edsl/questions/loop_processor.py +149 -0
  181. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  182. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  183. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  184. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  185. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  186. edsl/questions/prompt_templates/question_list.jinja +17 -17
  187. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  188. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  189. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
  190. edsl/questions/question_registry.py +177 -147
  191. edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
  192. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
  193. edsl/questions/response_validator_factory.py +34 -0
  194. edsl/questions/settings.py +12 -12
  195. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  196. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  197. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  198. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  199. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  200. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  201. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  202. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  203. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  204. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  205. edsl/questions/templates/list/question_presentation.jinja +5 -5
  206. edsl/questions/templates/matrix/__init__.py +1 -0
  207. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  208. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  209. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  210. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  211. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  212. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  213. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  214. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  215. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  216. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  217. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  218. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  219. edsl/results/CSSParameterizer.py +108 -0
  220. edsl/results/Dataset.py +587 -293
  221. edsl/results/DatasetExportMixin.py +594 -717
  222. edsl/results/DatasetTree.py +295 -145
  223. edsl/results/MarkdownToDocx.py +122 -0
  224. edsl/results/MarkdownToPDF.py +111 -0
  225. edsl/results/Result.py +557 -456
  226. edsl/results/Results.py +1183 -1071
  227. edsl/results/ResultsExportMixin.py +45 -43
  228. edsl/results/ResultsGGMixin.py +121 -121
  229. edsl/results/TableDisplay.py +125 -0
  230. edsl/results/TextEditor.py +50 -0
  231. edsl/results/__init__.py +2 -2
  232. edsl/results/file_exports.py +252 -0
  233. edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
  234. edsl/results/{Selector.py → results_selector.py} +145 -135
  235. edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
  236. edsl/results/smart_objects.py +96 -0
  237. edsl/results/table_data_class.py +12 -0
  238. edsl/results/table_display.css +78 -0
  239. edsl/results/table_renderers.py +118 -0
  240. edsl/results/tree_explore.py +115 -115
  241. edsl/scenarios/ConstructDownloadLink.py +109 -0
  242. edsl/scenarios/DocumentChunker.py +102 -0
  243. edsl/scenarios/DocxScenario.py +16 -0
  244. edsl/scenarios/FileStore.py +543 -458
  245. edsl/scenarios/PdfExtractor.py +40 -0
  246. edsl/scenarios/Scenario.py +498 -544
  247. edsl/scenarios/ScenarioHtmlMixin.py +65 -64
  248. edsl/scenarios/ScenarioList.py +1458 -1112
  249. edsl/scenarios/ScenarioListExportMixin.py +45 -52
  250. edsl/scenarios/ScenarioListPdfMixin.py +239 -261
  251. edsl/scenarios/__init__.py +3 -4
  252. edsl/scenarios/directory_scanner.py +96 -0
  253. edsl/scenarios/file_methods.py +85 -0
  254. edsl/scenarios/handlers/__init__.py +13 -0
  255. edsl/scenarios/handlers/csv.py +49 -0
  256. edsl/scenarios/handlers/docx.py +76 -0
  257. edsl/scenarios/handlers/html.py +37 -0
  258. edsl/scenarios/handlers/json.py +111 -0
  259. edsl/scenarios/handlers/latex.py +5 -0
  260. edsl/scenarios/handlers/md.py +51 -0
  261. edsl/scenarios/handlers/pdf.py +68 -0
  262. edsl/scenarios/handlers/png.py +39 -0
  263. edsl/scenarios/handlers/pptx.py +105 -0
  264. edsl/scenarios/handlers/py.py +294 -0
  265. edsl/scenarios/handlers/sql.py +313 -0
  266. edsl/scenarios/handlers/sqlite.py +149 -0
  267. edsl/scenarios/handlers/txt.py +33 -0
  268. edsl/scenarios/scenario_join.py +131 -0
  269. edsl/scenarios/scenario_selector.py +156 -0
  270. edsl/shared.py +1 -1
  271. edsl/study/ObjectEntry.py +173 -173
  272. edsl/study/ProofOfWork.py +113 -113
  273. edsl/study/SnapShot.py +80 -80
  274. edsl/study/Study.py +521 -528
  275. edsl/study/__init__.py +4 -4
  276. edsl/surveys/ConstructDAG.py +92 -0
  277. edsl/surveys/DAG.py +148 -148
  278. edsl/surveys/EditSurvey.py +221 -0
  279. edsl/surveys/InstructionHandler.py +100 -0
  280. edsl/surveys/Memory.py +31 -31
  281. edsl/surveys/MemoryManagement.py +72 -0
  282. edsl/surveys/MemoryPlan.py +244 -244
  283. edsl/surveys/Rule.py +327 -326
  284. edsl/surveys/RuleCollection.py +385 -387
  285. edsl/surveys/RuleManager.py +172 -0
  286. edsl/surveys/Simulator.py +75 -0
  287. edsl/surveys/Survey.py +1280 -1787
  288. edsl/surveys/SurveyCSS.py +273 -261
  289. edsl/surveys/SurveyExportMixin.py +259 -259
  290. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -121
  291. edsl/surveys/SurveyQualtricsImport.py +284 -284
  292. edsl/surveys/SurveyToApp.py +141 -0
  293. edsl/surveys/__init__.py +5 -3
  294. edsl/surveys/base.py +53 -53
  295. edsl/surveys/descriptors.py +60 -56
  296. edsl/surveys/instructions/ChangeInstruction.py +48 -49
  297. edsl/surveys/instructions/Instruction.py +56 -53
  298. edsl/surveys/instructions/InstructionCollection.py +82 -77
  299. edsl/templates/error_reporting/base.html +23 -23
  300. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  301. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  302. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  303. edsl/templates/error_reporting/interview_details.html +115 -115
  304. edsl/templates/error_reporting/interviews.html +19 -10
  305. edsl/templates/error_reporting/overview.html +4 -4
  306. edsl/templates/error_reporting/performance_plot.html +1 -1
  307. edsl/templates/error_reporting/report.css +73 -73
  308. edsl/templates/error_reporting/report.html +117 -117
  309. edsl/templates/error_reporting/report.js +25 -25
  310. edsl/tools/__init__.py +1 -1
  311. edsl/tools/clusters.py +192 -192
  312. edsl/tools/embeddings.py +27 -27
  313. edsl/tools/embeddings_plotting.py +118 -118
  314. edsl/tools/plotting.py +112 -112
  315. edsl/tools/summarize.py +18 -18
  316. edsl/utilities/PrettyList.py +56 -0
  317. edsl/utilities/SystemInfo.py +28 -28
  318. edsl/utilities/__init__.py +22 -22
  319. edsl/utilities/ast_utilities.py +25 -25
  320. edsl/utilities/data/Registry.py +6 -6
  321. edsl/utilities/data/__init__.py +1 -1
  322. edsl/utilities/data/scooter_results.json +1 -1
  323. edsl/utilities/decorators.py +77 -77
  324. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  325. edsl/utilities/interface.py +627 -627
  326. edsl/utilities/is_notebook.py +18 -0
  327. edsl/utilities/is_valid_variable_name.py +11 -0
  328. edsl/utilities/naming_utilities.py +263 -263
  329. edsl/utilities/remove_edsl_version.py +24 -0
  330. edsl/utilities/repair_functions.py +28 -28
  331. edsl/utilities/restricted_python.py +70 -70
  332. edsl/utilities/utilities.py +436 -409
  333. {edsl-0.1.38.dev3.dist-info → edsl-0.1.39.dist-info}/LICENSE +21 -21
  334. {edsl-0.1.38.dev3.dist-info → edsl-0.1.39.dist-info}/METADATA +13 -10
  335. edsl-0.1.39.dist-info/RECORD +358 -0
  336. {edsl-0.1.38.dev3.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  337. edsl/language_models/KeyLookup.py +0 -30
  338. edsl/language_models/registry.py +0 -137
  339. edsl/language_models/unused/ReplicateBase.py +0 -83
  340. edsl/results/ResultsDBMixin.py +0 -238
  341. edsl-0.1.38.dev3.dist-info/RECORD +0 -269
@@ -1,708 +1,626 @@
1
- """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
-
3
- Terminology:
4
-
5
- raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
-
7
- edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
- such as the cache key, token usage, etc.
9
-
10
- generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
- edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
-
13
- """
14
-
15
- from __future__ import annotations
16
- import warnings
17
- from functools import wraps
18
- import asyncio
19
- import json
20
- import os
21
- from typing import (
22
- Coroutine,
23
- Any,
24
- Callable,
25
- Type,
26
- Union,
27
- List,
28
- get_type_hints,
29
- TypedDict,
30
- Optional,
31
- TYPE_CHECKING,
32
- )
33
- from abc import ABC, abstractmethod
34
-
35
- from json_repair import repair_json
36
-
37
- from edsl.data_transfer_models import (
38
- ModelResponse,
39
- ModelInputs,
40
- EDSLOutput,
41
- AgentResponseDict,
42
- )
43
-
44
-
45
- from edsl.config import CONFIG
46
- from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
47
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
48
- from edsl.language_models.repair import repair
49
- from edsl.enums import InferenceServiceType
50
- from edsl.Base import RichPrintingMixin, PersistenceMixin
51
- from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
52
- from edsl.exceptions.language_models import LanguageModelBadResponseError
53
-
54
- from edsl.language_models.KeyLookup import KeyLookup
55
-
56
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
57
-
58
-
59
- # you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
60
- # for reasons I don't understand. So leave it here.
61
- def extract_item_from_raw_response(data, key_sequence):
62
- if isinstance(data, str):
63
- try:
64
- data = json.loads(data)
65
- except json.JSONDecodeError as e:
66
- return data
67
- current_data = data
68
- for i, key in enumerate(key_sequence):
69
- try:
70
- if isinstance(current_data, (list, tuple)):
71
- if not isinstance(key, int):
72
- raise TypeError(
73
- f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
74
- )
75
- if key < 0 or key >= len(current_data):
76
- raise IndexError(
77
- f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
78
- )
79
- elif isinstance(current_data, dict):
80
- if key not in current_data:
81
- raise KeyError(
82
- f"Key '{key}' not found in dictionary at position {i}"
83
- )
84
- else:
85
- raise TypeError(
86
- f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
87
- )
88
-
89
- current_data = current_data[key]
90
- except Exception as e:
91
- path = " -> ".join(map(str, key_sequence[: i + 1]))
92
- if "error" in data:
93
- msg = data["error"]
94
- else:
95
- msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
96
- raise LanguageModelBadResponseError(message=msg, response_json=data)
97
- if isinstance(current_data, str):
98
- return current_data.strip()
99
- else:
100
- return current_data
101
-
102
-
103
- def handle_key_error(func):
104
- """Handle KeyError exceptions."""
105
-
106
- @wraps(func)
107
- def wrapper(*args, **kwargs):
108
- try:
109
- return func(*args, **kwargs)
110
- assert True == False
111
- except KeyError as e:
112
- return f"""KeyError occurred: {e}. This is most likely because the model you are using
113
- returned a JSON object we were not expecting."""
114
-
115
- return wrapper
116
-
117
-
118
- class LanguageModel(
119
- RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
120
- ):
121
- """ABC for LLM subclasses.
122
-
123
- TODO:
124
-
125
- 1) Need better, more descriptive names for functions
126
-
127
- get_model_response_no_cache (currently called async_execute_model_call)
128
-
129
- get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
130
- Calls:
131
- - async_execute_model_call
132
- - _updated_model_response_with_tracking
133
-
134
- get_answer (currently called async_get_response)
135
- This parses out the answer block and does some error-handling.
136
- Calls:
137
- - async_get_raw_response
138
- - parse_response
139
-
140
-
141
- """
142
-
143
- _model_ = None
144
- key_sequence = (
145
- None # This should be something like ["choices", 0, "message", "content"]
146
- )
147
- __rate_limits = None
148
- _safety_factor = 0.8
149
-
150
- def __init__(
151
- self,
152
- tpm: float = None,
153
- rpm: float = None,
154
- omit_system_prompt_if_empty_string: bool = True,
155
- key_lookup: Optional[KeyLookup] = None,
156
- **kwargs,
157
- ):
158
- """Initialize the LanguageModel."""
159
- self.model = getattr(self, "_model_", None)
160
- default_parameters = getattr(self, "_parameters_", None)
161
- parameters = self._overide_default_parameters(kwargs, default_parameters)
162
- self.parameters = parameters
163
- self.remote = False
164
- self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
165
-
166
- # self._rpm / _tpm comes from the class
167
- if rpm is not None:
168
- self._rpm = rpm
169
-
170
- if tpm is not None:
171
- self._tpm = tpm
172
-
173
- for key, value in parameters.items():
174
- setattr(self, key, value)
175
-
176
- for key, value in kwargs.items():
177
- if key not in parameters:
178
- setattr(self, key, value)
179
-
180
- if "use_cache" in kwargs:
181
- warnings.warn(
182
- "The use_cache parameter is deprecated. Use the Cache class instead."
183
- )
184
-
185
- if skip_api_key_check := kwargs.get("skip_api_key_check", False):
186
- # Skip the API key check. Sometimes this is useful for testing.
187
- self._api_token = None
188
-
189
- if key_lookup is not None:
190
- self.key_lookup = key_lookup
191
- else:
192
- self.key_lookup = KeyLookup.from_os_environ()
193
-
194
- def ask_question(self, question):
195
- user_prompt = question.get_instructions().render(question.data).text
196
- system_prompt = "You are a helpful agent pretending to be a human."
197
- return self.execute_model_call(user_prompt, system_prompt)
198
-
199
- def set_key_lookup(self, key_lookup: KeyLookup):
200
- del self._api_token
201
- self.key_lookup = key_lookup
202
-
203
- @property
204
- def api_token(self) -> str:
205
- if not hasattr(self, "_api_token"):
206
- self._api_token = self.key_lookup.get_api_token(
207
- self._inference_service_, self.remote
208
- )
209
- return self._api_token
210
-
211
- def __getitem__(self, key):
212
- return getattr(self, key)
213
-
214
- def _repr_html_(self):
215
- from edsl.utilities.utilities import data_to_html
216
-
217
- return data_to_html(self.to_dict())
218
-
219
- def hello(self, verbose=False):
220
- """Runs a simple test to check if the model is working."""
221
- token = self.api_token
222
- masked = token[: min(8, len(token))] + "..."
223
- if verbose:
224
- print(f"Current key is {masked}")
225
- return self.execute_model_call(
226
- user_prompt="Hello, model!", system_prompt="You are a helpful agent."
227
- )
228
-
229
- def has_valid_api_key(self) -> bool:
230
- """Check if the model has a valid API key.
231
-
232
- >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
233
- True
234
-
235
- This method is used to check if the model has a valid API key.
236
- """
237
- from edsl.enums import service_to_api_keyname
238
- import os
239
-
240
- if self._model_ == "test":
241
- return True
242
-
243
- key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
244
- key_value = os.getenv(key_name)
245
- return key_value is not None
246
-
247
- def __hash__(self) -> str:
248
- """Allow the model to be used as a key in a dictionary."""
249
- from edsl.utilities.utilities import dict_hash
250
-
251
- return dict_hash(self.to_dict())
252
-
253
- def __eq__(self, other):
254
- """Check is two models are the same.
255
-
256
- >>> m1 = LanguageModel.example()
257
- >>> m2 = LanguageModel.example()
258
- >>> m1 == m2
259
- True
260
-
261
- """
262
- return self.model == other.model and self.parameters == other.parameters
263
-
264
- def set_rate_limits(self, rpm=None, tpm=None) -> None:
265
- """Set the rate limits for the model.
266
-
267
- >>> m = LanguageModel.example()
268
- >>> m.set_rate_limits(rpm=100, tpm=1000)
269
- >>> m.RPM
270
- 100
271
- """
272
- if rpm is not None:
273
- self._rpm = rpm
274
- if tpm is not None:
275
- self._tpm = tpm
276
- return None
277
-
278
- @property
279
- def RPM(self):
280
- """Model's requests-per-minute limit."""
281
- # self._set_rate_limits()
282
- # return self._safety_factor * self.__rate_limits["rpm"]
283
- return self._rpm
284
-
285
- @property
286
- def TPM(self):
287
- """Model's tokens-per-minute limit."""
288
- # self._set_rate_limits()
289
- # return self._safety_factor * self.__rate_limits["tpm"]
290
- return self._tpm
291
-
292
- @property
293
- def rpm(self):
294
- return self._rpm
295
-
296
- @rpm.setter
297
- def rpm(self, value):
298
- self._rpm = value
299
-
300
- @property
301
- def tpm(self):
302
- return self._tpm
303
-
304
- @tpm.setter
305
- def tpm(self, value):
306
- self._tpm = value
307
-
308
- @staticmethod
309
- def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
310
- """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
311
-
312
- >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
313
- {'temperature': 0.5}
314
- >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
315
- {'temperature': 0.5, 'max_tokens': 1000}
316
- """
317
- # parameters = dict({})
318
-
319
- # this is the case when data is loaded from a dict after serialization
320
- if "parameters" in passed_parameter_dict:
321
- passed_parameter_dict = passed_parameter_dict["parameters"]
322
- return {
323
- parameter_name: passed_parameter_dict.get(parameter_name, default_value)
324
- for parameter_name, default_value in default_parameter_dict.items()
325
- }
326
-
327
- def __call__(self, user_prompt: str, system_prompt: str):
328
- return self.execute_model_call(user_prompt, system_prompt)
329
-
330
- @abstractmethod
331
- async def async_execute_model_call(user_prompt: str, system_prompt: str):
332
- """Execute the model call and returns a coroutine.
333
-
334
- >>> m = LanguageModel.example(test_model = True)
335
- >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
336
- >>> asyncio.run(test())
337
- {'message': [{'text': 'Hello world'}], ...}
338
-
339
- >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
340
- {'message': [{'text': 'Hello world'}], ...}
341
- """
342
- pass
343
-
344
- async def remote_async_execute_model_call(
345
- self, user_prompt: str, system_prompt: str
346
- ):
347
- """Execute the model call and returns the result as a coroutine, using Coop."""
348
- from edsl.coop import Coop
349
-
350
- client = Coop()
351
- response_data = await client.remote_async_execute_model_call(
352
- self.to_dict(), user_prompt, system_prompt
353
- )
354
- return response_data
355
-
356
- @jupyter_nb_handler
357
- def execute_model_call(self, *args, **kwargs) -> Coroutine:
358
- """Execute the model call and returns the result as a coroutine.
359
-
360
- >>> m = LanguageModel.example(test_model = True)
361
- >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
362
-
363
- """
364
-
365
- async def main():
366
- results = await asyncio.gather(
367
- self.async_execute_model_call(*args, **kwargs)
368
- )
369
- return results[0] # Since there's only one task, return its result
370
-
371
- return main()
372
-
373
- @classmethod
374
- def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
375
- """Return the generated token string from the raw response."""
376
- return extract_item_from_raw_response(raw_response, cls.key_sequence)
377
-
378
- @classmethod
379
- def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
380
- """Return the usage dictionary from the raw response."""
381
- if not hasattr(cls, "usage_sequence"):
382
- raise NotImplementedError(
383
- "This inference service does not have a usage_sequence."
384
- )
385
- return extract_item_from_raw_response(raw_response, cls.usage_sequence)
386
-
387
- @staticmethod
388
- def convert_answer(response_part):
389
- import json
390
-
391
- response_part = response_part.strip()
392
-
393
- if response_part == "None":
394
- return None
395
-
396
- repaired = repair_json(response_part)
397
- if repaired == '""':
398
- # it was a literal string
399
- return response_part
400
-
401
- try:
402
- return json.loads(repaired)
403
- except json.JSONDecodeError as j:
404
- # last resort
405
- return response_part
406
-
407
- @classmethod
408
- def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
409
- """Parses the API response and returns the response text."""
410
- generated_token_string = cls.get_generated_token_string(raw_response)
411
- last_newline = generated_token_string.rfind("\n")
412
-
413
- if last_newline == -1:
414
- # There is no comment
415
- edsl_dict = {
416
- "answer": cls.convert_answer(generated_token_string),
417
- "generated_tokens": generated_token_string,
418
- "comment": None,
419
- }
420
- else:
421
- edsl_dict = {
422
- "answer": cls.convert_answer(generated_token_string[:last_newline]),
423
- "comment": generated_token_string[last_newline + 1 :].strip(),
424
- "generated_tokens": generated_token_string,
425
- }
426
- return EDSLOutput(**edsl_dict)
427
-
428
- async def _async_get_intended_model_call_outcome(
429
- self,
430
- user_prompt: str,
431
- system_prompt: str,
432
- cache: "Cache",
433
- iteration: int = 0,
434
- files_list=None,
435
- ) -> ModelResponse:
436
- """Handle caching of responses.
437
-
438
- :param user_prompt: The user's prompt.
439
- :param system_prompt: The system's prompt.
440
- :param iteration: The iteration number.
441
- :param cache: The cache to use.
442
-
443
- If the cache isn't being used, it just returns a 'fresh' call to the LLM.
444
- But if cache is being used, it first checks the database to see if the response is already there.
445
- If it is, it returns the cached response, but again appends some tracking information.
446
- If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
447
-
448
- If self.use_cache is True, then attempts to retrieve the response from the database;
449
- if not in the DB, calls the LLM and writes the response to the DB.
450
-
451
- >>> from edsl import Cache
452
- >>> m = LanguageModel.example(test_model = True)
453
- >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
454
- ModelResponse(...)"""
455
-
456
- if files_list:
457
- files_hash = "+".join([str(hash(file)) for file in files_list])
458
- # print(f"Files hash: {files_hash}")
459
- user_prompt_with_hashes = user_prompt + f" {files_hash}"
460
- else:
461
- user_prompt_with_hashes = user_prompt
462
-
463
- cache_call_params = {
464
- "model": str(self.model),
465
- "parameters": self.parameters,
466
- "system_prompt": system_prompt,
467
- "user_prompt": user_prompt_with_hashes,
468
- "iteration": iteration,
469
- }
470
- cached_response, cache_key = cache.fetch(**cache_call_params)
471
-
472
- if cache_used := cached_response is not None:
473
- response = json.loads(cached_response)
474
- else:
475
- f = (
476
- self.remote_async_execute_model_call
477
- if hasattr(self, "remote") and self.remote
478
- else self.async_execute_model_call
479
- )
480
- params = {
481
- "user_prompt": user_prompt,
482
- "system_prompt": system_prompt,
483
- "files_list": files_list,
484
- # **({"encoded_image": encoded_image} if encoded_image else {}),
485
- }
486
- # response = await f(**params)
487
- response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
488
- new_cache_key = cache.store(
489
- **cache_call_params, response=response
490
- ) # store the response in the cache
491
- assert new_cache_key == cache_key # should be the same
492
-
493
- cost = self.cost(response)
494
-
495
- return ModelResponse(
496
- response=response,
497
- cache_used=cache_used,
498
- cache_key=cache_key,
499
- cached_response=cached_response,
500
- cost=cost,
501
- )
502
-
503
- _get_intended_model_call_outcome = sync_wrapper(
504
- _async_get_intended_model_call_outcome
505
- )
506
-
507
- # get_raw_response = sync_wrapper(async_get_raw_response)
508
-
509
- def simple_ask(
510
- self,
511
- question: "QuestionBase",
512
- system_prompt="You are a helpful agent pretending to be a human.",
513
- top_logprobs=2,
514
- ):
515
- """Ask a question and return the response."""
516
- self.logprobs = True
517
- self.top_logprobs = top_logprobs
518
- return self.execute_model_call(
519
- user_prompt=question.human_readable(), system_prompt=system_prompt
520
- )
521
-
522
- async def async_get_response(
523
- self,
524
- user_prompt: str,
525
- system_prompt: str,
526
- cache: "Cache",
527
- iteration: int = 1,
528
- files_list: Optional[List["File"]] = None,
529
- ) -> dict:
530
- """Get response, parse, and return as string.
531
-
532
- :param user_prompt: The user's prompt.
533
- :param system_prompt: The system's prompt.
534
- :param iteration: The iteration number.
535
- :param cache: The cache to use.
536
- :param encoded_image: The encoded image to use.
537
-
538
- """
539
- params = {
540
- "user_prompt": user_prompt,
541
- "system_prompt": system_prompt,
542
- "iteration": iteration,
543
- "cache": cache,
544
- "files_list": files_list,
545
- }
546
- model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
547
- model_outputs = await self._async_get_intended_model_call_outcome(**params)
548
- edsl_dict = self.parse_response(model_outputs.response)
549
- agent_response_dict = AgentResponseDict(
550
- model_inputs=model_inputs,
551
- model_outputs=model_outputs,
552
- edsl_dict=edsl_dict,
553
- )
554
- return agent_response_dict
555
-
556
- # return await self._async_prepare_response(model_call_outcome, cache=cache)
557
-
558
- get_response = sync_wrapper(async_get_response)
559
-
560
- def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
561
- """Return the dollar cost of a raw response."""
562
-
563
- usage = self.get_usage_dict(raw_response)
564
- from edsl.coop import Coop
565
-
566
- c = Coop()
567
- price_lookup = c.fetch_prices()
568
- key = (self._inference_service_, self.model)
569
- if key not in price_lookup:
570
- return f"Could not find price for model {self.model} in the price lookup."
571
-
572
- relevant_prices = price_lookup[key]
573
- try:
574
- input_tokens = int(usage[self.input_token_name])
575
- output_tokens = int(usage[self.output_token_name])
576
- except Exception as e:
577
- return f"Could not fetch tokens from model response: {e}"
578
-
579
- try:
580
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
581
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
582
- except Exception as e:
583
- if "output" not in relevant_prices:
584
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
585
- if "input" not in relevant_prices:
586
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
587
- return f"Could not fetch prices from {relevant_prices} - {e}"
588
-
589
- if inverse_input_price == "infinity":
590
- input_cost = 0
591
- else:
592
- try:
593
- input_cost = input_tokens / float(inverse_input_price)
594
- except Exception as e:
595
- return f"Could not compute input price - {e}."
596
-
597
- if inverse_output_price == "infinity":
598
- output_cost = 0
599
- else:
600
- try:
601
- output_cost = output_tokens / float(inverse_output_price)
602
- except Exception as e:
603
- return f"Could not compute output price - {e}"
604
-
605
- return input_cost + output_cost
606
-
607
- #######################
608
- # SERIALIZATION METHODS
609
- #######################
610
- def to_dict(self, add_edsl_version=True) -> dict[str, Any]:
611
- """Convert instance to a dictionary
612
-
613
- >>> m = LanguageModel.example()
614
- >>> m.to_dict()
615
- {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
616
- """
617
- d = {"model": self.model, "parameters": self.parameters}
618
- if add_edsl_version:
619
- from edsl import __version__
620
-
621
- d["edsl_version"] = __version__
622
- d["edsl_class_name"] = self.__class__.__name__
623
- return d
624
-
625
- @classmethod
626
- @remove_edsl_version
627
- def from_dict(cls, data: dict) -> Type[LanguageModel]:
628
- """Convert dictionary to a LanguageModel child instance."""
629
- from edsl.language_models.registry import get_model_class
630
-
631
- model_class = get_model_class(data["model"])
632
- # data["use_cache"] = True
633
- return model_class(**data)
634
-
635
- #######################
636
- # DUNDER METHODS
637
- #######################
638
- def print(self):
639
- from rich import print_json
640
- import json
641
-
642
- print_json(json.dumps(self.to_dict()))
643
-
644
- def __repr__(self) -> str:
645
- """Return a string representation of the object."""
646
- param_string = ", ".join(
647
- f"{key} = {value}" for key, value in self.parameters.items()
648
- )
649
- return (
650
- f"Model(model_name = '{self.model}'"
651
- + (f", {param_string}" if param_string else "")
652
- + ")"
653
- )
654
-
655
- def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
656
- """Combine two models into a single model (other_model takes precedence over self)."""
657
- print(
658
- f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
659
- by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
660
- )
661
- return other_model or self
662
-
663
- def rich_print(self):
664
- """Display an object as a table."""
665
- from rich.table import Table
666
-
667
- table = Table(title="Language Model")
668
- table.add_column("Attribute", style="bold")
669
- table.add_column("Value")
670
-
671
- to_display = self.__dict__.copy()
672
- for attr_name, attr_value in to_display.items():
673
- table.add_row(attr_name, repr(attr_value))
674
-
675
- return table
676
-
677
- @classmethod
678
- def example(
679
- cls,
680
- test_model: bool = False,
681
- canned_response: str = "Hello world",
682
- throw_exception: bool = False,
683
- ):
684
- """Return a default instance of the class.
685
-
686
- >>> from edsl.language_models import LanguageModel
687
- >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
688
- >>> isinstance(m, LanguageModel)
689
- True
690
- >>> from edsl import QuestionFreeText
691
- >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
692
- >>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
693
- 'WOWZA!'
694
- """
695
- from edsl import Model
696
-
697
- if test_model:
698
- m = Model("test", canned_response=canned_response)
699
- return m
700
- else:
701
- return Model(skip_api_key_check=True)
702
-
703
-
704
- if __name__ == "__main__":
705
- """Run the module's test suite."""
706
- import doctest
707
-
708
- doctest.testmod(optionflags=doctest.ELLIPSIS)
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
+
3
+ Terminology:
4
+
5
+ raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
+
7
+ edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
+ such as the cache key, token usage, etc.
9
+
10
+ generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
+ edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
+
13
+ """
14
+
15
+ from __future__ import annotations
16
+ import warnings
17
+ from functools import wraps
18
+ import asyncio
19
+ import json
20
+ import os
21
+ from typing import (
22
+ Coroutine,
23
+ Any,
24
+ Type,
25
+ Union,
26
+ List,
27
+ get_type_hints,
28
+ TypedDict,
29
+ Optional,
30
+ TYPE_CHECKING,
31
+ )
32
+ from abc import ABC, abstractmethod
33
+
34
+ from edsl.data_transfer_models import (
35
+ ModelResponse,
36
+ ModelInputs,
37
+ EDSLOutput,
38
+ AgentResponseDict,
39
+ )
40
+
41
+ if TYPE_CHECKING:
42
+ from edsl.data.Cache import Cache
43
+ from edsl.scenarios.FileStore import FileStore
44
+ from edsl.questions.QuestionBase import QuestionBase
45
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
46
+
47
+ from edsl.enums import InferenceServiceType
48
+
49
+ from edsl.utilities.decorators import (
50
+ sync_wrapper,
51
+ jupyter_nb_handler,
52
+ )
53
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
54
+
55
+ from edsl.Base import PersistenceMixin, RepresentationMixin
56
+ from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
57
+
58
+ from edsl.language_models.key_management.KeyLookupCollection import (
59
+ KeyLookupCollection,
60
+ )
61
+
62
+ from edsl.language_models.RawResponseHandler import RawResponseHandler
63
+
64
+
65
+ def handle_key_error(func):
66
+ """Handle KeyError exceptions."""
67
+
68
+ @wraps(func)
69
+ def wrapper(*args, **kwargs):
70
+ try:
71
+ return func(*args, **kwargs)
72
+ assert True == False
73
+ except KeyError as e:
74
+ return f"""KeyError occurred: {e}. This is most likely because the model you are using
75
+ returned a JSON object we were not expecting."""
76
+
77
+ return wrapper
78
+
79
+
80
+ class classproperty:
81
+ def __init__(self, method):
82
+ self.method = method
83
+
84
+ def __get__(self, instance, cls):
85
+ return self.method(cls)
86
+
87
+
88
+ from edsl.Base import HashingMixin
89
+
90
+
91
+ class LanguageModel(
92
+ PersistenceMixin,
93
+ RepresentationMixin,
94
+ HashingMixin,
95
+ ABC,
96
+ metaclass=RegisterLanguageModelsMeta,
97
+ ):
98
+ """ABC for Language Models."""
99
+
100
+ _model_ = None
101
+ key_sequence = (
102
+ None # This should be something like ["choices", 0, "message", "content"]
103
+ )
104
+
105
+ DEFAULT_RPM = 100
106
+ DEFAULT_TPM = 1000
107
+
108
+ @classproperty
109
+ def response_handler(cls):
110
+ key_sequence = cls.key_sequence
111
+ usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
112
+ return RawResponseHandler(key_sequence, usage_sequence)
113
+
114
+ def __init__(
115
+ self,
116
+ tpm: Optional[float] = None,
117
+ rpm: Optional[float] = None,
118
+ omit_system_prompt_if_empty_string: bool = True,
119
+ key_lookup: Optional["KeyLookup"] = None,
120
+ **kwargs,
121
+ ):
122
+ """Initialize the LanguageModel."""
123
+ self.model = getattr(self, "_model_", None)
124
+ default_parameters = getattr(self, "_parameters_", None)
125
+ parameters = self._overide_default_parameters(kwargs, default_parameters)
126
+ self.parameters = parameters
127
+ self.remote = False
128
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
129
+
130
+ self.key_lookup = self._set_key_lookup(key_lookup)
131
+ self.model_info = self.key_lookup.get(self._inference_service_)
132
+
133
+ if rpm is not None:
134
+ self._rpm = rpm
135
+
136
+ if tpm is not None:
137
+ self._tpm = tpm
138
+
139
+ for key, value in parameters.items():
140
+ setattr(self, key, value)
141
+
142
+ for key, value in kwargs.items():
143
+ if key not in parameters:
144
+ setattr(self, key, value)
145
+
146
+ if kwargs.get("skip_api_key_check", False):
147
+ # Skip the API key check. Sometimes this is useful for testing.
148
+ self._api_token = None
149
+
150
+ def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
151
+ """Set the key lookup."""
152
+ if key_lookup is not None:
153
+ return key_lookup
154
+ else:
155
+ klc = KeyLookupCollection()
156
+ klc.add_key_lookup(fetch_order=("config", "env"))
157
+ return klc.get(("config", "env"))
158
+
159
+ def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
160
+ """Set the key lookup, later"""
161
+ if hasattr(self, "_api_token"):
162
+ del self._api_token
163
+ self.key_lookup = key_lookup
164
+
165
+ def ask_question(self, question: "QuestionBase") -> str:
166
+ """Ask a question and return the response.
167
+
168
+ :param question: The question to ask.
169
+ """
170
+ user_prompt = question.get_instructions().render(question.data).text
171
+ system_prompt = "You are a helpful agent pretending to be a human."
172
+ return self.execute_model_call(user_prompt, system_prompt)
173
+
174
+ @property
175
+ def rpm(self):
176
+ if not hasattr(self, "_rpm"):
177
+ if self.model_info is None:
178
+ self._rpm = self.DEFAULT_RPM
179
+ else:
180
+ self._rpm = self.model_info.rpm
181
+ return self._rpm
182
+
183
+ @property
184
+ def tpm(self):
185
+ if not hasattr(self, "_tpm"):
186
+ if self.model_info is None:
187
+ self._tpm = self.DEFAULT_TPM
188
+ else:
189
+ self._tpm = self.model_info.tpm
190
+ return self._tpm
191
+
192
+ # in case we want to override the default values
193
+ @tpm.setter
194
+ def tpm(self, value):
195
+ self._tpm = value
196
+
197
+ @rpm.setter
198
+ def rpm(self, value):
199
+ self._rpm = value
200
+
201
+ @property
202
+ def api_token(self) -> str:
203
+ if not hasattr(self, "_api_token"):
204
+ info = self.key_lookup.get(self._inference_service_, None)
205
+ if info is None:
206
+ raise ValueError(
207
+ f"No key found for service '{self._inference_service_}'"
208
+ )
209
+ self._api_token = info.api_token
210
+ return self._api_token
211
+
212
+ def __getitem__(self, key):
213
+ return getattr(self, key)
214
+
215
+ def hello(self, verbose=False):
216
+ """Runs a simple test to check if the model is working."""
217
+ token = self.api_token
218
+ masked = token[: min(8, len(token))] + "..."
219
+ if verbose:
220
+ print(f"Current key is {masked}")
221
+ return self.execute_model_call(
222
+ user_prompt="Hello, model!", system_prompt="You are a helpful agent."
223
+ )
224
+
225
+ def has_valid_api_key(self) -> bool:
226
+ """Check if the model has a valid API key.
227
+
228
+ >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
229
+ True
230
+
231
+ This method is used to check if the model has a valid API key.
232
+ """
233
+ from edsl.enums import service_to_api_keyname
234
+
235
+ if self._model_ == "test":
236
+ return True
237
+
238
+ key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
239
+ key_value = os.getenv(key_name)
240
+ return key_value is not None
241
+
242
+ def __hash__(self) -> str:
243
+ """Allow the model to be used as a key in a dictionary.
244
+
245
+ >>> m = LanguageModel.example()
246
+ >>> hash(m)
247
+ 1811901442659237949
248
+ """
249
+ from edsl.utilities.utilities import dict_hash
250
+
251
+ return dict_hash(self.to_dict(add_edsl_version=False))
252
+
253
+ def __eq__(self, other) -> bool:
254
+ """Check is two models are the same.
255
+
256
+ >>> m1 = LanguageModel.example()
257
+ >>> m2 = LanguageModel.example()
258
+ >>> m1 == m2
259
+ True
260
+
261
+ """
262
+ return self.model == other.model and self.parameters == other.parameters
263
+
264
+ @staticmethod
265
+ def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
266
+ """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
267
+
268
+ >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
269
+ {'temperature': 0.5}
270
+ >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
271
+ {'temperature': 0.5, 'max_tokens': 1000}
272
+ """
273
+ # this is the case when data is loaded from a dict after serialization
274
+ if "parameters" in passed_parameter_dict:
275
+ passed_parameter_dict = passed_parameter_dict["parameters"]
276
+ return {
277
+ parameter_name: passed_parameter_dict.get(parameter_name, default_value)
278
+ for parameter_name, default_value in default_parameter_dict.items()
279
+ }
280
+
281
+ def __call__(self, user_prompt: str, system_prompt: str):
282
+ return self.execute_model_call(user_prompt, system_prompt)
283
+
284
+ @abstractmethod
285
+ async def async_execute_model_call(user_prompt: str, system_prompt: str):
286
+ """Execute the model call and returns a coroutine."""
287
+ pass
288
+
289
+ async def remote_async_execute_model_call(
290
+ self, user_prompt: str, system_prompt: str
291
+ ):
292
+ """Execute the model call and returns the result as a coroutine, using Coop."""
293
+ from edsl.coop import Coop
294
+
295
+ client = Coop()
296
+ response_data = await client.remote_async_execute_model_call(
297
+ self.to_dict(), user_prompt, system_prompt
298
+ )
299
+ return response_data
300
+
301
+ @jupyter_nb_handler
302
+ def execute_model_call(self, *args, **kwargs) -> Coroutine:
303
+ """Execute the model call and returns the result as a coroutine."""
304
+
305
+ async def main():
306
+ results = await asyncio.gather(
307
+ self.async_execute_model_call(*args, **kwargs)
308
+ )
309
+ return results[0] # Since there's only one task, return its result
310
+
311
+ return main()
312
+
313
+ @classmethod
314
+ def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
315
+ """Return the generated token string from the raw response.
316
+
317
+ >>> m = LanguageModel.example(test_model = True)
318
+ >>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
319
+ >>> m.get_generated_token_string(raw_response)
320
+ 'Hello world'
321
+
322
+ """
323
+ return cls.response_handler.get_generated_token_string(raw_response)
324
+
325
+ @classmethod
326
+ def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
327
+ """Return the usage dictionary from the raw response."""
328
+ return cls.response_handler.get_usage_dict(raw_response)
329
+
330
+ @classmethod
331
+ def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
332
+ """Parses the API response and returns the response text."""
333
+ return cls.response_handler.parse_response(raw_response)
334
+
335
+ async def _async_get_intended_model_call_outcome(
336
+ self,
337
+ user_prompt: str,
338
+ system_prompt: str,
339
+ cache: Cache,
340
+ iteration: int = 0,
341
+ files_list: Optional[List[FileStore]] = None,
342
+ invigilator=None,
343
+ ) -> ModelResponse:
344
+ """Handle caching of responses.
345
+
346
+ :param user_prompt: The user's prompt.
347
+ :param system_prompt: The system's prompt.
348
+ :param iteration: The iteration number.
349
+ :param cache: The cache to use.
350
+ :param files_list: The list of files to use.
351
+ :param invigilator: The invigilator to use.
352
+
353
+ If the cache isn't being used, it just returns a 'fresh' call to the LLM.
354
+ But if cache is being used, it first checks the database to see if the response is already there.
355
+ If it is, it returns the cached response, but again appends some tracking information.
356
+ If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
357
+
358
+ If self.use_cache is True, then attempts to retrieve the response from the database;
359
+ if not in the DB, calls the LLM and writes the response to the DB.
360
+
361
+ >>> from edsl import Cache
362
+ >>> m = LanguageModel.example(test_model = True)
363
+ >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
364
+ ModelResponse(...)"""
365
+
366
+ if files_list:
367
+ files_hash = "+".join([str(hash(file)) for file in files_list])
368
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
369
+ else:
370
+ user_prompt_with_hashes = user_prompt
371
+
372
+ cache_call_params = {
373
+ "model": str(self.model),
374
+ "parameters": self.parameters,
375
+ "system_prompt": system_prompt,
376
+ "user_prompt": user_prompt_with_hashes,
377
+ "iteration": iteration,
378
+ }
379
+ cached_response, cache_key = cache.fetch(**cache_call_params)
380
+
381
+ if cache_used := cached_response is not None:
382
+ response = json.loads(cached_response)
383
+ else:
384
+ f = (
385
+ self.remote_async_execute_model_call
386
+ if hasattr(self, "remote") and self.remote
387
+ else self.async_execute_model_call
388
+ )
389
+ params = {
390
+ "user_prompt": user_prompt,
391
+ "system_prompt": system_prompt,
392
+ "files_list": files_list,
393
+ }
394
+ from edsl.config import CONFIG
395
+
396
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
+
398
+ response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
399
+ new_cache_key = cache.store(
400
+ **cache_call_params, response=response
401
+ ) # store the response in the cache
402
+ assert new_cache_key == cache_key # should be the same
403
+
404
+ cost = self.cost(response)
405
+ return ModelResponse(
406
+ response=response,
407
+ cache_used=cache_used,
408
+ cache_key=cache_key,
409
+ cached_response=cached_response,
410
+ cost=cost,
411
+ )
412
+
413
+ _get_intended_model_call_outcome = sync_wrapper(
414
+ _async_get_intended_model_call_outcome
415
+ )
416
+
417
+ def simple_ask(
418
+ self,
419
+ question: QuestionBase,
420
+ system_prompt="You are a helpful agent pretending to be a human.",
421
+ top_logprobs=2,
422
+ ):
423
+ """Ask a question and return the response."""
424
+ self.logprobs = True
425
+ self.top_logprobs = top_logprobs
426
+ return self.execute_model_call(
427
+ user_prompt=question.human_readable(), system_prompt=system_prompt
428
+ )
429
+
430
+ async def async_get_response(
431
+ self,
432
+ user_prompt: str,
433
+ system_prompt: str,
434
+ cache: Cache,
435
+ iteration: int = 1,
436
+ files_list: Optional[List[FileStore]] = None,
437
+ **kwargs,
438
+ ) -> dict:
439
+ """Get response, parse, and return as string.
440
+
441
+ :param user_prompt: The user's prompt.
442
+ :param system_prompt: The system's prompt.
443
+ :param cache: The cache to use.
444
+ :param iteration: The iteration number.
445
+ :param files_list: The list of files to use.
446
+
447
+ """
448
+ params = {
449
+ "user_prompt": user_prompt,
450
+ "system_prompt": system_prompt,
451
+ "iteration": iteration,
452
+ "cache": cache,
453
+ "files_list": files_list,
454
+ }
455
+ if "invigilator" in kwargs:
456
+ params.update({"invigilator": kwargs["invigilator"]})
457
+
458
+ model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
459
+ model_outputs: ModelResponse = (
460
+ await self._async_get_intended_model_call_outcome(**params)
461
+ )
462
+ edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
463
+
464
+ agent_response_dict = AgentResponseDict(
465
+ model_inputs=model_inputs,
466
+ model_outputs=model_outputs,
467
+ edsl_dict=edsl_dict,
468
+ )
469
+ return agent_response_dict
470
+
471
+ get_response = sync_wrapper(async_get_response)
472
+
473
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
474
+ """Return the dollar cost of a raw response.
475
+
476
+ :param raw_response: The raw response from the model.
477
+ """
478
+
479
+ usage = self.get_usage_dict(raw_response)
480
+ from edsl.language_models.PriceManager import PriceManager
481
+
482
+ price_manger = PriceManager()
483
+ return price_manger.calculate_cost(
484
+ inference_service=self._inference_service_,
485
+ model=self.model,
486
+ usage=usage,
487
+ input_token_name=self.input_token_name,
488
+ output_token_name=self.output_token_name,
489
+ )
490
+
491
+ def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
492
+ """Convert instance to a dictionary
493
+
494
+ :param add_edsl_version: Whether to add the EDSL version to the dictionary.
495
+
496
+ >>> m = LanguageModel.example()
497
+ >>> m.to_dict()
498
+ {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
499
+ """
500
+ d = {
501
+ "model": self.model,
502
+ "parameters": self.parameters,
503
+ }
504
+ if add_edsl_version:
505
+ from edsl import __version__
506
+
507
+ d["edsl_version"] = __version__
508
+ d["edsl_class_name"] = self.__class__.__name__
509
+ return d
510
+
511
+ @classmethod
512
+ @remove_edsl_version
513
+ def from_dict(cls, data: dict) -> Type[LanguageModel]:
514
+ """Convert dictionary to a LanguageModel child instance."""
515
+ from edsl.language_models.model import get_model_class
516
+
517
+ model_class = get_model_class(data["model"])
518
+ return model_class(**data)
519
+
520
+ def __repr__(self) -> str:
521
+ """Return a representation of the object."""
522
+ param_string = ", ".join(
523
+ f"{key} = {value}" for key, value in self.parameters.items()
524
+ )
525
+ return (
526
+ f"Model(model_name = '{self.model}'"
527
+ + (f", {param_string}" if param_string else "")
528
+ + ")"
529
+ )
530
+
531
+ def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
532
+ """Combine two models into a single model (other_model takes precedence over self)."""
533
+ import warnings
534
+
535
+ warnings.warn(
536
+ f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
537
+ by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
538
+ )
539
+ return other_model or self
540
+
541
+ @classmethod
542
+ def example(
543
+ cls,
544
+ test_model: bool = False,
545
+ canned_response: str = "Hello world",
546
+ throw_exception: bool = False,
547
+ ) -> LanguageModel:
548
+ """Return a default instance of the class.
549
+
550
+ >>> from edsl.language_models import LanguageModel
551
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
552
+ >>> isinstance(m, LanguageModel)
553
+ True
554
+ >>> from edsl import QuestionFreeText
555
+ >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
556
+ >>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
557
+ 'WOWZA!'
558
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
559
+ >>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
560
+ Exception report saved to ...
561
+ Also see: ...
562
+ """
563
+ from edsl.language_models.model import Model
564
+
565
+ if test_model:
566
+ m = Model(
567
+ "test", canned_response=canned_response, throw_exception=throw_exception
568
+ )
569
+ return m
570
+ else:
571
+ return Model(skip_api_key_check=True)
572
+
573
+ def from_cache(self, cache: "Cache") -> LanguageModel:
574
+
575
+ from copy import deepcopy
576
+ from types import MethodType
577
+ from edsl import Cache
578
+
579
+ new_instance = deepcopy(self)
580
+ print("Cache entries", len(cache))
581
+ new_instance.cache = Cache(
582
+ data={k: v for k, v in cache.items() if v.model == self.model}
583
+ )
584
+ print("Cache entries with same model", len(new_instance.cache))
585
+
586
+ new_instance.user_prompts = [
587
+ ce.user_prompt for ce in new_instance.cache.values()
588
+ ]
589
+ new_instance.system_prompts = [
590
+ ce.system_prompt for ce in new_instance.cache.values()
591
+ ]
592
+
593
+ async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
594
+ cache_call_params = {
595
+ "model": str(self.model),
596
+ "parameters": self.parameters,
597
+ "system_prompt": system_prompt,
598
+ "user_prompt": user_prompt,
599
+ "iteration": 1,
600
+ }
601
+ cached_response, cache_key = cache.fetch(**cache_call_params)
602
+ response = json.loads(cached_response)
603
+ cost = 0
604
+ return ModelResponse(
605
+ response=response,
606
+ cache_used=True,
607
+ cache_key=cache_key,
608
+ cached_response=cached_response,
609
+ cost=cost,
610
+ )
611
+
612
+ # Bind the new method to the copied instance
613
+ setattr(
614
+ new_instance,
615
+ "async_execute_model_call",
616
+ MethodType(async_execute_model_call, new_instance),
617
+ )
618
+
619
+ return new_instance
620
+
621
+
622
+ if __name__ == "__main__":
623
+ """Run the module's test suite."""
624
+ import doctest
625
+
626
+ doctest.testmod(optionflags=doctest.ELLIPSIS)