edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +136 -221
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +48 -47
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data/CacheHandler.py +3 -4
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +8 -0
- edsl/exceptions/general.py +10 -8
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +112 -0
- edsl/inference_services/AzureAI.py +214 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +5 -4
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +55 -56
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +25 -0
- edsl/inference_services/registry.py +19 -1
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +137 -41
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +105 -18
- edsl/jobs/interviews/Interview.py +393 -83
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
- edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +205 -126
- edsl/language_models/LanguageModel.py +297 -177
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +25 -8
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -42
- edsl/questions/QuestionCheckBox.py +227 -36
- edsl/questions/QuestionExtract.py +98 -28
- edsl/questions/QuestionFreeText.py +47 -31
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -23
- edsl/questions/QuestionMultipleChoice.py +159 -66
- edsl/questions/QuestionNumerical.py +88 -47
- edsl/questions/QuestionRank.py +182 -25
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +58 -30
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +109 -24
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +546 -21
- edsl/scenarios/ScenarioListExportMixin.py +24 -4
- edsl/scenarios/ScenarioListPdfMixin.py +153 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +15 -3
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +707 -298
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +40 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.31.dev4.dist-info/RECORD +0 -204
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -1,4 +1,16 @@
|
|
1
|
-
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
1
|
+
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
2
|
+
|
3
|
+
Terminology:
|
4
|
+
|
5
|
+
raw_response: The JSON response from the model. This has all the model meta-data about the call.
|
6
|
+
|
7
|
+
edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
|
8
|
+
such as the cache key, token usage, etc.
|
9
|
+
|
10
|
+
generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
|
11
|
+
edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
|
12
|
+
|
13
|
+
"""
|
2
14
|
|
3
15
|
from __future__ import annotations
|
4
16
|
import warnings
|
@@ -8,45 +20,103 @@ import json
|
|
8
20
|
import time
|
9
21
|
import os
|
10
22
|
import hashlib
|
11
|
-
from typing import
|
23
|
+
from typing import (
|
24
|
+
Coroutine,
|
25
|
+
Any,
|
26
|
+
Callable,
|
27
|
+
Type,
|
28
|
+
Union,
|
29
|
+
List,
|
30
|
+
get_type_hints,
|
31
|
+
TypedDict,
|
32
|
+
Optional,
|
33
|
+
)
|
12
34
|
from abc import ABC, abstractmethod
|
13
35
|
|
14
|
-
|
15
|
-
"This is a tuple-like class that holds the response, cache_used, and cache_key."
|
16
|
-
|
17
|
-
def __init__(self, response: dict, cache_used: bool, cache_key: str):
|
18
|
-
self.response = response
|
19
|
-
self.cache_used = cache_used
|
20
|
-
self.cache_key = cache_key
|
36
|
+
from json_repair import repair_json
|
21
37
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
"""
|
29
|
-
yield self.response
|
30
|
-
yield self.cache_used
|
31
|
-
yield self.cache_key
|
38
|
+
from edsl.data_transfer_models import (
|
39
|
+
ModelResponse,
|
40
|
+
ModelInputs,
|
41
|
+
EDSLOutput,
|
42
|
+
AgentResponseDict,
|
43
|
+
)
|
32
44
|
|
33
|
-
def __len__(self):
|
34
|
-
return 3
|
35
|
-
|
36
|
-
def __repr__(self):
|
37
|
-
return f"IntendedModelCallOutcome(response = {self.response}, cache_used = {self.cache_used}, cache_key = '{self.cache_key}')"
|
38
45
|
|
39
46
|
from edsl.config import CONFIG
|
40
|
-
|
41
47
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
42
48
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
43
|
-
|
44
49
|
from edsl.language_models.repair import repair
|
45
50
|
from edsl.enums import InferenceServiceType
|
46
51
|
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
47
52
|
from edsl.enums import service_to_api_keyname
|
48
53
|
from edsl.exceptions import MissingAPIKeyError
|
49
54
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
55
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
56
|
+
|
57
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
58
|
+
|
59
|
+
|
60
|
+
def convert_answer(response_part):
|
61
|
+
import json
|
62
|
+
|
63
|
+
response_part = response_part.strip()
|
64
|
+
|
65
|
+
if response_part == "None":
|
66
|
+
return None
|
67
|
+
|
68
|
+
repaired = repair_json(response_part)
|
69
|
+
if repaired == '""':
|
70
|
+
# it was a literal string
|
71
|
+
return response_part
|
72
|
+
|
73
|
+
try:
|
74
|
+
return json.loads(repaired)
|
75
|
+
except json.JSONDecodeError as j:
|
76
|
+
# last resort
|
77
|
+
return response_part
|
78
|
+
|
79
|
+
|
80
|
+
def extract_item_from_raw_response(data, key_sequence):
|
81
|
+
if isinstance(data, str):
|
82
|
+
try:
|
83
|
+
data = json.loads(data)
|
84
|
+
except json.JSONDecodeError as e:
|
85
|
+
return data
|
86
|
+
current_data = data
|
87
|
+
for i, key in enumerate(key_sequence):
|
88
|
+
try:
|
89
|
+
if isinstance(current_data, (list, tuple)):
|
90
|
+
if not isinstance(key, int):
|
91
|
+
raise TypeError(
|
92
|
+
f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
|
93
|
+
)
|
94
|
+
if key < 0 or key >= len(current_data):
|
95
|
+
raise IndexError(
|
96
|
+
f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
|
97
|
+
)
|
98
|
+
elif isinstance(current_data, dict):
|
99
|
+
if key not in current_data:
|
100
|
+
raise KeyError(
|
101
|
+
f"Key '{key}' not found in dictionary at position {i}"
|
102
|
+
)
|
103
|
+
else:
|
104
|
+
raise TypeError(
|
105
|
+
f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
|
106
|
+
)
|
107
|
+
|
108
|
+
current_data = current_data[key]
|
109
|
+
except Exception as e:
|
110
|
+
path = " -> ".join(map(str, key_sequence[: i + 1]))
|
111
|
+
if "error" in data:
|
112
|
+
msg = data["error"]
|
113
|
+
else:
|
114
|
+
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
115
|
+
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
116
|
+
if isinstance(current_data, str):
|
117
|
+
return current_data.strip()
|
118
|
+
else:
|
119
|
+
return current_data
|
50
120
|
|
51
121
|
|
52
122
|
def handle_key_error(func):
|
@@ -90,21 +160,29 @@ class LanguageModel(
|
|
90
160
|
"""
|
91
161
|
|
92
162
|
_model_ = None
|
93
|
-
|
163
|
+
key_sequence = (
|
164
|
+
None # This should be something like ["choices", 0, "message", "content"]
|
165
|
+
)
|
94
166
|
__rate_limits = None
|
95
|
-
__default_rate_limits = {
|
96
|
-
"rpm": 10_000,
|
97
|
-
"tpm": 2_000_000,
|
98
|
-
} # TODO: Use the OpenAI Teir 1 rate limits
|
99
167
|
_safety_factor = 0.8
|
100
168
|
|
101
|
-
def __init__(
|
169
|
+
def __init__(
|
170
|
+
self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
|
171
|
+
):
|
102
172
|
"""Initialize the LanguageModel."""
|
103
173
|
self.model = getattr(self, "_model_", None)
|
104
174
|
default_parameters = getattr(self, "_parameters_", None)
|
105
175
|
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
106
176
|
self.parameters = parameters
|
107
177
|
self.remote = False
|
178
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
179
|
+
|
180
|
+
# self._rpm / _tpm comes from the class
|
181
|
+
if rpm is not None:
|
182
|
+
self._rpm = rpm
|
183
|
+
|
184
|
+
if tpm is not None:
|
185
|
+
self._tpm = tpm
|
108
186
|
|
109
187
|
for key, value in parameters.items():
|
110
188
|
setattr(self, key, value)
|
@@ -131,17 +209,20 @@ class LanguageModel(
|
|
131
209
|
def api_token(self) -> str:
|
132
210
|
if not hasattr(self, "_api_token"):
|
133
211
|
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
134
|
-
self.
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
212
|
+
if self._inference_service_ == "bedrock":
|
213
|
+
self._api_token = [os.getenv(key_name[0]), os.getenv(key_name[1])]
|
214
|
+
# Check if any of the tokens are None
|
215
|
+
missing_token = any(token is None for token in self._api_token)
|
216
|
+
else:
|
217
|
+
self._api_token = os.getenv(key_name)
|
218
|
+
missing_token = self._api_token is None
|
219
|
+
if missing_token and self._inference_service_ != "test" and not self.remote:
|
220
|
+
print("raising error")
|
140
221
|
raise MissingAPIKeyError(
|
141
222
|
f"""The key for service: `{self._inference_service_}` is not set.
|
142
|
-
|
143
|
-
"""
|
223
|
+
Need a key with name {key_name} in your .env file."""
|
144
224
|
)
|
225
|
+
|
145
226
|
return self._api_token
|
146
227
|
|
147
228
|
def __getitem__(self, key):
|
@@ -159,8 +240,7 @@ class LanguageModel(
|
|
159
240
|
if verbose:
|
160
241
|
print(f"Current key is {masked}")
|
161
242
|
return self.execute_model_call(
|
162
|
-
user_prompt="Hello, model!",
|
163
|
-
system_prompt="You are a helpful agent."
|
243
|
+
user_prompt="Hello, model!", system_prompt="You are a helpful agent."
|
164
244
|
)
|
165
245
|
|
166
246
|
def has_valid_api_key(self) -> bool:
|
@@ -204,42 +284,58 @@ class LanguageModel(
|
|
204
284
|
>>> m = LanguageModel.example()
|
205
285
|
>>> m.set_rate_limits(rpm=100, tpm=1000)
|
206
286
|
>>> m.RPM
|
207
|
-
|
287
|
+
100
|
208
288
|
"""
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
289
|
+
if rpm is not None:
|
290
|
+
self._rpm = rpm
|
291
|
+
if tpm is not None:
|
292
|
+
self._tpm = tpm
|
293
|
+
return None
|
294
|
+
# self._set_rate_limits(rpm=rpm, tpm=tpm)
|
295
|
+
|
296
|
+
# def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
297
|
+
# """Set the rate limits for the model.
|
298
|
+
|
299
|
+
# If the model does not have rate limits, use the default rate limits."""
|
300
|
+
# if rpm is not None and tpm is not None:
|
301
|
+
# self.__rate_limits = {"rpm": rpm, "tpm": tpm}
|
302
|
+
# return
|
303
|
+
|
304
|
+
# if self.__rate_limits is None:
|
305
|
+
# if hasattr(self, "get_rate_limits"):
|
306
|
+
# self.__rate_limits = self.get_rate_limits()
|
307
|
+
# else:
|
308
|
+
# self.__rate_limits = self.__default_rate_limits
|
226
309
|
|
227
310
|
@property
|
228
311
|
def RPM(self):
|
229
312
|
"""Model's requests-per-minute limit."""
|
230
|
-
self._set_rate_limits()
|
231
|
-
return self._safety_factor * self.__rate_limits["rpm"]
|
313
|
+
# self._set_rate_limits()
|
314
|
+
# return self._safety_factor * self.__rate_limits["rpm"]
|
315
|
+
return self._rpm
|
232
316
|
|
233
317
|
@property
|
234
318
|
def TPM(self):
|
235
|
-
"""Model's tokens-per-minute limit.
|
319
|
+
"""Model's tokens-per-minute limit."""
|
320
|
+
# self._set_rate_limits()
|
321
|
+
# return self._safety_factor * self.__rate_limits["tpm"]
|
322
|
+
return self._tpm
|
236
323
|
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
324
|
+
@property
|
325
|
+
def rpm(self):
|
326
|
+
return self._rpm
|
327
|
+
|
328
|
+
@rpm.setter
|
329
|
+
def rpm(self, value):
|
330
|
+
self._rpm = value
|
331
|
+
|
332
|
+
@property
|
333
|
+
def tpm(self):
|
334
|
+
return self._tpm
|
335
|
+
|
336
|
+
@tpm.setter
|
337
|
+
def tpm(self, value):
|
338
|
+
self._tpm = value
|
243
339
|
|
244
340
|
@staticmethod
|
245
341
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
@@ -250,14 +346,16 @@ class LanguageModel(
|
|
250
346
|
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
|
251
347
|
{'temperature': 0.5, 'max_tokens': 1000}
|
252
348
|
"""
|
253
|
-
#parameters = dict({})
|
349
|
+
# parameters = dict({})
|
350
|
+
|
351
|
+
return {
|
352
|
+
parameter_name: passed_parameter_dict.get(parameter_name, default_value)
|
353
|
+
for parameter_name, default_value in default_parameter_dict.items()
|
354
|
+
}
|
254
355
|
|
255
|
-
|
256
|
-
for parameter_name, default_value in default_parameter_dict.items()}
|
257
|
-
|
258
|
-
def __call__(self, user_prompt:str, system_prompt:str):
|
356
|
+
def __call__(self, user_prompt: str, system_prompt: str):
|
259
357
|
return self.execute_model_call(user_prompt, system_prompt)
|
260
|
-
|
358
|
+
|
261
359
|
@abstractmethod
|
262
360
|
async def async_execute_model_call(user_prompt: str, system_prompt: str):
|
263
361
|
"""Execute the model call and returns a coroutine.
|
@@ -265,11 +363,10 @@ class LanguageModel(
|
|
265
363
|
>>> m = LanguageModel.example(test_model = True)
|
266
364
|
>>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
|
267
365
|
>>> asyncio.run(test())
|
268
|
-
{'message': '
|
366
|
+
{'message': [{'text': 'Hello world'}], ...}
|
269
367
|
|
270
368
|
>>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
271
|
-
{'message': '
|
272
|
-
|
369
|
+
{'message': [{'text': 'Hello world'}], ...}
|
273
370
|
"""
|
274
371
|
pass
|
275
372
|
|
@@ -302,66 +399,40 @@ class LanguageModel(
|
|
302
399
|
|
303
400
|
return main()
|
304
401
|
|
305
|
-
@
|
306
|
-
def
|
307
|
-
"""
|
308
|
-
|
309
|
-
>>> m = LanguageModel.example(test_model = True)
|
310
|
-
>>> m
|
311
|
-
Model(model_name = 'test', temperature = 0.5)
|
312
|
-
|
313
|
-
What is returned by the API is model-specific and often includes meta-data that we do not need.
|
314
|
-
For example, here is the results from a call to GPT-4:
|
315
|
-
To actually track the response, we need to grab
|
316
|
-
data["choices[0]"]["message"]["content"].
|
317
|
-
"""
|
318
|
-
raise NotImplementedError
|
319
|
-
|
320
|
-
async def _async_prepare_response(self, model_call_outcome: IntendedModelCallOutcome, cache: "Cache") -> dict:
|
321
|
-
"""Prepare the response for return."""
|
322
|
-
|
323
|
-
model_response = {
|
324
|
-
"cache_used": model_call_outcome.cache_used,
|
325
|
-
"cache_key": model_call_outcome.cache_key,
|
326
|
-
"usage": model_call_outcome.response.get("usage", {}),
|
327
|
-
"raw_model_response": model_call_outcome.response,
|
328
|
-
}
|
402
|
+
@classmethod
|
403
|
+
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
404
|
+
"""Return the generated token string from the raw response."""
|
405
|
+
return extract_item_from_raw_response(raw_response, cls.key_sequence)
|
329
406
|
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
bad_json=answer_portion,
|
337
|
-
error_message=str(e),
|
338
|
-
cache=cache
|
407
|
+
@classmethod
|
408
|
+
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
409
|
+
"""Return the usage dictionary from the raw response."""
|
410
|
+
if not hasattr(cls, "usage_sequence"):
|
411
|
+
raise NotImplementedError(
|
412
|
+
"This inference service does not have a usage_sequence."
|
339
413
|
)
|
340
|
-
|
341
|
-
raise Exception(
|
342
|
-
f"""Even the repair failed. The error was: {e}. The response was: {answer_portion}."""
|
343
|
-
)
|
344
|
-
|
345
|
-
return {**model_response, **answer_dict}
|
346
|
-
|
347
|
-
async def async_get_raw_response(
|
348
|
-
self,
|
349
|
-
user_prompt: str,
|
350
|
-
system_prompt: str,
|
351
|
-
cache: "Cache",
|
352
|
-
iteration: int = 0,
|
353
|
-
encoded_image=None,
|
354
|
-
) -> IntendedModelCallOutcome:
|
355
|
-
import warnings
|
356
|
-
warnings.warn("This method is deprecated. Use async_get_intended_model_call_outcome.")
|
357
|
-
return await self._async_get_intended_model_call_outcome(
|
358
|
-
user_prompt=user_prompt,
|
359
|
-
system_prompt=system_prompt,
|
360
|
-
cache=cache,
|
361
|
-
iteration=iteration,
|
362
|
-
encoded_image=encoded_image
|
363
|
-
)
|
414
|
+
return extract_item_from_raw_response(raw_response, cls.usage_sequence)
|
364
415
|
|
416
|
+
@classmethod
|
417
|
+
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
418
|
+
"""Parses the API response and returns the response text."""
|
419
|
+
generated_token_string = cls.get_generated_token_string(raw_response)
|
420
|
+
last_newline = generated_token_string.rfind("\n")
|
421
|
+
|
422
|
+
if last_newline == -1:
|
423
|
+
# There is no comment
|
424
|
+
edsl_dict = {
|
425
|
+
"answer": convert_answer(generated_token_string),
|
426
|
+
"generated_tokens": generated_token_string,
|
427
|
+
"comment": None,
|
428
|
+
}
|
429
|
+
else:
|
430
|
+
edsl_dict = {
|
431
|
+
"answer": convert_answer(generated_token_string[:last_newline]),
|
432
|
+
"comment": generated_token_string[last_newline + 1 :].strip(),
|
433
|
+
"generated_tokens": generated_token_string,
|
434
|
+
}
|
435
|
+
return EDSLOutput(**edsl_dict)
|
365
436
|
|
366
437
|
async def _async_get_intended_model_call_outcome(
|
367
438
|
self,
|
@@ -370,7 +441,7 @@ class LanguageModel(
|
|
370
441
|
cache: "Cache",
|
371
442
|
iteration: int = 0,
|
372
443
|
encoded_image=None,
|
373
|
-
) ->
|
444
|
+
) -> ModelResponse:
|
374
445
|
"""Handle caching of responses.
|
375
446
|
|
376
447
|
:param user_prompt: The user's prompt.
|
@@ -389,23 +460,23 @@ class LanguageModel(
|
|
389
460
|
>>> from edsl import Cache
|
390
461
|
>>> m = LanguageModel.example(test_model = True)
|
391
462
|
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
392
|
-
|
393
|
-
"""
|
463
|
+
ModelResponse(...)"""
|
394
464
|
|
395
465
|
if encoded_image:
|
396
466
|
# the image has is appended to the user_prompt for hash-lookup purposes
|
397
467
|
image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
|
468
|
+
user_prompt += f" {image_hash}"
|
398
469
|
|
399
470
|
cache_call_params = {
|
400
471
|
"model": str(self.model),
|
401
472
|
"parameters": self.parameters,
|
402
473
|
"system_prompt": system_prompt,
|
403
|
-
"user_prompt": user_prompt
|
474
|
+
"user_prompt": user_prompt,
|
404
475
|
"iteration": iteration,
|
405
476
|
}
|
406
477
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
407
|
-
|
408
|
-
if
|
478
|
+
|
479
|
+
if cache_used := cached_response is not None:
|
409
480
|
response = json.loads(cached_response)
|
410
481
|
else:
|
411
482
|
f = (
|
@@ -413,18 +484,33 @@ class LanguageModel(
|
|
413
484
|
if hasattr(self, "remote") and self.remote
|
414
485
|
else self.async_execute_model_call
|
415
486
|
)
|
416
|
-
params = {
|
417
|
-
|
487
|
+
params = {
|
488
|
+
"user_prompt": user_prompt,
|
489
|
+
"system_prompt": system_prompt,
|
490
|
+
**({"encoded_image": encoded_image} if encoded_image else {}),
|
418
491
|
}
|
419
|
-
response = await f(**params)
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
492
|
+
# response = await f(**params)
|
493
|
+
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
494
|
+
new_cache_key = cache.store(
|
495
|
+
**cache_call_params, response=response
|
496
|
+
) # store the response in the cache
|
497
|
+
assert new_cache_key == cache_key # should be the same
|
498
|
+
|
499
|
+
cost = self.cost(response)
|
500
|
+
|
501
|
+
return ModelResponse(
|
502
|
+
response=response,
|
503
|
+
cache_used=cache_used,
|
504
|
+
cache_key=cache_key,
|
505
|
+
cached_response=cached_response,
|
506
|
+
cost=cost,
|
507
|
+
)
|
424
508
|
|
425
|
-
_get_intended_model_call_outcome = sync_wrapper(
|
509
|
+
_get_intended_model_call_outcome = sync_wrapper(
|
510
|
+
_async_get_intended_model_call_outcome
|
511
|
+
)
|
426
512
|
|
427
|
-
get_raw_response = sync_wrapper(async_get_raw_response)
|
513
|
+
# get_raw_response = sync_wrapper(async_get_raw_response)
|
428
514
|
|
429
515
|
def simple_ask(
|
430
516
|
self,
|
@@ -443,7 +529,7 @@ class LanguageModel(
|
|
443
529
|
self,
|
444
530
|
user_prompt: str,
|
445
531
|
system_prompt: str,
|
446
|
-
cache:
|
532
|
+
cache: "Cache",
|
447
533
|
iteration: int = 1,
|
448
534
|
encoded_image=None,
|
449
535
|
) -> dict:
|
@@ -461,16 +547,68 @@ class LanguageModel(
|
|
461
547
|
"system_prompt": system_prompt,
|
462
548
|
"iteration": iteration,
|
463
549
|
"cache": cache,
|
464
|
-
**({"encoded_image": encoded_image} if encoded_image else {})
|
465
|
-
}
|
466
|
-
|
467
|
-
|
550
|
+
**({"encoded_image": encoded_image} if encoded_image else {}),
|
551
|
+
}
|
552
|
+
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
553
|
+
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
554
|
+
edsl_dict = self.parse_response(model_outputs.response)
|
555
|
+
agent_response_dict = AgentResponseDict(
|
556
|
+
model_inputs=model_inputs,
|
557
|
+
model_outputs=model_outputs,
|
558
|
+
edsl_dict=edsl_dict,
|
559
|
+
)
|
560
|
+
return agent_response_dict
|
561
|
+
|
562
|
+
# return await self._async_prepare_response(model_call_outcome, cache=cache)
|
468
563
|
|
469
564
|
get_response = sync_wrapper(async_get_response)
|
470
565
|
|
471
|
-
def cost(self, raw_response: dict[str, Any]) -> float:
|
566
|
+
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
472
567
|
"""Return the dollar cost of a raw response."""
|
473
|
-
|
568
|
+
|
569
|
+
usage = self.get_usage_dict(raw_response)
|
570
|
+
from edsl.coop import Coop
|
571
|
+
|
572
|
+
c = Coop()
|
573
|
+
price_lookup = c.fetch_prices()
|
574
|
+
key = (self._inference_service_, self.model)
|
575
|
+
if key not in price_lookup:
|
576
|
+
return f"Could not find price for model {self.model} in the price lookup."
|
577
|
+
|
578
|
+
relevant_prices = price_lookup[key]
|
579
|
+
try:
|
580
|
+
input_tokens = int(usage[self.input_token_name])
|
581
|
+
output_tokens = int(usage[self.output_token_name])
|
582
|
+
except Exception as e:
|
583
|
+
return f"Could not fetch tokens from model response: {e}"
|
584
|
+
|
585
|
+
try:
|
586
|
+
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
587
|
+
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
588
|
+
except Exception as e:
|
589
|
+
if "output" not in relevant_prices:
|
590
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
591
|
+
if "input" not in relevant_prices:
|
592
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
593
|
+
return f"Could not fetch prices from {relevant_prices} - {e}"
|
594
|
+
|
595
|
+
if inverse_input_price == "infinity":
|
596
|
+
input_cost = 0
|
597
|
+
else:
|
598
|
+
try:
|
599
|
+
input_cost = input_tokens / float(inverse_input_price)
|
600
|
+
except Exception as e:
|
601
|
+
return f"Could not compute input price - {e}."
|
602
|
+
|
603
|
+
if inverse_output_price == "infinity":
|
604
|
+
output_cost = 0
|
605
|
+
else:
|
606
|
+
try:
|
607
|
+
output_cost = output_tokens / float(inverse_output_price)
|
608
|
+
except Exception as e:
|
609
|
+
return f"Could not compute output price - {e}"
|
610
|
+
|
611
|
+
return input_cost + output_cost
|
474
612
|
|
475
613
|
#######################
|
476
614
|
# SERIALIZATION METHODS
|
@@ -484,7 +622,7 @@ class LanguageModel(
|
|
484
622
|
|
485
623
|
>>> m = LanguageModel.example()
|
486
624
|
>>> m.to_dict()
|
487
|
-
{'model': '
|
625
|
+
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
488
626
|
"""
|
489
627
|
return self._to_dict()
|
490
628
|
|
@@ -560,26 +698,8 @@ class LanguageModel(
|
|
560
698
|
"""
|
561
699
|
from edsl import Model
|
562
700
|
|
563
|
-
class TestLanguageModelGood(LanguageModel):
|
564
|
-
use_cache = False
|
565
|
-
_model_ = "test"
|
566
|
-
_parameters_ = {"temperature": 0.5}
|
567
|
-
_inference_service_ = InferenceServiceType.TEST.value
|
568
|
-
|
569
|
-
async def async_execute_model_call(
|
570
|
-
self, user_prompt: str, system_prompt: str
|
571
|
-
) -> dict[str, Any]:
|
572
|
-
await asyncio.sleep(0.1)
|
573
|
-
# return {"message": """{"answer": "Hello, world"}"""}
|
574
|
-
if throw_exception:
|
575
|
-
raise Exception("This is a test error")
|
576
|
-
return {"message": f'{{"answer": "{canned_response}"}}'}
|
577
|
-
|
578
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
579
|
-
return raw_response["message"]
|
580
|
-
|
581
701
|
if test_model:
|
582
|
-
m =
|
702
|
+
m = Model("test", canned_response=canned_response)
|
583
703
|
return m
|
584
704
|
else:
|
585
705
|
return Model(skip_api_key_check=True)
|
@@ -40,8 +40,8 @@ class ModelList(Base, UserList):
|
|
40
40
|
def __hash__(self):
|
41
41
|
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
42
|
|
43
|
-
>>> hash(
|
44
|
-
|
43
|
+
>>> isinstance(hash(Model()), int)
|
44
|
+
True
|
45
45
|
|
46
46
|
"""
|
47
47
|
from edsl.utilities.utilities import dict_hash
|