edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +135 -219
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +47 -56
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +4 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +48 -54
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +6 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +37 -22
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +93 -14
- edsl/jobs/interviews/Interview.py +366 -78
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
- edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
- edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +148 -213
- edsl/language_models/LanguageModel.py +261 -156
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +23 -6
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -41
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +52 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +159 -65
- edsl/questions/QuestionNumerical.py +88 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +46 -48
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +96 -25
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +361 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +637 -311
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.32.dist-info/RECORD +0 -209
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -11,6 +11,11 @@ class AnthropicService(InferenceServiceABC):
|
|
11
11
|
|
12
12
|
_inference_service_ = "anthropic"
|
13
13
|
_env_key_name_ = "ANTHROPIC_API_KEY"
|
14
|
+
key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
input_token_name = "input_tokens"
|
17
|
+
output_token_name = "output_tokens"
|
18
|
+
model_exclude_list = []
|
14
19
|
|
15
20
|
@classmethod
|
16
21
|
def available(cls):
|
@@ -34,6 +39,11 @@ class AnthropicService(InferenceServiceABC):
|
|
34
39
|
Child class of LanguageModel for interacting with OpenAI models
|
35
40
|
"""
|
36
41
|
|
42
|
+
key_sequence = cls.key_sequence
|
43
|
+
usage_sequence = cls.usage_sequence
|
44
|
+
input_token_name = cls.input_token_name
|
45
|
+
output_token_name = cls.output_token_name
|
46
|
+
|
37
47
|
_inference_service_ = cls._inference_service_
|
38
48
|
_model_ = model_name
|
39
49
|
_parameters_ = {
|
@@ -46,6 +56,9 @@ class AnthropicService(InferenceServiceABC):
|
|
46
56
|
"top_logprobs": 3,
|
47
57
|
}
|
48
58
|
|
59
|
+
_tpm = cls.get_tpm(cls)
|
60
|
+
_rpm = cls.get_rpm(cls)
|
61
|
+
|
49
62
|
async def async_execute_model_call(
|
50
63
|
self, user_prompt: str, system_prompt: str = ""
|
51
64
|
) -> dict[str, Any]:
|
@@ -66,17 +79,6 @@ class AnthropicService(InferenceServiceABC):
|
|
66
79
|
)
|
67
80
|
return response.model_dump()
|
68
81
|
|
69
|
-
@staticmethod
|
70
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
71
|
-
"""Parses the API response and returns the response text."""
|
72
|
-
response = raw_response["content"][0]["text"]
|
73
|
-
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
74
|
-
match = re.match(pattern, response, re.DOTALL)
|
75
|
-
if match:
|
76
|
-
return match.group(1)
|
77
|
-
else:
|
78
|
-
return response
|
79
|
-
|
80
82
|
LLM.__name__ = model_class_name
|
81
83
|
|
82
84
|
return LLM
|
@@ -16,6 +16,18 @@ class AwsBedrockService(InferenceServiceABC):
|
|
16
16
|
_env_key_name_ = (
|
17
17
|
"AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
|
18
18
|
)
|
19
|
+
key_sequence = ["output", "message", "content", 0, "text"]
|
20
|
+
input_token_name = "inputTokens"
|
21
|
+
output_token_name = "outputTokens"
|
22
|
+
usage_sequence = ["usage"]
|
23
|
+
model_exclude_list = [
|
24
|
+
"ai21.j2-grande-instruct",
|
25
|
+
"ai21.j2-jumbo-instruct",
|
26
|
+
"ai21.j2-mid",
|
27
|
+
"ai21.j2-mid-v1",
|
28
|
+
"ai21.j2-ultra",
|
29
|
+
"ai21.j2-ultra-v1",
|
30
|
+
]
|
19
31
|
|
20
32
|
@classmethod
|
21
33
|
def available(cls):
|
@@ -28,7 +40,7 @@ class AwsBedrockService(InferenceServiceABC):
|
|
28
40
|
else:
|
29
41
|
all_models_ids = cls._models_list_cache
|
30
42
|
|
31
|
-
return all_models_ids
|
43
|
+
return [m for m in all_models_ids if m not in cls.model_exclude_list]
|
32
44
|
|
33
45
|
@classmethod
|
34
46
|
def create_model(
|
@@ -42,6 +54,8 @@ class AwsBedrockService(InferenceServiceABC):
|
|
42
54
|
Child class of LanguageModel for interacting with AWS Bedrock models.
|
43
55
|
"""
|
44
56
|
|
57
|
+
key_sequence = cls.key_sequence
|
58
|
+
usage_sequence = cls.usage_sequence
|
45
59
|
_inference_service_ = cls._inference_service_
|
46
60
|
_model_ = model_name
|
47
61
|
_parameters_ = {
|
@@ -49,6 +63,10 @@ class AwsBedrockService(InferenceServiceABC):
|
|
49
63
|
"max_tokens": 512,
|
50
64
|
"top_p": 0.9,
|
51
65
|
}
|
66
|
+
input_token_name = cls.input_token_name
|
67
|
+
output_token_name = cls.output_token_name
|
68
|
+
_rpm = cls.get_rpm(cls)
|
69
|
+
_tpm = cls.get_tpm(cls)
|
52
70
|
|
53
71
|
async def async_execute_model_call(
|
54
72
|
self, user_prompt: str, system_prompt: str = ""
|
@@ -89,22 +107,6 @@ class AwsBedrockService(InferenceServiceABC):
|
|
89
107
|
print(e)
|
90
108
|
return {"error": str(e)}
|
91
109
|
|
92
|
-
@staticmethod
|
93
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
94
|
-
"""Parses the API response and returns the response text."""
|
95
|
-
if "output" in raw_response and "message" in raw_response["output"]:
|
96
|
-
response = raw_response["output"]["message"]["content"][0]["text"]
|
97
|
-
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
98
|
-
match = re.match(pattern, response, re.DOTALL)
|
99
|
-
if match:
|
100
|
-
return match.group(1)
|
101
|
-
else:
|
102
|
-
out = fix_partial_correct_response(response)
|
103
|
-
if "error" not in out:
|
104
|
-
response = out["extracted_json"]
|
105
|
-
return response
|
106
|
-
return "Error parsing response"
|
107
|
-
|
108
110
|
LLM.__name__ = model_class_name
|
109
111
|
|
110
112
|
return LLM
|
@@ -25,11 +25,22 @@ def json_handle_none(value: Any) -> Any:
|
|
25
25
|
class AzureAIService(InferenceServiceABC):
|
26
26
|
"""Azure AI service class."""
|
27
27
|
|
28
|
+
# key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
29
|
+
key_sequence = ["choices", 0, "message", "content"]
|
30
|
+
usage_sequence = ["usage"]
|
31
|
+
input_token_name = "prompt_tokens"
|
32
|
+
output_token_name = "completion_tokens"
|
33
|
+
|
28
34
|
_inference_service_ = "azure"
|
29
35
|
_env_key_name_ = (
|
30
36
|
"AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
|
31
37
|
)
|
32
38
|
_model_id_to_endpoint_and_key = {}
|
39
|
+
model_exclude_list = [
|
40
|
+
"Cohere-command-r-plus-xncmg",
|
41
|
+
"Mistral-Nemo-klfsi",
|
42
|
+
"Mistral-large-2407-ojfld",
|
43
|
+
]
|
33
44
|
|
34
45
|
@classmethod
|
35
46
|
def available(cls):
|
@@ -82,7 +93,7 @@ class AzureAIService(InferenceServiceABC):
|
|
82
93
|
|
83
94
|
except Exception as e:
|
84
95
|
raise e
|
85
|
-
return out
|
96
|
+
return [m for m in out if m not in cls.model_exclude_list]
|
86
97
|
|
87
98
|
@classmethod
|
88
99
|
def create_model(
|
@@ -96,6 +107,10 @@ class AzureAIService(InferenceServiceABC):
|
|
96
107
|
Child class of LanguageModel for interacting with Azure OpenAI models.
|
97
108
|
"""
|
98
109
|
|
110
|
+
key_sequence = cls.key_sequence
|
111
|
+
usage_sequence = cls.usage_sequence
|
112
|
+
input_token_name = cls.input_token_name
|
113
|
+
output_token_name = cls.output_token_name
|
99
114
|
_inference_service_ = cls._inference_service_
|
100
115
|
_model_ = model_name
|
101
116
|
_parameters_ = {
|
@@ -103,6 +118,8 @@ class AzureAIService(InferenceServiceABC):
|
|
103
118
|
"max_tokens": 512,
|
104
119
|
"top_p": 0.9,
|
105
120
|
}
|
121
|
+
_rpm = cls.get_rpm(cls)
|
122
|
+
_tpm = cls.get_tpm(cls)
|
106
123
|
|
107
124
|
async def async_execute_model_call(
|
108
125
|
self, user_prompt: str, system_prompt: str = ""
|
@@ -172,25 +189,25 @@ class AzureAIService(InferenceServiceABC):
|
|
172
189
|
)
|
173
190
|
return response.model_dump()
|
174
191
|
|
175
|
-
@staticmethod
|
176
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
192
|
+
# @staticmethod
|
193
|
+
# def parse_response(raw_response: dict[str, Any]) -> str:
|
194
|
+
# """Parses the API response and returns the response text."""
|
195
|
+
# if (
|
196
|
+
# raw_response
|
197
|
+
# and "choices" in raw_response
|
198
|
+
# and raw_response["choices"]
|
199
|
+
# ):
|
200
|
+
# response = raw_response["choices"][0]["message"]["content"]
|
201
|
+
# pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
202
|
+
# match = re.match(pattern, response, re.DOTALL)
|
203
|
+
# if match:
|
204
|
+
# return match.group(1)
|
205
|
+
# else:
|
206
|
+
# out = fix_partial_correct_response(response)
|
207
|
+
# if "error" not in out:
|
208
|
+
# response = out["extracted_json"]
|
209
|
+
# return response
|
210
|
+
# return "Error parsing response"
|
194
211
|
|
195
212
|
LLM.__name__ = model_class_name
|
196
213
|
|
@@ -10,10 +10,16 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
|
10
10
|
|
11
11
|
class GoogleService(InferenceServiceABC):
|
12
12
|
_inference_service_ = "google"
|
13
|
+
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
+
usage_sequence = ["usageMetadata"]
|
15
|
+
input_token_name = "promptTokenCount"
|
16
|
+
output_token_name = "candidatesTokenCount"
|
17
|
+
|
18
|
+
model_exclude_list = []
|
13
19
|
|
14
20
|
@classmethod
|
15
21
|
def available(cls):
|
16
|
-
return ["gemini-pro"]
|
22
|
+
return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
17
23
|
|
18
24
|
@classmethod
|
19
25
|
def create_model(
|
@@ -24,7 +30,15 @@ class GoogleService(InferenceServiceABC):
|
|
24
30
|
|
25
31
|
class LLM(LanguageModel):
|
26
32
|
_model_ = model_name
|
33
|
+
key_sequence = cls.key_sequence
|
34
|
+
usage_sequence = cls.usage_sequence
|
35
|
+
input_token_name = cls.input_token_name
|
36
|
+
output_token_name = cls.output_token_name
|
27
37
|
_inference_service_ = cls._inference_service_
|
38
|
+
|
39
|
+
_tpm = cls.get_tpm(cls)
|
40
|
+
_rpm = cls.get_rpm(cls)
|
41
|
+
|
28
42
|
_parameters_ = {
|
29
43
|
"temperature": 0.5,
|
30
44
|
"topP": 1,
|
@@ -50,7 +64,7 @@ class GoogleService(InferenceServiceABC):
|
|
50
64
|
"stopSequences": self.stopSequences,
|
51
65
|
},
|
52
66
|
}
|
53
|
-
|
67
|
+
# print(combined_prompt)
|
54
68
|
async with aiohttp.ClientSession() as session:
|
55
69
|
async with session.post(
|
56
70
|
url, headers=headers, data=json.dumps(data)
|
@@ -58,16 +72,6 @@ class GoogleService(InferenceServiceABC):
|
|
58
72
|
raw_response_text = await response.text()
|
59
73
|
return json.loads(raw_response_text)
|
60
74
|
|
61
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
62
|
-
data = raw_response
|
63
|
-
try:
|
64
|
-
return data["candidates"][0]["content"]["parts"][0]["text"]
|
65
|
-
except KeyError as e:
|
66
|
-
print(
|
67
|
-
f"The data return was {data}, which was missing the key 'candidates'"
|
68
|
-
)
|
69
|
-
raise e
|
70
|
-
|
71
75
|
LLM.__name__ = model_name
|
72
76
|
|
73
77
|
return LLM
|
@@ -13,6 +13,8 @@ class GroqService(OpenAIService):
|
|
13
13
|
_sync_client_ = groq.Groq
|
14
14
|
_async_client_ = groq.AsyncGroq
|
15
15
|
|
16
|
+
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
+
|
16
18
|
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
19
|
_base_url_ = None
|
18
20
|
_models_list_cache: List[str] = []
|
@@ -1,22 +1,77 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
|
-
|
2
|
+
import os
|
3
3
|
import re
|
4
|
+
from edsl.config import CONFIG
|
4
5
|
|
5
6
|
|
6
7
|
class InferenceServiceABC(ABC):
|
7
|
-
"""
|
8
|
+
"""
|
9
|
+
Abstract class for inference services.
|
10
|
+
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
11
|
+
"""
|
12
|
+
|
13
|
+
default_levels = {
|
14
|
+
"google": {"tpm": 2_000_000, "rpm": 15},
|
15
|
+
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
16
|
+
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
17
|
+
}
|
18
|
+
|
19
|
+
def __init_subclass__(cls):
|
20
|
+
"""
|
21
|
+
Check that the subclass has the required attributes.
|
22
|
+
- `key_sequence` attribute determines...
|
23
|
+
- `model_exclude_list` attribute determines...
|
24
|
+
"""
|
25
|
+
if not hasattr(cls, "key_sequence"):
|
26
|
+
raise NotImplementedError(
|
27
|
+
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
28
|
+
)
|
29
|
+
if not hasattr(cls, "model_exclude_list"):
|
30
|
+
raise NotImplementedError(
|
31
|
+
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
32
|
+
)
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def _get_limt(cls, limit_type: str) -> int:
|
36
|
+
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
37
|
+
if key in os.environ:
|
38
|
+
return int(os.getenv(key))
|
39
|
+
|
40
|
+
if cls._inference_service_ in cls.default_levels:
|
41
|
+
return int(cls.default_levels[cls._inference_service_][limit_type])
|
42
|
+
|
43
|
+
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
44
|
+
|
45
|
+
def get_tpm(cls) -> int:
|
46
|
+
"""
|
47
|
+
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
48
|
+
"""
|
49
|
+
return cls._get_limt(limit_type="tpm")
|
50
|
+
|
51
|
+
def get_rpm(cls):
|
52
|
+
"""
|
53
|
+
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
54
|
+
"""
|
55
|
+
return cls._get_limt(limit_type="rpm")
|
8
56
|
|
9
57
|
@abstractmethod
|
10
58
|
def available() -> list[str]:
|
59
|
+
"""
|
60
|
+
Returns a list of available models for the service.
|
61
|
+
"""
|
11
62
|
pass
|
12
63
|
|
13
64
|
@abstractmethod
|
14
65
|
def create_model():
|
66
|
+
"""
|
67
|
+
Returns a LanguageModel object.
|
68
|
+
"""
|
15
69
|
pass
|
16
70
|
|
17
71
|
@staticmethod
|
18
72
|
def to_class_name(s):
|
19
|
-
"""
|
73
|
+
"""
|
74
|
+
Converts a string to a valid class name.
|
20
75
|
|
21
76
|
>>> InferenceServiceABC.to_class_name("hello world")
|
22
77
|
'HelloWorld'
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List
|
3
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
import asyncio
|
6
|
+
from mistralai import Mistral
|
7
|
+
|
8
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
+
|
10
|
+
|
11
|
+
class MistralAIService(InferenceServiceABC):
|
12
|
+
"""Mistral AI service class."""
|
13
|
+
|
14
|
+
key_sequence = ["choices", 0, "message", "content"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
|
17
|
+
_inference_service_ = "mistral"
|
18
|
+
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
+
input_token_name = "prompt_tokens"
|
20
|
+
output_token_name = "completion_tokens"
|
21
|
+
|
22
|
+
_sync_client_instance = None
|
23
|
+
_async_client_instance = None
|
24
|
+
|
25
|
+
_sync_client = Mistral
|
26
|
+
_async_client = Mistral
|
27
|
+
|
28
|
+
_models_list_cache: List[str] = []
|
29
|
+
model_exclude_list = []
|
30
|
+
|
31
|
+
def __init_subclass__(cls, **kwargs):
|
32
|
+
super().__init_subclass__(**kwargs)
|
33
|
+
# so subclasses have to create their own instances of the clients
|
34
|
+
cls._sync_client_instance = None
|
35
|
+
cls._async_client_instance = None
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def sync_client(cls):
|
39
|
+
if cls._sync_client_instance is None:
|
40
|
+
cls._sync_client_instance = cls._sync_client(
|
41
|
+
api_key=os.getenv(cls._env_key_name_)
|
42
|
+
)
|
43
|
+
return cls._sync_client_instance
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def async_client(cls):
|
47
|
+
if cls._async_client_instance is None:
|
48
|
+
cls._async_client_instance = cls._async_client(
|
49
|
+
api_key=os.getenv(cls._env_key_name_)
|
50
|
+
)
|
51
|
+
return cls._async_client_instance
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def available(cls) -> list[str]:
|
55
|
+
if not cls._models_list_cache:
|
56
|
+
cls._models_list_cache = [
|
57
|
+
m.id for m in cls.sync_client().models.list().data
|
58
|
+
]
|
59
|
+
|
60
|
+
return cls._models_list_cache
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def create_model(
|
64
|
+
cls, model_name: str = "mistral", model_class_name=None
|
65
|
+
) -> LanguageModel:
|
66
|
+
if model_class_name is None:
|
67
|
+
model_class_name = cls.to_class_name(model_name)
|
68
|
+
|
69
|
+
class LLM(LanguageModel):
|
70
|
+
"""
|
71
|
+
Child class of LanguageModel for interacting with Mistral models.
|
72
|
+
"""
|
73
|
+
|
74
|
+
key_sequence = cls.key_sequence
|
75
|
+
usage_sequence = cls.usage_sequence
|
76
|
+
|
77
|
+
input_token_name = cls.input_token_name
|
78
|
+
output_token_name = cls.output_token_name
|
79
|
+
|
80
|
+
_inference_service_ = cls._inference_service_
|
81
|
+
_model_ = model_name
|
82
|
+
_parameters_ = {
|
83
|
+
"temperature": 0.5,
|
84
|
+
"max_tokens": 512,
|
85
|
+
"top_p": 0.9,
|
86
|
+
}
|
87
|
+
|
88
|
+
_tpm = cls.get_tpm(cls)
|
89
|
+
_rpm = cls.get_rpm(cls)
|
90
|
+
|
91
|
+
def sync_client(self):
|
92
|
+
return cls.sync_client()
|
93
|
+
|
94
|
+
def async_client(self):
|
95
|
+
return cls.async_client()
|
96
|
+
|
97
|
+
async def async_execute_model_call(
|
98
|
+
self, user_prompt: str, system_prompt: str = ""
|
99
|
+
) -> dict[str, Any]:
|
100
|
+
"""Calls the Mistral API and returns the API response."""
|
101
|
+
s = self.async_client()
|
102
|
+
|
103
|
+
try:
|
104
|
+
res = await s.chat.complete_async(
|
105
|
+
model=model_name,
|
106
|
+
messages=[
|
107
|
+
{
|
108
|
+
"content": user_prompt,
|
109
|
+
"role": "user",
|
110
|
+
},
|
111
|
+
],
|
112
|
+
)
|
113
|
+
except Exception as e:
|
114
|
+
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
115
|
+
|
116
|
+
return res.model_dump()
|
117
|
+
|
118
|
+
LLM.__name__ = model_class_name
|
119
|
+
|
120
|
+
return LLM
|
@@ -1,8 +1,7 @@
|
|
1
|
-
from
|
2
|
-
import
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
import os
|
4
4
|
|
5
|
-
# from openai import AsyncOpenAI
|
6
5
|
import openai
|
7
6
|
|
8
7
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -10,6 +9,8 @@ from edsl.language_models import LanguageModel
|
|
10
9
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
11
10
|
from edsl.utilities.utilities import fix_partial_correct_response
|
12
11
|
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
|
13
14
|
|
14
15
|
class OpenAIService(InferenceServiceABC):
|
15
16
|
"""OpenAI service class."""
|
@@ -21,19 +22,36 @@ class OpenAIService(InferenceServiceABC):
|
|
21
22
|
_sync_client_ = openai.OpenAI
|
22
23
|
_async_client_ = openai.AsyncOpenAI
|
23
24
|
|
25
|
+
_sync_client_instance = None
|
26
|
+
_async_client_instance = None
|
27
|
+
|
28
|
+
key_sequence = ["choices", 0, "message", "content"]
|
29
|
+
usage_sequence = ["usage"]
|
30
|
+
input_token_name = "prompt_tokens"
|
31
|
+
output_token_name = "completion_tokens"
|
32
|
+
|
33
|
+
def __init_subclass__(cls, **kwargs):
|
34
|
+
super().__init_subclass__(**kwargs)
|
35
|
+
# so subclasses have to create their own instances of the clients
|
36
|
+
cls._sync_client_instance = None
|
37
|
+
cls._async_client_instance = None
|
38
|
+
|
24
39
|
@classmethod
|
25
40
|
def sync_client(cls):
|
26
|
-
|
27
|
-
|
28
|
-
|
41
|
+
if cls._sync_client_instance is None:
|
42
|
+
cls._sync_client_instance = cls._sync_client_(
|
43
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
44
|
+
)
|
45
|
+
return cls._sync_client_instance
|
29
46
|
|
30
47
|
@classmethod
|
31
48
|
def async_client(cls):
|
32
|
-
|
33
|
-
|
34
|
-
|
49
|
+
if cls._async_client_instance is None:
|
50
|
+
cls._async_client_instance = cls._async_client_(
|
51
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
52
|
+
)
|
53
|
+
return cls._async_client_instance
|
35
54
|
|
36
|
-
# TODO: Make this a coop call
|
37
55
|
model_exclude_list = [
|
38
56
|
"whisper-1",
|
39
57
|
"davinci-002",
|
@@ -48,6 +66,8 @@ class OpenAIService(InferenceServiceABC):
|
|
48
66
|
"text-embedding-3-small",
|
49
67
|
"text-embedding-ada-002",
|
50
68
|
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
69
|
+
"gpt-3.5-turbo-instruct-0914",
|
70
|
+
"gpt-3.5-turbo-instruct",
|
51
71
|
]
|
52
72
|
_models_list_cache: List[str] = []
|
53
73
|
|
@@ -61,11 +81,8 @@ class OpenAIService(InferenceServiceABC):
|
|
61
81
|
|
62
82
|
@classmethod
|
63
83
|
def available(cls) -> List[str]:
|
64
|
-
# from openai import OpenAI
|
65
|
-
|
66
84
|
if not cls._models_list_cache:
|
67
85
|
try:
|
68
|
-
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
69
86
|
cls._models_list_cache = [
|
70
87
|
m.id
|
71
88
|
for m in cls.get_model_list()
|
@@ -73,15 +90,6 @@ class OpenAIService(InferenceServiceABC):
|
|
73
90
|
]
|
74
91
|
except Exception as e:
|
75
92
|
raise
|
76
|
-
# print(
|
77
|
-
# f"""Error retrieving models: {e}.
|
78
|
-
# See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
|
79
|
-
# )
|
80
|
-
# cls._models_list_cache = [
|
81
|
-
# "gpt-3.5-turbo",
|
82
|
-
# "gpt-4-1106-preview",
|
83
|
-
# "gpt-4",
|
84
|
-
# ] # Fallback list
|
85
93
|
return cls._models_list_cache
|
86
94
|
|
87
95
|
@classmethod
|
@@ -94,6 +102,14 @@ class OpenAIService(InferenceServiceABC):
|
|
94
102
|
Child class of LanguageModel for interacting with OpenAI models
|
95
103
|
"""
|
96
104
|
|
105
|
+
key_sequence = cls.key_sequence
|
106
|
+
usage_sequence = cls.usage_sequence
|
107
|
+
input_token_name = cls.input_token_name
|
108
|
+
output_token_name = cls.output_token_name
|
109
|
+
|
110
|
+
_rpm = cls.get_rpm(cls)
|
111
|
+
_tpm = cls.get_tpm(cls)
|
112
|
+
|
97
113
|
_inference_service_ = cls._inference_service_
|
98
114
|
_model_ = model_name
|
99
115
|
_parameters_ = {
|
@@ -114,15 +130,9 @@ class OpenAIService(InferenceServiceABC):
|
|
114
130
|
|
115
131
|
@classmethod
|
116
132
|
def available(cls) -> list[str]:
|
117
|
-
# import openai
|
118
|
-
# client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
119
|
-
# return client.models.list()
|
120
133
|
return cls.sync_client().models.list()
|
121
134
|
|
122
135
|
def get_headers(self) -> dict[str, Any]:
|
123
|
-
# from openai import OpenAI
|
124
|
-
|
125
|
-
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
126
136
|
client = self.sync_client()
|
127
137
|
response = client.chat.completions.with_raw_response.create(
|
128
138
|
messages=[
|
@@ -159,6 +169,9 @@ class OpenAIService(InferenceServiceABC):
|
|
159
169
|
user_prompt: str,
|
160
170
|
system_prompt: str = "",
|
161
171
|
encoded_image=None,
|
172
|
+
invigilator: Optional[
|
173
|
+
"InvigilatorAI"
|
174
|
+
] = None, # TBD - can eventually be used for function-calling
|
162
175
|
) -> dict[str, Any]:
|
163
176
|
"""Calls the OpenAI API and returns the API response."""
|
164
177
|
if encoded_image:
|
@@ -173,17 +186,16 @@ class OpenAIService(InferenceServiceABC):
|
|
173
186
|
)
|
174
187
|
else:
|
175
188
|
content = user_prompt
|
176
|
-
# self.client = AsyncOpenAI(
|
177
|
-
# api_key = os.getenv(cls._env_key_name_),
|
178
|
-
# base_url = cls._base_url_
|
179
|
-
# )
|
180
189
|
client = self.async_client()
|
190
|
+
messages = [
|
191
|
+
{"role": "system", "content": system_prompt},
|
192
|
+
{"role": "user", "content": content},
|
193
|
+
]
|
194
|
+
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
195
|
+
messages = messages[1:]
|
181
196
|
params = {
|
182
197
|
"model": self.model,
|
183
|
-
"messages":
|
184
|
-
{"role": "system", "content": system_prompt},
|
185
|
-
{"role": "user", "content": content},
|
186
|
-
],
|
198
|
+
"messages": messages,
|
187
199
|
"temperature": self.temperature,
|
188
200
|
"max_tokens": self.max_tokens,
|
189
201
|
"top_p": self.top_p,
|
@@ -195,24 +207,6 @@ class OpenAIService(InferenceServiceABC):
|
|
195
207
|
response = await client.chat.completions.create(**params)
|
196
208
|
return response.model_dump()
|
197
209
|
|
198
|
-
@staticmethod
|
199
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
200
|
-
"""Parses the API response and returns the response text."""
|
201
|
-
try:
|
202
|
-
response = raw_response["choices"][0]["message"]["content"]
|
203
|
-
except KeyError:
|
204
|
-
print("Tried to parse response but failed:")
|
205
|
-
print(raw_response)
|
206
|
-
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
207
|
-
match = re.match(pattern, response, re.DOTALL)
|
208
|
-
if match:
|
209
|
-
return match.group(1)
|
210
|
-
else:
|
211
|
-
out = fix_partial_correct_response(response)
|
212
|
-
if "error" not in out:
|
213
|
-
response = out["extracted_json"]
|
214
|
-
return response
|
215
|
-
|
216
210
|
LLM.__name__ = "LanguageModel"
|
217
211
|
|
218
212
|
return LLM
|