edsl 0.1.31.dev4__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +3 -4
- edsl/agents/PromptConstructionMixin.py +35 -15
- edsl/config.py +11 -1
- edsl/conjure/Conjure.py +6 -0
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +4 -0
- edsl/exceptions/general.py +10 -8
- edsl/inference_services/AwsBedrock.py +110 -0
- edsl/inference_services/AzureAI.py +197 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GroqService.py +3 -4
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +23 -18
- edsl/inference_services/models_available_cache.py +31 -0
- edsl/inference_services/registry.py +13 -1
- edsl/jobs/Jobs.py +100 -19
- edsl/jobs/buckets/TokenBucket.py +12 -4
- edsl/jobs/interviews/Interview.py +31 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -34
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +36 -15
- edsl/jobs/runners/JobsRunnerStatusMixin.py +81 -51
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +58 -43
- edsl/language_models/registry.py +2 -2
- edsl/questions/QuestionBudget.py +0 -1
- edsl/questions/QuestionCheckBox.py +0 -1
- edsl/questions/QuestionExtract.py +0 -1
- edsl/questions/QuestionFreeText.py +2 -9
- edsl/questions/QuestionList.py +0 -1
- edsl/questions/QuestionMultipleChoice.py +1 -2
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/QuestionRank.py +0 -1
- edsl/results/DatasetExportMixin.py +33 -3
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +216 -13
- edsl/scenarios/ScenarioListExportMixin.py +15 -4
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/surveys/Rule.py +5 -2
- edsl/surveys/Survey.py +84 -1
- edsl/surveys/SurveyQualtricsImport.py +213 -0
- edsl/utilities/utilities.py +31 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/METADATA +4 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/RECORD +50 -45
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.32"
|
edsl/agents/Invigilator.py
CHANGED
@@ -19,7 +19,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
19
19
|
|
20
20
|
async def async_answer_question(self) -> AgentResponseDict:
|
21
21
|
"""Answer a question using the AI model.
|
22
|
-
|
22
|
+
|
23
23
|
>>> i = InvigilatorAI.example()
|
24
24
|
>>> i.answer_question()
|
25
25
|
{'message': '{"answer": "SPAM!"}'}
|
@@ -34,7 +34,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
34
34
|
"raw_model_response": raw_response["raw_model_response"],
|
35
35
|
}
|
36
36
|
response = self._format_raw_response(**data)
|
37
|
-
#breakpoint()
|
37
|
+
# breakpoint()
|
38
38
|
return AgentResponseDict(**response)
|
39
39
|
|
40
40
|
async def async_get_response(
|
@@ -44,8 +44,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
44
44
|
iteration: int = 0,
|
45
45
|
encoded_image=None,
|
46
46
|
) -> dict:
|
47
|
-
"""Call the LLM and gets a response. Used in the `answer_question` method.
|
48
|
-
"""
|
47
|
+
"""Call the LLM and gets a response. Used in the `answer_question` method."""
|
49
48
|
try:
|
50
49
|
params = {
|
51
50
|
"user_prompt": user_prompt.text,
|
@@ -214,6 +214,17 @@ class PromptConstructorMixin:
|
|
214
214
|
|
215
215
|
return self._agent_persona_prompt
|
216
216
|
|
217
|
+
def prior_answers_dict(self) -> dict:
|
218
|
+
d = self.survey.question_names_to_questions()
|
219
|
+
for question, answer in self.current_answers.items():
|
220
|
+
if question in d:
|
221
|
+
d[question].answer = answer
|
222
|
+
else:
|
223
|
+
# adds a comment to the question
|
224
|
+
if (new_question := question.split("_comment")[0]) in d:
|
225
|
+
d[new_question].comment = answer
|
226
|
+
return d
|
227
|
+
|
217
228
|
@property
|
218
229
|
def question_instructions_prompt(self) -> Prompt:
|
219
230
|
"""
|
@@ -266,29 +277,38 @@ class PromptConstructorMixin:
|
|
266
277
|
question_prompt = self.question.get_instructions(model=self.model.model)
|
267
278
|
|
268
279
|
# TODO: Try to populate the answers in the question object if they are available
|
269
|
-
d = self.survey.question_names_to_questions()
|
270
|
-
for question, answer in self.current_answers.items():
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
280
|
+
# d = self.survey.question_names_to_questions()
|
281
|
+
# for question, answer in self.current_answers.items():
|
282
|
+
# if question in d:
|
283
|
+
# d[question].answer = answer
|
284
|
+
# else:
|
285
|
+
# # adds a comment to the question
|
286
|
+
# if (new_question := question.split("_comment")[0]) in d:
|
287
|
+
# d[new_question].comment = answer
|
277
288
|
|
278
289
|
question_data = self.question.data.copy()
|
279
290
|
|
280
|
-
# check to see if the
|
291
|
+
# check to see if the question_options is actually a string
|
292
|
+
# This is used when the user is using the question_options as a variable from a sceario
|
281
293
|
if "question_options" in question_data:
|
282
294
|
if isinstance(self.question.data["question_options"], str):
|
283
295
|
from jinja2 import Environment, meta
|
296
|
+
|
284
297
|
env = Environment()
|
285
|
-
parsed_content = env.parse(self.question.data[
|
286
|
-
question_option_key = list(
|
287
|
-
|
298
|
+
parsed_content = env.parse(self.question.data["question_options"])
|
299
|
+
question_option_key = list(
|
300
|
+
meta.find_undeclared_variables(parsed_content)
|
301
|
+
)[0]
|
302
|
+
question_data["question_options"] = self.scenario.get(
|
303
|
+
question_option_key
|
304
|
+
)
|
288
305
|
|
289
|
-
#breakpoint()
|
306
|
+
# breakpoint()
|
290
307
|
rendered_instructions = question_prompt.render(
|
291
|
-
question_data
|
308
|
+
question_data
|
309
|
+
| self.scenario
|
310
|
+
| self.prior_answers_dict()
|
311
|
+
| {"agent": self.agent}
|
292
312
|
)
|
293
313
|
|
294
314
|
undefined_template_variables = (
|
@@ -321,7 +341,7 @@ class PromptConstructorMixin:
|
|
321
341
|
if self.memory_plan is not None:
|
322
342
|
memory_prompt += self.create_memory_prompt(
|
323
343
|
self.question.question_name
|
324
|
-
).render(self.scenario)
|
344
|
+
).render(self.scenario | self.prior_answers_dict())
|
325
345
|
self._prior_question_memory_prompt = memory_prompt
|
326
346
|
return self._prior_question_memory_prompt
|
327
347
|
|
edsl/config.py
CHANGED
@@ -69,6 +69,15 @@ CONFIG_MAP = {
|
|
69
69
|
# "default": None,
|
70
70
|
# "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
|
71
71
|
# },
|
72
|
+
# "AWS_ACCESS_KEY_ID" :
|
73
|
+
# "default": None,
|
74
|
+
# "info": "This env var holds your AWS access key ID.",
|
75
|
+
# "AWS_SECRET_ACCESS_KEY:
|
76
|
+
# "default": None,
|
77
|
+
# "info": "This env var holds your AWS secret access key.",
|
78
|
+
# "AZURE_ENDPOINT_URL_AND_KEY":
|
79
|
+
# "default": None,
|
80
|
+
# "info": "This env var holds your Azure endpoint URL and key (URL:key). You can have several comma-separated URL-key pairs (URL1:key1,URL2:key2).",
|
72
81
|
}
|
73
82
|
|
74
83
|
|
@@ -113,8 +122,9 @@ class Config:
|
|
113
122
|
"""
|
114
123
|
# for each env var in the CONFIG_MAP
|
115
124
|
for env_var, config in CONFIG_MAP.items():
|
125
|
+
# we've set it already in _set_run_mode
|
116
126
|
if env_var == "EDSL_RUN_MODE":
|
117
|
-
continue
|
127
|
+
continue
|
118
128
|
value = os.getenv(env_var)
|
119
129
|
default_value = config.get("default")
|
120
130
|
# if the env var is set, set it as a CONFIG attribute
|
edsl/conjure/Conjure.py
CHANGED
@@ -35,6 +35,12 @@ class Conjure:
|
|
35
35
|
# The __init__ method in Conjure won't be called because __new__ returns a different class instance.
|
36
36
|
pass
|
37
37
|
|
38
|
+
@classmethod
|
39
|
+
def example(cls):
|
40
|
+
from edsl.conjure.InputData import InputDataABC
|
41
|
+
|
42
|
+
return InputDataABC.example()
|
43
|
+
|
38
44
|
|
39
45
|
if __name__ == "__main__":
|
40
46
|
pass
|
edsl/data/CacheHandler.py
CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
|
|
41
41
|
old_data = self.from_old_sqlite_cache()
|
42
42
|
self.cache.add_from_dict(old_data)
|
43
43
|
|
44
|
-
def create_cache_directory(self) -> None:
|
44
|
+
def create_cache_directory(self, notify=False) -> None:
|
45
45
|
"""
|
46
46
|
Create the cache directory if one is required and it does not exist.
|
47
47
|
"""
|
@@ -49,9 +49,8 @@ class CacheHandler:
|
|
49
49
|
dir_path = os.path.dirname(path)
|
50
50
|
if dir_path and not os.path.exists(dir_path):
|
51
51
|
os.makedirs(dir_path)
|
52
|
-
|
53
|
-
|
54
|
-
warnings.warn(f"Created cache directory: {dir_path}")
|
52
|
+
if notify:
|
53
|
+
print(f"Created cache directory: {dir_path}")
|
55
54
|
|
56
55
|
def gen_cache(self) -> Cache:
|
57
56
|
"""
|
edsl/enums.py
CHANGED
@@ -60,6 +60,9 @@ class InferenceServiceType(EnumWithChecks):
|
|
60
60
|
TEST = "test"
|
61
61
|
ANTHROPIC = "anthropic"
|
62
62
|
GROQ = "groq"
|
63
|
+
AZURE = "azure"
|
64
|
+
OLLAMA = "ollama"
|
65
|
+
|
63
66
|
|
64
67
|
service_to_api_keyname = {
|
65
68
|
InferenceServiceType.BEDROCK.value: "TBD",
|
@@ -70,6 +73,7 @@ service_to_api_keyname = {
|
|
70
73
|
InferenceServiceType.TEST.value: "TBD",
|
71
74
|
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
72
75
|
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
76
|
+
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
73
77
|
}
|
74
78
|
|
75
79
|
|
edsl/exceptions/general.py
CHANGED
@@ -21,12 +21,14 @@ class GeneralErrors(Exception):
|
|
21
21
|
|
22
22
|
|
23
23
|
class MissingAPIKeyError(GeneralErrors):
|
24
|
-
def __init__(self, model_name, inference_service):
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
24
|
+
def __init__(self, full_message=None, model_name=None, inference_service=None):
|
25
|
+
if model_name and inference_service:
|
26
|
+
full_message = dedent(
|
27
|
+
f"""
|
28
|
+
An API Key for model `{model_name}` is missing from the .env file.
|
29
|
+
This key is associated with the inference service `{inference_service}`.
|
30
|
+
Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
|
31
|
+
"""
|
32
|
+
)
|
33
|
+
|
32
34
|
super().__init__(full_message)
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import re
|
4
|
+
import boto3
|
5
|
+
from botocore.exceptions import ClientError
|
6
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
8
|
+
import json
|
9
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
10
|
+
|
11
|
+
|
12
|
+
class AwsBedrockService(InferenceServiceABC):
|
13
|
+
"""AWS Bedrock service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "bedrock"
|
16
|
+
_env_key_name_ = (
|
17
|
+
"AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
|
18
|
+
)
|
19
|
+
|
20
|
+
@classmethod
|
21
|
+
def available(cls):
|
22
|
+
"""Fetch available models from AWS Bedrock."""
|
23
|
+
if not cls._models_list_cache:
|
24
|
+
client = boto3.client("bedrock", region_name="us-west-2")
|
25
|
+
all_models_ids = [
|
26
|
+
x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
|
27
|
+
]
|
28
|
+
else:
|
29
|
+
all_models_ids = cls._models_list_cache
|
30
|
+
|
31
|
+
return all_models_ids
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def create_model(
|
35
|
+
cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
|
36
|
+
) -> LanguageModel:
|
37
|
+
if model_class_name is None:
|
38
|
+
model_class_name = cls.to_class_name(model_name)
|
39
|
+
|
40
|
+
class LLM(LanguageModel):
|
41
|
+
"""
|
42
|
+
Child class of LanguageModel for interacting with AWS Bedrock models.
|
43
|
+
"""
|
44
|
+
|
45
|
+
_inference_service_ = cls._inference_service_
|
46
|
+
_model_ = model_name
|
47
|
+
_parameters_ = {
|
48
|
+
"temperature": 0.5,
|
49
|
+
"max_tokens": 512,
|
50
|
+
"top_p": 0.9,
|
51
|
+
}
|
52
|
+
|
53
|
+
async def async_execute_model_call(
|
54
|
+
self, user_prompt: str, system_prompt: str = ""
|
55
|
+
) -> dict[str, Any]:
|
56
|
+
"""Calls the AWS Bedrock API and returns the API response."""
|
57
|
+
|
58
|
+
api_token = (
|
59
|
+
self.api_token
|
60
|
+
) # call to check the if env variables are set.
|
61
|
+
|
62
|
+
client = boto3.client("bedrock-runtime", region_name="us-west-2")
|
63
|
+
|
64
|
+
conversation = [
|
65
|
+
{
|
66
|
+
"role": "user",
|
67
|
+
"content": [{"text": user_prompt}],
|
68
|
+
}
|
69
|
+
]
|
70
|
+
system = [
|
71
|
+
{
|
72
|
+
"text": system_prompt,
|
73
|
+
}
|
74
|
+
]
|
75
|
+
try:
|
76
|
+
response = client.converse(
|
77
|
+
modelId=self._model_,
|
78
|
+
messages=conversation,
|
79
|
+
inferenceConfig={
|
80
|
+
"maxTokens": self.max_tokens,
|
81
|
+
"temperature": self.temperature,
|
82
|
+
"topP": self.top_p,
|
83
|
+
},
|
84
|
+
# system=system,
|
85
|
+
additionalModelRequestFields={},
|
86
|
+
)
|
87
|
+
return response
|
88
|
+
except (ClientError, Exception) as e:
|
89
|
+
print(e)
|
90
|
+
return {"error": str(e)}
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
94
|
+
"""Parses the API response and returns the response text."""
|
95
|
+
if "output" in raw_response and "message" in raw_response["output"]:
|
96
|
+
response = raw_response["output"]["message"]["content"][0]["text"]
|
97
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
98
|
+
match = re.match(pattern, response, re.DOTALL)
|
99
|
+
if match:
|
100
|
+
return match.group(1)
|
101
|
+
else:
|
102
|
+
out = fix_partial_correct_response(response)
|
103
|
+
if "error" not in out:
|
104
|
+
response = out["extracted_json"]
|
105
|
+
return response
|
106
|
+
return "Error parsing response"
|
107
|
+
|
108
|
+
LLM.__name__ = model_class_name
|
109
|
+
|
110
|
+
return LLM
|
@@ -0,0 +1,197 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import re
|
4
|
+
from openai import AsyncAzureOpenAI
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
+
|
8
|
+
from azure.ai.inference.aio import ChatCompletionsClient
|
9
|
+
from azure.core.credentials import AzureKeyCredential
|
10
|
+
from azure.ai.inference.models import SystemMessage, UserMessage
|
11
|
+
import asyncio
|
12
|
+
import json
|
13
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
14
|
+
|
15
|
+
|
16
|
+
def json_handle_none(value: Any) -> Any:
|
17
|
+
"""
|
18
|
+
Handle None values during JSON serialization.
|
19
|
+
- Return "null" if the value is None. Otherwise, don't return anything.
|
20
|
+
"""
|
21
|
+
if value is None:
|
22
|
+
return "null"
|
23
|
+
|
24
|
+
|
25
|
+
class AzureAIService(InferenceServiceABC):
|
26
|
+
"""Azure AI service class."""
|
27
|
+
|
28
|
+
_inference_service_ = "azure"
|
29
|
+
_env_key_name_ = (
|
30
|
+
"AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
|
31
|
+
)
|
32
|
+
_model_id_to_endpoint_and_key = {}
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def available(cls):
|
36
|
+
out = []
|
37
|
+
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
38
|
+
if not azure_endpoints:
|
39
|
+
raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
40
|
+
azure_endpoints = azure_endpoints.split(",")
|
41
|
+
for data in azure_endpoints:
|
42
|
+
try:
|
43
|
+
# data has this format for non openai models https://model_id.azure_endpoint:azure_key
|
44
|
+
_, endpoint, azure_endpoint_key = data.split(":")
|
45
|
+
if "openai" not in endpoint:
|
46
|
+
model_id = endpoint.split(".")[0].replace("/", "")
|
47
|
+
out.append(model_id)
|
48
|
+
cls._model_id_to_endpoint_and_key[model_id] = {
|
49
|
+
"endpoint": f"https:{endpoint}",
|
50
|
+
"azure_endpoint_key": azure_endpoint_key,
|
51
|
+
}
|
52
|
+
else:
|
53
|
+
# data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
|
54
|
+
if "/deployments/" in endpoint:
|
55
|
+
start_idx = endpoint.index("/deployments/") + len(
|
56
|
+
"/deployments/"
|
57
|
+
)
|
58
|
+
end_idx = (
|
59
|
+
endpoint.index("/", start_idx)
|
60
|
+
if "/" in endpoint[start_idx:]
|
61
|
+
else len(endpoint)
|
62
|
+
)
|
63
|
+
model_id = endpoint[start_idx:end_idx]
|
64
|
+
api_version_value = None
|
65
|
+
if "api-version=" in endpoint:
|
66
|
+
start_idx = endpoint.index("api-version=") + len(
|
67
|
+
"api-version="
|
68
|
+
)
|
69
|
+
end_idx = (
|
70
|
+
endpoint.index("&", start_idx)
|
71
|
+
if "&" in endpoint[start_idx:]
|
72
|
+
else len(endpoint)
|
73
|
+
)
|
74
|
+
api_version_value = endpoint[start_idx:end_idx]
|
75
|
+
|
76
|
+
cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
|
77
|
+
"endpoint": f"https:{endpoint}",
|
78
|
+
"azure_endpoint_key": azure_endpoint_key,
|
79
|
+
"api_version": api_version_value,
|
80
|
+
}
|
81
|
+
out.append(f"azure:{model_id}")
|
82
|
+
|
83
|
+
except Exception as e:
|
84
|
+
raise e
|
85
|
+
return out
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def create_model(
|
89
|
+
cls, model_name: str = "azureai", model_class_name=None
|
90
|
+
) -> LanguageModel:
|
91
|
+
if model_class_name is None:
|
92
|
+
model_class_name = cls.to_class_name(model_name)
|
93
|
+
|
94
|
+
class LLM(LanguageModel):
|
95
|
+
"""
|
96
|
+
Child class of LanguageModel for interacting with Azure OpenAI models.
|
97
|
+
"""
|
98
|
+
|
99
|
+
_inference_service_ = cls._inference_service_
|
100
|
+
_model_ = model_name
|
101
|
+
_parameters_ = {
|
102
|
+
"temperature": 0.5,
|
103
|
+
"max_tokens": 512,
|
104
|
+
"top_p": 0.9,
|
105
|
+
}
|
106
|
+
|
107
|
+
async def async_execute_model_call(
|
108
|
+
self, user_prompt: str, system_prompt: str = ""
|
109
|
+
) -> dict[str, Any]:
|
110
|
+
"""Calls the Azure OpenAI API and returns the API response."""
|
111
|
+
|
112
|
+
try:
|
113
|
+
api_key = cls._model_id_to_endpoint_and_key[model_name][
|
114
|
+
"azure_endpoint_key"
|
115
|
+
]
|
116
|
+
except:
|
117
|
+
api_key = None
|
118
|
+
|
119
|
+
if not api_key:
|
120
|
+
raise EnvironmentError(
|
121
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
122
|
+
)
|
123
|
+
|
124
|
+
try:
|
125
|
+
endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
|
126
|
+
except:
|
127
|
+
endpoint = None
|
128
|
+
|
129
|
+
if not endpoint:
|
130
|
+
raise EnvironmentError(
|
131
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
132
|
+
)
|
133
|
+
|
134
|
+
if "openai" not in endpoint:
|
135
|
+
client = ChatCompletionsClient(
|
136
|
+
endpoint=endpoint,
|
137
|
+
credential=AzureKeyCredential(api_key),
|
138
|
+
temperature=self.temperature,
|
139
|
+
top_p=self.top_p,
|
140
|
+
max_tokens=self.max_tokens,
|
141
|
+
)
|
142
|
+
try:
|
143
|
+
response = await client.complete(
|
144
|
+
messages=[
|
145
|
+
SystemMessage(content=system_prompt),
|
146
|
+
UserMessage(content=user_prompt),
|
147
|
+
],
|
148
|
+
# model_extras={"safe_mode": True},
|
149
|
+
)
|
150
|
+
await client.close()
|
151
|
+
return response.as_dict()
|
152
|
+
except Exception as e:
|
153
|
+
await client.close()
|
154
|
+
return {"error": str(e)}
|
155
|
+
else:
|
156
|
+
api_version = cls._model_id_to_endpoint_and_key[model_name][
|
157
|
+
"api_version"
|
158
|
+
]
|
159
|
+
client = AsyncAzureOpenAI(
|
160
|
+
azure_endpoint=endpoint,
|
161
|
+
api_version=api_version,
|
162
|
+
api_key=api_key,
|
163
|
+
)
|
164
|
+
response = await client.chat.completions.create(
|
165
|
+
model=model_name,
|
166
|
+
messages=[
|
167
|
+
{
|
168
|
+
"role": "user",
|
169
|
+
"content": user_prompt, # Your question can go here
|
170
|
+
},
|
171
|
+
],
|
172
|
+
)
|
173
|
+
return response.model_dump()
|
174
|
+
|
175
|
+
@staticmethod
|
176
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
177
|
+
"""Parses the API response and returns the response text."""
|
178
|
+
if (
|
179
|
+
raw_response
|
180
|
+
and "choices" in raw_response
|
181
|
+
and raw_response["choices"]
|
182
|
+
):
|
183
|
+
response = raw_response["choices"][0]["message"]["content"]
|
184
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
185
|
+
match = re.match(pattern, response, re.DOTALL)
|
186
|
+
if match:
|
187
|
+
return match.group(1)
|
188
|
+
else:
|
189
|
+
out = fix_partial_correct_response(response)
|
190
|
+
if "error" not in out:
|
191
|
+
response = out["extracted_json"]
|
192
|
+
return response
|
193
|
+
return "Error parsing response"
|
194
|
+
|
195
|
+
LLM.__name__ = model_class_name
|
196
|
+
|
197
|
+
return LLM
|
@@ -2,16 +2,17 @@ import aiohttp
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from typing import Any, List
|
5
|
-
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
7
|
from edsl.language_models import LanguageModel
|
7
8
|
|
8
9
|
from edsl.inference_services.OpenAIService import OpenAIService
|
9
10
|
|
11
|
+
|
10
12
|
class DeepInfraService(OpenAIService):
|
11
13
|
"""DeepInfra service class."""
|
12
14
|
|
13
15
|
_inference_service_ = "deep_infra"
|
14
16
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
15
|
-
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
|
+
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
16
18
|
_models_list_cache: List[str] = []
|
17
|
-
|
@@ -10,10 +10,9 @@ class GroqService(OpenAIService):
|
|
10
10
|
_inference_service_ = "groq"
|
11
11
|
_env_key_name_ = "GROQ_API_KEY"
|
12
12
|
|
13
|
-
_sync_client_ =
|
14
|
-
_async_client_ = groq.AsyncGroq
|
13
|
+
_sync_client_ = groq.Groq
|
14
|
+
_async_client_ = groq.AsyncGroq
|
15
15
|
|
16
|
-
#_base_url_ = "https://api.deepinfra.com/v1/openai"
|
16
|
+
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
17
|
_base_url_ = None
|
18
18
|
_models_list_cache: List[str] = []
|
19
|
-
|
@@ -15,18 +15,19 @@ class InferenceServicesCollection:
|
|
15
15
|
cls.added_models[service_name].append(model_name)
|
16
16
|
|
17
17
|
@staticmethod
|
18
|
-
def _get_service_available(service) -> list[str]:
|
18
|
+
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
19
|
from_api = True
|
20
20
|
try:
|
21
21
|
service_models = service.available()
|
22
22
|
except Exception as e:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
23
|
+
if warn:
|
24
|
+
warnings.warn(
|
25
|
+
f"""Error getting models for {service._inference_service_}.
|
26
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
27
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
28
|
+
Relying on cache.""",
|
29
|
+
UserWarning,
|
30
|
+
)
|
30
31
|
from edsl.inference_services.models_available_cache import models_available
|
31
32
|
|
32
33
|
service_models = models_available.get(service._inference_service_, [])
|
@@ -60,4 +61,8 @@ class InferenceServicesCollection:
|
|
60
61
|
if service_name is None or service_name == service._inference_service_:
|
61
62
|
return service.create_model(model_name)
|
62
63
|
|
64
|
+
# if model_name == "test":
|
65
|
+
# from edsl.language_models import LanguageModel
|
66
|
+
# return LanguageModel(test = True)
|
67
|
+
|
63
68
|
raise Exception(f"Model {model_name} not found in any of the services")
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class OllamaService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "ollama"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "http://localhost:11434/v1"
|
18
|
+
_models_list_cache: List[str] = []
|