edsl 0.1.36.dev5__py3-none-any.whl → 0.1.36.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +47 -47
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +804 -804
- edsl/agents/AgentList.py +337 -337
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +294 -294
- edsl/agents/PromptConstructor.py +312 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +86 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +152 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +238 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +849 -849
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +83 -83
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +40 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +26 -26
- edsl/exceptions/surveys.py +34 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +115 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +72 -68
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -94
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1112 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +651 -651
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +182 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +337 -337
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +441 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/LanguageModel.py +718 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +2 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +358 -358
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +616 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +266 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +113 -113
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +418 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +693 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +433 -433
- edsl/results/Results.py +1158 -1158
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +118 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +443 -443
- edsl/scenarios/Scenario.py +507 -507
- edsl/scenarios/ScenarioHtmlMixin.py +59 -59
- edsl/scenarios/ScenarioList.py +1101 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +2 -2
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +324 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1772 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +391 -391
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/LICENSE +21 -21
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/METADATA +1 -1
- edsl-0.1.36.dev6.dist-info/RECORD +279 -0
- edsl-0.1.36.dev5.dist-info/RECORD +0 -279
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/WHEEL +0 -0
@@ -1,224 +1,224 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Any, List, Optional
|
3
|
-
import os
|
4
|
-
|
5
|
-
import openai
|
6
|
-
|
7
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
|
-
from edsl.language_models import LanguageModel
|
9
|
-
from edsl.inference_services.rate_limits_cache import rate_limits
|
10
|
-
from edsl.utilities.utilities import fix_partial_correct_response
|
11
|
-
|
12
|
-
from edsl.config import CONFIG
|
13
|
-
|
14
|
-
|
15
|
-
class OpenAIService(InferenceServiceABC):
|
16
|
-
"""OpenAI service class."""
|
17
|
-
|
18
|
-
_inference_service_ = "openai"
|
19
|
-
_env_key_name_ = "OPENAI_API_KEY"
|
20
|
-
_base_url_ = None
|
21
|
-
|
22
|
-
_sync_client_ = openai.OpenAI
|
23
|
-
_async_client_ = openai.AsyncOpenAI
|
24
|
-
|
25
|
-
_sync_client_instance = None
|
26
|
-
_async_client_instance = None
|
27
|
-
|
28
|
-
key_sequence = ["choices", 0, "message", "content"]
|
29
|
-
usage_sequence = ["usage"]
|
30
|
-
input_token_name = "prompt_tokens"
|
31
|
-
output_token_name = "completion_tokens"
|
32
|
-
|
33
|
-
def __init_subclass__(cls, **kwargs):
|
34
|
-
super().__init_subclass__(**kwargs)
|
35
|
-
# so subclasses have to create their own instances of the clients
|
36
|
-
cls._sync_client_instance = None
|
37
|
-
cls._async_client_instance = None
|
38
|
-
|
39
|
-
@classmethod
|
40
|
-
def sync_client(cls):
|
41
|
-
if cls._sync_client_instance is None:
|
42
|
-
cls._sync_client_instance = cls._sync_client_(
|
43
|
-
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
44
|
-
)
|
45
|
-
return cls._sync_client_instance
|
46
|
-
|
47
|
-
@classmethod
|
48
|
-
def async_client(cls):
|
49
|
-
if cls._async_client_instance is None:
|
50
|
-
cls._async_client_instance = cls._async_client_(
|
51
|
-
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
52
|
-
)
|
53
|
-
return cls._async_client_instance
|
54
|
-
|
55
|
-
model_exclude_list = [
|
56
|
-
"whisper-1",
|
57
|
-
"davinci-002",
|
58
|
-
"dall-e-2",
|
59
|
-
"tts-1-hd-1106",
|
60
|
-
"tts-1-hd",
|
61
|
-
"dall-e-3",
|
62
|
-
"tts-1",
|
63
|
-
"babbage-002",
|
64
|
-
"tts-1-1106",
|
65
|
-
"text-embedding-3-large",
|
66
|
-
"text-embedding-3-small",
|
67
|
-
"text-embedding-ada-002",
|
68
|
-
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
69
|
-
"gpt-3.5-turbo-instruct-0914",
|
70
|
-
"gpt-3.5-turbo-instruct",
|
71
|
-
]
|
72
|
-
_models_list_cache: List[str] = []
|
73
|
-
|
74
|
-
@classmethod
|
75
|
-
def get_model_list(cls):
|
76
|
-
raw_list = cls.sync_client().models.list()
|
77
|
-
if hasattr(raw_list, "data"):
|
78
|
-
return raw_list.data
|
79
|
-
else:
|
80
|
-
return raw_list
|
81
|
-
|
82
|
-
@classmethod
|
83
|
-
def available(cls) -> List[str]:
|
84
|
-
if not cls._models_list_cache:
|
85
|
-
try:
|
86
|
-
cls._models_list_cache = [
|
87
|
-
m.id
|
88
|
-
for m in cls.get_model_list()
|
89
|
-
if m.id not in cls.model_exclude_list
|
90
|
-
]
|
91
|
-
except Exception as e:
|
92
|
-
raise
|
93
|
-
return cls._models_list_cache
|
94
|
-
|
95
|
-
@classmethod
|
96
|
-
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
97
|
-
if model_class_name is None:
|
98
|
-
model_class_name = cls.to_class_name(model_name)
|
99
|
-
|
100
|
-
class LLM(LanguageModel):
|
101
|
-
"""
|
102
|
-
Child class of LanguageModel for interacting with OpenAI models
|
103
|
-
"""
|
104
|
-
|
105
|
-
key_sequence = cls.key_sequence
|
106
|
-
usage_sequence = cls.usage_sequence
|
107
|
-
input_token_name = cls.input_token_name
|
108
|
-
output_token_name = cls.output_token_name
|
109
|
-
|
110
|
-
_rpm = cls.get_rpm(cls)
|
111
|
-
_tpm = cls.get_tpm(cls)
|
112
|
-
|
113
|
-
_inference_service_ = cls._inference_service_
|
114
|
-
_model_ = model_name
|
115
|
-
_parameters_ = {
|
116
|
-
"temperature": 0.5,
|
117
|
-
"max_tokens": 1000,
|
118
|
-
"top_p": 1,
|
119
|
-
"frequency_penalty": 0,
|
120
|
-
"presence_penalty": 0,
|
121
|
-
"logprobs": False,
|
122
|
-
"top_logprobs": 3,
|
123
|
-
}
|
124
|
-
|
125
|
-
def sync_client(self):
|
126
|
-
return cls.sync_client()
|
127
|
-
|
128
|
-
def async_client(self):
|
129
|
-
return cls.async_client()
|
130
|
-
|
131
|
-
@classmethod
|
132
|
-
def available(cls) -> list[str]:
|
133
|
-
return cls.sync_client().models.list()
|
134
|
-
|
135
|
-
def get_headers(self) -> dict[str, Any]:
|
136
|
-
client = self.sync_client()
|
137
|
-
response = client.chat.completions.with_raw_response.create(
|
138
|
-
messages=[
|
139
|
-
{
|
140
|
-
"role": "user",
|
141
|
-
"content": "Say this is a test",
|
142
|
-
}
|
143
|
-
],
|
144
|
-
model=self.model,
|
145
|
-
)
|
146
|
-
return dict(response.headers)
|
147
|
-
|
148
|
-
def get_rate_limits(self) -> dict[str, Any]:
|
149
|
-
try:
|
150
|
-
if "openai" in rate_limits:
|
151
|
-
headers = rate_limits["openai"]
|
152
|
-
|
153
|
-
else:
|
154
|
-
headers = self.get_headers()
|
155
|
-
|
156
|
-
except Exception as e:
|
157
|
-
return {
|
158
|
-
"rpm": 10_000,
|
159
|
-
"tpm": 2_000_000,
|
160
|
-
}
|
161
|
-
else:
|
162
|
-
return {
|
163
|
-
"rpm": int(headers["x-ratelimit-limit-requests"]),
|
164
|
-
"tpm": int(headers["x-ratelimit-limit-tokens"]),
|
165
|
-
}
|
166
|
-
|
167
|
-
async def async_execute_model_call(
|
168
|
-
self,
|
169
|
-
user_prompt: str,
|
170
|
-
system_prompt: str = "",
|
171
|
-
files_list: Optional[List["Files"]] = None,
|
172
|
-
invigilator: Optional[
|
173
|
-
"InvigilatorAI"
|
174
|
-
] = None, # TBD - can eventually be used for function-calling
|
175
|
-
) -> dict[str, Any]:
|
176
|
-
"""Calls the OpenAI API and returns the API response."""
|
177
|
-
if files_list:
|
178
|
-
encoded_image = files_list[0].base64_string
|
179
|
-
content = [{"type": "text", "text": user_prompt}]
|
180
|
-
content.append(
|
181
|
-
{
|
182
|
-
"type": "image_url",
|
183
|
-
"image_url": {
|
184
|
-
"url": f"data:image/jpeg;base64,{encoded_image}"
|
185
|
-
},
|
186
|
-
}
|
187
|
-
)
|
188
|
-
else:
|
189
|
-
content = user_prompt
|
190
|
-
client = self.async_client()
|
191
|
-
|
192
|
-
messages = [
|
193
|
-
{"role": "system", "content": system_prompt},
|
194
|
-
{"role": "user", "content": content},
|
195
|
-
]
|
196
|
-
if (
|
197
|
-
system_prompt == "" and self.omit_system_prompt_if_empty
|
198
|
-
) or "o1" in self.model:
|
199
|
-
messages = messages[1:]
|
200
|
-
|
201
|
-
params = {
|
202
|
-
"model": self.model,
|
203
|
-
"messages": messages,
|
204
|
-
"temperature": self.temperature,
|
205
|
-
"max_tokens": self.max_tokens,
|
206
|
-
"top_p": self.top_p,
|
207
|
-
"frequency_penalty": self.frequency_penalty,
|
208
|
-
"presence_penalty": self.presence_penalty,
|
209
|
-
"logprobs": self.logprobs,
|
210
|
-
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
211
|
-
}
|
212
|
-
if "o1" in self.model:
|
213
|
-
params.pop("max_tokens")
|
214
|
-
params["max_completion_tokens"] = self.max_tokens
|
215
|
-
params["temperature"] = 1
|
216
|
-
try:
|
217
|
-
response = await client.chat.completions.create(**params)
|
218
|
-
except Exception as e:
|
219
|
-
print(e)
|
220
|
-
return response.model_dump()
|
221
|
-
|
222
|
-
LLM.__name__ = "LanguageModel"
|
223
|
-
|
224
|
-
return LLM
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, List, Optional
|
3
|
+
import os
|
4
|
+
|
5
|
+
import openai
|
6
|
+
|
7
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
|
+
from edsl.language_models import LanguageModel
|
9
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
10
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
11
|
+
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
|
14
|
+
|
15
|
+
class OpenAIService(InferenceServiceABC):
|
16
|
+
"""OpenAI service class."""
|
17
|
+
|
18
|
+
_inference_service_ = "openai"
|
19
|
+
_env_key_name_ = "OPENAI_API_KEY"
|
20
|
+
_base_url_ = None
|
21
|
+
|
22
|
+
_sync_client_ = openai.OpenAI
|
23
|
+
_async_client_ = openai.AsyncOpenAI
|
24
|
+
|
25
|
+
_sync_client_instance = None
|
26
|
+
_async_client_instance = None
|
27
|
+
|
28
|
+
key_sequence = ["choices", 0, "message", "content"]
|
29
|
+
usage_sequence = ["usage"]
|
30
|
+
input_token_name = "prompt_tokens"
|
31
|
+
output_token_name = "completion_tokens"
|
32
|
+
|
33
|
+
def __init_subclass__(cls, **kwargs):
|
34
|
+
super().__init_subclass__(**kwargs)
|
35
|
+
# so subclasses have to create their own instances of the clients
|
36
|
+
cls._sync_client_instance = None
|
37
|
+
cls._async_client_instance = None
|
38
|
+
|
39
|
+
@classmethod
|
40
|
+
def sync_client(cls):
|
41
|
+
if cls._sync_client_instance is None:
|
42
|
+
cls._sync_client_instance = cls._sync_client_(
|
43
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
44
|
+
)
|
45
|
+
return cls._sync_client_instance
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def async_client(cls):
|
49
|
+
if cls._async_client_instance is None:
|
50
|
+
cls._async_client_instance = cls._async_client_(
|
51
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
52
|
+
)
|
53
|
+
return cls._async_client_instance
|
54
|
+
|
55
|
+
model_exclude_list = [
|
56
|
+
"whisper-1",
|
57
|
+
"davinci-002",
|
58
|
+
"dall-e-2",
|
59
|
+
"tts-1-hd-1106",
|
60
|
+
"tts-1-hd",
|
61
|
+
"dall-e-3",
|
62
|
+
"tts-1",
|
63
|
+
"babbage-002",
|
64
|
+
"tts-1-1106",
|
65
|
+
"text-embedding-3-large",
|
66
|
+
"text-embedding-3-small",
|
67
|
+
"text-embedding-ada-002",
|
68
|
+
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
69
|
+
"gpt-3.5-turbo-instruct-0914",
|
70
|
+
"gpt-3.5-turbo-instruct",
|
71
|
+
]
|
72
|
+
_models_list_cache: List[str] = []
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def get_model_list(cls):
|
76
|
+
raw_list = cls.sync_client().models.list()
|
77
|
+
if hasattr(raw_list, "data"):
|
78
|
+
return raw_list.data
|
79
|
+
else:
|
80
|
+
return raw_list
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def available(cls) -> List[str]:
|
84
|
+
if not cls._models_list_cache:
|
85
|
+
try:
|
86
|
+
cls._models_list_cache = [
|
87
|
+
m.id
|
88
|
+
for m in cls.get_model_list()
|
89
|
+
if m.id not in cls.model_exclude_list
|
90
|
+
]
|
91
|
+
except Exception as e:
|
92
|
+
raise
|
93
|
+
return cls._models_list_cache
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
97
|
+
if model_class_name is None:
|
98
|
+
model_class_name = cls.to_class_name(model_name)
|
99
|
+
|
100
|
+
class LLM(LanguageModel):
|
101
|
+
"""
|
102
|
+
Child class of LanguageModel for interacting with OpenAI models
|
103
|
+
"""
|
104
|
+
|
105
|
+
key_sequence = cls.key_sequence
|
106
|
+
usage_sequence = cls.usage_sequence
|
107
|
+
input_token_name = cls.input_token_name
|
108
|
+
output_token_name = cls.output_token_name
|
109
|
+
|
110
|
+
_rpm = cls.get_rpm(cls)
|
111
|
+
_tpm = cls.get_tpm(cls)
|
112
|
+
|
113
|
+
_inference_service_ = cls._inference_service_
|
114
|
+
_model_ = model_name
|
115
|
+
_parameters_ = {
|
116
|
+
"temperature": 0.5,
|
117
|
+
"max_tokens": 1000,
|
118
|
+
"top_p": 1,
|
119
|
+
"frequency_penalty": 0,
|
120
|
+
"presence_penalty": 0,
|
121
|
+
"logprobs": False,
|
122
|
+
"top_logprobs": 3,
|
123
|
+
}
|
124
|
+
|
125
|
+
def sync_client(self):
|
126
|
+
return cls.sync_client()
|
127
|
+
|
128
|
+
def async_client(self):
|
129
|
+
return cls.async_client()
|
130
|
+
|
131
|
+
@classmethod
|
132
|
+
def available(cls) -> list[str]:
|
133
|
+
return cls.sync_client().models.list()
|
134
|
+
|
135
|
+
def get_headers(self) -> dict[str, Any]:
|
136
|
+
client = self.sync_client()
|
137
|
+
response = client.chat.completions.with_raw_response.create(
|
138
|
+
messages=[
|
139
|
+
{
|
140
|
+
"role": "user",
|
141
|
+
"content": "Say this is a test",
|
142
|
+
}
|
143
|
+
],
|
144
|
+
model=self.model,
|
145
|
+
)
|
146
|
+
return dict(response.headers)
|
147
|
+
|
148
|
+
def get_rate_limits(self) -> dict[str, Any]:
|
149
|
+
try:
|
150
|
+
if "openai" in rate_limits:
|
151
|
+
headers = rate_limits["openai"]
|
152
|
+
|
153
|
+
else:
|
154
|
+
headers = self.get_headers()
|
155
|
+
|
156
|
+
except Exception as e:
|
157
|
+
return {
|
158
|
+
"rpm": 10_000,
|
159
|
+
"tpm": 2_000_000,
|
160
|
+
}
|
161
|
+
else:
|
162
|
+
return {
|
163
|
+
"rpm": int(headers["x-ratelimit-limit-requests"]),
|
164
|
+
"tpm": int(headers["x-ratelimit-limit-tokens"]),
|
165
|
+
}
|
166
|
+
|
167
|
+
async def async_execute_model_call(
|
168
|
+
self,
|
169
|
+
user_prompt: str,
|
170
|
+
system_prompt: str = "",
|
171
|
+
files_list: Optional[List["Files"]] = None,
|
172
|
+
invigilator: Optional[
|
173
|
+
"InvigilatorAI"
|
174
|
+
] = None, # TBD - can eventually be used for function-calling
|
175
|
+
) -> dict[str, Any]:
|
176
|
+
"""Calls the OpenAI API and returns the API response."""
|
177
|
+
if files_list:
|
178
|
+
encoded_image = files_list[0].base64_string
|
179
|
+
content = [{"type": "text", "text": user_prompt}]
|
180
|
+
content.append(
|
181
|
+
{
|
182
|
+
"type": "image_url",
|
183
|
+
"image_url": {
|
184
|
+
"url": f"data:image/jpeg;base64,{encoded_image}"
|
185
|
+
},
|
186
|
+
}
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
content = user_prompt
|
190
|
+
client = self.async_client()
|
191
|
+
|
192
|
+
messages = [
|
193
|
+
{"role": "system", "content": system_prompt},
|
194
|
+
{"role": "user", "content": content},
|
195
|
+
]
|
196
|
+
if (
|
197
|
+
system_prompt == "" and self.omit_system_prompt_if_empty
|
198
|
+
) or "o1" in self.model:
|
199
|
+
messages = messages[1:]
|
200
|
+
|
201
|
+
params = {
|
202
|
+
"model": self.model,
|
203
|
+
"messages": messages,
|
204
|
+
"temperature": self.temperature,
|
205
|
+
"max_tokens": self.max_tokens,
|
206
|
+
"top_p": self.top_p,
|
207
|
+
"frequency_penalty": self.frequency_penalty,
|
208
|
+
"presence_penalty": self.presence_penalty,
|
209
|
+
"logprobs": self.logprobs,
|
210
|
+
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
211
|
+
}
|
212
|
+
if "o1" in self.model:
|
213
|
+
params.pop("max_tokens")
|
214
|
+
params["max_completion_tokens"] = self.max_tokens
|
215
|
+
params["temperature"] = 1
|
216
|
+
try:
|
217
|
+
response = await client.chat.completions.create(**params)
|
218
|
+
except Exception as e:
|
219
|
+
print(e)
|
220
|
+
return response.model_dump()
|
221
|
+
|
222
|
+
LLM.__name__ = "LanguageModel"
|
223
|
+
|
224
|
+
return LLM
|
@@ -1,89 +1,89 @@
|
|
1
|
-
from typing import Any, List, Optional
|
2
|
-
import os
|
3
|
-
import asyncio
|
4
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
-
from edsl.language_models import LanguageModel
|
6
|
-
from edsl.inference_services.rate_limits_cache import rate_limits
|
7
|
-
from edsl.utilities.utilities import fix_partial_correct_response
|
8
|
-
|
9
|
-
from edsl.enums import InferenceServiceType
|
10
|
-
import random
|
11
|
-
|
12
|
-
|
13
|
-
class TestService(InferenceServiceABC):
|
14
|
-
"""OpenAI service class."""
|
15
|
-
|
16
|
-
_inference_service_ = "test"
|
17
|
-
_env_key_name_ = None
|
18
|
-
_base_url_ = None
|
19
|
-
|
20
|
-
_sync_client_ = None
|
21
|
-
_async_client_ = None
|
22
|
-
|
23
|
-
_sync_client_instance = None
|
24
|
-
_async_client_instance = None
|
25
|
-
|
26
|
-
key_sequence = None
|
27
|
-
usage_sequence = None
|
28
|
-
model_exclude_list = []
|
29
|
-
input_token_name = "prompt_tokens"
|
30
|
-
output_token_name = "completion_tokens"
|
31
|
-
|
32
|
-
@classmethod
|
33
|
-
def available(cls) -> list[str]:
|
34
|
-
return ["test"]
|
35
|
-
|
36
|
-
@classmethod
|
37
|
-
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
38
|
-
throw_exception = False
|
39
|
-
|
40
|
-
class TestServiceLanguageModel(LanguageModel):
|
41
|
-
_model_ = "test"
|
42
|
-
_parameters_ = {"temperature": 0.5}
|
43
|
-
_inference_service_ = InferenceServiceType.TEST.value
|
44
|
-
usage_sequence = ["usage"]
|
45
|
-
key_sequence = ["message", 0, "text"]
|
46
|
-
input_token_name = cls.input_token_name
|
47
|
-
output_token_name = cls.output_token_name
|
48
|
-
_rpm = 1000
|
49
|
-
_tpm = 100000
|
50
|
-
|
51
|
-
@property
|
52
|
-
def _canned_response(self):
|
53
|
-
if hasattr(self, "canned_response"):
|
54
|
-
return self.canned_response
|
55
|
-
else:
|
56
|
-
return "Hello, world"
|
57
|
-
|
58
|
-
async def async_execute_model_call(
|
59
|
-
self,
|
60
|
-
user_prompt: str,
|
61
|
-
system_prompt: str,
|
62
|
-
# func: Optional[callable] = None,
|
63
|
-
files_list: Optional[List["File"]] = None,
|
64
|
-
) -> dict[str, Any]:
|
65
|
-
await asyncio.sleep(0.1)
|
66
|
-
# return {"message": """{"answer": "Hello, world"}"""}
|
67
|
-
|
68
|
-
if hasattr(self, "func"):
|
69
|
-
return {
|
70
|
-
"message": [
|
71
|
-
{"text": self.func(user_prompt, system_prompt, files_list)}
|
72
|
-
],
|
73
|
-
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
74
|
-
}
|
75
|
-
|
76
|
-
if hasattr(self, "throw_exception") and self.throw_exception:
|
77
|
-
if hasattr(self, "exception_probability"):
|
78
|
-
p = self.exception_probability
|
79
|
-
else:
|
80
|
-
p = 1
|
81
|
-
|
82
|
-
if random.random() < p:
|
83
|
-
raise Exception("This is a test error")
|
84
|
-
return {
|
85
|
-
"message": [{"text": f"{self._canned_response}"}],
|
86
|
-
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
87
|
-
}
|
88
|
-
|
89
|
-
return TestServiceLanguageModel
|
1
|
+
from typing import Any, List, Optional
|
2
|
+
import os
|
3
|
+
import asyncio
|
4
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
+
from edsl.language_models import LanguageModel
|
6
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
7
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
8
|
+
|
9
|
+
from edsl.enums import InferenceServiceType
|
10
|
+
import random
|
11
|
+
|
12
|
+
|
13
|
+
class TestService(InferenceServiceABC):
|
14
|
+
"""OpenAI service class."""
|
15
|
+
|
16
|
+
_inference_service_ = "test"
|
17
|
+
_env_key_name_ = None
|
18
|
+
_base_url_ = None
|
19
|
+
|
20
|
+
_sync_client_ = None
|
21
|
+
_async_client_ = None
|
22
|
+
|
23
|
+
_sync_client_instance = None
|
24
|
+
_async_client_instance = None
|
25
|
+
|
26
|
+
key_sequence = None
|
27
|
+
usage_sequence = None
|
28
|
+
model_exclude_list = []
|
29
|
+
input_token_name = "prompt_tokens"
|
30
|
+
output_token_name = "completion_tokens"
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def available(cls) -> list[str]:
|
34
|
+
return ["test"]
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
38
|
+
throw_exception = False
|
39
|
+
|
40
|
+
class TestServiceLanguageModel(LanguageModel):
|
41
|
+
_model_ = "test"
|
42
|
+
_parameters_ = {"temperature": 0.5}
|
43
|
+
_inference_service_ = InferenceServiceType.TEST.value
|
44
|
+
usage_sequence = ["usage"]
|
45
|
+
key_sequence = ["message", 0, "text"]
|
46
|
+
input_token_name = cls.input_token_name
|
47
|
+
output_token_name = cls.output_token_name
|
48
|
+
_rpm = 1000
|
49
|
+
_tpm = 100000
|
50
|
+
|
51
|
+
@property
|
52
|
+
def _canned_response(self):
|
53
|
+
if hasattr(self, "canned_response"):
|
54
|
+
return self.canned_response
|
55
|
+
else:
|
56
|
+
return "Hello, world"
|
57
|
+
|
58
|
+
async def async_execute_model_call(
|
59
|
+
self,
|
60
|
+
user_prompt: str,
|
61
|
+
system_prompt: str,
|
62
|
+
# func: Optional[callable] = None,
|
63
|
+
files_list: Optional[List["File"]] = None,
|
64
|
+
) -> dict[str, Any]:
|
65
|
+
await asyncio.sleep(0.1)
|
66
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
67
|
+
|
68
|
+
if hasattr(self, "func"):
|
69
|
+
return {
|
70
|
+
"message": [
|
71
|
+
{"text": self.func(user_prompt, system_prompt, files_list)}
|
72
|
+
],
|
73
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
74
|
+
}
|
75
|
+
|
76
|
+
if hasattr(self, "throw_exception") and self.throw_exception:
|
77
|
+
if hasattr(self, "exception_probability"):
|
78
|
+
p = self.exception_probability
|
79
|
+
else:
|
80
|
+
p = 1
|
81
|
+
|
82
|
+
if random.random() < p:
|
83
|
+
raise Exception("This is a test error")
|
84
|
+
return {
|
85
|
+
"message": [{"text": f"{self._canned_response}"}],
|
86
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
87
|
+
}
|
88
|
+
|
89
|
+
return TestServiceLanguageModel
|