edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +136 -221
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +48 -47
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data/CacheHandler.py +3 -4
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +8 -0
- edsl/exceptions/general.py +10 -8
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +112 -0
- edsl/inference_services/AzureAI.py +214 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +5 -4
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +55 -56
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +25 -0
- edsl/inference_services/registry.py +19 -1
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +137 -41
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +105 -18
- edsl/jobs/interviews/Interview.py +393 -83
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
- edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +205 -126
- edsl/language_models/LanguageModel.py +297 -177
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +25 -8
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -42
- edsl/questions/QuestionCheckBox.py +227 -36
- edsl/questions/QuestionExtract.py +98 -28
- edsl/questions/QuestionFreeText.py +47 -31
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -23
- edsl/questions/QuestionMultipleChoice.py +159 -66
- edsl/questions/QuestionNumerical.py +88 -47
- edsl/questions/QuestionRank.py +182 -25
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +58 -30
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +109 -24
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +546 -21
- edsl/scenarios/ScenarioListExportMixin.py +24 -4
- edsl/scenarios/ScenarioListPdfMixin.py +153 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +15 -3
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +707 -298
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +40 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.31.dev4.dist-info/RECORD +0 -204
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
edsl/exceptions/questions.py
CHANGED
@@ -1,16 +1,73 @@
|
|
1
|
+
from typing import Any, SupportsIndex
|
2
|
+
from jinja2 import Template
|
3
|
+
import json
|
4
|
+
|
5
|
+
|
1
6
|
class QuestionErrors(Exception):
|
2
|
-
|
7
|
+
"""
|
8
|
+
Base exception class for question-related errors.
|
9
|
+
"""
|
3
10
|
|
11
|
+
def __init__(self, message="An error occurred with the question"):
|
12
|
+
self.message = message
|
13
|
+
super().__init__(self.message)
|
4
14
|
|
5
|
-
class QuestionCreationValidationError(QuestionErrors):
|
6
|
-
pass
|
7
15
|
|
16
|
+
class QuestionAnswerValidationError(QuestionErrors):
|
17
|
+
documentation = "https://docs.expectedparrot.com/en/latest/exceptions.html"
|
18
|
+
|
19
|
+
explanation = """This when the answer coming from the Language Model does not conform to the expectation for that question type.
|
20
|
+
For example, if the question is a multiple choice question, the answer should be drawn from the list of options provided.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, message="Invalid answer.", data=None, model=None):
|
24
|
+
self.message = message
|
25
|
+
self.data = data
|
26
|
+
self.model = model
|
27
|
+
super().__init__(self.message)
|
28
|
+
|
29
|
+
def __str__(self):
|
30
|
+
return f"""{repr(self)}
|
31
|
+
Data being validated: {self.data}
|
32
|
+
Pydnantic Model: {self.model}.
|
33
|
+
Reported error: {self.message}."""
|
34
|
+
|
35
|
+
def to_html_dict(self):
|
36
|
+
return {
|
37
|
+
"error_type": ("Name of the exception", "p", "/p", self.__class__.__name__),
|
38
|
+
"explaination": ("Explanation", "p", "/p", self.explanation),
|
39
|
+
"edsl answer": (
|
40
|
+
"What model returned",
|
41
|
+
"pre",
|
42
|
+
"/pre",
|
43
|
+
json.dumps(self.data, indent=2),
|
44
|
+
),
|
45
|
+
"validating_model": (
|
46
|
+
"Pydantic model for answers",
|
47
|
+
"pre",
|
48
|
+
"/pre",
|
49
|
+
json.dumps(self.model.model_json_schema(), indent=2),
|
50
|
+
),
|
51
|
+
"error_message": (
|
52
|
+
"Error message Pydantic returned",
|
53
|
+
"p",
|
54
|
+
"/p",
|
55
|
+
self.message,
|
56
|
+
),
|
57
|
+
"documentation_url": (
|
58
|
+
"URL to EDSL docs",
|
59
|
+
f"a href='{self.documentation}'",
|
60
|
+
"/a",
|
61
|
+
self.documentation,
|
62
|
+
),
|
63
|
+
}
|
8
64
|
|
9
|
-
|
65
|
+
|
66
|
+
class QuestionCreationValidationError(QuestionErrors):
|
10
67
|
pass
|
11
68
|
|
12
69
|
|
13
|
-
class
|
70
|
+
class QuestionResponseValidationError(QuestionErrors):
|
14
71
|
pass
|
15
72
|
|
16
73
|
|
edsl/exceptions/results.py
CHANGED
@@ -11,6 +11,11 @@ class AnthropicService(InferenceServiceABC):
|
|
11
11
|
|
12
12
|
_inference_service_ = "anthropic"
|
13
13
|
_env_key_name_ = "ANTHROPIC_API_KEY"
|
14
|
+
key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
input_token_name = "input_tokens"
|
17
|
+
output_token_name = "output_tokens"
|
18
|
+
model_exclude_list = []
|
14
19
|
|
15
20
|
@classmethod
|
16
21
|
def available(cls):
|
@@ -34,6 +39,11 @@ class AnthropicService(InferenceServiceABC):
|
|
34
39
|
Child class of LanguageModel for interacting with OpenAI models
|
35
40
|
"""
|
36
41
|
|
42
|
+
key_sequence = cls.key_sequence
|
43
|
+
usage_sequence = cls.usage_sequence
|
44
|
+
input_token_name = cls.input_token_name
|
45
|
+
output_token_name = cls.output_token_name
|
46
|
+
|
37
47
|
_inference_service_ = cls._inference_service_
|
38
48
|
_model_ = model_name
|
39
49
|
_parameters_ = {
|
@@ -46,6 +56,9 @@ class AnthropicService(InferenceServiceABC):
|
|
46
56
|
"top_logprobs": 3,
|
47
57
|
}
|
48
58
|
|
59
|
+
_tpm = cls.get_tpm(cls)
|
60
|
+
_rpm = cls.get_rpm(cls)
|
61
|
+
|
49
62
|
async def async_execute_model_call(
|
50
63
|
self, user_prompt: str, system_prompt: str = ""
|
51
64
|
) -> dict[str, Any]:
|
@@ -66,17 +79,6 @@ class AnthropicService(InferenceServiceABC):
|
|
66
79
|
)
|
67
80
|
return response.model_dump()
|
68
81
|
|
69
|
-
@staticmethod
|
70
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
71
|
-
"""Parses the API response and returns the response text."""
|
72
|
-
response = raw_response["content"][0]["text"]
|
73
|
-
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
74
|
-
match = re.match(pattern, response, re.DOTALL)
|
75
|
-
if match:
|
76
|
-
return match.group(1)
|
77
|
-
else:
|
78
|
-
return response
|
79
|
-
|
80
82
|
LLM.__name__ = model_class_name
|
81
83
|
|
82
84
|
return LLM
|
@@ -0,0 +1,112 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import re
|
4
|
+
import boto3
|
5
|
+
from botocore.exceptions import ClientError
|
6
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
8
|
+
import json
|
9
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
10
|
+
|
11
|
+
|
12
|
+
class AwsBedrockService(InferenceServiceABC):
|
13
|
+
"""AWS Bedrock service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "bedrock"
|
16
|
+
_env_key_name_ = (
|
17
|
+
"AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
|
18
|
+
)
|
19
|
+
key_sequence = ["output", "message", "content", 0, "text"]
|
20
|
+
input_token_name = "inputTokens"
|
21
|
+
output_token_name = "outputTokens"
|
22
|
+
usage_sequence = ["usage"]
|
23
|
+
model_exclude_list = [
|
24
|
+
"ai21.j2-grande-instruct",
|
25
|
+
"ai21.j2-jumbo-instruct",
|
26
|
+
"ai21.j2-mid",
|
27
|
+
"ai21.j2-mid-v1",
|
28
|
+
"ai21.j2-ultra",
|
29
|
+
"ai21.j2-ultra-v1",
|
30
|
+
]
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def available(cls):
|
34
|
+
"""Fetch available models from AWS Bedrock."""
|
35
|
+
if not cls._models_list_cache:
|
36
|
+
client = boto3.client("bedrock", region_name="us-west-2")
|
37
|
+
all_models_ids = [
|
38
|
+
x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
|
39
|
+
]
|
40
|
+
else:
|
41
|
+
all_models_ids = cls._models_list_cache
|
42
|
+
|
43
|
+
return [m for m in all_models_ids if m not in cls.model_exclude_list]
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def create_model(
|
47
|
+
cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
|
48
|
+
) -> LanguageModel:
|
49
|
+
if model_class_name is None:
|
50
|
+
model_class_name = cls.to_class_name(model_name)
|
51
|
+
|
52
|
+
class LLM(LanguageModel):
|
53
|
+
"""
|
54
|
+
Child class of LanguageModel for interacting with AWS Bedrock models.
|
55
|
+
"""
|
56
|
+
|
57
|
+
key_sequence = cls.key_sequence
|
58
|
+
usage_sequence = cls.usage_sequence
|
59
|
+
_inference_service_ = cls._inference_service_
|
60
|
+
_model_ = model_name
|
61
|
+
_parameters_ = {
|
62
|
+
"temperature": 0.5,
|
63
|
+
"max_tokens": 512,
|
64
|
+
"top_p": 0.9,
|
65
|
+
}
|
66
|
+
input_token_name = cls.input_token_name
|
67
|
+
output_token_name = cls.output_token_name
|
68
|
+
_rpm = cls.get_rpm(cls)
|
69
|
+
_tpm = cls.get_tpm(cls)
|
70
|
+
|
71
|
+
async def async_execute_model_call(
|
72
|
+
self, user_prompt: str, system_prompt: str = ""
|
73
|
+
) -> dict[str, Any]:
|
74
|
+
"""Calls the AWS Bedrock API and returns the API response."""
|
75
|
+
|
76
|
+
api_token = (
|
77
|
+
self.api_token
|
78
|
+
) # call to check the if env variables are set.
|
79
|
+
|
80
|
+
client = boto3.client("bedrock-runtime", region_name="us-west-2")
|
81
|
+
|
82
|
+
conversation = [
|
83
|
+
{
|
84
|
+
"role": "user",
|
85
|
+
"content": [{"text": user_prompt}],
|
86
|
+
}
|
87
|
+
]
|
88
|
+
system = [
|
89
|
+
{
|
90
|
+
"text": system_prompt,
|
91
|
+
}
|
92
|
+
]
|
93
|
+
try:
|
94
|
+
response = client.converse(
|
95
|
+
modelId=self._model_,
|
96
|
+
messages=conversation,
|
97
|
+
inferenceConfig={
|
98
|
+
"maxTokens": self.max_tokens,
|
99
|
+
"temperature": self.temperature,
|
100
|
+
"topP": self.top_p,
|
101
|
+
},
|
102
|
+
# system=system,
|
103
|
+
additionalModelRequestFields={},
|
104
|
+
)
|
105
|
+
return response
|
106
|
+
except (ClientError, Exception) as e:
|
107
|
+
print(e)
|
108
|
+
return {"error": str(e)}
|
109
|
+
|
110
|
+
LLM.__name__ = model_class_name
|
111
|
+
|
112
|
+
return LLM
|
@@ -0,0 +1,214 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import re
|
4
|
+
from openai import AsyncAzureOpenAI
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
+
|
8
|
+
from azure.ai.inference.aio import ChatCompletionsClient
|
9
|
+
from azure.core.credentials import AzureKeyCredential
|
10
|
+
from azure.ai.inference.models import SystemMessage, UserMessage
|
11
|
+
import asyncio
|
12
|
+
import json
|
13
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
14
|
+
|
15
|
+
|
16
|
+
def json_handle_none(value: Any) -> Any:
|
17
|
+
"""
|
18
|
+
Handle None values during JSON serialization.
|
19
|
+
- Return "null" if the value is None. Otherwise, don't return anything.
|
20
|
+
"""
|
21
|
+
if value is None:
|
22
|
+
return "null"
|
23
|
+
|
24
|
+
|
25
|
+
class AzureAIService(InferenceServiceABC):
|
26
|
+
"""Azure AI service class."""
|
27
|
+
|
28
|
+
# key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
29
|
+
key_sequence = ["choices", 0, "message", "content"]
|
30
|
+
usage_sequence = ["usage"]
|
31
|
+
input_token_name = "prompt_tokens"
|
32
|
+
output_token_name = "completion_tokens"
|
33
|
+
|
34
|
+
_inference_service_ = "azure"
|
35
|
+
_env_key_name_ = (
|
36
|
+
"AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
|
37
|
+
)
|
38
|
+
_model_id_to_endpoint_and_key = {}
|
39
|
+
model_exclude_list = [
|
40
|
+
"Cohere-command-r-plus-xncmg",
|
41
|
+
"Mistral-Nemo-klfsi",
|
42
|
+
"Mistral-large-2407-ojfld",
|
43
|
+
]
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def available(cls):
|
47
|
+
out = []
|
48
|
+
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
49
|
+
if not azure_endpoints:
|
50
|
+
raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
51
|
+
azure_endpoints = azure_endpoints.split(",")
|
52
|
+
for data in azure_endpoints:
|
53
|
+
try:
|
54
|
+
# data has this format for non openai models https://model_id.azure_endpoint:azure_key
|
55
|
+
_, endpoint, azure_endpoint_key = data.split(":")
|
56
|
+
if "openai" not in endpoint:
|
57
|
+
model_id = endpoint.split(".")[0].replace("/", "")
|
58
|
+
out.append(model_id)
|
59
|
+
cls._model_id_to_endpoint_and_key[model_id] = {
|
60
|
+
"endpoint": f"https:{endpoint}",
|
61
|
+
"azure_endpoint_key": azure_endpoint_key,
|
62
|
+
}
|
63
|
+
else:
|
64
|
+
# data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
|
65
|
+
if "/deployments/" in endpoint:
|
66
|
+
start_idx = endpoint.index("/deployments/") + len(
|
67
|
+
"/deployments/"
|
68
|
+
)
|
69
|
+
end_idx = (
|
70
|
+
endpoint.index("/", start_idx)
|
71
|
+
if "/" in endpoint[start_idx:]
|
72
|
+
else len(endpoint)
|
73
|
+
)
|
74
|
+
model_id = endpoint[start_idx:end_idx]
|
75
|
+
api_version_value = None
|
76
|
+
if "api-version=" in endpoint:
|
77
|
+
start_idx = endpoint.index("api-version=") + len(
|
78
|
+
"api-version="
|
79
|
+
)
|
80
|
+
end_idx = (
|
81
|
+
endpoint.index("&", start_idx)
|
82
|
+
if "&" in endpoint[start_idx:]
|
83
|
+
else len(endpoint)
|
84
|
+
)
|
85
|
+
api_version_value = endpoint[start_idx:end_idx]
|
86
|
+
|
87
|
+
cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
|
88
|
+
"endpoint": f"https:{endpoint}",
|
89
|
+
"azure_endpoint_key": azure_endpoint_key,
|
90
|
+
"api_version": api_version_value,
|
91
|
+
}
|
92
|
+
out.append(f"azure:{model_id}")
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
raise e
|
96
|
+
return [m for m in out if m not in cls.model_exclude_list]
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def create_model(
|
100
|
+
cls, model_name: str = "azureai", model_class_name=None
|
101
|
+
) -> LanguageModel:
|
102
|
+
if model_class_name is None:
|
103
|
+
model_class_name = cls.to_class_name(model_name)
|
104
|
+
|
105
|
+
class LLM(LanguageModel):
|
106
|
+
"""
|
107
|
+
Child class of LanguageModel for interacting with Azure OpenAI models.
|
108
|
+
"""
|
109
|
+
|
110
|
+
key_sequence = cls.key_sequence
|
111
|
+
usage_sequence = cls.usage_sequence
|
112
|
+
input_token_name = cls.input_token_name
|
113
|
+
output_token_name = cls.output_token_name
|
114
|
+
_inference_service_ = cls._inference_service_
|
115
|
+
_model_ = model_name
|
116
|
+
_parameters_ = {
|
117
|
+
"temperature": 0.5,
|
118
|
+
"max_tokens": 512,
|
119
|
+
"top_p": 0.9,
|
120
|
+
}
|
121
|
+
_rpm = cls.get_rpm(cls)
|
122
|
+
_tpm = cls.get_tpm(cls)
|
123
|
+
|
124
|
+
async def async_execute_model_call(
|
125
|
+
self, user_prompt: str, system_prompt: str = ""
|
126
|
+
) -> dict[str, Any]:
|
127
|
+
"""Calls the Azure OpenAI API and returns the API response."""
|
128
|
+
|
129
|
+
try:
|
130
|
+
api_key = cls._model_id_to_endpoint_and_key[model_name][
|
131
|
+
"azure_endpoint_key"
|
132
|
+
]
|
133
|
+
except:
|
134
|
+
api_key = None
|
135
|
+
|
136
|
+
if not api_key:
|
137
|
+
raise EnvironmentError(
|
138
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
139
|
+
)
|
140
|
+
|
141
|
+
try:
|
142
|
+
endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
|
143
|
+
except:
|
144
|
+
endpoint = None
|
145
|
+
|
146
|
+
if not endpoint:
|
147
|
+
raise EnvironmentError(
|
148
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
149
|
+
)
|
150
|
+
|
151
|
+
if "openai" not in endpoint:
|
152
|
+
client = ChatCompletionsClient(
|
153
|
+
endpoint=endpoint,
|
154
|
+
credential=AzureKeyCredential(api_key),
|
155
|
+
temperature=self.temperature,
|
156
|
+
top_p=self.top_p,
|
157
|
+
max_tokens=self.max_tokens,
|
158
|
+
)
|
159
|
+
try:
|
160
|
+
response = await client.complete(
|
161
|
+
messages=[
|
162
|
+
SystemMessage(content=system_prompt),
|
163
|
+
UserMessage(content=user_prompt),
|
164
|
+
],
|
165
|
+
# model_extras={"safe_mode": True},
|
166
|
+
)
|
167
|
+
await client.close()
|
168
|
+
return response.as_dict()
|
169
|
+
except Exception as e:
|
170
|
+
await client.close()
|
171
|
+
return {"error": str(e)}
|
172
|
+
else:
|
173
|
+
api_version = cls._model_id_to_endpoint_and_key[model_name][
|
174
|
+
"api_version"
|
175
|
+
]
|
176
|
+
client = AsyncAzureOpenAI(
|
177
|
+
azure_endpoint=endpoint,
|
178
|
+
api_version=api_version,
|
179
|
+
api_key=api_key,
|
180
|
+
)
|
181
|
+
response = await client.chat.completions.create(
|
182
|
+
model=model_name,
|
183
|
+
messages=[
|
184
|
+
{
|
185
|
+
"role": "user",
|
186
|
+
"content": user_prompt, # Your question can go here
|
187
|
+
},
|
188
|
+
],
|
189
|
+
)
|
190
|
+
return response.model_dump()
|
191
|
+
|
192
|
+
# @staticmethod
|
193
|
+
# def parse_response(raw_response: dict[str, Any]) -> str:
|
194
|
+
# """Parses the API response and returns the response text."""
|
195
|
+
# if (
|
196
|
+
# raw_response
|
197
|
+
# and "choices" in raw_response
|
198
|
+
# and raw_response["choices"]
|
199
|
+
# ):
|
200
|
+
# response = raw_response["choices"][0]["message"]["content"]
|
201
|
+
# pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
202
|
+
# match = re.match(pattern, response, re.DOTALL)
|
203
|
+
# if match:
|
204
|
+
# return match.group(1)
|
205
|
+
# else:
|
206
|
+
# out = fix_partial_correct_response(response)
|
207
|
+
# if "error" not in out:
|
208
|
+
# response = out["extracted_json"]
|
209
|
+
# return response
|
210
|
+
# return "Error parsing response"
|
211
|
+
|
212
|
+
LLM.__name__ = model_class_name
|
213
|
+
|
214
|
+
return LLM
|
@@ -2,16 +2,17 @@ import aiohttp
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from typing import Any, List
|
5
|
-
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
7
|
from edsl.language_models import LanguageModel
|
7
8
|
|
8
9
|
from edsl.inference_services.OpenAIService import OpenAIService
|
9
10
|
|
11
|
+
|
10
12
|
class DeepInfraService(OpenAIService):
|
11
13
|
"""DeepInfra service class."""
|
12
14
|
|
13
15
|
_inference_service_ = "deep_infra"
|
14
16
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
15
|
-
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
|
+
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
16
18
|
_models_list_cache: List[str] = []
|
17
|
-
|
@@ -10,10 +10,16 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
|
10
10
|
|
11
11
|
class GoogleService(InferenceServiceABC):
|
12
12
|
_inference_service_ = "google"
|
13
|
+
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
+
usage_sequence = ["usageMetadata"]
|
15
|
+
input_token_name = "promptTokenCount"
|
16
|
+
output_token_name = "candidatesTokenCount"
|
17
|
+
|
18
|
+
model_exclude_list = []
|
13
19
|
|
14
20
|
@classmethod
|
15
21
|
def available(cls):
|
16
|
-
return ["gemini-pro"]
|
22
|
+
return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
17
23
|
|
18
24
|
@classmethod
|
19
25
|
def create_model(
|
@@ -24,7 +30,15 @@ class GoogleService(InferenceServiceABC):
|
|
24
30
|
|
25
31
|
class LLM(LanguageModel):
|
26
32
|
_model_ = model_name
|
33
|
+
key_sequence = cls.key_sequence
|
34
|
+
usage_sequence = cls.usage_sequence
|
35
|
+
input_token_name = cls.input_token_name
|
36
|
+
output_token_name = cls.output_token_name
|
27
37
|
_inference_service_ = cls._inference_service_
|
38
|
+
|
39
|
+
_tpm = cls.get_tpm(cls)
|
40
|
+
_rpm = cls.get_rpm(cls)
|
41
|
+
|
28
42
|
_parameters_ = {
|
29
43
|
"temperature": 0.5,
|
30
44
|
"topP": 1,
|
@@ -50,7 +64,7 @@ class GoogleService(InferenceServiceABC):
|
|
50
64
|
"stopSequences": self.stopSequences,
|
51
65
|
},
|
52
66
|
}
|
53
|
-
|
67
|
+
# print(combined_prompt)
|
54
68
|
async with aiohttp.ClientSession() as session:
|
55
69
|
async with session.post(
|
56
70
|
url, headers=headers, data=json.dumps(data)
|
@@ -58,16 +72,6 @@ class GoogleService(InferenceServiceABC):
|
|
58
72
|
raw_response_text = await response.text()
|
59
73
|
return json.loads(raw_response_text)
|
60
74
|
|
61
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
62
|
-
data = raw_response
|
63
|
-
try:
|
64
|
-
return data["candidates"][0]["content"]["parts"][0]["text"]
|
65
|
-
except KeyError as e:
|
66
|
-
print(
|
67
|
-
f"The data return was {data}, which was missing the key 'candidates'"
|
68
|
-
)
|
69
|
-
raise e
|
70
|
-
|
71
75
|
LLM.__name__ = model_name
|
72
76
|
|
73
77
|
return LLM
|
@@ -10,10 +10,11 @@ class GroqService(OpenAIService):
|
|
10
10
|
_inference_service_ = "groq"
|
11
11
|
_env_key_name_ = "GROQ_API_KEY"
|
12
12
|
|
13
|
-
_sync_client_ =
|
14
|
-
_async_client_ = groq.AsyncGroq
|
13
|
+
_sync_client_ = groq.Groq
|
14
|
+
_async_client_ = groq.AsyncGroq
|
15
15
|
|
16
|
-
|
16
|
+
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
+
|
18
|
+
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
19
|
_base_url_ = None
|
18
20
|
_models_list_cache: List[str] = []
|
19
|
-
|
@@ -1,22 +1,77 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
|
-
|
2
|
+
import os
|
3
3
|
import re
|
4
|
+
from edsl.config import CONFIG
|
4
5
|
|
5
6
|
|
6
7
|
class InferenceServiceABC(ABC):
|
7
|
-
"""
|
8
|
+
"""
|
9
|
+
Abstract class for inference services.
|
10
|
+
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
11
|
+
"""
|
12
|
+
|
13
|
+
default_levels = {
|
14
|
+
"google": {"tpm": 2_000_000, "rpm": 15},
|
15
|
+
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
16
|
+
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
17
|
+
}
|
18
|
+
|
19
|
+
def __init_subclass__(cls):
|
20
|
+
"""
|
21
|
+
Check that the subclass has the required attributes.
|
22
|
+
- `key_sequence` attribute determines...
|
23
|
+
- `model_exclude_list` attribute determines...
|
24
|
+
"""
|
25
|
+
if not hasattr(cls, "key_sequence"):
|
26
|
+
raise NotImplementedError(
|
27
|
+
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
28
|
+
)
|
29
|
+
if not hasattr(cls, "model_exclude_list"):
|
30
|
+
raise NotImplementedError(
|
31
|
+
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
32
|
+
)
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def _get_limt(cls, limit_type: str) -> int:
|
36
|
+
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
37
|
+
if key in os.environ:
|
38
|
+
return int(os.getenv(key))
|
39
|
+
|
40
|
+
if cls._inference_service_ in cls.default_levels:
|
41
|
+
return int(cls.default_levels[cls._inference_service_][limit_type])
|
42
|
+
|
43
|
+
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
44
|
+
|
45
|
+
def get_tpm(cls) -> int:
|
46
|
+
"""
|
47
|
+
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
48
|
+
"""
|
49
|
+
return cls._get_limt(limit_type="tpm")
|
50
|
+
|
51
|
+
def get_rpm(cls):
|
52
|
+
"""
|
53
|
+
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
54
|
+
"""
|
55
|
+
return cls._get_limt(limit_type="rpm")
|
8
56
|
|
9
57
|
@abstractmethod
|
10
58
|
def available() -> list[str]:
|
59
|
+
"""
|
60
|
+
Returns a list of available models for the service.
|
61
|
+
"""
|
11
62
|
pass
|
12
63
|
|
13
64
|
@abstractmethod
|
14
65
|
def create_model():
|
66
|
+
"""
|
67
|
+
Returns a LanguageModel object.
|
68
|
+
"""
|
15
69
|
pass
|
16
70
|
|
17
71
|
@staticmethod
|
18
72
|
def to_class_name(s):
|
19
|
-
"""
|
73
|
+
"""
|
74
|
+
Converts a string to a valid class name.
|
20
75
|
|
21
76
|
>>> InferenceServiceABC.to_class_name("hello world")
|
22
77
|
'HelloWorld'
|
@@ -15,18 +15,19 @@ class InferenceServicesCollection:
|
|
15
15
|
cls.added_models[service_name].append(model_name)
|
16
16
|
|
17
17
|
@staticmethod
|
18
|
-
def _get_service_available(service) -> list[str]:
|
18
|
+
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
19
|
from_api = True
|
20
20
|
try:
|
21
21
|
service_models = service.available()
|
22
22
|
except Exception as e:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
23
|
+
if warn:
|
24
|
+
warnings.warn(
|
25
|
+
f"""Error getting models for {service._inference_service_}.
|
26
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
27
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
28
|
+
Relying on cache.""",
|
29
|
+
UserWarning,
|
30
|
+
)
|
30
31
|
from edsl.inference_services.models_available_cache import models_available
|
31
32
|
|
32
33
|
service_models = models_available.get(service._inference_service_, [])
|
@@ -60,4 +61,8 @@ class InferenceServicesCollection:
|
|
60
61
|
if service_name is None or service_name == service._inference_service_:
|
61
62
|
return service.create_model(model_name)
|
62
63
|
|
64
|
+
# if model_name == "test":
|
65
|
+
# from edsl.language_models import LanguageModel
|
66
|
+
# return LanguageModel(test = True)
|
67
|
+
|
63
68
|
raise Exception(f"Model {model_name} not found in any of the services")
|