edsl 0.1.36.dev2__py3-none-any.whl → 0.1.36.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. edsl/Base.py +303 -298
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +47 -47
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +804 -800
  7. edsl/agents/AgentList.py +337 -337
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +294 -294
  10. edsl/agents/PromptConstructor.py +312 -311
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +86 -86
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +289 -289
  26. edsl/config.py +149 -149
  27. edsl/conjure/AgentConstructionMixin.py +152 -152
  28. edsl/conjure/Conjure.py +62 -62
  29. edsl/conjure/InputData.py +659 -659
  30. edsl/conjure/InputDataCSV.py +48 -48
  31. edsl/conjure/InputDataMixinQuestionStats.py +182 -182
  32. edsl/conjure/InputDataPyRead.py +91 -91
  33. edsl/conjure/InputDataSPSS.py +8 -8
  34. edsl/conjure/InputDataStata.py +8 -8
  35. edsl/conjure/QuestionOptionMixin.py +76 -76
  36. edsl/conjure/QuestionTypeMixin.py +23 -23
  37. edsl/conjure/RawQuestion.py +65 -65
  38. edsl/conjure/SurveyResponses.py +7 -7
  39. edsl/conjure/__init__.py +9 -9
  40. edsl/conjure/naming_utilities.py +263 -263
  41. edsl/conjure/utilities.py +201 -201
  42. edsl/conversation/Conversation.py +238 -238
  43. edsl/conversation/car_buying.py +58 -58
  44. edsl/conversation/mug_negotiation.py +81 -81
  45. edsl/conversation/next_speaker_utilities.py +93 -93
  46. edsl/coop/PriceFetcher.py +54 -58
  47. edsl/coop/__init__.py +2 -2
  48. edsl/coop/coop.py +849 -815
  49. edsl/coop/utils.py +131 -131
  50. edsl/data/Cache.py +527 -527
  51. edsl/data/CacheEntry.py +228 -228
  52. edsl/data/CacheHandler.py +149 -149
  53. edsl/data/RemoteCacheSync.py +84 -0
  54. edsl/data/SQLiteDict.py +292 -292
  55. edsl/data/__init__.py +4 -4
  56. edsl/data/orm.py +10 -10
  57. edsl/data_transfer_models.py +73 -73
  58. edsl/enums.py +173 -173
  59. edsl/exceptions/__init__.py +50 -50
  60. edsl/exceptions/agents.py +40 -40
  61. edsl/exceptions/configuration.py +16 -16
  62. edsl/exceptions/coop.py +10 -2
  63. edsl/exceptions/data.py +14 -14
  64. edsl/exceptions/general.py +34 -34
  65. edsl/exceptions/jobs.py +33 -33
  66. edsl/exceptions/language_models.py +63 -63
  67. edsl/exceptions/prompts.py +15 -15
  68. edsl/exceptions/questions.py +91 -91
  69. edsl/exceptions/results.py +26 -26
  70. edsl/exceptions/surveys.py +34 -34
  71. edsl/inference_services/AnthropicService.py +87 -87
  72. edsl/inference_services/AwsBedrock.py +115 -115
  73. edsl/inference_services/AzureAI.py +217 -217
  74. edsl/inference_services/DeepInfraService.py +18 -18
  75. edsl/inference_services/GoogleService.py +156 -156
  76. edsl/inference_services/GroqService.py +20 -20
  77. edsl/inference_services/InferenceServiceABC.py +147 -119
  78. edsl/inference_services/InferenceServicesCollection.py +72 -68
  79. edsl/inference_services/MistralAIService.py +123 -123
  80. edsl/inference_services/OllamaService.py +18 -18
  81. edsl/inference_services/OpenAIService.py +224 -224
  82. edsl/inference_services/TestService.py +89 -89
  83. edsl/inference_services/TogetherAIService.py +170 -170
  84. edsl/inference_services/models_available_cache.py +118 -94
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +39 -39
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/Answers.py +56 -56
  89. edsl/jobs/Jobs.py +1112 -1089
  90. edsl/jobs/__init__.py +1 -1
  91. edsl/jobs/buckets/BucketCollection.py +63 -63
  92. edsl/jobs/buckets/ModelBuckets.py +65 -65
  93. edsl/jobs/buckets/TokenBucket.py +248 -248
  94. edsl/jobs/interviews/Interview.py +651 -633
  95. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -90
  96. edsl/jobs/interviews/InterviewExceptionEntry.py +182 -164
  97. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  98. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  99. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  100. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  101. edsl/jobs/interviews/ReportErrors.py +66 -66
  102. edsl/jobs/interviews/interview_status_enum.py +9 -9
  103. edsl/jobs/runners/JobsRunnerAsyncio.py +337 -343
  104. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  105. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  106. edsl/jobs/tasks/TaskCreators.py +64 -64
  107. edsl/jobs/tasks/TaskHistory.py +441 -425
  108. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  109. edsl/jobs/tasks/task_status_enum.py +163 -163
  110. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  111. edsl/jobs/tokens/TokenUsage.py +34 -34
  112. edsl/language_models/LanguageModel.py +718 -718
  113. edsl/language_models/ModelList.py +102 -102
  114. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  115. edsl/language_models/__init__.py +2 -2
  116. edsl/language_models/fake_openai_call.py +15 -15
  117. edsl/language_models/fake_openai_service.py +61 -61
  118. edsl/language_models/registry.py +137 -137
  119. edsl/language_models/repair.py +156 -156
  120. edsl/language_models/unused/ReplicateBase.py +83 -83
  121. edsl/language_models/utilities.py +64 -64
  122. edsl/notebooks/Notebook.py +259 -259
  123. edsl/notebooks/__init__.py +1 -1
  124. edsl/prompts/Prompt.py +358 -358
  125. edsl/prompts/__init__.py +2 -2
  126. edsl/questions/AnswerValidatorMixin.py +289 -289
  127. edsl/questions/QuestionBase.py +616 -616
  128. edsl/questions/QuestionBaseGenMixin.py +161 -161
  129. edsl/questions/QuestionBasePromptsMixin.py +266 -266
  130. edsl/questions/QuestionBudget.py +227 -227
  131. edsl/questions/QuestionCheckBox.py +359 -359
  132. edsl/questions/QuestionExtract.py +183 -183
  133. edsl/questions/QuestionFreeText.py +113 -113
  134. edsl/questions/QuestionFunctional.py +159 -155
  135. edsl/questions/QuestionList.py +231 -231
  136. edsl/questions/QuestionMultipleChoice.py +286 -286
  137. edsl/questions/QuestionNumerical.py +153 -153
  138. edsl/questions/QuestionRank.py +324 -324
  139. edsl/questions/Quick.py +41 -41
  140. edsl/questions/RegisterQuestionsMeta.py +71 -71
  141. edsl/questions/ResponseValidatorABC.py +174 -174
  142. edsl/questions/SimpleAskMixin.py +73 -73
  143. edsl/questions/__init__.py +26 -26
  144. edsl/questions/compose_questions.py +98 -98
  145. edsl/questions/decorators.py +21 -21
  146. edsl/questions/derived/QuestionLikertFive.py +76 -76
  147. edsl/questions/derived/QuestionLinearScale.py +87 -87
  148. edsl/questions/derived/QuestionTopK.py +91 -91
  149. edsl/questions/derived/QuestionYesNo.py +82 -82
  150. edsl/questions/descriptors.py +418 -418
  151. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  152. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  153. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  154. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  155. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  156. edsl/questions/prompt_templates/question_list.jinja +17 -17
  157. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  158. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  159. edsl/questions/question_registry.py +147 -147
  160. edsl/questions/settings.py +12 -12
  161. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  162. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  163. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  164. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  165. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  166. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  167. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  168. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  169. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  170. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  171. edsl/questions/templates/list/question_presentation.jinja +5 -5
  172. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  173. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  174. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  176. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  177. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  178. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  179. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  180. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  181. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  182. edsl/results/Dataset.py +293 -281
  183. edsl/results/DatasetExportMixin.py +693 -693
  184. edsl/results/DatasetTree.py +145 -145
  185. edsl/results/Result.py +433 -431
  186. edsl/results/Results.py +1158 -1146
  187. edsl/results/ResultsDBMixin.py +238 -238
  188. edsl/results/ResultsExportMixin.py +43 -43
  189. edsl/results/ResultsFetchMixin.py +33 -33
  190. edsl/results/ResultsGGMixin.py +121 -121
  191. edsl/results/ResultsToolsMixin.py +98 -98
  192. edsl/results/Selector.py +118 -118
  193. edsl/results/__init__.py +2 -2
  194. edsl/results/tree_explore.py +115 -115
  195. edsl/scenarios/FileStore.py +443 -443
  196. edsl/scenarios/Scenario.py +507 -496
  197. edsl/scenarios/ScenarioHtmlMixin.py +59 -59
  198. edsl/scenarios/ScenarioList.py +1101 -1101
  199. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  200. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  201. edsl/scenarios/__init__.py +2 -2
  202. edsl/shared.py +1 -1
  203. edsl/study/ObjectEntry.py +173 -173
  204. edsl/study/ProofOfWork.py +113 -113
  205. edsl/study/SnapShot.py +80 -80
  206. edsl/study/Study.py +528 -528
  207. edsl/study/__init__.py +4 -4
  208. edsl/surveys/DAG.py +148 -148
  209. edsl/surveys/Memory.py +31 -31
  210. edsl/surveys/MemoryPlan.py +244 -244
  211. edsl/surveys/Rule.py +324 -324
  212. edsl/surveys/RuleCollection.py +387 -387
  213. edsl/surveys/Survey.py +1772 -1769
  214. edsl/surveys/SurveyCSS.py +261 -261
  215. edsl/surveys/SurveyExportMixin.py +259 -259
  216. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  217. edsl/surveys/SurveyQualtricsImport.py +284 -284
  218. edsl/surveys/__init__.py +3 -3
  219. edsl/surveys/base.py +53 -53
  220. edsl/surveys/descriptors.py +56 -56
  221. edsl/surveys/instructions/ChangeInstruction.py +47 -47
  222. edsl/surveys/instructions/Instruction.py +51 -34
  223. edsl/surveys/instructions/InstructionCollection.py +77 -77
  224. edsl/templates/error_reporting/base.html +23 -23
  225. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  226. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  227. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  228. edsl/templates/error_reporting/interview_details.html +115 -115
  229. edsl/templates/error_reporting/interviews.html +9 -9
  230. edsl/templates/error_reporting/overview.html +4 -4
  231. edsl/templates/error_reporting/performance_plot.html +1 -1
  232. edsl/templates/error_reporting/report.css +73 -73
  233. edsl/templates/error_reporting/report.html +117 -117
  234. edsl/templates/error_reporting/report.js +25 -25
  235. edsl/tools/__init__.py +1 -1
  236. edsl/tools/clusters.py +192 -192
  237. edsl/tools/embeddings.py +27 -27
  238. edsl/tools/embeddings_plotting.py +118 -118
  239. edsl/tools/plotting.py +112 -112
  240. edsl/tools/summarize.py +18 -18
  241. edsl/utilities/SystemInfo.py +28 -28
  242. edsl/utilities/__init__.py +22 -22
  243. edsl/utilities/ast_utilities.py +25 -25
  244. edsl/utilities/data/Registry.py +6 -6
  245. edsl/utilities/data/__init__.py +1 -1
  246. edsl/utilities/data/scooter_results.json +1 -1
  247. edsl/utilities/decorators.py +77 -77
  248. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  249. edsl/utilities/interface.py +627 -627
  250. edsl/utilities/repair_functions.py +28 -28
  251. edsl/utilities/restricted_python.py +70 -70
  252. edsl/utilities/utilities.py +391 -391
  253. {edsl-0.1.36.dev2.dist-info → edsl-0.1.36.dev6.dist-info}/LICENSE +21 -21
  254. {edsl-0.1.36.dev2.dist-info → edsl-0.1.36.dev6.dist-info}/METADATA +1 -1
  255. edsl-0.1.36.dev6.dist-info/RECORD +279 -0
  256. edsl-0.1.36.dev2.dist-info/RECORD +0 -278
  257. {edsl-0.1.36.dev2.dist-info → edsl-0.1.36.dev6.dist-info}/WHEEL +0 -0
@@ -1,718 +1,718 @@
1
- """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
-
3
- Terminology:
4
-
5
- raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
-
7
- edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
- such as the cache key, token usage, etc.
9
-
10
- generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
- edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
-
13
- """
14
-
15
- from __future__ import annotations
16
- import warnings
17
- from functools import wraps
18
- import asyncio
19
- import json
20
- import time
21
- import os
22
- import hashlib
23
- from typing import (
24
- Coroutine,
25
- Any,
26
- Callable,
27
- Type,
28
- Union,
29
- List,
30
- get_type_hints,
31
- TypedDict,
32
- Optional,
33
- )
34
- from abc import ABC, abstractmethod
35
-
36
- from json_repair import repair_json
37
-
38
- from edsl.data_transfer_models import (
39
- ModelResponse,
40
- ModelInputs,
41
- EDSLOutput,
42
- AgentResponseDict,
43
- )
44
-
45
-
46
- from edsl.config import CONFIG
47
- from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
48
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
49
- from edsl.language_models.repair import repair
50
- from edsl.enums import InferenceServiceType
51
- from edsl.Base import RichPrintingMixin, PersistenceMixin
52
- from edsl.enums import service_to_api_keyname
53
- from edsl.exceptions import MissingAPIKeyError
54
- from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
- from edsl.exceptions.language_models import LanguageModelBadResponseError
56
-
57
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
58
-
59
-
60
- def convert_answer(response_part):
61
- import json
62
-
63
- response_part = response_part.strip()
64
-
65
- if response_part == "None":
66
- return None
67
-
68
- repaired = repair_json(response_part)
69
- if repaired == '""':
70
- # it was a literal string
71
- return response_part
72
-
73
- try:
74
- return json.loads(repaired)
75
- except json.JSONDecodeError as j:
76
- # last resort
77
- return response_part
78
-
79
-
80
- def extract_item_from_raw_response(data, key_sequence):
81
- if isinstance(data, str):
82
- try:
83
- data = json.loads(data)
84
- except json.JSONDecodeError as e:
85
- return data
86
- current_data = data
87
- for i, key in enumerate(key_sequence):
88
- try:
89
- if isinstance(current_data, (list, tuple)):
90
- if not isinstance(key, int):
91
- raise TypeError(
92
- f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
93
- )
94
- if key < 0 or key >= len(current_data):
95
- raise IndexError(
96
- f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
97
- )
98
- elif isinstance(current_data, dict):
99
- if key not in current_data:
100
- raise KeyError(
101
- f"Key '{key}' not found in dictionary at position {i}"
102
- )
103
- else:
104
- raise TypeError(
105
- f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
106
- )
107
-
108
- current_data = current_data[key]
109
- except Exception as e:
110
- path = " -> ".join(map(str, key_sequence[: i + 1]))
111
- if "error" in data:
112
- msg = data["error"]
113
- else:
114
- msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
115
- raise LanguageModelBadResponseError(message=msg, response_json=data)
116
- if isinstance(current_data, str):
117
- return current_data.strip()
118
- else:
119
- return current_data
120
-
121
-
122
- def handle_key_error(func):
123
- """Handle KeyError exceptions."""
124
-
125
- @wraps(func)
126
- def wrapper(*args, **kwargs):
127
- try:
128
- return func(*args, **kwargs)
129
- assert True == False
130
- except KeyError as e:
131
- return f"""KeyError occurred: {e}. This is most likely because the model you are using
132
- returned a JSON object we were not expecting."""
133
-
134
- return wrapper
135
-
136
-
137
- class LanguageModel(
138
- RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
139
- ):
140
- """ABC for LLM subclasses.
141
-
142
- TODO:
143
-
144
- 1) Need better, more descriptive names for functions
145
-
146
- get_model_response_no_cache (currently called async_execute_model_call)
147
-
148
- get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
149
- Calls:
150
- - async_execute_model_call
151
- - _updated_model_response_with_tracking
152
-
153
- get_answer (currently called async_get_response)
154
- This parses out the answer block and does some error-handling.
155
- Calls:
156
- - async_get_raw_response
157
- - parse_response
158
-
159
-
160
- """
161
-
162
- _model_ = None
163
- key_sequence = (
164
- None # This should be something like ["choices", 0, "message", "content"]
165
- )
166
- __rate_limits = None
167
- _safety_factor = 0.8
168
-
169
- def __init__(
170
- self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
- ):
172
- """Initialize the LanguageModel."""
173
- self.model = getattr(self, "_model_", None)
174
- default_parameters = getattr(self, "_parameters_", None)
175
- parameters = self._overide_default_parameters(kwargs, default_parameters)
176
- self.parameters = parameters
177
- self.remote = False
178
- self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
179
-
180
- # self._rpm / _tpm comes from the class
181
- if rpm is not None:
182
- self._rpm = rpm
183
-
184
- if tpm is not None:
185
- self._tpm = tpm
186
-
187
- for key, value in parameters.items():
188
- setattr(self, key, value)
189
-
190
- for key, value in kwargs.items():
191
- if key not in parameters:
192
- setattr(self, key, value)
193
-
194
- if "use_cache" in kwargs:
195
- warnings.warn(
196
- "The use_cache parameter is deprecated. Use the Cache class instead."
197
- )
198
-
199
- if skip_api_key_check := kwargs.get("skip_api_key_check", False):
200
- # Skip the API key check. Sometimes this is useful for testing.
201
- self._api_token = None
202
-
203
- def ask_question(self, question):
204
- user_prompt = question.get_instructions().render(question.data).text
205
- system_prompt = "You are a helpful agent pretending to be a human."
206
- return self.execute_model_call(user_prompt, system_prompt)
207
-
208
- @property
209
- def api_token(self) -> str:
210
- if not hasattr(self, "_api_token"):
211
- key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
212
- if self._inference_service_ == "bedrock":
213
- self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
214
- # Check if any of the tokens are None
215
- missing_token = any(token is None for token in self._api_token)
216
- else:
217
- self._api_token = os.getenv(key_name)
218
- missing_token = self._api_token is None
219
- if missing_token and self._inference_service_ != "test" and not self.remote:
220
- print("raising error")
221
- raise MissingAPIKeyError(
222
- f"""The key for service: `{self._inference_service_}` is not set.
223
- Need a key with name {key_name} in your .env file."""
224
- )
225
-
226
- return self._api_token
227
-
228
- def __getitem__(self, key):
229
- return getattr(self, key)
230
-
231
- def _repr_html_(self):
232
- from edsl.utilities.utilities import data_to_html
233
-
234
- return data_to_html(self.to_dict())
235
-
236
- def hello(self, verbose=False):
237
- """Runs a simple test to check if the model is working."""
238
- token = self.api_token
239
- masked = token[: min(8, len(token))] + "..."
240
- if verbose:
241
- print(f"Current key is {masked}")
242
- return self.execute_model_call(
243
- user_prompt="Hello, model!", system_prompt="You are a helpful agent."
244
- )
245
-
246
- def has_valid_api_key(self) -> bool:
247
- """Check if the model has a valid API key.
248
-
249
- >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
250
- True
251
-
252
- This method is used to check if the model has a valid API key.
253
- """
254
- from edsl.enums import service_to_api_keyname
255
- import os
256
-
257
- if self._model_ == "test":
258
- return True
259
-
260
- key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
261
- key_value = os.getenv(key_name)
262
- return key_value is not None
263
-
264
- def __hash__(self) -> str:
265
- """Allow the model to be used as a key in a dictionary."""
266
- from edsl.utilities.utilities import dict_hash
267
-
268
- return dict_hash(self.to_dict())
269
-
270
- def __eq__(self, other):
271
- """Check is two models are the same.
272
-
273
- >>> m1 = LanguageModel.example()
274
- >>> m2 = LanguageModel.example()
275
- >>> m1 == m2
276
- True
277
-
278
- """
279
- return self.model == other.model and self.parameters == other.parameters
280
-
281
- def set_rate_limits(self, rpm=None, tpm=None) -> None:
282
- """Set the rate limits for the model.
283
-
284
- >>> m = LanguageModel.example()
285
- >>> m.set_rate_limits(rpm=100, tpm=1000)
286
- >>> m.RPM
287
- 100
288
- """
289
- if rpm is not None:
290
- self._rpm = rpm
291
- if tpm is not None:
292
- self._tpm = tpm
293
- return None
294
- # self._set_rate_limits(rpm=rpm, tpm=tpm)
295
-
296
- # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
- # """Set the rate limits for the model.
298
-
299
- # If the model does not have rate limits, use the default rate limits."""
300
- # if rpm is not None and tpm is not None:
301
- # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
- # return
303
-
304
- # if self.__rate_limits is None:
305
- # if hasattr(self, "get_rate_limits"):
306
- # self.__rate_limits = self.get_rate_limits()
307
- # else:
308
- # self.__rate_limits = self.__default_rate_limits
309
-
310
- @property
311
- def RPM(self):
312
- """Model's requests-per-minute limit."""
313
- # self._set_rate_limits()
314
- # return self._safety_factor * self.__rate_limits["rpm"]
315
- return self._rpm
316
-
317
- @property
318
- def TPM(self):
319
- """Model's tokens-per-minute limit."""
320
- # self._set_rate_limits()
321
- # return self._safety_factor * self.__rate_limits["tpm"]
322
- return self._tpm
323
-
324
- @property
325
- def rpm(self):
326
- return self._rpm
327
-
328
- @rpm.setter
329
- def rpm(self, value):
330
- self._rpm = value
331
-
332
- @property
333
- def tpm(self):
334
- return self._tpm
335
-
336
- @tpm.setter
337
- def tpm(self, value):
338
- self._tpm = value
339
-
340
- @staticmethod
341
- def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
342
- """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
343
-
344
- >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
345
- {'temperature': 0.5}
346
- >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
347
- {'temperature': 0.5, 'max_tokens': 1000}
348
- """
349
- # parameters = dict({})
350
-
351
- # this is the case when data is loaded from a dict after serialization
352
- if "parameters" in passed_parameter_dict:
353
- passed_parameter_dict = passed_parameter_dict["parameters"]
354
- return {
355
- parameter_name: passed_parameter_dict.get(parameter_name, default_value)
356
- for parameter_name, default_value in default_parameter_dict.items()
357
- }
358
-
359
- def __call__(self, user_prompt: str, system_prompt: str):
360
- return self.execute_model_call(user_prompt, system_prompt)
361
-
362
- @abstractmethod
363
- async def async_execute_model_call(user_prompt: str, system_prompt: str):
364
- """Execute the model call and returns a coroutine.
365
-
366
- >>> m = LanguageModel.example(test_model = True)
367
- >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
368
- >>> asyncio.run(test())
369
- {'message': [{'text': 'Hello world'}], ...}
370
-
371
- >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
372
- {'message': [{'text': 'Hello world'}], ...}
373
- """
374
- pass
375
-
376
- async def remote_async_execute_model_call(
377
- self, user_prompt: str, system_prompt: str
378
- ):
379
- """Execute the model call and returns the result as a coroutine, using Coop."""
380
- from edsl.coop import Coop
381
-
382
- client = Coop()
383
- response_data = await client.remote_async_execute_model_call(
384
- self.to_dict(), user_prompt, system_prompt
385
- )
386
- return response_data
387
-
388
- @jupyter_nb_handler
389
- def execute_model_call(self, *args, **kwargs) -> Coroutine:
390
- """Execute the model call and returns the result as a coroutine.
391
-
392
- >>> m = LanguageModel.example(test_model = True)
393
- >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
394
-
395
- """
396
-
397
- async def main():
398
- results = await asyncio.gather(
399
- self.async_execute_model_call(*args, **kwargs)
400
- )
401
- return results[0] # Since there's only one task, return its result
402
-
403
- return main()
404
-
405
- @classmethod
406
- def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
407
- """Return the generated token string from the raw response."""
408
- return extract_item_from_raw_response(raw_response, cls.key_sequence)
409
-
410
- @classmethod
411
- def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
412
- """Return the usage dictionary from the raw response."""
413
- if not hasattr(cls, "usage_sequence"):
414
- raise NotImplementedError(
415
- "This inference service does not have a usage_sequence."
416
- )
417
- return extract_item_from_raw_response(raw_response, cls.usage_sequence)
418
-
419
- @classmethod
420
- def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
421
- """Parses the API response and returns the response text."""
422
- generated_token_string = cls.get_generated_token_string(raw_response)
423
- last_newline = generated_token_string.rfind("\n")
424
-
425
- if last_newline == -1:
426
- # There is no comment
427
- edsl_dict = {
428
- "answer": convert_answer(generated_token_string),
429
- "generated_tokens": generated_token_string,
430
- "comment": None,
431
- }
432
- else:
433
- edsl_dict = {
434
- "answer": convert_answer(generated_token_string[:last_newline]),
435
- "comment": generated_token_string[last_newline + 1 :].strip(),
436
- "generated_tokens": generated_token_string,
437
- }
438
- return EDSLOutput(**edsl_dict)
439
-
440
- async def _async_get_intended_model_call_outcome(
441
- self,
442
- user_prompt: str,
443
- system_prompt: str,
444
- cache: "Cache",
445
- iteration: int = 0,
446
- files_list=None,
447
- ) -> ModelResponse:
448
- """Handle caching of responses.
449
-
450
- :param user_prompt: The user's prompt.
451
- :param system_prompt: The system's prompt.
452
- :param iteration: The iteration number.
453
- :param cache: The cache to use.
454
-
455
- If the cache isn't being used, it just returns a 'fresh' call to the LLM.
456
- But if cache is being used, it first checks the database to see if the response is already there.
457
- If it is, it returns the cached response, but again appends some tracking information.
458
- If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
459
-
460
- If self.use_cache is True, then attempts to retrieve the response from the database;
461
- if not in the DB, calls the LLM and writes the response to the DB.
462
-
463
- >>> from edsl import Cache
464
- >>> m = LanguageModel.example(test_model = True)
465
- >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
466
- ModelResponse(...)"""
467
-
468
- if files_list:
469
- files_hash = "+".join([str(hash(file)) for file in files_list])
470
- # print(f"Files hash: {files_hash}")
471
- user_prompt_with_hashes = user_prompt + f" {files_hash}"
472
- else:
473
- user_prompt_with_hashes = user_prompt
474
-
475
- cache_call_params = {
476
- "model": str(self.model),
477
- "parameters": self.parameters,
478
- "system_prompt": system_prompt,
479
- "user_prompt": user_prompt_with_hashes,
480
- "iteration": iteration,
481
- }
482
- cached_response, cache_key = cache.fetch(**cache_call_params)
483
-
484
- if cache_used := cached_response is not None:
485
- response = json.loads(cached_response)
486
- else:
487
- f = (
488
- self.remote_async_execute_model_call
489
- if hasattr(self, "remote") and self.remote
490
- else self.async_execute_model_call
491
- )
492
- params = {
493
- "user_prompt": user_prompt,
494
- "system_prompt": system_prompt,
495
- "files_list": files_list
496
- # **({"encoded_image": encoded_image} if encoded_image else {}),
497
- }
498
- # response = await f(**params)
499
- response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
500
- new_cache_key = cache.store(
501
- **cache_call_params, response=response
502
- ) # store the response in the cache
503
- assert new_cache_key == cache_key # should be the same
504
-
505
- cost = self.cost(response)
506
-
507
- return ModelResponse(
508
- response=response,
509
- cache_used=cache_used,
510
- cache_key=cache_key,
511
- cached_response=cached_response,
512
- cost=cost,
513
- )
514
-
515
- _get_intended_model_call_outcome = sync_wrapper(
516
- _async_get_intended_model_call_outcome
517
- )
518
-
519
- # get_raw_response = sync_wrapper(async_get_raw_response)
520
-
521
- def simple_ask(
522
- self,
523
- question: "QuestionBase",
524
- system_prompt="You are a helpful agent pretending to be a human.",
525
- top_logprobs=2,
526
- ):
527
- """Ask a question and return the response."""
528
- self.logprobs = True
529
- self.top_logprobs = top_logprobs
530
- return self.execute_model_call(
531
- user_prompt=question.human_readable(), system_prompt=system_prompt
532
- )
533
-
534
- async def async_get_response(
535
- self,
536
- user_prompt: str,
537
- system_prompt: str,
538
- cache: "Cache",
539
- iteration: int = 1,
540
- files_list: Optional[List["File"]] = None,
541
- ) -> dict:
542
- """Get response, parse, and return as string.
543
-
544
- :param user_prompt: The user's prompt.
545
- :param system_prompt: The system's prompt.
546
- :param iteration: The iteration number.
547
- :param cache: The cache to use.
548
- :param encoded_image: The encoded image to use.
549
-
550
- """
551
- params = {
552
- "user_prompt": user_prompt,
553
- "system_prompt": system_prompt,
554
- "iteration": iteration,
555
- "cache": cache,
556
- "files_list": files_list,
557
- }
558
- model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
559
- model_outputs = await self._async_get_intended_model_call_outcome(**params)
560
- edsl_dict = self.parse_response(model_outputs.response)
561
- agent_response_dict = AgentResponseDict(
562
- model_inputs=model_inputs,
563
- model_outputs=model_outputs,
564
- edsl_dict=edsl_dict,
565
- )
566
- return agent_response_dict
567
-
568
- # return await self._async_prepare_response(model_call_outcome, cache=cache)
569
-
570
- get_response = sync_wrapper(async_get_response)
571
-
572
- def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
573
- """Return the dollar cost of a raw response."""
574
-
575
- usage = self.get_usage_dict(raw_response)
576
- from edsl.coop import Coop
577
-
578
- c = Coop()
579
- price_lookup = c.fetch_prices()
580
- key = (self._inference_service_, self.model)
581
- if key not in price_lookup:
582
- return f"Could not find price for model {self.model} in the price lookup."
583
-
584
- relevant_prices = price_lookup[key]
585
- try:
586
- input_tokens = int(usage[self.input_token_name])
587
- output_tokens = int(usage[self.output_token_name])
588
- except Exception as e:
589
- return f"Could not fetch tokens from model response: {e}"
590
-
591
- try:
592
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
593
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
594
- except Exception as e:
595
- if "output" not in relevant_prices:
596
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
597
- if "input" not in relevant_prices:
598
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
599
- return f"Could not fetch prices from {relevant_prices} - {e}"
600
-
601
- if inverse_input_price == "infinity":
602
- input_cost = 0
603
- else:
604
- try:
605
- input_cost = input_tokens / float(inverse_input_price)
606
- except Exception as e:
607
- return f"Could not compute input price - {e}."
608
-
609
- if inverse_output_price == "infinity":
610
- output_cost = 0
611
- else:
612
- try:
613
- output_cost = output_tokens / float(inverse_output_price)
614
- except Exception as e:
615
- return f"Could not compute output price - {e}"
616
-
617
- return input_cost + output_cost
618
-
619
- #######################
620
- # SERIALIZATION METHODS
621
- #######################
622
- def _to_dict(self) -> dict[str, Any]:
623
- return {"model": self.model, "parameters": self.parameters}
624
-
625
- @add_edsl_version
626
- def to_dict(self) -> dict[str, Any]:
627
- """Convert instance to a dictionary.
628
-
629
- >>> m = LanguageModel.example()
630
- >>> m.to_dict()
631
- {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
632
- """
633
- return self._to_dict()
634
-
635
- @classmethod
636
- @remove_edsl_version
637
- def from_dict(cls, data: dict) -> Type[LanguageModel]:
638
- """Convert dictionary to a LanguageModel child instance."""
639
- from edsl.language_models.registry import get_model_class
640
-
641
- model_class = get_model_class(data["model"])
642
- # data["use_cache"] = True
643
- return model_class(**data)
644
-
645
- #######################
646
- # DUNDER METHODS
647
- #######################
648
- def print(self):
649
- from rich import print_json
650
- import json
651
-
652
- print_json(json.dumps(self.to_dict()))
653
-
654
- def __repr__(self) -> str:
655
- """Return a string representation of the object."""
656
- param_string = ", ".join(
657
- f"{key} = {value}" for key, value in self.parameters.items()
658
- )
659
- return (
660
- f"Model(model_name = '{self.model}'"
661
- + (f", {param_string}" if param_string else "")
662
- + ")"
663
- )
664
-
665
- def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
666
- """Combine two models into a single model (other_model takes precedence over self)."""
667
- print(
668
- f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
669
- by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
670
- )
671
- return other_model or self
672
-
673
- def rich_print(self):
674
- """Display an object as a table."""
675
- from rich.table import Table
676
-
677
- table = Table(title="Language Model")
678
- table.add_column("Attribute", style="bold")
679
- table.add_column("Value")
680
-
681
- to_display = self.__dict__.copy()
682
- for attr_name, attr_value in to_display.items():
683
- table.add_row(attr_name, repr(attr_value))
684
-
685
- return table
686
-
687
- @classmethod
688
- def example(
689
- cls,
690
- test_model: bool = False,
691
- canned_response: str = "Hello world",
692
- throw_exception: bool = False,
693
- ):
694
- """Return a default instance of the class.
695
-
696
- >>> from edsl.language_models import LanguageModel
697
- >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
698
- >>> isinstance(m, LanguageModel)
699
- True
700
- >>> from edsl import QuestionFreeText
701
- >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
702
- >>> q.by(m).run(cache = False).select('example').first()
703
- 'WOWZA!'
704
- """
705
- from edsl import Model
706
-
707
- if test_model:
708
- m = Model("test", canned_response=canned_response)
709
- return m
710
- else:
711
- return Model(skip_api_key_check=True)
712
-
713
-
714
- if __name__ == "__main__":
715
- """Run the module's test suite."""
716
- import doctest
717
-
718
- doctest.testmod(optionflags=doctest.ELLIPSIS)
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
+
3
+ Terminology:
4
+
5
+ raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
+
7
+ edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
+ such as the cache key, token usage, etc.
9
+
10
+ generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
+ edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
+
13
+ """
14
+
15
+ from __future__ import annotations
16
+ import warnings
17
+ from functools import wraps
18
+ import asyncio
19
+ import json
20
+ import time
21
+ import os
22
+ import hashlib
23
+ from typing import (
24
+ Coroutine,
25
+ Any,
26
+ Callable,
27
+ Type,
28
+ Union,
29
+ List,
30
+ get_type_hints,
31
+ TypedDict,
32
+ Optional,
33
+ )
34
+ from abc import ABC, abstractmethod
35
+
36
+ from json_repair import repair_json
37
+
38
+ from edsl.data_transfer_models import (
39
+ ModelResponse,
40
+ ModelInputs,
41
+ EDSLOutput,
42
+ AgentResponseDict,
43
+ )
44
+
45
+
46
+ from edsl.config import CONFIG
47
+ from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
48
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
49
+ from edsl.language_models.repair import repair
50
+ from edsl.enums import InferenceServiceType
51
+ from edsl.Base import RichPrintingMixin, PersistenceMixin
52
+ from edsl.enums import service_to_api_keyname
53
+ from edsl.exceptions import MissingAPIKeyError
54
+ from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
56
+
57
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
58
+
59
+
60
+ def convert_answer(response_part):
61
+ import json
62
+
63
+ response_part = response_part.strip()
64
+
65
+ if response_part == "None":
66
+ return None
67
+
68
+ repaired = repair_json(response_part)
69
+ if repaired == '""':
70
+ # it was a literal string
71
+ return response_part
72
+
73
+ try:
74
+ return json.loads(repaired)
75
+ except json.JSONDecodeError as j:
76
+ # last resort
77
+ return response_part
78
+
79
+
80
+ def extract_item_from_raw_response(data, key_sequence):
81
+ if isinstance(data, str):
82
+ try:
83
+ data = json.loads(data)
84
+ except json.JSONDecodeError as e:
85
+ return data
86
+ current_data = data
87
+ for i, key in enumerate(key_sequence):
88
+ try:
89
+ if isinstance(current_data, (list, tuple)):
90
+ if not isinstance(key, int):
91
+ raise TypeError(
92
+ f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
93
+ )
94
+ if key < 0 or key >= len(current_data):
95
+ raise IndexError(
96
+ f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
97
+ )
98
+ elif isinstance(current_data, dict):
99
+ if key not in current_data:
100
+ raise KeyError(
101
+ f"Key '{key}' not found in dictionary at position {i}"
102
+ )
103
+ else:
104
+ raise TypeError(
105
+ f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
106
+ )
107
+
108
+ current_data = current_data[key]
109
+ except Exception as e:
110
+ path = " -> ".join(map(str, key_sequence[: i + 1]))
111
+ if "error" in data:
112
+ msg = data["error"]
113
+ else:
114
+ msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
115
+ raise LanguageModelBadResponseError(message=msg, response_json=data)
116
+ if isinstance(current_data, str):
117
+ return current_data.strip()
118
+ else:
119
+ return current_data
120
+
121
+
122
+ def handle_key_error(func):
123
+ """Handle KeyError exceptions."""
124
+
125
+ @wraps(func)
126
+ def wrapper(*args, **kwargs):
127
+ try:
128
+ return func(*args, **kwargs)
129
+ assert True == False
130
+ except KeyError as e:
131
+ return f"""KeyError occurred: {e}. This is most likely because the model you are using
132
+ returned a JSON object we were not expecting."""
133
+
134
+ return wrapper
135
+
136
+
137
+ class LanguageModel(
138
+ RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
139
+ ):
140
+ """ABC for LLM subclasses.
141
+
142
+ TODO:
143
+
144
+ 1) Need better, more descriptive names for functions
145
+
146
+ get_model_response_no_cache (currently called async_execute_model_call)
147
+
148
+ get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
149
+ Calls:
150
+ - async_execute_model_call
151
+ - _updated_model_response_with_tracking
152
+
153
+ get_answer (currently called async_get_response)
154
+ This parses out the answer block and does some error-handling.
155
+ Calls:
156
+ - async_get_raw_response
157
+ - parse_response
158
+
159
+
160
+ """
161
+
162
+ _model_ = None
163
+ key_sequence = (
164
+ None # This should be something like ["choices", 0, "message", "content"]
165
+ )
166
+ __rate_limits = None
167
+ _safety_factor = 0.8
168
+
169
+ def __init__(
170
+ self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
+ ):
172
+ """Initialize the LanguageModel."""
173
+ self.model = getattr(self, "_model_", None)
174
+ default_parameters = getattr(self, "_parameters_", None)
175
+ parameters = self._overide_default_parameters(kwargs, default_parameters)
176
+ self.parameters = parameters
177
+ self.remote = False
178
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
179
+
180
+ # self._rpm / _tpm comes from the class
181
+ if rpm is not None:
182
+ self._rpm = rpm
183
+
184
+ if tpm is not None:
185
+ self._tpm = tpm
186
+
187
+ for key, value in parameters.items():
188
+ setattr(self, key, value)
189
+
190
+ for key, value in kwargs.items():
191
+ if key not in parameters:
192
+ setattr(self, key, value)
193
+
194
+ if "use_cache" in kwargs:
195
+ warnings.warn(
196
+ "The use_cache parameter is deprecated. Use the Cache class instead."
197
+ )
198
+
199
+ if skip_api_key_check := kwargs.get("skip_api_key_check", False):
200
+ # Skip the API key check. Sometimes this is useful for testing.
201
+ self._api_token = None
202
+
203
+ def ask_question(self, question):
204
+ user_prompt = question.get_instructions().render(question.data).text
205
+ system_prompt = "You are a helpful agent pretending to be a human."
206
+ return self.execute_model_call(user_prompt, system_prompt)
207
+
208
+ @property
209
+ def api_token(self) -> str:
210
+ if not hasattr(self, "_api_token"):
211
+ key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
212
+ if self._inference_service_ == "bedrock":
213
+ self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
214
+ # Check if any of the tokens are None
215
+ missing_token = any(token is None for token in self._api_token)
216
+ else:
217
+ self._api_token = os.getenv(key_name)
218
+ missing_token = self._api_token is None
219
+ if missing_token and self._inference_service_ != "test" and not self.remote:
220
+ print("raising error")
221
+ raise MissingAPIKeyError(
222
+ f"""The key for service: `{self._inference_service_}` is not set.
223
+ Need a key with name {key_name} in your .env file."""
224
+ )
225
+
226
+ return self._api_token
227
+
228
+ def __getitem__(self, key):
229
+ return getattr(self, key)
230
+
231
+ def _repr_html_(self):
232
+ from edsl.utilities.utilities import data_to_html
233
+
234
+ return data_to_html(self.to_dict())
235
+
236
+ def hello(self, verbose=False):
237
+ """Runs a simple test to check if the model is working."""
238
+ token = self.api_token
239
+ masked = token[: min(8, len(token))] + "..."
240
+ if verbose:
241
+ print(f"Current key is {masked}")
242
+ return self.execute_model_call(
243
+ user_prompt="Hello, model!", system_prompt="You are a helpful agent."
244
+ )
245
+
246
+ def has_valid_api_key(self) -> bool:
247
+ """Check if the model has a valid API key.
248
+
249
+ >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
250
+ True
251
+
252
+ This method is used to check if the model has a valid API key.
253
+ """
254
+ from edsl.enums import service_to_api_keyname
255
+ import os
256
+
257
+ if self._model_ == "test":
258
+ return True
259
+
260
+ key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
261
+ key_value = os.getenv(key_name)
262
+ return key_value is not None
263
+
264
+ def __hash__(self) -> str:
265
+ """Allow the model to be used as a key in a dictionary."""
266
+ from edsl.utilities.utilities import dict_hash
267
+
268
+ return dict_hash(self.to_dict())
269
+
270
+ def __eq__(self, other):
271
+ """Check is two models are the same.
272
+
273
+ >>> m1 = LanguageModel.example()
274
+ >>> m2 = LanguageModel.example()
275
+ >>> m1 == m2
276
+ True
277
+
278
+ """
279
+ return self.model == other.model and self.parameters == other.parameters
280
+
281
+ def set_rate_limits(self, rpm=None, tpm=None) -> None:
282
+ """Set the rate limits for the model.
283
+
284
+ >>> m = LanguageModel.example()
285
+ >>> m.set_rate_limits(rpm=100, tpm=1000)
286
+ >>> m.RPM
287
+ 100
288
+ """
289
+ if rpm is not None:
290
+ self._rpm = rpm
291
+ if tpm is not None:
292
+ self._tpm = tpm
293
+ return None
294
+ # self._set_rate_limits(rpm=rpm, tpm=tpm)
295
+
296
+ # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
+ # """Set the rate limits for the model.
298
+
299
+ # If the model does not have rate limits, use the default rate limits."""
300
+ # if rpm is not None and tpm is not None:
301
+ # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
+ # return
303
+
304
+ # if self.__rate_limits is None:
305
+ # if hasattr(self, "get_rate_limits"):
306
+ # self.__rate_limits = self.get_rate_limits()
307
+ # else:
308
+ # self.__rate_limits = self.__default_rate_limits
309
+
310
+ @property
311
+ def RPM(self):
312
+ """Model's requests-per-minute limit."""
313
+ # self._set_rate_limits()
314
+ # return self._safety_factor * self.__rate_limits["rpm"]
315
+ return self._rpm
316
+
317
+ @property
318
+ def TPM(self):
319
+ """Model's tokens-per-minute limit."""
320
+ # self._set_rate_limits()
321
+ # return self._safety_factor * self.__rate_limits["tpm"]
322
+ return self._tpm
323
+
324
+ @property
325
+ def rpm(self):
326
+ return self._rpm
327
+
328
+ @rpm.setter
329
+ def rpm(self, value):
330
+ self._rpm = value
331
+
332
+ @property
333
+ def tpm(self):
334
+ return self._tpm
335
+
336
+ @tpm.setter
337
+ def tpm(self, value):
338
+ self._tpm = value
339
+
340
+ @staticmethod
341
+ def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
342
+ """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
343
+
344
+ >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
345
+ {'temperature': 0.5}
346
+ >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
347
+ {'temperature': 0.5, 'max_tokens': 1000}
348
+ """
349
+ # parameters = dict({})
350
+
351
+ # this is the case when data is loaded from a dict after serialization
352
+ if "parameters" in passed_parameter_dict:
353
+ passed_parameter_dict = passed_parameter_dict["parameters"]
354
+ return {
355
+ parameter_name: passed_parameter_dict.get(parameter_name, default_value)
356
+ for parameter_name, default_value in default_parameter_dict.items()
357
+ }
358
+
359
+ def __call__(self, user_prompt: str, system_prompt: str):
360
+ return self.execute_model_call(user_prompt, system_prompt)
361
+
362
+ @abstractmethod
363
+ async def async_execute_model_call(user_prompt: str, system_prompt: str):
364
+ """Execute the model call and returns a coroutine.
365
+
366
+ >>> m = LanguageModel.example(test_model = True)
367
+ >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
368
+ >>> asyncio.run(test())
369
+ {'message': [{'text': 'Hello world'}], ...}
370
+
371
+ >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
372
+ {'message': [{'text': 'Hello world'}], ...}
373
+ """
374
+ pass
375
+
376
+ async def remote_async_execute_model_call(
377
+ self, user_prompt: str, system_prompt: str
378
+ ):
379
+ """Execute the model call and returns the result as a coroutine, using Coop."""
380
+ from edsl.coop import Coop
381
+
382
+ client = Coop()
383
+ response_data = await client.remote_async_execute_model_call(
384
+ self.to_dict(), user_prompt, system_prompt
385
+ )
386
+ return response_data
387
+
388
+ @jupyter_nb_handler
389
+ def execute_model_call(self, *args, **kwargs) -> Coroutine:
390
+ """Execute the model call and returns the result as a coroutine.
391
+
392
+ >>> m = LanguageModel.example(test_model = True)
393
+ >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
394
+
395
+ """
396
+
397
+ async def main():
398
+ results = await asyncio.gather(
399
+ self.async_execute_model_call(*args, **kwargs)
400
+ )
401
+ return results[0] # Since there's only one task, return its result
402
+
403
+ return main()
404
+
405
+ @classmethod
406
+ def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
407
+ """Return the generated token string from the raw response."""
408
+ return extract_item_from_raw_response(raw_response, cls.key_sequence)
409
+
410
+ @classmethod
411
+ def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
412
+ """Return the usage dictionary from the raw response."""
413
+ if not hasattr(cls, "usage_sequence"):
414
+ raise NotImplementedError(
415
+ "This inference service does not have a usage_sequence."
416
+ )
417
+ return extract_item_from_raw_response(raw_response, cls.usage_sequence)
418
+
419
+ @classmethod
420
+ def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
421
+ """Parses the API response and returns the response text."""
422
+ generated_token_string = cls.get_generated_token_string(raw_response)
423
+ last_newline = generated_token_string.rfind("\n")
424
+
425
+ if last_newline == -1:
426
+ # There is no comment
427
+ edsl_dict = {
428
+ "answer": convert_answer(generated_token_string),
429
+ "generated_tokens": generated_token_string,
430
+ "comment": None,
431
+ }
432
+ else:
433
+ edsl_dict = {
434
+ "answer": convert_answer(generated_token_string[:last_newline]),
435
+ "comment": generated_token_string[last_newline + 1 :].strip(),
436
+ "generated_tokens": generated_token_string,
437
+ }
438
+ return EDSLOutput(**edsl_dict)
439
+
440
+ async def _async_get_intended_model_call_outcome(
441
+ self,
442
+ user_prompt: str,
443
+ system_prompt: str,
444
+ cache: "Cache",
445
+ iteration: int = 0,
446
+ files_list=None,
447
+ ) -> ModelResponse:
448
+ """Handle caching of responses.
449
+
450
+ :param user_prompt: The user's prompt.
451
+ :param system_prompt: The system's prompt.
452
+ :param iteration: The iteration number.
453
+ :param cache: The cache to use.
454
+
455
+ If the cache isn't being used, it just returns a 'fresh' call to the LLM.
456
+ But if cache is being used, it first checks the database to see if the response is already there.
457
+ If it is, it returns the cached response, but again appends some tracking information.
458
+ If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
459
+
460
+ If self.use_cache is True, then attempts to retrieve the response from the database;
461
+ if not in the DB, calls the LLM and writes the response to the DB.
462
+
463
+ >>> from edsl import Cache
464
+ >>> m = LanguageModel.example(test_model = True)
465
+ >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
466
+ ModelResponse(...)"""
467
+
468
+ if files_list:
469
+ files_hash = "+".join([str(hash(file)) for file in files_list])
470
+ # print(f"Files hash: {files_hash}")
471
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
472
+ else:
473
+ user_prompt_with_hashes = user_prompt
474
+
475
+ cache_call_params = {
476
+ "model": str(self.model),
477
+ "parameters": self.parameters,
478
+ "system_prompt": system_prompt,
479
+ "user_prompt": user_prompt_with_hashes,
480
+ "iteration": iteration,
481
+ }
482
+ cached_response, cache_key = cache.fetch(**cache_call_params)
483
+
484
+ if cache_used := cached_response is not None:
485
+ response = json.loads(cached_response)
486
+ else:
487
+ f = (
488
+ self.remote_async_execute_model_call
489
+ if hasattr(self, "remote") and self.remote
490
+ else self.async_execute_model_call
491
+ )
492
+ params = {
493
+ "user_prompt": user_prompt,
494
+ "system_prompt": system_prompt,
495
+ "files_list": files_list
496
+ # **({"encoded_image": encoded_image} if encoded_image else {}),
497
+ }
498
+ # response = await f(**params)
499
+ response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
500
+ new_cache_key = cache.store(
501
+ **cache_call_params, response=response
502
+ ) # store the response in the cache
503
+ assert new_cache_key == cache_key # should be the same
504
+
505
+ cost = self.cost(response)
506
+
507
+ return ModelResponse(
508
+ response=response,
509
+ cache_used=cache_used,
510
+ cache_key=cache_key,
511
+ cached_response=cached_response,
512
+ cost=cost,
513
+ )
514
+
515
+ _get_intended_model_call_outcome = sync_wrapper(
516
+ _async_get_intended_model_call_outcome
517
+ )
518
+
519
+ # get_raw_response = sync_wrapper(async_get_raw_response)
520
+
521
+ def simple_ask(
522
+ self,
523
+ question: "QuestionBase",
524
+ system_prompt="You are a helpful agent pretending to be a human.",
525
+ top_logprobs=2,
526
+ ):
527
+ """Ask a question and return the response."""
528
+ self.logprobs = True
529
+ self.top_logprobs = top_logprobs
530
+ return self.execute_model_call(
531
+ user_prompt=question.human_readable(), system_prompt=system_prompt
532
+ )
533
+
534
+ async def async_get_response(
535
+ self,
536
+ user_prompt: str,
537
+ system_prompt: str,
538
+ cache: "Cache",
539
+ iteration: int = 1,
540
+ files_list: Optional[List["File"]] = None,
541
+ ) -> dict:
542
+ """Get response, parse, and return as string.
543
+
544
+ :param user_prompt: The user's prompt.
545
+ :param system_prompt: The system's prompt.
546
+ :param iteration: The iteration number.
547
+ :param cache: The cache to use.
548
+ :param encoded_image: The encoded image to use.
549
+
550
+ """
551
+ params = {
552
+ "user_prompt": user_prompt,
553
+ "system_prompt": system_prompt,
554
+ "iteration": iteration,
555
+ "cache": cache,
556
+ "files_list": files_list,
557
+ }
558
+ model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
559
+ model_outputs = await self._async_get_intended_model_call_outcome(**params)
560
+ edsl_dict = self.parse_response(model_outputs.response)
561
+ agent_response_dict = AgentResponseDict(
562
+ model_inputs=model_inputs,
563
+ model_outputs=model_outputs,
564
+ edsl_dict=edsl_dict,
565
+ )
566
+ return agent_response_dict
567
+
568
+ # return await self._async_prepare_response(model_call_outcome, cache=cache)
569
+
570
+ get_response = sync_wrapper(async_get_response)
571
+
572
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
573
+ """Return the dollar cost of a raw response."""
574
+
575
+ usage = self.get_usage_dict(raw_response)
576
+ from edsl.coop import Coop
577
+
578
+ c = Coop()
579
+ price_lookup = c.fetch_prices()
580
+ key = (self._inference_service_, self.model)
581
+ if key not in price_lookup:
582
+ return f"Could not find price for model {self.model} in the price lookup."
583
+
584
+ relevant_prices = price_lookup[key]
585
+ try:
586
+ input_tokens = int(usage[self.input_token_name])
587
+ output_tokens = int(usage[self.output_token_name])
588
+ except Exception as e:
589
+ return f"Could not fetch tokens from model response: {e}"
590
+
591
+ try:
592
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
593
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
594
+ except Exception as e:
595
+ if "output" not in relevant_prices:
596
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
597
+ if "input" not in relevant_prices:
598
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
599
+ return f"Could not fetch prices from {relevant_prices} - {e}"
600
+
601
+ if inverse_input_price == "infinity":
602
+ input_cost = 0
603
+ else:
604
+ try:
605
+ input_cost = input_tokens / float(inverse_input_price)
606
+ except Exception as e:
607
+ return f"Could not compute input price - {e}."
608
+
609
+ if inverse_output_price == "infinity":
610
+ output_cost = 0
611
+ else:
612
+ try:
613
+ output_cost = output_tokens / float(inverse_output_price)
614
+ except Exception as e:
615
+ return f"Could not compute output price - {e}"
616
+
617
+ return input_cost + output_cost
618
+
619
+ #######################
620
+ # SERIALIZATION METHODS
621
+ #######################
622
+ def _to_dict(self) -> dict[str, Any]:
623
+ return {"model": self.model, "parameters": self.parameters}
624
+
625
+ @add_edsl_version
626
+ def to_dict(self) -> dict[str, Any]:
627
+ """Convert instance to a dictionary.
628
+
629
+ >>> m = LanguageModel.example()
630
+ >>> m.to_dict()
631
+ {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
632
+ """
633
+ return self._to_dict()
634
+
635
+ @classmethod
636
+ @remove_edsl_version
637
+ def from_dict(cls, data: dict) -> Type[LanguageModel]:
638
+ """Convert dictionary to a LanguageModel child instance."""
639
+ from edsl.language_models.registry import get_model_class
640
+
641
+ model_class = get_model_class(data["model"])
642
+ # data["use_cache"] = True
643
+ return model_class(**data)
644
+
645
+ #######################
646
+ # DUNDER METHODS
647
+ #######################
648
+ def print(self):
649
+ from rich import print_json
650
+ import json
651
+
652
+ print_json(json.dumps(self.to_dict()))
653
+
654
+ def __repr__(self) -> str:
655
+ """Return a string representation of the object."""
656
+ param_string = ", ".join(
657
+ f"{key} = {value}" for key, value in self.parameters.items()
658
+ )
659
+ return (
660
+ f"Model(model_name = '{self.model}'"
661
+ + (f", {param_string}" if param_string else "")
662
+ + ")"
663
+ )
664
+
665
+ def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
666
+ """Combine two models into a single model (other_model takes precedence over self)."""
667
+ print(
668
+ f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
669
+ by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
670
+ )
671
+ return other_model or self
672
+
673
+ def rich_print(self):
674
+ """Display an object as a table."""
675
+ from rich.table import Table
676
+
677
+ table = Table(title="Language Model")
678
+ table.add_column("Attribute", style="bold")
679
+ table.add_column("Value")
680
+
681
+ to_display = self.__dict__.copy()
682
+ for attr_name, attr_value in to_display.items():
683
+ table.add_row(attr_name, repr(attr_value))
684
+
685
+ return table
686
+
687
+ @classmethod
688
+ def example(
689
+ cls,
690
+ test_model: bool = False,
691
+ canned_response: str = "Hello world",
692
+ throw_exception: bool = False,
693
+ ):
694
+ """Return a default instance of the class.
695
+
696
+ >>> from edsl.language_models import LanguageModel
697
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
698
+ >>> isinstance(m, LanguageModel)
699
+ True
700
+ >>> from edsl import QuestionFreeText
701
+ >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
702
+ >>> q.by(m).run(cache = False).select('example').first()
703
+ 'WOWZA!'
704
+ """
705
+ from edsl import Model
706
+
707
+ if test_model:
708
+ m = Model("test", canned_response=canned_response)
709
+ return m
710
+ else:
711
+ return Model(skip_api_key_check=True)
712
+
713
+
714
+ if __name__ == "__main__":
715
+ """Run the module's test suite."""
716
+ import doctest
717
+
718
+ doctest.testmod(optionflags=doctest.ELLIPSIS)