edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -21,7 +21,6 @@ import os
21
21
  from typing import (
22
22
  Coroutine,
23
23
  Any,
24
- Callable,
25
24
  Type,
26
25
  Union,
27
26
  List,
@@ -32,8 +31,6 @@ from typing import (
32
31
  )
33
32
  from abc import ABC, abstractmethod
34
33
 
35
- from json_repair import repair_json
36
-
37
34
  from edsl.data_transfer_models import (
38
35
  ModelResponse,
39
36
  ModelInputs,
@@ -45,61 +42,22 @@ if TYPE_CHECKING:
45
42
  from edsl.data.Cache import Cache
46
43
  from edsl.scenarios.FileStore import FileStore
47
44
  from edsl.questions.QuestionBase import QuestionBase
45
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
48
46
 
49
- from edsl.config import CONFIG
50
- from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
51
- from edsl.utilities.decorators import remove_edsl_version
47
+ from edsl.utilities.decorators import (
48
+ sync_wrapper,
49
+ jupyter_nb_handler,
50
+ )
51
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
52
52
 
53
- from edsl.Base import PersistenceMixin
53
+ from edsl.Base import PersistenceMixin, RepresentationMixin
54
54
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
- from edsl.language_models.KeyLookup import KeyLookup
56
- from edsl.exceptions.language_models import LanguageModelBadResponseError
57
-
58
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
59
-
60
55
 
61
- # you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
62
- # for reasons I don't understand. So leave it here.
63
- def extract_item_from_raw_response(data, key_sequence):
64
- if isinstance(data, str):
65
- try:
66
- data = json.loads(data)
67
- except json.JSONDecodeError as e:
68
- return data
69
- current_data = data
70
- for i, key in enumerate(key_sequence):
71
- try:
72
- if isinstance(current_data, (list, tuple)):
73
- if not isinstance(key, int):
74
- raise TypeError(
75
- f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
76
- )
77
- if key < 0 or key >= len(current_data):
78
- raise IndexError(
79
- f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
80
- )
81
- elif isinstance(current_data, dict):
82
- if key not in current_data:
83
- raise KeyError(
84
- f"Key '{key}' not found in dictionary at position {i}"
85
- )
86
- else:
87
- raise TypeError(
88
- f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
89
- )
56
+ from edsl.language_models.key_management.KeyLookupCollection import (
57
+ KeyLookupCollection,
58
+ )
90
59
 
91
- current_data = current_data[key]
92
- except Exception as e:
93
- path = " -> ".join(map(str, key_sequence[: i + 1]))
94
- if "error" in data:
95
- msg = data["error"]
96
- else:
97
- msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
98
- raise LanguageModelBadResponseError(message=msg, response_json=data)
99
- if isinstance(current_data, str):
100
- return current_data.strip()
101
- else:
102
- return current_data
60
+ from edsl.language_models.RawResponseHandler import RawResponseHandler
103
61
 
104
62
 
105
63
  def handle_key_error(func):
@@ -117,8 +75,21 @@ def handle_key_error(func):
117
75
  return wrapper
118
76
 
119
77
 
78
+ class classproperty:
79
+ def __init__(self, method):
80
+ self.method = method
81
+
82
+ def __get__(self, instance, cls):
83
+ return self.method(cls)
84
+
85
+
86
+ from edsl.Base import HashingMixin
87
+
88
+
120
89
  class LanguageModel(
121
90
  PersistenceMixin,
91
+ RepresentationMixin,
92
+ HashingMixin,
122
93
  ABC,
123
94
  metaclass=RegisterLanguageModelsMeta,
124
95
  ):
@@ -128,15 +99,22 @@ class LanguageModel(
128
99
  key_sequence = (
129
100
  None # This should be something like ["choices", 0, "message", "content"]
130
101
  )
131
- __rate_limits = None
132
- _safety_factor = 0.8
102
+
103
+ DEFAULT_RPM = 100
104
+ DEFAULT_TPM = 1000
105
+
106
+ @classproperty
107
+ def response_handler(cls):
108
+ key_sequence = cls.key_sequence
109
+ usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
110
+ return RawResponseHandler(key_sequence, usage_sequence)
133
111
 
134
112
  def __init__(
135
113
  self,
136
- tpm: float = None,
137
- rpm: float = None,
114
+ tpm: Optional[float] = None,
115
+ rpm: Optional[float] = None,
138
116
  omit_system_prompt_if_empty_string: bool = True,
139
- key_lookup: Optional[KeyLookup] = None,
117
+ key_lookup: Optional["KeyLookup"] = None,
140
118
  **kwargs,
141
119
  ):
142
120
  """Initialize the LanguageModel."""
@@ -147,7 +125,9 @@ class LanguageModel(
147
125
  self.remote = False
148
126
  self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
149
127
 
150
- # self._rpm / _tpm comes from the class
128
+ self.key_lookup = self._set_key_lookup(key_lookup)
129
+ self.model_info = self.key_lookup.get(self._inference_service_)
130
+
151
131
  if rpm is not None:
152
132
  self._rpm = rpm
153
133
 
@@ -161,49 +141,73 @@ class LanguageModel(
161
141
  if key not in parameters:
162
142
  setattr(self, key, value)
163
143
 
164
- if "use_cache" in kwargs:
165
- warnings.warn(
166
- "The use_cache parameter is deprecated. Use the Cache class instead."
167
- )
168
-
169
- if skip_api_key_check := kwargs.get("skip_api_key_check", False):
144
+ if kwargs.get("skip_api_key_check", False):
170
145
  # Skip the API key check. Sometimes this is useful for testing.
171
146
  self._api_token = None
172
147
 
148
+ def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
149
+ """Set the key lookup."""
173
150
  if key_lookup is not None:
174
- self.key_lookup = key_lookup
151
+ return key_lookup
175
152
  else:
176
- self.key_lookup = KeyLookup.from_os_environ()
153
+ klc = KeyLookupCollection()
154
+ klc.add_key_lookup(fetch_order=("config", "env"))
155
+ return klc.get(("config", "env"))
156
+
157
+ def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
158
+ del self._api_token
159
+ self.key_lookup = key_lookup
160
+
161
+ def ask_question(self, question: "QuestionBase") -> str:
162
+ """Ask a question and return the response.
177
163
 
178
- def ask_question(self, question):
164
+ :param question: The question to ask.
165
+ """
179
166
  user_prompt = question.get_instructions().render(question.data).text
180
167
  system_prompt = "You are a helpful agent pretending to be a human."
181
168
  return self.execute_model_call(user_prompt, system_prompt)
182
169
 
183
- def set_key_lookup(self, key_lookup: KeyLookup) -> None:
184
- del self._api_token
185
- self.key_lookup = key_lookup
170
+ @property
171
+ def rpm(self):
172
+ if not hasattr(self, "_rpm"):
173
+ if self.model_info is None:
174
+ self._rpm = self.DEFAULT_RPM
175
+ else:
176
+ self._rpm = self.model_info.rpm
177
+ return self._rpm
178
+
179
+ @property
180
+ def tpm(self):
181
+ if not hasattr(self, "_tpm"):
182
+ if self.model_info is None:
183
+ self._tpm = self.DEFAULT_TPM
184
+ else:
185
+ self._tpm = self.model_info.tpm
186
+ return self._tpm
187
+
188
+ # in case we want to override the default values
189
+ @tpm.setter
190
+ def tpm(self, value):
191
+ self._tpm = value
192
+
193
+ @rpm.setter
194
+ def rpm(self, value):
195
+ self._rpm = value
186
196
 
187
197
  @property
188
198
  def api_token(self) -> str:
189
199
  if not hasattr(self, "_api_token"):
190
- self._api_token = self.key_lookup.get_api_token(
191
- self._inference_service_, self.remote
192
- )
200
+ info = self.key_lookup.get(self._inference_service_, None)
201
+ if info is None:
202
+ raise ValueError(
203
+ f"No key found for service '{self._inference_service_}'"
204
+ )
205
+ self._api_token = info.api_token
193
206
  return self._api_token
194
207
 
195
208
  def __getitem__(self, key):
196
209
  return getattr(self, key)
197
210
 
198
- def _repr_html_(self) -> str:
199
- d = {"model": self.model}
200
- d.update(self.parameters)
201
- data = [[k, v] for k, v in d.items()]
202
- from tabulate import tabulate
203
-
204
- table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
205
- return f"<pre>{table}</pre>"
206
-
207
211
  def hello(self, verbose=False):
208
212
  """Runs a simple test to check if the model is working."""
209
213
  token = self.api_token
@@ -232,7 +236,12 @@ class LanguageModel(
232
236
  return key_value is not None
233
237
 
234
238
  def __hash__(self) -> str:
235
- """Allow the model to be used as a key in a dictionary."""
239
+ """Allow the model to be used as a key in a dictionary.
240
+
241
+ >>> m = LanguageModel.example()
242
+ >>> hash(m)
243
+ 1811901442659237949
244
+ """
236
245
  from edsl.utilities.utilities import dict_hash
237
246
 
238
247
  return dict_hash(self.to_dict(add_edsl_version=False))
@@ -248,46 +257,6 @@ class LanguageModel(
248
257
  """
249
258
  return self.model == other.model and self.parameters == other.parameters
250
259
 
251
- def set_rate_limits(self, rpm=None, tpm=None) -> None:
252
- """Set the rate limits for the model.
253
-
254
- >>> m = LanguageModel.example()
255
- >>> m.set_rate_limits(rpm=100, tpm=1000)
256
- >>> m.RPM
257
- 100
258
- """
259
- if rpm is not None:
260
- self._rpm = rpm
261
- if tpm is not None:
262
- self._tpm = tpm
263
- return None
264
-
265
- @property
266
- def RPM(self):
267
- """Model's requests-per-minute limit."""
268
- return self._rpm
269
-
270
- @property
271
- def TPM(self):
272
- """Model's tokens-per-minute limit."""
273
- return self._tpm
274
-
275
- @property
276
- def rpm(self):
277
- return self._rpm
278
-
279
- @rpm.setter
280
- def rpm(self, value):
281
- self._rpm = value
282
-
283
- @property
284
- def tpm(self):
285
- return self._tpm
286
-
287
- @tpm.setter
288
- def tpm(self, value):
289
- self._tpm = value
290
-
291
260
  @staticmethod
292
261
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
293
262
  """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
@@ -310,16 +279,7 @@ class LanguageModel(
310
279
 
311
280
  @abstractmethod
312
281
  async def async_execute_model_call(user_prompt: str, system_prompt: str):
313
- """Execute the model call and returns a coroutine.
314
-
315
- >>> m = LanguageModel.example(test_model = True)
316
- >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
317
- >>> asyncio.run(test())
318
- {'message': [{'text': 'Hello world'}], ...}
319
-
320
- >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
321
- {'message': [{'text': 'Hello world'}], ...}
322
- """
282
+ """Execute the model call and returns a coroutine."""
323
283
  pass
324
284
 
325
285
  async def remote_async_execute_model_call(
@@ -336,12 +296,7 @@ class LanguageModel(
336
296
 
337
297
  @jupyter_nb_handler
338
298
  def execute_model_call(self, *args, **kwargs) -> Coroutine:
339
- """Execute the model call and returns the result as a coroutine.
340
-
341
- >>> m = LanguageModel.example(test_model = True)
342
- >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
343
-
344
- """
299
+ """Execute the model call and returns the result as a coroutine."""
345
300
 
346
301
  async def main():
347
302
  results = await asyncio.gather(
@@ -353,58 +308,25 @@ class LanguageModel(
353
308
 
354
309
  @classmethod
355
310
  def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
356
- """Return the generated token string from the raw response."""
357
- return extract_item_from_raw_response(raw_response, cls.key_sequence)
311
+ """Return the generated token string from the raw response.
312
+
313
+ >>> m = LanguageModel.example(test_model = True)
314
+ >>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
315
+ >>> m.get_generated_token_string(raw_response)
316
+ 'Hello world'
317
+
318
+ """
319
+ return cls.response_handler.get_generated_token_string(raw_response)
358
320
 
359
321
  @classmethod
360
322
  def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
361
323
  """Return the usage dictionary from the raw response."""
362
- if not hasattr(cls, "usage_sequence"):
363
- raise NotImplementedError(
364
- "This inference service does not have a usage_sequence."
365
- )
366
- return extract_item_from_raw_response(raw_response, cls.usage_sequence)
367
-
368
- @staticmethod
369
- def convert_answer(response_part):
370
- import json
371
-
372
- response_part = response_part.strip()
373
-
374
- if response_part == "None":
375
- return None
376
-
377
- repaired = repair_json(response_part)
378
- if repaired == '""':
379
- # it was a literal string
380
- return response_part
381
-
382
- try:
383
- return json.loads(repaired)
384
- except json.JSONDecodeError as j:
385
- # last resort
386
- return response_part
324
+ return cls.response_handler.get_usage_dict(raw_response)
387
325
 
388
326
  @classmethod
389
327
  def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
390
328
  """Parses the API response and returns the response text."""
391
- generated_token_string = cls.get_generated_token_string(raw_response)
392
- last_newline = generated_token_string.rfind("\n")
393
-
394
- if last_newline == -1:
395
- # There is no comment
396
- edsl_dict = {
397
- "answer": cls.convert_answer(generated_token_string),
398
- "generated_tokens": generated_token_string,
399
- "comment": None,
400
- }
401
- else:
402
- edsl_dict = {
403
- "answer": cls.convert_answer(generated_token_string[:last_newline]),
404
- "comment": generated_token_string[last_newline + 1 :].strip(),
405
- "generated_tokens": generated_token_string,
406
- }
407
- return EDSLOutput(**edsl_dict)
329
+ return cls.response_handler.parse_response(raw_response)
408
330
 
409
331
  async def _async_get_intended_model_call_outcome(
410
332
  self,
@@ -421,6 +343,8 @@ class LanguageModel(
421
343
  :param system_prompt: The system's prompt.
422
344
  :param iteration: The iteration number.
423
345
  :param cache: The cache to use.
346
+ :param files_list: The list of files to use.
347
+ :param invigilator: The invigilator to use.
424
348
 
425
349
  If the cache isn't being used, it just returns a 'fresh' call to the LLM.
426
350
  But if cache is being used, it first checks the database to see if the response is already there.
@@ -463,6 +387,10 @@ class LanguageModel(
463
387
  "system_prompt": system_prompt,
464
388
  "files_list": files_list,
465
389
  }
390
+ from edsl.config import CONFIG
391
+
392
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
393
+
466
394
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
467
395
  new_cache_key = cache.store(
468
396
  **cache_call_params, response=response
@@ -470,7 +398,6 @@ class LanguageModel(
470
398
  assert new_cache_key == cache_key # should be the same
471
399
 
472
400
  cost = self.cost(response)
473
-
474
401
  return ModelResponse(
475
402
  response=response,
476
403
  cache_used=cache_used,
@@ -509,9 +436,9 @@ class LanguageModel(
509
436
 
510
437
  :param user_prompt: The user's prompt.
511
438
  :param system_prompt: The system's prompt.
512
- :param iteration: The iteration number.
513
439
  :param cache: The cache to use.
514
- :param encoded_image: The encoded image to use.
440
+ :param iteration: The iteration number.
441
+ :param files_list: The list of files to use.
515
442
 
516
443
  """
517
444
  params = {
@@ -525,8 +452,11 @@ class LanguageModel(
525
452
  params.update({"invigilator": kwargs["invigilator"]})
526
453
 
527
454
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
528
- model_outputs = await self._async_get_intended_model_call_outcome(**params)
529
- edsl_dict = self.parse_response(model_outputs.response)
455
+ model_outputs: ModelResponse = (
456
+ await self._async_get_intended_model_call_outcome(**params)
457
+ )
458
+ edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
459
+
530
460
  agent_response_dict = AgentResponseDict(
531
461
  model_inputs=model_inputs,
532
462
  model_outputs=model_outputs,
@@ -537,55 +467,28 @@ class LanguageModel(
537
467
  get_response = sync_wrapper(async_get_response)
538
468
 
539
469
  def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
540
- """Return the dollar cost of a raw response."""
541
-
542
- usage = self.get_usage_dict(raw_response)
543
- from edsl.coop import Coop
470
+ """Return the dollar cost of a raw response.
544
471
 
545
- c = Coop()
546
- price_lookup = c.fetch_prices()
547
- key = (self._inference_service_, self.model)
548
- if key not in price_lookup:
549
- return f"Could not find price for model {self.model} in the price lookup."
550
-
551
- relevant_prices = price_lookup[key]
552
- try:
553
- input_tokens = int(usage[self.input_token_name])
554
- output_tokens = int(usage[self.output_token_name])
555
- except Exception as e:
556
- return f"Could not fetch tokens from model response: {e}"
557
-
558
- try:
559
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
560
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
561
- except Exception as e:
562
- if "output" not in relevant_prices:
563
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
564
- if "input" not in relevant_prices:
565
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
566
- return f"Could not fetch prices from {relevant_prices} - {e}"
567
-
568
- if inverse_input_price == "infinity":
569
- input_cost = 0
570
- else:
571
- try:
572
- input_cost = input_tokens / float(inverse_input_price)
573
- except Exception as e:
574
- return f"Could not compute input price - {e}."
575
-
576
- if inverse_output_price == "infinity":
577
- output_cost = 0
578
- else:
579
- try:
580
- output_cost = output_tokens / float(inverse_output_price)
581
- except Exception as e:
582
- return f"Could not compute output price - {e}"
472
+ :param raw_response: The raw response from the model.
473
+ """
583
474
 
584
- return input_cost + output_cost
475
+ usage = self.get_usage_dict(raw_response)
476
+ from edsl.language_models.PriceManager import PriceManager
477
+
478
+ price_manger = PriceManager()
479
+ return price_manger.calculate_cost(
480
+ inference_service=self._inference_service_,
481
+ model=self.model,
482
+ usage=usage,
483
+ input_token_name=self.input_token_name,
484
+ output_token_name=self.output_token_name,
485
+ )
585
486
 
586
487
  def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
587
488
  """Convert instance to a dictionary
588
489
 
490
+ :param add_edsl_version: Whether to add the EDSL version to the dictionary.
491
+
589
492
  >>> m = LanguageModel.example()
590
493
  >>> m.to_dict()
591
494
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
@@ -608,7 +511,7 @@ class LanguageModel(
608
511
  return model_class(**data)
609
512
 
610
513
  def __repr__(self) -> str:
611
- """Return a string representation of the object."""
514
+ """Return a representation of the object."""
612
515
  param_string = ", ".join(
613
516
  f"{key} = {value}" for key, value in self.parameters.items()
614
517
  )
@@ -650,7 +553,7 @@ class LanguageModel(
650
553
  Exception report saved to ...
651
554
  Also see: ...
652
555
  """
653
- from edsl import Model
556
+ from edsl.language_models.registry import Model
654
557
 
655
558
  if test_model:
656
559
  m = Model(
@@ -1,12 +1,12 @@
1
1
  from typing import Optional, List
2
2
  from collections import UserList
3
- from edsl import Model
4
3
 
5
- from edsl.language_models import LanguageModel
6
4
  from edsl.Base import Base
7
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
8
- from edsl.utilities.utilities import is_valid_variable_name
9
- from edsl.utilities.utilities import dict_hash
5
+ from edsl.language_models.registry import Model
6
+
7
+ # from edsl.language_models import LanguageModel
8
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
9
+ from edsl.utilities.is_valid_variable_name import is_valid_variable_name
10
10
 
11
11
 
12
12
  class ModelList(Base, UserList):
@@ -40,7 +40,7 @@ class ModelList(Base, UserList):
40
40
  return f"ModelList({super().__repr__()})"
41
41
 
42
42
  def _summary(self):
43
- return {"EDSL Class": "ModelList", "Number of Models": len(self)}
43
+ return {"models": len(self)}
44
44
 
45
45
  def __hash__(self):
46
46
  """Return a hash of the ModelList. This is used for comparison of ModelLists.
@@ -54,7 +54,8 @@ class ModelList(Base, UserList):
54
54
  return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
55
55
 
56
56
  def to_scenario_list(self):
57
- from edsl import ScenarioList, Scenario
57
+ from edsl.scenarios.ScenarioList import ScenarioList
58
+ from edsl.scenarios.Scenario import Scenario
58
59
 
59
60
  sl = ScenarioList()
60
61
  for model in self:
@@ -73,7 +74,7 @@ class ModelList(Base, UserList):
73
74
  pretty_labels: Optional[dict] = None,
74
75
  ):
75
76
  """
76
- >>> ModelList.example().table("model")
77
+ >>> ModelList.example().table('model')
77
78
  model
78
79
  -------
79
80
  gpt-4o
@@ -112,11 +113,6 @@ class ModelList(Base, UserList):
112
113
 
113
114
  return d
114
115
 
115
- def _repr_html_(self):
116
- """Return an HTML representation of the ModelList."""
117
- footer = f"<a href={self.__documentation__}>(docs)</a>"
118
- return str(self.summary(format="html")) + footer
119
-
120
116
  @classmethod
121
117
  def from_names(self, *args, **kwargs):
122
118
  """A a model list from a list of names"""
@@ -133,6 +129,8 @@ class ModelList(Base, UserList):
133
129
  >>> newm = ModelList.from_dict(ModelList.example().to_dict())
134
130
  >>> assert ModelList.example() == newm
135
131
  """
132
+ from edsl.language_models.LanguageModel import LanguageModel
133
+
136
134
  return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
137
135
 
138
136
  def code(self):