edsl 0.1.38.dev1__py3-none-any.whl → 0.1.38.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +49 -48
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +858 -855
- edsl/agents/AgentList.py +362 -350
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +284 -284
- edsl/agents/PromptConstructor.py +353 -353
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +279 -289
- edsl/config.py +149 -149
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +961 -958
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +530 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +42 -38
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +22 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -97
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1358 -1347
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +251 -248
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +361 -338
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +451 -442
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -30
- edsl/language_models/LanguageModel.py +708 -706
- edsl/language_models/ModelList.py +109 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +258 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +357 -357
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +660 -656
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +217 -234
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +166 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +93 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -413
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +717 -717
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +456 -450
- edsl/results/Results.py +1071 -1071
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -135
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +544 -546
- edsl/scenarios/ScenarioHtmlMixin.py +64 -64
- edsl/scenarios/ScenarioList.py +1112 -1112
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +326 -330
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1787 -1795
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +49 -47
- edsl/surveys/instructions/Instruction.py +53 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/{conjure → utilities}/naming_utilities.py +263 -263
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +409 -409
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/LICENSE +21 -21
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/METADATA +1 -1
- edsl-0.1.38.dev3.dist-info/RECORD +269 -0
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- edsl-0.1.38.dev1.dist-info/RECORD +0 -283
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/WHEEL +0 -0
@@ -1,706 +1,708 @@
|
|
1
|
-
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
2
|
-
|
3
|
-
Terminology:
|
4
|
-
|
5
|
-
raw_response: The JSON response from the model. This has all the model meta-data about the call.
|
6
|
-
|
7
|
-
edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
|
8
|
-
such as the cache key, token usage, etc.
|
9
|
-
|
10
|
-
generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
|
11
|
-
edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
|
12
|
-
|
13
|
-
"""
|
14
|
-
|
15
|
-
from __future__ import annotations
|
16
|
-
import warnings
|
17
|
-
from functools import wraps
|
18
|
-
import asyncio
|
19
|
-
import json
|
20
|
-
import os
|
21
|
-
from typing import (
|
22
|
-
Coroutine,
|
23
|
-
Any,
|
24
|
-
Callable,
|
25
|
-
Type,
|
26
|
-
Union,
|
27
|
-
List,
|
28
|
-
get_type_hints,
|
29
|
-
TypedDict,
|
30
|
-
Optional,
|
31
|
-
TYPE_CHECKING,
|
32
|
-
)
|
33
|
-
from abc import ABC, abstractmethod
|
34
|
-
|
35
|
-
from json_repair import repair_json
|
36
|
-
|
37
|
-
from edsl.data_transfer_models import (
|
38
|
-
ModelResponse,
|
39
|
-
ModelInputs,
|
40
|
-
EDSLOutput,
|
41
|
-
AgentResponseDict,
|
42
|
-
)
|
43
|
-
|
44
|
-
|
45
|
-
from edsl.config import CONFIG
|
46
|
-
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
47
|
-
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
48
|
-
from edsl.language_models.repair import repair
|
49
|
-
from edsl.enums import InferenceServiceType
|
50
|
-
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
51
|
-
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
52
|
-
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
53
|
-
|
54
|
-
from edsl.language_models.KeyLookup import KeyLookup
|
55
|
-
|
56
|
-
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
57
|
-
|
58
|
-
|
59
|
-
# you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
|
60
|
-
# for reasons I don't understand. So leave it here.
|
61
|
-
def extract_item_from_raw_response(data, key_sequence):
|
62
|
-
if isinstance(data, str):
|
63
|
-
try:
|
64
|
-
data = json.loads(data)
|
65
|
-
except json.JSONDecodeError as e:
|
66
|
-
return data
|
67
|
-
current_data = data
|
68
|
-
for i, key in enumerate(key_sequence):
|
69
|
-
try:
|
70
|
-
if isinstance(current_data, (list, tuple)):
|
71
|
-
if not isinstance(key, int):
|
72
|
-
raise TypeError(
|
73
|
-
f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
|
74
|
-
)
|
75
|
-
if key < 0 or key >= len(current_data):
|
76
|
-
raise IndexError(
|
77
|
-
f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
|
78
|
-
)
|
79
|
-
elif isinstance(current_data, dict):
|
80
|
-
if key not in current_data:
|
81
|
-
raise KeyError(
|
82
|
-
f"Key '{key}' not found in dictionary at position {i}"
|
83
|
-
)
|
84
|
-
else:
|
85
|
-
raise TypeError(
|
86
|
-
f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
|
87
|
-
)
|
88
|
-
|
89
|
-
current_data = current_data[key]
|
90
|
-
except Exception as e:
|
91
|
-
path = " -> ".join(map(str, key_sequence[: i + 1]))
|
92
|
-
if "error" in data:
|
93
|
-
msg = data["error"]
|
94
|
-
else:
|
95
|
-
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
96
|
-
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
97
|
-
if isinstance(current_data, str):
|
98
|
-
return current_data.strip()
|
99
|
-
else:
|
100
|
-
return current_data
|
101
|
-
|
102
|
-
|
103
|
-
def handle_key_error(func):
|
104
|
-
"""Handle KeyError exceptions."""
|
105
|
-
|
106
|
-
@wraps(func)
|
107
|
-
def wrapper(*args, **kwargs):
|
108
|
-
try:
|
109
|
-
return func(*args, **kwargs)
|
110
|
-
assert True == False
|
111
|
-
except KeyError as e:
|
112
|
-
return f"""KeyError occurred: {e}. This is most likely because the model you are using
|
113
|
-
returned a JSON object we were not expecting."""
|
114
|
-
|
115
|
-
return wrapper
|
116
|
-
|
117
|
-
|
118
|
-
class LanguageModel(
|
119
|
-
RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
|
120
|
-
):
|
121
|
-
"""ABC for LLM subclasses.
|
122
|
-
|
123
|
-
TODO:
|
124
|
-
|
125
|
-
1) Need better, more descriptive names for functions
|
126
|
-
|
127
|
-
get_model_response_no_cache (currently called async_execute_model_call)
|
128
|
-
|
129
|
-
get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
|
130
|
-
Calls:
|
131
|
-
- async_execute_model_call
|
132
|
-
- _updated_model_response_with_tracking
|
133
|
-
|
134
|
-
get_answer (currently called async_get_response)
|
135
|
-
This parses out the answer block and does some error-handling.
|
136
|
-
Calls:
|
137
|
-
- async_get_raw_response
|
138
|
-
- parse_response
|
139
|
-
|
140
|
-
|
141
|
-
"""
|
142
|
-
|
143
|
-
_model_ = None
|
144
|
-
key_sequence = (
|
145
|
-
None # This should be something like ["choices", 0, "message", "content"]
|
146
|
-
)
|
147
|
-
__rate_limits = None
|
148
|
-
_safety_factor = 0.8
|
149
|
-
|
150
|
-
def __init__(
|
151
|
-
self,
|
152
|
-
tpm: float = None,
|
153
|
-
rpm: float = None,
|
154
|
-
omit_system_prompt_if_empty_string: bool = True,
|
155
|
-
key_lookup: Optional[KeyLookup] = None,
|
156
|
-
**kwargs,
|
157
|
-
):
|
158
|
-
"""Initialize the LanguageModel."""
|
159
|
-
self.model = getattr(self, "_model_", None)
|
160
|
-
default_parameters = getattr(self, "_parameters_", None)
|
161
|
-
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
162
|
-
self.parameters = parameters
|
163
|
-
self.remote = False
|
164
|
-
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
165
|
-
|
166
|
-
# self._rpm / _tpm comes from the class
|
167
|
-
if rpm is not None:
|
168
|
-
self._rpm = rpm
|
169
|
-
|
170
|
-
if tpm is not None:
|
171
|
-
self._tpm = tpm
|
172
|
-
|
173
|
-
for key, value in parameters.items():
|
174
|
-
setattr(self, key, value)
|
175
|
-
|
176
|
-
for key, value in kwargs.items():
|
177
|
-
if key not in parameters:
|
178
|
-
setattr(self, key, value)
|
179
|
-
|
180
|
-
if "use_cache" in kwargs:
|
181
|
-
warnings.warn(
|
182
|
-
"The use_cache parameter is deprecated. Use the Cache class instead."
|
183
|
-
)
|
184
|
-
|
185
|
-
if skip_api_key_check := kwargs.get("skip_api_key_check", False):
|
186
|
-
# Skip the API key check. Sometimes this is useful for testing.
|
187
|
-
self._api_token = None
|
188
|
-
|
189
|
-
if key_lookup is not None:
|
190
|
-
self.key_lookup = key_lookup
|
191
|
-
else:
|
192
|
-
self.key_lookup = KeyLookup.from_os_environ()
|
193
|
-
|
194
|
-
def ask_question(self, question):
|
195
|
-
user_prompt = question.get_instructions().render(question.data).text
|
196
|
-
system_prompt = "You are a helpful agent pretending to be a human."
|
197
|
-
return self.execute_model_call(user_prompt, system_prompt)
|
198
|
-
|
199
|
-
def set_key_lookup(self, key_lookup: KeyLookup):
|
200
|
-
del self._api_token
|
201
|
-
self.key_lookup = key_lookup
|
202
|
-
|
203
|
-
@property
|
204
|
-
def api_token(self) -> str:
|
205
|
-
if not hasattr(self, "_api_token"):
|
206
|
-
self._api_token = self.key_lookup.get_api_token(
|
207
|
-
self._inference_service_, self.remote
|
208
|
-
)
|
209
|
-
return self._api_token
|
210
|
-
|
211
|
-
def __getitem__(self, key):
|
212
|
-
return getattr(self, key)
|
213
|
-
|
214
|
-
def _repr_html_(self):
|
215
|
-
from edsl.utilities.utilities import data_to_html
|
216
|
-
|
217
|
-
return data_to_html(self.to_dict())
|
218
|
-
|
219
|
-
def hello(self, verbose=False):
|
220
|
-
"""Runs a simple test to check if the model is working."""
|
221
|
-
token = self.api_token
|
222
|
-
masked = token[: min(8, len(token))] + "..."
|
223
|
-
if verbose:
|
224
|
-
print(f"Current key is {masked}")
|
225
|
-
return self.execute_model_call(
|
226
|
-
user_prompt="Hello, model!", system_prompt="You are a helpful agent."
|
227
|
-
)
|
228
|
-
|
229
|
-
def has_valid_api_key(self) -> bool:
|
230
|
-
"""Check if the model has a valid API key.
|
231
|
-
|
232
|
-
>>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
|
233
|
-
True
|
234
|
-
|
235
|
-
This method is used to check if the model has a valid API key.
|
236
|
-
"""
|
237
|
-
from edsl.enums import service_to_api_keyname
|
238
|
-
import os
|
239
|
-
|
240
|
-
if self._model_ == "test":
|
241
|
-
return True
|
242
|
-
|
243
|
-
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
244
|
-
key_value = os.getenv(key_name)
|
245
|
-
return key_value is not None
|
246
|
-
|
247
|
-
def __hash__(self) -> str:
|
248
|
-
"""Allow the model to be used as a key in a dictionary."""
|
249
|
-
from edsl.utilities.utilities import dict_hash
|
250
|
-
|
251
|
-
return dict_hash(self.to_dict())
|
252
|
-
|
253
|
-
def __eq__(self, other):
|
254
|
-
"""Check is two models are the same.
|
255
|
-
|
256
|
-
>>> m1 = LanguageModel.example()
|
257
|
-
>>> m2 = LanguageModel.example()
|
258
|
-
>>> m1 == m2
|
259
|
-
True
|
260
|
-
|
261
|
-
"""
|
262
|
-
return self.model == other.model and self.parameters == other.parameters
|
263
|
-
|
264
|
-
def set_rate_limits(self, rpm=None, tpm=None) -> None:
|
265
|
-
"""Set the rate limits for the model.
|
266
|
-
|
267
|
-
>>> m = LanguageModel.example()
|
268
|
-
>>> m.set_rate_limits(rpm=100, tpm=1000)
|
269
|
-
>>> m.RPM
|
270
|
-
100
|
271
|
-
"""
|
272
|
-
if rpm is not None:
|
273
|
-
self._rpm = rpm
|
274
|
-
if tpm is not None:
|
275
|
-
self._tpm = tpm
|
276
|
-
return None
|
277
|
-
|
278
|
-
@property
|
279
|
-
def RPM(self):
|
280
|
-
"""Model's requests-per-minute limit."""
|
281
|
-
# self._set_rate_limits()
|
282
|
-
# return self._safety_factor * self.__rate_limits["rpm"]
|
283
|
-
return self._rpm
|
284
|
-
|
285
|
-
@property
|
286
|
-
def TPM(self):
|
287
|
-
"""Model's tokens-per-minute limit."""
|
288
|
-
# self._set_rate_limits()
|
289
|
-
# return self._safety_factor * self.__rate_limits["tpm"]
|
290
|
-
return self._tpm
|
291
|
-
|
292
|
-
@property
|
293
|
-
def rpm(self):
|
294
|
-
return self._rpm
|
295
|
-
|
296
|
-
@rpm.setter
|
297
|
-
def rpm(self, value):
|
298
|
-
self._rpm = value
|
299
|
-
|
300
|
-
@property
|
301
|
-
def tpm(self):
|
302
|
-
return self._tpm
|
303
|
-
|
304
|
-
@tpm.setter
|
305
|
-
def tpm(self, value):
|
306
|
-
self._tpm = value
|
307
|
-
|
308
|
-
@staticmethod
|
309
|
-
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
310
|
-
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
|
311
|
-
|
312
|
-
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
|
313
|
-
{'temperature': 0.5}
|
314
|
-
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
|
315
|
-
{'temperature': 0.5, 'max_tokens': 1000}
|
316
|
-
"""
|
317
|
-
# parameters = dict({})
|
318
|
-
|
319
|
-
# this is the case when data is loaded from a dict after serialization
|
320
|
-
if "parameters" in passed_parameter_dict:
|
321
|
-
passed_parameter_dict = passed_parameter_dict["parameters"]
|
322
|
-
return {
|
323
|
-
parameter_name: passed_parameter_dict.get(parameter_name, default_value)
|
324
|
-
for parameter_name, default_value in default_parameter_dict.items()
|
325
|
-
}
|
326
|
-
|
327
|
-
def __call__(self, user_prompt: str, system_prompt: str):
|
328
|
-
return self.execute_model_call(user_prompt, system_prompt)
|
329
|
-
|
330
|
-
@abstractmethod
|
331
|
-
async def async_execute_model_call(user_prompt: str, system_prompt: str):
|
332
|
-
"""Execute the model call and returns a coroutine.
|
333
|
-
|
334
|
-
>>> m = LanguageModel.example(test_model = True)
|
335
|
-
>>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
|
336
|
-
>>> asyncio.run(test())
|
337
|
-
{'message': [{'text': 'Hello world'}], ...}
|
338
|
-
|
339
|
-
>>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
340
|
-
{'message': [{'text': 'Hello world'}], ...}
|
341
|
-
"""
|
342
|
-
pass
|
343
|
-
|
344
|
-
async def remote_async_execute_model_call(
|
345
|
-
self, user_prompt: str, system_prompt: str
|
346
|
-
):
|
347
|
-
"""Execute the model call and returns the result as a coroutine, using Coop."""
|
348
|
-
from edsl.coop import Coop
|
349
|
-
|
350
|
-
client = Coop()
|
351
|
-
response_data = await client.remote_async_execute_model_call(
|
352
|
-
self.to_dict(), user_prompt, system_prompt
|
353
|
-
)
|
354
|
-
return response_data
|
355
|
-
|
356
|
-
@jupyter_nb_handler
|
357
|
-
def execute_model_call(self, *args, **kwargs) -> Coroutine:
|
358
|
-
"""Execute the model call and returns the result as a coroutine.
|
359
|
-
|
360
|
-
>>> m = LanguageModel.example(test_model = True)
|
361
|
-
>>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
|
362
|
-
|
363
|
-
"""
|
364
|
-
|
365
|
-
async def main():
|
366
|
-
results = await asyncio.gather(
|
367
|
-
self.async_execute_model_call(*args, **kwargs)
|
368
|
-
)
|
369
|
-
return results[0] # Since there's only one task, return its result
|
370
|
-
|
371
|
-
return main()
|
372
|
-
|
373
|
-
@classmethod
|
374
|
-
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
375
|
-
"""Return the generated token string from the raw response."""
|
376
|
-
return extract_item_from_raw_response(raw_response, cls.key_sequence)
|
377
|
-
|
378
|
-
@classmethod
|
379
|
-
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
380
|
-
"""Return the usage dictionary from the raw response."""
|
381
|
-
if not hasattr(cls, "usage_sequence"):
|
382
|
-
raise NotImplementedError(
|
383
|
-
"This inference service does not have a usage_sequence."
|
384
|
-
)
|
385
|
-
return extract_item_from_raw_response(raw_response, cls.usage_sequence)
|
386
|
-
|
387
|
-
@staticmethod
|
388
|
-
def convert_answer(response_part):
|
389
|
-
import json
|
390
|
-
|
391
|
-
response_part = response_part.strip()
|
392
|
-
|
393
|
-
if response_part == "None":
|
394
|
-
return None
|
395
|
-
|
396
|
-
repaired = repair_json(response_part)
|
397
|
-
if repaired == '""':
|
398
|
-
# it was a literal string
|
399
|
-
return response_part
|
400
|
-
|
401
|
-
try:
|
402
|
-
return json.loads(repaired)
|
403
|
-
except json.JSONDecodeError as j:
|
404
|
-
# last resort
|
405
|
-
return response_part
|
406
|
-
|
407
|
-
@classmethod
|
408
|
-
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
409
|
-
"""Parses the API response and returns the response text."""
|
410
|
-
generated_token_string = cls.get_generated_token_string(raw_response)
|
411
|
-
last_newline = generated_token_string.rfind("\n")
|
412
|
-
|
413
|
-
if last_newline == -1:
|
414
|
-
# There is no comment
|
415
|
-
edsl_dict = {
|
416
|
-
"answer": cls.convert_answer(generated_token_string),
|
417
|
-
"generated_tokens": generated_token_string,
|
418
|
-
"comment": None,
|
419
|
-
}
|
420
|
-
else:
|
421
|
-
edsl_dict = {
|
422
|
-
"answer": cls.convert_answer(generated_token_string[:last_newline]),
|
423
|
-
"comment": generated_token_string[last_newline + 1 :].strip(),
|
424
|
-
"generated_tokens": generated_token_string,
|
425
|
-
}
|
426
|
-
return EDSLOutput(**edsl_dict)
|
427
|
-
|
428
|
-
async def _async_get_intended_model_call_outcome(
|
429
|
-
self,
|
430
|
-
user_prompt: str,
|
431
|
-
system_prompt: str,
|
432
|
-
cache: "Cache",
|
433
|
-
iteration: int = 0,
|
434
|
-
files_list=None,
|
435
|
-
) -> ModelResponse:
|
436
|
-
"""Handle caching of responses.
|
437
|
-
|
438
|
-
:param user_prompt: The user's prompt.
|
439
|
-
:param system_prompt: The system's prompt.
|
440
|
-
:param iteration: The iteration number.
|
441
|
-
:param cache: The cache to use.
|
442
|
-
|
443
|
-
If the cache isn't being used, it just returns a 'fresh' call to the LLM.
|
444
|
-
But if cache is being used, it first checks the database to see if the response is already there.
|
445
|
-
If it is, it returns the cached response, but again appends some tracking information.
|
446
|
-
If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
|
447
|
-
|
448
|
-
If self.use_cache is True, then attempts to retrieve the response from the database;
|
449
|
-
if not in the DB, calls the LLM and writes the response to the DB.
|
450
|
-
|
451
|
-
>>> from edsl import Cache
|
452
|
-
>>> m = LanguageModel.example(test_model = True)
|
453
|
-
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
454
|
-
ModelResponse(...)"""
|
455
|
-
|
456
|
-
if files_list:
|
457
|
-
files_hash = "+".join([str(hash(file)) for file in files_list])
|
458
|
-
# print(f"Files hash: {files_hash}")
|
459
|
-
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
460
|
-
else:
|
461
|
-
user_prompt_with_hashes = user_prompt
|
462
|
-
|
463
|
-
cache_call_params = {
|
464
|
-
"model": str(self.model),
|
465
|
-
"parameters": self.parameters,
|
466
|
-
"system_prompt": system_prompt,
|
467
|
-
"user_prompt": user_prompt_with_hashes,
|
468
|
-
"iteration": iteration,
|
469
|
-
}
|
470
|
-
cached_response, cache_key = cache.fetch(**cache_call_params)
|
471
|
-
|
472
|
-
if cache_used := cached_response is not None:
|
473
|
-
response = json.loads(cached_response)
|
474
|
-
else:
|
475
|
-
f = (
|
476
|
-
self.remote_async_execute_model_call
|
477
|
-
if hasattr(self, "remote") and self.remote
|
478
|
-
else self.async_execute_model_call
|
479
|
-
)
|
480
|
-
params = {
|
481
|
-
"user_prompt": user_prompt,
|
482
|
-
"system_prompt": system_prompt,
|
483
|
-
"files_list": files_list,
|
484
|
-
# **({"encoded_image": encoded_image} if encoded_image else {}),
|
485
|
-
}
|
486
|
-
# response = await f(**params)
|
487
|
-
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
488
|
-
new_cache_key = cache.store(
|
489
|
-
**cache_call_params, response=response
|
490
|
-
) # store the response in the cache
|
491
|
-
assert new_cache_key == cache_key # should be the same
|
492
|
-
|
493
|
-
cost = self.cost(response)
|
494
|
-
|
495
|
-
return ModelResponse(
|
496
|
-
response=response,
|
497
|
-
cache_used=cache_used,
|
498
|
-
cache_key=cache_key,
|
499
|
-
cached_response=cached_response,
|
500
|
-
cost=cost,
|
501
|
-
)
|
502
|
-
|
503
|
-
_get_intended_model_call_outcome = sync_wrapper(
|
504
|
-
_async_get_intended_model_call_outcome
|
505
|
-
)
|
506
|
-
|
507
|
-
# get_raw_response = sync_wrapper(async_get_raw_response)
|
508
|
-
|
509
|
-
def simple_ask(
|
510
|
-
self,
|
511
|
-
question: "QuestionBase",
|
512
|
-
system_prompt="You are a helpful agent pretending to be a human.",
|
513
|
-
top_logprobs=2,
|
514
|
-
):
|
515
|
-
"""Ask a question and return the response."""
|
516
|
-
self.logprobs = True
|
517
|
-
self.top_logprobs = top_logprobs
|
518
|
-
return self.execute_model_call(
|
519
|
-
user_prompt=question.human_readable(), system_prompt=system_prompt
|
520
|
-
)
|
521
|
-
|
522
|
-
async def async_get_response(
|
523
|
-
self,
|
524
|
-
user_prompt: str,
|
525
|
-
system_prompt: str,
|
526
|
-
cache: "Cache",
|
527
|
-
iteration: int = 1,
|
528
|
-
files_list: Optional[List["File"]] = None,
|
529
|
-
) -> dict:
|
530
|
-
"""Get response, parse, and return as string.
|
531
|
-
|
532
|
-
:param user_prompt: The user's prompt.
|
533
|
-
:param system_prompt: The system's prompt.
|
534
|
-
:param iteration: The iteration number.
|
535
|
-
:param cache: The cache to use.
|
536
|
-
:param encoded_image: The encoded image to use.
|
537
|
-
|
538
|
-
"""
|
539
|
-
params = {
|
540
|
-
"user_prompt": user_prompt,
|
541
|
-
"system_prompt": system_prompt,
|
542
|
-
"iteration": iteration,
|
543
|
-
"cache": cache,
|
544
|
-
"files_list": files_list,
|
545
|
-
}
|
546
|
-
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
547
|
-
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
548
|
-
edsl_dict = self.parse_response(model_outputs.response)
|
549
|
-
agent_response_dict = AgentResponseDict(
|
550
|
-
model_inputs=model_inputs,
|
551
|
-
model_outputs=model_outputs,
|
552
|
-
edsl_dict=edsl_dict,
|
553
|
-
)
|
554
|
-
return agent_response_dict
|
555
|
-
|
556
|
-
# return await self._async_prepare_response(model_call_outcome, cache=cache)
|
557
|
-
|
558
|
-
get_response = sync_wrapper(async_get_response)
|
559
|
-
|
560
|
-
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
561
|
-
"""Return the dollar cost of a raw response."""
|
562
|
-
|
563
|
-
usage = self.get_usage_dict(raw_response)
|
564
|
-
from edsl.coop import Coop
|
565
|
-
|
566
|
-
c = Coop()
|
567
|
-
price_lookup = c.fetch_prices()
|
568
|
-
key = (self._inference_service_, self.model)
|
569
|
-
if key not in price_lookup:
|
570
|
-
return f"Could not find price for model {self.model} in the price lookup."
|
571
|
-
|
572
|
-
relevant_prices = price_lookup[key]
|
573
|
-
try:
|
574
|
-
input_tokens = int(usage[self.input_token_name])
|
575
|
-
output_tokens = int(usage[self.output_token_name])
|
576
|
-
except Exception as e:
|
577
|
-
return f"Could not fetch tokens from model response: {e}"
|
578
|
-
|
579
|
-
try:
|
580
|
-
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
581
|
-
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
582
|
-
except Exception as e:
|
583
|
-
if "output" not in relevant_prices:
|
584
|
-
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
585
|
-
if "input" not in relevant_prices:
|
586
|
-
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
587
|
-
return f"Could not fetch prices from {relevant_prices} - {e}"
|
588
|
-
|
589
|
-
if inverse_input_price == "infinity":
|
590
|
-
input_cost = 0
|
591
|
-
else:
|
592
|
-
try:
|
593
|
-
input_cost = input_tokens / float(inverse_input_price)
|
594
|
-
except Exception as e:
|
595
|
-
return f"Could not compute input price - {e}."
|
596
|
-
|
597
|
-
if inverse_output_price == "infinity":
|
598
|
-
output_cost = 0
|
599
|
-
else:
|
600
|
-
try:
|
601
|
-
output_cost = output_tokens / float(inverse_output_price)
|
602
|
-
except Exception as e:
|
603
|
-
return f"Could not compute output price - {e}"
|
604
|
-
|
605
|
-
return input_cost + output_cost
|
606
|
-
|
607
|
-
#######################
|
608
|
-
# SERIALIZATION METHODS
|
609
|
-
#######################
|
610
|
-
def
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
#######################
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
table
|
666
|
-
|
667
|
-
table
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
>>>
|
687
|
-
True
|
688
|
-
>>>
|
689
|
-
|
690
|
-
>>>
|
691
|
-
'
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
return
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
doctest
|
1
|
+
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
2
|
+
|
3
|
+
Terminology:
|
4
|
+
|
5
|
+
raw_response: The JSON response from the model. This has all the model meta-data about the call.
|
6
|
+
|
7
|
+
edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
|
8
|
+
such as the cache key, token usage, etc.
|
9
|
+
|
10
|
+
generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
|
11
|
+
edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
|
12
|
+
|
13
|
+
"""
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
import warnings
|
17
|
+
from functools import wraps
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import os
|
21
|
+
from typing import (
|
22
|
+
Coroutine,
|
23
|
+
Any,
|
24
|
+
Callable,
|
25
|
+
Type,
|
26
|
+
Union,
|
27
|
+
List,
|
28
|
+
get_type_hints,
|
29
|
+
TypedDict,
|
30
|
+
Optional,
|
31
|
+
TYPE_CHECKING,
|
32
|
+
)
|
33
|
+
from abc import ABC, abstractmethod
|
34
|
+
|
35
|
+
from json_repair import repair_json
|
36
|
+
|
37
|
+
from edsl.data_transfer_models import (
|
38
|
+
ModelResponse,
|
39
|
+
ModelInputs,
|
40
|
+
EDSLOutput,
|
41
|
+
AgentResponseDict,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
from edsl.config import CONFIG
|
46
|
+
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
47
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
48
|
+
from edsl.language_models.repair import repair
|
49
|
+
from edsl.enums import InferenceServiceType
|
50
|
+
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
51
|
+
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
52
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
53
|
+
|
54
|
+
from edsl.language_models.KeyLookup import KeyLookup
|
55
|
+
|
56
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
57
|
+
|
58
|
+
|
59
|
+
# you might be tempated to move this to be a static method of LanguageModel, but this doesn't work
|
60
|
+
# for reasons I don't understand. So leave it here.
|
61
|
+
def extract_item_from_raw_response(data, key_sequence):
|
62
|
+
if isinstance(data, str):
|
63
|
+
try:
|
64
|
+
data = json.loads(data)
|
65
|
+
except json.JSONDecodeError as e:
|
66
|
+
return data
|
67
|
+
current_data = data
|
68
|
+
for i, key in enumerate(key_sequence):
|
69
|
+
try:
|
70
|
+
if isinstance(current_data, (list, tuple)):
|
71
|
+
if not isinstance(key, int):
|
72
|
+
raise TypeError(
|
73
|
+
f"Expected integer index for sequence at position {i}, got {type(key).__name__}"
|
74
|
+
)
|
75
|
+
if key < 0 or key >= len(current_data):
|
76
|
+
raise IndexError(
|
77
|
+
f"Index {key} out of range for sequence of length {len(current_data)} at position {i}"
|
78
|
+
)
|
79
|
+
elif isinstance(current_data, dict):
|
80
|
+
if key not in current_data:
|
81
|
+
raise KeyError(
|
82
|
+
f"Key '{key}' not found in dictionary at position {i}"
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
raise TypeError(
|
86
|
+
f"Cannot index into {type(current_data).__name__} at position {i}. Full response is: {data} of type {type(data)}. Key sequence is: {key_sequence}"
|
87
|
+
)
|
88
|
+
|
89
|
+
current_data = current_data[key]
|
90
|
+
except Exception as e:
|
91
|
+
path = " -> ".join(map(str, key_sequence[: i + 1]))
|
92
|
+
if "error" in data:
|
93
|
+
msg = data["error"]
|
94
|
+
else:
|
95
|
+
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
96
|
+
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
97
|
+
if isinstance(current_data, str):
|
98
|
+
return current_data.strip()
|
99
|
+
else:
|
100
|
+
return current_data
|
101
|
+
|
102
|
+
|
103
|
+
def handle_key_error(func):
|
104
|
+
"""Handle KeyError exceptions."""
|
105
|
+
|
106
|
+
@wraps(func)
|
107
|
+
def wrapper(*args, **kwargs):
|
108
|
+
try:
|
109
|
+
return func(*args, **kwargs)
|
110
|
+
assert True == False
|
111
|
+
except KeyError as e:
|
112
|
+
return f"""KeyError occurred: {e}. This is most likely because the model you are using
|
113
|
+
returned a JSON object we were not expecting."""
|
114
|
+
|
115
|
+
return wrapper
|
116
|
+
|
117
|
+
|
118
|
+
class LanguageModel(
|
119
|
+
RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
|
120
|
+
):
|
121
|
+
"""ABC for LLM subclasses.
|
122
|
+
|
123
|
+
TODO:
|
124
|
+
|
125
|
+
1) Need better, more descriptive names for functions
|
126
|
+
|
127
|
+
get_model_response_no_cache (currently called async_execute_model_call)
|
128
|
+
|
129
|
+
get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
|
130
|
+
Calls:
|
131
|
+
- async_execute_model_call
|
132
|
+
- _updated_model_response_with_tracking
|
133
|
+
|
134
|
+
get_answer (currently called async_get_response)
|
135
|
+
This parses out the answer block and does some error-handling.
|
136
|
+
Calls:
|
137
|
+
- async_get_raw_response
|
138
|
+
- parse_response
|
139
|
+
|
140
|
+
|
141
|
+
"""
|
142
|
+
|
143
|
+
_model_ = None
|
144
|
+
key_sequence = (
|
145
|
+
None # This should be something like ["choices", 0, "message", "content"]
|
146
|
+
)
|
147
|
+
__rate_limits = None
|
148
|
+
_safety_factor = 0.8
|
149
|
+
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
tpm: float = None,
|
153
|
+
rpm: float = None,
|
154
|
+
omit_system_prompt_if_empty_string: bool = True,
|
155
|
+
key_lookup: Optional[KeyLookup] = None,
|
156
|
+
**kwargs,
|
157
|
+
):
|
158
|
+
"""Initialize the LanguageModel."""
|
159
|
+
self.model = getattr(self, "_model_", None)
|
160
|
+
default_parameters = getattr(self, "_parameters_", None)
|
161
|
+
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
162
|
+
self.parameters = parameters
|
163
|
+
self.remote = False
|
164
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
165
|
+
|
166
|
+
# self._rpm / _tpm comes from the class
|
167
|
+
if rpm is not None:
|
168
|
+
self._rpm = rpm
|
169
|
+
|
170
|
+
if tpm is not None:
|
171
|
+
self._tpm = tpm
|
172
|
+
|
173
|
+
for key, value in parameters.items():
|
174
|
+
setattr(self, key, value)
|
175
|
+
|
176
|
+
for key, value in kwargs.items():
|
177
|
+
if key not in parameters:
|
178
|
+
setattr(self, key, value)
|
179
|
+
|
180
|
+
if "use_cache" in kwargs:
|
181
|
+
warnings.warn(
|
182
|
+
"The use_cache parameter is deprecated. Use the Cache class instead."
|
183
|
+
)
|
184
|
+
|
185
|
+
if skip_api_key_check := kwargs.get("skip_api_key_check", False):
|
186
|
+
# Skip the API key check. Sometimes this is useful for testing.
|
187
|
+
self._api_token = None
|
188
|
+
|
189
|
+
if key_lookup is not None:
|
190
|
+
self.key_lookup = key_lookup
|
191
|
+
else:
|
192
|
+
self.key_lookup = KeyLookup.from_os_environ()
|
193
|
+
|
194
|
+
def ask_question(self, question):
|
195
|
+
user_prompt = question.get_instructions().render(question.data).text
|
196
|
+
system_prompt = "You are a helpful agent pretending to be a human."
|
197
|
+
return self.execute_model_call(user_prompt, system_prompt)
|
198
|
+
|
199
|
+
def set_key_lookup(self, key_lookup: KeyLookup):
|
200
|
+
del self._api_token
|
201
|
+
self.key_lookup = key_lookup
|
202
|
+
|
203
|
+
@property
|
204
|
+
def api_token(self) -> str:
|
205
|
+
if not hasattr(self, "_api_token"):
|
206
|
+
self._api_token = self.key_lookup.get_api_token(
|
207
|
+
self._inference_service_, self.remote
|
208
|
+
)
|
209
|
+
return self._api_token
|
210
|
+
|
211
|
+
def __getitem__(self, key):
|
212
|
+
return getattr(self, key)
|
213
|
+
|
214
|
+
def _repr_html_(self):
|
215
|
+
from edsl.utilities.utilities import data_to_html
|
216
|
+
|
217
|
+
return data_to_html(self.to_dict())
|
218
|
+
|
219
|
+
def hello(self, verbose=False):
|
220
|
+
"""Runs a simple test to check if the model is working."""
|
221
|
+
token = self.api_token
|
222
|
+
masked = token[: min(8, len(token))] + "..."
|
223
|
+
if verbose:
|
224
|
+
print(f"Current key is {masked}")
|
225
|
+
return self.execute_model_call(
|
226
|
+
user_prompt="Hello, model!", system_prompt="You are a helpful agent."
|
227
|
+
)
|
228
|
+
|
229
|
+
def has_valid_api_key(self) -> bool:
|
230
|
+
"""Check if the model has a valid API key.
|
231
|
+
|
232
|
+
>>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
|
233
|
+
True
|
234
|
+
|
235
|
+
This method is used to check if the model has a valid API key.
|
236
|
+
"""
|
237
|
+
from edsl.enums import service_to_api_keyname
|
238
|
+
import os
|
239
|
+
|
240
|
+
if self._model_ == "test":
|
241
|
+
return True
|
242
|
+
|
243
|
+
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
244
|
+
key_value = os.getenv(key_name)
|
245
|
+
return key_value is not None
|
246
|
+
|
247
|
+
def __hash__(self) -> str:
|
248
|
+
"""Allow the model to be used as a key in a dictionary."""
|
249
|
+
from edsl.utilities.utilities import dict_hash
|
250
|
+
|
251
|
+
return dict_hash(self.to_dict())
|
252
|
+
|
253
|
+
def __eq__(self, other):
|
254
|
+
"""Check is two models are the same.
|
255
|
+
|
256
|
+
>>> m1 = LanguageModel.example()
|
257
|
+
>>> m2 = LanguageModel.example()
|
258
|
+
>>> m1 == m2
|
259
|
+
True
|
260
|
+
|
261
|
+
"""
|
262
|
+
return self.model == other.model and self.parameters == other.parameters
|
263
|
+
|
264
|
+
def set_rate_limits(self, rpm=None, tpm=None) -> None:
|
265
|
+
"""Set the rate limits for the model.
|
266
|
+
|
267
|
+
>>> m = LanguageModel.example()
|
268
|
+
>>> m.set_rate_limits(rpm=100, tpm=1000)
|
269
|
+
>>> m.RPM
|
270
|
+
100
|
271
|
+
"""
|
272
|
+
if rpm is not None:
|
273
|
+
self._rpm = rpm
|
274
|
+
if tpm is not None:
|
275
|
+
self._tpm = tpm
|
276
|
+
return None
|
277
|
+
|
278
|
+
@property
|
279
|
+
def RPM(self):
|
280
|
+
"""Model's requests-per-minute limit."""
|
281
|
+
# self._set_rate_limits()
|
282
|
+
# return self._safety_factor * self.__rate_limits["rpm"]
|
283
|
+
return self._rpm
|
284
|
+
|
285
|
+
@property
|
286
|
+
def TPM(self):
|
287
|
+
"""Model's tokens-per-minute limit."""
|
288
|
+
# self._set_rate_limits()
|
289
|
+
# return self._safety_factor * self.__rate_limits["tpm"]
|
290
|
+
return self._tpm
|
291
|
+
|
292
|
+
@property
|
293
|
+
def rpm(self):
|
294
|
+
return self._rpm
|
295
|
+
|
296
|
+
@rpm.setter
|
297
|
+
def rpm(self, value):
|
298
|
+
self._rpm = value
|
299
|
+
|
300
|
+
@property
|
301
|
+
def tpm(self):
|
302
|
+
return self._tpm
|
303
|
+
|
304
|
+
@tpm.setter
|
305
|
+
def tpm(self, value):
|
306
|
+
self._tpm = value
|
307
|
+
|
308
|
+
@staticmethod
|
309
|
+
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
310
|
+
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
|
311
|
+
|
312
|
+
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
|
313
|
+
{'temperature': 0.5}
|
314
|
+
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
|
315
|
+
{'temperature': 0.5, 'max_tokens': 1000}
|
316
|
+
"""
|
317
|
+
# parameters = dict({})
|
318
|
+
|
319
|
+
# this is the case when data is loaded from a dict after serialization
|
320
|
+
if "parameters" in passed_parameter_dict:
|
321
|
+
passed_parameter_dict = passed_parameter_dict["parameters"]
|
322
|
+
return {
|
323
|
+
parameter_name: passed_parameter_dict.get(parameter_name, default_value)
|
324
|
+
for parameter_name, default_value in default_parameter_dict.items()
|
325
|
+
}
|
326
|
+
|
327
|
+
def __call__(self, user_prompt: str, system_prompt: str):
|
328
|
+
return self.execute_model_call(user_prompt, system_prompt)
|
329
|
+
|
330
|
+
@abstractmethod
|
331
|
+
async def async_execute_model_call(user_prompt: str, system_prompt: str):
|
332
|
+
"""Execute the model call and returns a coroutine.
|
333
|
+
|
334
|
+
>>> m = LanguageModel.example(test_model = True)
|
335
|
+
>>> async def test(): return await m.async_execute_model_call("Hello, model!", "You are a helpful agent.")
|
336
|
+
>>> asyncio.run(test())
|
337
|
+
{'message': [{'text': 'Hello world'}], ...}
|
338
|
+
|
339
|
+
>>> m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
340
|
+
{'message': [{'text': 'Hello world'}], ...}
|
341
|
+
"""
|
342
|
+
pass
|
343
|
+
|
344
|
+
async def remote_async_execute_model_call(
|
345
|
+
self, user_prompt: str, system_prompt: str
|
346
|
+
):
|
347
|
+
"""Execute the model call and returns the result as a coroutine, using Coop."""
|
348
|
+
from edsl.coop import Coop
|
349
|
+
|
350
|
+
client = Coop()
|
351
|
+
response_data = await client.remote_async_execute_model_call(
|
352
|
+
self.to_dict(), user_prompt, system_prompt
|
353
|
+
)
|
354
|
+
return response_data
|
355
|
+
|
356
|
+
@jupyter_nb_handler
|
357
|
+
def execute_model_call(self, *args, **kwargs) -> Coroutine:
|
358
|
+
"""Execute the model call and returns the result as a coroutine.
|
359
|
+
|
360
|
+
>>> m = LanguageModel.example(test_model = True)
|
361
|
+
>>> m.execute_model_call(user_prompt = "Hello, model!", system_prompt = "You are a helpful agent.")
|
362
|
+
|
363
|
+
"""
|
364
|
+
|
365
|
+
async def main():
|
366
|
+
results = await asyncio.gather(
|
367
|
+
self.async_execute_model_call(*args, **kwargs)
|
368
|
+
)
|
369
|
+
return results[0] # Since there's only one task, return its result
|
370
|
+
|
371
|
+
return main()
|
372
|
+
|
373
|
+
@classmethod
|
374
|
+
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
375
|
+
"""Return the generated token string from the raw response."""
|
376
|
+
return extract_item_from_raw_response(raw_response, cls.key_sequence)
|
377
|
+
|
378
|
+
@classmethod
|
379
|
+
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
380
|
+
"""Return the usage dictionary from the raw response."""
|
381
|
+
if not hasattr(cls, "usage_sequence"):
|
382
|
+
raise NotImplementedError(
|
383
|
+
"This inference service does not have a usage_sequence."
|
384
|
+
)
|
385
|
+
return extract_item_from_raw_response(raw_response, cls.usage_sequence)
|
386
|
+
|
387
|
+
@staticmethod
|
388
|
+
def convert_answer(response_part):
|
389
|
+
import json
|
390
|
+
|
391
|
+
response_part = response_part.strip()
|
392
|
+
|
393
|
+
if response_part == "None":
|
394
|
+
return None
|
395
|
+
|
396
|
+
repaired = repair_json(response_part)
|
397
|
+
if repaired == '""':
|
398
|
+
# it was a literal string
|
399
|
+
return response_part
|
400
|
+
|
401
|
+
try:
|
402
|
+
return json.loads(repaired)
|
403
|
+
except json.JSONDecodeError as j:
|
404
|
+
# last resort
|
405
|
+
return response_part
|
406
|
+
|
407
|
+
@classmethod
|
408
|
+
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
409
|
+
"""Parses the API response and returns the response text."""
|
410
|
+
generated_token_string = cls.get_generated_token_string(raw_response)
|
411
|
+
last_newline = generated_token_string.rfind("\n")
|
412
|
+
|
413
|
+
if last_newline == -1:
|
414
|
+
# There is no comment
|
415
|
+
edsl_dict = {
|
416
|
+
"answer": cls.convert_answer(generated_token_string),
|
417
|
+
"generated_tokens": generated_token_string,
|
418
|
+
"comment": None,
|
419
|
+
}
|
420
|
+
else:
|
421
|
+
edsl_dict = {
|
422
|
+
"answer": cls.convert_answer(generated_token_string[:last_newline]),
|
423
|
+
"comment": generated_token_string[last_newline + 1 :].strip(),
|
424
|
+
"generated_tokens": generated_token_string,
|
425
|
+
}
|
426
|
+
return EDSLOutput(**edsl_dict)
|
427
|
+
|
428
|
+
async def _async_get_intended_model_call_outcome(
|
429
|
+
self,
|
430
|
+
user_prompt: str,
|
431
|
+
system_prompt: str,
|
432
|
+
cache: "Cache",
|
433
|
+
iteration: int = 0,
|
434
|
+
files_list=None,
|
435
|
+
) -> ModelResponse:
|
436
|
+
"""Handle caching of responses.
|
437
|
+
|
438
|
+
:param user_prompt: The user's prompt.
|
439
|
+
:param system_prompt: The system's prompt.
|
440
|
+
:param iteration: The iteration number.
|
441
|
+
:param cache: The cache to use.
|
442
|
+
|
443
|
+
If the cache isn't being used, it just returns a 'fresh' call to the LLM.
|
444
|
+
But if cache is being used, it first checks the database to see if the response is already there.
|
445
|
+
If it is, it returns the cached response, but again appends some tracking information.
|
446
|
+
If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
|
447
|
+
|
448
|
+
If self.use_cache is True, then attempts to retrieve the response from the database;
|
449
|
+
if not in the DB, calls the LLM and writes the response to the DB.
|
450
|
+
|
451
|
+
>>> from edsl import Cache
|
452
|
+
>>> m = LanguageModel.example(test_model = True)
|
453
|
+
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
454
|
+
ModelResponse(...)"""
|
455
|
+
|
456
|
+
if files_list:
|
457
|
+
files_hash = "+".join([str(hash(file)) for file in files_list])
|
458
|
+
# print(f"Files hash: {files_hash}")
|
459
|
+
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
460
|
+
else:
|
461
|
+
user_prompt_with_hashes = user_prompt
|
462
|
+
|
463
|
+
cache_call_params = {
|
464
|
+
"model": str(self.model),
|
465
|
+
"parameters": self.parameters,
|
466
|
+
"system_prompt": system_prompt,
|
467
|
+
"user_prompt": user_prompt_with_hashes,
|
468
|
+
"iteration": iteration,
|
469
|
+
}
|
470
|
+
cached_response, cache_key = cache.fetch(**cache_call_params)
|
471
|
+
|
472
|
+
if cache_used := cached_response is not None:
|
473
|
+
response = json.loads(cached_response)
|
474
|
+
else:
|
475
|
+
f = (
|
476
|
+
self.remote_async_execute_model_call
|
477
|
+
if hasattr(self, "remote") and self.remote
|
478
|
+
else self.async_execute_model_call
|
479
|
+
)
|
480
|
+
params = {
|
481
|
+
"user_prompt": user_prompt,
|
482
|
+
"system_prompt": system_prompt,
|
483
|
+
"files_list": files_list,
|
484
|
+
# **({"encoded_image": encoded_image} if encoded_image else {}),
|
485
|
+
}
|
486
|
+
# response = await f(**params)
|
487
|
+
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
488
|
+
new_cache_key = cache.store(
|
489
|
+
**cache_call_params, response=response
|
490
|
+
) # store the response in the cache
|
491
|
+
assert new_cache_key == cache_key # should be the same
|
492
|
+
|
493
|
+
cost = self.cost(response)
|
494
|
+
|
495
|
+
return ModelResponse(
|
496
|
+
response=response,
|
497
|
+
cache_used=cache_used,
|
498
|
+
cache_key=cache_key,
|
499
|
+
cached_response=cached_response,
|
500
|
+
cost=cost,
|
501
|
+
)
|
502
|
+
|
503
|
+
_get_intended_model_call_outcome = sync_wrapper(
|
504
|
+
_async_get_intended_model_call_outcome
|
505
|
+
)
|
506
|
+
|
507
|
+
# get_raw_response = sync_wrapper(async_get_raw_response)
|
508
|
+
|
509
|
+
def simple_ask(
|
510
|
+
self,
|
511
|
+
question: "QuestionBase",
|
512
|
+
system_prompt="You are a helpful agent pretending to be a human.",
|
513
|
+
top_logprobs=2,
|
514
|
+
):
|
515
|
+
"""Ask a question and return the response."""
|
516
|
+
self.logprobs = True
|
517
|
+
self.top_logprobs = top_logprobs
|
518
|
+
return self.execute_model_call(
|
519
|
+
user_prompt=question.human_readable(), system_prompt=system_prompt
|
520
|
+
)
|
521
|
+
|
522
|
+
async def async_get_response(
|
523
|
+
self,
|
524
|
+
user_prompt: str,
|
525
|
+
system_prompt: str,
|
526
|
+
cache: "Cache",
|
527
|
+
iteration: int = 1,
|
528
|
+
files_list: Optional[List["File"]] = None,
|
529
|
+
) -> dict:
|
530
|
+
"""Get response, parse, and return as string.
|
531
|
+
|
532
|
+
:param user_prompt: The user's prompt.
|
533
|
+
:param system_prompt: The system's prompt.
|
534
|
+
:param iteration: The iteration number.
|
535
|
+
:param cache: The cache to use.
|
536
|
+
:param encoded_image: The encoded image to use.
|
537
|
+
|
538
|
+
"""
|
539
|
+
params = {
|
540
|
+
"user_prompt": user_prompt,
|
541
|
+
"system_prompt": system_prompt,
|
542
|
+
"iteration": iteration,
|
543
|
+
"cache": cache,
|
544
|
+
"files_list": files_list,
|
545
|
+
}
|
546
|
+
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
547
|
+
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
548
|
+
edsl_dict = self.parse_response(model_outputs.response)
|
549
|
+
agent_response_dict = AgentResponseDict(
|
550
|
+
model_inputs=model_inputs,
|
551
|
+
model_outputs=model_outputs,
|
552
|
+
edsl_dict=edsl_dict,
|
553
|
+
)
|
554
|
+
return agent_response_dict
|
555
|
+
|
556
|
+
# return await self._async_prepare_response(model_call_outcome, cache=cache)
|
557
|
+
|
558
|
+
get_response = sync_wrapper(async_get_response)
|
559
|
+
|
560
|
+
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
561
|
+
"""Return the dollar cost of a raw response."""
|
562
|
+
|
563
|
+
usage = self.get_usage_dict(raw_response)
|
564
|
+
from edsl.coop import Coop
|
565
|
+
|
566
|
+
c = Coop()
|
567
|
+
price_lookup = c.fetch_prices()
|
568
|
+
key = (self._inference_service_, self.model)
|
569
|
+
if key not in price_lookup:
|
570
|
+
return f"Could not find price for model {self.model} in the price lookup."
|
571
|
+
|
572
|
+
relevant_prices = price_lookup[key]
|
573
|
+
try:
|
574
|
+
input_tokens = int(usage[self.input_token_name])
|
575
|
+
output_tokens = int(usage[self.output_token_name])
|
576
|
+
except Exception as e:
|
577
|
+
return f"Could not fetch tokens from model response: {e}"
|
578
|
+
|
579
|
+
try:
|
580
|
+
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
581
|
+
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
582
|
+
except Exception as e:
|
583
|
+
if "output" not in relevant_prices:
|
584
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'output' key."
|
585
|
+
if "input" not in relevant_prices:
|
586
|
+
return f"Could not fetch prices from {relevant_prices} - {e}; Missing 'input' key."
|
587
|
+
return f"Could not fetch prices from {relevant_prices} - {e}"
|
588
|
+
|
589
|
+
if inverse_input_price == "infinity":
|
590
|
+
input_cost = 0
|
591
|
+
else:
|
592
|
+
try:
|
593
|
+
input_cost = input_tokens / float(inverse_input_price)
|
594
|
+
except Exception as e:
|
595
|
+
return f"Could not compute input price - {e}."
|
596
|
+
|
597
|
+
if inverse_output_price == "infinity":
|
598
|
+
output_cost = 0
|
599
|
+
else:
|
600
|
+
try:
|
601
|
+
output_cost = output_tokens / float(inverse_output_price)
|
602
|
+
except Exception as e:
|
603
|
+
return f"Could not compute output price - {e}"
|
604
|
+
|
605
|
+
return input_cost + output_cost
|
606
|
+
|
607
|
+
#######################
|
608
|
+
# SERIALIZATION METHODS
|
609
|
+
#######################
|
610
|
+
def to_dict(self, add_edsl_version=True) -> dict[str, Any]:
|
611
|
+
"""Convert instance to a dictionary
|
612
|
+
|
613
|
+
>>> m = LanguageModel.example()
|
614
|
+
>>> m.to_dict()
|
615
|
+
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
616
|
+
"""
|
617
|
+
d = {"model": self.model, "parameters": self.parameters}
|
618
|
+
if add_edsl_version:
|
619
|
+
from edsl import __version__
|
620
|
+
|
621
|
+
d["edsl_version"] = __version__
|
622
|
+
d["edsl_class_name"] = self.__class__.__name__
|
623
|
+
return d
|
624
|
+
|
625
|
+
@classmethod
|
626
|
+
@remove_edsl_version
|
627
|
+
def from_dict(cls, data: dict) -> Type[LanguageModel]:
|
628
|
+
"""Convert dictionary to a LanguageModel child instance."""
|
629
|
+
from edsl.language_models.registry import get_model_class
|
630
|
+
|
631
|
+
model_class = get_model_class(data["model"])
|
632
|
+
# data["use_cache"] = True
|
633
|
+
return model_class(**data)
|
634
|
+
|
635
|
+
#######################
|
636
|
+
# DUNDER METHODS
|
637
|
+
#######################
|
638
|
+
def print(self):
|
639
|
+
from rich import print_json
|
640
|
+
import json
|
641
|
+
|
642
|
+
print_json(json.dumps(self.to_dict()))
|
643
|
+
|
644
|
+
def __repr__(self) -> str:
|
645
|
+
"""Return a string representation of the object."""
|
646
|
+
param_string = ", ".join(
|
647
|
+
f"{key} = {value}" for key, value in self.parameters.items()
|
648
|
+
)
|
649
|
+
return (
|
650
|
+
f"Model(model_name = '{self.model}'"
|
651
|
+
+ (f", {param_string}" if param_string else "")
|
652
|
+
+ ")"
|
653
|
+
)
|
654
|
+
|
655
|
+
def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
|
656
|
+
"""Combine two models into a single model (other_model takes precedence over self)."""
|
657
|
+
print(
|
658
|
+
f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
|
659
|
+
by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
|
660
|
+
)
|
661
|
+
return other_model or self
|
662
|
+
|
663
|
+
def rich_print(self):
|
664
|
+
"""Display an object as a table."""
|
665
|
+
from rich.table import Table
|
666
|
+
|
667
|
+
table = Table(title="Language Model")
|
668
|
+
table.add_column("Attribute", style="bold")
|
669
|
+
table.add_column("Value")
|
670
|
+
|
671
|
+
to_display = self.__dict__.copy()
|
672
|
+
for attr_name, attr_value in to_display.items():
|
673
|
+
table.add_row(attr_name, repr(attr_value))
|
674
|
+
|
675
|
+
return table
|
676
|
+
|
677
|
+
@classmethod
|
678
|
+
def example(
|
679
|
+
cls,
|
680
|
+
test_model: bool = False,
|
681
|
+
canned_response: str = "Hello world",
|
682
|
+
throw_exception: bool = False,
|
683
|
+
):
|
684
|
+
"""Return a default instance of the class.
|
685
|
+
|
686
|
+
>>> from edsl.language_models import LanguageModel
|
687
|
+
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
|
688
|
+
>>> isinstance(m, LanguageModel)
|
689
|
+
True
|
690
|
+
>>> from edsl import QuestionFreeText
|
691
|
+
>>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
|
692
|
+
>>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
|
693
|
+
'WOWZA!'
|
694
|
+
"""
|
695
|
+
from edsl import Model
|
696
|
+
|
697
|
+
if test_model:
|
698
|
+
m = Model("test", canned_response=canned_response)
|
699
|
+
return m
|
700
|
+
else:
|
701
|
+
return Model(skip_api_key_check=True)
|
702
|
+
|
703
|
+
|
704
|
+
if __name__ == "__main__":
|
705
|
+
"""Run the module's test suite."""
|
706
|
+
import doctest
|
707
|
+
|
708
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|