edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -0,0 +1,30 @@
1
+ import os
2
+ from collections import UserDict
3
+
4
+ from edsl.enums import service_to_api_keyname
5
+ from edsl.exceptions import MissingAPIKeyError
6
+
7
+
8
+ class KeyLookup(UserDict):
9
+ @classmethod
10
+ def from_os_environ(cls):
11
+ """Create an instance of KeyLookupAPI with keys from os.environ"""
12
+ return cls({key: value for key, value in os.environ.items()})
13
+
14
+ def get_api_token(self, service: str, remote: bool = False):
15
+ key_name = service_to_api_keyname.get(service, "NOT FOUND")
16
+
17
+ if service == "bedrock":
18
+ api_token = [self.get(key_name[0]), self.get(key_name[1])]
19
+ missing_token = any(token is None for token in api_token)
20
+ else:
21
+ api_token = self.get(key_name)
22
+ missing_token = api_token is None
23
+
24
+ if missing_token and service != "test" and not remote:
25
+ raise MissingAPIKeyError(
26
+ f"""The key for service: `{service}` is not set.
27
+ Need a key with name {key_name} in your .env file."""
28
+ )
29
+
30
+ return api_token
@@ -21,6 +21,7 @@ import os
21
21
  from typing import (
22
22
  Coroutine,
23
23
  Any,
24
+ Callable,
24
25
  Type,
25
26
  Union,
26
27
  List,
@@ -31,6 +32,8 @@ from typing import (
31
32
  )
32
33
  from abc import ABC, abstractmethod
33
34
 
35
+ from json_repair import repair_json
36
+
34
37
  from edsl.data_transfer_models import (
35
38
  ModelResponse,
36
39
  ModelInputs,
@@ -42,24 +45,61 @@ if TYPE_CHECKING:
42
45
  from edsl.data.Cache import Cache
43
46
  from edsl.scenarios.FileStore import FileStore
44
47
  from edsl.questions.QuestionBase import QuestionBase
45
- from edsl.language_models.key_management.KeyLookup import KeyLookup
46
-
47
- from edsl.enums import InferenceServiceType
48
48
 
49
- from edsl.utilities.decorators import (
50
- sync_wrapper,
51
- jupyter_nb_handler,
52
- )
53
- from edsl.utilities.remove_edsl_version import remove_edsl_version
49
+ from edsl.config import CONFIG
50
+ from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
51
+ from edsl.utilities.decorators import remove_edsl_version
54
52
 
55
- from edsl.Base import PersistenceMixin, RepresentationMixin
53
+ from edsl.Base import PersistenceMixin
56
54
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
+ from edsl.language_models.KeyLookup import KeyLookup
56
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
57
57
 
58
- from edsl.language_models.key_management.KeyLookupCollection import (
59
- KeyLookupCollection,
60
- )
58
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
59
+
60
+
61
+ # you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
62
+ # for reasons I don't understand. So leave it here.
63
+ def extract_item_from_raw_response(data, key_sequence):
64
+ if isinstance(data, str):
65
+ try:
66
+ data = json.loads(data)
67
+ except json.JSONDecodeError as e:
68
+ return data
69
+ current_data = data
70
+ for i, key in enumerate(key_sequence):
71
+ try:
72
+ if isinstance(current_data, (list, tuple)):
73
+ if not isinstance(key, int):
74
+ raise TypeError(
75
+ f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
76
+ )
77
+ if key < 0 or key >= len(current_data):
78
+ raise IndexError(
79
+ f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
80
+ )
81
+ elif isinstance(current_data, dict):
82
+ if key not in current_data:
83
+ raise KeyError(
84
+ f"Key '{key}' not found in dictionary at position {i}"
85
+ )
86
+ else:
87
+ raise TypeError(
88
+ f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
89
+ )
61
90
 
62
- from edsl.language_models.RawResponseHandler import RawResponseHandler
91
+ current_data = current_data[key]
92
+ except Exception as e:
93
+ path = " -> ".join(map(str, key_sequence[: i + 1]))
94
+ if "error" in data:
95
+ msg = data["error"]
96
+ else:
97
+ msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
98
+ raise LanguageModelBadResponseError(message=msg, response_json=data)
99
+ if isinstance(current_data, str):
100
+ return current_data.strip()
101
+ else:
102
+ return current_data
63
103
 
64
104
 
65
105
  def handle_key_error(func):
@@ -77,21 +117,8 @@ def handle_key_error(func):
77
117
  return wrapper
78
118
 
79
119
 
80
- class classproperty:
81
- def __init__(self, method):
82
- self.method = method
83
-
84
- def __get__(self, instance, cls):
85
- return self.method(cls)
86
-
87
-
88
- from edsl.Base import HashingMixin
89
-
90
-
91
120
  class LanguageModel(
92
121
  PersistenceMixin,
93
- RepresentationMixin,
94
- HashingMixin,
95
122
  ABC,
96
123
  metaclass=RegisterLanguageModelsMeta,
97
124
  ):
@@ -101,22 +128,15 @@ class LanguageModel(
101
128
  key_sequence = (
102
129
  None # This should be something like ["choices", 0, "message", "content"]
103
130
  )
104
-
105
- DEFAULT_RPM = 100
106
- DEFAULT_TPM = 1000
107
-
108
- @classproperty
109
- def response_handler(cls):
110
- key_sequence = cls.key_sequence
111
- usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
112
- return RawResponseHandler(key_sequence, usage_sequence)
131
+ __rate_limits = None
132
+ _safety_factor = 0.8
113
133
 
114
134
  def __init__(
115
135
  self,
116
- tpm: Optional[float] = None,
117
- rpm: Optional[float] = None,
136
+ tpm: float = None,
137
+ rpm: float = None,
118
138
  omit_system_prompt_if_empty_string: bool = True,
119
- key_lookup: Optional["KeyLookup"] = None,
139
+ key_lookup: Optional[KeyLookup] = None,
120
140
  **kwargs,
121
141
  ):
122
142
  """Initialize the LanguageModel."""
@@ -127,9 +147,7 @@ class LanguageModel(
127
147
  self.remote = False
128
148
  self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
129
149
 
130
- self.key_lookup = self._set_key_lookup(key_lookup)
131
- self.model_info = self.key_lookup.get(self._inference_service_)
132
-
150
+ # self._rpm / _tpm comes from the class
133
151
  if rpm is not None:
134
152
  self._rpm = rpm
135
153
 
@@ -143,75 +161,49 @@ class LanguageModel(
143
161
  if key not in parameters:
144
162
  setattr(self, key, value)
145
163
 
146
- if kwargs.get("skip_api_key_check", False):
164
+ if "use_cache" in kwargs:
165
+ warnings.warn(
166
+ "The use_cache parameter is deprecated. Use the Cache class instead."
167
+ )
168
+
169
+ if skip_api_key_check := kwargs.get("skip_api_key_check", False):
147
170
  # Skip the API key check. Sometimes this is useful for testing.
148
171
  self._api_token = None
149
172
 
150
- def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
151
- """Set the key lookup."""
152
173
  if key_lookup is not None:
153
- return key_lookup
174
+ self.key_lookup = key_lookup
154
175
  else:
155
- klc = KeyLookupCollection()
156
- klc.add_key_lookup(fetch_order=("config", "env"))
157
- return klc.get(("config", "env"))
158
-
159
- def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
160
- """Set the key lookup, later"""
161
- if hasattr(self, "_api_token"):
162
- del self._api_token
163
- self.key_lookup = key_lookup
164
-
165
- def ask_question(self, question: "QuestionBase") -> str:
166
- """Ask a question and return the response.
176
+ self.key_lookup = KeyLookup.from_os_environ()
167
177
 
168
- :param question: The question to ask.
169
- """
178
+ def ask_question(self, question):
170
179
  user_prompt = question.get_instructions().render(question.data).text
171
180
  system_prompt = "You are a helpful agent pretending to be a human."
172
181
  return self.execute_model_call(user_prompt, system_prompt)
173
182
 
174
- @property
175
- def rpm(self):
176
- if not hasattr(self, "_rpm"):
177
- if self.model_info is None:
178
- self._rpm = self.DEFAULT_RPM
179
- else:
180
- self._rpm = self.model_info.rpm
181
- return self._rpm
182
-
183
- @property
184
- def tpm(self):
185
- if not hasattr(self, "_tpm"):
186
- if self.model_info is None:
187
- self._tpm = self.DEFAULT_TPM
188
- else:
189
- self._tpm = self.model_info.tpm
190
- return self._tpm
191
-
192
- # in case we want to override the default values
193
- @tpm.setter
194
- def tpm(self, value):
195
- self._tpm = value
196
-
197
- @rpm.setter
198
- def rpm(self, value):
199
- self._rpm = value
183
+ def set_key_lookup(self, key_lookup: KeyLookup) -> None:
184
+ del self._api_token
185
+ self.key_lookup = key_lookup
200
186
 
201
187
  @property
202
188
  def api_token(self) -> str:
203
189
  if not hasattr(self, "_api_token"):
204
- info = self.key_lookup.get(self._inference_service_, None)
205
- if info is None:
206
- raise ValueError(
207
- f"No key found for service '{self._inference_service_}'"
208
- )
209
- self._api_token = info.api_token
190
+ self._api_token = self.key_lookup.get_api_token(
191
+ self._inference_service_, self.remote
192
+ )
210
193
  return self._api_token
211
194
 
212
195
  def __getitem__(self, key):
213
196
  return getattr(self, key)
214
197
 
198
+ def _repr_html_(self) -> str:
199
+ d = {"model": self.model}
200
+ d.update(self.parameters)
201
+ data = [[k, v] for k, v in d.items()]
202
+ from tabulate import tabulate
203
+
204
+ table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
205
+ return f"<pre>{table}</pre>"
206
+
215
207
  def hello(self, verbose=False):
216
208
  """Runs a simple test to check if the model is working."""
217
209
  token = self.api_token
@@ -240,12 +232,7 @@ class LanguageModel(
240
232
  return key_value is not None
241
233
 
242
234
  def __hash__(self) -> str:
243
- """Allow the model to be used as a key in a dictionary.
244
-
245
- >>> m = LanguageModel.example()
246
- >>> hash(m)
247
- 1811901442659237949
248
- """
235
+ """Allow the model to be used as a key in a dictionary."""
249
236
  from edsl.utilities.utilities import dict_hash
250
237
 
251
238
  return dict_hash(self.to_dict(add_edsl_version=False))
@@ -261,6 +248,46 @@ class LanguageModel(
261
248
  """
262
249
  return self.model == other.model and self.parameters == other.parameters
263
250
 
251
+ def set_rate_limits(self, rpm=None, tpm=None) -> None:
252
+ """Set the rate limits for the model.
253
+
254
+ >>> m = LanguageModel.example()
255
+ >>> m.set_rate_limits(rpm=100, tpm=1000)
256
+ >>> m.RPM
257
+ 100
258
+ """
259
+ if rpm is not None:
260
+ self._rpm = rpm
261
+ if tpm is not None:
262
+ self._tpm = tpm
263
+ return None
264
+
265
+ @property
266
+ def RPM(self):
267
+ """Model's requests-per-minute limit."""
268
+ return self._rpm
269
+
270
+ @property
271
+ def TPM(self):
272
+ """Model's tokens-per-minute limit."""
273
+ return self._tpm
274
+
275
+ @property
276
+ def rpm(self):
277
+ return self._rpm
278
+
279
+ @rpm.setter
280
+ def rpm(self, value):
281
+ self._rpm = value
282
+
283
+ @property
284
+ def tpm(self):
285
+ return self._tpm
286
+
287
+ @tpm.setter
288
+ def tpm(self, value):
289
+ self._tpm = value
290
+
264
291
  @staticmethod
265
292
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
266
293
  """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
@@ -283,7 +310,16 @@ class LanguageModel(
283
310
 
284
311
  @abstractmethod
285
312
  async def async_execute_model_call(user_prompt: str, system_prompt: str):
286
- """Execute the model call and returns a coroutine."""
313
+ """Execute the model call and returns a coroutine.
314
+
315
+ >>> m = LanguageModel.example(test_model = True)
316
+ >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
317
+ >>> asyncio.run(test())
318
+ {'message': [{'text': 'Hello world'}], ...}
319
+
320
+ >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
321
+ {'message': [{'text': 'Hello world'}], ...}
322
+ """
287
323
  pass
288
324
 
289
325
  async def remote_async_execute_model_call(
@@ -300,7 +336,12 @@ class LanguageModel(
300
336
 
301
337
  @jupyter_nb_handler
302
338
  def execute_model_call(self, *args, **kwargs) -> Coroutine:
303
- """Execute the model call and returns the result as a coroutine."""
339
+ """Execute the model call and returns the result as a coroutine.
340
+
341
+ >>> m = LanguageModel.example(test_model = True)
342
+ >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
343
+
344
+ """
304
345
 
305
346
  async def main():
306
347
  results = await asyncio.gather(
@@ -312,25 +353,58 @@ class LanguageModel(
312
353
 
313
354
  @classmethod
314
355
  def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
315
- """Return the generated token string from the raw response.
316
-
317
- >>> m = LanguageModel.example(test_model = True)
318
- >>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
319
- >>> m.get_generated_token_string(raw_response)
320
- 'Hello world'
321
-
322
- """
323
- return cls.response_handler.get_generated_token_string(raw_response)
356
+ """Return the generated token string from the raw response."""
357
+ return extract_item_from_raw_response(raw_response, cls.key_sequence)
324
358
 
325
359
  @classmethod
326
360
  def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
327
361
  """Return the usage dictionary from the raw response."""
328
- return cls.response_handler.get_usage_dict(raw_response)
362
+ if not hasattr(cls, "usage_sequence"):
363
+ raise NotImplementedError(
364
+ "This inference service does not have a usage_sequence."
365
+ )
366
+ return extract_item_from_raw_response(raw_response, cls.usage_sequence)
367
+
368
+ @staticmethod
369
+ def convert_answer(response_part):
370
+ import json
371
+
372
+ response_part = response_part.strip()
373
+
374
+ if response_part == "None":
375
+ return None
376
+
377
+ repaired = repair_json(response_part)
378
+ if repaired == '""':
379
+ # it was a literal string
380
+ return response_part
381
+
382
+ try:
383
+ return json.loads(repaired)
384
+ except json.JSONDecodeError as j:
385
+ # last resort
386
+ return response_part
329
387
 
330
388
  @classmethod
331
389
  def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
332
390
  """Parses the API response and returns the response text."""
333
- return cls.response_handler.parse_response(raw_response)
391
+ generated_token_string = cls.get_generated_token_string(raw_response)
392
+ last_newline = generated_token_string.rfind("\n")
393
+
394
+ if last_newline == -1:
395
+ # There is no comment
396
+ edsl_dict = {
397
+ "answer": cls.convert_answer(generated_token_string),
398
+ "generated_tokens": generated_token_string,
399
+ "comment": None,
400
+ }
401
+ else:
402
+ edsl_dict = {
403
+ "answer": cls.convert_answer(generated_token_string[:last_newline]),
404
+ "comment": generated_token_string[last_newline + 1 :].strip(),
405
+ "generated_tokens": generated_token_string,
406
+ }
407
+ return EDSLOutput(**edsl_dict)
334
408
 
335
409
  async def _async_get_intended_model_call_outcome(
336
410
  self,
@@ -347,8 +421,6 @@ class LanguageModel(
347
421
  :param system_prompt: The system's prompt.
348
422
  :param iteration: The iteration number.
349
423
  :param cache: The cache to use.
350
- :param files_list: The list of files to use.
351
- :param invigilator: The invigilator to use.
352
424
 
353
425
  If the cache isn't being used, it just returns a 'fresh' call to the LLM.
354
426
  But if cache is being used, it first checks the database to see if the response is already there.
@@ -391,10 +463,6 @@ class LanguageModel(
391
463
  "system_prompt": system_prompt,
392
464
  "files_list": files_list,
393
465
  }
394
- from edsl.config import CONFIG
395
-
396
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
-
398
466
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
399
467
  new_cache_key = cache.store(
400
468
  **cache_call_params, response=response
@@ -402,6 +470,7 @@ class LanguageModel(
402
470
  assert new_cache_key == cache_key # should be the same
403
471
 
404
472
  cost = self.cost(response)
473
+
405
474
  return ModelResponse(
406
475
  response=response,
407
476
  cache_used=cache_used,
@@ -440,9 +509,9 @@ class LanguageModel(
440
509
 
441
510
  :param user_prompt: The user's prompt.
442
511
  :param system_prompt: The system's prompt.
443
- :param cache: The cache to use.
444
512
  :param iteration: The iteration number.
445
- :param files_list: The list of files to use.
513
+ :param cache: The cache to use.
514
+ :param encoded_image: The encoded image to use.
446
515
 
447
516
  """
448
517
  params = {
@@ -456,11 +525,8 @@ class LanguageModel(
456
525
  params.update({"invigilator": kwargs["invigilator"]})
457
526
 
458
527
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
459
- model_outputs: ModelResponse = (
460
- await self._async_get_intended_model_call_outcome(**params)
461
- )
462
- edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
463
-
528
+ model_outputs = await self._async_get_intended_model_call_outcome(**params)
529
+ edsl_dict = self.parse_response(model_outputs.response)
464
530
  agent_response_dict = AgentResponseDict(
465
531
  model_inputs=model_inputs,
466
532
  model_outputs=model_outputs,
@@ -471,36 +537,60 @@ class LanguageModel(
471
537
  get_response = sync_wrapper(async_get_response)
472
538
 
473
539
  def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
474
- """Return the dollar cost of a raw response.
475
-
476
- :param raw_response: The raw response from the model.
477
- """
540
+ """Return the dollar cost of a raw response."""
478
541
 
479
542
  usage = self.get_usage_dict(raw_response)
480
- from edsl.language_models.PriceManager import PriceManager
481
-
482
- price_manger = PriceManager()
483
- return price_manger.calculate_cost(
484
- inference_service=self._inference_service_,
485
- model=self.model,
486
- usage=usage,
487
- input_token_name=self.input_token_name,
488
- output_token_name=self.output_token_name,
489
- )
543
+ from edsl.coop import Coop
544
+
545
+ c = Coop()
546
+ price_lookup = c.fetch_prices()
547
+ key = (self._inference_service_, self.model)
548
+ if key not in price_lookup:
549
+ return f"Could not find price for model {self.model} in the price lookup."
550
+
551
+ relevant_prices = price_lookup[key]
552
+ try:
553
+ input_tokens = int(usage[self.input_token_name])
554
+ output_tokens = int(usage[self.output_token_name])
555
+ except Exception as e:
556
+ return f"Could not fetch tokens from model response: {e}"
557
+
558
+ try:
559
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
560
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
561
+ except Exception as e:
562
+ if "output" not in relevant_prices:
563
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
564
+ if "input" not in relevant_prices:
565
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
566
+ return f"Could not fetch prices from {relevant_prices} - {e}"
567
+
568
+ if inverse_input_price == "infinity":
569
+ input_cost = 0
570
+ else:
571
+ try:
572
+ input_cost = input_tokens / float(inverse_input_price)
573
+ except Exception as e:
574
+ return f"Could not compute input price - {e}."
575
+
576
+ if inverse_output_price == "infinity":
577
+ output_cost = 0
578
+ else:
579
+ try:
580
+ output_cost = output_tokens / float(inverse_output_price)
581
+ except Exception as e:
582
+ return f"Could not compute output price - {e}"
583
+
584
+ return input_cost + output_cost
490
585
 
491
586
  def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
492
587
  """Convert instance to a dictionary
493
588
 
494
- :param add_edsl_version: Whether to add the EDSL version to the dictionary.
495
-
496
589
  >>> m = LanguageModel.example()
497
590
  >>> m.to_dict()
498
591
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
499
592
  """
500
- d = {
501
- "model": self.model,
502
- "parameters": self.parameters,
503
- }
593
+ d = {"model": self.model, "parameters": self.parameters}
504
594
  if add_edsl_version:
505
595
  from edsl import __version__
506
596
 
@@ -512,13 +602,13 @@ class LanguageModel(
512
602
  @remove_edsl_version
513
603
  def from_dict(cls, data: dict) -> Type[LanguageModel]:
514
604
  """Convert dictionary to a LanguageModel child instance."""
515
- from edsl.language_models.model import get_model_class
605
+ from edsl.language_models.registry import get_model_class
516
606
 
517
607
  model_class = get_model_class(data["model"])
518
608
  return model_class(**data)
519
609
 
520
610
  def __repr__(self) -> str:
521
- """Return a representation of the object."""
611
+ """Return a string representation of the object."""
522
612
  param_string = ", ".join(
523
613
  f"{key} = {value}" for key, value in self.parameters.items()
524
614
  )
@@ -560,7 +650,7 @@ class LanguageModel(
560
650
  Exception report saved to ...
561
651
  Also see: ...
562
652
  """
563
- from edsl.language_models.model import Model
653
+ from edsl import Model
564
654
 
565
655
  if test_model:
566
656
  m = Model(
@@ -570,54 +660,6 @@ class LanguageModel(
570
660
  else:
571
661
  return Model(skip_api_key_check=True)
572
662
 
573
- def from_cache(self, cache: "Cache") -> LanguageModel:
574
-
575
- from copy import deepcopy
576
- from types import MethodType
577
- from edsl import Cache
578
-
579
- new_instance = deepcopy(self)
580
- print("Cache entries", len(cache))
581
- new_instance.cache = Cache(
582
- data={k: v for k, v in cache.items() if v.model == self.model}
583
- )
584
- print("Cache entries with same model", len(new_instance.cache))
585
-
586
- new_instance.user_prompts = [
587
- ce.user_prompt for ce in new_instance.cache.values()
588
- ]
589
- new_instance.system_prompts = [
590
- ce.system_prompt for ce in new_instance.cache.values()
591
- ]
592
-
593
- async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
594
- cache_call_params = {
595
- "model": str(self.model),
596
- "parameters": self.parameters,
597
- "system_prompt": system_prompt,
598
- "user_prompt": user_prompt,
599
- "iteration": 1,
600
- }
601
- cached_response, cache_key = cache.fetch(**cache_call_params)
602
- response = json.loads(cached_response)
603
- cost = 0
604
- return ModelResponse(
605
- response=response,
606
- cache_used=True,
607
- cache_key=cache_key,
608
- cached_response=cached_response,
609
- cost=cost,
610
- )
611
-
612
- # Bind the new method to the copied instance
613
- setattr(
614
- new_instance,
615
- "async_execute_model_call",
616
- MethodType(async_execute_model_call, new_instance),
617
- )
618
-
619
- return new_instance
620
-
621
663
 
622
664
  if __name__ == "__main__":
623
665
  """Run the module's test suite."""