edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +136 -221
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +48 -47
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data/CacheHandler.py +3 -4
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +8 -0
- edsl/exceptions/general.py +10 -8
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +112 -0
- edsl/inference_services/AzureAI.py +214 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +5 -4
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +55 -56
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +25 -0
- edsl/inference_services/registry.py +19 -1
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +137 -41
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +105 -18
- edsl/jobs/interviews/Interview.py +393 -83
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
- edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +205 -126
- edsl/language_models/LanguageModel.py +297 -177
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +25 -8
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -42
- edsl/questions/QuestionCheckBox.py +227 -36
- edsl/questions/QuestionExtract.py +98 -28
- edsl/questions/QuestionFreeText.py +47 -31
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -23
- edsl/questions/QuestionMultipleChoice.py +159 -66
- edsl/questions/QuestionNumerical.py +88 -47
- edsl/questions/QuestionRank.py +182 -25
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +58 -30
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +109 -24
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +546 -21
- edsl/scenarios/ScenarioListExportMixin.py +24 -4
- edsl/scenarios/ScenarioListPdfMixin.py +153 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +15 -3
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +707 -298
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +40 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.31.dev4.dist-info/RECORD +0 -204
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -0,0 +1,120 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List
|
3
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
import asyncio
|
6
|
+
from mistralai import Mistral
|
7
|
+
|
8
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
+
|
10
|
+
|
11
|
+
class MistralAIService(InferenceServiceABC):
|
12
|
+
"""Mistral AI service class."""
|
13
|
+
|
14
|
+
key_sequence = ["choices", 0, "message", "content"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
|
17
|
+
_inference_service_ = "mistral"
|
18
|
+
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
+
input_token_name = "prompt_tokens"
|
20
|
+
output_token_name = "completion_tokens"
|
21
|
+
|
22
|
+
_sync_client_instance = None
|
23
|
+
_async_client_instance = None
|
24
|
+
|
25
|
+
_sync_client = Mistral
|
26
|
+
_async_client = Mistral
|
27
|
+
|
28
|
+
_models_list_cache: List[str] = []
|
29
|
+
model_exclude_list = []
|
30
|
+
|
31
|
+
def __init_subclass__(cls, **kwargs):
|
32
|
+
super().__init_subclass__(**kwargs)
|
33
|
+
# so subclasses have to create their own instances of the clients
|
34
|
+
cls._sync_client_instance = None
|
35
|
+
cls._async_client_instance = None
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def sync_client(cls):
|
39
|
+
if cls._sync_client_instance is None:
|
40
|
+
cls._sync_client_instance = cls._sync_client(
|
41
|
+
api_key=os.getenv(cls._env_key_name_)
|
42
|
+
)
|
43
|
+
return cls._sync_client_instance
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def async_client(cls):
|
47
|
+
if cls._async_client_instance is None:
|
48
|
+
cls._async_client_instance = cls._async_client(
|
49
|
+
api_key=os.getenv(cls._env_key_name_)
|
50
|
+
)
|
51
|
+
return cls._async_client_instance
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def available(cls) -> list[str]:
|
55
|
+
if not cls._models_list_cache:
|
56
|
+
cls._models_list_cache = [
|
57
|
+
m.id for m in cls.sync_client().models.list().data
|
58
|
+
]
|
59
|
+
|
60
|
+
return cls._models_list_cache
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def create_model(
|
64
|
+
cls, model_name: str = "mistral", model_class_name=None
|
65
|
+
) -> LanguageModel:
|
66
|
+
if model_class_name is None:
|
67
|
+
model_class_name = cls.to_class_name(model_name)
|
68
|
+
|
69
|
+
class LLM(LanguageModel):
|
70
|
+
"""
|
71
|
+
Child class of LanguageModel for interacting with Mistral models.
|
72
|
+
"""
|
73
|
+
|
74
|
+
key_sequence = cls.key_sequence
|
75
|
+
usage_sequence = cls.usage_sequence
|
76
|
+
|
77
|
+
input_token_name = cls.input_token_name
|
78
|
+
output_token_name = cls.output_token_name
|
79
|
+
|
80
|
+
_inference_service_ = cls._inference_service_
|
81
|
+
_model_ = model_name
|
82
|
+
_parameters_ = {
|
83
|
+
"temperature": 0.5,
|
84
|
+
"max_tokens": 512,
|
85
|
+
"top_p": 0.9,
|
86
|
+
}
|
87
|
+
|
88
|
+
_tpm = cls.get_tpm(cls)
|
89
|
+
_rpm = cls.get_rpm(cls)
|
90
|
+
|
91
|
+
def sync_client(self):
|
92
|
+
return cls.sync_client()
|
93
|
+
|
94
|
+
def async_client(self):
|
95
|
+
return cls.async_client()
|
96
|
+
|
97
|
+
async def async_execute_model_call(
|
98
|
+
self, user_prompt: str, system_prompt: str = ""
|
99
|
+
) -> dict[str, Any]:
|
100
|
+
"""Calls the Mistral API and returns the API response."""
|
101
|
+
s = self.async_client()
|
102
|
+
|
103
|
+
try:
|
104
|
+
res = await s.chat.complete_async(
|
105
|
+
model=model_name,
|
106
|
+
messages=[
|
107
|
+
{
|
108
|
+
"content": user_prompt,
|
109
|
+
"role": "user",
|
110
|
+
},
|
111
|
+
],
|
112
|
+
)
|
113
|
+
except Exception as e:
|
114
|
+
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
115
|
+
|
116
|
+
return res.model_dump()
|
117
|
+
|
118
|
+
LLM.__name__ = model_class_name
|
119
|
+
|
120
|
+
return LLM
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class OllamaService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "ollama"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "http://localhost:11434/v1"
|
18
|
+
_models_list_cache: List[str] = []
|
@@ -1,12 +1,15 @@
|
|
1
|
-
from
|
2
|
-
import
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
import os
|
4
|
-
|
4
|
+
|
5
5
|
import openai
|
6
6
|
|
7
7
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
8
|
from edsl.language_models import LanguageModel
|
9
9
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
10
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
11
|
+
|
12
|
+
from edsl.config import CONFIG
|
10
13
|
|
11
14
|
|
12
15
|
class OpenAIService(InferenceServiceABC):
|
@@ -18,20 +21,37 @@ class OpenAIService(InferenceServiceABC):
|
|
18
21
|
|
19
22
|
_sync_client_ = openai.OpenAI
|
20
23
|
_async_client_ = openai.AsyncOpenAI
|
21
|
-
|
24
|
+
|
25
|
+
_sync_client_instance = None
|
26
|
+
_async_client_instance = None
|
27
|
+
|
28
|
+
key_sequence = ["choices", 0, "message", "content"]
|
29
|
+
usage_sequence = ["usage"]
|
30
|
+
input_token_name = "prompt_tokens"
|
31
|
+
output_token_name = "completion_tokens"
|
32
|
+
|
33
|
+
def __init_subclass__(cls, **kwargs):
|
34
|
+
super().__init_subclass__(**kwargs)
|
35
|
+
# so subclasses have to create their own instances of the clients
|
36
|
+
cls._sync_client_instance = None
|
37
|
+
cls._async_client_instance = None
|
38
|
+
|
22
39
|
@classmethod
|
23
40
|
def sync_client(cls):
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
41
|
+
if cls._sync_client_instance is None:
|
42
|
+
cls._sync_client_instance = cls._sync_client_(
|
43
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
44
|
+
)
|
45
|
+
return cls._sync_client_instance
|
46
|
+
|
28
47
|
@classmethod
|
29
48
|
def async_client(cls):
|
30
|
-
|
31
|
-
|
32
|
-
|
49
|
+
if cls._async_client_instance is None:
|
50
|
+
cls._async_client_instance = cls._async_client_(
|
51
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
52
|
+
)
|
53
|
+
return cls._async_client_instance
|
33
54
|
|
34
|
-
# TODO: Make this a coop call
|
35
55
|
model_exclude_list = [
|
36
56
|
"whisper-1",
|
37
57
|
"davinci-002",
|
@@ -46,6 +66,8 @@ class OpenAIService(InferenceServiceABC):
|
|
46
66
|
"text-embedding-3-small",
|
47
67
|
"text-embedding-ada-002",
|
48
68
|
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
69
|
+
"gpt-3.5-turbo-instruct-0914",
|
70
|
+
"gpt-3.5-turbo-instruct",
|
49
71
|
]
|
50
72
|
_models_list_cache: List[str] = []
|
51
73
|
|
@@ -59,27 +81,15 @@ class OpenAIService(InferenceServiceABC):
|
|
59
81
|
|
60
82
|
@classmethod
|
61
83
|
def available(cls) -> List[str]:
|
62
|
-
#from openai import OpenAI
|
63
|
-
|
64
84
|
if not cls._models_list_cache:
|
65
85
|
try:
|
66
|
-
#client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
67
86
|
cls._models_list_cache = [
|
68
87
|
m.id
|
69
|
-
for m in cls.get_model_list()
|
88
|
+
for m in cls.get_model_list()
|
70
89
|
if m.id not in cls.model_exclude_list
|
71
90
|
]
|
72
91
|
except Exception as e:
|
73
92
|
raise
|
74
|
-
# print(
|
75
|
-
# f"""Error retrieving models: {e}.
|
76
|
-
# See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
|
77
|
-
# )
|
78
|
-
# cls._models_list_cache = [
|
79
|
-
# "gpt-3.5-turbo",
|
80
|
-
# "gpt-4-1106-preview",
|
81
|
-
# "gpt-4",
|
82
|
-
# ] # Fallback list
|
83
93
|
return cls._models_list_cache
|
84
94
|
|
85
95
|
@classmethod
|
@@ -92,6 +102,14 @@ class OpenAIService(InferenceServiceABC):
|
|
92
102
|
Child class of LanguageModel for interacting with OpenAI models
|
93
103
|
"""
|
94
104
|
|
105
|
+
key_sequence = cls.key_sequence
|
106
|
+
usage_sequence = cls.usage_sequence
|
107
|
+
input_token_name = cls.input_token_name
|
108
|
+
output_token_name = cls.output_token_name
|
109
|
+
|
110
|
+
_rpm = cls.get_rpm(cls)
|
111
|
+
_tpm = cls.get_tpm(cls)
|
112
|
+
|
95
113
|
_inference_service_ = cls._inference_service_
|
96
114
|
_model_ = model_name
|
97
115
|
_parameters_ = {
|
@@ -106,21 +124,15 @@ class OpenAIService(InferenceServiceABC):
|
|
106
124
|
|
107
125
|
def sync_client(self):
|
108
126
|
return cls.sync_client()
|
109
|
-
|
127
|
+
|
110
128
|
def async_client(self):
|
111
129
|
return cls.async_client()
|
112
130
|
|
113
131
|
@classmethod
|
114
132
|
def available(cls) -> list[str]:
|
115
|
-
#import openai
|
116
|
-
#client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
117
|
-
#return client.models.list()
|
118
133
|
return cls.sync_client().models.list()
|
119
|
-
|
120
|
-
def get_headers(self) -> dict[str, Any]:
|
121
|
-
#from openai import OpenAI
|
122
134
|
|
123
|
-
|
135
|
+
def get_headers(self) -> dict[str, Any]:
|
124
136
|
client = self.sync_client()
|
125
137
|
response = client.chat.completions.with_raw_response.create(
|
126
138
|
messages=[
|
@@ -157,6 +169,9 @@ class OpenAIService(InferenceServiceABC):
|
|
157
169
|
user_prompt: str,
|
158
170
|
system_prompt: str = "",
|
159
171
|
encoded_image=None,
|
172
|
+
invigilator: Optional[
|
173
|
+
"InvigilatorAI"
|
174
|
+
] = None, # TBD - can eventually be used for function-calling
|
160
175
|
) -> dict[str, Any]:
|
161
176
|
"""Calls the OpenAI API and returns the API response."""
|
162
177
|
if encoded_image:
|
@@ -171,17 +186,16 @@ class OpenAIService(InferenceServiceABC):
|
|
171
186
|
)
|
172
187
|
else:
|
173
188
|
content = user_prompt
|
174
|
-
# self.client = AsyncOpenAI(
|
175
|
-
# api_key = os.getenv(cls._env_key_name_),
|
176
|
-
# base_url = cls._base_url_
|
177
|
-
# )
|
178
189
|
client = self.async_client()
|
190
|
+
messages = [
|
191
|
+
{"role": "system", "content": system_prompt},
|
192
|
+
{"role": "user", "content": content},
|
193
|
+
]
|
194
|
+
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
195
|
+
messages = messages[1:]
|
179
196
|
params = {
|
180
197
|
"model": self.model,
|
181
|
-
"messages":
|
182
|
-
{"role": "system", "content": system_prompt},
|
183
|
-
{"role": "user", "content": content},
|
184
|
-
],
|
198
|
+
"messages": messages,
|
185
199
|
"temperature": self.temperature,
|
186
200
|
"max_tokens": self.max_tokens,
|
187
201
|
"top_p": self.top_p,
|
@@ -193,21 +207,6 @@ class OpenAIService(InferenceServiceABC):
|
|
193
207
|
response = await client.chat.completions.create(**params)
|
194
208
|
return response.model_dump()
|
195
209
|
|
196
|
-
@staticmethod
|
197
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
198
|
-
"""Parses the API response and returns the response text."""
|
199
|
-
try:
|
200
|
-
response = raw_response["choices"][0]["message"]["content"]
|
201
|
-
except KeyError:
|
202
|
-
print("Tried to parse response but failed:")
|
203
|
-
print(raw_response)
|
204
|
-
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
205
|
-
match = re.match(pattern, response, re.DOTALL)
|
206
|
-
if match:
|
207
|
-
return match.group(1)
|
208
|
-
else:
|
209
|
-
return response
|
210
|
-
|
211
210
|
LLM.__name__ = "LanguageModel"
|
212
211
|
|
213
212
|
return LLM
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
import os
|
3
|
+
import asyncio
|
4
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
+
from edsl.language_models import LanguageModel
|
6
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
7
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
8
|
+
|
9
|
+
from edsl.enums import InferenceServiceType
|
10
|
+
import random
|
11
|
+
|
12
|
+
|
13
|
+
class TestService(InferenceServiceABC):
|
14
|
+
"""OpenAI service class."""
|
15
|
+
|
16
|
+
_inference_service_ = "test"
|
17
|
+
_env_key_name_ = None
|
18
|
+
_base_url_ = None
|
19
|
+
|
20
|
+
_sync_client_ = None
|
21
|
+
_async_client_ = None
|
22
|
+
|
23
|
+
_sync_client_instance = None
|
24
|
+
_async_client_instance = None
|
25
|
+
|
26
|
+
key_sequence = None
|
27
|
+
usage_sequence = None
|
28
|
+
model_exclude_list = []
|
29
|
+
input_token_name = "prompt_tokens"
|
30
|
+
output_token_name = "completion_tokens"
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def available(cls) -> list[str]:
|
34
|
+
return ["test"]
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
38
|
+
throw_exception = False
|
39
|
+
|
40
|
+
class TestServiceLanguageModel(LanguageModel):
|
41
|
+
_model_ = "test"
|
42
|
+
_parameters_ = {"temperature": 0.5}
|
43
|
+
_inference_service_ = InferenceServiceType.TEST.value
|
44
|
+
usage_sequence = ["usage"]
|
45
|
+
key_sequence = ["message", 0, "text"]
|
46
|
+
input_token_name = cls.input_token_name
|
47
|
+
output_token_name = cls.output_token_name
|
48
|
+
_rpm = 1000
|
49
|
+
_tpm = 100000
|
50
|
+
|
51
|
+
@property
|
52
|
+
def _canned_response(self):
|
53
|
+
if hasattr(self, "canned_response"):
|
54
|
+
return self.canned_response
|
55
|
+
else:
|
56
|
+
return "Hello, world"
|
57
|
+
|
58
|
+
async def async_execute_model_call(
|
59
|
+
self,
|
60
|
+
user_prompt: str,
|
61
|
+
system_prompt: str,
|
62
|
+
encoded_image=None,
|
63
|
+
) -> dict[str, Any]:
|
64
|
+
await asyncio.sleep(0.1)
|
65
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
66
|
+
|
67
|
+
if hasattr(self, "throw_exception") and self.throw_exception:
|
68
|
+
if hasattr(self, "exception_probability"):
|
69
|
+
p = self.exception_probability
|
70
|
+
else:
|
71
|
+
p = 1
|
72
|
+
|
73
|
+
if random.random() < p:
|
74
|
+
raise Exception("This is a test error")
|
75
|
+
return {
|
76
|
+
"message": [{"text": f"{self._canned_response}"}],
|
77
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
78
|
+
}
|
79
|
+
|
80
|
+
return TestServiceLanguageModel
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
import openai
|
11
|
+
|
12
|
+
|
13
|
+
class TogetherAIService(OpenAIService):
|
14
|
+
"""DeepInfra service class."""
|
15
|
+
|
16
|
+
_inference_service_ = "together"
|
17
|
+
_env_key_name_ = "TOGETHER_API_KEY"
|
18
|
+
_base_url_ = "https://api.together.xyz/v1"
|
19
|
+
_models_list_cache: List[str] = []
|
20
|
+
|
21
|
+
# These are non-serverless models. There was no api param to filter them
|
22
|
+
model_exclude_list = [
|
23
|
+
"EleutherAI/llemma_7b",
|
24
|
+
"HuggingFaceH4/zephyr-7b-beta",
|
25
|
+
"Nexusflow/NexusRaven-V2-13B",
|
26
|
+
"NousResearch/Hermes-2-Theta-Llama-3-70B",
|
27
|
+
"NousResearch/Nous-Capybara-7B-V1p9",
|
28
|
+
"NousResearch/Nous-Hermes-13b",
|
29
|
+
"NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
|
30
|
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
|
31
|
+
"NousResearch/Nous-Hermes-Llama2-13b",
|
32
|
+
"NousResearch/Nous-Hermes-Llama2-70b",
|
33
|
+
"NousResearch/Nous-Hermes-llama-2-7b",
|
34
|
+
"NumbersStation/nsql-llama-2-7B",
|
35
|
+
"Open-Orca/Mistral-7B-OpenOrca",
|
36
|
+
"Phind/Phind-CodeLlama-34B-Python-v1",
|
37
|
+
"Phind/Phind-CodeLlama-34B-v2",
|
38
|
+
"Qwen/Qwen1.5-0.5B",
|
39
|
+
"Qwen/Qwen1.5-0.5B-Chat",
|
40
|
+
"Qwen/Qwen1.5-1.8B",
|
41
|
+
"Qwen/Qwen1.5-1.8B-Chat",
|
42
|
+
"Qwen/Qwen1.5-14B",
|
43
|
+
"Qwen/Qwen1.5-14B-Chat",
|
44
|
+
"Qwen/Qwen1.5-32B",
|
45
|
+
"Qwen/Qwen1.5-32B-Chat",
|
46
|
+
"Qwen/Qwen1.5-4B",
|
47
|
+
"Qwen/Qwen1.5-4B-Chat",
|
48
|
+
"Qwen/Qwen1.5-72B",
|
49
|
+
"Qwen/Qwen1.5-7B",
|
50
|
+
"Qwen/Qwen1.5-7B-Chat",
|
51
|
+
"Qwen/Qwen2-1.5B",
|
52
|
+
"Qwen/Qwen2-1.5B-Instruct",
|
53
|
+
"Qwen/Qwen2-72B",
|
54
|
+
"Qwen/Qwen2-7B",
|
55
|
+
"Qwen/Qwen2-7B-Instruct",
|
56
|
+
"SG161222/Realistic_Vision_V3.0_VAE",
|
57
|
+
"Snowflake/snowflake-arctic-instruct",
|
58
|
+
"Undi95/ReMM-SLERP-L2-13B",
|
59
|
+
"Undi95/Toppy-M-7B",
|
60
|
+
"WizardLM/WizardCoder-Python-34B-V1.0",
|
61
|
+
"WizardLM/WizardLM-13B-V1.2",
|
62
|
+
"WizardLM/WizardLM-70B-V1.0",
|
63
|
+
"allenai/OLMo-7B",
|
64
|
+
"allenai/OLMo-7B-Instruct",
|
65
|
+
"bert-base-uncased",
|
66
|
+
"codellama/CodeLlama-13b-Instruct-hf",
|
67
|
+
"codellama/CodeLlama-13b-Python-hf",
|
68
|
+
"codellama/CodeLlama-13b-hf",
|
69
|
+
"codellama/CodeLlama-34b-Python-hf",
|
70
|
+
"codellama/CodeLlama-34b-hf",
|
71
|
+
"codellama/CodeLlama-70b-Instruct-hf",
|
72
|
+
"codellama/CodeLlama-70b-Python-hf",
|
73
|
+
"codellama/CodeLlama-70b-hf",
|
74
|
+
"codellama/CodeLlama-7b-Instruct-hf",
|
75
|
+
"codellama/CodeLlama-7b-Python-hf",
|
76
|
+
"codellama/CodeLlama-7b-hf",
|
77
|
+
"cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
78
|
+
"deepseek-ai/deepseek-coder-33b-instruct",
|
79
|
+
"garage-bAInd/Platypus2-70B-instruct",
|
80
|
+
"google/gemma-2b",
|
81
|
+
"google/gemma-7b",
|
82
|
+
"google/gemma-7b-it",
|
83
|
+
"gradientai/Llama-3-70B-Instruct-Gradient-1048k",
|
84
|
+
"hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1",
|
85
|
+
"huggyllama/llama-13b",
|
86
|
+
"huggyllama/llama-30b",
|
87
|
+
"huggyllama/llama-65b",
|
88
|
+
"huggyllama/llama-7b",
|
89
|
+
"lmsys/vicuna-13b-v1.3",
|
90
|
+
"lmsys/vicuna-13b-v1.5",
|
91
|
+
"lmsys/vicuna-13b-v1.5-16k",
|
92
|
+
"lmsys/vicuna-7b-v1.3",
|
93
|
+
"lmsys/vicuna-7b-v1.5",
|
94
|
+
"meta-llama/Llama-2-13b-hf",
|
95
|
+
"meta-llama/Llama-2-70b-chat-hf",
|
96
|
+
"meta-llama/Llama-2-7b-hf",
|
97
|
+
"meta-llama/Llama-3-70b-hf",
|
98
|
+
"meta-llama/Llama-3-8b-hf",
|
99
|
+
"meta-llama/Meta-Llama-3-70B",
|
100
|
+
"meta-llama/Meta-Llama-3-70B-Instruct",
|
101
|
+
"meta-llama/Meta-Llama-3-8B-Instruct",
|
102
|
+
"meta-llama/Meta-Llama-3.1-70B-Instruct-Reference",
|
103
|
+
"meta-llama/Meta-Llama-3.1-70B-Reference",
|
104
|
+
"meta-llama/Meta-Llama-3.1-8B-Reference",
|
105
|
+
"microsoft/phi-2",
|
106
|
+
"mistralai/Mixtral-8x22B",
|
107
|
+
"openchat/openchat-3.5-1210",
|
108
|
+
"prompthero/openjourney",
|
109
|
+
"runwayml/stable-diffusion-v1-5",
|
110
|
+
"sentence-transformers/msmarco-bert-base-dot-v5",
|
111
|
+
"snorkelai/Snorkel-Mistral-PairRM-DPO",
|
112
|
+
"stabilityai/stable-diffusion-2-1",
|
113
|
+
"teknium/OpenHermes-2-Mistral-7B",
|
114
|
+
"teknium/OpenHermes-2p5-Mistral-7B",
|
115
|
+
"togethercomputer/CodeLlama-13b-Instruct",
|
116
|
+
"togethercomputer/CodeLlama-13b-Python",
|
117
|
+
"togethercomputer/CodeLlama-34b",
|
118
|
+
"togethercomputer/CodeLlama-34b-Python",
|
119
|
+
"togethercomputer/CodeLlama-7b-Instruct",
|
120
|
+
"togethercomputer/CodeLlama-7b-Python",
|
121
|
+
"togethercomputer/Koala-13B",
|
122
|
+
"togethercomputer/Koala-7B",
|
123
|
+
"togethercomputer/LLaMA-2-7B-32K",
|
124
|
+
"togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4",
|
125
|
+
"togethercomputer/StripedHyena-Hessian-7B",
|
126
|
+
"togethercomputer/alpaca-7b",
|
127
|
+
"togethercomputer/evo-1-131k-base",
|
128
|
+
"togethercomputer/evo-1-8k-base",
|
129
|
+
"togethercomputer/guanaco-13b",
|
130
|
+
"togethercomputer/guanaco-33b",
|
131
|
+
"togethercomputer/guanaco-65b",
|
132
|
+
"togethercomputer/guanaco-7b",
|
133
|
+
"togethercomputer/llama-2-13b",
|
134
|
+
"togethercomputer/llama-2-70b-chat",
|
135
|
+
"togethercomputer/llama-2-7b",
|
136
|
+
"wavymulder/Analog-Diffusion",
|
137
|
+
"zero-one-ai/Yi-34B",
|
138
|
+
"zero-one-ai/Yi-34B-Chat",
|
139
|
+
"zero-one-ai/Yi-6B",
|
140
|
+
]
|
141
|
+
|
142
|
+
_sync_client_ = openai.OpenAI
|
143
|
+
_async_client_ = openai.AsyncOpenAI
|
144
|
+
|
145
|
+
@classmethod
|
146
|
+
def get_model_list(cls):
|
147
|
+
# Togheter.ai has a different response in model list then openai
|
148
|
+
# and the OpenAI class returns an error when calling .models.list()
|
149
|
+
import requests
|
150
|
+
import os
|
151
|
+
|
152
|
+
url = "https://api.together.xyz/v1/models?filter=serverless"
|
153
|
+
token = os.getenv(cls._env_key_name_)
|
154
|
+
headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
|
155
|
+
|
156
|
+
response = requests.get(url, headers=headers)
|
157
|
+
return response.json()
|
158
|
+
|
159
|
+
@classmethod
|
160
|
+
def available(cls) -> List[str]:
|
161
|
+
if not cls._models_list_cache:
|
162
|
+
try:
|
163
|
+
cls._models_list_cache = [
|
164
|
+
m["id"]
|
165
|
+
for m in cls.get_model_list()
|
166
|
+
if m["id"] not in cls.model_exclude_list
|
167
|
+
]
|
168
|
+
except Exception as e:
|
169
|
+
raise
|
170
|
+
return cls._models_list_cache
|
@@ -66,4 +66,29 @@ models_available = {
|
|
66
66
|
"openchat/openchat_3.5",
|
67
67
|
],
|
68
68
|
"google": ["gemini-pro"],
|
69
|
+
"bedrock": [
|
70
|
+
"amazon.titan-tg1-large",
|
71
|
+
"amazon.titan-text-lite-v1",
|
72
|
+
"amazon.titan-text-express-v1",
|
73
|
+
"anthropic.claude-instant-v1",
|
74
|
+
"anthropic.claude-v2:1",
|
75
|
+
"anthropic.claude-v2",
|
76
|
+
"anthropic.claude-3-sonnet-20240229-v1:0",
|
77
|
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
78
|
+
"anthropic.claude-3-opus-20240229-v1:0",
|
79
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
80
|
+
"cohere.command-text-v14",
|
81
|
+
"cohere.command-r-v1:0",
|
82
|
+
"cohere.command-r-plus-v1:0",
|
83
|
+
"cohere.command-light-text-v14",
|
84
|
+
"meta.llama3-8b-instruct-v1:0",
|
85
|
+
"meta.llama3-70b-instruct-v1:0",
|
86
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
87
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
88
|
+
"meta.llama3-1-405b-instruct-v1:0",
|
89
|
+
"mistral.mistral-7b-instruct-v0:2",
|
90
|
+
"mistral.mixtral-8x7b-instruct-v0:1",
|
91
|
+
"mistral.mistral-large-2402-v1:0",
|
92
|
+
"mistral.mistral-large-2407-v1:0",
|
93
|
+
],
|
69
94
|
}
|
@@ -7,7 +7,25 @@ from edsl.inference_services.AnthropicService import AnthropicService
|
|
7
7
|
from edsl.inference_services.DeepInfraService import DeepInfraService
|
8
8
|
from edsl.inference_services.GoogleService import GoogleService
|
9
9
|
from edsl.inference_services.GroqService import GroqService
|
10
|
+
from edsl.inference_services.AwsBedrock import AwsBedrockService
|
11
|
+
from edsl.inference_services.AzureAI import AzureAIService
|
12
|
+
from edsl.inference_services.OllamaService import OllamaService
|
13
|
+
from edsl.inference_services.TestService import TestService
|
14
|
+
from edsl.inference_services.MistralAIService import MistralAIService
|
15
|
+
from edsl.inference_services.TogetherAIService import TogetherAIService
|
10
16
|
|
11
17
|
default = InferenceServicesCollection(
|
12
|
-
[
|
18
|
+
[
|
19
|
+
OpenAIService,
|
20
|
+
AnthropicService,
|
21
|
+
DeepInfraService,
|
22
|
+
GoogleService,
|
23
|
+
GroqService,
|
24
|
+
AwsBedrockService,
|
25
|
+
AzureAIService,
|
26
|
+
OllamaService,
|
27
|
+
TestService,
|
28
|
+
MistralAIService,
|
29
|
+
TogetherAIService,
|
30
|
+
]
|
13
31
|
)
|
edsl/jobs/Answers.py
CHANGED
@@ -2,24 +2,22 @@
|
|
2
2
|
|
3
3
|
from collections import UserDict
|
4
4
|
from rich.table import Table
|
5
|
+
from edsl.data_transfer_models import EDSLResultObjectInput
|
5
6
|
|
6
7
|
|
7
8
|
class Answers(UserDict):
|
8
9
|
"""Helper class to hold the answers to a survey."""
|
9
10
|
|
10
|
-
def add_answer(
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
>>> answers[q.question_name]
|
18
|
-
'yes'
|
19
|
-
"""
|
20
|
-
answer = response.get("answer")
|
21
|
-
comment = response.pop("comment", None)
|
11
|
+
def add_answer(
|
12
|
+
self, response: EDSLResultObjectInput, question: "QuestionBase"
|
13
|
+
) -> None:
|
14
|
+
"""Add a response to the answers dictionary."""
|
15
|
+
answer = response.answer
|
16
|
+
comment = response.comment
|
17
|
+
generated_tokens = response.generated_tokens
|
22
18
|
# record the answer
|
19
|
+
if generated_tokens:
|
20
|
+
self[question.question_name + "_generated_tokens"] = generated_tokens
|
23
21
|
self[question.question_name] = answer
|
24
22
|
if comment:
|
25
23
|
self[question.question_name + "_comment"] = comment
|