edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +136 -221
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +48 -47
  23. edsl/conjure/Conjure.py +6 -0
  24. edsl/coop/PriceFetcher.py +58 -0
  25. edsl/coop/coop.py +50 -7
  26. edsl/data/Cache.py +35 -1
  27. edsl/data/CacheHandler.py +3 -4
  28. edsl/data_transfer_models.py +73 -38
  29. edsl/enums.py +8 -0
  30. edsl/exceptions/general.py +10 -8
  31. edsl/exceptions/language_models.py +25 -1
  32. edsl/exceptions/questions.py +62 -5
  33. edsl/exceptions/results.py +4 -0
  34. edsl/inference_services/AnthropicService.py +13 -11
  35. edsl/inference_services/AwsBedrock.py +112 -0
  36. edsl/inference_services/AzureAI.py +214 -0
  37. edsl/inference_services/DeepInfraService.py +4 -3
  38. edsl/inference_services/GoogleService.py +16 -12
  39. edsl/inference_services/GroqService.py +5 -4
  40. edsl/inference_services/InferenceServiceABC.py +58 -3
  41. edsl/inference_services/InferenceServicesCollection.py +13 -8
  42. edsl/inference_services/MistralAIService.py +120 -0
  43. edsl/inference_services/OllamaService.py +18 -0
  44. edsl/inference_services/OpenAIService.py +55 -56
  45. edsl/inference_services/TestService.py +80 -0
  46. edsl/inference_services/TogetherAIService.py +170 -0
  47. edsl/inference_services/models_available_cache.py +25 -0
  48. edsl/inference_services/registry.py +19 -1
  49. edsl/jobs/Answers.py +10 -12
  50. edsl/jobs/FailedQuestion.py +78 -0
  51. edsl/jobs/Jobs.py +137 -41
  52. edsl/jobs/buckets/BucketCollection.py +24 -15
  53. edsl/jobs/buckets/TokenBucket.py +105 -18
  54. edsl/jobs/interviews/Interview.py +393 -83
  55. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
  56. edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
  57. edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
  58. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  59. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  60. edsl/jobs/tasks/TaskCreators.py +1 -1
  61. edsl/jobs/tasks/TaskHistory.py +205 -126
  62. edsl/language_models/LanguageModel.py +297 -177
  63. edsl/language_models/ModelList.py +2 -2
  64. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  65. edsl/language_models/fake_openai_call.py +15 -0
  66. edsl/language_models/fake_openai_service.py +61 -0
  67. edsl/language_models/registry.py +25 -8
  68. edsl/language_models/repair.py +0 -19
  69. edsl/language_models/utilities.py +61 -0
  70. edsl/notebooks/Notebook.py +20 -2
  71. edsl/prompts/Prompt.py +52 -2
  72. edsl/questions/AnswerValidatorMixin.py +23 -26
  73. edsl/questions/QuestionBase.py +330 -249
  74. edsl/questions/QuestionBaseGenMixin.py +133 -0
  75. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  76. edsl/questions/QuestionBudget.py +99 -42
  77. edsl/questions/QuestionCheckBox.py +227 -36
  78. edsl/questions/QuestionExtract.py +98 -28
  79. edsl/questions/QuestionFreeText.py +47 -31
  80. edsl/questions/QuestionFunctional.py +7 -0
  81. edsl/questions/QuestionList.py +141 -23
  82. edsl/questions/QuestionMultipleChoice.py +159 -66
  83. edsl/questions/QuestionNumerical.py +88 -47
  84. edsl/questions/QuestionRank.py +182 -25
  85. edsl/questions/Quick.py +41 -0
  86. edsl/questions/RegisterQuestionsMeta.py +31 -12
  87. edsl/questions/ResponseValidatorABC.py +170 -0
  88. edsl/questions/__init__.py +3 -4
  89. edsl/questions/decorators.py +21 -0
  90. edsl/questions/derived/QuestionLikertFive.py +10 -5
  91. edsl/questions/derived/QuestionLinearScale.py +15 -2
  92. edsl/questions/derived/QuestionTopK.py +10 -1
  93. edsl/questions/derived/QuestionYesNo.py +24 -3
  94. edsl/questions/descriptors.py +43 -7
  95. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  96. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  97. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  98. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  99. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  100. edsl/questions/prompt_templates/question_list.jinja +17 -0
  101. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  102. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  103. edsl/questions/question_registry.py +6 -2
  104. edsl/questions/templates/__init__.py +0 -0
  105. edsl/questions/templates/budget/__init__.py +0 -0
  106. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  107. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  108. edsl/questions/templates/checkbox/__init__.py +0 -0
  109. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  110. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  111. edsl/questions/templates/extract/__init__.py +0 -0
  112. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  113. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  114. edsl/questions/templates/free_text/__init__.py +0 -0
  115. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  116. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  117. edsl/questions/templates/likert_five/__init__.py +0 -0
  118. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  119. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  120. edsl/questions/templates/linear_scale/__init__.py +0 -0
  121. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  122. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  123. edsl/questions/templates/list/__init__.py +0 -0
  124. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  125. edsl/questions/templates/list/question_presentation.jinja +5 -0
  126. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  127. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  128. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  129. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  130. edsl/questions/templates/numerical/__init__.py +0 -0
  131. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  132. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  133. edsl/questions/templates/rank/__init__.py +0 -0
  134. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  135. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  136. edsl/questions/templates/top_k/__init__.py +0 -0
  137. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  138. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  139. edsl/questions/templates/yes_no/__init__.py +0 -0
  140. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  141. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  142. edsl/results/Dataset.py +20 -0
  143. edsl/results/DatasetExportMixin.py +58 -30
  144. edsl/results/DatasetTree.py +145 -0
  145. edsl/results/Result.py +32 -5
  146. edsl/results/Results.py +135 -46
  147. edsl/results/ResultsDBMixin.py +3 -3
  148. edsl/results/Selector.py +118 -0
  149. edsl/results/tree_explore.py +115 -0
  150. edsl/scenarios/FileStore.py +71 -10
  151. edsl/scenarios/Scenario.py +109 -24
  152. edsl/scenarios/ScenarioImageMixin.py +2 -2
  153. edsl/scenarios/ScenarioList.py +546 -21
  154. edsl/scenarios/ScenarioListExportMixin.py +24 -4
  155. edsl/scenarios/ScenarioListPdfMixin.py +153 -4
  156. edsl/study/SnapShot.py +8 -1
  157. edsl/study/Study.py +32 -0
  158. edsl/surveys/Rule.py +15 -3
  159. edsl/surveys/RuleCollection.py +21 -5
  160. edsl/surveys/Survey.py +707 -298
  161. edsl/surveys/SurveyExportMixin.py +71 -9
  162. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  163. edsl/surveys/SurveyQualtricsImport.py +284 -0
  164. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  165. edsl/surveys/instructions/Instruction.py +34 -0
  166. edsl/surveys/instructions/InstructionCollection.py +77 -0
  167. edsl/surveys/instructions/__init__.py +0 -0
  168. edsl/templates/error_reporting/base.html +24 -0
  169. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  170. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  171. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  172. edsl/templates/error_reporting/interview_details.html +116 -0
  173. edsl/templates/error_reporting/interviews.html +10 -0
  174. edsl/templates/error_reporting/overview.html +5 -0
  175. edsl/templates/error_reporting/performance_plot.html +2 -0
  176. edsl/templates/error_reporting/report.css +74 -0
  177. edsl/templates/error_reporting/report.html +118 -0
  178. edsl/templates/error_reporting/report.js +25 -0
  179. edsl/utilities/utilities.py +40 -1
  180. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
  181. edsl-0.1.33.dist-info/RECORD +295 -0
  182. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
  183. edsl/jobs/interviews/retry_management.py +0 -37
  184. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
  185. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  186. edsl-0.1.31.dev4.dist-info/RECORD +0 -204
  187. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  188. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -1,4 +1,16 @@
1
- """This module contains the LanguageModel class, which is an abstract base class for all language models."""
1
+ """This module contains the LanguageModel class, which is an abstract base class for all language models.
2
+
3
+ Terminology:
4
+
5
+ raw_response: The JSON response from the model. This has all the model meta-data about the call.
6
+
7
+ edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
8
+ such as the cache key, token usage, etc.
9
+
10
+ generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
11
+ edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
12
+
13
+ """
2
14
 
3
15
  from __future__ import annotations
4
16
  import warnings
@@ -8,45 +20,103 @@ import json
8
20
  import time
9
21
  import os
10
22
  import hashlib
11
- from typing import Coroutine, Any, Callable, Type, List, get_type_hints
23
+ from typing import (
24
+ Coroutine,
25
+ Any,
26
+ Callable,
27
+ Type,
28
+ Union,
29
+ List,
30
+ get_type_hints,
31
+ TypedDict,
32
+ Optional,
33
+ )
12
34
  from abc import ABC, abstractmethod
13
35
 
14
- class IntendedModelCallOutcome:
15
- "This is a tuple-like class that holds the response, cache_used, and cache_key."
16
-
17
- def __init__(self, response: dict, cache_used: bool, cache_key: str):
18
- self.response = response
19
- self.cache_used = cache_used
20
- self.cache_key = cache_key
36
+ from json_repair import repair_json
21
37
 
22
- def __iter__(self):
23
- """Iterate over the class attributes.
24
-
25
- >>> a, b, c = IntendedModelCallOutcome({'answer': "yes"}, True, 'x1289')
26
- >>> a
27
- {'answer': 'yes'}
28
- """
29
- yield self.response
30
- yield self.cache_used
31
- yield self.cache_key
38
+ from edsl.data_transfer_models import (
39
+ ModelResponse,
40
+ ModelInputs,
41
+ EDSLOutput,
42
+ AgentResponseDict,
43
+ )
32
44
 
33
- def __len__(self):
34
- return 3
35
-
36
- def __repr__(self):
37
- return f"IntendedModelCallOutcome(response = {self.response}, cache_used = {self.cache_used}, cache_key = '{self.cache_key}')"
38
45
 
39
46
  from edsl.config import CONFIG
40
-
41
47
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
42
48
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
43
-
44
49
  from edsl.language_models.repair import repair
45
50
  from edsl.enums import InferenceServiceType
46
51
  from edsl.Base import RichPrintingMixin, PersistenceMixin
47
52
  from edsl.enums import service_to_api_keyname
48
53
  from edsl.exceptions import MissingAPIKeyError
49
54
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
56
+
57
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
58
+
59
+
60
+ def convert_answer(response_part):
61
+ import json
62
+
63
+ response_part = response_part.strip()
64
+
65
+ if response_part == "None":
66
+ return None
67
+
68
+ repaired = repair_json(response_part)
69
+ if repaired == '""':
70
+ # it was a literal string
71
+ return response_part
72
+
73
+ try:
74
+ return json.loads(repaired)
75
+ except json.JSONDecodeError as j:
76
+ # last resort
77
+ return response_part
78
+
79
+
80
+ def extract_item_from_raw_response(data, key_sequence):
81
+ if isinstance(data, str):
82
+ try:
83
+ data = json.loads(data)
84
+ except json.JSONDecodeError as e:
85
+ return data
86
+ current_data = data
87
+ for i, key in enumerate(key_sequence):
88
+ try:
89
+ if isinstance(current_data, (list, tuple)):
90
+ if not isinstance(key, int):
91
+ raise TypeError(
92
+ f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
93
+ )
94
+ if key < 0 or key >= len(current_data):
95
+ raise IndexError(
96
+ f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
97
+ )
98
+ elif isinstance(current_data, dict):
99
+ if key not in current_data:
100
+ raise KeyError(
101
+ f"Key '{key}' not found in dictionary at position {i}"
102
+ )
103
+ else:
104
+ raise TypeError(
105
+ f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
106
+ )
107
+
108
+ current_data = current_data[key]
109
+ except Exception as e:
110
+ path = " -> ".join(map(str, key_sequence[: i + 1]))
111
+ if "error" in data:
112
+ msg = data["error"]
113
+ else:
114
+ msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
115
+ raise LanguageModelBadResponseError(message=msg, response_json=data)
116
+ if isinstance(current_data, str):
117
+ return current_data.strip()
118
+ else:
119
+ return current_data
50
120
 
51
121
 
52
122
  def handle_key_error(func):
@@ -90,21 +160,29 @@ class LanguageModel(
90
160
  """
91
161
 
92
162
  _model_ = None
93
-
163
+ key_sequence = (
164
+ None # This should be something like ["choices", 0, "message", "content"]
165
+ )
94
166
  __rate_limits = None
95
- __default_rate_limits = {
96
- "rpm": 10_000,
97
- "tpm": 2_000_000,
98
- } # TODO: Use the OpenAI Teir 1 rate limits
99
167
  _safety_factor = 0.8
100
168
 
101
- def __init__(self, **kwargs):
169
+ def __init__(
170
+ self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
+ ):
102
172
  """Initialize the LanguageModel."""
103
173
  self.model = getattr(self, "_model_", None)
104
174
  default_parameters = getattr(self, "_parameters_", None)
105
175
  parameters = self._overide_default_parameters(kwargs, default_parameters)
106
176
  self.parameters = parameters
107
177
  self.remote = False
178
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
179
+
180
+ # self._rpm / _tpm comes from the class
181
+ if rpm is not None:
182
+ self._rpm = rpm
183
+
184
+ if tpm is not None:
185
+ self._tpm = tpm
108
186
 
109
187
  for key, value in parameters.items():
110
188
  setattr(self, key, value)
@@ -131,17 +209,20 @@ class LanguageModel(
131
209
  def api_token(self) -> str:
132
210
  if not hasattr(self, "_api_token"):
133
211
  key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
134
- self._api_token = os.getenv(key_name)
135
- if (
136
- self._api_token is None
137
- and self._inference_service_ != "test"
138
- and not self.remote
139
- ):
212
+ if self._inference_service_ == "bedrock":
213
+ self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
214
+ # Check if any of the tokens are None
215
+ missing_token = any(token is None for token in self._api_token)
216
+ else:
217
+ self._api_token = os.getenv(key_name)
218
+ missing_token = self._api_token is None
219
+ if missing_token and self._inference_service_ != "test" and not self.remote:
220
+ print("raising error")
140
221
  raise MissingAPIKeyError(
141
222
  f"""The key for service: `{self._inference_service_}` is not set.
142
- Need a key with name {key_name} in your .env file.
143
- """
223
+ Need a key with name {key_name} in your .env file."""
144
224
  )
225
+
145
226
  return self._api_token
146
227
 
147
228
  def __getitem__(self, key):
@@ -159,8 +240,7 @@ class LanguageModel(
159
240
  if verbose:
160
241
  print(f"Current key is {masked}")
161
242
  return self.execute_model_call(
162
- user_prompt="Hello, model!",
163
- system_prompt="You are a helpful agent."
243
+ user_prompt="Hello, model!", system_prompt="You are a helpful agent."
164
244
  )
165
245
 
166
246
  def has_valid_api_key(self) -> bool:
@@ -204,42 +284,58 @@ class LanguageModel(
204
284
  >>> m = LanguageModel.example()
205
285
  >>> m.set_rate_limits(rpm=100, tpm=1000)
206
286
  >>> m.RPM
207
- 80.0
287
+ 100
208
288
  """
209
- self._set_rate_limits(rpm=rpm, tpm=tpm)
210
-
211
-
212
-
213
- def _set_rate_limits(self, rpm=None, tpm=None) -> None:
214
- """Set the rate limits for the model.
215
-
216
- If the model does not have rate limits, use the default rate limits."""
217
- if rpm is not None and tpm is not None:
218
- self.__rate_limits = {"rpm": rpm, "tpm": tpm}
219
- return
220
-
221
- if self.__rate_limits is None:
222
- if hasattr(self, "get_rate_limits"):
223
- self.__rate_limits = self.get_rate_limits()
224
- else:
225
- self.__rate_limits = self.__default_rate_limits
289
+ if rpm is not None:
290
+ self._rpm = rpm
291
+ if tpm is not None:
292
+ self._tpm = tpm
293
+ return None
294
+ # self._set_rate_limits(rpm=rpm, tpm=tpm)
295
+
296
+ # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
+ # """Set the rate limits for the model.
298
+
299
+ # If the model does not have rate limits, use the default rate limits."""
300
+ # if rpm is not None and tpm is not None:
301
+ # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
+ # return
303
+
304
+ # if self.__rate_limits is None:
305
+ # if hasattr(self, "get_rate_limits"):
306
+ # self.__rate_limits = self.get_rate_limits()
307
+ # else:
308
+ # self.__rate_limits = self.__default_rate_limits
226
309
 
227
310
  @property
228
311
  def RPM(self):
229
312
  """Model's requests-per-minute limit."""
230
- self._set_rate_limits()
231
- return self._safety_factor * self.__rate_limits["rpm"]
313
+ # self._set_rate_limits()
314
+ # return self._safety_factor * self.__rate_limits["rpm"]
315
+ return self._rpm
232
316
 
233
317
  @property
234
318
  def TPM(self):
235
- """Model's tokens-per-minute limit.
319
+ """Model's tokens-per-minute limit."""
320
+ # self._set_rate_limits()
321
+ # return self._safety_factor * self.__rate_limits["tpm"]
322
+ return self._tpm
236
323
 
237
- >>> m = LanguageModel.example()
238
- >>> m.TPM > 0
239
- True
240
- """
241
- self._set_rate_limits()
242
- return self._safety_factor * self.__rate_limits["tpm"]
324
+ @property
325
+ def rpm(self):
326
+ return self._rpm
327
+
328
+ @rpm.setter
329
+ def rpm(self, value):
330
+ self._rpm = value
331
+
332
+ @property
333
+ def tpm(self):
334
+ return self._tpm
335
+
336
+ @tpm.setter
337
+ def tpm(self, value):
338
+ self._tpm = value
243
339
 
244
340
  @staticmethod
245
341
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
@@ -250,14 +346,16 @@ class LanguageModel(
250
346
  >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
251
347
  {'temperature': 0.5, 'max_tokens': 1000}
252
348
  """
253
- #parameters = dict({})
349
+ # parameters = dict({})
350
+
351
+ return {
352
+ parameter_name: passed_parameter_dict.get(parameter_name, default_value)
353
+ for parameter_name, default_value in default_parameter_dict.items()
354
+ }
254
355
 
255
- return {parameter_name: passed_parameter_dict.get(parameter_name, default_value)
256
- for parameter_name, default_value in default_parameter_dict.items()}
257
-
258
- def __call__(self, user_prompt:str, system_prompt:str):
356
+ def __call__(self, user_prompt: str, system_prompt: str):
259
357
  return self.execute_model_call(user_prompt, system_prompt)
260
-
358
+
261
359
  @abstractmethod
262
360
  async def async_execute_model_call(user_prompt: str, system_prompt: str):
263
361
  """Execute the model call and returns a coroutine.
@@ -265,11 +363,10 @@ class LanguageModel(
265
363
  >>> m = LanguageModel.example(test_model = True)
266
364
  >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
267
365
  >>> asyncio.run(test())
268
- {'message': '{"answer": "Hello world"}'}
366
+ {'message': [{'text': 'Hello world'}], ...}
269
367
 
270
368
  >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
271
- {'message': '{"answer": "Hello world"}'}
272
-
369
+ {'message': [{'text': 'Hello world'}], ...}
273
370
  """
274
371
  pass
275
372
 
@@ -302,66 +399,40 @@ class LanguageModel(
302
399
 
303
400
  return main()
304
401
 
305
- @abstractmethod
306
- def parse_response(raw_response: dict[str, Any]) -> str:
307
- """Parse the response and returns the response text.
308
-
309
- >>> m = LanguageModel.example(test_model = True)
310
- >>> m
311
- Model(model_name = 'test', temperature = 0.5)
312
-
313
- What is returned by the API is model-specific and often includes meta-data that we do not need.
314
- For example, here is the results from a call to GPT-4:
315
- To actually track the response, we need to grab
316
- data["choices[0]"]["message"]["content"].
317
- """
318
- raise NotImplementedError
319
-
320
- async def _async_prepare_response(self, model_call_outcome: IntendedModelCallOutcome, cache: "Cache") -> dict:
321
- """Prepare the response for return."""
322
-
323
- model_response = {
324
- "cache_used": model_call_outcome.cache_used,
325
- "cache_key": model_call_outcome.cache_key,
326
- "usage": model_call_outcome.response.get("usage", {}),
327
- "raw_model_response": model_call_outcome.response,
328
- }
402
+ @classmethod
403
+ def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
404
+ """Return the generated token string from the raw response."""
405
+ return extract_item_from_raw_response(raw_response, cls.key_sequence)
329
406
 
330
- answer_portion = self.parse_response(model_call_outcome.response)
331
- try:
332
- answer_dict = json.loads(answer_portion)
333
- except json.JSONDecodeError as e:
334
- # TODO: Turn into logs to generate issues
335
- answer_dict, success = await repair(
336
- bad_json=answer_portion,
337
- error_message=str(e),
338
- cache=cache
407
+ @classmethod
408
+ def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
409
+ """Return the usage dictionary from the raw response."""
410
+ if not hasattr(cls, "usage_sequence"):
411
+ raise NotImplementedError(
412
+ "This inference service does not have a usage_sequence."
339
413
  )
340
- if not success:
341
- raise Exception(
342
- f"""Even the repair failed. The error was: {e}. The response was: {answer_portion}."""
343
- )
344
-
345
- return {**model_response, **answer_dict}
346
-
347
- async def async_get_raw_response(
348
- self,
349
- user_prompt: str,
350
- system_prompt: str,
351
- cache: "Cache",
352
- iteration: int = 0,
353
- encoded_image=None,
354
- ) -> IntendedModelCallOutcome:
355
- import warnings
356
- warnings.warn("This method is deprecated. Use async_get_intended_model_call_outcome.")
357
- return await self._async_get_intended_model_call_outcome(
358
- user_prompt=user_prompt,
359
- system_prompt=system_prompt,
360
- cache=cache,
361
- iteration=iteration,
362
- encoded_image=encoded_image
363
- )
414
+ return extract_item_from_raw_response(raw_response, cls.usage_sequence)
364
415
 
416
+ @classmethod
417
+ def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
418
+ """Parses the API response and returns the response text."""
419
+ generated_token_string = cls.get_generated_token_string(raw_response)
420
+ last_newline = generated_token_string.rfind("\n")
421
+
422
+ if last_newline == -1:
423
+ # There is no comment
424
+ edsl_dict = {
425
+ "answer": convert_answer(generated_token_string),
426
+ "generated_tokens": generated_token_string,
427
+ "comment": None,
428
+ }
429
+ else:
430
+ edsl_dict = {
431
+ "answer": convert_answer(generated_token_string[:last_newline]),
432
+ "comment": generated_token_string[last_newline + 1 :].strip(),
433
+ "generated_tokens": generated_token_string,
434
+ }
435
+ return EDSLOutput(**edsl_dict)
365
436
 
366
437
  async def _async_get_intended_model_call_outcome(
367
438
  self,
@@ -370,7 +441,7 @@ class LanguageModel(
370
441
  cache: "Cache",
371
442
  iteration: int = 0,
372
443
  encoded_image=None,
373
- ) -> IntendedModelCallOutcome:
444
+ ) -> ModelResponse:
374
445
  """Handle caching of responses.
375
446
 
376
447
  :param user_prompt: The user's prompt.
@@ -389,23 +460,23 @@ class LanguageModel(
389
460
  >>> from edsl import Cache
390
461
  >>> m = LanguageModel.example(test_model = True)
391
462
  >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
392
- IntendedModelCallOutcome(response = {'message': '{"answer": "Hello world"}'}, cache_used = False, cache_key = '24ff6ac2bc2f1729f817f261e0792577')
393
- """
463
+ ModelResponse(...)"""
394
464
 
395
465
  if encoded_image:
396
466
  # the image has is appended to the user_prompt for hash-lookup purposes
397
467
  image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
468
+ user_prompt += f" {image_hash}"
398
469
 
399
470
  cache_call_params = {
400
471
  "model": str(self.model),
401
472
  "parameters": self.parameters,
402
473
  "system_prompt": system_prompt,
403
- "user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
474
+ "user_prompt": user_prompt,
404
475
  "iteration": iteration,
405
476
  }
406
477
  cached_response, cache_key = cache.fetch(**cache_call_params)
407
-
408
- if (cache_used := cached_response is not None):
478
+
479
+ if cache_used := cached_response is not None:
409
480
  response = json.loads(cached_response)
410
481
  else:
411
482
  f = (
@@ -413,18 +484,33 @@ class LanguageModel(
413
484
  if hasattr(self, "remote") and self.remote
414
485
  else self.async_execute_model_call
415
486
  )
416
- params = {"user_prompt": user_prompt, "system_prompt": system_prompt,
417
- **({"encoded_image": encoded_image} if encoded_image else {})
487
+ params = {
488
+ "user_prompt": user_prompt,
489
+ "system_prompt": system_prompt,
490
+ **({"encoded_image": encoded_image} if encoded_image else {}),
418
491
  }
419
- response = await f(**params)
420
- new_cache_key = cache.store(**cache_call_params, response=response) # store the response in the cache
421
- assert new_cache_key == cache_key # should be the same
422
-
423
- return IntendedModelCallOutcome(response = response, cache_used = cache_used, cache_key = cache_key)
492
+ # response = await f(**params)
493
+ response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
494
+ new_cache_key = cache.store(
495
+ **cache_call_params, response=response
496
+ ) # store the response in the cache
497
+ assert new_cache_key == cache_key # should be the same
498
+
499
+ cost = self.cost(response)
500
+
501
+ return ModelResponse(
502
+ response=response,
503
+ cache_used=cache_used,
504
+ cache_key=cache_key,
505
+ cached_response=cached_response,
506
+ cost=cost,
507
+ )
424
508
 
425
- _get_intended_model_call_outcome = sync_wrapper(_async_get_intended_model_call_outcome)
509
+ _get_intended_model_call_outcome = sync_wrapper(
510
+ _async_get_intended_model_call_outcome
511
+ )
426
512
 
427
- get_raw_response = sync_wrapper(async_get_raw_response)
513
+ # get_raw_response = sync_wrapper(async_get_raw_response)
428
514
 
429
515
  def simple_ask(
430
516
  self,
@@ -443,7 +529,7 @@ class LanguageModel(
443
529
  self,
444
530
  user_prompt: str,
445
531
  system_prompt: str,
446
- cache: 'Cache',
532
+ cache: "Cache",
447
533
  iteration: int = 1,
448
534
  encoded_image=None,
449
535
  ) -> dict:
@@ -461,16 +547,68 @@ class LanguageModel(
461
547
  "system_prompt": system_prompt,
462
548
  "iteration": iteration,
463
549
  "cache": cache,
464
- **({"encoded_image": encoded_image} if encoded_image else {})
465
- }
466
- model_call_outcome = await self._async_get_intended_model_call_outcome(**params)
467
- return await self._async_prepare_response(model_call_outcome, cache=cache)
550
+ **({"encoded_image": encoded_image} if encoded_image else {}),
551
+ }
552
+ model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
553
+ model_outputs = await self._async_get_intended_model_call_outcome(**params)
554
+ edsl_dict = self.parse_response(model_outputs.response)
555
+ agent_response_dict = AgentResponseDict(
556
+ model_inputs=model_inputs,
557
+ model_outputs=model_outputs,
558
+ edsl_dict=edsl_dict,
559
+ )
560
+ return agent_response_dict
561
+
562
+ # return await self._async_prepare_response(model_call_outcome, cache=cache)
468
563
 
469
564
  get_response = sync_wrapper(async_get_response)
470
565
 
471
- def cost(self, raw_response: dict[str, Any]) -> float:
566
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
472
567
  """Return the dollar cost of a raw response."""
473
- raise NotImplementedError
568
+
569
+ usage = self.get_usage_dict(raw_response)
570
+ from edsl.coop import Coop
571
+
572
+ c = Coop()
573
+ price_lookup = c.fetch_prices()
574
+ key = (self._inference_service_, self.model)
575
+ if key not in price_lookup:
576
+ return f"Could not find price for model {self.model} in the price lookup."
577
+
578
+ relevant_prices = price_lookup[key]
579
+ try:
580
+ input_tokens = int(usage[self.input_token_name])
581
+ output_tokens = int(usage[self.output_token_name])
582
+ except Exception as e:
583
+ return f"Could not fetch tokens from model response: {e}"
584
+
585
+ try:
586
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
587
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
588
+ except Exception as e:
589
+ if "output" not in relevant_prices:
590
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
591
+ if "input" not in relevant_prices:
592
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
593
+ return f"Could not fetch prices from {relevant_prices} - {e}"
594
+
595
+ if inverse_input_price == "infinity":
596
+ input_cost = 0
597
+ else:
598
+ try:
599
+ input_cost = input_tokens / float(inverse_input_price)
600
+ except Exception as e:
601
+ return f"Could not compute input price - {e}."
602
+
603
+ if inverse_output_price == "infinity":
604
+ output_cost = 0
605
+ else:
606
+ try:
607
+ output_cost = output_tokens / float(inverse_output_price)
608
+ except Exception as e:
609
+ return f"Could not compute output price - {e}"
610
+
611
+ return input_cost + output_cost
474
612
 
475
613
  #######################
476
614
  # SERIALIZATION METHODS
@@ -484,7 +622,7 @@ class LanguageModel(
484
622
 
485
623
  >>> m = LanguageModel.example()
486
624
  >>> m.to_dict()
487
- {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
625
+ {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
488
626
  """
489
627
  return self._to_dict()
490
628
 
@@ -560,26 +698,8 @@ class LanguageModel(
560
698
  """
561
699
  from edsl import Model
562
700
 
563
- class TestLanguageModelGood(LanguageModel):
564
- use_cache = False
565
- _model_ = "test"
566
- _parameters_ = {"temperature": 0.5}
567
- _inference_service_ = InferenceServiceType.TEST.value
568
-
569
- async def async_execute_model_call(
570
- self, user_prompt: str, system_prompt: str
571
- ) -> dict[str, Any]:
572
- await asyncio.sleep(0.1)
573
- # return {"message": """{"answer": "Hello, world"}"""}
574
- if throw_exception:
575
- raise Exception("This is a test error")
576
- return {"message": f'{{"answer": "{canned_response}"}}'}
577
-
578
- def parse_response(self, raw_response: dict[str, Any]) -> str:
579
- return raw_response["message"]
580
-
581
701
  if test_model:
582
- m = TestLanguageModelGood()
702
+ m = Model("test", canned_response=canned_response)
583
703
  return m
584
704
  else:
585
705
  return Model(skip_api_key_check=True)
@@ -40,8 +40,8 @@ class ModelList(Base, UserList):
40
40
  def __hash__(self):
41
41
  """Return a hash of the ModelList. This is used for comparison of ModelLists.
42
42
 
43
- >>> hash(ModelList.example())
44
- 1423518243781418961
43
+ >>> isinstance(hash(Model()), int)
44
+ True
45
45
 
46
46
  """
47
47
  from edsl.utilities.utilities import dict_hash