edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +24 -14
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +28 -6
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
- edsl/agents/prompt_helpers.py +129 -0
- edsl/config.py +26 -34
- edsl/coop/coop.py +14 -4
- edsl/data_transfer_models.py +26 -73
- edsl/enums.py +2 -0
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +10 -6
- edsl/inference_services/TestService.py +34 -16
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +109 -18
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +130 -49
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
- edsl/jobs/runners/JobsRunnerStatus.py +332 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +36 -38
- edsl/language_models/registry.py +13 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +74 -16
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +13 -24
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +11 -6
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +178 -34
- edsl/scenarios/Scenario.py +76 -37
- edsl/scenarios/ScenarioList.py +19 -2
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +34 -1
- edsl/surveys/RuleCollection.py +55 -5
- edsl/surveys/Survey.py +189 -10
- edsl/surveys/base.py +4 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
edsl/config.py
CHANGED
@@ -1,73 +1,65 @@
|
|
1
1
|
"""This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
|
2
2
|
|
3
3
|
import os
|
4
|
+
from dotenv import load_dotenv, find_dotenv
|
4
5
|
from edsl.exceptions import (
|
5
6
|
InvalidEnvironmentVariableError,
|
6
7
|
MissingEnvironmentVariableError,
|
7
8
|
)
|
8
|
-
from dotenv import load_dotenv, find_dotenv
|
9
9
|
|
10
10
|
# valid values for EDSL_RUN_MODE
|
11
|
-
EDSL_RUN_MODES = [
|
11
|
+
EDSL_RUN_MODES = [
|
12
|
+
"development",
|
13
|
+
"development-testrun",
|
14
|
+
"production",
|
15
|
+
]
|
12
16
|
|
13
17
|
# `default` is used to impute values only in "production" mode
|
14
18
|
# `info` gives a brief description of the env var
|
15
19
|
CONFIG_MAP = {
|
16
20
|
"EDSL_RUN_MODE": {
|
17
21
|
"default": "production",
|
18
|
-
"info": "This
|
19
|
-
},
|
20
|
-
"EDSL_DATABASE_PATH": {
|
21
|
-
"default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
|
22
|
-
"info": "This env var determines the path to the cache file.",
|
23
|
-
},
|
24
|
-
"EDSL_LOGGING_PATH": {
|
25
|
-
"default": f"{os.path.join(os.getcwd(), 'interview.log')}",
|
26
|
-
"info": "This env var determines the path to the log file.",
|
22
|
+
"info": "This config var determines the run mode of the application.",
|
27
23
|
},
|
28
24
|
"EDSL_API_TIMEOUT": {
|
29
25
|
"default": "60",
|
30
|
-
"info": "This
|
26
|
+
"info": "This config var determines the maximum number of seconds to wait for an API call to return.",
|
31
27
|
},
|
32
28
|
"EDSL_BACKOFF_START_SEC": {
|
33
29
|
"default": "1",
|
34
|
-
"info": "This
|
30
|
+
"info": "This config var determines the number of seconds to wait before retrying a failed API call.",
|
35
31
|
},
|
36
|
-
"
|
32
|
+
"EDSL_BACKOFF_MAX_SEC": {
|
37
33
|
"default": "60",
|
38
|
-
"info": "This
|
34
|
+
"info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
|
39
35
|
},
|
40
|
-
"
|
41
|
-
"default": "
|
42
|
-
"info": "This
|
36
|
+
"EDSL_DATABASE_PATH": {
|
37
|
+
"default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
|
38
|
+
"info": "This config var determines the path to the cache file.",
|
43
39
|
},
|
44
40
|
"EDSL_DEFAULT_MODEL": {
|
45
41
|
"default": "gpt-4o",
|
46
|
-
"info": "This
|
42
|
+
"info": "This config var holds the default model that will be used if a model is not explicitly passed.",
|
47
43
|
},
|
48
|
-
"
|
49
|
-
"default": "
|
50
|
-
"info": "This
|
44
|
+
"EDSL_FETCH_TOKEN_PRICES": {
|
45
|
+
"default": "True",
|
46
|
+
"info": "This config var determines whether to fetch prices for tokens used in remote inference",
|
47
|
+
},
|
48
|
+
"EDSL_MAX_ATTEMPTS": {
|
49
|
+
"default": "5",
|
50
|
+
"info": "This config var determines the maximum number of times to retry a failed API call.",
|
51
51
|
},
|
52
52
|
"EDSL_SERVICE_RPM_BASELINE": {
|
53
53
|
"default": "100",
|
54
|
-
"info": "This
|
54
|
+
"info": "This config var holds the maximum number of requests per minute. Model-specific values provided in env vars such as EDSL_SERVICE_RPM_OPENAI will override this. value for the corresponding model",
|
55
55
|
},
|
56
|
-
"
|
56
|
+
"EDSL_SERVICE_TPM_BASELINE": {
|
57
57
|
"default": "2000000",
|
58
|
-
"info": "This
|
59
|
-
},
|
60
|
-
"EDSL_SERVICE_RPM_OPENAI": {
|
61
|
-
"default": "100",
|
62
|
-
"info": "This env var holds the maximum number of requests per minute for OpenAI.",
|
63
|
-
},
|
64
|
-
"EDSL_FETCH_TOKEN_PRICES": {
|
65
|
-
"default": "True",
|
66
|
-
"info": "Whether to fetch the prices for tokens",
|
58
|
+
"info": "This config var holds the maximum number of tokens per minute for all models. Model-specific values provided in env vars such as EDSL_SERVICE_TPM_OPENAI will override this value for the corresponding model.",
|
67
59
|
},
|
68
60
|
"EXPECTED_PARROT_URL": {
|
69
61
|
"default": "https://www.expectedparrot.com",
|
70
|
-
"info": "This
|
62
|
+
"info": "This config var holds the URL of the Expected Parrot API.",
|
71
63
|
},
|
72
64
|
}
|
73
65
|
|
edsl/coop/coop.py
CHANGED
@@ -59,8 +59,16 @@ class Coop:
|
|
59
59
|
Send a request to the server and return the response.
|
60
60
|
"""
|
61
61
|
url = f"{self.url}/{uri}"
|
62
|
+
method = method.upper()
|
63
|
+
if payload is None:
|
64
|
+
timeout = 20
|
65
|
+
elif (
|
66
|
+
method.upper() == "POST"
|
67
|
+
and "json_string" in payload
|
68
|
+
and payload.get("json_string") is not None
|
69
|
+
):
|
70
|
+
timeout = max(20, (len(payload.get("json_string", "")) // (1024 * 1024)))
|
62
71
|
try:
|
63
|
-
method = method.upper()
|
64
72
|
if method in ["GET", "DELETE"]:
|
65
73
|
response = requests.request(
|
66
74
|
method, url, params=params, headers=self.headers, timeout=timeout
|
@@ -77,7 +85,7 @@ class Coop:
|
|
77
85
|
else:
|
78
86
|
raise Exception(f"Invalid {method=}.")
|
79
87
|
except requests.ConnectionError:
|
80
|
-
raise requests.ConnectionError("Could not connect to the server.")
|
88
|
+
raise requests.ConnectionError(f"Could not connect to the server at {url}.")
|
81
89
|
|
82
90
|
return response
|
83
91
|
|
@@ -87,6 +95,7 @@ class Coop:
|
|
87
95
|
"""
|
88
96
|
if response.status_code >= 400:
|
89
97
|
message = response.json().get("detail")
|
98
|
+
# print(response.text)
|
90
99
|
if "Authorization" in message:
|
91
100
|
print(message)
|
92
101
|
message = "Please provide an Expected Parrot API key."
|
@@ -794,8 +803,9 @@ def main():
|
|
794
803
|
##############
|
795
804
|
job = Jobs.example()
|
796
805
|
coop.remote_inference_cost(job)
|
797
|
-
|
798
|
-
coop.remote_inference_get(
|
806
|
+
job_coop_object = coop.remote_inference_create(job)
|
807
|
+
job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
|
808
|
+
coop.get(uuid=job_coop_results.get("results_uuid"))
|
799
809
|
|
800
810
|
##############
|
801
811
|
# E. Errors
|
edsl/data_transfer_models.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
from typing import NamedTuple, Dict, List, Optional, Any
|
2
|
+
from dataclasses import dataclass, fields
|
3
|
+
import reprlib
|
2
4
|
|
3
5
|
|
4
6
|
class ModelInputs(NamedTuple):
|
@@ -45,76 +47,27 @@ class EDSLResultObjectInput(NamedTuple):
|
|
45
47
|
cost: Optional[float] = None
|
46
48
|
|
47
49
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
#
|
59
|
-
|
60
|
-
#
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
#
|
68
|
-
|
69
|
-
|
70
|
-
#
|
71
|
-
|
72
|
-
# raise ValueError("generated_tokens must be provided")
|
73
|
-
# self.data = {
|
74
|
-
# "answer": answer,
|
75
|
-
# "comment": comment,
|
76
|
-
# "question_name": question_name,
|
77
|
-
# "prompts": prompts,
|
78
|
-
# "usage": usage,
|
79
|
-
# "cached_response": cached_response,
|
80
|
-
# "raw_model_response": raw_model_response,
|
81
|
-
# "simple_model_raw_response": simple_model_raw_response,
|
82
|
-
# "cache_used": cache_used,
|
83
|
-
# "cache_key": cache_key,
|
84
|
-
# "generated_tokens": generated_tokens,
|
85
|
-
# }
|
86
|
-
|
87
|
-
# @property
|
88
|
-
# def data(self):
|
89
|
-
# return self._data
|
90
|
-
|
91
|
-
# @data.setter
|
92
|
-
# def data(self, value):
|
93
|
-
# self._data = value
|
94
|
-
|
95
|
-
# def __getitem__(self, key):
|
96
|
-
# return self.data.get(key, None)
|
97
|
-
|
98
|
-
# def __setitem__(self, key, value):
|
99
|
-
# self.data[key] = value
|
100
|
-
|
101
|
-
# def __delitem__(self, key):
|
102
|
-
# del self.data[key]
|
103
|
-
|
104
|
-
# def __iter__(self):
|
105
|
-
# return iter(self.data)
|
106
|
-
|
107
|
-
# def __len__(self):
|
108
|
-
# return len(self.data)
|
109
|
-
|
110
|
-
# def keys(self):
|
111
|
-
# return self.data.keys()
|
112
|
-
|
113
|
-
# def values(self):
|
114
|
-
# return self.data.values()
|
115
|
-
|
116
|
-
# def items(self):
|
117
|
-
# return self.data.items()
|
118
|
-
|
119
|
-
# def is_this_same_model(self):
|
120
|
-
# return True
|
50
|
+
@dataclass
|
51
|
+
class ImageInfo:
|
52
|
+
file_path: str
|
53
|
+
file_name: str
|
54
|
+
image_format: str
|
55
|
+
file_size: int
|
56
|
+
encoded_image: str
|
57
|
+
|
58
|
+
def __repr__(self):
|
59
|
+
reprlib_instance = reprlib.Repr()
|
60
|
+
reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
|
61
|
+
|
62
|
+
# Get all fields except encoded_image
|
63
|
+
field_reprs = [
|
64
|
+
f"{f.name}={getattr(self, f.name)!r}"
|
65
|
+
for f in fields(self)
|
66
|
+
if f.name != "encoded_image"
|
67
|
+
]
|
68
|
+
|
69
|
+
# Add the reprlib-restricted encoded_image field
|
70
|
+
field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
|
71
|
+
|
72
|
+
# Join everything to create the repr
|
73
|
+
return f"{self.__class__.__name__}({', '.join(field_reprs)})"
|
edsl/enums.py
CHANGED
@@ -63,6 +63,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
63
63
|
AZURE = "azure"
|
64
64
|
OLLAMA = "ollama"
|
65
65
|
MISTRAL = "mistral"
|
66
|
+
TOGETHER = "together"
|
66
67
|
|
67
68
|
|
68
69
|
service_to_api_keyname = {
|
@@ -76,6 +77,7 @@ service_to_api_keyname = {
|
|
76
77
|
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
77
78
|
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
78
79
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
80
|
+
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
79
81
|
}
|
80
82
|
|
81
83
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Optional, List
|
3
3
|
import re
|
4
4
|
from anthropic import AsyncAnthropic
|
5
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -60,7 +60,10 @@ class AnthropicService(InferenceServiceABC):
|
|
60
60
|
_rpm = cls.get_rpm(cls)
|
61
61
|
|
62
62
|
async def async_execute_model_call(
|
63
|
-
self,
|
63
|
+
self,
|
64
|
+
user_prompt: str,
|
65
|
+
system_prompt: str = "",
|
66
|
+
files_list: Optional[List["Files"]] = None,
|
64
67
|
) -> dict[str, Any]:
|
65
68
|
"""Calls the OpenAI API and returns the API response."""
|
66
69
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
import re
|
4
4
|
import boto3
|
5
5
|
from botocore.exceptions import ClientError
|
@@ -69,7 +69,10 @@ class AwsBedrockService(InferenceServiceABC):
|
|
69
69
|
_tpm = cls.get_tpm(cls)
|
70
70
|
|
71
71
|
async def async_execute_model_call(
|
72
|
-
self,
|
72
|
+
self,
|
73
|
+
user_prompt: str,
|
74
|
+
system_prompt: str = "",
|
75
|
+
files_list: Optional[List["FileStore"]] = None,
|
73
76
|
) -> dict[str, Any]:
|
74
77
|
"""Calls the AWS Bedrock API and returns the API response."""
|
75
78
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Optional, List
|
3
3
|
import re
|
4
4
|
from openai import AsyncAzureOpenAI
|
5
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -122,7 +122,10 @@ class AzureAIService(InferenceServiceABC):
|
|
122
122
|
_tpm = cls.get_tpm(cls)
|
123
123
|
|
124
124
|
async def async_execute_model_call(
|
125
|
-
self,
|
125
|
+
self,
|
126
|
+
user_prompt: str,
|
127
|
+
system_prompt: str = "",
|
128
|
+
files_list: Optional[List["FileStore"]] = None,
|
126
129
|
) -> dict[str, Any]:
|
127
130
|
"""Calls the Azure OpenAI API and returns the API response."""
|
128
131
|
|
@@ -1,25 +1,54 @@
|
|
1
1
|
import os
|
2
|
-
import
|
3
|
-
import
|
4
|
-
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
import google
|
4
|
+
import google.generativeai as genai
|
5
|
+
from google.generativeai.types import GenerationConfig
|
6
|
+
from google.api_core.exceptions import InvalidArgument
|
7
|
+
|
5
8
|
from edsl.exceptions import MissingAPIKeyError
|
6
9
|
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
-
|
8
10
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
9
11
|
|
12
|
+
safety_settings = [
|
13
|
+
{
|
14
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
15
|
+
"threshold": "BLOCK_NONE",
|
16
|
+
},
|
17
|
+
{
|
18
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
19
|
+
"threshold": "BLOCK_NONE",
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
23
|
+
"threshold": "BLOCK_NONE",
|
24
|
+
},
|
25
|
+
{
|
26
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
27
|
+
"threshold": "BLOCK_NONE",
|
28
|
+
},
|
29
|
+
]
|
30
|
+
|
10
31
|
|
11
32
|
class GoogleService(InferenceServiceABC):
|
12
33
|
_inference_service_ = "google"
|
13
34
|
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
-
usage_sequence = ["
|
15
|
-
input_token_name = "
|
16
|
-
output_token_name = "
|
35
|
+
usage_sequence = ["usage_metadata"]
|
36
|
+
input_token_name = "prompt_token_count"
|
37
|
+
output_token_name = "candidates_token_count"
|
17
38
|
|
18
39
|
model_exclude_list = []
|
19
40
|
|
41
|
+
# @classmethod
|
42
|
+
# def available(cls) -> List[str]:
|
43
|
+
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
44
|
+
|
20
45
|
@classmethod
|
21
|
-
def available(cls):
|
22
|
-
|
46
|
+
def available(cls) -> List[str]:
|
47
|
+
model_list = []
|
48
|
+
for m in genai.list_models():
|
49
|
+
if "generateContent" in m.supported_generation_methods:
|
50
|
+
model_list.append(m.name.split("/")[-1])
|
51
|
+
return model_list
|
23
52
|
|
24
53
|
@classmethod
|
25
54
|
def create_model(
|
@@ -47,33 +76,79 @@ class GoogleService(InferenceServiceABC):
|
|
47
76
|
"stopSequences": [],
|
48
77
|
}
|
49
78
|
|
79
|
+
api_token = None
|
80
|
+
model = None
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def initialize(cls):
|
84
|
+
if cls.api_token is None:
|
85
|
+
cls.api_token = os.getenv("GOOGLE_API_KEY")
|
86
|
+
if not cls.api_token:
|
87
|
+
raise MissingAPIKeyError(
|
88
|
+
"GOOGLE_API_KEY environment variable is not set"
|
89
|
+
)
|
90
|
+
genai.configure(api_key=cls.api_token)
|
91
|
+
cls.generative_model = genai.GenerativeModel(
|
92
|
+
cls._model_, safety_settings=safety_settings
|
93
|
+
)
|
94
|
+
|
95
|
+
def __init__(self, *args, **kwargs):
|
96
|
+
super().__init__(*args, **kwargs)
|
97
|
+
self.initialize()
|
98
|
+
|
99
|
+
def get_generation_config(self) -> GenerationConfig:
|
100
|
+
return GenerationConfig(
|
101
|
+
temperature=self.temperature,
|
102
|
+
top_p=self.topP,
|
103
|
+
top_k=self.topK,
|
104
|
+
max_output_tokens=self.maxOutputTokens,
|
105
|
+
stop_sequences=self.stopSequences,
|
106
|
+
)
|
107
|
+
|
50
108
|
async def async_execute_model_call(
|
51
|
-
self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
data = {
|
58
|
-
"contents": [{"parts": [{"text": combined_prompt}]}],
|
59
|
-
"generationConfig": {
|
60
|
-
"temperature": self.temperature,
|
61
|
-
"topK": self.topK,
|
62
|
-
"topP": self.topP,
|
63
|
-
"maxOutputTokens": self.maxOutputTokens,
|
64
|
-
"stopSequences": self.stopSequences,
|
65
|
-
},
|
66
|
-
}
|
67
|
-
print(combined_prompt)
|
68
|
-
async with aiohttp.ClientSession() as session:
|
69
|
-
async with session.post(
|
70
|
-
url, headers=headers, data=json.dumps(data)
|
71
|
-
) as response:
|
72
|
-
raw_response_text = await response.text()
|
73
|
-
return json.loads(raw_response_text)
|
109
|
+
self,
|
110
|
+
user_prompt: str,
|
111
|
+
system_prompt: str = "",
|
112
|
+
files_list: Optional["Files"] = None,
|
113
|
+
) -> Dict[str, Any]:
|
114
|
+
generation_config = self.get_generation_config()
|
74
115
|
|
75
|
-
|
116
|
+
if files_list is None:
|
117
|
+
files_list = []
|
118
|
+
|
119
|
+
if (
|
120
|
+
system_prompt is not None
|
121
|
+
and system_prompt != ""
|
122
|
+
and self._model_ != "gemini-pro"
|
123
|
+
):
|
124
|
+
try:
|
125
|
+
self.generative_model = genai.GenerativeModel(
|
126
|
+
self._model_,
|
127
|
+
safety_settings=safety_settings,
|
128
|
+
system_instruction=system_prompt,
|
129
|
+
)
|
130
|
+
except InvalidArgument as e:
|
131
|
+
print(
|
132
|
+
f"This model, {self._model_}, does not support system_instruction"
|
133
|
+
)
|
134
|
+
print("Will add system_prompt to user_prompt")
|
135
|
+
user_prompt = f"{system_prompt}\n{user_prompt}"
|
76
136
|
|
137
|
+
combined_prompt = [user_prompt]
|
138
|
+
for file in files_list:
|
139
|
+
if "google" not in file.external_locations:
|
140
|
+
_ = file.upload_google()
|
141
|
+
gen_ai_file = google.generativeai.types.file_types.File(
|
142
|
+
file.external_locations["google"]
|
143
|
+
)
|
144
|
+
combined_prompt.append(gen_ai_file)
|
145
|
+
|
146
|
+
response = await self.generative_model.generate_content_async(
|
147
|
+
combined_prompt, generation_config=generation_config
|
148
|
+
)
|
149
|
+
return response.to_dict()
|
150
|
+
|
151
|
+
LLM.__name__ = model_name
|
77
152
|
return LLM
|
78
153
|
|
79
154
|
|
@@ -1,14 +1,27 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
|
-
|
2
|
+
import os
|
3
3
|
import re
|
4
4
|
from edsl.config import CONFIG
|
5
5
|
|
6
6
|
|
7
7
|
class InferenceServiceABC(ABC):
|
8
|
-
"""
|
8
|
+
"""
|
9
|
+
Abstract class for inference services.
|
10
|
+
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
11
|
+
"""
|
12
|
+
|
13
|
+
default_levels = {
|
14
|
+
"google": {"tpm": 2_000_000, "rpm": 15},
|
15
|
+
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
16
|
+
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
17
|
+
}
|
9
18
|
|
10
|
-
# check if child class has cls attribute "key_sequence"
|
11
19
|
def __init_subclass__(cls):
|
20
|
+
"""
|
21
|
+
Check that the subclass has the required attributes.
|
22
|
+
- `key_sequence` attribute determines...
|
23
|
+
- `model_exclude_list` attribute determines...
|
24
|
+
"""
|
12
25
|
if not hasattr(cls, "key_sequence"):
|
13
26
|
raise NotImplementedError(
|
14
27
|
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
@@ -18,29 +31,47 @@ class InferenceServiceABC(ABC):
|
|
18
31
|
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
19
32
|
)
|
20
33
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
34
|
+
@classmethod
|
35
|
+
def _get_limt(cls, limit_type: str) -> int:
|
36
|
+
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
37
|
+
if key in os.environ:
|
38
|
+
return int(os.getenv(key))
|
39
|
+
|
40
|
+
if cls._inference_service_ in cls.default_levels:
|
41
|
+
return int(cls.default_levels[cls._inference_service_][limit_type])
|
42
|
+
|
43
|
+
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
44
|
+
|
45
|
+
def get_tpm(cls) -> int:
|
46
|
+
"""
|
47
|
+
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
48
|
+
"""
|
49
|
+
return cls._get_limt(limit_type="tpm")
|
26
50
|
|
27
51
|
def get_rpm(cls):
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
return
|
52
|
+
"""
|
53
|
+
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
54
|
+
"""
|
55
|
+
return cls._get_limt(limit_type="rpm")
|
32
56
|
|
33
57
|
@abstractmethod
|
34
58
|
def available() -> list[str]:
|
59
|
+
"""
|
60
|
+
Returns a list of available models for the service.
|
61
|
+
"""
|
35
62
|
pass
|
36
63
|
|
37
64
|
@abstractmethod
|
38
65
|
def create_model():
|
66
|
+
"""
|
67
|
+
Returns a LanguageModel object.
|
68
|
+
"""
|
39
69
|
pass
|
40
70
|
|
41
71
|
@staticmethod
|
42
72
|
def to_class_name(s):
|
43
|
-
"""
|
73
|
+
"""
|
74
|
+
Converts a string to a valid class name.
|
44
75
|
|
45
76
|
>>> InferenceServiceABC.to_class_name("hello world")
|
46
77
|
'HelloWorld'
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any, List
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
4
|
from edsl.language_models.LanguageModel import LanguageModel
|
5
5
|
import asyncio
|
@@ -95,7 +95,10 @@ class MistralAIService(InferenceServiceABC):
|
|
95
95
|
return cls.async_client()
|
96
96
|
|
97
97
|
async def async_execute_model_call(
|
98
|
-
self,
|
98
|
+
self,
|
99
|
+
user_prompt: str,
|
100
|
+
system_prompt: str = "",
|
101
|
+
files_list: Optional[List["FileStore"]] = None,
|
99
102
|
) -> dict[str, Any]:
|
100
103
|
"""Calls the Mistral API and returns the API response."""
|
101
104
|
s = self.async_client()
|
@@ -168,13 +168,14 @@ class OpenAIService(InferenceServiceABC):
|
|
168
168
|
self,
|
169
169
|
user_prompt: str,
|
170
170
|
system_prompt: str = "",
|
171
|
-
|
171
|
+
files_list: Optional[List["Files"]] = None,
|
172
172
|
invigilator: Optional[
|
173
173
|
"InvigilatorAI"
|
174
174
|
] = None, # TBD - can eventually be used for function-calling
|
175
175
|
) -> dict[str, Any]:
|
176
176
|
"""Calls the OpenAI API and returns the API response."""
|
177
|
-
if
|
177
|
+
if files_list:
|
178
|
+
encoded_image = files_list[0].base64_string
|
178
179
|
content = [{"type": "text", "text": user_prompt}]
|
179
180
|
content.append(
|
180
181
|
{
|
@@ -187,12 +188,15 @@ class OpenAIService(InferenceServiceABC):
|
|
187
188
|
else:
|
188
189
|
content = user_prompt
|
189
190
|
client = self.async_client()
|
191
|
+
messages = [
|
192
|
+
{"role": "system", "content": system_prompt},
|
193
|
+
{"role": "user", "content": content},
|
194
|
+
]
|
195
|
+
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
196
|
+
messages = messages[1:]
|
190
197
|
params = {
|
191
198
|
"model": self.model,
|
192
|
-
"messages":
|
193
|
-
{"role": "system", "content": system_prompt},
|
194
|
-
{"role": "user", "content": content},
|
195
|
-
],
|
199
|
+
"messages": messages,
|
196
200
|
"temperature": self.temperature,
|
197
201
|
"max_tokens": self.max_tokens,
|
198
202
|
"top_p": self.top_p,
|