edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +116 -197
- edsl/__init__.py +7 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +147 -351
- edsl/agents/AgentList.py +73 -211
- edsl/agents/Invigilator.py +50 -101
- edsl/agents/InvigilatorBase.py +70 -62
- edsl/agents/PromptConstructor.py +225 -143
- edsl/agents/__init__.py +1 -0
- edsl/agents/prompt_helpers.py +3 -3
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/config.py +2 -22
- edsl/conversation/car_buying.py +1 -2
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +47 -125
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +27 -45
- edsl/data/CacheEntry.py +15 -12
- edsl/data/CacheHandler.py +12 -31
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/data/__init__.py +3 -4
- edsl/data_transfer_models.py +1 -2
- edsl/enums.py +0 -27
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +0 -12
- edsl/exceptions/questions.py +6 -24
- edsl/exceptions/scenarios.py +0 -7
- edsl/inference_services/AnthropicService.py +19 -38
- edsl/inference_services/AwsBedrock.py +2 -0
- edsl/inference_services/AzureAI.py +2 -0
- edsl/inference_services/GoogleService.py +12 -7
- edsl/inference_services/InferenceServiceABC.py +85 -18
- edsl/inference_services/InferenceServicesCollection.py +79 -120
- edsl/inference_services/MistralAIService.py +3 -0
- edsl/inference_services/OpenAIService.py +35 -47
- edsl/inference_services/PerplexityService.py +3 -0
- edsl/inference_services/TestService.py +10 -11
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/jobs/Answers.py +14 -1
- edsl/jobs/Jobs.py +431 -356
- edsl/jobs/JobsChecks.py +10 -35
- edsl/jobs/JobsPrompts.py +4 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
- edsl/jobs/buckets/BucketCollection.py +3 -44
- edsl/jobs/buckets/TokenBucket.py +21 -53
- edsl/jobs/interviews/Interview.py +408 -143
- edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
- edsl/jobs/tasks/TaskHistory.py +18 -38
- edsl/jobs/tasks/task_status_enum.py +2 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +236 -194
- edsl/language_models/ModelList.py +19 -28
- edsl/language_models/__init__.py +2 -1
- edsl/language_models/registry.py +190 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/unused/ReplicateBase.py +83 -0
- edsl/language_models/utilities.py +4 -5
- edsl/notebooks/Notebook.py +14 -19
- edsl/prompts/Prompt.py +39 -29
- edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
- edsl/questions/QuestionBase.py +214 -68
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
- edsl/questions/QuestionBasePromptsMixin.py +3 -7
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +3 -2
- edsl/questions/QuestionList.py +18 -10
- edsl/questions/QuestionMultipleChoice.py +23 -67
- edsl/questions/QuestionNumerical.py +4 -2
- edsl/questions/QuestionRank.py +17 -7
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
- edsl/questions/SimpleAskMixin.py +3 -4
- edsl/questions/__init__.py +1 -2
- edsl/questions/derived/QuestionLinearScale.py +3 -6
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +3 -17
- edsl/questions/question_registry.py +1 -1
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +7 -170
- edsl/results/DatasetExportMixin.py +305 -168
- edsl/results/DatasetTree.py +8 -28
- edsl/results/Result.py +206 -298
- edsl/results/Results.py +131 -149
- edsl/results/ResultsDBMixin.py +238 -0
- edsl/results/ResultsExportMixin.py +0 -2
- edsl/results/{results_selector.py → Selector.py} +13 -23
- edsl/results/TableDisplay.py +171 -98
- edsl/results/__init__.py +1 -1
- edsl/scenarios/FileStore.py +239 -150
- edsl/scenarios/Scenario.py +193 -90
- edsl/scenarios/ScenarioHtmlMixin.py +3 -4
- edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
- edsl/scenarios/ScenarioList.py +244 -415
- edsl/scenarios/ScenarioListExportMixin.py +7 -0
- edsl/scenarios/ScenarioListPdfMixin.py +37 -15
- edsl/scenarios/__init__.py +2 -1
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +12 -5
- edsl/surveys/Rule.py +4 -5
- edsl/surveys/RuleCollection.py +27 -25
- edsl/surveys/Survey.py +791 -270
- edsl/surveys/SurveyCSS.py +8 -20
- edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
- edsl/surveys/__init__.py +2 -4
- edsl/surveys/descriptors.py +2 -6
- edsl/surveys/instructions/ChangeInstruction.py +2 -1
- edsl/surveys/instructions/Instruction.py +13 -4
- edsl/surveys/instructions/InstructionCollection.py +6 -11
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/utilities.py +23 -35
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
- edsl-0.1.39.dev1.dist-info/RECORD +277 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
- edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
- edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
- edsl/agents/question_option_processor.py +0 -172
- edsl/coop/CoopFunctionsMixin.py +0 -15
- edsl/coop/ExpectedParrotKeyHandler.py +0 -125
- edsl/exceptions/inference_services.py +0 -5
- edsl/inference_services/AvailableModelCacheHandler.py +0 -184
- edsl/inference_services/AvailableModelFetcher.py +0 -215
- edsl/inference_services/ServiceAvailability.py +0 -135
- edsl/inference_services/data_structures.py +0 -134
- edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
- edsl/jobs/FetchInvigilator.py +0 -47
- edsl/jobs/InterviewTaskManager.py +0 -98
- edsl/jobs/InterviewsConstructor.py +0 -50
- edsl/jobs/JobsComponentConstructor.py +0 -189
- edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
- edsl/jobs/RequestTokenEstimator.py +0 -30
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/buckets/TokenBucketAPI.py +0 -211
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/decorators.py +0 -35
- edsl/jobs/jobs_status_enums.py +0 -9
- edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/ComputeCost.py +0 -63
- edsl/language_models/PriceManager.py +0 -127
- edsl/language_models/RawResponseHandler.py +0 -106
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +0 -131
- edsl/language_models/model.py +0 -256
- edsl/notebooks/NotebookToLaTeX.py +0 -142
- edsl/questions/ExceptionExplainer.py +0 -77
- edsl/questions/HTMLQuestion.py +0 -103
- edsl/questions/QuestionMatrix.py +0 -265
- edsl/questions/data_structures.py +0 -20
- edsl/questions/loop_processor.py +0 -149
- edsl/questions/response_validator_factory.py +0 -34
- edsl/questions/templates/matrix/__init__.py +0 -1
- edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
- edsl/questions/templates/matrix/question_presentation.jinja +0 -20
- edsl/results/MarkdownToDocx.py +0 -122
- edsl/results/MarkdownToPDF.py +0 -111
- edsl/results/TextEditor.py +0 -50
- edsl/results/file_exports.py +0 -252
- edsl/results/smart_objects.py +0 -96
- edsl/results/table_data_class.py +0 -12
- edsl/results/table_renderers.py +0 -118
- edsl/scenarios/ConstructDownloadLink.py +0 -109
- edsl/scenarios/DocumentChunker.py +0 -102
- edsl/scenarios/DocxScenario.py +0 -16
- edsl/scenarios/PdfExtractor.py +0 -40
- edsl/scenarios/directory_scanner.py +0 -96
- edsl/scenarios/file_methods.py +0 -85
- edsl/scenarios/handlers/__init__.py +0 -13
- edsl/scenarios/handlers/csv.py +0 -49
- edsl/scenarios/handlers/docx.py +0 -76
- edsl/scenarios/handlers/html.py +0 -37
- edsl/scenarios/handlers/json.py +0 -111
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/scenarios/handlers/md.py +0 -51
- edsl/scenarios/handlers/pdf.py +0 -68
- edsl/scenarios/handlers/png.py +0 -39
- edsl/scenarios/handlers/pptx.py +0 -105
- edsl/scenarios/handlers/py.py +0 -294
- edsl/scenarios/handlers/sql.py +0 -313
- edsl/scenarios/handlers/sqlite.py +0 -149
- edsl/scenarios/handlers/txt.py +0 -33
- edsl/scenarios/scenario_selector.py +0 -156
- edsl/surveys/ConstructDAG.py +0 -92
- edsl/surveys/EditSurvey.py +0 -221
- edsl/surveys/InstructionHandler.py +0 -100
- edsl/surveys/MemoryManagement.py +0 -72
- edsl/surveys/RuleManager.py +0 -172
- edsl/surveys/Simulator.py +0 -75
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/utilities/PrettyList.py +0 -56
- edsl/utilities/is_notebook.py +0 -18
- edsl/utilities/is_valid_variable_name.py +0 -11
- edsl/utilities/remove_edsl_version.py +0 -24
- edsl-0.1.39.dist-info/RECORD +0 -358
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -0,0 +1,30 @@
|
|
1
|
+
import os
|
2
|
+
from collections import UserDict
|
3
|
+
|
4
|
+
from edsl.enums import service_to_api_keyname
|
5
|
+
from edsl.exceptions import MissingAPIKeyError
|
6
|
+
|
7
|
+
|
8
|
+
class KeyLookup(UserDict):
|
9
|
+
@classmethod
|
10
|
+
def from_os_environ(cls):
|
11
|
+
"""Create an instance of KeyLookupAPI with keys from os.environ"""
|
12
|
+
return cls({key: value for key, value in os.environ.items()})
|
13
|
+
|
14
|
+
def get_api_token(self, service: str, remote: bool = False):
|
15
|
+
key_name = service_to_api_keyname.get(service, "NOT FOUND")
|
16
|
+
|
17
|
+
if service == "bedrock":
|
18
|
+
api_token = [self.get(key_name[0]), self.get(key_name[1])]
|
19
|
+
missing_token = any(token is None for token in api_token)
|
20
|
+
else:
|
21
|
+
api_token = self.get(key_name)
|
22
|
+
missing_token = api_token is None
|
23
|
+
|
24
|
+
if missing_token and service != "test" and not remote:
|
25
|
+
raise MissingAPIKeyError(
|
26
|
+
f"""The key for service: `{service}` is not set.
|
27
|
+
Need a key with name {key_name} in your .env file."""
|
28
|
+
)
|
29
|
+
|
30
|
+
return api_token
|
@@ -21,6 +21,7 @@ import os
|
|
21
21
|
from typing import (
|
22
22
|
Coroutine,
|
23
23
|
Any,
|
24
|
+
Callable,
|
24
25
|
Type,
|
25
26
|
Union,
|
26
27
|
List,
|
@@ -31,6 +32,8 @@ from typing import (
|
|
31
32
|
)
|
32
33
|
from abc import ABC, abstractmethod
|
33
34
|
|
35
|
+
from json_repair import repair_json
|
36
|
+
|
34
37
|
from edsl.data_transfer_models import (
|
35
38
|
ModelResponse,
|
36
39
|
ModelInputs,
|
@@ -42,24 +45,61 @@ if TYPE_CHECKING:
|
|
42
45
|
from edsl.data.Cache import Cache
|
43
46
|
from edsl.scenarios.FileStore import FileStore
|
44
47
|
from edsl.questions.QuestionBase import QuestionBase
|
45
|
-
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
46
|
-
|
47
|
-
from edsl.enums import InferenceServiceType
|
48
48
|
|
49
|
-
from edsl.
|
50
|
-
|
51
|
-
|
52
|
-
)
|
53
|
-
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
49
|
+
from edsl.config import CONFIG
|
50
|
+
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
51
|
+
from edsl.utilities.decorators import remove_edsl_version
|
54
52
|
|
55
|
-
from edsl.Base import PersistenceMixin
|
53
|
+
from edsl.Base import PersistenceMixin
|
56
54
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
55
|
+
from edsl.language_models.KeyLookup import KeyLookup
|
56
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
57
57
|
|
58
|
-
|
59
|
-
|
60
|
-
|
58
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
59
|
+
|
60
|
+
|
61
|
+
# you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
|
62
|
+
# for reasons I don't understand. So leave it here.
|
63
|
+
def extract_item_from_raw_response(data, key_sequence):
|
64
|
+
if isinstance(data, str):
|
65
|
+
try:
|
66
|
+
data = json.loads(data)
|
67
|
+
except json.JSONDecodeError as e:
|
68
|
+
return data
|
69
|
+
current_data = data
|
70
|
+
for i, key in enumerate(key_sequence):
|
71
|
+
try:
|
72
|
+
if isinstance(current_data, (list, tuple)):
|
73
|
+
if not isinstance(key, int):
|
74
|
+
raise TypeError(
|
75
|
+
f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
|
76
|
+
)
|
77
|
+
if key < 0 or key >= len(current_data):
|
78
|
+
raise IndexError(
|
79
|
+
f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
|
80
|
+
)
|
81
|
+
elif isinstance(current_data, dict):
|
82
|
+
if key not in current_data:
|
83
|
+
raise KeyError(
|
84
|
+
f"Key '{key}' not found in dictionary at position {i}"
|
85
|
+
)
|
86
|
+
else:
|
87
|
+
raise TypeError(
|
88
|
+
f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
|
89
|
+
)
|
61
90
|
|
62
|
-
|
91
|
+
current_data = current_data[key]
|
92
|
+
except Exception as e:
|
93
|
+
path = " -> ".join(map(str, key_sequence[: i + 1]))
|
94
|
+
if "error" in data:
|
95
|
+
msg = data["error"]
|
96
|
+
else:
|
97
|
+
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
98
|
+
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
99
|
+
if isinstance(current_data, str):
|
100
|
+
return current_data.strip()
|
101
|
+
else:
|
102
|
+
return current_data
|
63
103
|
|
64
104
|
|
65
105
|
def handle_key_error(func):
|
@@ -77,21 +117,8 @@ def handle_key_error(func):
|
|
77
117
|
return wrapper
|
78
118
|
|
79
119
|
|
80
|
-
class classproperty:
|
81
|
-
def __init__(self, method):
|
82
|
-
self.method = method
|
83
|
-
|
84
|
-
def __get__(self, instance, cls):
|
85
|
-
return self.method(cls)
|
86
|
-
|
87
|
-
|
88
|
-
from edsl.Base import HashingMixin
|
89
|
-
|
90
|
-
|
91
120
|
class LanguageModel(
|
92
121
|
PersistenceMixin,
|
93
|
-
RepresentationMixin,
|
94
|
-
HashingMixin,
|
95
122
|
ABC,
|
96
123
|
metaclass=RegisterLanguageModelsMeta,
|
97
124
|
):
|
@@ -101,22 +128,15 @@ class LanguageModel(
|
|
101
128
|
key_sequence = (
|
102
129
|
None # This should be something like ["choices", 0, "message", "content"]
|
103
130
|
)
|
104
|
-
|
105
|
-
|
106
|
-
DEFAULT_TPM = 1000
|
107
|
-
|
108
|
-
@classproperty
|
109
|
-
def response_handler(cls):
|
110
|
-
key_sequence = cls.key_sequence
|
111
|
-
usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
|
112
|
-
return RawResponseHandler(key_sequence, usage_sequence)
|
131
|
+
__rate_limits = None
|
132
|
+
_safety_factor = 0.8
|
113
133
|
|
114
134
|
def __init__(
|
115
135
|
self,
|
116
|
-
tpm:
|
117
|
-
rpm:
|
136
|
+
tpm: float = None,
|
137
|
+
rpm: float = None,
|
118
138
|
omit_system_prompt_if_empty_string: bool = True,
|
119
|
-
key_lookup: Optional[
|
139
|
+
key_lookup: Optional[KeyLookup] = None,
|
120
140
|
**kwargs,
|
121
141
|
):
|
122
142
|
"""Initialize the LanguageModel."""
|
@@ -127,9 +147,7 @@ class LanguageModel(
|
|
127
147
|
self.remote = False
|
128
148
|
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
129
149
|
|
130
|
-
self.
|
131
|
-
self.model_info = self.key_lookup.get(self._inference_service_)
|
132
|
-
|
150
|
+
# self._rpm / _tpm comes from the class
|
133
151
|
if rpm is not None:
|
134
152
|
self._rpm = rpm
|
135
153
|
|
@@ -143,75 +161,49 @@ class LanguageModel(
|
|
143
161
|
if key not in parameters:
|
144
162
|
setattr(self, key, value)
|
145
163
|
|
146
|
-
if
|
164
|
+
if "use_cache" in kwargs:
|
165
|
+
warnings.warn(
|
166
|
+
"The use_cache parameter is deprecated. Use the Cache class instead."
|
167
|
+
)
|
168
|
+
|
169
|
+
if skip_api_key_check := kwargs.get("skip_api_key_check", False):
|
147
170
|
# Skip the API key check. Sometimes this is useful for testing.
|
148
171
|
self._api_token = None
|
149
172
|
|
150
|
-
def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
|
151
|
-
"""Set the key lookup."""
|
152
173
|
if key_lookup is not None:
|
153
|
-
|
174
|
+
self.key_lookup = key_lookup
|
154
175
|
else:
|
155
|
-
|
156
|
-
klc.add_key_lookup(fetch_order=("config", "env"))
|
157
|
-
return klc.get(("config", "env"))
|
158
|
-
|
159
|
-
def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
160
|
-
"""Set the key lookup, later"""
|
161
|
-
if hasattr(self, "_api_token"):
|
162
|
-
del self._api_token
|
163
|
-
self.key_lookup = key_lookup
|
164
|
-
|
165
|
-
def ask_question(self, question: "QuestionBase") -> str:
|
166
|
-
"""Ask a question and return the response.
|
176
|
+
self.key_lookup = KeyLookup.from_os_environ()
|
167
177
|
|
168
|
-
|
169
|
-
"""
|
178
|
+
def ask_question(self, question):
|
170
179
|
user_prompt = question.get_instructions().render(question.data).text
|
171
180
|
system_prompt = "You are a helpful agent pretending to be a human."
|
172
181
|
return self.execute_model_call(user_prompt, system_prompt)
|
173
182
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
if self.model_info is None:
|
178
|
-
self._rpm = self.DEFAULT_RPM
|
179
|
-
else:
|
180
|
-
self._rpm = self.model_info.rpm
|
181
|
-
return self._rpm
|
182
|
-
|
183
|
-
@property
|
184
|
-
def tpm(self):
|
185
|
-
if not hasattr(self, "_tpm"):
|
186
|
-
if self.model_info is None:
|
187
|
-
self._tpm = self.DEFAULT_TPM
|
188
|
-
else:
|
189
|
-
self._tpm = self.model_info.tpm
|
190
|
-
return self._tpm
|
191
|
-
|
192
|
-
# in case we want to override the default values
|
193
|
-
@tpm.setter
|
194
|
-
def tpm(self, value):
|
195
|
-
self._tpm = value
|
196
|
-
|
197
|
-
@rpm.setter
|
198
|
-
def rpm(self, value):
|
199
|
-
self._rpm = value
|
183
|
+
def set_key_lookup(self, key_lookup: KeyLookup) -> None:
|
184
|
+
del self._api_token
|
185
|
+
self.key_lookup = key_lookup
|
200
186
|
|
201
187
|
@property
|
202
188
|
def api_token(self) -> str:
|
203
189
|
if not hasattr(self, "_api_token"):
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
f"No key found for service '{self._inference_service_}'"
|
208
|
-
)
|
209
|
-
self._api_token = info.api_token
|
190
|
+
self._api_token = self.key_lookup.get_api_token(
|
191
|
+
self._inference_service_, self.remote
|
192
|
+
)
|
210
193
|
return self._api_token
|
211
194
|
|
212
195
|
def __getitem__(self, key):
|
213
196
|
return getattr(self, key)
|
214
197
|
|
198
|
+
def _repr_html_(self) -> str:
|
199
|
+
d = {"model": self.model}
|
200
|
+
d.update(self.parameters)
|
201
|
+
data = [[k, v] for k, v in d.items()]
|
202
|
+
from tabulate import tabulate
|
203
|
+
|
204
|
+
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
205
|
+
return f"<pre>{table}</pre>"
|
206
|
+
|
215
207
|
def hello(self, verbose=False):
|
216
208
|
"""Runs a simple test to check if the model is working."""
|
217
209
|
token = self.api_token
|
@@ -240,12 +232,7 @@ class LanguageModel(
|
|
240
232
|
return key_value is not None
|
241
233
|
|
242
234
|
def __hash__(self) -> str:
|
243
|
-
"""Allow the model to be used as a key in a dictionary.
|
244
|
-
|
245
|
-
>>> m = LanguageModel.example()
|
246
|
-
>>> hash(m)
|
247
|
-
1811901442659237949
|
248
|
-
"""
|
235
|
+
"""Allow the model to be used as a key in a dictionary."""
|
249
236
|
from edsl.utilities.utilities import dict_hash
|
250
237
|
|
251
238
|
return dict_hash(self.to_dict(add_edsl_version=False))
|
@@ -261,6 +248,46 @@ class LanguageModel(
|
|
261
248
|
"""
|
262
249
|
return self.model == other.model and self.parameters == other.parameters
|
263
250
|
|
251
|
+
def set_rate_limits(self, rpm=None, tpm=None) -> None:
|
252
|
+
"""Set the rate limits for the model.
|
253
|
+
|
254
|
+
>>> m = LanguageModel.example()
|
255
|
+
>>> m.set_rate_limits(rpm=100, tpm=1000)
|
256
|
+
>>> m.RPM
|
257
|
+
100
|
258
|
+
"""
|
259
|
+
if rpm is not None:
|
260
|
+
self._rpm = rpm
|
261
|
+
if tpm is not None:
|
262
|
+
self._tpm = tpm
|
263
|
+
return None
|
264
|
+
|
265
|
+
@property
|
266
|
+
def RPM(self):
|
267
|
+
"""Model's requests-per-minute limit."""
|
268
|
+
return self._rpm
|
269
|
+
|
270
|
+
@property
|
271
|
+
def TPM(self):
|
272
|
+
"""Model's tokens-per-minute limit."""
|
273
|
+
return self._tpm
|
274
|
+
|
275
|
+
@property
|
276
|
+
def rpm(self):
|
277
|
+
return self._rpm
|
278
|
+
|
279
|
+
@rpm.setter
|
280
|
+
def rpm(self, value):
|
281
|
+
self._rpm = value
|
282
|
+
|
283
|
+
@property
|
284
|
+
def tpm(self):
|
285
|
+
return self._tpm
|
286
|
+
|
287
|
+
@tpm.setter
|
288
|
+
def tpm(self, value):
|
289
|
+
self._tpm = value
|
290
|
+
|
264
291
|
@staticmethod
|
265
292
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
266
293
|
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
|
@@ -283,7 +310,16 @@ class LanguageModel(
|
|
283
310
|
|
284
311
|
@abstractmethod
|
285
312
|
async def async_execute_model_call(user_prompt: str, system_prompt: str):
|
286
|
-
"""Execute the model call and returns a coroutine.
|
313
|
+
"""Execute the model call and returns a coroutine.
|
314
|
+
|
315
|
+
>>> m = LanguageModel.example(test_model = True)
|
316
|
+
>>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
|
317
|
+
>>> asyncio.run(test())
|
318
|
+
{'message': [{'text': 'Hello world'}], ...}
|
319
|
+
|
320
|
+
>>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
321
|
+
{'message': [{'text': 'Hello world'}], ...}
|
322
|
+
"""
|
287
323
|
pass
|
288
324
|
|
289
325
|
async def remote_async_execute_model_call(
|
@@ -300,7 +336,12 @@ class LanguageModel(
|
|
300
336
|
|
301
337
|
@jupyter_nb_handler
|
302
338
|
def execute_model_call(self, *args, **kwargs) -> Coroutine:
|
303
|
-
"""Execute the model call and returns the result as a coroutine.
|
339
|
+
"""Execute the model call and returns the result as a coroutine.
|
340
|
+
|
341
|
+
>>> m = LanguageModel.example(test_model = True)
|
342
|
+
>>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
|
343
|
+
|
344
|
+
"""
|
304
345
|
|
305
346
|
async def main():
|
306
347
|
results = await asyncio.gather(
|
@@ -312,25 +353,58 @@ class LanguageModel(
|
|
312
353
|
|
313
354
|
@classmethod
|
314
355
|
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
315
|
-
"""Return the generated token string from the raw response.
|
316
|
-
|
317
|
-
>>> m = LanguageModel.example(test_model = True)
|
318
|
-
>>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
319
|
-
>>> m.get_generated_token_string(raw_response)
|
320
|
-
'Hello world'
|
321
|
-
|
322
|
-
"""
|
323
|
-
return cls.response_handler.get_generated_token_string(raw_response)
|
356
|
+
"""Return the generated token string from the raw response."""
|
357
|
+
return extract_item_from_raw_response(raw_response, cls.key_sequence)
|
324
358
|
|
325
359
|
@classmethod
|
326
360
|
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
327
361
|
"""Return the usage dictionary from the raw response."""
|
328
|
-
|
362
|
+
if not hasattr(cls, "usage_sequence"):
|
363
|
+
raise NotImplementedError(
|
364
|
+
"This inference service does not have a usage_sequence."
|
365
|
+
)
|
366
|
+
return extract_item_from_raw_response(raw_response, cls.usage_sequence)
|
367
|
+
|
368
|
+
@staticmethod
|
369
|
+
def convert_answer(response_part):
|
370
|
+
import json
|
371
|
+
|
372
|
+
response_part = response_part.strip()
|
373
|
+
|
374
|
+
if response_part == "None":
|
375
|
+
return None
|
376
|
+
|
377
|
+
repaired = repair_json(response_part)
|
378
|
+
if repaired == '""':
|
379
|
+
# it was a literal string
|
380
|
+
return response_part
|
381
|
+
|
382
|
+
try:
|
383
|
+
return json.loads(repaired)
|
384
|
+
except json.JSONDecodeError as j:
|
385
|
+
# last resort
|
386
|
+
return response_part
|
329
387
|
|
330
388
|
@classmethod
|
331
389
|
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
332
390
|
"""Parses the API response and returns the response text."""
|
333
|
-
|
391
|
+
generated_token_string = cls.get_generated_token_string(raw_response)
|
392
|
+
last_newline = generated_token_string.rfind("\n")
|
393
|
+
|
394
|
+
if last_newline == -1:
|
395
|
+
# There is no comment
|
396
|
+
edsl_dict = {
|
397
|
+
"answer": cls.convert_answer(generated_token_string),
|
398
|
+
"generated_tokens": generated_token_string,
|
399
|
+
"comment": None,
|
400
|
+
}
|
401
|
+
else:
|
402
|
+
edsl_dict = {
|
403
|
+
"answer": cls.convert_answer(generated_token_string[:last_newline]),
|
404
|
+
"comment": generated_token_string[last_newline + 1 :].strip(),
|
405
|
+
"generated_tokens": generated_token_string,
|
406
|
+
}
|
407
|
+
return EDSLOutput(**edsl_dict)
|
334
408
|
|
335
409
|
async def _async_get_intended_model_call_outcome(
|
336
410
|
self,
|
@@ -347,8 +421,6 @@ class LanguageModel(
|
|
347
421
|
:param system_prompt: The system's prompt.
|
348
422
|
:param iteration: The iteration number.
|
349
423
|
:param cache: The cache to use.
|
350
|
-
:param files_list: The list of files to use.
|
351
|
-
:param invigilator: The invigilator to use.
|
352
424
|
|
353
425
|
If the cache isn't being used, it just returns a 'fresh' call to the LLM.
|
354
426
|
But if cache is being used, it first checks the database to see if the response is already there.
|
@@ -391,10 +463,6 @@ class LanguageModel(
|
|
391
463
|
"system_prompt": system_prompt,
|
392
464
|
"files_list": files_list,
|
393
465
|
}
|
394
|
-
from edsl.config import CONFIG
|
395
|
-
|
396
|
-
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
397
|
-
|
398
466
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
399
467
|
new_cache_key = cache.store(
|
400
468
|
**cache_call_params, response=response
|
@@ -402,6 +470,7 @@ class LanguageModel(
|
|
402
470
|
assert new_cache_key == cache_key # should be the same
|
403
471
|
|
404
472
|
cost = self.cost(response)
|
473
|
+
|
405
474
|
return ModelResponse(
|
406
475
|
response=response,
|
407
476
|
cache_used=cache_used,
|
@@ -440,9 +509,9 @@ class LanguageModel(
|
|
440
509
|
|
441
510
|
:param user_prompt: The user's prompt.
|
442
511
|
:param system_prompt: The system's prompt.
|
443
|
-
:param cache: The cache to use.
|
444
512
|
:param iteration: The iteration number.
|
445
|
-
:param
|
513
|
+
:param cache: The cache to use.
|
514
|
+
:param encoded_image: The encoded image to use.
|
446
515
|
|
447
516
|
"""
|
448
517
|
params = {
|
@@ -456,11 +525,8 @@ class LanguageModel(
|
|
456
525
|
params.update({"invigilator": kwargs["invigilator"]})
|
457
526
|
|
458
527
|
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
459
|
-
model_outputs
|
460
|
-
|
461
|
-
)
|
462
|
-
edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
|
463
|
-
|
528
|
+
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
529
|
+
edsl_dict = self.parse_response(model_outputs.response)
|
464
530
|
agent_response_dict = AgentResponseDict(
|
465
531
|
model_inputs=model_inputs,
|
466
532
|
model_outputs=model_outputs,
|
@@ -471,36 +537,60 @@ class LanguageModel(
|
|
471
537
|
get_response = sync_wrapper(async_get_response)
|
472
538
|
|
473
539
|
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
474
|
-
"""Return the dollar cost of a raw response.
|
475
|
-
|
476
|
-
:param raw_response: The raw response from the model.
|
477
|
-
"""
|
540
|
+
"""Return the dollar cost of a raw response."""
|
478
541
|
|
479
542
|
usage = self.get_usage_dict(raw_response)
|
480
|
-
from edsl.
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
543
|
+
from edsl.coop import Coop
|
544
|
+
|
545
|
+
c = Coop()
|
546
|
+
price_lookup = c.fetch_prices()
|
547
|
+
key = (self._inference_service_, self.model)
|
548
|
+
if key not in price_lookup:
|
549
|
+
return f"Could not find price for model {self.model} in the price lookup."
|
550
|
+
|
551
|
+
relevant_prices = price_lookup[key]
|
552
|
+
try:
|
553
|
+
input_tokens = int(usage[self.input_token_name])
|
554
|
+
output_tokens = int(usage[self.output_token_name])
|
555
|
+
except Exception as e:
|
556
|
+
return f"Could not fetch tokens from model response: {e}"
|
557
|
+
|
558
|
+
try:
|
559
|
+
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
560
|
+
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
561
|
+
except Exception as e:
|
562
|
+
if "output" not in relevant_prices:
|
563
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
564
|
+
if "input" not in relevant_prices:
|
565
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
566
|
+
return f"Could not fetch prices from {relevant_prices} - {e}"
|
567
|
+
|
568
|
+
if inverse_input_price == "infinity":
|
569
|
+
input_cost = 0
|
570
|
+
else:
|
571
|
+
try:
|
572
|
+
input_cost = input_tokens / float(inverse_input_price)
|
573
|
+
except Exception as e:
|
574
|
+
return f"Could not compute input price - {e}."
|
575
|
+
|
576
|
+
if inverse_output_price == "infinity":
|
577
|
+
output_cost = 0
|
578
|
+
else:
|
579
|
+
try:
|
580
|
+
output_cost = output_tokens / float(inverse_output_price)
|
581
|
+
except Exception as e:
|
582
|
+
return f"Could not compute output price - {e}"
|
583
|
+
|
584
|
+
return input_cost + output_cost
|
490
585
|
|
491
586
|
def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
|
492
587
|
"""Convert instance to a dictionary
|
493
588
|
|
494
|
-
:param add_edsl_version: Whether to add the EDSL version to the dictionary.
|
495
|
-
|
496
589
|
>>> m = LanguageModel.example()
|
497
590
|
>>> m.to_dict()
|
498
591
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
499
592
|
"""
|
500
|
-
d = {
|
501
|
-
"model": self.model,
|
502
|
-
"parameters": self.parameters,
|
503
|
-
}
|
593
|
+
d = {"model": self.model, "parameters": self.parameters}
|
504
594
|
if add_edsl_version:
|
505
595
|
from edsl import __version__
|
506
596
|
|
@@ -512,13 +602,13 @@ class LanguageModel(
|
|
512
602
|
@remove_edsl_version
|
513
603
|
def from_dict(cls, data: dict) -> Type[LanguageModel]:
|
514
604
|
"""Convert dictionary to a LanguageModel child instance."""
|
515
|
-
from edsl.language_models.
|
605
|
+
from edsl.language_models.registry import get_model_class
|
516
606
|
|
517
607
|
model_class = get_model_class(data["model"])
|
518
608
|
return model_class(**data)
|
519
609
|
|
520
610
|
def __repr__(self) -> str:
|
521
|
-
"""Return a representation of the object."""
|
611
|
+
"""Return a string representation of the object."""
|
522
612
|
param_string = ", ".join(
|
523
613
|
f"{key} = {value}" for key, value in self.parameters.items()
|
524
614
|
)
|
@@ -560,7 +650,7 @@ class LanguageModel(
|
|
560
650
|
Exception report saved to ...
|
561
651
|
Also see: ...
|
562
652
|
"""
|
563
|
-
from edsl
|
653
|
+
from edsl import Model
|
564
654
|
|
565
655
|
if test_model:
|
566
656
|
m = Model(
|
@@ -570,54 +660,6 @@ class LanguageModel(
|
|
570
660
|
else:
|
571
661
|
return Model(skip_api_key_check=True)
|
572
662
|
|
573
|
-
def from_cache(self, cache: "Cache") -> LanguageModel:
|
574
|
-
|
575
|
-
from copy import deepcopy
|
576
|
-
from types import MethodType
|
577
|
-
from edsl import Cache
|
578
|
-
|
579
|
-
new_instance = deepcopy(self)
|
580
|
-
print("Cache entries", len(cache))
|
581
|
-
new_instance.cache = Cache(
|
582
|
-
data={k: v for k, v in cache.items() if v.model == self.model}
|
583
|
-
)
|
584
|
-
print("Cache entries with same model", len(new_instance.cache))
|
585
|
-
|
586
|
-
new_instance.user_prompts = [
|
587
|
-
ce.user_prompt for ce in new_instance.cache.values()
|
588
|
-
]
|
589
|
-
new_instance.system_prompts = [
|
590
|
-
ce.system_prompt for ce in new_instance.cache.values()
|
591
|
-
]
|
592
|
-
|
593
|
-
async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
|
594
|
-
cache_call_params = {
|
595
|
-
"model": str(self.model),
|
596
|
-
"parameters": self.parameters,
|
597
|
-
"system_prompt": system_prompt,
|
598
|
-
"user_prompt": user_prompt,
|
599
|
-
"iteration": 1,
|
600
|
-
}
|
601
|
-
cached_response, cache_key = cache.fetch(**cache_call_params)
|
602
|
-
response = json.loads(cached_response)
|
603
|
-
cost = 0
|
604
|
-
return ModelResponse(
|
605
|
-
response=response,
|
606
|
-
cache_used=True,
|
607
|
-
cache_key=cache_key,
|
608
|
-
cached_response=cached_response,
|
609
|
-
cost=cost,
|
610
|
-
)
|
611
|
-
|
612
|
-
# Bind the new method to the copied instance
|
613
|
-
setattr(
|
614
|
-
new_instance,
|
615
|
-
"async_execute_model_call",
|
616
|
-
MethodType(async_execute_model_call, new_instance),
|
617
|
-
)
|
618
|
-
|
619
|
-
return new_instance
|
620
|
-
|
621
663
|
|
622
664
|
if __name__ == "__main__":
|
623
665
|
"""Run the module's test suite."""
|