edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +3 -8
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -40
- edsl/agents/AgentList.py +0 -43
- edsl/agents/Invigilator.py +219 -135
- edsl/agents/InvigilatorBase.py +59 -148
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
- edsl/agents/__init__.py +0 -1
- edsl/config.py +56 -47
- edsl/coop/coop.py +7 -50
- edsl/data/Cache.py +1 -35
- edsl/data_transfer_models.py +38 -73
- edsl/enums.py +0 -4
- edsl/exceptions/language_models.py +1 -25
- edsl/exceptions/questions.py +5 -62
- edsl/exceptions/results.py +0 -4
- edsl/inference_services/AnthropicService.py +11 -13
- edsl/inference_services/AwsBedrock.py +17 -19
- edsl/inference_services/AzureAI.py +20 -37
- edsl/inference_services/GoogleService.py +12 -16
- edsl/inference_services/GroqService.py +0 -2
- edsl/inference_services/InferenceServiceABC.py +3 -58
- edsl/inference_services/OpenAIService.py +54 -48
- edsl/inference_services/models_available_cache.py +6 -0
- edsl/inference_services/registry.py +0 -6
- edsl/jobs/Answers.py +12 -10
- edsl/jobs/Jobs.py +21 -36
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +14 -93
- edsl/jobs/interviews/Interview.py +78 -366
- edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
- edsl/jobs/interviews/retry_management.py +37 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
- edsl/jobs/tasks/TaskHistory.py +213 -148
- edsl/language_models/LanguageModel.py +156 -261
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
- edsl/language_models/registry.py +6 -23
- edsl/language_models/repair.py +19 -0
- edsl/prompts/Prompt.py +2 -52
- edsl/questions/AnswerValidatorMixin.py +26 -23
- edsl/questions/QuestionBase.py +249 -329
- edsl/questions/QuestionBudget.py +41 -99
- edsl/questions/QuestionCheckBox.py +35 -227
- edsl/questions/QuestionExtract.py +27 -98
- edsl/questions/QuestionFreeText.py +29 -52
- edsl/questions/QuestionFunctional.py +0 -7
- edsl/questions/QuestionList.py +22 -141
- edsl/questions/QuestionMultipleChoice.py +65 -159
- edsl/questions/QuestionNumerical.py +46 -88
- edsl/questions/QuestionRank.py +24 -182
- edsl/questions/RegisterQuestionsMeta.py +12 -31
- edsl/questions/__init__.py +4 -3
- edsl/questions/derived/QuestionLikertFive.py +5 -10
- edsl/questions/derived/QuestionLinearScale.py +2 -15
- edsl/questions/derived/QuestionTopK.py +1 -10
- edsl/questions/derived/QuestionYesNo.py +3 -24
- edsl/questions/descriptors.py +7 -43
- edsl/questions/question_registry.py +2 -6
- edsl/results/Dataset.py +0 -20
- edsl/results/DatasetExportMixin.py +48 -46
- edsl/results/Result.py +5 -32
- edsl/results/Results.py +46 -135
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +25 -96
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +39 -361
- edsl/scenarios/ScenarioListExportMixin.py +0 -9
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/SnapShot.py +1 -8
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +5 -21
- edsl/surveys/Survey.py +310 -636
- edsl/surveys/SurveyExportMixin.py +9 -71
- edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
- edsl/surveys/SurveyQualtricsImport.py +4 -75
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
- edsl-0.1.33.dev1.dist-info/RECORD +209 -0
- edsl/TemplateLoader.py +0 -24
- edsl/auto/AutoStudy.py +0 -117
- edsl/auto/StageBase.py +0 -230
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -73
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -224
- edsl/coop/PriceFetcher.py +0 -58
- edsl/inference_services/MistralAIService.py +0 -120
- edsl/inference_services/TestService.py +0 -80
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/language_models/fake_openai_call.py +0 -15
- edsl/language_models/fake_openai_service.py +0 -61
- edsl/language_models/utilities.py +0 -61
- edsl/questions/QuestionBaseGenMixin.py +0 -133
- edsl/questions/QuestionBasePromptsMixin.py +0 -266
- edsl/questions/Quick.py +0 -41
- edsl/questions/ResponseValidatorABC.py +0 -170
- edsl/questions/decorators.py +0 -21
- edsl/questions/prompt_templates/question_budget.jinja +0 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
- edsl/questions/prompt_templates/question_extract.jinja +0 -11
- edsl/questions/prompt_templates/question_free_text.jinja +0 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
- edsl/questions/prompt_templates/question_list.jinja +0 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
- edsl/questions/prompt_templates/question_numerical.jinja +0 -37
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +0 -7
- edsl/questions/templates/extract/question_presentation.jinja +0 -1
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +0 -1
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +0 -4
- edsl/questions/templates/list/question_presentation.jinja +0 -5
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
- edsl/questions/templates/numerical/question_presentation.jinja +0 -7
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +0 -11
- edsl/questions/templates/rank/question_presentation.jinja +0 -15
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
- edsl/questions/templates/top_k/question_presentation.jinja +0 -22
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
- edsl/results/DatasetTree.py +0 -145
- edsl/results/Selector.py +0 -118
- edsl/results/tree_explore.py +0 -115
- edsl/surveys/instructions/ChangeInstruction.py +0 -47
- edsl/surveys/instructions/Instruction.py +0 -34
- edsl/surveys/instructions/InstructionCollection.py +0 -77
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +0 -24
- edsl/templates/error_reporting/exceptions_by_model.html +0 -35
- edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
- edsl/templates/error_reporting/exceptions_by_type.html +0 -17
- edsl/templates/error_reporting/interview_details.html +0 -116
- edsl/templates/error_reporting/interviews.html +0 -10
- edsl/templates/error_reporting/overview.html +0 -5
- edsl/templates/error_reporting/performance_plot.html +0 -2
- edsl/templates/error_reporting/report.css +0 -74
- edsl/templates/error_reporting/report.html +0 -118
- edsl/templates/error_reporting/report.js +0 -25
- edsl-0.1.33.dist-info/RECORD +0 -295
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -1,16 +1,4 @@
|
|
1
|
-
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
2
|
-
|
3
|
-
Terminology:
|
4
|
-
|
5
|
-
raw_response: The JSON response from the model. This has all the model meta-data about the call.
|
6
|
-
|
7
|
-
edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
|
8
|
-
such as the cache key, token usage, etc.
|
9
|
-
|
10
|
-
generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
|
11
|
-
edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
|
12
|
-
|
13
|
-
"""
|
1
|
+
"""This module contains the LanguageModel class, which is an abstract base class for all language models."""
|
14
2
|
|
15
3
|
from __future__ import annotations
|
16
4
|
import warnings
|
@@ -20,103 +8,47 @@ import json
|
|
20
8
|
import time
|
21
9
|
import os
|
22
10
|
import hashlib
|
23
|
-
from typing import
|
24
|
-
Coroutine,
|
25
|
-
Any,
|
26
|
-
Callable,
|
27
|
-
Type,
|
28
|
-
Union,
|
29
|
-
List,
|
30
|
-
get_type_hints,
|
31
|
-
TypedDict,
|
32
|
-
Optional,
|
33
|
-
)
|
11
|
+
from typing import Coroutine, Any, Callable, Type, List, get_type_hints
|
34
12
|
from abc import ABC, abstractmethod
|
35
13
|
|
36
|
-
from json_repair import repair_json
|
37
14
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
15
|
+
class IntendedModelCallOutcome:
|
16
|
+
"This is a tuple-like class that holds the response, cache_used, and cache_key."
|
17
|
+
|
18
|
+
def __init__(self, response: dict, cache_used: bool, cache_key: str):
|
19
|
+
self.response = response
|
20
|
+
self.cache_used = cache_used
|
21
|
+
self.cache_key = cache_key
|
22
|
+
|
23
|
+
def __iter__(self):
|
24
|
+
"""Iterate over the class attributes.
|
25
|
+
|
26
|
+
>>> a, b, c = IntendedModelCallOutcome({'answer': "yes"}, True, 'x1289')
|
27
|
+
>>> a
|
28
|
+
{'answer': 'yes'}
|
29
|
+
"""
|
30
|
+
yield self.response
|
31
|
+
yield self.cache_used
|
32
|
+
yield self.cache_key
|
33
|
+
|
34
|
+
def __len__(self):
|
35
|
+
return 3
|
36
|
+
|
37
|
+
def __repr__(self):
|
38
|
+
return f"IntendedModelCallOutcome(response = {self.response}, cache_used = {self.cache_used}, cache_key = '{self.cache_key}')"
|
44
39
|
|
45
40
|
|
46
41
|
from edsl.config import CONFIG
|
42
|
+
|
47
43
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
48
44
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
45
|
+
|
49
46
|
from edsl.language_models.repair import repair
|
50
47
|
from edsl.enums import InferenceServiceType
|
51
48
|
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
52
49
|
from edsl.enums import service_to_api_keyname
|
53
50
|
from edsl.exceptions import MissingAPIKeyError
|
54
51
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
55
|
-
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
56
|
-
|
57
|
-
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
58
|
-
|
59
|
-
|
60
|
-
def convert_answer(response_part):
|
61
|
-
import json
|
62
|
-
|
63
|
-
response_part = response_part.strip()
|
64
|
-
|
65
|
-
if response_part == "None":
|
66
|
-
return None
|
67
|
-
|
68
|
-
repaired = repair_json(response_part)
|
69
|
-
if repaired == '""':
|
70
|
-
# it was a literal string
|
71
|
-
return response_part
|
72
|
-
|
73
|
-
try:
|
74
|
-
return json.loads(repaired)
|
75
|
-
except json.JSONDecodeError as j:
|
76
|
-
# last resort
|
77
|
-
return response_part
|
78
|
-
|
79
|
-
|
80
|
-
def extract_item_from_raw_response(data, key_sequence):
|
81
|
-
if isinstance(data, str):
|
82
|
-
try:
|
83
|
-
data = json.loads(data)
|
84
|
-
except json.JSONDecodeError as e:
|
85
|
-
return data
|
86
|
-
current_data = data
|
87
|
-
for i, key in enumerate(key_sequence):
|
88
|
-
try:
|
89
|
-
if isinstance(current_data, (list, tuple)):
|
90
|
-
if not isinstance(key, int):
|
91
|
-
raise TypeError(
|
92
|
-
f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
|
93
|
-
)
|
94
|
-
if key < 0 or key >= len(current_data):
|
95
|
-
raise IndexError(
|
96
|
-
f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
|
97
|
-
)
|
98
|
-
elif isinstance(current_data, dict):
|
99
|
-
if key not in current_data:
|
100
|
-
raise KeyError(
|
101
|
-
f"Key '{key}' not found in dictionary at position {i}"
|
102
|
-
)
|
103
|
-
else:
|
104
|
-
raise TypeError(
|
105
|
-
f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
|
106
|
-
)
|
107
|
-
|
108
|
-
current_data = current_data[key]
|
109
|
-
except Exception as e:
|
110
|
-
path = " -> ".join(map(str, key_sequence[: i + 1]))
|
111
|
-
if "error" in data:
|
112
|
-
msg = data["error"]
|
113
|
-
else:
|
114
|
-
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
115
|
-
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
116
|
-
if isinstance(current_data, str):
|
117
|
-
return current_data.strip()
|
118
|
-
else:
|
119
|
-
return current_data
|
120
52
|
|
121
53
|
|
122
54
|
def handle_key_error(func):
|
@@ -160,29 +92,21 @@ class LanguageModel(
|
|
160
92
|
"""
|
161
93
|
|
162
94
|
_model_ = None
|
163
|
-
|
164
|
-
None # This should be something like ["choices", 0, "message", "content"]
|
165
|
-
)
|
95
|
+
|
166
96
|
__rate_limits = None
|
97
|
+
__default_rate_limits = {
|
98
|
+
"rpm": 10_000,
|
99
|
+
"tpm": 2_000_000,
|
100
|
+
} # TODO: Use the OpenAI Teir 1 rate limits
|
167
101
|
_safety_factor = 0.8
|
168
102
|
|
169
|
-
def __init__(
|
170
|
-
self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
|
171
|
-
):
|
103
|
+
def __init__(self, **kwargs):
|
172
104
|
"""Initialize the LanguageModel."""
|
173
105
|
self.model = getattr(self, "_model_", None)
|
174
106
|
default_parameters = getattr(self, "_parameters_", None)
|
175
107
|
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
176
108
|
self.parameters = parameters
|
177
109
|
self.remote = False
|
178
|
-
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
179
|
-
|
180
|
-
# self._rpm / _tpm comes from the class
|
181
|
-
if rpm is not None:
|
182
|
-
self._rpm = rpm
|
183
|
-
|
184
|
-
if tpm is not None:
|
185
|
-
self._tpm = tpm
|
186
110
|
|
187
111
|
for key, value in parameters.items():
|
188
112
|
setattr(self, key, value)
|
@@ -209,6 +133,7 @@ class LanguageModel(
|
|
209
133
|
def api_token(self) -> str:
|
210
134
|
if not hasattr(self, "_api_token"):
|
211
135
|
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
136
|
+
|
212
137
|
if self._inference_service_ == "bedrock":
|
213
138
|
self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
|
214
139
|
# Check if any of the tokens are None
|
@@ -217,13 +142,13 @@ class LanguageModel(
|
|
217
142
|
self._api_token = os.getenv(key_name)
|
218
143
|
missing_token = self._api_token is None
|
219
144
|
if missing_token and self._inference_service_ != "test" and not self.remote:
|
220
|
-
print("
|
145
|
+
print("rainsing error")
|
221
146
|
raise MissingAPIKeyError(
|
222
147
|
f"""The key for service: `{self._inference_service_}` is not set.
|
223
148
|
Need a key with name {key_name} in your .env file."""
|
224
149
|
)
|
225
150
|
|
226
|
-
|
151
|
+
return self._api_token
|
227
152
|
|
228
153
|
def __getitem__(self, key):
|
229
154
|
return getattr(self, key)
|
@@ -284,58 +209,40 @@ class LanguageModel(
|
|
284
209
|
>>> m = LanguageModel.example()
|
285
210
|
>>> m.set_rate_limits(rpm=100, tpm=1000)
|
286
211
|
>>> m.RPM
|
287
|
-
|
212
|
+
80.0
|
288
213
|
"""
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
# if self.__rate_limits is None:
|
305
|
-
# if hasattr(self, "get_rate_limits"):
|
306
|
-
# self.__rate_limits = self.get_rate_limits()
|
307
|
-
# else:
|
308
|
-
# self.__rate_limits = self.__default_rate_limits
|
214
|
+
self._set_rate_limits(rpm=rpm, tpm=tpm)
|
215
|
+
|
216
|
+
def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
217
|
+
"""Set the rate limits for the model.
|
218
|
+
|
219
|
+
If the model does not have rate limits, use the default rate limits."""
|
220
|
+
if rpm is not None and tpm is not None:
|
221
|
+
self.__rate_limits = {"rpm": rpm, "tpm": tpm}
|
222
|
+
return
|
223
|
+
|
224
|
+
if self.__rate_limits is None:
|
225
|
+
if hasattr(self, "get_rate_limits"):
|
226
|
+
self.__rate_limits = self.get_rate_limits()
|
227
|
+
else:
|
228
|
+
self.__rate_limits = self.__default_rate_limits
|
309
229
|
|
310
230
|
@property
|
311
231
|
def RPM(self):
|
312
232
|
"""Model's requests-per-minute limit."""
|
313
|
-
|
314
|
-
|
315
|
-
return self._rpm
|
233
|
+
self._set_rate_limits()
|
234
|
+
return self._safety_factor * self.__rate_limits["rpm"]
|
316
235
|
|
317
236
|
@property
|
318
237
|
def TPM(self):
|
319
|
-
"""Model's tokens-per-minute limit.
|
320
|
-
# self._set_rate_limits()
|
321
|
-
# return self._safety_factor * self.__rate_limits["tpm"]
|
322
|
-
return self._tpm
|
323
|
-
|
324
|
-
@property
|
325
|
-
def rpm(self):
|
326
|
-
return self._rpm
|
238
|
+
"""Model's tokens-per-minute limit.
|
327
239
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
return self._tpm
|
335
|
-
|
336
|
-
@tpm.setter
|
337
|
-
def tpm(self, value):
|
338
|
-
self._tpm = value
|
240
|
+
>>> m = LanguageModel.example()
|
241
|
+
>>> m.TPM > 0
|
242
|
+
True
|
243
|
+
"""
|
244
|
+
self._set_rate_limits()
|
245
|
+
return self._safety_factor * self.__rate_limits["tpm"]
|
339
246
|
|
340
247
|
@staticmethod
|
341
248
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
@@ -363,10 +270,11 @@ class LanguageModel(
|
|
363
270
|
>>> m = LanguageModel.example(test_model = True)
|
364
271
|
>>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
|
365
272
|
>>> asyncio.run(test())
|
366
|
-
{'message':
|
273
|
+
{'message': '{"answer": "Hello world"}'}
|
367
274
|
|
368
275
|
>>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
369
|
-
{'message':
|
276
|
+
{'message': '{"answer": "Hello world"}'}
|
277
|
+
|
370
278
|
"""
|
371
279
|
pass
|
372
280
|
|
@@ -399,40 +307,68 @@ class LanguageModel(
|
|
399
307
|
|
400
308
|
return main()
|
401
309
|
|
402
|
-
@
|
403
|
-
def
|
404
|
-
"""
|
405
|
-
return extract_item_from_raw_response(raw_response, cls.key_sequence)
|
310
|
+
@abstractmethod
|
311
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
312
|
+
"""Parse the response and returns the response text.
|
406
313
|
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
314
|
+
>>> m = LanguageModel.example(test_model = True)
|
315
|
+
>>> m
|
316
|
+
Model(model_name = 'test', temperature = 0.5)
|
317
|
+
|
318
|
+
What is returned by the API is model-specific and often includes meta-data that we do not need.
|
319
|
+
For example, here is the results from a call to GPT-4:
|
320
|
+
To actually track the response, we need to grab
|
321
|
+
data["choices[0]"]["message"]["content"].
|
322
|
+
"""
|
323
|
+
raise NotImplementedError
|
324
|
+
|
325
|
+
async def _async_prepare_response(
|
326
|
+
self, model_call_outcome: IntendedModelCallOutcome, cache: "Cache"
|
327
|
+
) -> dict:
|
328
|
+
"""Prepare the response for return."""
|
329
|
+
|
330
|
+
model_response = {
|
331
|
+
"cache_used": model_call_outcome.cache_used,
|
332
|
+
"cache_key": model_call_outcome.cache_key,
|
333
|
+
"usage": model_call_outcome.response.get("usage", {}),
|
334
|
+
"raw_model_response": model_call_outcome.response,
|
335
|
+
}
|
336
|
+
|
337
|
+
answer_portion = self.parse_response(model_call_outcome.response)
|
338
|
+
try:
|
339
|
+
answer_dict = json.loads(answer_portion)
|
340
|
+
except json.JSONDecodeError as e:
|
341
|
+
# TODO: Turn into logs to generate issues
|
342
|
+
answer_dict, success = await repair(
|
343
|
+
bad_json=answer_portion, error_message=str(e), cache=cache
|
413
344
|
)
|
414
|
-
|
345
|
+
if not success:
|
346
|
+
raise Exception(
|
347
|
+
f"""Even the repair failed. The error was: {e}. The response was: {answer_portion}."""
|
348
|
+
)
|
415
349
|
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
350
|
+
return {**model_response, **answer_dict}
|
351
|
+
|
352
|
+
async def async_get_raw_response(
|
353
|
+
self,
|
354
|
+
user_prompt: str,
|
355
|
+
system_prompt: str,
|
356
|
+
cache: "Cache",
|
357
|
+
iteration: int = 0,
|
358
|
+
encoded_image=None,
|
359
|
+
) -> IntendedModelCallOutcome:
|
360
|
+
import warnings
|
361
|
+
|
362
|
+
warnings.warn(
|
363
|
+
"This method is deprecated. Use async_get_intended_model_call_outcome."
|
364
|
+
)
|
365
|
+
return await self._async_get_intended_model_call_outcome(
|
366
|
+
user_prompt=user_prompt,
|
367
|
+
system_prompt=system_prompt,
|
368
|
+
cache=cache,
|
369
|
+
iteration=iteration,
|
370
|
+
encoded_image=encoded_image,
|
371
|
+
)
|
436
372
|
|
437
373
|
async def _async_get_intended_model_call_outcome(
|
438
374
|
self,
|
@@ -441,7 +377,7 @@ class LanguageModel(
|
|
441
377
|
cache: "Cache",
|
442
378
|
iteration: int = 0,
|
443
379
|
encoded_image=None,
|
444
|
-
) ->
|
380
|
+
) -> IntendedModelCallOutcome:
|
445
381
|
"""Handle caching of responses.
|
446
382
|
|
447
383
|
:param user_prompt: The user's prompt.
|
@@ -460,18 +396,18 @@ class LanguageModel(
|
|
460
396
|
>>> from edsl import Cache
|
461
397
|
>>> m = LanguageModel.example(test_model = True)
|
462
398
|
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
463
|
-
|
399
|
+
IntendedModelCallOutcome(response = {'message': '{"answer": "Hello world"}'}, cache_used = False, cache_key = '24ff6ac2bc2f1729f817f261e0792577')
|
400
|
+
"""
|
464
401
|
|
465
402
|
if encoded_image:
|
466
403
|
# the image has is appended to the user_prompt for hash-lookup purposes
|
467
404
|
image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
|
468
|
-
user_prompt += f" {image_hash}"
|
469
405
|
|
470
406
|
cache_call_params = {
|
471
407
|
"model": str(self.model),
|
472
408
|
"parameters": self.parameters,
|
473
409
|
"system_prompt": system_prompt,
|
474
|
-
"user_prompt": user_prompt,
|
410
|
+
"user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
|
475
411
|
"iteration": iteration,
|
476
412
|
}
|
477
413
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
@@ -489,28 +425,21 @@ class LanguageModel(
|
|
489
425
|
"system_prompt": system_prompt,
|
490
426
|
**({"encoded_image": encoded_image} if encoded_image else {}),
|
491
427
|
}
|
492
|
-
|
493
|
-
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
428
|
+
response = await f(**params)
|
494
429
|
new_cache_key = cache.store(
|
495
430
|
**cache_call_params, response=response
|
496
431
|
) # store the response in the cache
|
497
432
|
assert new_cache_key == cache_key # should be the same
|
498
433
|
|
499
|
-
|
500
|
-
|
501
|
-
return ModelResponse(
|
502
|
-
response=response,
|
503
|
-
cache_used=cache_used,
|
504
|
-
cache_key=cache_key,
|
505
|
-
cached_response=cached_response,
|
506
|
-
cost=cost,
|
434
|
+
return IntendedModelCallOutcome(
|
435
|
+
response=response, cache_used=cache_used, cache_key=cache_key
|
507
436
|
)
|
508
437
|
|
509
438
|
_get_intended_model_call_outcome = sync_wrapper(
|
510
439
|
_async_get_intended_model_call_outcome
|
511
440
|
)
|
512
441
|
|
513
|
-
|
442
|
+
get_raw_response = sync_wrapper(async_get_raw_response)
|
514
443
|
|
515
444
|
def simple_ask(
|
516
445
|
self,
|
@@ -549,66 +478,14 @@ class LanguageModel(
|
|
549
478
|
"cache": cache,
|
550
479
|
**({"encoded_image": encoded_image} if encoded_image else {}),
|
551
480
|
}
|
552
|
-
|
553
|
-
|
554
|
-
edsl_dict = self.parse_response(model_outputs.response)
|
555
|
-
agent_response_dict = AgentResponseDict(
|
556
|
-
model_inputs=model_inputs,
|
557
|
-
model_outputs=model_outputs,
|
558
|
-
edsl_dict=edsl_dict,
|
559
|
-
)
|
560
|
-
return agent_response_dict
|
561
|
-
|
562
|
-
# return await self._async_prepare_response(model_call_outcome, cache=cache)
|
481
|
+
model_call_outcome = await self._async_get_intended_model_call_outcome(**params)
|
482
|
+
return await self._async_prepare_response(model_call_outcome, cache=cache)
|
563
483
|
|
564
484
|
get_response = sync_wrapper(async_get_response)
|
565
485
|
|
566
|
-
def cost(self, raw_response: dict[str, Any]) ->
|
486
|
+
def cost(self, raw_response: dict[str, Any]) -> float:
|
567
487
|
"""Return the dollar cost of a raw response."""
|
568
|
-
|
569
|
-
usage = self.get_usage_dict(raw_response)
|
570
|
-
from edsl.coop import Coop
|
571
|
-
|
572
|
-
c = Coop()
|
573
|
-
price_lookup = c.fetch_prices()
|
574
|
-
key = (self._inference_service_, self.model)
|
575
|
-
if key not in price_lookup:
|
576
|
-
return f"Could not find price for model {self.model} in the price lookup."
|
577
|
-
|
578
|
-
relevant_prices = price_lookup[key]
|
579
|
-
try:
|
580
|
-
input_tokens = int(usage[self.input_token_name])
|
581
|
-
output_tokens = int(usage[self.output_token_name])
|
582
|
-
except Exception as e:
|
583
|
-
return f"Could not fetch tokens from model response: {e}"
|
584
|
-
|
585
|
-
try:
|
586
|
-
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
587
|
-
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
588
|
-
except Exception as e:
|
589
|
-
if "output" not in relevant_prices:
|
590
|
-
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
591
|
-
if "input" not in relevant_prices:
|
592
|
-
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
593
|
-
return f"Could not fetch prices from {relevant_prices} - {e}"
|
594
|
-
|
595
|
-
if inverse_input_price == "infinity":
|
596
|
-
input_cost = 0
|
597
|
-
else:
|
598
|
-
try:
|
599
|
-
input_cost = input_tokens / float(inverse_input_price)
|
600
|
-
except Exception as e:
|
601
|
-
return f"Could not compute input price - {e}."
|
602
|
-
|
603
|
-
if inverse_output_price == "infinity":
|
604
|
-
output_cost = 0
|
605
|
-
else:
|
606
|
-
try:
|
607
|
-
output_cost = output_tokens / float(inverse_output_price)
|
608
|
-
except Exception as e:
|
609
|
-
return f"Could not compute output price - {e}"
|
610
|
-
|
611
|
-
return input_cost + output_cost
|
488
|
+
raise NotImplementedError
|
612
489
|
|
613
490
|
#######################
|
614
491
|
# SERIALIZATION METHODS
|
@@ -622,7 +499,7 @@ class LanguageModel(
|
|
622
499
|
|
623
500
|
>>> m = LanguageModel.example()
|
624
501
|
>>> m.to_dict()
|
625
|
-
{'model': '
|
502
|
+
{'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
626
503
|
"""
|
627
504
|
return self._to_dict()
|
628
505
|
|
@@ -698,8 +575,26 @@ class LanguageModel(
|
|
698
575
|
"""
|
699
576
|
from edsl import Model
|
700
577
|
|
578
|
+
class TestLanguageModelGood(LanguageModel):
|
579
|
+
use_cache = False
|
580
|
+
_model_ = "test"
|
581
|
+
_parameters_ = {"temperature": 0.5}
|
582
|
+
_inference_service_ = InferenceServiceType.TEST.value
|
583
|
+
|
584
|
+
async def async_execute_model_call(
|
585
|
+
self, user_prompt: str, system_prompt: str
|
586
|
+
) -> dict[str, Any]:
|
587
|
+
await asyncio.sleep(0.1)
|
588
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
589
|
+
if throw_exception:
|
590
|
+
raise Exception("This is a test error")
|
591
|
+
return {"message": f'{{"answer": "{canned_response}"}}'}
|
592
|
+
|
593
|
+
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
594
|
+
return raw_response["message"]
|
595
|
+
|
701
596
|
if test_model:
|
702
|
-
m =
|
597
|
+
m = TestLanguageModelGood()
|
703
598
|
return m
|
704
599
|
else:
|
705
600
|
return Model(skip_api_key_check=True)
|
@@ -40,8 +40,8 @@ class ModelList(Base, UserList):
|
|
40
40
|
def __hash__(self):
|
41
41
|
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
42
|
|
43
|
-
>>>
|
44
|
-
|
43
|
+
>>> hash(ModelList.example())
|
44
|
+
1423518243781418961
|
45
45
|
|
46
46
|
"""
|
47
47
|
from edsl.utilities.utilities import dict_hash
|
@@ -47,6 +47,13 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
47
47
|
must_be_async=True,
|
48
48
|
)
|
49
49
|
# LanguageModel children have to implement the parse_response method
|
50
|
+
RegisterLanguageModelsMeta.verify_method(
|
51
|
+
candidate_class=cls,
|
52
|
+
method_name="parse_response",
|
53
|
+
expected_return_type=str,
|
54
|
+
required_parameters=[("raw_response", dict[str, Any])],
|
55
|
+
must_be_async=False,
|
56
|
+
)
|
50
57
|
RegisterLanguageModelsMeta._registry[model_name] = cls
|
51
58
|
|
52
59
|
@classmethod
|
@@ -91,7 +98,7 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
91
98
|
|
92
99
|
required_parameters = required_parameters or []
|
93
100
|
method = getattr(candidate_class, method_name)
|
94
|
-
|
101
|
+
signature = inspect.signature(method)
|
95
102
|
|
96
103
|
RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
|
97
104
|
|
@@ -99,11 +106,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
99
106
|
RegisterLanguageModelsMeta._check_is_coroutine(method)
|
100
107
|
|
101
108
|
# Check the parameters
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
109
|
+
params = signature.parameters
|
110
|
+
for param_name, param_type in required_parameters:
|
111
|
+
RegisterLanguageModelsMeta._verify_parameter(
|
112
|
+
params, param_name, param_type, method_name
|
113
|
+
)
|
107
114
|
|
108
115
|
@staticmethod
|
109
116
|
def _check_method_defined(cls, method_name):
|
@@ -160,15 +167,23 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
160
167
|
Check if the return type of a method is as expected.
|
161
168
|
|
162
169
|
Example:
|
170
|
+
>>> class M:
|
171
|
+
... async def f(self) -> str: pass
|
172
|
+
>>> RegisterLanguageModelsMeta._check_return_type(M.f, str)
|
173
|
+
>>> class N:
|
174
|
+
... async def f(self) -> int: pass
|
175
|
+
>>> RegisterLanguageModelsMeta._check_return_type(N.f, str)
|
176
|
+
Traceback (most recent call last):
|
177
|
+
...
|
178
|
+
TypeError: Return type of f must be <class 'str'>. Got <class 'int'>.
|
163
179
|
"""
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
# )
|
180
|
+
if inspect.isroutine(method):
|
181
|
+
# return_type = inspect.signature(method).return_annotation
|
182
|
+
return_type = get_type_hints(method)["return"]
|
183
|
+
if return_type != expected_return_type:
|
184
|
+
raise TypeError(
|
185
|
+
f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
|
186
|
+
)
|
172
187
|
|
173
188
|
@classmethod
|
174
189
|
def model_names_to_classes(cls):
|