edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. edsl/Base.py +332 -385
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +49 -57
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +867 -1079
  7. edsl/agents/AgentList.py +413 -551
  8. edsl/agents/Invigilator.py +233 -285
  9. edsl/agents/InvigilatorBase.py +270 -254
  10. edsl/agents/PromptConstructor.py +354 -252
  11. edsl/agents/__init__.py +3 -2
  12. edsl/agents/descriptors.py +99 -99
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +279 -279
  26. edsl/config.py +157 -177
  27. edsl/conversation/Conversation.py +290 -290
  28. edsl/conversation/car_buying.py +58 -59
  29. edsl/conversation/chips.py +95 -95
  30. edsl/conversation/mug_negotiation.py +81 -81
  31. edsl/conversation/next_speaker_utilities.py +93 -93
  32. edsl/coop/PriceFetcher.py +54 -54
  33. edsl/coop/__init__.py +2 -2
  34. edsl/coop/coop.py +1028 -1090
  35. edsl/coop/utils.py +131 -131
  36. edsl/data/Cache.py +555 -562
  37. edsl/data/CacheEntry.py +233 -230
  38. edsl/data/CacheHandler.py +149 -170
  39. edsl/data/RemoteCacheSync.py +78 -78
  40. edsl/data/SQLiteDict.py +292 -292
  41. edsl/data/__init__.py +4 -5
  42. edsl/data/orm.py +10 -10
  43. edsl/data_transfer_models.py +73 -74
  44. edsl/enums.py +175 -195
  45. edsl/exceptions/BaseException.py +21 -21
  46. edsl/exceptions/__init__.py +54 -54
  47. edsl/exceptions/agents.py +42 -54
  48. edsl/exceptions/cache.py +5 -5
  49. edsl/exceptions/configuration.py +16 -16
  50. edsl/exceptions/coop.py +10 -10
  51. edsl/exceptions/data.py +14 -14
  52. edsl/exceptions/general.py +34 -34
  53. edsl/exceptions/jobs.py +33 -33
  54. edsl/exceptions/language_models.py +63 -63
  55. edsl/exceptions/prompts.py +15 -15
  56. edsl/exceptions/questions.py +91 -109
  57. edsl/exceptions/results.py +29 -29
  58. edsl/exceptions/scenarios.py +22 -29
  59. edsl/exceptions/surveys.py +37 -37
  60. edsl/inference_services/AnthropicService.py +87 -84
  61. edsl/inference_services/AwsBedrock.py +120 -118
  62. edsl/inference_services/AzureAI.py +217 -215
  63. edsl/inference_services/DeepInfraService.py +18 -18
  64. edsl/inference_services/GoogleService.py +148 -139
  65. edsl/inference_services/GroqService.py +20 -20
  66. edsl/inference_services/InferenceServiceABC.py +147 -80
  67. edsl/inference_services/InferenceServicesCollection.py +97 -122
  68. edsl/inference_services/MistralAIService.py +123 -120
  69. edsl/inference_services/OllamaService.py +18 -18
  70. edsl/inference_services/OpenAIService.py +224 -221
  71. edsl/inference_services/PerplexityService.py +163 -160
  72. edsl/inference_services/TestService.py +89 -92
  73. edsl/inference_services/TogetherAIService.py +170 -170
  74. edsl/inference_services/models_available_cache.py +118 -118
  75. edsl/inference_services/rate_limits_cache.py +25 -25
  76. edsl/inference_services/registry.py +41 -41
  77. edsl/inference_services/write_available.py +10 -10
  78. edsl/jobs/Answers.py +56 -43
  79. edsl/jobs/Jobs.py +898 -757
  80. edsl/jobs/JobsChecks.py +147 -172
  81. edsl/jobs/JobsPrompts.py +268 -270
  82. edsl/jobs/JobsRemoteInferenceHandler.py +239 -287
  83. edsl/jobs/__init__.py +1 -1
  84. edsl/jobs/buckets/BucketCollection.py +63 -104
  85. edsl/jobs/buckets/ModelBuckets.py +65 -65
  86. edsl/jobs/buckets/TokenBucket.py +251 -283
  87. edsl/jobs/interviews/Interview.py +661 -358
  88. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  89. edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
  90. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  91. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  92. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  93. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  94. edsl/jobs/interviews/ReportErrors.py +66 -66
  95. edsl/jobs/interviews/interview_status_enum.py +9 -9
  96. edsl/jobs/runners/JobsRunnerAsyncio.py +466 -421
  97. edsl/jobs/runners/JobsRunnerStatus.py +330 -330
  98. edsl/jobs/tasks/QuestionTaskCreator.py +242 -244
  99. edsl/jobs/tasks/TaskCreators.py +64 -64
  100. edsl/jobs/tasks/TaskHistory.py +450 -449
  101. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  102. edsl/jobs/tasks/task_status_enum.py +163 -161
  103. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  104. edsl/jobs/tokens/TokenUsage.py +34 -34
  105. edsl/language_models/KeyLookup.py +30 -0
  106. edsl/language_models/LanguageModel.py +668 -571
  107. edsl/language_models/ModelList.py +155 -153
  108. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  109. edsl/language_models/__init__.py +3 -2
  110. edsl/language_models/fake_openai_call.py +15 -15
  111. edsl/language_models/fake_openai_service.py +61 -61
  112. edsl/language_models/registry.py +190 -180
  113. edsl/language_models/repair.py +156 -156
  114. edsl/language_models/unused/ReplicateBase.py +83 -0
  115. edsl/language_models/utilities.py +64 -65
  116. edsl/notebooks/Notebook.py +258 -263
  117. edsl/notebooks/__init__.py +1 -1
  118. edsl/prompts/Prompt.py +362 -352
  119. edsl/prompts/__init__.py +2 -2
  120. edsl/questions/AnswerValidatorMixin.py +289 -334
  121. edsl/questions/QuestionBase.py +664 -509
  122. edsl/questions/QuestionBaseGenMixin.py +161 -165
  123. edsl/questions/QuestionBasePromptsMixin.py +217 -221
  124. edsl/questions/QuestionBudget.py +227 -227
  125. edsl/questions/QuestionCheckBox.py +359 -359
  126. edsl/questions/QuestionExtract.py +182 -182
  127. edsl/questions/QuestionFreeText.py +114 -113
  128. edsl/questions/QuestionFunctional.py +166 -166
  129. edsl/questions/QuestionList.py +231 -229
  130. edsl/questions/QuestionMultipleChoice.py +286 -330
  131. edsl/questions/QuestionNumerical.py +153 -151
  132. edsl/questions/QuestionRank.py +324 -314
  133. edsl/questions/Quick.py +41 -41
  134. edsl/questions/RegisterQuestionsMeta.py +71 -71
  135. edsl/questions/ResponseValidatorABC.py +174 -200
  136. edsl/questions/SimpleAskMixin.py +73 -74
  137. edsl/questions/__init__.py +26 -27
  138. edsl/questions/compose_questions.py +98 -98
  139. edsl/questions/decorators.py +21 -21
  140. edsl/questions/derived/QuestionLikertFive.py +76 -76
  141. edsl/questions/derived/QuestionLinearScale.py +87 -90
  142. edsl/questions/derived/QuestionTopK.py +93 -93
  143. edsl/questions/derived/QuestionYesNo.py +82 -82
  144. edsl/questions/descriptors.py +413 -427
  145. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  146. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  147. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  148. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  149. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  150. edsl/questions/prompt_templates/question_list.jinja +17 -17
  151. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  152. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  153. edsl/questions/question_registry.py +177 -177
  154. edsl/questions/settings.py +12 -12
  155. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  156. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  157. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  158. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  159. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  160. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  161. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  162. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  163. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  164. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  165. edsl/questions/templates/list/question_presentation.jinja +5 -5
  166. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  167. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  168. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  169. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  170. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  171. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  172. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  173. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  174. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  176. edsl/results/CSSParameterizer.py +108 -108
  177. edsl/results/Dataset.py +424 -587
  178. edsl/results/DatasetExportMixin.py +731 -653
  179. edsl/results/DatasetTree.py +275 -295
  180. edsl/results/Result.py +465 -451
  181. edsl/results/Results.py +1165 -1172
  182. edsl/results/ResultsDBMixin.py +238 -0
  183. edsl/results/ResultsExportMixin.py +43 -45
  184. edsl/results/ResultsFetchMixin.py +33 -33
  185. edsl/results/ResultsGGMixin.py +121 -121
  186. edsl/results/ResultsToolsMixin.py +98 -98
  187. edsl/results/Selector.py +135 -145
  188. edsl/results/TableDisplay.py +198 -125
  189. edsl/results/__init__.py +2 -2
  190. edsl/results/table_display.css +77 -77
  191. edsl/results/tree_explore.py +115 -115
  192. edsl/scenarios/FileStore.py +632 -511
  193. edsl/scenarios/Scenario.py +601 -498
  194. edsl/scenarios/ScenarioHtmlMixin.py +64 -65
  195. edsl/scenarios/ScenarioJoin.py +127 -131
  196. edsl/scenarios/ScenarioList.py +1287 -1430
  197. edsl/scenarios/ScenarioListExportMixin.py +52 -45
  198. edsl/scenarios/ScenarioListPdfMixin.py +261 -239
  199. edsl/scenarios/__init__.py +4 -3
  200. edsl/shared.py +1 -1
  201. edsl/study/ObjectEntry.py +173 -173
  202. edsl/study/ProofOfWork.py +113 -113
  203. edsl/study/SnapShot.py +80 -80
  204. edsl/study/Study.py +528 -521
  205. edsl/study/__init__.py +4 -4
  206. edsl/surveys/DAG.py +148 -148
  207. edsl/surveys/Memory.py +31 -31
  208. edsl/surveys/MemoryPlan.py +244 -244
  209. edsl/surveys/Rule.py +326 -327
  210. edsl/surveys/RuleCollection.py +387 -385
  211. edsl/surveys/Survey.py +1801 -1229
  212. edsl/surveys/SurveyCSS.py +261 -273
  213. edsl/surveys/SurveyExportMixin.py +259 -259
  214. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +179 -181
  215. edsl/surveys/SurveyQualtricsImport.py +284 -284
  216. edsl/surveys/__init__.py +3 -5
  217. edsl/surveys/base.py +53 -53
  218. edsl/surveys/descriptors.py +56 -60
  219. edsl/surveys/instructions/ChangeInstruction.py +49 -48
  220. edsl/surveys/instructions/Instruction.py +65 -56
  221. edsl/surveys/instructions/InstructionCollection.py +77 -82
  222. edsl/templates/error_reporting/base.html +23 -23
  223. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  224. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  225. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  226. edsl/templates/error_reporting/interview_details.html +115 -115
  227. edsl/templates/error_reporting/interviews.html +19 -19
  228. edsl/templates/error_reporting/overview.html +4 -4
  229. edsl/templates/error_reporting/performance_plot.html +1 -1
  230. edsl/templates/error_reporting/report.css +73 -73
  231. edsl/templates/error_reporting/report.html +117 -117
  232. edsl/templates/error_reporting/report.js +25 -25
  233. edsl/tools/__init__.py +1 -1
  234. edsl/tools/clusters.py +192 -192
  235. edsl/tools/embeddings.py +27 -27
  236. edsl/tools/embeddings_plotting.py +118 -118
  237. edsl/tools/plotting.py +112 -112
  238. edsl/tools/summarize.py +18 -18
  239. edsl/utilities/SystemInfo.py +28 -28
  240. edsl/utilities/__init__.py +22 -22
  241. edsl/utilities/ast_utilities.py +25 -25
  242. edsl/utilities/data/Registry.py +6 -6
  243. edsl/utilities/data/__init__.py +1 -1
  244. edsl/utilities/data/scooter_results.json +1 -1
  245. edsl/utilities/decorators.py +77 -77
  246. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  247. edsl/utilities/interface.py +627 -627
  248. edsl/utilities/naming_utilities.py +263 -263
  249. edsl/utilities/repair_functions.py +28 -28
  250. edsl/utilities/restricted_python.py +70 -70
  251. edsl/utilities/utilities.py +424 -436
  252. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/LICENSE +21 -21
  253. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/METADATA +10 -12
  254. edsl-0.1.39.dev3.dist-info/RECORD +277 -0
  255. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  256. edsl/agents/QuestionOptionProcessor.py +0 -172
  257. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  258. edsl/coop/CoopFunctionsMixin.py +0 -15
  259. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  260. edsl/exceptions/inference_services.py +0 -5
  261. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  262. edsl/inference_services/AvailableModelFetcher.py +0 -209
  263. edsl/inference_services/ServiceAvailability.py +0 -135
  264. edsl/inference_services/data_structures.py +0 -62
  265. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -188
  266. edsl/jobs/FetchInvigilator.py +0 -40
  267. edsl/jobs/InterviewTaskManager.py +0 -98
  268. edsl/jobs/InterviewsConstructor.py +0 -48
  269. edsl/jobs/JobsComponentConstructor.py +0 -189
  270. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  271. edsl/jobs/RequestTokenEstimator.py +0 -30
  272. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  273. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  274. edsl/jobs/decorators.py +0 -35
  275. edsl/jobs/jobs_status_enums.py +0 -9
  276. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  277. edsl/language_models/ComputeCost.py +0 -63
  278. edsl/language_models/PriceManager.py +0 -127
  279. edsl/language_models/RawResponseHandler.py +0 -106
  280. edsl/language_models/ServiceDataSources.py +0 -0
  281. edsl/language_models/key_management/KeyLookup.py +0 -63
  282. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  283. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  284. edsl/language_models/key_management/__init__.py +0 -0
  285. edsl/language_models/key_management/models.py +0 -131
  286. edsl/notebooks/NotebookToLaTeX.py +0 -142
  287. edsl/questions/ExceptionExplainer.py +0 -77
  288. edsl/questions/HTMLQuestion.py +0 -103
  289. edsl/questions/LoopProcessor.py +0 -149
  290. edsl/questions/QuestionMatrix.py +0 -265
  291. edsl/questions/ResponseValidatorFactory.py +0 -28
  292. edsl/questions/templates/matrix/__init__.py +0 -1
  293. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  294. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  295. edsl/results/MarkdownToDocx.py +0 -122
  296. edsl/results/MarkdownToPDF.py +0 -111
  297. edsl/results/TextEditor.py +0 -50
  298. edsl/results/smart_objects.py +0 -96
  299. edsl/results/table_data_class.py +0 -12
  300. edsl/results/table_renderers.py +0 -118
  301. edsl/scenarios/ConstructDownloadLink.py +0 -109
  302. edsl/scenarios/DirectoryScanner.py +0 -96
  303. edsl/scenarios/DocumentChunker.py +0 -102
  304. edsl/scenarios/DocxScenario.py +0 -16
  305. edsl/scenarios/PdfExtractor.py +0 -40
  306. edsl/scenarios/ScenarioSelector.py +0 -156
  307. edsl/scenarios/file_methods.py +0 -85
  308. edsl/scenarios/handlers/__init__.py +0 -13
  309. edsl/scenarios/handlers/csv.py +0 -38
  310. edsl/scenarios/handlers/docx.py +0 -76
  311. edsl/scenarios/handlers/html.py +0 -37
  312. edsl/scenarios/handlers/json.py +0 -111
  313. edsl/scenarios/handlers/latex.py +0 -5
  314. edsl/scenarios/handlers/md.py +0 -51
  315. edsl/scenarios/handlers/pdf.py +0 -68
  316. edsl/scenarios/handlers/png.py +0 -39
  317. edsl/scenarios/handlers/pptx.py +0 -105
  318. edsl/scenarios/handlers/py.py +0 -294
  319. edsl/scenarios/handlers/sql.py +0 -313
  320. edsl/scenarios/handlers/sqlite.py +0 -149
  321. edsl/scenarios/handlers/txt.py +0 -33
  322. edsl/surveys/ConstructDAG.py +0 -92
  323. edsl/surveys/EditSurvey.py +0 -221
  324. edsl/surveys/InstructionHandler.py +0 -100
  325. edsl/surveys/MemoryManagement.py +0 -72
  326. edsl/surveys/RuleManager.py +0 -172
  327. edsl/surveys/Simulator.py +0 -75
  328. edsl/surveys/SurveyToApp.py +0 -141
  329. edsl/utilities/PrettyList.py +0 -56
  330. edsl/utilities/is_notebook.py +0 -18
  331. edsl/utilities/is_valid_variable_name.py +0 -11
  332. edsl/utilities/remove_edsl_version.py +0 -24
  333. edsl-0.1.39.dev2.dist-info/RECORD +0 -352
  334. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev3.dist-info}/WHEEL +0 -0
@@ -1,571 +1,668 @@
1
- """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
-
3
- Terminology:
4
-
5
- raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
-
7
- edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
- such as the cache key, token usage, etc.
9
-
10
- generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
- edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
-
13
- """
14
-
15
- from __future__ import annotations
16
- import warnings
17
- from functools import wraps
18
- import asyncio
19
- import json
20
- import os
21
- from typing import (
22
- Coroutine,
23
- Any,
24
- Type,
25
- Union,
26
- List,
27
- get_type_hints,
28
- TypedDict,
29
- Optional,
30
- TYPE_CHECKING,
31
- )
32
- from abc import ABC, abstractmethod
33
-
34
- from edsl.data_transfer_models import (
35
- ModelResponse,
36
- ModelInputs,
37
- EDSLOutput,
38
- AgentResponseDict,
39
- )
40
-
41
- if TYPE_CHECKING:
42
- from edsl.data.Cache import Cache
43
- from edsl.scenarios.FileStore import FileStore
44
- from edsl.questions.QuestionBase import QuestionBase
45
- from edsl.language_models.key_management.KeyLookup import KeyLookup
46
-
47
- from edsl.utilities.decorators import (
48
- sync_wrapper,
49
- jupyter_nb_handler,
50
- )
51
- from edsl.utilities.remove_edsl_version import remove_edsl_version
52
-
53
- from edsl.Base import PersistenceMixin, RepresentationMixin
54
- from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
-
56
- from edsl.language_models.key_management.KeyLookupCollection import (
57
- KeyLookupCollection,
58
- )
59
-
60
- from edsl.language_models.RawResponseHandler import RawResponseHandler
61
-
62
-
63
- def handle_key_error(func):
64
- """Handle KeyError exceptions."""
65
-
66
- @wraps(func)
67
- def wrapper(*args, **kwargs):
68
- try:
69
- return func(*args, **kwargs)
70
- assert True == False
71
- except KeyError as e:
72
- return f"""KeyError occurred: {e}. This is most likely because the model you are using
73
- returned a JSON object we were not expecting."""
74
-
75
- return wrapper
76
-
77
-
78
- class classproperty:
79
- def __init__(self, method):
80
- self.method = method
81
-
82
- def __get__(self, instance, cls):
83
- return self.method(cls)
84
-
85
-
86
- from edsl.Base import HashingMixin
87
-
88
-
89
- class LanguageModel(
90
- PersistenceMixin,
91
- RepresentationMixin,
92
- HashingMixin,
93
- ABC,
94
- metaclass=RegisterLanguageModelsMeta,
95
- ):
96
- """ABC for Language Models."""
97
-
98
- _model_ = None
99
- key_sequence = (
100
- None # This should be something like ["choices", 0, "message", "content"]
101
- )
102
-
103
- DEFAULT_RPM = 100
104
- DEFAULT_TPM = 1000
105
-
106
- @classproperty
107
- def response_handler(cls):
108
- key_sequence = cls.key_sequence
109
- usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
110
- return RawResponseHandler(key_sequence, usage_sequence)
111
-
112
- def __init__(
113
- self,
114
- tpm: Optional[float] = None,
115
- rpm: Optional[float] = None,
116
- omit_system_prompt_if_empty_string: bool = True,
117
- key_lookup: Optional["KeyLookup"] = None,
118
- **kwargs,
119
- ):
120
- """Initialize the LanguageModel."""
121
- self.model = getattr(self, "_model_", None)
122
- default_parameters = getattr(self, "_parameters_", None)
123
- parameters = self._overide_default_parameters(kwargs, default_parameters)
124
- self.parameters = parameters
125
- self.remote = False
126
- self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
127
-
128
- self.key_lookup = self._set_key_lookup(key_lookup)
129
- self.model_info = self.key_lookup.get(self._inference_service_)
130
-
131
- if rpm is not None:
132
- self._rpm = rpm
133
-
134
- if tpm is not None:
135
- self._tpm = tpm
136
-
137
- for key, value in parameters.items():
138
- setattr(self, key, value)
139
-
140
- for key, value in kwargs.items():
141
- if key not in parameters:
142
- setattr(self, key, value)
143
-
144
- if kwargs.get("skip_api_key_check", False):
145
- # Skip the API key check. Sometimes this is useful for testing.
146
- self._api_token = None
147
-
148
- def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
149
- """Set the key lookup."""
150
- if key_lookup is not None:
151
- return key_lookup
152
- else:
153
- klc = KeyLookupCollection()
154
- klc.add_key_lookup(fetch_order=("config", "env"))
155
- return klc.get(("config", "env"))
156
-
157
- def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
158
- del self._api_token
159
- self.key_lookup = key_lookup
160
-
161
- def ask_question(self, question: "QuestionBase") -> str:
162
- """Ask a question and return the response.
163
-
164
- :param question: The question to ask.
165
- """
166
- user_prompt = question.get_instructions().render(question.data).text
167
- system_prompt = "You are a helpful agent pretending to be a human."
168
- return self.execute_model_call(user_prompt, system_prompt)
169
-
170
- @property
171
- def rpm(self):
172
- if not hasattr(self, "_rpm"):
173
- if self.model_info is None:
174
- self._rpm = self.DEFAULT_RPM
175
- else:
176
- self._rpm = self.model_info.rpm
177
- return self._rpm
178
-
179
- @property
180
- def tpm(self):
181
- if not hasattr(self, "_tpm"):
182
- if self.model_info is None:
183
- self._tpm = self.DEFAULT_TPM
184
- else:
185
- self._tpm = self.model_info.tpm
186
- return self._tpm
187
-
188
- # in case we want to override the default values
189
- @tpm.setter
190
- def tpm(self, value):
191
- self._tpm = value
192
-
193
- @rpm.setter
194
- def rpm(self, value):
195
- self._rpm = value
196
-
197
- @property
198
- def api_token(self) -> str:
199
- if not hasattr(self, "_api_token"):
200
- info = self.key_lookup.get(self._inference_service_, None)
201
- if info is None:
202
- raise ValueError(
203
- f"No key found for service '{self._inference_service_}'"
204
- )
205
- self._api_token = info.api_token
206
- return self._api_token
207
-
208
- def __getitem__(self, key):
209
- return getattr(self, key)
210
-
211
- def hello(self, verbose=False):
212
- """Runs a simple test to check if the model is working."""
213
- token = self.api_token
214
- masked = token[: min(8, len(token))] + "..."
215
- if verbose:
216
- print(f"Current key is {masked}")
217
- return self.execute_model_call(
218
- user_prompt="Hello, model!", system_prompt="You are a helpful agent."
219
- )
220
-
221
- def has_valid_api_key(self) -> bool:
222
- """Check if the model has a valid API key.
223
-
224
- >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
225
- True
226
-
227
- This method is used to check if the model has a valid API key.
228
- """
229
- from edsl.enums import service_to_api_keyname
230
-
231
- if self._model_ == "test":
232
- return True
233
-
234
- key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
235
- key_value = os.getenv(key_name)
236
- return key_value is not None
237
-
238
- def __hash__(self) -> str:
239
- """Allow the model to be used as a key in a dictionary.
240
-
241
- >>> m = LanguageModel.example()
242
- >>> hash(m)
243
- 1811901442659237949
244
- """
245
- from edsl.utilities.utilities import dict_hash
246
-
247
- return dict_hash(self.to_dict(add_edsl_version=False))
248
-
249
- def __eq__(self, other) -> bool:
250
- """Check is two models are the same.
251
-
252
- >>> m1 = LanguageModel.example()
253
- >>> m2 = LanguageModel.example()
254
- >>> m1 == m2
255
- True
256
-
257
- """
258
- return self.model == other.model and self.parameters == other.parameters
259
-
260
- @staticmethod
261
- def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
262
- """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
263
-
264
- >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
265
- {'temperature': 0.5}
266
- >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
267
- {'temperature': 0.5, 'max_tokens': 1000}
268
- """
269
- # this is the case when data is loaded from a dict after serialization
270
- if "parameters" in passed_parameter_dict:
271
- passed_parameter_dict = passed_parameter_dict["parameters"]
272
- return {
273
- parameter_name: passed_parameter_dict.get(parameter_name, default_value)
274
- for parameter_name, default_value in default_parameter_dict.items()
275
- }
276
-
277
- def __call__(self, user_prompt: str, system_prompt: str):
278
- return self.execute_model_call(user_prompt, system_prompt)
279
-
280
- @abstractmethod
281
- async def async_execute_model_call(user_prompt: str, system_prompt: str):
282
- """Execute the model call and returns a coroutine."""
283
- pass
284
-
285
- async def remote_async_execute_model_call(
286
- self, user_prompt: str, system_prompt: str
287
- ):
288
- """Execute the model call and returns the result as a coroutine, using Coop."""
289
- from edsl.coop import Coop
290
-
291
- client = Coop()
292
- response_data = await client.remote_async_execute_model_call(
293
- self.to_dict(), user_prompt, system_prompt
294
- )
295
- return response_data
296
-
297
- @jupyter_nb_handler
298
- def execute_model_call(self, *args, **kwargs) -> Coroutine:
299
- """Execute the model call and returns the result as a coroutine."""
300
-
301
- async def main():
302
- results = await asyncio.gather(
303
- self.async_execute_model_call(*args, **kwargs)
304
- )
305
- return results[0] # Since there's only one task, return its result
306
-
307
- return main()
308
-
309
- @classmethod
310
- def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
311
- """Return the generated token string from the raw response.
312
-
313
- >>> m = LanguageModel.example(test_model = True)
314
- >>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
315
- >>> m.get_generated_token_string(raw_response)
316
- 'Hello world'
317
-
318
- """
319
- return cls.response_handler.get_generated_token_string(raw_response)
320
-
321
- @classmethod
322
- def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
323
- """Return the usage dictionary from the raw response."""
324
- return cls.response_handler.get_usage_dict(raw_response)
325
-
326
- @classmethod
327
- def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
328
- """Parses the API response and returns the response text."""
329
- return cls.response_handler.parse_response(raw_response)
330
-
331
- async def _async_get_intended_model_call_outcome(
332
- self,
333
- user_prompt: str,
334
- system_prompt: str,
335
- cache: Cache,
336
- iteration: int = 0,
337
- files_list: Optional[List[FileStore]] = None,
338
- invigilator=None,
339
- ) -> ModelResponse:
340
- """Handle caching of responses.
341
-
342
- :param user_prompt: The user's prompt.
343
- :param system_prompt: The system's prompt.
344
- :param iteration: The iteration number.
345
- :param cache: The cache to use.
346
- :param files_list: The list of files to use.
347
- :param invigilator: The invigilator to use.
348
-
349
- If the cache isn't being used, it just returns a 'fresh' call to the LLM.
350
- But if cache is being used, it first checks the database to see if the response is already there.
351
- If it is, it returns the cached response, but again appends some tracking information.
352
- If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
353
-
354
- If self.use_cache is True, then attempts to retrieve the response from the database;
355
- if not in the DB, calls the LLM and writes the response to the DB.
356
-
357
- >>> from edsl import Cache
358
- >>> m = LanguageModel.example(test_model = True)
359
- >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
360
- ModelResponse(...)"""
361
-
362
- if files_list:
363
- files_hash = "+".join([str(hash(file)) for file in files_list])
364
- user_prompt_with_hashes = user_prompt + f" {files_hash}"
365
- else:
366
- user_prompt_with_hashes = user_prompt
367
-
368
- cache_call_params = {
369
- "model": str(self.model),
370
- "parameters": self.parameters,
371
- "system_prompt": system_prompt,
372
- "user_prompt": user_prompt_with_hashes,
373
- "iteration": iteration,
374
- }
375
- cached_response, cache_key = cache.fetch(**cache_call_params)
376
-
377
- if cache_used := cached_response is not None:
378
- response = json.loads(cached_response)
379
- else:
380
- f = (
381
- self.remote_async_execute_model_call
382
- if hasattr(self, "remote") and self.remote
383
- else self.async_execute_model_call
384
- )
385
- params = {
386
- "user_prompt": user_prompt,
387
- "system_prompt": system_prompt,
388
- "files_list": files_list,
389
- }
390
- from edsl.config import CONFIG
391
-
392
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
393
-
394
- response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
395
- new_cache_key = cache.store(
396
- **cache_call_params, response=response
397
- ) # store the response in the cache
398
- assert new_cache_key == cache_key # should be the same
399
-
400
- cost = self.cost(response)
401
- return ModelResponse(
402
- response=response,
403
- cache_used=cache_used,
404
- cache_key=cache_key,
405
- cached_response=cached_response,
406
- cost=cost,
407
- )
408
-
409
- _get_intended_model_call_outcome = sync_wrapper(
410
- _async_get_intended_model_call_outcome
411
- )
412
-
413
- def simple_ask(
414
- self,
415
- question: QuestionBase,
416
- system_prompt="You are a helpful agent pretending to be a human.",
417
- top_logprobs=2,
418
- ):
419
- """Ask a question and return the response."""
420
- self.logprobs = True
421
- self.top_logprobs = top_logprobs
422
- return self.execute_model_call(
423
- user_prompt=question.human_readable(), system_prompt=system_prompt
424
- )
425
-
426
- async def async_get_response(
427
- self,
428
- user_prompt: str,
429
- system_prompt: str,
430
- cache: Cache,
431
- iteration: int = 1,
432
- files_list: Optional[List[FileStore]] = None,
433
- **kwargs,
434
- ) -> dict:
435
- """Get response, parse, and return as string.
436
-
437
- :param user_prompt: The user's prompt.
438
- :param system_prompt: The system's prompt.
439
- :param cache: The cache to use.
440
- :param iteration: The iteration number.
441
- :param files_list: The list of files to use.
442
-
443
- """
444
- params = {
445
- "user_prompt": user_prompt,
446
- "system_prompt": system_prompt,
447
- "iteration": iteration,
448
- "cache": cache,
449
- "files_list": files_list,
450
- }
451
- if "invigilator" in kwargs:
452
- params.update({"invigilator": kwargs["invigilator"]})
453
-
454
- model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
455
- model_outputs: ModelResponse = (
456
- await self._async_get_intended_model_call_outcome(**params)
457
- )
458
- edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
459
-
460
- agent_response_dict = AgentResponseDict(
461
- model_inputs=model_inputs,
462
- model_outputs=model_outputs,
463
- edsl_dict=edsl_dict,
464
- )
465
- return agent_response_dict
466
-
467
- get_response = sync_wrapper(async_get_response)
468
-
469
- def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
470
- """Return the dollar cost of a raw response.
471
-
472
- :param raw_response: The raw response from the model.
473
- """
474
-
475
- usage = self.get_usage_dict(raw_response)
476
- from edsl.language_models.PriceManager import PriceManager
477
-
478
- price_manger = PriceManager()
479
- return price_manger.calculate_cost(
480
- inference_service=self._inference_service_,
481
- model=self.model,
482
- usage=usage,
483
- input_token_name=self.input_token_name,
484
- output_token_name=self.output_token_name,
485
- )
486
-
487
- def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
488
- """Convert instance to a dictionary
489
-
490
- :param add_edsl_version: Whether to add the EDSL version to the dictionary.
491
-
492
- >>> m = LanguageModel.example()
493
- >>> m.to_dict()
494
- {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
495
- """
496
- d = {"model": self.model, "parameters": self.parameters}
497
- if add_edsl_version:
498
- from edsl import __version__
499
-
500
- d["edsl_version"] = __version__
501
- d["edsl_class_name"] = self.__class__.__name__
502
- return d
503
-
504
- @classmethod
505
- @remove_edsl_version
506
- def from_dict(cls, data: dict) -> Type[LanguageModel]:
507
- """Convert dictionary to a LanguageModel child instance."""
508
- from edsl.language_models.registry import get_model_class
509
-
510
- model_class = get_model_class(data["model"])
511
- return model_class(**data)
512
-
513
- def __repr__(self) -> str:
514
- """Return a representation of the object."""
515
- param_string = ", ".join(
516
- f"{key} = {value}" for key, value in self.parameters.items()
517
- )
518
- return (
519
- f"Model(model_name = '{self.model}'"
520
- + (f", {param_string}" if param_string else "")
521
- + ")"
522
- )
523
-
524
- def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
525
- """Combine two models into a single model (other_model takes precedence over self)."""
526
- import warnings
527
-
528
- warnings.warn(
529
- f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
530
- by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
531
- )
532
- return other_model or self
533
-
534
- @classmethod
535
- def example(
536
- cls,
537
- test_model: bool = False,
538
- canned_response: str = "Hello world",
539
- throw_exception: bool = False,
540
- ) -> LanguageModel:
541
- """Return a default instance of the class.
542
-
543
- >>> from edsl.language_models import LanguageModel
544
- >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
545
- >>> isinstance(m, LanguageModel)
546
- True
547
- >>> from edsl import QuestionFreeText
548
- >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
549
- >>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
550
- 'WOWZA!'
551
- >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
552
- >>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
553
- Exception report saved to ...
554
- Also see: ...
555
- """
556
- from edsl.language_models.registry import Model
557
-
558
- if test_model:
559
- m = Model(
560
- "test", canned_response=canned_response, throw_exception=throw_exception
561
- )
562
- return m
563
- else:
564
- return Model(skip_api_key_check=True)
565
-
566
-
567
- if __name__ == "__main__":
568
- """Run the module's test suite."""
569
- import doctest
570
-
571
- doctest.testmod(optionflags=doctest.ELLIPSIS)
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
+
3
+ Terminology:
4
+
5
+ raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
+
7
+ edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
+ such as the cache key, token usage, etc.
9
+
10
+ generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
+ edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
+
13
+ """
14
+
15
+ from __future__ import annotations
16
+ import warnings
17
+ from functools import wraps
18
+ import asyncio
19
+ import json
20
+ import os
21
+ from typing import (
22
+ Coroutine,
23
+ Any,
24
+ Callable,
25
+ Type,
26
+ Union,
27
+ List,
28
+ get_type_hints,
29
+ TypedDict,
30
+ Optional,
31
+ TYPE_CHECKING,
32
+ )
33
+ from abc import ABC, abstractmethod
34
+
35
+ from json_repair import repair_json
36
+
37
+ from edsl.data_transfer_models import (
38
+ ModelResponse,
39
+ ModelInputs,
40
+ EDSLOutput,
41
+ AgentResponseDict,
42
+ )
43
+
44
+ if TYPE_CHECKING:
45
+ from edsl.data.Cache import Cache
46
+ from edsl.scenarios.FileStore import FileStore
47
+ from edsl.questions.QuestionBase import QuestionBase
48
+
49
+ from edsl.config import CONFIG
50
+ from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
51
+ from edsl.utilities.decorators import remove_edsl_version
52
+
53
+ from edsl.Base import PersistenceMixin
54
+ from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
+ from edsl.language_models.KeyLookup import KeyLookup
56
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
57
+
58
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
59
+
60
+
61
+ # you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
62
+ # for reasons I don't understand. So leave it here.
63
+ def extract_item_from_raw_response(data, key_sequence):
64
+ if isinstance(data, str):
65
+ try:
66
+ data = json.loads(data)
67
+ except json.JSONDecodeError as e:
68
+ return data
69
+ current_data = data
70
+ for i, key in enumerate(key_sequence):
71
+ try:
72
+ if isinstance(current_data, (list, tuple)):
73
+ if not isinstance(key, int):
74
+ raise TypeError(
75
+ f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
76
+ )
77
+ if key < 0 or key >= len(current_data):
78
+ raise IndexError(
79
+ f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
80
+ )
81
+ elif isinstance(current_data, dict):
82
+ if key not in current_data:
83
+ raise KeyError(
84
+ f"Key '{key}' not found in dictionary at position {i}"
85
+ )
86
+ else:
87
+ raise TypeError(
88
+ f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
89
+ )
90
+
91
+ current_data = current_data[key]
92
+ except Exception as e:
93
+ path = " -> ".join(map(str, key_sequence[: i + 1]))
94
+ if "error" in data:
95
+ msg = data["error"]
96
+ else:
97
+ msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
98
+ raise LanguageModelBadResponseError(message=msg, response_json=data)
99
+ if isinstance(current_data, str):
100
+ return current_data.strip()
101
+ else:
102
+ return current_data
103
+
104
+
105
+ def handle_key_error(func):
106
+ """Handle KeyError exceptions."""
107
+
108
+ @wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ try:
111
+ return func(*args, **kwargs)
112
+ assert True == False
113
+ except KeyError as e:
114
+ return f"""KeyError occurred: {e}. This is most likely because the model you are using
115
+ returned a JSON object we were not expecting."""
116
+
117
+ return wrapper
118
+
119
+
120
+ class LanguageModel(
121
+ PersistenceMixin,
122
+ ABC,
123
+ metaclass=RegisterLanguageModelsMeta,
124
+ ):
125
+ """ABC for Language Models."""
126
+
127
+ _model_ = None
128
+ key_sequence = (
129
+ None # This should be something like ["choices", 0, "message", "content"]
130
+ )
131
+ __rate_limits = None
132
+ _safety_factor = 0.8
133
+
134
+ def __init__(
135
+ self,
136
+ tpm: float = None,
137
+ rpm: float = None,
138
+ omit_system_prompt_if_empty_string: bool = True,
139
+ key_lookup: Optional[KeyLookup] = None,
140
+ **kwargs,
141
+ ):
142
+ """Initialize the LanguageModel."""
143
+ self.model = getattr(self, "_model_", None)
144
+ default_parameters = getattr(self, "_parameters_", None)
145
+ parameters = self._overide_default_parameters(kwargs, default_parameters)
146
+ self.parameters = parameters
147
+ self.remote = False
148
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
149
+
150
+ # self._rpm / _tpm comes from the class
151
+ if rpm is not None:
152
+ self._rpm = rpm
153
+
154
+ if tpm is not None:
155
+ self._tpm = tpm
156
+
157
+ for key, value in parameters.items():
158
+ setattr(self, key, value)
159
+
160
+ for key, value in kwargs.items():
161
+ if key not in parameters:
162
+ setattr(self, key, value)
163
+
164
+ if "use_cache" in kwargs:
165
+ warnings.warn(
166
+ "The use_cache parameter is deprecated. Use the Cache class instead."
167
+ )
168
+
169
+ if skip_api_key_check := kwargs.get("skip_api_key_check", False):
170
+ # Skip the API key check. Sometimes this is useful for testing.
171
+ self._api_token = None
172
+
173
+ if key_lookup is not None:
174
+ self.key_lookup = key_lookup
175
+ else:
176
+ self.key_lookup = KeyLookup.from_os_environ()
177
+
178
+ def ask_question(self, question):
179
+ user_prompt = question.get_instructions().render(question.data).text
180
+ system_prompt = "You are a helpful agent pretending to be a human."
181
+ return self.execute_model_call(user_prompt, system_prompt)
182
+
183
+ def set_key_lookup(self, key_lookup: KeyLookup) -> None:
184
+ del self._api_token
185
+ self.key_lookup = key_lookup
186
+
187
+ @property
188
+ def api_token(self) -> str:
189
+ if not hasattr(self, "_api_token"):
190
+ self._api_token = self.key_lookup.get_api_token(
191
+ self._inference_service_, self.remote
192
+ )
193
+ return self._api_token
194
+
195
+ def __getitem__(self, key):
196
+ return getattr(self, key)
197
+
198
+ def _repr_html_(self) -> str:
199
+ d = {"model": self.model}
200
+ d.update(self.parameters)
201
+ data = [[k, v] for k, v in d.items()]
202
+ from tabulate import tabulate
203
+
204
+ table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
205
+ return f"<pre>{table}</pre>"
206
+
207
+ def hello(self, verbose=False):
208
+ """Runs a simple test to check if the model is working."""
209
+ token = self.api_token
210
+ masked = token[: min(8, len(token))] + "..."
211
+ if verbose:
212
+ print(f"Current key is {masked}")
213
+ return self.execute_model_call(
214
+ user_prompt="Hello, model!", system_prompt="You are a helpful agent."
215
+ )
216
+
217
+ def has_valid_api_key(self) -> bool:
218
+ """Check if the model has a valid API key.
219
+
220
+ >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
221
+ True
222
+
223
+ This method is used to check if the model has a valid API key.
224
+ """
225
+ from edsl.enums import service_to_api_keyname
226
+
227
+ if self._model_ == "test":
228
+ return True
229
+
230
+ key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
231
+ key_value = os.getenv(key_name)
232
+ return key_value is not None
233
+
234
+ def __hash__(self) -> str:
235
+ """Allow the model to be used as a key in a dictionary."""
236
+ from edsl.utilities.utilities import dict_hash
237
+
238
+ return dict_hash(self.to_dict(add_edsl_version=False))
239
+
240
+ def __eq__(self, other) -> bool:
241
+ """Check is two models are the same.
242
+
243
+ >>> m1 = LanguageModel.example()
244
+ >>> m2 = LanguageModel.example()
245
+ >>> m1 == m2
246
+ True
247
+
248
+ """
249
+ return self.model == other.model and self.parameters == other.parameters
250
+
251
+ def set_rate_limits(self, rpm=None, tpm=None) -> None:
252
+ """Set the rate limits for the model.
253
+
254
+ >>> m = LanguageModel.example()
255
+ >>> m.set_rate_limits(rpm=100, tpm=1000)
256
+ >>> m.RPM
257
+ 100
258
+ """
259
+ if rpm is not None:
260
+ self._rpm = rpm
261
+ if tpm is not None:
262
+ self._tpm = tpm
263
+ return None
264
+
265
+ @property
266
+ def RPM(self):
267
+ """Model's requests-per-minute limit."""
268
+ return self._rpm
269
+
270
+ @property
271
+ def TPM(self):
272
+ """Model's tokens-per-minute limit."""
273
+ return self._tpm
274
+
275
+ @property
276
+ def rpm(self):
277
+ return self._rpm
278
+
279
+ @rpm.setter
280
+ def rpm(self, value):
281
+ self._rpm = value
282
+
283
+ @property
284
+ def tpm(self):
285
+ return self._tpm
286
+
287
+ @tpm.setter
288
+ def tpm(self, value):
289
+ self._tpm = value
290
+
291
+ @staticmethod
292
+ def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
293
+ """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
294
+
295
+ >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
296
+ {'temperature': 0.5}
297
+ >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
298
+ {'temperature': 0.5, 'max_tokens': 1000}
299
+ """
300
+ # this is the case when data is loaded from a dict after serialization
301
+ if "parameters" in passed_parameter_dict:
302
+ passed_parameter_dict = passed_parameter_dict["parameters"]
303
+ return {
304
+ parameter_name: passed_parameter_dict.get(parameter_name, default_value)
305
+ for parameter_name, default_value in default_parameter_dict.items()
306
+ }
307
+
308
+ def __call__(self, user_prompt: str, system_prompt: str):
309
+ return self.execute_model_call(user_prompt, system_prompt)
310
+
311
+ @abstractmethod
312
+ async def async_execute_model_call(user_prompt: str, system_prompt: str):
313
+ """Execute the model call and returns a coroutine.
314
+
315
+ >>> m = LanguageModel.example(test_model = True)
316
+ >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
317
+ >>> asyncio.run(test())
318
+ {'message': [{'text': 'Hello world'}], ...}
319
+
320
+ >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
321
+ {'message': [{'text': 'Hello world'}], ...}
322
+ """
323
+ pass
324
+
325
+ async def remote_async_execute_model_call(
326
+ self, user_prompt: str, system_prompt: str
327
+ ):
328
+ """Execute the model call and returns the result as a coroutine, using Coop."""
329
+ from edsl.coop import Coop
330
+
331
+ client = Coop()
332
+ response_data = await client.remote_async_execute_model_call(
333
+ self.to_dict(), user_prompt, system_prompt
334
+ )
335
+ return response_data
336
+
337
+ @jupyter_nb_handler
338
+ def execute_model_call(self, *args, **kwargs) -> Coroutine:
339
+ """Execute the model call and returns the result as a coroutine.
340
+
341
+ >>> m = LanguageModel.example(test_model = True)
342
+ >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
343
+
344
+ """
345
+
346
+ async def main():
347
+ results = await asyncio.gather(
348
+ self.async_execute_model_call(*args, **kwargs)
349
+ )
350
+ return results[0] # Since there's only one task, return its result
351
+
352
+ return main()
353
+
354
+ @classmethod
355
+ def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
356
+ """Return the generated token string from the raw response."""
357
+ return extract_item_from_raw_response(raw_response, cls.key_sequence)
358
+
359
+ @classmethod
360
+ def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
361
+ """Return the usage dictionary from the raw response."""
362
+ if not hasattr(cls, "usage_sequence"):
363
+ raise NotImplementedError(
364
+ "This inference service does not have a usage_sequence."
365
+ )
366
+ return extract_item_from_raw_response(raw_response, cls.usage_sequence)
367
+
368
+ @staticmethod
369
+ def convert_answer(response_part):
370
+ import json
371
+
372
+ response_part = response_part.strip()
373
+
374
+ if response_part == "None":
375
+ return None
376
+
377
+ repaired = repair_json(response_part)
378
+ if repaired == '""':
379
+ # it was a literal string
380
+ return response_part
381
+
382
+ try:
383
+ return json.loads(repaired)
384
+ except json.JSONDecodeError as j:
385
+ # last resort
386
+ return response_part
387
+
388
+ @classmethod
389
+ def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
390
+ """Parses the API response and returns the response text."""
391
+ generated_token_string = cls.get_generated_token_string(raw_response)
392
+ last_newline = generated_token_string.rfind("\n")
393
+
394
+ if last_newline == -1:
395
+ # There is no comment
396
+ edsl_dict = {
397
+ "answer": cls.convert_answer(generated_token_string),
398
+ "generated_tokens": generated_token_string,
399
+ "comment": None,
400
+ }
401
+ else:
402
+ edsl_dict = {
403
+ "answer": cls.convert_answer(generated_token_string[:last_newline]),
404
+ "comment": generated_token_string[last_newline + 1 :].strip(),
405
+ "generated_tokens": generated_token_string,
406
+ }
407
+ return EDSLOutput(**edsl_dict)
408
+
409
+ async def _async_get_intended_model_call_outcome(
410
+ self,
411
+ user_prompt: str,
412
+ system_prompt: str,
413
+ cache: Cache,
414
+ iteration: int = 0,
415
+ files_list: Optional[List[FileStore]] = None,
416
+ invigilator=None,
417
+ ) -> ModelResponse:
418
+ """Handle caching of responses.
419
+
420
+ :param user_prompt: The user's prompt.
421
+ :param system_prompt: The system's prompt.
422
+ :param iteration: The iteration number.
423
+ :param cache: The cache to use.
424
+
425
+ If the cache isn't being used, it just returns a 'fresh' call to the LLM.
426
+ But if cache is being used, it first checks the database to see if the response is already there.
427
+ If it is, it returns the cached response, but again appends some tracking information.
428
+ If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
429
+
430
+ If self.use_cache is True, then attempts to retrieve the response from the database;
431
+ if not in the DB, calls the LLM and writes the response to the DB.
432
+
433
+ >>> from edsl import Cache
434
+ >>> m = LanguageModel.example(test_model = True)
435
+ >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
436
+ ModelResponse(...)"""
437
+
438
+ if files_list:
439
+ files_hash = "+".join([str(hash(file)) for file in files_list])
440
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
441
+ else:
442
+ user_prompt_with_hashes = user_prompt
443
+
444
+ cache_call_params = {
445
+ "model": str(self.model),
446
+ "parameters": self.parameters,
447
+ "system_prompt": system_prompt,
448
+ "user_prompt": user_prompt_with_hashes,
449
+ "iteration": iteration,
450
+ }
451
+ cached_response, cache_key = cache.fetch(**cache_call_params)
452
+
453
+ if cache_used := cached_response is not None:
454
+ response = json.loads(cached_response)
455
+ else:
456
+ f = (
457
+ self.remote_async_execute_model_call
458
+ if hasattr(self, "remote") and self.remote
459
+ else self.async_execute_model_call
460
+ )
461
+ params = {
462
+ "user_prompt": user_prompt,
463
+ "system_prompt": system_prompt,
464
+ "files_list": files_list,
465
+ }
466
+ response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
467
+ new_cache_key = cache.store(
468
+ **cache_call_params, response=response
469
+ ) # store the response in the cache
470
+ assert new_cache_key == cache_key # should be the same
471
+
472
+ cost = self.cost(response)
473
+
474
+ return ModelResponse(
475
+ response=response,
476
+ cache_used=cache_used,
477
+ cache_key=cache_key,
478
+ cached_response=cached_response,
479
+ cost=cost,
480
+ )
481
+
482
+ _get_intended_model_call_outcome = sync_wrapper(
483
+ _async_get_intended_model_call_outcome
484
+ )
485
+
486
+ def simple_ask(
487
+ self,
488
+ question: QuestionBase,
489
+ system_prompt="You are a helpful agent pretending to be a human.",
490
+ top_logprobs=2,
491
+ ):
492
+ """Ask a question and return the response."""
493
+ self.logprobs = True
494
+ self.top_logprobs = top_logprobs
495
+ return self.execute_model_call(
496
+ user_prompt=question.human_readable(), system_prompt=system_prompt
497
+ )
498
+
499
+ async def async_get_response(
500
+ self,
501
+ user_prompt: str,
502
+ system_prompt: str,
503
+ cache: Cache,
504
+ iteration: int = 1,
505
+ files_list: Optional[List[FileStore]] = None,
506
+ **kwargs,
507
+ ) -> dict:
508
+ """Get response, parse, and return as string.
509
+
510
+ :param user_prompt: The user's prompt.
511
+ :param system_prompt: The system's prompt.
512
+ :param iteration: The iteration number.
513
+ :param cache: The cache to use.
514
+ :param encoded_image: The encoded image to use.
515
+
516
+ """
517
+ params = {
518
+ "user_prompt": user_prompt,
519
+ "system_prompt": system_prompt,
520
+ "iteration": iteration,
521
+ "cache": cache,
522
+ "files_list": files_list,
523
+ }
524
+ if "invigilator" in kwargs:
525
+ params.update({"invigilator": kwargs["invigilator"]})
526
+
527
+ model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
528
+ model_outputs = await self._async_get_intended_model_call_outcome(**params)
529
+ edsl_dict = self.parse_response(model_outputs.response)
530
+ agent_response_dict = AgentResponseDict(
531
+ model_inputs=model_inputs,
532
+ model_outputs=model_outputs,
533
+ edsl_dict=edsl_dict,
534
+ )
535
+ return agent_response_dict
536
+
537
+ get_response = sync_wrapper(async_get_response)
538
+
539
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
540
+ """Return the dollar cost of a raw response."""
541
+
542
+ usage = self.get_usage_dict(raw_response)
543
+ from edsl.coop import Coop
544
+
545
+ c = Coop()
546
+ price_lookup = c.fetch_prices()
547
+ key = (self._inference_service_, self.model)
548
+ if key not in price_lookup:
549
+ return f"Could not find price for model {self.model} in the price lookup."
550
+
551
+ relevant_prices = price_lookup[key]
552
+ try:
553
+ input_tokens = int(usage[self.input_token_name])
554
+ output_tokens = int(usage[self.output_token_name])
555
+ except Exception as e:
556
+ return f"Could not fetch tokens from model response: {e}"
557
+
558
+ try:
559
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
560
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
561
+ except Exception as e:
562
+ if "output" not in relevant_prices:
563
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
564
+ if "input" not in relevant_prices:
565
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
566
+ return f"Could not fetch prices from {relevant_prices} - {e}"
567
+
568
+ if inverse_input_price == "infinity":
569
+ input_cost = 0
570
+ else:
571
+ try:
572
+ input_cost = input_tokens / float(inverse_input_price)
573
+ except Exception as e:
574
+ return f"Could not compute input price - {e}."
575
+
576
+ if inverse_output_price == "infinity":
577
+ output_cost = 0
578
+ else:
579
+ try:
580
+ output_cost = output_tokens / float(inverse_output_price)
581
+ except Exception as e:
582
+ return f"Could not compute output price - {e}"
583
+
584
+ return input_cost + output_cost
585
+
586
+ def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
587
+ """Convert instance to a dictionary
588
+
589
+ >>> m = LanguageModel.example()
590
+ >>> m.to_dict()
591
+ {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
592
+ """
593
+ d = {"model": self.model, "parameters": self.parameters}
594
+ if add_edsl_version:
595
+ from edsl import __version__
596
+
597
+ d["edsl_version"] = __version__
598
+ d["edsl_class_name"] = self.__class__.__name__
599
+ return d
600
+
601
+ @classmethod
602
+ @remove_edsl_version
603
+ def from_dict(cls, data: dict) -> Type[LanguageModel]:
604
+ """Convert dictionary to a LanguageModel child instance."""
605
+ from edsl.language_models.registry import get_model_class
606
+
607
+ model_class = get_model_class(data["model"])
608
+ return model_class(**data)
609
+
610
+ def __repr__(self) -> str:
611
+ """Return a string representation of the object."""
612
+ param_string = ", ".join(
613
+ f"{key} = {value}" for key, value in self.parameters.items()
614
+ )
615
+ return (
616
+ f"Model(model_name = '{self.model}'"
617
+ + (f", {param_string}" if param_string else "")
618
+ + ")"
619
+ )
620
+
621
+ def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
622
+ """Combine two models into a single model (other_model takes precedence over self)."""
623
+ import warnings
624
+
625
+ warnings.warn(
626
+ f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
627
+ by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
628
+ )
629
+ return other_model or self
630
+
631
+ @classmethod
632
+ def example(
633
+ cls,
634
+ test_model: bool = False,
635
+ canned_response: str = "Hello world",
636
+ throw_exception: bool = False,
637
+ ) -> LanguageModel:
638
+ """Return a default instance of the class.
639
+
640
+ >>> from edsl.language_models import LanguageModel
641
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
642
+ >>> isinstance(m, LanguageModel)
643
+ True
644
+ >>> from edsl import QuestionFreeText
645
+ >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
646
+ >>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
647
+ 'WOWZA!'
648
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
649
+ >>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
650
+ Exception report saved to ...
651
+ Also see: ...
652
+ """
653
+ from edsl import Model
654
+
655
+ if test_model:
656
+ m = Model(
657
+ "test", canned_response=canned_response, throw_exception=throw_exception
658
+ )
659
+ return m
660
+ else:
661
+ return Model(skip_api_key_check=True)
662
+
663
+
664
+ if __name__ == "__main__":
665
+ """Run the module's test suite."""
666
+ import doctest
667
+
668
+ doctest.testmod(optionflags=doctest.ELLIPSIS)