edsl 0.1.31.dev3__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +7 -2
- edsl/agents/PromptConstructionMixin.py +35 -15
- edsl/config.py +15 -1
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/coop.py +4 -0
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +5 -0
- edsl/exceptions/general.py +10 -8
- edsl/inference_services/AwsBedrock.py +110 -0
- edsl/inference_services/AzureAI.py +197 -0
- edsl/inference_services/DeepInfraService.py +6 -91
- edsl/inference_services/GroqService.py +18 -0
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +68 -21
- edsl/inference_services/models_available_cache.py +31 -0
- edsl/inference_services/registry.py +14 -1
- edsl/jobs/Jobs.py +103 -21
- edsl/jobs/buckets/TokenBucket.py +12 -4
- edsl/jobs/interviews/Interview.py +31 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -33
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
- edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
- edsl/jobs/tasks/TaskCreators.py +8 -2
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +62 -41
- edsl/language_models/registry.py +4 -0
- edsl/questions/QuestionBudget.py +0 -1
- edsl/questions/QuestionCheckBox.py +0 -1
- edsl/questions/QuestionExtract.py +0 -1
- edsl/questions/QuestionFreeText.py +2 -9
- edsl/questions/QuestionList.py +0 -1
- edsl/questions/QuestionMultipleChoice.py +1 -2
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/QuestionRank.py +0 -1
- edsl/results/DatasetExportMixin.py +33 -3
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +216 -13
- edsl/scenarios/ScenarioListExportMixin.py +15 -4
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/surveys/Rule.py +5 -2
- edsl/surveys/Survey.py +84 -1
- edsl/surveys/SurveyQualtricsImport.py +213 -0
- edsl/utilities/utilities.py +31 -0
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/METADATA +5 -1
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/RECORD +52 -46
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.32"
|
edsl/agents/Invigilator.py
CHANGED
@@ -18,7 +18,12 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
18
18
|
"""An invigilator that uses an AI model to answer questions."""
|
19
19
|
|
20
20
|
async def async_answer_question(self) -> AgentResponseDict:
|
21
|
-
"""Answer a question using the AI model.
|
21
|
+
"""Answer a question using the AI model.
|
22
|
+
|
23
|
+
>>> i = InvigilatorAI.example()
|
24
|
+
>>> i.answer_question()
|
25
|
+
{'message': '{"answer": "SPAM!"}'}
|
26
|
+
"""
|
22
27
|
params = self.get_prompts() | {"iteration": self.iteration}
|
23
28
|
raw_response = await self.async_get_response(**params)
|
24
29
|
data = {
|
@@ -29,6 +34,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
29
34
|
"raw_model_response": raw_response["raw_model_response"],
|
30
35
|
}
|
31
36
|
response = self._format_raw_response(**data)
|
37
|
+
# breakpoint()
|
32
38
|
return AgentResponseDict(**response)
|
33
39
|
|
34
40
|
async def async_get_response(
|
@@ -97,7 +103,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
97
103
|
answer = question._translate_answer_code_to_answer(
|
98
104
|
response["answer"], combined_dict
|
99
105
|
)
|
100
|
-
# breakpoint()
|
101
106
|
data = {
|
102
107
|
"answer": answer,
|
103
108
|
"comment": response.get(
|
@@ -214,6 +214,17 @@ class PromptConstructorMixin:
|
|
214
214
|
|
215
215
|
return self._agent_persona_prompt
|
216
216
|
|
217
|
+
def prior_answers_dict(self) -> dict:
|
218
|
+
d = self.survey.question_names_to_questions()
|
219
|
+
for question, answer in self.current_answers.items():
|
220
|
+
if question in d:
|
221
|
+
d[question].answer = answer
|
222
|
+
else:
|
223
|
+
# adds a comment to the question
|
224
|
+
if (new_question := question.split("_comment")[0]) in d:
|
225
|
+
d[new_question].comment = answer
|
226
|
+
return d
|
227
|
+
|
217
228
|
@property
|
218
229
|
def question_instructions_prompt(self) -> Prompt:
|
219
230
|
"""
|
@@ -266,29 +277,38 @@ class PromptConstructorMixin:
|
|
266
277
|
question_prompt = self.question.get_instructions(model=self.model.model)
|
267
278
|
|
268
279
|
# TODO: Try to populate the answers in the question object if they are available
|
269
|
-
d = self.survey.question_names_to_questions()
|
270
|
-
for question, answer in self.current_answers.items():
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
280
|
+
# d = self.survey.question_names_to_questions()
|
281
|
+
# for question, answer in self.current_answers.items():
|
282
|
+
# if question in d:
|
283
|
+
# d[question].answer = answer
|
284
|
+
# else:
|
285
|
+
# # adds a comment to the question
|
286
|
+
# if (new_question := question.split("_comment")[0]) in d:
|
287
|
+
# d[new_question].comment = answer
|
277
288
|
|
278
289
|
question_data = self.question.data.copy()
|
279
290
|
|
280
|
-
# check to see if the
|
291
|
+
# check to see if the question_options is actually a string
|
292
|
+
# This is used when the user is using the question_options as a variable from a sceario
|
281
293
|
if "question_options" in question_data:
|
282
294
|
if isinstance(self.question.data["question_options"], str):
|
283
295
|
from jinja2 import Environment, meta
|
296
|
+
|
284
297
|
env = Environment()
|
285
|
-
parsed_content = env.parse(self.question.data[
|
286
|
-
question_option_key = list(
|
287
|
-
|
298
|
+
parsed_content = env.parse(self.question.data["question_options"])
|
299
|
+
question_option_key = list(
|
300
|
+
meta.find_undeclared_variables(parsed_content)
|
301
|
+
)[0]
|
302
|
+
question_data["question_options"] = self.scenario.get(
|
303
|
+
question_option_key
|
304
|
+
)
|
288
305
|
|
289
|
-
#breakpoint()
|
306
|
+
# breakpoint()
|
290
307
|
rendered_instructions = question_prompt.render(
|
291
|
-
question_data
|
308
|
+
question_data
|
309
|
+
| self.scenario
|
310
|
+
| self.prior_answers_dict()
|
311
|
+
| {"agent": self.agent}
|
292
312
|
)
|
293
313
|
|
294
314
|
undefined_template_variables = (
|
@@ -321,7 +341,7 @@ class PromptConstructorMixin:
|
|
321
341
|
if self.memory_plan is not None:
|
322
342
|
memory_prompt += self.create_memory_prompt(
|
323
343
|
self.question.question_name
|
324
|
-
).render(self.scenario)
|
344
|
+
).render(self.scenario | self.prior_answers_dict())
|
325
345
|
self._prior_question_memory_prompt = memory_prompt
|
326
346
|
return self._prior_question_memory_prompt
|
327
347
|
|
edsl/config.py
CHANGED
@@ -65,6 +65,19 @@ CONFIG_MAP = {
|
|
65
65
|
# "default": None,
|
66
66
|
# "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
|
67
67
|
# },
|
68
|
+
# "GROQ_API_KEY": {
|
69
|
+
# "default": None,
|
70
|
+
# "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
|
71
|
+
# },
|
72
|
+
# "AWS_ACCESS_KEY_ID" :
|
73
|
+
# "default": None,
|
74
|
+
# "info": "This env var holds your AWS access key ID.",
|
75
|
+
# "AWS_SECRET_ACCESS_KEY:
|
76
|
+
# "default": None,
|
77
|
+
# "info": "This env var holds your AWS secret access key.",
|
78
|
+
# "AZURE_ENDPOINT_URL_AND_KEY":
|
79
|
+
# "default": None,
|
80
|
+
# "info": "This env var holds your Azure endpoint URL and key (URL:key). You can have several comma-separated URL-key pairs (URL1:key1,URL2:key2).",
|
68
81
|
}
|
69
82
|
|
70
83
|
|
@@ -109,8 +122,9 @@ class Config:
|
|
109
122
|
"""
|
110
123
|
# for each env var in the CONFIG_MAP
|
111
124
|
for env_var, config in CONFIG_MAP.items():
|
125
|
+
# we've set it already in _set_run_mode
|
112
126
|
if env_var == "EDSL_RUN_MODE":
|
113
|
-
continue
|
127
|
+
continue
|
114
128
|
value = os.getenv(env_var)
|
115
129
|
default_value = config.get("default")
|
116
130
|
# if the env var is set, set it as a CONFIG attribute
|
edsl/conjure/Conjure.py
CHANGED
@@ -35,6 +35,12 @@ class Conjure:
|
|
35
35
|
# The __init__ method in Conjure won't be called because __new__ returns a different class instance.
|
36
36
|
pass
|
37
37
|
|
38
|
+
@classmethod
|
39
|
+
def example(cls):
|
40
|
+
from edsl.conjure.InputData import InputDataABC
|
41
|
+
|
42
|
+
return InputDataABC.example()
|
43
|
+
|
38
44
|
|
39
45
|
if __name__ == "__main__":
|
40
46
|
pass
|
edsl/coop/coop.py
CHANGED
@@ -465,6 +465,7 @@ class Coop:
|
|
465
465
|
description: Optional[str] = None,
|
466
466
|
status: RemoteJobStatus = "queued",
|
467
467
|
visibility: Optional[VisibilityType] = "unlisted",
|
468
|
+
iterations: Optional[int] = 1,
|
468
469
|
) -> dict:
|
469
470
|
"""
|
470
471
|
Send a remote inference job to the server.
|
@@ -473,6 +474,7 @@ class Coop:
|
|
473
474
|
:param optional description: A description for this entry in the remote cache.
|
474
475
|
:param status: The status of the job. Should be 'queued', unless you are debugging.
|
475
476
|
:param visibility: The visibility of the cache entry.
|
477
|
+
:param iterations: The number of times to run each interview.
|
476
478
|
|
477
479
|
>>> job = Jobs.example()
|
478
480
|
>>> coop.remote_inference_create(job=job, description="My job")
|
@@ -488,6 +490,7 @@ class Coop:
|
|
488
490
|
),
|
489
491
|
"description": description,
|
490
492
|
"status": status,
|
493
|
+
"iterations": iterations,
|
491
494
|
"visibility": visibility,
|
492
495
|
"version": self._edsl_version,
|
493
496
|
},
|
@@ -498,6 +501,7 @@ class Coop:
|
|
498
501
|
"uuid": response_json.get("jobs_uuid"),
|
499
502
|
"description": response_json.get("description"),
|
500
503
|
"status": response_json.get("status"),
|
504
|
+
"iterations": response_json.get("iterations"),
|
501
505
|
"visibility": response_json.get("visibility"),
|
502
506
|
"version": self._edsl_version,
|
503
507
|
}
|
edsl/data/CacheHandler.py
CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
|
|
41
41
|
old_data = self.from_old_sqlite_cache()
|
42
42
|
self.cache.add_from_dict(old_data)
|
43
43
|
|
44
|
-
def create_cache_directory(self) -> None:
|
44
|
+
def create_cache_directory(self, notify=False) -> None:
|
45
45
|
"""
|
46
46
|
Create the cache directory if one is required and it does not exist.
|
47
47
|
"""
|
@@ -49,9 +49,8 @@ class CacheHandler:
|
|
49
49
|
dir_path = os.path.dirname(path)
|
50
50
|
if dir_path and not os.path.exists(dir_path):
|
51
51
|
os.makedirs(dir_path)
|
52
|
-
|
53
|
-
|
54
|
-
warnings.warn(f"Created cache directory: {dir_path}")
|
52
|
+
if notify:
|
53
|
+
print(f"Created cache directory: {dir_path}")
|
55
54
|
|
56
55
|
def gen_cache(self) -> Cache:
|
57
56
|
"""
|
edsl/enums.py
CHANGED
@@ -59,6 +59,9 @@ class InferenceServiceType(EnumWithChecks):
|
|
59
59
|
GOOGLE = "google"
|
60
60
|
TEST = "test"
|
61
61
|
ANTHROPIC = "anthropic"
|
62
|
+
GROQ = "groq"
|
63
|
+
AZURE = "azure"
|
64
|
+
OLLAMA = "ollama"
|
62
65
|
|
63
66
|
|
64
67
|
service_to_api_keyname = {
|
@@ -69,6 +72,8 @@ service_to_api_keyname = {
|
|
69
72
|
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
70
73
|
InferenceServiceType.TEST.value: "TBD",
|
71
74
|
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
75
|
+
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
76
|
+
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
72
77
|
}
|
73
78
|
|
74
79
|
|
edsl/exceptions/general.py
CHANGED
@@ -21,12 +21,14 @@ class GeneralErrors(Exception):
|
|
21
21
|
|
22
22
|
|
23
23
|
class MissingAPIKeyError(GeneralErrors):
|
24
|
-
def __init__(self, model_name, inference_service):
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
24
|
+
def __init__(self, full_message=None, model_name=None, inference_service=None):
|
25
|
+
if model_name and inference_service:
|
26
|
+
full_message = dedent(
|
27
|
+
f"""
|
28
|
+
An API Key for model `{model_name}` is missing from the .env file.
|
29
|
+
This key is associated with the inference service `{inference_service}`.
|
30
|
+
Please see https://docs.expectedparrot.com/en/latest/api_keys.html for more information.
|
31
|
+
"""
|
32
|
+
)
|
33
|
+
|
32
34
|
super().__init__(full_message)
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import re
|
4
|
+
import boto3
|
5
|
+
from botocore.exceptions import ClientError
|
6
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
8
|
+
import json
|
9
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
10
|
+
|
11
|
+
|
12
|
+
class AwsBedrockService(InferenceServiceABC):
|
13
|
+
"""AWS Bedrock service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "bedrock"
|
16
|
+
_env_key_name_ = (
|
17
|
+
"AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
|
18
|
+
)
|
19
|
+
|
20
|
+
@classmethod
|
21
|
+
def available(cls):
|
22
|
+
"""Fetch available models from AWS Bedrock."""
|
23
|
+
if not cls._models_list_cache:
|
24
|
+
client = boto3.client("bedrock", region_name="us-west-2")
|
25
|
+
all_models_ids = [
|
26
|
+
x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
|
27
|
+
]
|
28
|
+
else:
|
29
|
+
all_models_ids = cls._models_list_cache
|
30
|
+
|
31
|
+
return all_models_ids
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def create_model(
|
35
|
+
cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
|
36
|
+
) -> LanguageModel:
|
37
|
+
if model_class_name is None:
|
38
|
+
model_class_name = cls.to_class_name(model_name)
|
39
|
+
|
40
|
+
class LLM(LanguageModel):
|
41
|
+
"""
|
42
|
+
Child class of LanguageModel for interacting with AWS Bedrock models.
|
43
|
+
"""
|
44
|
+
|
45
|
+
_inference_service_ = cls._inference_service_
|
46
|
+
_model_ = model_name
|
47
|
+
_parameters_ = {
|
48
|
+
"temperature": 0.5,
|
49
|
+
"max_tokens": 512,
|
50
|
+
"top_p": 0.9,
|
51
|
+
}
|
52
|
+
|
53
|
+
async def async_execute_model_call(
|
54
|
+
self, user_prompt: str, system_prompt: str = ""
|
55
|
+
) -> dict[str, Any]:
|
56
|
+
"""Calls the AWS Bedrock API and returns the API response."""
|
57
|
+
|
58
|
+
api_token = (
|
59
|
+
self.api_token
|
60
|
+
) # call to check the if env variables are set.
|
61
|
+
|
62
|
+
client = boto3.client("bedrock-runtime", region_name="us-west-2")
|
63
|
+
|
64
|
+
conversation = [
|
65
|
+
{
|
66
|
+
"role": "user",
|
67
|
+
"content": [{"text": user_prompt}],
|
68
|
+
}
|
69
|
+
]
|
70
|
+
system = [
|
71
|
+
{
|
72
|
+
"text": system_prompt,
|
73
|
+
}
|
74
|
+
]
|
75
|
+
try:
|
76
|
+
response = client.converse(
|
77
|
+
modelId=self._model_,
|
78
|
+
messages=conversation,
|
79
|
+
inferenceConfig={
|
80
|
+
"maxTokens": self.max_tokens,
|
81
|
+
"temperature": self.temperature,
|
82
|
+
"topP": self.top_p,
|
83
|
+
},
|
84
|
+
# system=system,
|
85
|
+
additionalModelRequestFields={},
|
86
|
+
)
|
87
|
+
return response
|
88
|
+
except (ClientError, Exception) as e:
|
89
|
+
print(e)
|
90
|
+
return {"error": str(e)}
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
94
|
+
"""Parses the API response and returns the response text."""
|
95
|
+
if "output" in raw_response and "message" in raw_response["output"]:
|
96
|
+
response = raw_response["output"]["message"]["content"][0]["text"]
|
97
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
98
|
+
match = re.match(pattern, response, re.DOTALL)
|
99
|
+
if match:
|
100
|
+
return match.group(1)
|
101
|
+
else:
|
102
|
+
out = fix_partial_correct_response(response)
|
103
|
+
if "error" not in out:
|
104
|
+
response = out["extracted_json"]
|
105
|
+
return response
|
106
|
+
return "Error parsing response"
|
107
|
+
|
108
|
+
LLM.__name__ = model_class_name
|
109
|
+
|
110
|
+
return LLM
|
@@ -0,0 +1,197 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import re
|
4
|
+
from openai import AsyncAzureOpenAI
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
+
|
8
|
+
from azure.ai.inference.aio import ChatCompletionsClient
|
9
|
+
from azure.core.credentials import AzureKeyCredential
|
10
|
+
from azure.ai.inference.models import SystemMessage, UserMessage
|
11
|
+
import asyncio
|
12
|
+
import json
|
13
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
14
|
+
|
15
|
+
|
16
|
+
def json_handle_none(value: Any) -> Any:
|
17
|
+
"""
|
18
|
+
Handle None values during JSON serialization.
|
19
|
+
- Return "null" if the value is None. Otherwise, don't return anything.
|
20
|
+
"""
|
21
|
+
if value is None:
|
22
|
+
return "null"
|
23
|
+
|
24
|
+
|
25
|
+
class AzureAIService(InferenceServiceABC):
|
26
|
+
"""Azure AI service class."""
|
27
|
+
|
28
|
+
_inference_service_ = "azure"
|
29
|
+
_env_key_name_ = (
|
30
|
+
"AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
|
31
|
+
)
|
32
|
+
_model_id_to_endpoint_and_key = {}
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def available(cls):
|
36
|
+
out = []
|
37
|
+
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
38
|
+
if not azure_endpoints:
|
39
|
+
raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
40
|
+
azure_endpoints = azure_endpoints.split(",")
|
41
|
+
for data in azure_endpoints:
|
42
|
+
try:
|
43
|
+
# data has this format for non openai models https://model_id.azure_endpoint:azure_key
|
44
|
+
_, endpoint, azure_endpoint_key = data.split(":")
|
45
|
+
if "openai" not in endpoint:
|
46
|
+
model_id = endpoint.split(".")[0].replace("/", "")
|
47
|
+
out.append(model_id)
|
48
|
+
cls._model_id_to_endpoint_and_key[model_id] = {
|
49
|
+
"endpoint": f"https:{endpoint}",
|
50
|
+
"azure_endpoint_key": azure_endpoint_key,
|
51
|
+
}
|
52
|
+
else:
|
53
|
+
# data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
|
54
|
+
if "/deployments/" in endpoint:
|
55
|
+
start_idx = endpoint.index("/deployments/") + len(
|
56
|
+
"/deployments/"
|
57
|
+
)
|
58
|
+
end_idx = (
|
59
|
+
endpoint.index("/", start_idx)
|
60
|
+
if "/" in endpoint[start_idx:]
|
61
|
+
else len(endpoint)
|
62
|
+
)
|
63
|
+
model_id = endpoint[start_idx:end_idx]
|
64
|
+
api_version_value = None
|
65
|
+
if "api-version=" in endpoint:
|
66
|
+
start_idx = endpoint.index("api-version=") + len(
|
67
|
+
"api-version="
|
68
|
+
)
|
69
|
+
end_idx = (
|
70
|
+
endpoint.index("&", start_idx)
|
71
|
+
if "&" in endpoint[start_idx:]
|
72
|
+
else len(endpoint)
|
73
|
+
)
|
74
|
+
api_version_value = endpoint[start_idx:end_idx]
|
75
|
+
|
76
|
+
cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
|
77
|
+
"endpoint": f"https:{endpoint}",
|
78
|
+
"azure_endpoint_key": azure_endpoint_key,
|
79
|
+
"api_version": api_version_value,
|
80
|
+
}
|
81
|
+
out.append(f"azure:{model_id}")
|
82
|
+
|
83
|
+
except Exception as e:
|
84
|
+
raise e
|
85
|
+
return out
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def create_model(
|
89
|
+
cls, model_name: str = "azureai", model_class_name=None
|
90
|
+
) -> LanguageModel:
|
91
|
+
if model_class_name is None:
|
92
|
+
model_class_name = cls.to_class_name(model_name)
|
93
|
+
|
94
|
+
class LLM(LanguageModel):
|
95
|
+
"""
|
96
|
+
Child class of LanguageModel for interacting with Azure OpenAI models.
|
97
|
+
"""
|
98
|
+
|
99
|
+
_inference_service_ = cls._inference_service_
|
100
|
+
_model_ = model_name
|
101
|
+
_parameters_ = {
|
102
|
+
"temperature": 0.5,
|
103
|
+
"max_tokens": 512,
|
104
|
+
"top_p": 0.9,
|
105
|
+
}
|
106
|
+
|
107
|
+
async def async_execute_model_call(
|
108
|
+
self, user_prompt: str, system_prompt: str = ""
|
109
|
+
) -> dict[str, Any]:
|
110
|
+
"""Calls the Azure OpenAI API and returns the API response."""
|
111
|
+
|
112
|
+
try:
|
113
|
+
api_key = cls._model_id_to_endpoint_and_key[model_name][
|
114
|
+
"azure_endpoint_key"
|
115
|
+
]
|
116
|
+
except:
|
117
|
+
api_key = None
|
118
|
+
|
119
|
+
if not api_key:
|
120
|
+
raise EnvironmentError(
|
121
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
122
|
+
)
|
123
|
+
|
124
|
+
try:
|
125
|
+
endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
|
126
|
+
except:
|
127
|
+
endpoint = None
|
128
|
+
|
129
|
+
if not endpoint:
|
130
|
+
raise EnvironmentError(
|
131
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
132
|
+
)
|
133
|
+
|
134
|
+
if "openai" not in endpoint:
|
135
|
+
client = ChatCompletionsClient(
|
136
|
+
endpoint=endpoint,
|
137
|
+
credential=AzureKeyCredential(api_key),
|
138
|
+
temperature=self.temperature,
|
139
|
+
top_p=self.top_p,
|
140
|
+
max_tokens=self.max_tokens,
|
141
|
+
)
|
142
|
+
try:
|
143
|
+
response = await client.complete(
|
144
|
+
messages=[
|
145
|
+
SystemMessage(content=system_prompt),
|
146
|
+
UserMessage(content=user_prompt),
|
147
|
+
],
|
148
|
+
# model_extras={"safe_mode": True},
|
149
|
+
)
|
150
|
+
await client.close()
|
151
|
+
return response.as_dict()
|
152
|
+
except Exception as e:
|
153
|
+
await client.close()
|
154
|
+
return {"error": str(e)}
|
155
|
+
else:
|
156
|
+
api_version = cls._model_id_to_endpoint_and_key[model_name][
|
157
|
+
"api_version"
|
158
|
+
]
|
159
|
+
client = AsyncAzureOpenAI(
|
160
|
+
azure_endpoint=endpoint,
|
161
|
+
api_version=api_version,
|
162
|
+
api_key=api_key,
|
163
|
+
)
|
164
|
+
response = await client.chat.completions.create(
|
165
|
+
model=model_name,
|
166
|
+
messages=[
|
167
|
+
{
|
168
|
+
"role": "user",
|
169
|
+
"content": user_prompt, # Your question can go here
|
170
|
+
},
|
171
|
+
],
|
172
|
+
)
|
173
|
+
return response.model_dump()
|
174
|
+
|
175
|
+
@staticmethod
|
176
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
177
|
+
"""Parses the API response and returns the response text."""
|
178
|
+
if (
|
179
|
+
raw_response
|
180
|
+
and "choices" in raw_response
|
181
|
+
and raw_response["choices"]
|
182
|
+
):
|
183
|
+
response = raw_response["choices"][0]["message"]["content"]
|
184
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
185
|
+
match = re.match(pattern, response, re.DOTALL)
|
186
|
+
if match:
|
187
|
+
return match.group(1)
|
188
|
+
else:
|
189
|
+
out = fix_partial_correct_response(response)
|
190
|
+
if "error" not in out:
|
191
|
+
response = out["extracted_json"]
|
192
|
+
return response
|
193
|
+
return "Error parsing response"
|
194
|
+
|
195
|
+
LLM.__name__ = model_class_name
|
196
|
+
|
197
|
+
return LLM
|
@@ -2,102 +2,17 @@ import aiohttp
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from typing import Any, List
|
5
|
-
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
7
|
from edsl.language_models import LanguageModel
|
7
8
|
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
8
11
|
|
9
|
-
class DeepInfraService(
|
12
|
+
class DeepInfraService(OpenAIService):
|
10
13
|
"""DeepInfra service class."""
|
11
14
|
|
12
15
|
_inference_service_ = "deep_infra"
|
13
16
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
14
|
-
|
17
|
+
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
15
18
|
_models_list_cache: List[str] = []
|
16
|
-
|
17
|
-
@classmethod
|
18
|
-
def available(cls):
|
19
|
-
text_models = cls.full_details_available()
|
20
|
-
return [m["model_name"] for m in text_models]
|
21
|
-
|
22
|
-
@classmethod
|
23
|
-
def full_details_available(cls, verbose=False):
|
24
|
-
if not cls._models_list_cache:
|
25
|
-
url = "https://api.deepinfra.com/models/list"
|
26
|
-
response = requests.get(url)
|
27
|
-
if response.status_code == 200:
|
28
|
-
text_generation_models = [
|
29
|
-
r for r in response.json() if r["type"] == "text-generation"
|
30
|
-
]
|
31
|
-
cls._models_list_cache = text_generation_models
|
32
|
-
|
33
|
-
from rich import print_json
|
34
|
-
import json
|
35
|
-
|
36
|
-
if verbose:
|
37
|
-
print_json(json.dumps(text_generation_models))
|
38
|
-
return text_generation_models
|
39
|
-
else:
|
40
|
-
return f"Failed to fetch data: Status code {response.status_code}"
|
41
|
-
else:
|
42
|
-
return cls._models_list_cache
|
43
|
-
|
44
|
-
@classmethod
|
45
|
-
def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
|
46
|
-
base_url = "https://api.deepinfra.com/v1/inference/"
|
47
|
-
if model_class_name is None:
|
48
|
-
model_class_name = cls.to_class_name(model_name)
|
49
|
-
url = f"{base_url}{model_name}"
|
50
|
-
|
51
|
-
class LLM(LanguageModel):
|
52
|
-
_inference_service_ = cls._inference_service_
|
53
|
-
_model_ = model_name
|
54
|
-
_parameters_ = {
|
55
|
-
"temperature": 0.7,
|
56
|
-
"top_p": 0.2,
|
57
|
-
"top_k": 0.1,
|
58
|
-
"max_new_tokens": 512,
|
59
|
-
"stopSequences": [],
|
60
|
-
}
|
61
|
-
|
62
|
-
async def async_execute_model_call(
|
63
|
-
self, user_prompt: str, system_prompt: str = ""
|
64
|
-
) -> dict[str, Any]:
|
65
|
-
self.url = url
|
66
|
-
headers = {
|
67
|
-
"Content-Type": "application/json",
|
68
|
-
"Authorization": f"bearer {self.api_token}",
|
69
|
-
}
|
70
|
-
# don't mess w/ the newlines
|
71
|
-
data = {
|
72
|
-
"input": f"""
|
73
|
-
[INST]<<SYS>>
|
74
|
-
{system_prompt}
|
75
|
-
<<SYS>>{user_prompt}[/INST]
|
76
|
-
""",
|
77
|
-
"stream": False,
|
78
|
-
"temperature": self.temperature,
|
79
|
-
"top_p": self.top_p,
|
80
|
-
"top_k": self.top_k,
|
81
|
-
"max_new_tokens": self.max_new_tokens,
|
82
|
-
}
|
83
|
-
async with aiohttp.ClientSession() as session:
|
84
|
-
async with session.post(
|
85
|
-
self.url, headers=headers, data=json.dumps(data)
|
86
|
-
) as response:
|
87
|
-
raw_response_text = await response.text()
|
88
|
-
return json.loads(raw_response_text)
|
89
|
-
|
90
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
91
|
-
if "results" not in raw_response:
|
92
|
-
raise Exception(
|
93
|
-
f"Deep Infra response does not contain 'results' key: {raw_response}"
|
94
|
-
)
|
95
|
-
if "generated_text" not in raw_response["results"][0]:
|
96
|
-
raise Exception(
|
97
|
-
f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
|
98
|
-
)
|
99
|
-
return raw_response["results"][0]["generated_text"]
|
100
|
-
|
101
|
-
LLM.__name__ = model_class_name
|
102
|
-
|
103
|
-
return LLM
|