edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +135 -219
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +47 -56
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +4 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +48 -54
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +6 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +37 -22
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +93 -14
- edsl/jobs/interviews/Interview.py +366 -78
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
- edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
- edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +148 -213
- edsl/language_models/LanguageModel.py +261 -156
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +23 -6
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -41
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +52 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +159 -65
- edsl/questions/QuestionNumerical.py +88 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +46 -48
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +96 -25
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +361 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +637 -311
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.32.dist-info/RECORD +0 -209
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -1,4 +1,16 @@
|
|
1
|
-
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
1
|
+
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
2
|
+
|
3
|
+
Terminology:
|
4
|
+
|
5
|
+
raw_response: The JSON response from the model. This has all the model meta-data about the call.
|
6
|
+
|
7
|
+
edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
|
8
|
+
such as the cache key, token usage, etc.
|
9
|
+
|
10
|
+
generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
|
11
|
+
edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
|
12
|
+
|
13
|
+
"""
|
2
14
|
|
3
15
|
from __future__ import annotations
|
4
16
|
import warnings
|
@@ -8,47 +20,103 @@ import json
|
|
8
20
|
import time
|
9
21
|
import os
|
10
22
|
import hashlib
|
11
|
-
from typing import
|
23
|
+
from typing import (
|
24
|
+
Coroutine,
|
25
|
+
Any,
|
26
|
+
Callable,
|
27
|
+
Type,
|
28
|
+
Union,
|
29
|
+
List,
|
30
|
+
get_type_hints,
|
31
|
+
TypedDict,
|
32
|
+
Optional,
|
33
|
+
)
|
12
34
|
from abc import ABC, abstractmethod
|
13
35
|
|
36
|
+
from json_repair import repair_json
|
14
37
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
self.cache_key = cache_key
|
22
|
-
|
23
|
-
def __iter__(self):
|
24
|
-
"""Iterate over the class attributes.
|
25
|
-
|
26
|
-
>>> a, b, c = IntendedModelCallOutcome({'answer': "yes"}, True, 'x1289')
|
27
|
-
>>> a
|
28
|
-
{'answer': 'yes'}
|
29
|
-
"""
|
30
|
-
yield self.response
|
31
|
-
yield self.cache_used
|
32
|
-
yield self.cache_key
|
33
|
-
|
34
|
-
def __len__(self):
|
35
|
-
return 3
|
36
|
-
|
37
|
-
def __repr__(self):
|
38
|
-
return f"IntendedModelCallOutcome(response = {self.response}, cache_used = {self.cache_used}, cache_key = '{self.cache_key}')"
|
38
|
+
from edsl.data_transfer_models import (
|
39
|
+
ModelResponse,
|
40
|
+
ModelInputs,
|
41
|
+
EDSLOutput,
|
42
|
+
AgentResponseDict,
|
43
|
+
)
|
39
44
|
|
40
45
|
|
41
46
|
from edsl.config import CONFIG
|
42
|
-
|
43
47
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
44
48
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
45
|
-
|
46
49
|
from edsl.language_models.repair import repair
|
47
50
|
from edsl.enums import InferenceServiceType
|
48
51
|
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
49
52
|
from edsl.enums import service_to_api_keyname
|
50
53
|
from edsl.exceptions import MissingAPIKeyError
|
51
54
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
55
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
56
|
+
|
57
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
58
|
+
|
59
|
+
|
60
|
+
def convert_answer(response_part):
|
61
|
+
import json
|
62
|
+
|
63
|
+
response_part = response_part.strip()
|
64
|
+
|
65
|
+
if response_part == "None":
|
66
|
+
return None
|
67
|
+
|
68
|
+
repaired = repair_json(response_part)
|
69
|
+
if repaired == '""':
|
70
|
+
# it was a literal string
|
71
|
+
return response_part
|
72
|
+
|
73
|
+
try:
|
74
|
+
return json.loads(repaired)
|
75
|
+
except json.JSONDecodeError as j:
|
76
|
+
# last resort
|
77
|
+
return response_part
|
78
|
+
|
79
|
+
|
80
|
+
def extract_item_from_raw_response(data, key_sequence):
|
81
|
+
if isinstance(data, str):
|
82
|
+
try:
|
83
|
+
data = json.loads(data)
|
84
|
+
except json.JSONDecodeError as e:
|
85
|
+
return data
|
86
|
+
current_data = data
|
87
|
+
for i, key in enumerate(key_sequence):
|
88
|
+
try:
|
89
|
+
if isinstance(current_data, (list, tuple)):
|
90
|
+
if not isinstance(key, int):
|
91
|
+
raise TypeError(
|
92
|
+
f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
|
93
|
+
)
|
94
|
+
if key < 0 or key >= len(current_data):
|
95
|
+
raise IndexError(
|
96
|
+
f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
|
97
|
+
)
|
98
|
+
elif isinstance(current_data, dict):
|
99
|
+
if key not in current_data:
|
100
|
+
raise KeyError(
|
101
|
+
f"Key '{key}' not found in dictionary at position {i}"
|
102
|
+
)
|
103
|
+
else:
|
104
|
+
raise TypeError(
|
105
|
+
f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
|
106
|
+
)
|
107
|
+
|
108
|
+
current_data = current_data[key]
|
109
|
+
except Exception as e:
|
110
|
+
path = " -> ".join(map(str, key_sequence[: i + 1]))
|
111
|
+
if "error" in data:
|
112
|
+
msg = data["error"]
|
113
|
+
else:
|
114
|
+
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
115
|
+
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
116
|
+
if isinstance(current_data, str):
|
117
|
+
return current_data.strip()
|
118
|
+
else:
|
119
|
+
return current_data
|
52
120
|
|
53
121
|
|
54
122
|
def handle_key_error(func):
|
@@ -92,21 +160,29 @@ class LanguageModel(
|
|
92
160
|
"""
|
93
161
|
|
94
162
|
_model_ = None
|
95
|
-
|
163
|
+
key_sequence = (
|
164
|
+
None # This should be something like ["choices", 0, "message", "content"]
|
165
|
+
)
|
96
166
|
__rate_limits = None
|
97
|
-
__default_rate_limits = {
|
98
|
-
"rpm": 10_000,
|
99
|
-
"tpm": 2_000_000,
|
100
|
-
} # TODO: Use the OpenAI Teir 1 rate limits
|
101
167
|
_safety_factor = 0.8
|
102
168
|
|
103
|
-
def __init__(
|
169
|
+
def __init__(
|
170
|
+
self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
|
171
|
+
):
|
104
172
|
"""Initialize the LanguageModel."""
|
105
173
|
self.model = getattr(self, "_model_", None)
|
106
174
|
default_parameters = getattr(self, "_parameters_", None)
|
107
175
|
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
108
176
|
self.parameters = parameters
|
109
177
|
self.remote = False
|
178
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
179
|
+
|
180
|
+
# self._rpm / _tpm comes from the class
|
181
|
+
if rpm is not None:
|
182
|
+
self._rpm = rpm
|
183
|
+
|
184
|
+
if tpm is not None:
|
185
|
+
self._tpm = tpm
|
110
186
|
|
111
187
|
for key, value in parameters.items():
|
112
188
|
setattr(self, key, value)
|
@@ -133,7 +209,6 @@ class LanguageModel(
|
|
133
209
|
def api_token(self) -> str:
|
134
210
|
if not hasattr(self, "_api_token"):
|
135
211
|
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
136
|
-
|
137
212
|
if self._inference_service_ == "bedrock":
|
138
213
|
self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
|
139
214
|
# Check if any of the tokens are None
|
@@ -142,13 +217,13 @@ class LanguageModel(
|
|
142
217
|
self._api_token = os.getenv(key_name)
|
143
218
|
missing_token = self._api_token is None
|
144
219
|
if missing_token and self._inference_service_ != "test" and not self.remote:
|
145
|
-
print("
|
220
|
+
print("raising error")
|
146
221
|
raise MissingAPIKeyError(
|
147
222
|
f"""The key for service: `{self._inference_service_}` is not set.
|
148
223
|
Need a key with name {key_name} in your .env file."""
|
149
224
|
)
|
150
225
|
|
151
|
-
|
226
|
+
return self._api_token
|
152
227
|
|
153
228
|
def __getitem__(self, key):
|
154
229
|
return getattr(self, key)
|
@@ -209,40 +284,58 @@ class LanguageModel(
|
|
209
284
|
>>> m = LanguageModel.example()
|
210
285
|
>>> m.set_rate_limits(rpm=100, tpm=1000)
|
211
286
|
>>> m.RPM
|
212
|
-
|
287
|
+
100
|
213
288
|
"""
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
289
|
+
if rpm is not None:
|
290
|
+
self._rpm = rpm
|
291
|
+
if tpm is not None:
|
292
|
+
self._tpm = tpm
|
293
|
+
return None
|
294
|
+
# self._set_rate_limits(rpm=rpm, tpm=tpm)
|
295
|
+
|
296
|
+
# def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
297
|
+
# """Set the rate limits for the model.
|
298
|
+
|
299
|
+
# If the model does not have rate limits, use the default rate limits."""
|
300
|
+
# if rpm is not None and tpm is not None:
|
301
|
+
# self.__rate_limits = {"rpm": rpm, "tpm": tpm}
|
302
|
+
# return
|
303
|
+
|
304
|
+
# if self.__rate_limits is None:
|
305
|
+
# if hasattr(self, "get_rate_limits"):
|
306
|
+
# self.__rate_limits = self.get_rate_limits()
|
307
|
+
# else:
|
308
|
+
# self.__rate_limits = self.__default_rate_limits
|
229
309
|
|
230
310
|
@property
|
231
311
|
def RPM(self):
|
232
312
|
"""Model's requests-per-minute limit."""
|
233
|
-
self._set_rate_limits()
|
234
|
-
return self._safety_factor * self.__rate_limits["rpm"]
|
313
|
+
# self._set_rate_limits()
|
314
|
+
# return self._safety_factor * self.__rate_limits["rpm"]
|
315
|
+
return self._rpm
|
235
316
|
|
236
317
|
@property
|
237
318
|
def TPM(self):
|
238
|
-
"""Model's tokens-per-minute limit.
|
319
|
+
"""Model's tokens-per-minute limit."""
|
320
|
+
# self._set_rate_limits()
|
321
|
+
# return self._safety_factor * self.__rate_limits["tpm"]
|
322
|
+
return self._tpm
|
239
323
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
324
|
+
@property
|
325
|
+
def rpm(self):
|
326
|
+
return self._rpm
|
327
|
+
|
328
|
+
@rpm.setter
|
329
|
+
def rpm(self, value):
|
330
|
+
self._rpm = value
|
331
|
+
|
332
|
+
@property
|
333
|
+
def tpm(self):
|
334
|
+
return self._tpm
|
335
|
+
|
336
|
+
@tpm.setter
|
337
|
+
def tpm(self, value):
|
338
|
+
self._tpm = value
|
246
339
|
|
247
340
|
@staticmethod
|
248
341
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
@@ -270,11 +363,10 @@ class LanguageModel(
|
|
270
363
|
>>> m = LanguageModel.example(test_model = True)
|
271
364
|
>>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
|
272
365
|
>>> asyncio.run(test())
|
273
|
-
{'message': '
|
366
|
+
{'message': [{'text': 'Hello world'}], ...}
|
274
367
|
|
275
368
|
>>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
276
|
-
{'message': '
|
277
|
-
|
369
|
+
{'message': [{'text': 'Hello world'}], ...}
|
278
370
|
"""
|
279
371
|
pass
|
280
372
|
|
@@ -307,68 +399,40 @@ class LanguageModel(
|
|
307
399
|
|
308
400
|
return main()
|
309
401
|
|
310
|
-
@
|
311
|
-
def
|
312
|
-
"""
|
313
|
-
|
314
|
-
>>> m = LanguageModel.example(test_model = True)
|
315
|
-
>>> m
|
316
|
-
Model(model_name = 'test', temperature = 0.5)
|
317
|
-
|
318
|
-
What is returned by the API is model-specific and often includes meta-data that we do not need.
|
319
|
-
For example, here is the results from a call to GPT-4:
|
320
|
-
To actually track the response, we need to grab
|
321
|
-
data["choices[0]"]["message"]["content"].
|
322
|
-
"""
|
323
|
-
raise NotImplementedError
|
324
|
-
|
325
|
-
async def _async_prepare_response(
|
326
|
-
self, model_call_outcome: IntendedModelCallOutcome, cache: "Cache"
|
327
|
-
) -> dict:
|
328
|
-
"""Prepare the response for return."""
|
329
|
-
|
330
|
-
model_response = {
|
331
|
-
"cache_used": model_call_outcome.cache_used,
|
332
|
-
"cache_key": model_call_outcome.cache_key,
|
333
|
-
"usage": model_call_outcome.response.get("usage", {}),
|
334
|
-
"raw_model_response": model_call_outcome.response,
|
335
|
-
}
|
402
|
+
@classmethod
|
403
|
+
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
404
|
+
"""Return the generated token string from the raw response."""
|
405
|
+
return extract_item_from_raw_response(raw_response, cls.key_sequence)
|
336
406
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
bad_json=answer_portion, error_message=str(e), cache=cache
|
407
|
+
@classmethod
|
408
|
+
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
409
|
+
"""Return the usage dictionary from the raw response."""
|
410
|
+
if not hasattr(cls, "usage_sequence"):
|
411
|
+
raise NotImplementedError(
|
412
|
+
"This inference service does not have a usage_sequence."
|
344
413
|
)
|
345
|
-
|
346
|
-
raise Exception(
|
347
|
-
f"""Even the repair failed. The error was: {e}. The response was: {answer_portion}."""
|
348
|
-
)
|
349
|
-
|
350
|
-
return {**model_response, **answer_dict}
|
351
|
-
|
352
|
-
async def async_get_raw_response(
|
353
|
-
self,
|
354
|
-
user_prompt: str,
|
355
|
-
system_prompt: str,
|
356
|
-
cache: "Cache",
|
357
|
-
iteration: int = 0,
|
358
|
-
encoded_image=None,
|
359
|
-
) -> IntendedModelCallOutcome:
|
360
|
-
import warnings
|
414
|
+
return extract_item_from_raw_response(raw_response, cls.usage_sequence)
|
361
415
|
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
416
|
+
@classmethod
|
417
|
+
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
418
|
+
"""Parses the API response and returns the response text."""
|
419
|
+
generated_token_string = cls.get_generated_token_string(raw_response)
|
420
|
+
last_newline = generated_token_string.rfind("\n")
|
421
|
+
|
422
|
+
if last_newline == -1:
|
423
|
+
# There is no comment
|
424
|
+
edsl_dict = {
|
425
|
+
"answer": convert_answer(generated_token_string),
|
426
|
+
"generated_tokens": generated_token_string,
|
427
|
+
"comment": None,
|
428
|
+
}
|
429
|
+
else:
|
430
|
+
edsl_dict = {
|
431
|
+
"answer": convert_answer(generated_token_string[:last_newline]),
|
432
|
+
"comment": generated_token_string[last_newline + 1 :].strip(),
|
433
|
+
"generated_tokens": generated_token_string,
|
434
|
+
}
|
435
|
+
return EDSLOutput(**edsl_dict)
|
372
436
|
|
373
437
|
async def _async_get_intended_model_call_outcome(
|
374
438
|
self,
|
@@ -377,7 +441,7 @@ class LanguageModel(
|
|
377
441
|
cache: "Cache",
|
378
442
|
iteration: int = 0,
|
379
443
|
encoded_image=None,
|
380
|
-
) ->
|
444
|
+
) -> ModelResponse:
|
381
445
|
"""Handle caching of responses.
|
382
446
|
|
383
447
|
:param user_prompt: The user's prompt.
|
@@ -396,18 +460,18 @@ class LanguageModel(
|
|
396
460
|
>>> from edsl import Cache
|
397
461
|
>>> m = LanguageModel.example(test_model = True)
|
398
462
|
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
399
|
-
|
400
|
-
"""
|
463
|
+
ModelResponse(...)"""
|
401
464
|
|
402
465
|
if encoded_image:
|
403
466
|
# the image has is appended to the user_prompt for hash-lookup purposes
|
404
467
|
image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
|
468
|
+
user_prompt += f" {image_hash}"
|
405
469
|
|
406
470
|
cache_call_params = {
|
407
471
|
"model": str(self.model),
|
408
472
|
"parameters": self.parameters,
|
409
473
|
"system_prompt": system_prompt,
|
410
|
-
"user_prompt": user_prompt
|
474
|
+
"user_prompt": user_prompt,
|
411
475
|
"iteration": iteration,
|
412
476
|
}
|
413
477
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
@@ -425,21 +489,28 @@ class LanguageModel(
|
|
425
489
|
"system_prompt": system_prompt,
|
426
490
|
**({"encoded_image": encoded_image} if encoded_image else {}),
|
427
491
|
}
|
428
|
-
response = await f(**params)
|
492
|
+
# response = await f(**params)
|
493
|
+
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
429
494
|
new_cache_key = cache.store(
|
430
495
|
**cache_call_params, response=response
|
431
496
|
) # store the response in the cache
|
432
497
|
assert new_cache_key == cache_key # should be the same
|
433
498
|
|
434
|
-
|
435
|
-
|
499
|
+
cost = self.cost(response)
|
500
|
+
|
501
|
+
return ModelResponse(
|
502
|
+
response=response,
|
503
|
+
cache_used=cache_used,
|
504
|
+
cache_key=cache_key,
|
505
|
+
cached_response=cached_response,
|
506
|
+
cost=cost,
|
436
507
|
)
|
437
508
|
|
438
509
|
_get_intended_model_call_outcome = sync_wrapper(
|
439
510
|
_async_get_intended_model_call_outcome
|
440
511
|
)
|
441
512
|
|
442
|
-
get_raw_response = sync_wrapper(async_get_raw_response)
|
513
|
+
# get_raw_response = sync_wrapper(async_get_raw_response)
|
443
514
|
|
444
515
|
def simple_ask(
|
445
516
|
self,
|
@@ -478,14 +549,66 @@ class LanguageModel(
|
|
478
549
|
"cache": cache,
|
479
550
|
**({"encoded_image": encoded_image} if encoded_image else {}),
|
480
551
|
}
|
481
|
-
|
482
|
-
|
552
|
+
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
553
|
+
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
554
|
+
edsl_dict = self.parse_response(model_outputs.response)
|
555
|
+
agent_response_dict = AgentResponseDict(
|
556
|
+
model_inputs=model_inputs,
|
557
|
+
model_outputs=model_outputs,
|
558
|
+
edsl_dict=edsl_dict,
|
559
|
+
)
|
560
|
+
return agent_response_dict
|
561
|
+
|
562
|
+
# return await self._async_prepare_response(model_call_outcome, cache=cache)
|
483
563
|
|
484
564
|
get_response = sync_wrapper(async_get_response)
|
485
565
|
|
486
|
-
def cost(self, raw_response: dict[str, Any]) -> float:
|
566
|
+
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
487
567
|
"""Return the dollar cost of a raw response."""
|
488
|
-
|
568
|
+
|
569
|
+
usage = self.get_usage_dict(raw_response)
|
570
|
+
from edsl.coop import Coop
|
571
|
+
|
572
|
+
c = Coop()
|
573
|
+
price_lookup = c.fetch_prices()
|
574
|
+
key = (self._inference_service_, self.model)
|
575
|
+
if key not in price_lookup:
|
576
|
+
return f"Could not find price for model {self.model} in the price lookup."
|
577
|
+
|
578
|
+
relevant_prices = price_lookup[key]
|
579
|
+
try:
|
580
|
+
input_tokens = int(usage[self.input_token_name])
|
581
|
+
output_tokens = int(usage[self.output_token_name])
|
582
|
+
except Exception as e:
|
583
|
+
return f"Could not fetch tokens from model response: {e}"
|
584
|
+
|
585
|
+
try:
|
586
|
+
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
587
|
+
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
588
|
+
except Exception as e:
|
589
|
+
if "output" not in relevant_prices:
|
590
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
591
|
+
if "input" not in relevant_prices:
|
592
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
593
|
+
return f"Could not fetch prices from {relevant_prices} - {e}"
|
594
|
+
|
595
|
+
if inverse_input_price == "infinity":
|
596
|
+
input_cost = 0
|
597
|
+
else:
|
598
|
+
try:
|
599
|
+
input_cost = input_tokens / float(inverse_input_price)
|
600
|
+
except Exception as e:
|
601
|
+
return f"Could not compute input price - {e}."
|
602
|
+
|
603
|
+
if inverse_output_price == "infinity":
|
604
|
+
output_cost = 0
|
605
|
+
else:
|
606
|
+
try:
|
607
|
+
output_cost = output_tokens / float(inverse_output_price)
|
608
|
+
except Exception as e:
|
609
|
+
return f"Could not compute output price - {e}"
|
610
|
+
|
611
|
+
return input_cost + output_cost
|
489
612
|
|
490
613
|
#######################
|
491
614
|
# SERIALIZATION METHODS
|
@@ -499,7 +622,7 @@ class LanguageModel(
|
|
499
622
|
|
500
623
|
>>> m = LanguageModel.example()
|
501
624
|
>>> m.to_dict()
|
502
|
-
{'model': '
|
625
|
+
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
503
626
|
"""
|
504
627
|
return self._to_dict()
|
505
628
|
|
@@ -575,26 +698,8 @@ class LanguageModel(
|
|
575
698
|
"""
|
576
699
|
from edsl import Model
|
577
700
|
|
578
|
-
class TestLanguageModelGood(LanguageModel):
|
579
|
-
use_cache = False
|
580
|
-
_model_ = "test"
|
581
|
-
_parameters_ = {"temperature": 0.5}
|
582
|
-
_inference_service_ = InferenceServiceType.TEST.value
|
583
|
-
|
584
|
-
async def async_execute_model_call(
|
585
|
-
self, user_prompt: str, system_prompt: str
|
586
|
-
) -> dict[str, Any]:
|
587
|
-
await asyncio.sleep(0.1)
|
588
|
-
# return {"message": """{"answer": "Hello, world"}"""}
|
589
|
-
if throw_exception:
|
590
|
-
raise Exception("This is a test error")
|
591
|
-
return {"message": f'{{"answer": "{canned_response}"}}'}
|
592
|
-
|
593
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
594
|
-
return raw_response["message"]
|
595
|
-
|
596
701
|
if test_model:
|
597
|
-
m =
|
702
|
+
m = Model("test", canned_response=canned_response)
|
598
703
|
return m
|
599
704
|
else:
|
600
705
|
return Model(skip_api_key_check=True)
|
@@ -40,8 +40,8 @@ class ModelList(Base, UserList):
|
|
40
40
|
def __hash__(self):
|
41
41
|
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
42
|
|
43
|
-
>>> hash(
|
44
|
-
|
43
|
+
>>> isinstance(hash(Model()), int)
|
44
|
+
True
|
45
45
|
|
46
46
|
"""
|
47
47
|
from edsl.utilities.utilities import dict_hash
|
@@ -47,13 +47,6 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
47
47
|
must_be_async=True,
|
48
48
|
)
|
49
49
|
# LanguageModel children have to implement the parse_response method
|
50
|
-
RegisterLanguageModelsMeta.verify_method(
|
51
|
-
candidate_class=cls,
|
52
|
-
method_name="parse_response",
|
53
|
-
expected_return_type=str,
|
54
|
-
required_parameters=[("raw_response", dict[str, Any])],
|
55
|
-
must_be_async=False,
|
56
|
-
)
|
57
50
|
RegisterLanguageModelsMeta._registry[model_name] = cls
|
58
51
|
|
59
52
|
@classmethod
|
@@ -98,7 +91,7 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
98
91
|
|
99
92
|
required_parameters = required_parameters or []
|
100
93
|
method = getattr(candidate_class, method_name)
|
101
|
-
signature = inspect.signature(method)
|
94
|
+
# signature = inspect.signature(method)
|
102
95
|
|
103
96
|
RegisterLanguageModelsMeta._check_return_type(method, expected_return_type)
|
104
97
|
|
@@ -106,11 +99,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
106
99
|
RegisterLanguageModelsMeta._check_is_coroutine(method)
|
107
100
|
|
108
101
|
# Check the parameters
|
109
|
-
params = signature.parameters
|
110
|
-
for param_name, param_type in required_parameters:
|
111
|
-
|
112
|
-
|
113
|
-
|
102
|
+
# params = signature.parameters
|
103
|
+
# for param_name, param_type in required_parameters:
|
104
|
+
# RegisterLanguageModelsMeta._verify_parameter(
|
105
|
+
# params, param_name, param_type, method_name
|
106
|
+
# )
|
114
107
|
|
115
108
|
@staticmethod
|
116
109
|
def _check_method_defined(cls, method_name):
|
@@ -167,23 +160,15 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
167
160
|
Check if the return type of a method is as expected.
|
168
161
|
|
169
162
|
Example:
|
170
|
-
>>> class M:
|
171
|
-
... async def f(self) -> str: pass
|
172
|
-
>>> RegisterLanguageModelsMeta._check_return_type(M.f, str)
|
173
|
-
>>> class N:
|
174
|
-
... async def f(self) -> int: pass
|
175
|
-
>>> RegisterLanguageModelsMeta._check_return_type(N.f, str)
|
176
|
-
Traceback (most recent call last):
|
177
|
-
...
|
178
|
-
TypeError: Return type of f must be <class 'str'>. Got <class 'int'>.
|
179
163
|
"""
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
164
|
+
pass
|
165
|
+
# if inspect.isroutine(method):
|
166
|
+
# # return_type = inspect.signature(method).return_annotation
|
167
|
+
# return_type = get_type_hints(method)["return"]
|
168
|
+
# if return_type != expected_return_type:
|
169
|
+
# raise TypeError(
|
170
|
+
# f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
|
171
|
+
# )
|
187
172
|
|
188
173
|
@classmethod
|
189
174
|
def model_names_to_classes(cls):
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from openai import AsyncOpenAI
|
2
|
+
import asyncio
|
3
|
+
|
4
|
+
client = AsyncOpenAI(base_url="http://127.0.0.1:8000/v1", api_key="fake_key")
|
5
|
+
|
6
|
+
|
7
|
+
async def main():
|
8
|
+
response = await client.chat.completions.create(
|
9
|
+
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Question XX42"}]
|
10
|
+
)
|
11
|
+
print(response)
|
12
|
+
|
13
|
+
|
14
|
+
if __name__ == "__main__":
|
15
|
+
asyncio.run(main())
|