edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +135 -219
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +47 -56
  23. edsl/coop/PriceFetcher.py +58 -0
  24. edsl/coop/coop.py +50 -7
  25. edsl/data/Cache.py +35 -1
  26. edsl/data_transfer_models.py +73 -38
  27. edsl/enums.py +4 -0
  28. edsl/exceptions/language_models.py +25 -1
  29. edsl/exceptions/questions.py +62 -5
  30. edsl/exceptions/results.py +4 -0
  31. edsl/inference_services/AnthropicService.py +13 -11
  32. edsl/inference_services/AwsBedrock.py +19 -17
  33. edsl/inference_services/AzureAI.py +37 -20
  34. edsl/inference_services/GoogleService.py +16 -12
  35. edsl/inference_services/GroqService.py +2 -0
  36. edsl/inference_services/InferenceServiceABC.py +58 -3
  37. edsl/inference_services/MistralAIService.py +120 -0
  38. edsl/inference_services/OpenAIService.py +48 -54
  39. edsl/inference_services/TestService.py +80 -0
  40. edsl/inference_services/TogetherAIService.py +170 -0
  41. edsl/inference_services/models_available_cache.py +0 -6
  42. edsl/inference_services/registry.py +6 -0
  43. edsl/jobs/Answers.py +10 -12
  44. edsl/jobs/FailedQuestion.py +78 -0
  45. edsl/jobs/Jobs.py +37 -22
  46. edsl/jobs/buckets/BucketCollection.py +24 -15
  47. edsl/jobs/buckets/TokenBucket.py +93 -14
  48. edsl/jobs/interviews/Interview.py +366 -78
  49. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
  50. edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
  51. edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
  52. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  53. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  54. edsl/jobs/tasks/TaskHistory.py +148 -213
  55. edsl/language_models/LanguageModel.py +261 -156
  56. edsl/language_models/ModelList.py +2 -2
  57. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  58. edsl/language_models/fake_openai_call.py +15 -0
  59. edsl/language_models/fake_openai_service.py +61 -0
  60. edsl/language_models/registry.py +23 -6
  61. edsl/language_models/repair.py +0 -19
  62. edsl/language_models/utilities.py +61 -0
  63. edsl/notebooks/Notebook.py +20 -2
  64. edsl/prompts/Prompt.py +52 -2
  65. edsl/questions/AnswerValidatorMixin.py +23 -26
  66. edsl/questions/QuestionBase.py +330 -249
  67. edsl/questions/QuestionBaseGenMixin.py +133 -0
  68. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  69. edsl/questions/QuestionBudget.py +99 -41
  70. edsl/questions/QuestionCheckBox.py +227 -35
  71. edsl/questions/QuestionExtract.py +98 -27
  72. edsl/questions/QuestionFreeText.py +52 -29
  73. edsl/questions/QuestionFunctional.py +7 -0
  74. edsl/questions/QuestionList.py +141 -22
  75. edsl/questions/QuestionMultipleChoice.py +159 -65
  76. edsl/questions/QuestionNumerical.py +88 -46
  77. edsl/questions/QuestionRank.py +182 -24
  78. edsl/questions/Quick.py +41 -0
  79. edsl/questions/RegisterQuestionsMeta.py +31 -12
  80. edsl/questions/ResponseValidatorABC.py +170 -0
  81. edsl/questions/__init__.py +3 -4
  82. edsl/questions/decorators.py +21 -0
  83. edsl/questions/derived/QuestionLikertFive.py +10 -5
  84. edsl/questions/derived/QuestionLinearScale.py +15 -2
  85. edsl/questions/derived/QuestionTopK.py +10 -1
  86. edsl/questions/derived/QuestionYesNo.py +24 -3
  87. edsl/questions/descriptors.py +43 -7
  88. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  89. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  90. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  91. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  92. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  93. edsl/questions/prompt_templates/question_list.jinja +17 -0
  94. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  95. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  96. edsl/questions/question_registry.py +6 -2
  97. edsl/questions/templates/__init__.py +0 -0
  98. edsl/questions/templates/budget/__init__.py +0 -0
  99. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  100. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  101. edsl/questions/templates/checkbox/__init__.py +0 -0
  102. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  104. edsl/questions/templates/extract/__init__.py +0 -0
  105. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  106. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  107. edsl/questions/templates/free_text/__init__.py +0 -0
  108. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  109. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  110. edsl/questions/templates/likert_five/__init__.py +0 -0
  111. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  112. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  113. edsl/questions/templates/linear_scale/__init__.py +0 -0
  114. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  115. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  116. edsl/questions/templates/list/__init__.py +0 -0
  117. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  118. edsl/questions/templates/list/question_presentation.jinja +5 -0
  119. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  120. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  121. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  122. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  123. edsl/questions/templates/numerical/__init__.py +0 -0
  124. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  125. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  126. edsl/questions/templates/rank/__init__.py +0 -0
  127. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  128. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  129. edsl/questions/templates/top_k/__init__.py +0 -0
  130. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  131. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  132. edsl/questions/templates/yes_no/__init__.py +0 -0
  133. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  134. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  135. edsl/results/Dataset.py +20 -0
  136. edsl/results/DatasetExportMixin.py +46 -48
  137. edsl/results/DatasetTree.py +145 -0
  138. edsl/results/Result.py +32 -5
  139. edsl/results/Results.py +135 -46
  140. edsl/results/ResultsDBMixin.py +3 -3
  141. edsl/results/Selector.py +118 -0
  142. edsl/results/tree_explore.py +115 -0
  143. edsl/scenarios/FileStore.py +71 -10
  144. edsl/scenarios/Scenario.py +96 -25
  145. edsl/scenarios/ScenarioImageMixin.py +2 -2
  146. edsl/scenarios/ScenarioList.py +361 -39
  147. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  148. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  149. edsl/study/SnapShot.py +8 -1
  150. edsl/study/Study.py +32 -0
  151. edsl/surveys/Rule.py +10 -1
  152. edsl/surveys/RuleCollection.py +21 -5
  153. edsl/surveys/Survey.py +637 -311
  154. edsl/surveys/SurveyExportMixin.py +71 -9
  155. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  156. edsl/surveys/SurveyQualtricsImport.py +75 -4
  157. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  158. edsl/surveys/instructions/Instruction.py +34 -0
  159. edsl/surveys/instructions/InstructionCollection.py +77 -0
  160. edsl/surveys/instructions/__init__.py +0 -0
  161. edsl/templates/error_reporting/base.html +24 -0
  162. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  163. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  164. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  165. edsl/templates/error_reporting/interview_details.html +116 -0
  166. edsl/templates/error_reporting/interviews.html +10 -0
  167. edsl/templates/error_reporting/overview.html +5 -0
  168. edsl/templates/error_reporting/performance_plot.html +2 -0
  169. edsl/templates/error_reporting/report.css +74 -0
  170. edsl/templates/error_reporting/report.html +118 -0
  171. edsl/templates/error_reporting/report.js +25 -0
  172. edsl/utilities/utilities.py +9 -1
  173. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
  174. edsl-0.1.33.dist-info/RECORD +295 -0
  175. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  176. edsl/jobs/interviews/retry_management.py +0 -37
  177. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  178. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  179. edsl-0.1.32.dist-info/RECORD +0 -209
  180. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  181. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -1,4 +1,16 @@
1
- """This module contains the LanguageModel class, which is an abstract base class for all language models."""
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
+
3
+ Terminology:
4
+
5
+ raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
+
7
+ edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
+ such as the cache key, token usage, etc.
9
+
10
+ generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
+ edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
+
13
+ """
2
14
 
3
15
  from __future__ import annotations
4
16
  import warnings
@@ -8,47 +20,103 @@ import json
8
20
  import time
9
21
  import os
10
22
  import hashlib
11
- from typing import Coroutine, Any, Callable, Type, List, get_type_hints
23
+ from typing import (
24
+ Coroutine,
25
+ Any,
26
+ Callable,
27
+ Type,
28
+ Union,
29
+ List,
30
+ get_type_hints,
31
+ TypedDict,
32
+ Optional,
33
+ )
12
34
  from abc import ABC, abstractmethod
13
35
 
36
+ from json_repair import repair_json
14
37
 
15
- class IntendedModelCallOutcome:
16
- "This is a tuple-like class that holds the response, cache_used, and cache_key."
17
-
18
- def __init__(self, response: dict, cache_used: bool, cache_key: str):
19
- self.response = response
20
- self.cache_used = cache_used
21
- self.cache_key = cache_key
22
-
23
- def __iter__(self):
24
- """Iterate over the class attributes.
25
-
26
- >>> a, b, c = IntendedModelCallOutcome({'answer': "yes"}, True, 'x1289')
27
- >>> a
28
- {'answer': 'yes'}
29
- """
30
- yield self.response
31
- yield self.cache_used
32
- yield self.cache_key
33
-
34
- def __len__(self):
35
- return 3
36
-
37
- def __repr__(self):
38
- return f"IntendedModelCallOutcome(response = {self.response}, cache_used = {self.cache_used}, cache_key = '{self.cache_key}')"
38
+ from edsl.data_transfer_models import (
39
+ ModelResponse,
40
+ ModelInputs,
41
+ EDSLOutput,
42
+ AgentResponseDict,
43
+ )
39
44
 
40
45
 
41
46
  from edsl.config import CONFIG
42
-
43
47
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
44
48
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
45
-
46
49
  from edsl.language_models.repair import repair
47
50
  from edsl.enums import InferenceServiceType
48
51
  from edsl.Base import RichPrintingMixin, PersistenceMixin
49
52
  from edsl.enums import service_to_api_keyname
50
53
  from edsl.exceptions import MissingAPIKeyError
51
54
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
56
+
57
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
58
+
59
+
60
+ def convert_answer(response_part):
61
+ import json
62
+
63
+ response_part = response_part.strip()
64
+
65
+ if response_part == "None":
66
+ return None
67
+
68
+ repaired = repair_json(response_part)
69
+ if repaired == '""':
70
+ # it was a literal string
71
+ return response_part
72
+
73
+ try:
74
+ return json.loads(repaired)
75
+ except json.JSONDecodeError as j:
76
+ # last resort
77
+ return response_part
78
+
79
+
80
+ def extract_item_from_raw_response(data, key_sequence):
81
+ if isinstance(data, str):
82
+ try:
83
+ data = json.loads(data)
84
+ except json.JSONDecodeError as e:
85
+ return data
86
+ current_data = data
87
+ for i, key in enumerate(key_sequence):
88
+ try:
89
+ if isinstance(current_data, (list, tuple)):
90
+ if not isinstance(key, int):
91
+ raise TypeError(
92
+ f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
93
+ )
94
+ if key < 0 or key >= len(current_data):
95
+ raise IndexError(
96
+ f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
97
+ )
98
+ elif isinstance(current_data, dict):
99
+ if key not in current_data:
100
+ raise KeyError(
101
+ f"Key '{key}' not found in dictionary at position {i}"
102
+ )
103
+ else:
104
+ raise TypeError(
105
+ f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
106
+ )
107
+
108
+ current_data = current_data[key]
109
+ except Exception as e:
110
+ path = " -> ".join(map(str, key_sequence[: i + 1]))
111
+ if "error" in data:
112
+ msg = data["error"]
113
+ else:
114
+ msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
115
+ raise LanguageModelBadResponseError(message=msg, response_json=data)
116
+ if isinstance(current_data, str):
117
+ return current_data.strip()
118
+ else:
119
+ return current_data
52
120
 
53
121
 
54
122
  def handle_key_error(func):
@@ -92,21 +160,29 @@ class LanguageModel(
92
160
  """
93
161
 
94
162
  _model_ = None
95
-
163
+ key_sequence = (
164
+ None # This should be something like ["choices", 0, "message", "content"]
165
+ )
96
166
  __rate_limits = None
97
- __default_rate_limits = {
98
- "rpm": 10_000,
99
- "tpm": 2_000_000,
100
- } # TODO: Use the OpenAI Teir 1 rate limits
101
167
  _safety_factor = 0.8
102
168
 
103
- def __init__(self, **kwargs):
169
+ def __init__(
170
+ self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
+ ):
104
172
  """Initialize the LanguageModel."""
105
173
  self.model = getattr(self, "_model_", None)
106
174
  default_parameters = getattr(self, "_parameters_", None)
107
175
  parameters = self._overide_default_parameters(kwargs, default_parameters)
108
176
  self.parameters = parameters
109
177
  self.remote = False
178
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
179
+
180
+ # self._rpm / _tpm comes from the class
181
+ if rpm is not None:
182
+ self._rpm = rpm
183
+
184
+ if tpm is not None:
185
+ self._tpm = tpm
110
186
 
111
187
  for key, value in parameters.items():
112
188
  setattr(self, key, value)
@@ -133,7 +209,6 @@ class LanguageModel(
133
209
  def api_token(self) -> str:
134
210
  if not hasattr(self, "_api_token"):
135
211
  key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
136
-
137
212
  if self._inference_service_ == "bedrock":
138
213
  self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
139
214
  # Check if any of the tokens are None
@@ -142,13 +217,13 @@ class LanguageModel(
142
217
  self._api_token = os.getenv(key_name)
143
218
  missing_token = self._api_token is None
144
219
  if missing_token and self._inference_service_ != "test" and not self.remote:
145
- print("rainsing error")
220
+ print("raising error")
146
221
  raise MissingAPIKeyError(
147
222
  f"""The key for service: `{self._inference_service_}` is not set.
148
223
  Need a key with name {key_name} in your .env file."""
149
224
  )
150
225
 
151
- return self._api_token
226
+ return self._api_token
152
227
 
153
228
  def __getitem__(self, key):
154
229
  return getattr(self, key)
@@ -209,40 +284,58 @@ class LanguageModel(
209
284
  >>> m = LanguageModel.example()
210
285
  >>> m.set_rate_limits(rpm=100, tpm=1000)
211
286
  >>> m.RPM
212
- 80.0
287
+ 100
213
288
  """
214
- self._set_rate_limits(rpm=rpm, tpm=tpm)
215
-
216
- def _set_rate_limits(self, rpm=None, tpm=None) -> None:
217
- """Set the rate limits for the model.
218
-
219
- If the model does not have rate limits, use the default rate limits."""
220
- if rpm is not None and tpm is not None:
221
- self.__rate_limits = {"rpm": rpm, "tpm": tpm}
222
- return
223
-
224
- if self.__rate_limits is None:
225
- if hasattr(self, "get_rate_limits"):
226
- self.__rate_limits = self.get_rate_limits()
227
- else:
228
- self.__rate_limits = self.__default_rate_limits
289
+ if rpm is not None:
290
+ self._rpm = rpm
291
+ if tpm is not None:
292
+ self._tpm = tpm
293
+ return None
294
+ # self._set_rate_limits(rpm=rpm, tpm=tpm)
295
+
296
+ # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
+ # """Set the rate limits for the model.
298
+
299
+ # If the model does not have rate limits, use the default rate limits."""
300
+ # if rpm is not None and tpm is not None:
301
+ # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
+ # return
303
+
304
+ # if self.__rate_limits is None:
305
+ # if hasattr(self, "get_rate_limits"):
306
+ # self.__rate_limits = self.get_rate_limits()
307
+ # else:
308
+ # self.__rate_limits = self.__default_rate_limits
229
309
 
230
310
  @property
231
311
  def RPM(self):
232
312
  """Model's requests-per-minute limit."""
233
- self._set_rate_limits()
234
- return self._safety_factor * self.__rate_limits["rpm"]
313
+ # self._set_rate_limits()
314
+ # return self._safety_factor * self.__rate_limits["rpm"]
315
+ return self._rpm
235
316
 
236
317
  @property
237
318
  def TPM(self):
238
- """Model's tokens-per-minute limit.
319
+ """Model's tokens-per-minute limit."""
320
+ # self._set_rate_limits()
321
+ # return self._safety_factor * self.__rate_limits["tpm"]
322
+ return self._tpm
239
323
 
240
- >>> m = LanguageModel.example()
241
- >>> m.TPM > 0
242
- True
243
- """
244
- self._set_rate_limits()
245
- return self._safety_factor * self.__rate_limits["tpm"]
324
+ @property
325
+ def rpm(self):
326
+ return self._rpm
327
+
328
+ @rpm.setter
329
+ def rpm(self, value):
330
+ self._rpm = value
331
+
332
+ @property
333
+ def tpm(self):
334
+ return self._tpm
335
+
336
+ @tpm.setter
337
+ def tpm(self, value):
338
+ self._tpm = value
246
339
 
247
340
  @staticmethod
248
341
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
@@ -270,11 +363,10 @@ class LanguageModel(
270
363
  >>> m = LanguageModel.example(test_model = True)
271
364
  >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
272
365
  >>> asyncio.run(test())
273
- {'message': '{"answer": "Hello world"}'}
366
+ {'message': [{'text': 'Hello world'}], ...}
274
367
 
275
368
  >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
276
- {'message': '{"answer": "Hello world"}'}
277
-
369
+ {'message': [{'text': 'Hello world'}], ...}
278
370
  """
279
371
  pass
280
372
 
@@ -307,68 +399,40 @@ class LanguageModel(
307
399
 
308
400
  return main()
309
401
 
310
- @abstractmethod
311
- def parse_response(raw_response: dict[str, Any]) -> str:
312
- """Parse the response and returns the response text.
313
-
314
- >>> m = LanguageModel.example(test_model = True)
315
- >>> m
316
- Model(model_name = 'test', temperature = 0.5)
317
-
318
- What is returned by the API is model-specific and often includes meta-data that we do not need.
319
- For example, here is the results from a call to GPT-4:
320
- To actually track the response, we need to grab
321
- data["choices[0]"]["message"]["content"].
322
- """
323
- raise NotImplementedError
324
-
325
- async def _async_prepare_response(
326
- self, model_call_outcome: IntendedModelCallOutcome, cache: "Cache"
327
- ) -> dict:
328
- """Prepare the response for return."""
329
-
330
- model_response = {
331
- "cache_used": model_call_outcome.cache_used,
332
- "cache_key": model_call_outcome.cache_key,
333
- "usage": model_call_outcome.response.get("usage", {}),
334
- "raw_model_response": model_call_outcome.response,
335
- }
402
+ @classmethod
403
+ def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
404
+ """Return the generated token string from the raw response."""
405
+ return extract_item_from_raw_response(raw_response, cls.key_sequence)
336
406
 
337
- answer_portion = self.parse_response(model_call_outcome.response)
338
- try:
339
- answer_dict = json.loads(answer_portion)
340
- except json.JSONDecodeError as e:
341
- # TODO: Turn into logs to generate issues
342
- answer_dict, success = await repair(
343
- bad_json=answer_portion, error_message=str(e), cache=cache
407
+ @classmethod
408
+ def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
409
+ """Return the usage dictionary from the raw response."""
410
+ if not hasattr(cls, "usage_sequence"):
411
+ raise NotImplementedError(
412
+ "This inference service does not have a usage_sequence."
344
413
  )
345
- if not success:
346
- raise Exception(
347
- f"""Even the repair failed. The error was: {e}. The response was: {answer_portion}."""
348
- )
349
-
350
- return {**model_response, **answer_dict}
351
-
352
- async def async_get_raw_response(
353
- self,
354
- user_prompt: str,
355
- system_prompt: str,
356
- cache: "Cache",
357
- iteration: int = 0,
358
- encoded_image=None,
359
- ) -> IntendedModelCallOutcome:
360
- import warnings
414
+ return extract_item_from_raw_response(raw_response, cls.usage_sequence)
361
415
 
362
- warnings.warn(
363
- "This method is deprecated. Use async_get_intended_model_call_outcome."
364
- )
365
- return await self._async_get_intended_model_call_outcome(
366
- user_prompt=user_prompt,
367
- system_prompt=system_prompt,
368
- cache=cache,
369
- iteration=iteration,
370
- encoded_image=encoded_image,
371
- )
416
+ @classmethod
417
+ def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
418
+ """Parses the API response and returns the response text."""
419
+ generated_token_string = cls.get_generated_token_string(raw_response)
420
+ last_newline = generated_token_string.rfind("\n")
421
+
422
+ if last_newline == -1:
423
+ # There is no comment
424
+ edsl_dict = {
425
+ "answer": convert_answer(generated_token_string),
426
+ "generated_tokens": generated_token_string,
427
+ "comment": None,
428
+ }
429
+ else:
430
+ edsl_dict = {
431
+ "answer": convert_answer(generated_token_string[:last_newline]),
432
+ "comment": generated_token_string[last_newline + 1 :].strip(),
433
+ "generated_tokens": generated_token_string,
434
+ }
435
+ return EDSLOutput(**edsl_dict)
372
436
 
373
437
  async def _async_get_intended_model_call_outcome(
374
438
  self,
@@ -377,7 +441,7 @@ class LanguageModel(
377
441
  cache: "Cache",
378
442
  iteration: int = 0,
379
443
  encoded_image=None,
380
- ) -> IntendedModelCallOutcome:
444
+ ) -> ModelResponse:
381
445
  """Handle caching of responses.
382
446
 
383
447
  :param user_prompt: The user's prompt.
@@ -396,18 +460,18 @@ class LanguageModel(
396
460
  >>> from edsl import Cache
397
461
  >>> m = LanguageModel.example(test_model = True)
398
462
  >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
399
- IntendedModelCallOutcome(response = {'message': '{"answer": "Hello world"}'}, cache_used = False, cache_key = '24ff6ac2bc2f1729f817f261e0792577')
400
- """
463
+ ModelResponse(...)"""
401
464
 
402
465
  if encoded_image:
403
466
  # the image has is appended to the user_prompt for hash-lookup purposes
404
467
  image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
468
+ user_prompt += f" {image_hash}"
405
469
 
406
470
  cache_call_params = {
407
471
  "model": str(self.model),
408
472
  "parameters": self.parameters,
409
473
  "system_prompt": system_prompt,
410
- "user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
474
+ "user_prompt": user_prompt,
411
475
  "iteration": iteration,
412
476
  }
413
477
  cached_response, cache_key = cache.fetch(**cache_call_params)
@@ -425,21 +489,28 @@ class LanguageModel(
425
489
  "system_prompt": system_prompt,
426
490
  **({"encoded_image": encoded_image} if encoded_image else {}),
427
491
  }
428
- response = await f(**params)
492
+ # response = await f(**params)
493
+ response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
429
494
  new_cache_key = cache.store(
430
495
  **cache_call_params, response=response
431
496
  ) # store the response in the cache
432
497
  assert new_cache_key == cache_key # should be the same
433
498
 
434
- return IntendedModelCallOutcome(
435
- response=response, cache_used=cache_used, cache_key=cache_key
499
+ cost = self.cost(response)
500
+
501
+ return ModelResponse(
502
+ response=response,
503
+ cache_used=cache_used,
504
+ cache_key=cache_key,
505
+ cached_response=cached_response,
506
+ cost=cost,
436
507
  )
437
508
 
438
509
  _get_intended_model_call_outcome = sync_wrapper(
439
510
  _async_get_intended_model_call_outcome
440
511
  )
441
512
 
442
- get_raw_response = sync_wrapper(async_get_raw_response)
513
+ # get_raw_response = sync_wrapper(async_get_raw_response)
443
514
 
444
515
  def simple_ask(
445
516
  self,
@@ -478,14 +549,66 @@ class LanguageModel(
478
549
  "cache": cache,
479
550
  **({"encoded_image": encoded_image} if encoded_image else {}),
480
551
  }
481
- model_call_outcome = await self._async_get_intended_model_call_outcome(**params)
482
- return await self._async_prepare_response(model_call_outcome, cache=cache)
552
+ model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
553
+ model_outputs = await self._async_get_intended_model_call_outcome(**params)
554
+ edsl_dict = self.parse_response(model_outputs.response)
555
+ agent_response_dict = AgentResponseDict(
556
+ model_inputs=model_inputs,
557
+ model_outputs=model_outputs,
558
+ edsl_dict=edsl_dict,
559
+ )
560
+ return agent_response_dict
561
+
562
+ # return await self._async_prepare_response(model_call_outcome, cache=cache)
483
563
 
484
564
  get_response = sync_wrapper(async_get_response)
485
565
 
486
- def cost(self, raw_response: dict[str, Any]) -> float:
566
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
487
567
  """Return the dollar cost of a raw response."""
488
- raise NotImplementedError
568
+
569
+ usage = self.get_usage_dict(raw_response)
570
+ from edsl.coop import Coop
571
+
572
+ c = Coop()
573
+ price_lookup = c.fetch_prices()
574
+ key = (self._inference_service_, self.model)
575
+ if key not in price_lookup:
576
+ return f"Could not find price for model {self.model} in the price lookup."
577
+
578
+ relevant_prices = price_lookup[key]
579
+ try:
580
+ input_tokens = int(usage[self.input_token_name])
581
+ output_tokens = int(usage[self.output_token_name])
582
+ except Exception as e:
583
+ return f"Could not fetch tokens from model response: {e}"
584
+
585
+ try:
586
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
587
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
588
+ except Exception as e:
589
+ if "output" not in relevant_prices:
590
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
591
+ if "input" not in relevant_prices:
592
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
593
+ return f"Could not fetch prices from {relevant_prices} - {e}"
594
+
595
+ if inverse_input_price == "infinity":
596
+ input_cost = 0
597
+ else:
598
+ try:
599
+ input_cost = input_tokens / float(inverse_input_price)
600
+ except Exception as e:
601
+ return f"Could not compute input price - {e}."
602
+
603
+ if inverse_output_price == "infinity":
604
+ output_cost = 0
605
+ else:
606
+ try:
607
+ output_cost = output_tokens / float(inverse_output_price)
608
+ except Exception as e:
609
+ return f"Could not compute output price - {e}"
610
+
611
+ return input_cost + output_cost
489
612
 
490
613
  #######################
491
614
  # SERIALIZATION METHODS
@@ -499,7 +622,7 @@ class LanguageModel(
499
622
 
500
623
  >>> m = LanguageModel.example()
501
624
  >>> m.to_dict()
502
- {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
625
+ {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
503
626
  """
504
627
  return self._to_dict()
505
628
 
@@ -575,26 +698,8 @@ class LanguageModel(
575
698
  """
576
699
  from edsl import Model
577
700
 
578
- class TestLanguageModelGood(LanguageModel):
579
- use_cache = False
580
- _model_ = "test"
581
- _parameters_ = {"temperature": 0.5}
582
- _inference_service_ = InferenceServiceType.TEST.value
583
-
584
- async def async_execute_model_call(
585
- self, user_prompt: str, system_prompt: str
586
- ) -> dict[str, Any]:
587
- await asyncio.sleep(0.1)
588
- # return {"message": """{"answer": "Hello, world"}"""}
589
- if throw_exception:
590
- raise Exception("This is a test error")
591
- return {"message": f'{{"answer": "{canned_response}"}}'}
592
-
593
- def parse_response(self, raw_response: dict[str, Any]) -> str:
594
- return raw_response["message"]
595
-
596
701
  if test_model:
597
- m = TestLanguageModelGood()
702
+ m = Model("test", canned_response=canned_response)
598
703
  return m
599
704
  else:
600
705
  return Model(skip_api_key_check=True)
@@ -40,8 +40,8 @@ class ModelList(Base, UserList):
40
40
  def __hash__(self):
41
41
  """Return a hash of the ModelList. This is used for comparison of ModelLists.
42
42
 
43
- >>> hash(ModelList.example())
44
- 1423518243781418961
43
+ >>> isinstance(hash(Model()), int)
44
+ True
45
45
 
46
46
  """
47
47
  from edsl.utilities.utilities import dict_hash
@@ -47,13 +47,6 @@ class RegisterLanguageModelsMeta(ABCMeta):
47
47
  must_be_async=True,
48
48
  )
49
49
  # LanguageModel children have to implement the parse_response method
50
- RegisterLanguageModelsMeta.verify_method(
51
- candidate_class=cls,
52
- method_name="parse_response",
53
- expected_return_type=str,
54
- required_parameters=[("raw_response", dict[str, Any])],
55
- must_be_async=False,
56
- )
57
50
  RegisterLanguageModelsMeta._registry[model_name] = cls
58
51
 
59
52
  @classmethod
@@ -98,7 +91,7 @@ class RegisterLanguageModelsMeta(ABCMeta):
98
91
 
99
92
  required_parameters = required_parameters or []
100
93
  method = getattr(candidate_class, method_name)
101
- signature = inspect.signature(method)
94
+ # signature = inspect.signature(method)
102
95
 
103
96
  RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
104
97
 
@@ -106,11 +99,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
106
99
  RegisterLanguageModelsMeta._check_is_coroutine(method)
107
100
 
108
101
  # Check the parameters
109
- params = signature.parameters
110
- for param_name, param_type in required_parameters:
111
- RegisterLanguageModelsMeta._verify_parameter(
112
- params, param_name, param_type, method_name
113
- )
102
+ # params = signature.parameters
103
+ # for param_name, param_type in required_parameters:
104
+ # RegisterLanguageModelsMeta._verify_parameter(
105
+ # params, param_name, param_type, method_name
106
+ # )
114
107
 
115
108
  @staticmethod
116
109
  def _check_method_defined(cls, method_name):
@@ -167,23 +160,15 @@ class RegisterLanguageModelsMeta(ABCMeta):
167
160
  Check if the return type of a method is as expected.
168
161
 
169
162
  Example:
170
- >>> class M:
171
- ... async def f(self) -> str: pass
172
- >>> RegisterLanguageModelsMeta._check_return_type(M.f, str)
173
- >>> class N:
174
- ... async def f(self) -> int: pass
175
- >>> RegisterLanguageModelsMeta._check_return_type(N.f, str)
176
- Traceback (most recent call last):
177
- ...
178
- TypeError: Return type of f must be <class 'str'>. Got <class 'int'>.
179
163
  """
180
- if inspect.isroutine(method):
181
- # return_type = inspect.signature(method).return_annotation
182
- return_type = get_type_hints(method)["return"]
183
- if return_type != expected_return_type:
184
- raise TypeError(
185
- f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
186
- )
164
+ pass
165
+ # if inspect.isroutine(method):
166
+ # # return_type = inspect.signature(method).return_annotation
167
+ # return_type = get_type_hints(method)["return"]
168
+ # if return_type != expected_return_type:
169
+ # raise TypeError(
170
+ # f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
171
+ # )
187
172
 
188
173
  @classmethod
189
174
  def model_names_to_classes(cls):
@@ -0,0 +1,15 @@
1
+ from openai import AsyncOpenAI
2
+ import asyncio
3
+
4
+ client = AsyncOpenAI(base_url="http://127.0.0.1:8000/v1", api_key="fake_key")
5
+
6
+
7
+ async def main():
8
+ response = await client.chat.completions.create(
9
+ model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Question XX42"}]
10
+ )
11
+ print(response)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ asyncio.run(main())