edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +3 -8
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -40
  5. edsl/agents/AgentList.py +0 -43
  6. edsl/agents/Invigilator.py +219 -135
  7. edsl/agents/InvigilatorBase.py +59 -148
  8. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
  9. edsl/agents/__init__.py +0 -1
  10. edsl/config.py +56 -47
  11. edsl/coop/coop.py +7 -50
  12. edsl/data/Cache.py +1 -35
  13. edsl/data_transfer_models.py +38 -73
  14. edsl/enums.py +0 -4
  15. edsl/exceptions/language_models.py +1 -25
  16. edsl/exceptions/questions.py +5 -62
  17. edsl/exceptions/results.py +0 -4
  18. edsl/inference_services/AnthropicService.py +11 -13
  19. edsl/inference_services/AwsBedrock.py +17 -19
  20. edsl/inference_services/AzureAI.py +20 -37
  21. edsl/inference_services/GoogleService.py +12 -16
  22. edsl/inference_services/GroqService.py +0 -2
  23. edsl/inference_services/InferenceServiceABC.py +3 -58
  24. edsl/inference_services/OpenAIService.py +54 -48
  25. edsl/inference_services/models_available_cache.py +6 -0
  26. edsl/inference_services/registry.py +0 -6
  27. edsl/jobs/Answers.py +12 -10
  28. edsl/jobs/Jobs.py +21 -36
  29. edsl/jobs/buckets/BucketCollection.py +15 -24
  30. edsl/jobs/buckets/TokenBucket.py +14 -93
  31. edsl/jobs/interviews/Interview.py +78 -366
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
  33. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
  34. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
  35. edsl/jobs/interviews/retry_management.py +37 -0
  36. edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
  37. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  38. edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
  39. edsl/jobs/tasks/TaskHistory.py +213 -148
  40. edsl/language_models/LanguageModel.py +156 -261
  41. edsl/language_models/ModelList.py +2 -2
  42. edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
  43. edsl/language_models/registry.py +6 -23
  44. edsl/language_models/repair.py +19 -0
  45. edsl/prompts/Prompt.py +2 -52
  46. edsl/questions/AnswerValidatorMixin.py +26 -23
  47. edsl/questions/QuestionBase.py +249 -329
  48. edsl/questions/QuestionBudget.py +41 -99
  49. edsl/questions/QuestionCheckBox.py +35 -227
  50. edsl/questions/QuestionExtract.py +27 -98
  51. edsl/questions/QuestionFreeText.py +29 -52
  52. edsl/questions/QuestionFunctional.py +0 -7
  53. edsl/questions/QuestionList.py +22 -141
  54. edsl/questions/QuestionMultipleChoice.py +65 -159
  55. edsl/questions/QuestionNumerical.py +46 -88
  56. edsl/questions/QuestionRank.py +24 -182
  57. edsl/questions/RegisterQuestionsMeta.py +12 -31
  58. edsl/questions/__init__.py +4 -3
  59. edsl/questions/derived/QuestionLikertFive.py +5 -10
  60. edsl/questions/derived/QuestionLinearScale.py +2 -15
  61. edsl/questions/derived/QuestionTopK.py +1 -10
  62. edsl/questions/derived/QuestionYesNo.py +3 -24
  63. edsl/questions/descriptors.py +7 -43
  64. edsl/questions/question_registry.py +2 -6
  65. edsl/results/Dataset.py +0 -20
  66. edsl/results/DatasetExportMixin.py +48 -46
  67. edsl/results/Result.py +5 -32
  68. edsl/results/Results.py +46 -135
  69. edsl/results/ResultsDBMixin.py +3 -3
  70. edsl/scenarios/FileStore.py +10 -71
  71. edsl/scenarios/Scenario.py +25 -96
  72. edsl/scenarios/ScenarioImageMixin.py +2 -2
  73. edsl/scenarios/ScenarioList.py +39 -361
  74. edsl/scenarios/ScenarioListExportMixin.py +0 -9
  75. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  76. edsl/study/SnapShot.py +1 -8
  77. edsl/study/Study.py +0 -32
  78. edsl/surveys/Rule.py +1 -10
  79. edsl/surveys/RuleCollection.py +5 -21
  80. edsl/surveys/Survey.py +310 -636
  81. edsl/surveys/SurveyExportMixin.py +9 -71
  82. edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
  83. edsl/surveys/SurveyQualtricsImport.py +4 -75
  84. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  85. edsl/utilities/utilities.py +1 -9
  86. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
  87. edsl-0.1.33.dev1.dist-info/RECORD +209 -0
  88. edsl/TemplateLoader.py +0 -24
  89. edsl/auto/AutoStudy.py +0 -117
  90. edsl/auto/StageBase.py +0 -230
  91. edsl/auto/StageGenerateSurvey.py +0 -178
  92. edsl/auto/StageLabelQuestions.py +0 -125
  93. edsl/auto/StagePersona.py +0 -61
  94. edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
  95. edsl/auto/StagePersonaDimensionValues.py +0 -74
  96. edsl/auto/StagePersonaDimensions.py +0 -69
  97. edsl/auto/StageQuestions.py +0 -73
  98. edsl/auto/SurveyCreatorPipeline.py +0 -21
  99. edsl/auto/utilities.py +0 -224
  100. edsl/coop/PriceFetcher.py +0 -58
  101. edsl/inference_services/MistralAIService.py +0 -120
  102. edsl/inference_services/TestService.py +0 -80
  103. edsl/inference_services/TogetherAIService.py +0 -170
  104. edsl/jobs/FailedQuestion.py +0 -78
  105. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  106. edsl/language_models/fake_openai_call.py +0 -15
  107. edsl/language_models/fake_openai_service.py +0 -61
  108. edsl/language_models/utilities.py +0 -61
  109. edsl/questions/QuestionBaseGenMixin.py +0 -133
  110. edsl/questions/QuestionBasePromptsMixin.py +0 -266
  111. edsl/questions/Quick.py +0 -41
  112. edsl/questions/ResponseValidatorABC.py +0 -170
  113. edsl/questions/decorators.py +0 -21
  114. edsl/questions/prompt_templates/question_budget.jinja +0 -13
  115. edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
  116. edsl/questions/prompt_templates/question_extract.jinja +0 -11
  117. edsl/questions/prompt_templates/question_free_text.jinja +0 -3
  118. edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
  119. edsl/questions/prompt_templates/question_list.jinja +0 -17
  120. edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
  121. edsl/questions/prompt_templates/question_numerical.jinja +0 -37
  122. edsl/questions/templates/__init__.py +0 -0
  123. edsl/questions/templates/budget/__init__.py +0 -0
  124. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  125. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  126. edsl/questions/templates/checkbox/__init__.py +0 -0
  127. edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
  128. edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
  129. edsl/questions/templates/extract/__init__.py +0 -0
  130. edsl/questions/templates/extract/answering_instructions.jinja +0 -7
  131. edsl/questions/templates/extract/question_presentation.jinja +0 -1
  132. edsl/questions/templates/free_text/__init__.py +0 -0
  133. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  134. edsl/questions/templates/free_text/question_presentation.jinja +0 -1
  135. edsl/questions/templates/likert_five/__init__.py +0 -0
  136. edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
  137. edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
  138. edsl/questions/templates/linear_scale/__init__.py +0 -0
  139. edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
  140. edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
  141. edsl/questions/templates/list/__init__.py +0 -0
  142. edsl/questions/templates/list/answering_instructions.jinja +0 -4
  143. edsl/questions/templates/list/question_presentation.jinja +0 -5
  144. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  145. edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
  146. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  147. edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
  148. edsl/questions/templates/numerical/__init__.py +0 -0
  149. edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
  150. edsl/questions/templates/numerical/question_presentation.jinja +0 -7
  151. edsl/questions/templates/rank/__init__.py +0 -0
  152. edsl/questions/templates/rank/answering_instructions.jinja +0 -11
  153. edsl/questions/templates/rank/question_presentation.jinja +0 -15
  154. edsl/questions/templates/top_k/__init__.py +0 -0
  155. edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
  156. edsl/questions/templates/top_k/question_presentation.jinja +0 -22
  157. edsl/questions/templates/yes_no/__init__.py +0 -0
  158. edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
  159. edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
  160. edsl/results/DatasetTree.py +0 -145
  161. edsl/results/Selector.py +0 -118
  162. edsl/results/tree_explore.py +0 -115
  163. edsl/surveys/instructions/ChangeInstruction.py +0 -47
  164. edsl/surveys/instructions/Instruction.py +0 -34
  165. edsl/surveys/instructions/InstructionCollection.py +0 -77
  166. edsl/surveys/instructions/__init__.py +0 -0
  167. edsl/templates/error_reporting/base.html +0 -24
  168. edsl/templates/error_reporting/exceptions_by_model.html +0 -35
  169. edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
  170. edsl/templates/error_reporting/exceptions_by_type.html +0 -17
  171. edsl/templates/error_reporting/interview_details.html +0 -116
  172. edsl/templates/error_reporting/interviews.html +0 -10
  173. edsl/templates/error_reporting/overview.html +0 -5
  174. edsl/templates/error_reporting/performance_plot.html +0 -2
  175. edsl/templates/error_reporting/report.css +0 -74
  176. edsl/templates/error_reporting/report.html +0 -118
  177. edsl/templates/error_reporting/report.js +0 -25
  178. edsl-0.1.33.dist-info/RECORD +0 -295
  179. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
  180. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -1,16 +1,4 @@
1
- """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
-
3
- Terminology:
4
-
5
- raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
-
7
- edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
- such as the cache key, token usage, etc.
9
-
10
- generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
- edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
-
13
- """
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models."""
14
2
 
15
3
  from __future__ import annotations
16
4
  import warnings
@@ -20,103 +8,47 @@ import json
20
8
  import time
21
9
  import os
22
10
  import hashlib
23
- from typing import (
24
- Coroutine,
25
- Any,
26
- Callable,
27
- Type,
28
- Union,
29
- List,
30
- get_type_hints,
31
- TypedDict,
32
- Optional,
33
- )
11
+ from typing import Coroutine, Any, Callable, Type, List, get_type_hints
34
12
  from abc import ABC, abstractmethod
35
13
 
36
- from json_repair import repair_json
37
14
 
38
- from edsl.data_transfer_models import (
39
- ModelResponse,
40
- ModelInputs,
41
- EDSLOutput,
42
- AgentResponseDict,
43
- )
15
+ class IntendedModelCallOutcome:
16
+ "This is a tuple-like class that holds the response, cache_used, and cache_key."
17
+
18
+ def __init__(self, response: dict, cache_used: bool, cache_key: str):
19
+ self.response = response
20
+ self.cache_used = cache_used
21
+ self.cache_key = cache_key
22
+
23
+ def __iter__(self):
24
+ """Iterate over the class attributes.
25
+
26
+ >>> a, b, c = IntendedModelCallOutcome({'answer': "yes"}, True, 'x1289')
27
+ >>> a
28
+ {'answer': 'yes'}
29
+ """
30
+ yield self.response
31
+ yield self.cache_used
32
+ yield self.cache_key
33
+
34
+ def __len__(self):
35
+ return 3
36
+
37
+ def __repr__(self):
38
+ return f"IntendedModelCallOutcome(response = {self.response}, cache_used = {self.cache_used}, cache_key = '{self.cache_key}')"
44
39
 
45
40
 
46
41
  from edsl.config import CONFIG
42
+
47
43
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
48
44
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
45
+
49
46
  from edsl.language_models.repair import repair
50
47
  from edsl.enums import InferenceServiceType
51
48
  from edsl.Base import RichPrintingMixin, PersistenceMixin
52
49
  from edsl.enums import service_to_api_keyname
53
50
  from edsl.exceptions import MissingAPIKeyError
54
51
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
- from edsl.exceptions.language_models import LanguageModelBadResponseError
56
-
57
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
58
-
59
-
60
- def convert_answer(response_part):
61
- import json
62
-
63
- response_part = response_part.strip()
64
-
65
- if response_part == "None":
66
- return None
67
-
68
- repaired = repair_json(response_part)
69
- if repaired == '""':
70
- # it was a literal string
71
- return response_part
72
-
73
- try:
74
- return json.loads(repaired)
75
- except json.JSONDecodeError as j:
76
- # last resort
77
- return response_part
78
-
79
-
80
- def extract_item_from_raw_response(data, key_sequence):
81
- if isinstance(data, str):
82
- try:
83
- data = json.loads(data)
84
- except json.JSONDecodeError as e:
85
- return data
86
- current_data = data
87
- for i, key in enumerate(key_sequence):
88
- try:
89
- if isinstance(current_data, (list, tuple)):
90
- if not isinstance(key, int):
91
- raise TypeError(
92
- f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
93
- )
94
- if key < 0 or key >= len(current_data):
95
- raise IndexError(
96
- f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
97
- )
98
- elif isinstance(current_data, dict):
99
- if key not in current_data:
100
- raise KeyError(
101
- f"Key '{key}' not found in dictionary at position {i}"
102
- )
103
- else:
104
- raise TypeError(
105
- f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
106
- )
107
-
108
- current_data = current_data[key]
109
- except Exception as e:
110
- path = " -> ".join(map(str, key_sequence[: i + 1]))
111
- if "error" in data:
112
- msg = data["error"]
113
- else:
114
- msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
115
- raise LanguageModelBadResponseError(message=msg, response_json=data)
116
- if isinstance(current_data, str):
117
- return current_data.strip()
118
- else:
119
- return current_data
120
52
 
121
53
 
122
54
  def handle_key_error(func):
@@ -160,29 +92,21 @@ class LanguageModel(
160
92
  """
161
93
 
162
94
  _model_ = None
163
- key_sequence = (
164
- None # This should be something like ["choices", 0, "message", "content"]
165
- )
95
+
166
96
  __rate_limits = None
97
+ __default_rate_limits = {
98
+ "rpm": 10_000,
99
+ "tpm": 2_000_000,
100
+ } # TODO: Use the OpenAI Teir 1 rate limits
167
101
  _safety_factor = 0.8
168
102
 
169
- def __init__(
170
- self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
- ):
103
+ def __init__(self, **kwargs):
172
104
  """Initialize the LanguageModel."""
173
105
  self.model = getattr(self, "_model_", None)
174
106
  default_parameters = getattr(self, "_parameters_", None)
175
107
  parameters = self._overide_default_parameters(kwargs, default_parameters)
176
108
  self.parameters = parameters
177
109
  self.remote = False
178
- self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
179
-
180
- # self._rpm / _tpm comes from the class
181
- if rpm is not None:
182
- self._rpm = rpm
183
-
184
- if tpm is not None:
185
- self._tpm = tpm
186
110
 
187
111
  for key, value in parameters.items():
188
112
  setattr(self, key, value)
@@ -209,6 +133,7 @@ class LanguageModel(
209
133
  def api_token(self) -> str:
210
134
  if not hasattr(self, "_api_token"):
211
135
  key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
136
+
212
137
  if self._inference_service_ == "bedrock":
213
138
  self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
214
139
  # Check if any of the tokens are None
@@ -217,13 +142,13 @@ class LanguageModel(
217
142
  self._api_token = os.getenv(key_name)
218
143
  missing_token = self._api_token is None
219
144
  if missing_token and self._inference_service_ != "test" and not self.remote:
220
- print("raising error")
145
+ print("rainsing error")
221
146
  raise MissingAPIKeyError(
222
147
  f"""The key for service: `{self._inference_service_}` is not set.
223
148
  Need a key with name {key_name} in your .env file."""
224
149
  )
225
150
 
226
- return self._api_token
151
+ return self._api_token
227
152
 
228
153
  def __getitem__(self, key):
229
154
  return getattr(self, key)
@@ -284,58 +209,40 @@ class LanguageModel(
284
209
  >>> m = LanguageModel.example()
285
210
  >>> m.set_rate_limits(rpm=100, tpm=1000)
286
211
  >>> m.RPM
287
- 100
212
+ 80.0
288
213
  """
289
- if rpm is not None:
290
- self._rpm = rpm
291
- if tpm is not None:
292
- self._tpm = tpm
293
- return None
294
- # self._set_rate_limits(rpm=rpm, tpm=tpm)
295
-
296
- # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
- # """Set the rate limits for the model.
298
-
299
- # If the model does not have rate limits, use the default rate limits."""
300
- # if rpm is not None and tpm is not None:
301
- # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
- # return
303
-
304
- # if self.__rate_limits is None:
305
- # if hasattr(self, "get_rate_limits"):
306
- # self.__rate_limits = self.get_rate_limits()
307
- # else:
308
- # self.__rate_limits = self.__default_rate_limits
214
+ self._set_rate_limits(rpm=rpm, tpm=tpm)
215
+
216
+ def _set_rate_limits(self, rpm=None, tpm=None) -> None:
217
+ """Set the rate limits for the model.
218
+
219
+ If the model does not have rate limits, use the default rate limits."""
220
+ if rpm is not None and tpm is not None:
221
+ self.__rate_limits = {"rpm": rpm, "tpm": tpm}
222
+ return
223
+
224
+ if self.__rate_limits is None:
225
+ if hasattr(self, "get_rate_limits"):
226
+ self.__rate_limits = self.get_rate_limits()
227
+ else:
228
+ self.__rate_limits = self.__default_rate_limits
309
229
 
310
230
  @property
311
231
  def RPM(self):
312
232
  """Model's requests-per-minute limit."""
313
- # self._set_rate_limits()
314
- # return self._safety_factor * self.__rate_limits["rpm"]
315
- return self._rpm
233
+ self._set_rate_limits()
234
+ return self._safety_factor * self.__rate_limits["rpm"]
316
235
 
317
236
  @property
318
237
  def TPM(self):
319
- """Model's tokens-per-minute limit."""
320
- # self._set_rate_limits()
321
- # return self._safety_factor * self.__rate_limits["tpm"]
322
- return self._tpm
323
-
324
- @property
325
- def rpm(self):
326
- return self._rpm
238
+ """Model's tokens-per-minute limit.
327
239
 
328
- @rpm.setter
329
- def rpm(self, value):
330
- self._rpm = value
331
-
332
- @property
333
- def tpm(self):
334
- return self._tpm
335
-
336
- @tpm.setter
337
- def tpm(self, value):
338
- self._tpm = value
240
+ >>> m = LanguageModel.example()
241
+ >>> m.TPM > 0
242
+ True
243
+ """
244
+ self._set_rate_limits()
245
+ return self._safety_factor * self.__rate_limits["tpm"]
339
246
 
340
247
  @staticmethod
341
248
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
@@ -363,10 +270,11 @@ class LanguageModel(
363
270
  >>> m = LanguageModel.example(test_model = True)
364
271
  >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
365
272
  >>> asyncio.run(test())
366
- {'message': [{'text': 'Hello world'}], ...}
273
+ {'message': '{"answer": "Hello world"}'}
367
274
 
368
275
  >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
369
- {'message': [{'text': 'Hello world'}], ...}
276
+ {'message': '{"answer": "Hello world"}'}
277
+
370
278
  """
371
279
  pass
372
280
 
@@ -399,40 +307,68 @@ class LanguageModel(
399
307
 
400
308
  return main()
401
309
 
402
- @classmethod
403
- def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
404
- """Return the generated token string from the raw response."""
405
- return extract_item_from_raw_response(raw_response, cls.key_sequence)
310
+ @abstractmethod
311
+ def parse_response(raw_response: dict[str, Any]) -> str:
312
+ """Parse the response and returns the response text.
406
313
 
407
- @classmethod
408
- def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
409
- """Return the usage dictionary from the raw response."""
410
- if not hasattr(cls, "usage_sequence"):
411
- raise NotImplementedError(
412
- "This inference service does not have a usage_sequence."
314
+ >>> m = LanguageModel.example(test_model = True)
315
+ >>> m
316
+ Model(model_name = 'test', temperature = 0.5)
317
+
318
+ What is returned by the API is model-specific and often includes meta-data that we do not need.
319
+ For example, here is the results from a call to GPT-4:
320
+ To actually track the response, we need to grab
321
+ data["choices[0]"]["message"]["content"].
322
+ """
323
+ raise NotImplementedError
324
+
325
+ async def _async_prepare_response(
326
+ self, model_call_outcome: IntendedModelCallOutcome, cache: "Cache"
327
+ ) -> dict:
328
+ """Prepare the response for return."""
329
+
330
+ model_response = {
331
+ "cache_used": model_call_outcome.cache_used,
332
+ "cache_key": model_call_outcome.cache_key,
333
+ "usage": model_call_outcome.response.get("usage", {}),
334
+ "raw_model_response": model_call_outcome.response,
335
+ }
336
+
337
+ answer_portion = self.parse_response(model_call_outcome.response)
338
+ try:
339
+ answer_dict = json.loads(answer_portion)
340
+ except json.JSONDecodeError as e:
341
+ # TODO: Turn into logs to generate issues
342
+ answer_dict, success = await repair(
343
+ bad_json=answer_portion, error_message=str(e), cache=cache
413
344
  )
414
- return extract_item_from_raw_response(raw_response, cls.usage_sequence)
345
+ if not success:
346
+ raise Exception(
347
+ f"""Even the repair failed. The error was: {e}. The response was: {answer_portion}."""
348
+ )
415
349
 
416
- @classmethod
417
- def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
418
- """Parses the API response and returns the response text."""
419
- generated_token_string = cls.get_generated_token_string(raw_response)
420
- last_newline = generated_token_string.rfind("\n")
421
-
422
- if last_newline == -1:
423
- # There is no comment
424
- edsl_dict = {
425
- "answer": convert_answer(generated_token_string),
426
- "generated_tokens": generated_token_string,
427
- "comment": None,
428
- }
429
- else:
430
- edsl_dict = {
431
- "answer": convert_answer(generated_token_string[:last_newline]),
432
- "comment": generated_token_string[last_newline + 1 :].strip(),
433
- "generated_tokens": generated_token_string,
434
- }
435
- return EDSLOutput(**edsl_dict)
350
+ return {**model_response, **answer_dict}
351
+
352
+ async def async_get_raw_response(
353
+ self,
354
+ user_prompt: str,
355
+ system_prompt: str,
356
+ cache: "Cache",
357
+ iteration: int = 0,
358
+ encoded_image=None,
359
+ ) -> IntendedModelCallOutcome:
360
+ import warnings
361
+
362
+ warnings.warn(
363
+ "This method is deprecated. Use async_get_intended_model_call_outcome."
364
+ )
365
+ return await self._async_get_intended_model_call_outcome(
366
+ user_prompt=user_prompt,
367
+ system_prompt=system_prompt,
368
+ cache=cache,
369
+ iteration=iteration,
370
+ encoded_image=encoded_image,
371
+ )
436
372
 
437
373
  async def _async_get_intended_model_call_outcome(
438
374
  self,
@@ -441,7 +377,7 @@ class LanguageModel(
441
377
  cache: "Cache",
442
378
  iteration: int = 0,
443
379
  encoded_image=None,
444
- ) -> ModelResponse:
380
+ ) -> IntendedModelCallOutcome:
445
381
  """Handle caching of responses.
446
382
 
447
383
  :param user_prompt: The user's prompt.
@@ -460,18 +396,18 @@ class LanguageModel(
460
396
  >>> from edsl import Cache
461
397
  >>> m = LanguageModel.example(test_model = True)
462
398
  >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
463
- ModelResponse(...)"""
399
+ IntendedModelCallOutcome(response = {'message': '{"answer": "Hello world"}'}, cache_used = False, cache_key = '24ff6ac2bc2f1729f817f261e0792577')
400
+ """
464
401
 
465
402
  if encoded_image:
466
403
  # the image has is appended to the user_prompt for hash-lookup purposes
467
404
  image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
468
- user_prompt += f" {image_hash}"
469
405
 
470
406
  cache_call_params = {
471
407
  "model": str(self.model),
472
408
  "parameters": self.parameters,
473
409
  "system_prompt": system_prompt,
474
- "user_prompt": user_prompt,
410
+ "user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
475
411
  "iteration": iteration,
476
412
  }
477
413
  cached_response, cache_key = cache.fetch(**cache_call_params)
@@ -489,28 +425,21 @@ class LanguageModel(
489
425
  "system_prompt": system_prompt,
490
426
  **({"encoded_image": encoded_image} if encoded_image else {}),
491
427
  }
492
- # response = await f(**params)
493
- response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
428
+ response = await f(**params)
494
429
  new_cache_key = cache.store(
495
430
  **cache_call_params, response=response
496
431
  ) # store the response in the cache
497
432
  assert new_cache_key == cache_key # should be the same
498
433
 
499
- cost = self.cost(response)
500
-
501
- return ModelResponse(
502
- response=response,
503
- cache_used=cache_used,
504
- cache_key=cache_key,
505
- cached_response=cached_response,
506
- cost=cost,
434
+ return IntendedModelCallOutcome(
435
+ response=response, cache_used=cache_used, cache_key=cache_key
507
436
  )
508
437
 
509
438
  _get_intended_model_call_outcome = sync_wrapper(
510
439
  _async_get_intended_model_call_outcome
511
440
  )
512
441
 
513
- # get_raw_response = sync_wrapper(async_get_raw_response)
442
+ get_raw_response = sync_wrapper(async_get_raw_response)
514
443
 
515
444
  def simple_ask(
516
445
  self,
@@ -549,66 +478,14 @@ class LanguageModel(
549
478
  "cache": cache,
550
479
  **({"encoded_image": encoded_image} if encoded_image else {}),
551
480
  }
552
- model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
553
- model_outputs = await self._async_get_intended_model_call_outcome(**params)
554
- edsl_dict = self.parse_response(model_outputs.response)
555
- agent_response_dict = AgentResponseDict(
556
- model_inputs=model_inputs,
557
- model_outputs=model_outputs,
558
- edsl_dict=edsl_dict,
559
- )
560
- return agent_response_dict
561
-
562
- # return await self._async_prepare_response(model_call_outcome, cache=cache)
481
+ model_call_outcome = await self._async_get_intended_model_call_outcome(**params)
482
+ return await self._async_prepare_response(model_call_outcome, cache=cache)
563
483
 
564
484
  get_response = sync_wrapper(async_get_response)
565
485
 
566
- def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
486
+ def cost(self, raw_response: dict[str, Any]) -> float:
567
487
  """Return the dollar cost of a raw response."""
568
-
569
- usage = self.get_usage_dict(raw_response)
570
- from edsl.coop import Coop
571
-
572
- c = Coop()
573
- price_lookup = c.fetch_prices()
574
- key = (self._inference_service_, self.model)
575
- if key not in price_lookup:
576
- return f"Could not find price for model {self.model} in the price lookup."
577
-
578
- relevant_prices = price_lookup[key]
579
- try:
580
- input_tokens = int(usage[self.input_token_name])
581
- output_tokens = int(usage[self.output_token_name])
582
- except Exception as e:
583
- return f"Could not fetch tokens from model response: {e}"
584
-
585
- try:
586
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
587
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
588
- except Exception as e:
589
- if "output" not in relevant_prices:
590
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
591
- if "input" not in relevant_prices:
592
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
593
- return f"Could not fetch prices from {relevant_prices} - {e}"
594
-
595
- if inverse_input_price == "infinity":
596
- input_cost = 0
597
- else:
598
- try:
599
- input_cost = input_tokens / float(inverse_input_price)
600
- except Exception as e:
601
- return f"Could not compute input price - {e}."
602
-
603
- if inverse_output_price == "infinity":
604
- output_cost = 0
605
- else:
606
- try:
607
- output_cost = output_tokens / float(inverse_output_price)
608
- except Exception as e:
609
- return f"Could not compute output price - {e}"
610
-
611
- return input_cost + output_cost
488
+ raise NotImplementedError
612
489
 
613
490
  #######################
614
491
  # SERIALIZATION METHODS
@@ -622,7 +499,7 @@ class LanguageModel(
622
499
 
623
500
  >>> m = LanguageModel.example()
624
501
  >>> m.to_dict()
625
- {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
502
+ {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
626
503
  """
627
504
  return self._to_dict()
628
505
 
@@ -698,8 +575,26 @@ class LanguageModel(
698
575
  """
699
576
  from edsl import Model
700
577
 
578
+ class TestLanguageModelGood(LanguageModel):
579
+ use_cache = False
580
+ _model_ = "test"
581
+ _parameters_ = {"temperature": 0.5}
582
+ _inference_service_ = InferenceServiceType.TEST.value
583
+
584
+ async def async_execute_model_call(
585
+ self, user_prompt: str, system_prompt: str
586
+ ) -> dict[str, Any]:
587
+ await asyncio.sleep(0.1)
588
+ # return {"message": """{"answer": "Hello, world"}"""}
589
+ if throw_exception:
590
+ raise Exception("This is a test error")
591
+ return {"message": f'{{"answer": "{canned_response}"}}'}
592
+
593
+ def parse_response(self, raw_response: dict[str, Any]) -> str:
594
+ return raw_response["message"]
595
+
701
596
  if test_model:
702
- m = Model("test", canned_response=canned_response)
597
+ m = TestLanguageModelGood()
703
598
  return m
704
599
  else:
705
600
  return Model(skip_api_key_check=True)
@@ -40,8 +40,8 @@ class ModelList(Base, UserList):
40
40
  def __hash__(self):
41
41
  """Return a hash of the ModelList. This is used for comparison of ModelLists.
42
42
 
43
- >>> isinstance(hash(Model()), int)
44
- True
43
+ >>> hash(ModelList.example())
44
+ 1423518243781418961
45
45
 
46
46
  """
47
47
  from edsl.utilities.utilities import dict_hash
@@ -47,6 +47,13 @@ class RegisterLanguageModelsMeta(ABCMeta):
47
47
  must_be_async=True,
48
48
  )
49
49
  # LanguageModel children have to implement the parse_response method
50
+ RegisterLanguageModelsMeta.verify_method(
51
+ candidate_class=cls,
52
+ method_name="parse_response",
53
+ expected_return_type=str,
54
+ required_parameters=[("raw_response", dict[str, Any])],
55
+ must_be_async=False,
56
+ )
50
57
  RegisterLanguageModelsMeta._registry[model_name] = cls
51
58
 
52
59
  @classmethod
@@ -91,7 +98,7 @@ class RegisterLanguageModelsMeta(ABCMeta):
91
98
 
92
99
  required_parameters = required_parameters or []
93
100
  method = getattr(candidate_class, method_name)
94
- # signature = inspect.signature(method)
101
+ signature = inspect.signature(method)
95
102
 
96
103
  RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
97
104
 
@@ -99,11 +106,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
99
106
  RegisterLanguageModelsMeta._check_is_coroutine(method)
100
107
 
101
108
  # Check the parameters
102
- # params = signature.parameters
103
- # for param_name, param_type in required_parameters:
104
- # RegisterLanguageModelsMeta._verify_parameter(
105
- # params, param_name, param_type, method_name
106
- # )
109
+ params = signature.parameters
110
+ for param_name, param_type in required_parameters:
111
+ RegisterLanguageModelsMeta._verify_parameter(
112
+ params, param_name, param_type, method_name
113
+ )
107
114
 
108
115
  @staticmethod
109
116
  def _check_method_defined(cls, method_name):
@@ -160,15 +167,23 @@ class RegisterLanguageModelsMeta(ABCMeta):
160
167
  Check if the return type of a method is as expected.
161
168
 
162
169
  Example:
170
+ >>> class M:
171
+ ... async def f(self) -> str: pass
172
+ >>> RegisterLanguageModelsMeta._check_return_type(M.f, str)
173
+ >>> class N:
174
+ ... async def f(self) -> int: pass
175
+ >>> RegisterLanguageModelsMeta._check_return_type(N.f, str)
176
+ Traceback (most recent call last):
177
+ ...
178
+ TypeError: Return type of f must be <class 'str'>. Got <class 'int'>.
163
179
  """
164
- pass
165
- # if inspect.isroutine(method):
166
- # # return_type = inspect.signature(method).return_annotation
167
- # return_type = get_type_hints(method)["return"]
168
- # if return_type != expected_return_type:
169
- # raise TypeError(
170
- # f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
171
- # )
180
+ if inspect.isroutine(method):
181
+ # return_type = inspect.signature(method).return_annotation
182
+ return_type = get_type_hints(method)["return"]
183
+ if return_type != expected_return_type:
184
+ raise TypeError(
185
+ f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
186
+ )
172
187
 
173
188
  @classmethod
174
189
  def model_names_to_classes(cls):