edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -0,0 +1,63 @@
1
+ from typing import Any, Union
2
+
3
+
4
+ class ComputeCost:
5
+ def __init__(self, language_model: "LanguageModel"):
6
+ self.language_model = language_model
7
+ self._price_lookup = None
8
+
9
+ @property
10
+ def price_lookup(self):
11
+ if self._price_lookup is None:
12
+ from edsl.coop import Coop
13
+
14
+ c = Coop()
15
+ self._price_lookup = c.fetch_prices()
16
+ return self._price_lookup
17
+
18
+ def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
19
+ """Return the dollar cost of a raw response."""
20
+
21
+ usage = self.get_usage_dict(raw_response)
22
+ from edsl.coop import Coop
23
+
24
+ c = Coop()
25
+ price_lookup = c.fetch_prices()
26
+ key = (self._inference_service_, self.model)
27
+ if key not in price_lookup:
28
+ return f"Could not find price for model {self.model} in the price lookup."
29
+
30
+ relevant_prices = price_lookup[key]
31
+ try:
32
+ input_tokens = int(usage[self.input_token_name])
33
+ output_tokens = int(usage[self.output_token_name])
34
+ except Exception as e:
35
+ return f"Could not fetch tokens from model response: {e}"
36
+
37
+ try:
38
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
39
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
40
+ except Exception as e:
41
+ if "output" not in relevant_prices:
42
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
43
+ if "input" not in relevant_prices:
44
+ return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
45
+ return f"Could not fetch prices from {relevant_prices} - {e}"
46
+
47
+ if inverse_input_price == "infinity":
48
+ input_cost = 0
49
+ else:
50
+ try:
51
+ input_cost = input_tokens / float(inverse_input_price)
52
+ except Exception as e:
53
+ return f"Could not compute input price - {e}."
54
+
55
+ if inverse_output_price == "infinity":
56
+ output_cost = 0
57
+ else:
58
+ try:
59
+ output_cost = output_tokens / float(inverse_output_price)
60
+ except Exception as e:
61
+ return f"Could not compute output price - {e}"
62
+
63
+ return input_cost + output_cost
@@ -21,7 +21,6 @@ import os
21
21
  from typing import (
22
22
  Coroutine,
23
23
  Any,
24
- Callable,
25
24
  Type,
26
25
  Union,
27
26
  List,
@@ -32,8 +31,6 @@ from typing import (
32
31
  )
33
32
  from abc import ABC, abstractmethod
34
33
 
35
- from json_repair import repair_json
36
-
37
34
  from edsl.data_transfer_models import (
38
35
  ModelResponse,
39
36
  ModelInputs,
@@ -45,61 +42,24 @@ if TYPE_CHECKING:
45
42
  from edsl.data.Cache import Cache
46
43
  from edsl.scenarios.FileStore import FileStore
47
44
  from edsl.questions.QuestionBase import QuestionBase
45
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
48
46
 
49
- from edsl.config import CONFIG
50
- from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
51
- from edsl.utilities.decorators import remove_edsl_version
52
-
53
- from edsl.Base import PersistenceMixin
54
- from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
55
- from edsl.language_models.KeyLookup import KeyLookup
56
- from edsl.exceptions.language_models import LanguageModelBadResponseError
47
+ from edsl.enums import InferenceServiceType
57
48
 
58
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
49
+ from edsl.utilities.decorators import (
50
+ sync_wrapper,
51
+ jupyter_nb_handler,
52
+ )
53
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
59
54
 
55
+ from edsl.Base import PersistenceMixin, RepresentationMixin
56
+ from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
60
57
 
61
- # you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
62
- # for reasons I don't understand. So leave it here.
63
- def extract_item_from_raw_response(data, key_sequence):
64
- if isinstance(data, str):
65
- try:
66
- data = json.loads(data)
67
- except json.JSONDecodeError as e:
68
- return data
69
- current_data = data
70
- for i, key in enumerate(key_sequence):
71
- try:
72
- if isinstance(current_data, (list, tuple)):
73
- if not isinstance(key, int):
74
- raise TypeError(
75
- f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
76
- )
77
- if key < 0 or key >= len(current_data):
78
- raise IndexError(
79
- f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
80
- )
81
- elif isinstance(current_data, dict):
82
- if key not in current_data:
83
- raise KeyError(
84
- f"Key '{key}' not found in dictionary at position {i}"
85
- )
86
- else:
87
- raise TypeError(
88
- f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
89
- )
58
+ from edsl.language_models.key_management.KeyLookupCollection import (
59
+ KeyLookupCollection,
60
+ )
90
61
 
91
- current_data = current_data[key]
92
- except Exception as e:
93
- path = " -> ".join(map(str, key_sequence[: i + 1]))
94
- if "error" in data:
95
- msg = data["error"]
96
- else:
97
- msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
98
- raise LanguageModelBadResponseError(message=msg, response_json=data)
99
- if isinstance(current_data, str):
100
- return current_data.strip()
101
- else:
102
- return current_data
62
+ from edsl.language_models.RawResponseHandler import RawResponseHandler
103
63
 
104
64
 
105
65
  def handle_key_error(func):
@@ -117,8 +77,21 @@ def handle_key_error(func):
117
77
  return wrapper
118
78
 
119
79
 
80
+ class classproperty:
81
+ def __init__(self, method):
82
+ self.method = method
83
+
84
+ def __get__(self, instance, cls):
85
+ return self.method(cls)
86
+
87
+
88
+ from edsl.Base import HashingMixin
89
+
90
+
120
91
  class LanguageModel(
121
92
  PersistenceMixin,
93
+ RepresentationMixin,
94
+ HashingMixin,
122
95
  ABC,
123
96
  metaclass=RegisterLanguageModelsMeta,
124
97
  ):
@@ -128,15 +101,22 @@ class LanguageModel(
128
101
  key_sequence = (
129
102
  None # This should be something like ["choices", 0, "message", "content"]
130
103
  )
131
- __rate_limits = None
132
- _safety_factor = 0.8
104
+
105
+ DEFAULT_RPM = 100
106
+ DEFAULT_TPM = 1000
107
+
108
+ @classproperty
109
+ def response_handler(cls):
110
+ key_sequence = cls.key_sequence
111
+ usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
112
+ return RawResponseHandler(key_sequence, usage_sequence)
133
113
 
134
114
  def __init__(
135
115
  self,
136
- tpm: float = None,
137
- rpm: float = None,
116
+ tpm: Optional[float] = None,
117
+ rpm: Optional[float] = None,
138
118
  omit_system_prompt_if_empty_string: bool = True,
139
- key_lookup: Optional[KeyLookup] = None,
119
+ key_lookup: Optional["KeyLookup"] = None,
140
120
  **kwargs,
141
121
  ):
142
122
  """Initialize the LanguageModel."""
@@ -147,7 +127,9 @@ class LanguageModel(
147
127
  self.remote = False
148
128
  self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
149
129
 
150
- # self._rpm / _tpm comes from the class
130
+ self.key_lookup = self._set_key_lookup(key_lookup)
131
+ self.model_info = self.key_lookup.get(self._inference_service_)
132
+
151
133
  if rpm is not None:
152
134
  self._rpm = rpm
153
135
 
@@ -161,49 +143,75 @@ class LanguageModel(
161
143
  if key not in parameters:
162
144
  setattr(self, key, value)
163
145
 
164
- if "use_cache" in kwargs:
165
- warnings.warn(
166
- "The use_cache parameter is deprecated. Use the Cache class instead."
167
- )
168
-
169
- if skip_api_key_check := kwargs.get("skip_api_key_check", False):
146
+ if kwargs.get("skip_api_key_check", False):
170
147
  # Skip the API key check. Sometimes this is useful for testing.
171
148
  self._api_token = None
172
149
 
150
+ def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
151
+ """Set the key lookup."""
173
152
  if key_lookup is not None:
174
- self.key_lookup = key_lookup
153
+ return key_lookup
175
154
  else:
176
- self.key_lookup = KeyLookup.from_os_environ()
155
+ klc = KeyLookupCollection()
156
+ klc.add_key_lookup(fetch_order=("config", "env"))
157
+ return klc.get(("config", "env"))
158
+
159
+ def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
160
+ """Set the key lookup, later"""
161
+ if hasattr(self, "_api_token"):
162
+ del self._api_token
163
+ self.key_lookup = key_lookup
164
+
165
+ def ask_question(self, question: "QuestionBase") -> str:
166
+ """Ask a question and return the response.
177
167
 
178
- def ask_question(self, question):
168
+ :param question: The question to ask.
169
+ """
179
170
  user_prompt = question.get_instructions().render(question.data).text
180
171
  system_prompt = "You are a helpful agent pretending to be a human."
181
172
  return self.execute_model_call(user_prompt, system_prompt)
182
173
 
183
- def set_key_lookup(self, key_lookup: KeyLookup) -> None:
184
- del self._api_token
185
- self.key_lookup = key_lookup
174
+ @property
175
+ def rpm(self):
176
+ if not hasattr(self, "_rpm"):
177
+ if self.model_info is None:
178
+ self._rpm = self.DEFAULT_RPM
179
+ else:
180
+ self._rpm = self.model_info.rpm
181
+ return self._rpm
182
+
183
+ @property
184
+ def tpm(self):
185
+ if not hasattr(self, "_tpm"):
186
+ if self.model_info is None:
187
+ self._tpm = self.DEFAULT_TPM
188
+ else:
189
+ self._tpm = self.model_info.tpm
190
+ return self._tpm
191
+
192
+ # in case we want to override the default values
193
+ @tpm.setter
194
+ def tpm(self, value):
195
+ self._tpm = value
196
+
197
+ @rpm.setter
198
+ def rpm(self, value):
199
+ self._rpm = value
186
200
 
187
201
  @property
188
202
  def api_token(self) -> str:
189
203
  if not hasattr(self, "_api_token"):
190
- self._api_token = self.key_lookup.get_api_token(
191
- self._inference_service_, self.remote
192
- )
204
+ info = self.key_lookup.get(self._inference_service_, None)
205
+ if info is None:
206
+ raise ValueError(
207
+ f"No key found for service '{self._inference_service_}'"
208
+ )
209
+ self._api_token = info.api_token
193
210
  return self._api_token
194
211
 
195
212
  def __getitem__(self, key):
196
213
  return getattr(self, key)
197
214
 
198
- def _repr_html_(self) -> str:
199
- d = {"model": self.model}
200
- d.update(self.parameters)
201
- data = [[k, v] for k, v in d.items()]
202
- from tabulate import tabulate
203
-
204
- table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
205
- return f"<pre>{table}</pre>"
206
-
207
215
  def hello(self, verbose=False):
208
216
  """Runs a simple test to check if the model is working."""
209
217
  token = self.api_token
@@ -232,7 +240,12 @@ class LanguageModel(
232
240
  return key_value is not None
233
241
 
234
242
  def __hash__(self) -> str:
235
- """Allow the model to be used as a key in a dictionary."""
243
+ """Allow the model to be used as a key in a dictionary.
244
+
245
+ >>> m = LanguageModel.example()
246
+ >>> hash(m)
247
+ 1811901442659237949
248
+ """
236
249
  from edsl.utilities.utilities import dict_hash
237
250
 
238
251
  return dict_hash(self.to_dict(add_edsl_version=False))
@@ -248,46 +261,6 @@ class LanguageModel(
248
261
  """
249
262
  return self.model == other.model and self.parameters == other.parameters
250
263
 
251
- def set_rate_limits(self, rpm=None, tpm=None) -> None:
252
- """Set the rate limits for the model.
253
-
254
- >>> m = LanguageModel.example()
255
- >>> m.set_rate_limits(rpm=100, tpm=1000)
256
- >>> m.RPM
257
- 100
258
- """
259
- if rpm is not None:
260
- self._rpm = rpm
261
- if tpm is not None:
262
- self._tpm = tpm
263
- return None
264
-
265
- @property
266
- def RPM(self):
267
- """Model's requests-per-minute limit."""
268
- return self._rpm
269
-
270
- @property
271
- def TPM(self):
272
- """Model's tokens-per-minute limit."""
273
- return self._tpm
274
-
275
- @property
276
- def rpm(self):
277
- return self._rpm
278
-
279
- @rpm.setter
280
- def rpm(self, value):
281
- self._rpm = value
282
-
283
- @property
284
- def tpm(self):
285
- return self._tpm
286
-
287
- @tpm.setter
288
- def tpm(self, value):
289
- self._tpm = value
290
-
291
264
  @staticmethod
292
265
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
293
266
  """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
@@ -310,16 +283,7 @@ class LanguageModel(
310
283
 
311
284
  @abstractmethod
312
285
  async def async_execute_model_call(user_prompt: str, system_prompt: str):
313
- """Execute the model call and returns a coroutine.
314
-
315
- >>> m = LanguageModel.example(test_model = True)
316
- >>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
317
- >>> asyncio.run(test())
318
- {'message': [{'text': 'Hello world'}], ...}
319
-
320
- >>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
321
- {'message': [{'text': 'Hello world'}], ...}
322
- """
286
+ """Execute the model call and returns a coroutine."""
323
287
  pass
324
288
 
325
289
  async def remote_async_execute_model_call(
@@ -336,12 +300,7 @@ class LanguageModel(
336
300
 
337
301
  @jupyter_nb_handler
338
302
  def execute_model_call(self, *args, **kwargs) -> Coroutine:
339
- """Execute the model call and returns the result as a coroutine.
340
-
341
- >>> m = LanguageModel.example(test_model = True)
342
- >>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
343
-
344
- """
303
+ """Execute the model call and returns the result as a coroutine."""
345
304
 
346
305
  async def main():
347
306
  results = await asyncio.gather(
@@ -353,58 +312,25 @@ class LanguageModel(
353
312
 
354
313
  @classmethod
355
314
  def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
356
- """Return the generated token string from the raw response."""
357
- return extract_item_from_raw_response(raw_response, cls.key_sequence)
315
+ """Return the generated token string from the raw response.
316
+
317
+ >>> m = LanguageModel.example(test_model = True)
318
+ >>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
319
+ >>> m.get_generated_token_string(raw_response)
320
+ 'Hello world'
321
+
322
+ """
323
+ return cls.response_handler.get_generated_token_string(raw_response)
358
324
 
359
325
  @classmethod
360
326
  def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
361
327
  """Return the usage dictionary from the raw response."""
362
- if not hasattr(cls, "usage_sequence"):
363
- raise NotImplementedError(
364
- "This inference service does not have a usage_sequence."
365
- )
366
- return extract_item_from_raw_response(raw_response, cls.usage_sequence)
367
-
368
- @staticmethod
369
- def convert_answer(response_part):
370
- import json
371
-
372
- response_part = response_part.strip()
373
-
374
- if response_part == "None":
375
- return None
376
-
377
- repaired = repair_json(response_part)
378
- if repaired == '""':
379
- # it was a literal string
380
- return response_part
381
-
382
- try:
383
- return json.loads(repaired)
384
- except json.JSONDecodeError as j:
385
- # last resort
386
- return response_part
328
+ return cls.response_handler.get_usage_dict(raw_response)
387
329
 
388
330
  @classmethod
389
331
  def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
390
332
  """Parses the API response and returns the response text."""
391
- generated_token_string = cls.get_generated_token_string(raw_response)
392
- last_newline = generated_token_string.rfind("\n")
393
-
394
- if last_newline == -1:
395
- # There is no comment
396
- edsl_dict = {
397
- "answer": cls.convert_answer(generated_token_string),
398
- "generated_tokens": generated_token_string,
399
- "comment": None,
400
- }
401
- else:
402
- edsl_dict = {
403
- "answer": cls.convert_answer(generated_token_string[:last_newline]),
404
- "comment": generated_token_string[last_newline + 1 :].strip(),
405
- "generated_tokens": generated_token_string,
406
- }
407
- return EDSLOutput(**edsl_dict)
333
+ return cls.response_handler.parse_response(raw_response)
408
334
 
409
335
  async def _async_get_intended_model_call_outcome(
410
336
  self,
@@ -421,6 +347,8 @@ class LanguageModel(
421
347
  :param system_prompt: The system's prompt.
422
348
  :param iteration: The iteration number.
423
349
  :param cache: The cache to use.
350
+ :param files_list: The list of files to use.
351
+ :param invigilator: The invigilator to use.
424
352
 
425
353
  If the cache isn't being used, it just returns a 'fresh' call to the LLM.
426
354
  But if cache is being used, it first checks the database to see if the response is already there.
@@ -463,6 +391,10 @@ class LanguageModel(
463
391
  "system_prompt": system_prompt,
464
392
  "files_list": files_list,
465
393
  }
394
+ from edsl.config import CONFIG
395
+
396
+ TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
397
+
466
398
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
467
399
  new_cache_key = cache.store(
468
400
  **cache_call_params, response=response
@@ -470,7 +402,6 @@ class LanguageModel(
470
402
  assert new_cache_key == cache_key # should be the same
471
403
 
472
404
  cost = self.cost(response)
473
-
474
405
  return ModelResponse(
475
406
  response=response,
476
407
  cache_used=cache_used,
@@ -509,9 +440,9 @@ class LanguageModel(
509
440
 
510
441
  :param user_prompt: The user's prompt.
511
442
  :param system_prompt: The system's prompt.
512
- :param iteration: The iteration number.
513
443
  :param cache: The cache to use.
514
- :param encoded_image: The encoded image to use.
444
+ :param iteration: The iteration number.
445
+ :param files_list: The list of files to use.
515
446
 
516
447
  """
517
448
  params = {
@@ -525,8 +456,11 @@ class LanguageModel(
525
456
  params.update({"invigilator": kwargs["invigilator"]})
526
457
 
527
458
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
528
- model_outputs = await self._async_get_intended_model_call_outcome(**params)
529
- edsl_dict = self.parse_response(model_outputs.response)
459
+ model_outputs: ModelResponse = (
460
+ await self._async_get_intended_model_call_outcome(**params)
461
+ )
462
+ edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
463
+
530
464
  agent_response_dict = AgentResponseDict(
531
465
  model_inputs=model_inputs,
532
466
  model_outputs=model_outputs,
@@ -537,60 +471,36 @@ class LanguageModel(
537
471
  get_response = sync_wrapper(async_get_response)
538
472
 
539
473
  def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
540
- """Return the dollar cost of a raw response."""
541
-
542
- usage = self.get_usage_dict(raw_response)
543
- from edsl.coop import Coop
474
+ """Return the dollar cost of a raw response.
544
475
 
545
- c = Coop()
546
- price_lookup = c.fetch_prices()
547
- key = (self._inference_service_, self.model)
548
- if key not in price_lookup:
549
- return f"Could not find price for model {self.model} in the price lookup."
550
-
551
- relevant_prices = price_lookup[key]
552
- try:
553
- input_tokens = int(usage[self.input_token_name])
554
- output_tokens = int(usage[self.output_token_name])
555
- except Exception as e:
556
- return f"Could not fetch tokens from model response: {e}"
557
-
558
- try:
559
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
560
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
561
- except Exception as e:
562
- if "output" not in relevant_prices:
563
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
564
- if "input" not in relevant_prices:
565
- return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
566
- return f"Could not fetch prices from {relevant_prices} - {e}"
567
-
568
- if inverse_input_price == "infinity":
569
- input_cost = 0
570
- else:
571
- try:
572
- input_cost = input_tokens / float(inverse_input_price)
573
- except Exception as e:
574
- return f"Could not compute input price - {e}."
575
-
576
- if inverse_output_price == "infinity":
577
- output_cost = 0
578
- else:
579
- try:
580
- output_cost = output_tokens / float(inverse_output_price)
581
- except Exception as e:
582
- return f"Could not compute output price - {e}"
476
+ :param raw_response: The raw response from the model.
477
+ """
583
478
 
584
- return input_cost + output_cost
479
+ usage = self.get_usage_dict(raw_response)
480
+ from edsl.language_models.PriceManager import PriceManager
481
+
482
+ price_manger = PriceManager()
483
+ return price_manger.calculate_cost(
484
+ inference_service=self._inference_service_,
485
+ model=self.model,
486
+ usage=usage,
487
+ input_token_name=self.input_token_name,
488
+ output_token_name=self.output_token_name,
489
+ )
585
490
 
586
491
  def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
587
492
  """Convert instance to a dictionary
588
493
 
494
+ :param add_edsl_version: Whether to add the EDSL version to the dictionary.
495
+
589
496
  >>> m = LanguageModel.example()
590
497
  >>> m.to_dict()
591
498
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
592
499
  """
593
- d = {"model": self.model, "parameters": self.parameters}
500
+ d = {
501
+ "model": self.model,
502
+ "parameters": self.parameters,
503
+ }
594
504
  if add_edsl_version:
595
505
  from edsl import __version__
596
506
 
@@ -602,13 +512,13 @@ class LanguageModel(
602
512
  @remove_edsl_version
603
513
  def from_dict(cls, data: dict) -> Type[LanguageModel]:
604
514
  """Convert dictionary to a LanguageModel child instance."""
605
- from edsl.language_models.registry import get_model_class
515
+ from edsl.language_models.model import get_model_class
606
516
 
607
517
  model_class = get_model_class(data["model"])
608
518
  return model_class(**data)
609
519
 
610
520
  def __repr__(self) -> str:
611
- """Return a string representation of the object."""
521
+ """Return a representation of the object."""
612
522
  param_string = ", ".join(
613
523
  f"{key} = {value}" for key, value in self.parameters.items()
614
524
  )
@@ -650,7 +560,7 @@ class LanguageModel(
650
560
  Exception report saved to ...
651
561
  Also see: ...
652
562
  """
653
- from edsl import Model
563
+ from edsl.language_models.model import Model
654
564
 
655
565
  if test_model:
656
566
  m = Model(
@@ -660,6 +570,54 @@ class LanguageModel(
660
570
  else:
661
571
  return Model(skip_api_key_check=True)
662
572
 
573
+ def from_cache(self, cache: "Cache") -> LanguageModel:
574
+
575
+ from copy import deepcopy
576
+ from types import MethodType
577
+ from edsl import Cache
578
+
579
+ new_instance = deepcopy(self)
580
+ print("Cache entries", len(cache))
581
+ new_instance.cache = Cache(
582
+ data={k: v for k, v in cache.items() if v.model == self.model}
583
+ )
584
+ print("Cache entries with same model", len(new_instance.cache))
585
+
586
+ new_instance.user_prompts = [
587
+ ce.user_prompt for ce in new_instance.cache.values()
588
+ ]
589
+ new_instance.system_prompts = [
590
+ ce.system_prompt for ce in new_instance.cache.values()
591
+ ]
592
+
593
+ async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
594
+ cache_call_params = {
595
+ "model": str(self.model),
596
+ "parameters": self.parameters,
597
+ "system_prompt": system_prompt,
598
+ "user_prompt": user_prompt,
599
+ "iteration": 1,
600
+ }
601
+ cached_response, cache_key = cache.fetch(**cache_call_params)
602
+ response = json.loads(cached_response)
603
+ cost = 0
604
+ return ModelResponse(
605
+ response=response,
606
+ cache_used=True,
607
+ cache_key=cache_key,
608
+ cached_response=cached_response,
609
+ cost=cost,
610
+ )
611
+
612
+ # Bind the new method to the copied instance
613
+ setattr(
614
+ new_instance,
615
+ "async_execute_model_call",
616
+ MethodType(async_execute_model_call, new_instance),
617
+ )
618
+
619
+ return new_instance
620
+
663
621
 
664
622
  if __name__ == "__main__":
665
623
  """Run the module's test suite."""