edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -4
- edsl/agents/Agent.py +46 -14
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +125 -212
- edsl/agents/InvigilatorBase.py +140 -32
- edsl/agents/PromptConstructionMixin.py +43 -66
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +38 -39
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +39 -5
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +120 -38
- edsl/enums.py +2 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +24 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +41 -50
- edsl/inference_services/TestService.py +71 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +4 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +18 -13
- edsl/jobs/buckets/TokenBucket.py +39 -14
- edsl/jobs/interviews/Interview.py +297 -77
- edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
- edsl/jobs/interviews/interview_exception_tracking.py +0 -70
- edsl/jobs/interviews/retry_management.py +3 -1
- edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
- edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +131 -213
- edsl/language_models/LanguageModel.py +239 -129
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +15 -2
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +273 -242
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +6 -0
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +46 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +173 -64
- edsl/questions/QuestionNumerical.py +87 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +169 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +11 -1
- edsl/questions/derived/QuestionTopK.py +6 -0
- edsl/questions/derived/QuestionYesNo.py +16 -1
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +41 -47
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +131 -45
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/Scenario.py +10 -4
- edsl/scenarios/ScenarioList.py +348 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/study/SnapShot.py +8 -1
- edsl/surveys/RuleCollection.py +2 -2
- edsl/surveys/Survey.py +634 -315
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +111 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
- edsl-0.1.33.dev2.dist-info/RECORD +289 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.33.dev1.dist-info/RECORD +0 -209
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -10,10 +10,16 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
|
10
10
|
|
11
11
|
class GoogleService(InferenceServiceABC):
|
12
12
|
_inference_service_ = "google"
|
13
|
+
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
+
usage_sequence = ["usageMetadata"]
|
15
|
+
input_token_name = "promptTokenCount"
|
16
|
+
output_token_name = "candidatesTokenCount"
|
17
|
+
|
18
|
+
model_exclude_list = []
|
13
19
|
|
14
20
|
@classmethod
|
15
21
|
def available(cls):
|
16
|
-
return ["gemini-pro"]
|
22
|
+
return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
17
23
|
|
18
24
|
@classmethod
|
19
25
|
def create_model(
|
@@ -24,7 +30,15 @@ class GoogleService(InferenceServiceABC):
|
|
24
30
|
|
25
31
|
class LLM(LanguageModel):
|
26
32
|
_model_ = model_name
|
33
|
+
key_sequence = cls.key_sequence
|
34
|
+
usage_sequence = cls.usage_sequence
|
35
|
+
input_token_name = cls.input_token_name
|
36
|
+
output_token_name = cls.output_token_name
|
27
37
|
_inference_service_ = cls._inference_service_
|
38
|
+
|
39
|
+
_tpm = cls.get_tpm(cls)
|
40
|
+
_rpm = cls.get_rpm(cls)
|
41
|
+
|
28
42
|
_parameters_ = {
|
29
43
|
"temperature": 0.5,
|
30
44
|
"topP": 1,
|
@@ -50,7 +64,7 @@ class GoogleService(InferenceServiceABC):
|
|
50
64
|
"stopSequences": self.stopSequences,
|
51
65
|
},
|
52
66
|
}
|
53
|
-
|
67
|
+
print(combined_prompt)
|
54
68
|
async with aiohttp.ClientSession() as session:
|
55
69
|
async with session.post(
|
56
70
|
url, headers=headers, data=json.dumps(data)
|
@@ -58,16 +72,6 @@ class GoogleService(InferenceServiceABC):
|
|
58
72
|
raw_response_text = await response.text()
|
59
73
|
return json.loads(raw_response_text)
|
60
74
|
|
61
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
62
|
-
data = raw_response
|
63
|
-
try:
|
64
|
-
return data["candidates"][0]["content"]["parts"][0]["text"]
|
65
|
-
except KeyError as e:
|
66
|
-
print(
|
67
|
-
f"The data return was {data}, which was missing the key 'candidates'"
|
68
|
-
)
|
69
|
-
raise e
|
70
|
-
|
71
75
|
LLM.__name__ = model_name
|
72
76
|
|
73
77
|
return LLM
|
@@ -13,6 +13,8 @@ class GroqService(OpenAIService):
|
|
13
13
|
_sync_client_ = groq.Groq
|
14
14
|
_async_client_ = groq.AsyncGroq
|
15
15
|
|
16
|
+
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
+
|
16
18
|
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
19
|
_base_url_ = None
|
18
20
|
_models_list_cache: List[str] = []
|
@@ -1,11 +1,35 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
2
|
from typing import Any
|
3
3
|
import re
|
4
|
+
from edsl.config import CONFIG
|
4
5
|
|
5
6
|
|
6
7
|
class InferenceServiceABC(ABC):
|
7
8
|
"""Abstract class for inference services."""
|
8
9
|
|
10
|
+
# check if child class has cls attribute "key_sequence"
|
11
|
+
def __init_subclass__(cls):
|
12
|
+
if not hasattr(cls, "key_sequence"):
|
13
|
+
raise NotImplementedError(
|
14
|
+
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
15
|
+
)
|
16
|
+
if not hasattr(cls, "model_exclude_list"):
|
17
|
+
raise NotImplementedError(
|
18
|
+
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
19
|
+
)
|
20
|
+
|
21
|
+
def get_tpm(cls):
|
22
|
+
key = f"EDSL_SERVICE_TPM_{cls._inference_service_.upper()}"
|
23
|
+
if key not in CONFIG:
|
24
|
+
key = "EDSL_SERVICE_TPM_BASELINE"
|
25
|
+
return int(CONFIG.get(key))
|
26
|
+
|
27
|
+
def get_rpm(cls):
|
28
|
+
key = f"EDSL_SERVICE_RPM_{cls._inference_service_.upper()}"
|
29
|
+
if key not in CONFIG:
|
30
|
+
key = "EDSL_SERVICE_RPM_BASELINE"
|
31
|
+
return int(CONFIG.get(key))
|
32
|
+
|
9
33
|
@abstractmethod
|
10
34
|
def available() -> list[str]:
|
11
35
|
pass
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List
|
3
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
import asyncio
|
6
|
+
from mistralai import Mistral
|
7
|
+
|
8
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
+
|
10
|
+
|
11
|
+
class MistralAIService(InferenceServiceABC):
|
12
|
+
"""Mistral AI service class."""
|
13
|
+
|
14
|
+
key_sequence = ["choices", 0, "message", "content"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
|
17
|
+
_inference_service_ = "mistral"
|
18
|
+
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
+
input_token_name = "prompt_tokens"
|
20
|
+
output_token_name = "completion_tokens"
|
21
|
+
|
22
|
+
_sync_client_instance = None
|
23
|
+
_async_client_instance = None
|
24
|
+
|
25
|
+
_sync_client = Mistral
|
26
|
+
_async_client = Mistral
|
27
|
+
|
28
|
+
_models_list_cache: List[str] = []
|
29
|
+
model_exclude_list = []
|
30
|
+
|
31
|
+
def __init_subclass__(cls, **kwargs):
|
32
|
+
super().__init_subclass__(**kwargs)
|
33
|
+
# so subclasses have to create their own instances of the clients
|
34
|
+
cls._sync_client_instance = None
|
35
|
+
cls._async_client_instance = None
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def sync_client(cls):
|
39
|
+
if cls._sync_client_instance is None:
|
40
|
+
cls._sync_client_instance = cls._sync_client(
|
41
|
+
api_key=os.getenv(cls._env_key_name_)
|
42
|
+
)
|
43
|
+
return cls._sync_client_instance
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def async_client(cls):
|
47
|
+
if cls._async_client_instance is None:
|
48
|
+
cls._async_client_instance = cls._async_client(
|
49
|
+
api_key=os.getenv(cls._env_key_name_)
|
50
|
+
)
|
51
|
+
return cls._async_client_instance
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def available(cls) -> list[str]:
|
55
|
+
if not cls._models_list_cache:
|
56
|
+
cls._models_list_cache = [
|
57
|
+
m.id for m in cls.sync_client().models.list().data
|
58
|
+
]
|
59
|
+
|
60
|
+
return cls._models_list_cache
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def create_model(
|
64
|
+
cls, model_name: str = "mistral", model_class_name=None
|
65
|
+
) -> LanguageModel:
|
66
|
+
if model_class_name is None:
|
67
|
+
model_class_name = cls.to_class_name(model_name)
|
68
|
+
|
69
|
+
class LLM(LanguageModel):
|
70
|
+
"""
|
71
|
+
Child class of LanguageModel for interacting with Mistral models.
|
72
|
+
"""
|
73
|
+
|
74
|
+
key_sequence = cls.key_sequence
|
75
|
+
usage_sequence = cls.usage_sequence
|
76
|
+
|
77
|
+
input_token_name = cls.input_token_name
|
78
|
+
output_token_name = cls.output_token_name
|
79
|
+
|
80
|
+
_inference_service_ = cls._inference_service_
|
81
|
+
_model_ = model_name
|
82
|
+
_parameters_ = {
|
83
|
+
"temperature": 0.5,
|
84
|
+
"max_tokens": 512,
|
85
|
+
"top_p": 0.9,
|
86
|
+
}
|
87
|
+
|
88
|
+
_tpm = cls.get_tpm(cls)
|
89
|
+
_rpm = cls.get_rpm(cls)
|
90
|
+
|
91
|
+
def sync_client(self):
|
92
|
+
return cls.sync_client()
|
93
|
+
|
94
|
+
def async_client(self):
|
95
|
+
return cls.async_client()
|
96
|
+
|
97
|
+
async def async_execute_model_call(
|
98
|
+
self, user_prompt: str, system_prompt: str = ""
|
99
|
+
) -> dict[str, Any]:
|
100
|
+
"""Calls the Mistral API and returns the API response."""
|
101
|
+
s = self.async_client()
|
102
|
+
|
103
|
+
try:
|
104
|
+
res = await s.chat.complete_async(
|
105
|
+
model=model_name,
|
106
|
+
messages=[
|
107
|
+
{
|
108
|
+
"content": user_prompt,
|
109
|
+
"role": "user",
|
110
|
+
},
|
111
|
+
],
|
112
|
+
)
|
113
|
+
except Exception as e:
|
114
|
+
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
115
|
+
|
116
|
+
return res.model_dump()
|
117
|
+
|
118
|
+
LLM.__name__ = model_class_name
|
119
|
+
|
120
|
+
return LLM
|
@@ -1,8 +1,7 @@
|
|
1
|
-
from
|
2
|
-
import
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
import os
|
4
4
|
|
5
|
-
# from openai import AsyncOpenAI
|
6
5
|
import openai
|
7
6
|
|
8
7
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -10,6 +9,8 @@ from edsl.language_models import LanguageModel
|
|
10
9
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
11
10
|
from edsl.utilities.utilities import fix_partial_correct_response
|
12
11
|
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
|
13
14
|
|
14
15
|
class OpenAIService(InferenceServiceABC):
|
15
16
|
"""OpenAI service class."""
|
@@ -21,19 +22,36 @@ class OpenAIService(InferenceServiceABC):
|
|
21
22
|
_sync_client_ = openai.OpenAI
|
22
23
|
_async_client_ = openai.AsyncOpenAI
|
23
24
|
|
25
|
+
_sync_client_instance = None
|
26
|
+
_async_client_instance = None
|
27
|
+
|
28
|
+
key_sequence = ["choices", 0, "message", "content"]
|
29
|
+
usage_sequence = ["usage"]
|
30
|
+
input_token_name = "prompt_tokens"
|
31
|
+
output_token_name = "completion_tokens"
|
32
|
+
|
33
|
+
def __init_subclass__(cls, **kwargs):
|
34
|
+
super().__init_subclass__(**kwargs)
|
35
|
+
# so subclasses have to create their own instances of the clients
|
36
|
+
cls._sync_client_instance = None
|
37
|
+
cls._async_client_instance = None
|
38
|
+
|
24
39
|
@classmethod
|
25
40
|
def sync_client(cls):
|
26
|
-
|
27
|
-
|
28
|
-
|
41
|
+
if cls._sync_client_instance is None:
|
42
|
+
cls._sync_client_instance = cls._sync_client_(
|
43
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
44
|
+
)
|
45
|
+
return cls._sync_client_instance
|
29
46
|
|
30
47
|
@classmethod
|
31
48
|
def async_client(cls):
|
32
|
-
|
33
|
-
|
34
|
-
|
49
|
+
if cls._async_client_instance is None:
|
50
|
+
cls._async_client_instance = cls._async_client_(
|
51
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
52
|
+
)
|
53
|
+
return cls._async_client_instance
|
35
54
|
|
36
|
-
# TODO: Make this a coop call
|
37
55
|
model_exclude_list = [
|
38
56
|
"whisper-1",
|
39
57
|
"davinci-002",
|
@@ -48,6 +66,8 @@ class OpenAIService(InferenceServiceABC):
|
|
48
66
|
"text-embedding-3-small",
|
49
67
|
"text-embedding-ada-002",
|
50
68
|
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
69
|
+
"gpt-3.5-turbo-instruct-0914",
|
70
|
+
"gpt-3.5-turbo-instruct",
|
51
71
|
]
|
52
72
|
_models_list_cache: List[str] = []
|
53
73
|
|
@@ -61,11 +81,8 @@ class OpenAIService(InferenceServiceABC):
|
|
61
81
|
|
62
82
|
@classmethod
|
63
83
|
def available(cls) -> List[str]:
|
64
|
-
# from openai import OpenAI
|
65
|
-
|
66
84
|
if not cls._models_list_cache:
|
67
85
|
try:
|
68
|
-
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
69
86
|
cls._models_list_cache = [
|
70
87
|
m.id
|
71
88
|
for m in cls.get_model_list()
|
@@ -73,15 +90,6 @@ class OpenAIService(InferenceServiceABC):
|
|
73
90
|
]
|
74
91
|
except Exception as e:
|
75
92
|
raise
|
76
|
-
# print(
|
77
|
-
# f"""Error retrieving models: {e}.
|
78
|
-
# See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
|
79
|
-
# )
|
80
|
-
# cls._models_list_cache = [
|
81
|
-
# "gpt-3.5-turbo",
|
82
|
-
# "gpt-4-1106-preview",
|
83
|
-
# "gpt-4",
|
84
|
-
# ] # Fallback list
|
85
93
|
return cls._models_list_cache
|
86
94
|
|
87
95
|
@classmethod
|
@@ -94,6 +102,14 @@ class OpenAIService(InferenceServiceABC):
|
|
94
102
|
Child class of LanguageModel for interacting with OpenAI models
|
95
103
|
"""
|
96
104
|
|
105
|
+
key_sequence = cls.key_sequence
|
106
|
+
usage_sequence = cls.usage_sequence
|
107
|
+
input_token_name = cls.input_token_name
|
108
|
+
output_token_name = cls.output_token_name
|
109
|
+
|
110
|
+
_rpm = cls.get_rpm(cls)
|
111
|
+
_tpm = cls.get_tpm(cls)
|
112
|
+
|
97
113
|
_inference_service_ = cls._inference_service_
|
98
114
|
_model_ = model_name
|
99
115
|
_parameters_ = {
|
@@ -114,15 +130,9 @@ class OpenAIService(InferenceServiceABC):
|
|
114
130
|
|
115
131
|
@classmethod
|
116
132
|
def available(cls) -> list[str]:
|
117
|
-
# import openai
|
118
|
-
# client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
119
|
-
# return client.models.list()
|
120
133
|
return cls.sync_client().models.list()
|
121
134
|
|
122
135
|
def get_headers(self) -> dict[str, Any]:
|
123
|
-
# from openai import OpenAI
|
124
|
-
|
125
|
-
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
126
136
|
client = self.sync_client()
|
127
137
|
response = client.chat.completions.with_raw_response.create(
|
128
138
|
messages=[
|
@@ -159,6 +169,9 @@ class OpenAIService(InferenceServiceABC):
|
|
159
169
|
user_prompt: str,
|
160
170
|
system_prompt: str = "",
|
161
171
|
encoded_image=None,
|
172
|
+
invigilator: Optional[
|
173
|
+
"InvigilatorAI"
|
174
|
+
] = None, # TBD - can eventually be used for function-calling
|
162
175
|
) -> dict[str, Any]:
|
163
176
|
"""Calls the OpenAI API and returns the API response."""
|
164
177
|
if encoded_image:
|
@@ -173,10 +186,6 @@ class OpenAIService(InferenceServiceABC):
|
|
173
186
|
)
|
174
187
|
else:
|
175
188
|
content = user_prompt
|
176
|
-
# self.client = AsyncOpenAI(
|
177
|
-
# api_key = os.getenv(cls._env_key_name_),
|
178
|
-
# base_url = cls._base_url_
|
179
|
-
# )
|
180
189
|
client = self.async_client()
|
181
190
|
params = {
|
182
191
|
"model": self.model,
|
@@ -195,24 +204,6 @@ class OpenAIService(InferenceServiceABC):
|
|
195
204
|
response = await client.chat.completions.create(**params)
|
196
205
|
return response.model_dump()
|
197
206
|
|
198
|
-
@staticmethod
|
199
|
-
def parse_response(raw_response: dict[str, Any]) -> str:
|
200
|
-
"""Parses the API response and returns the response text."""
|
201
|
-
try:
|
202
|
-
response = raw_response["choices"][0]["message"]["content"]
|
203
|
-
except KeyError:
|
204
|
-
print("Tried to parse response but failed:")
|
205
|
-
print(raw_response)
|
206
|
-
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
207
|
-
match = re.match(pattern, response, re.DOTALL)
|
208
|
-
if match:
|
209
|
-
return match.group(1)
|
210
|
-
else:
|
211
|
-
out = fix_partial_correct_response(response)
|
212
|
-
if "error" not in out:
|
213
|
-
response = out["extracted_json"]
|
214
|
-
return response
|
215
|
-
|
216
207
|
LLM.__name__ = "LanguageModel"
|
217
208
|
|
218
209
|
return LLM
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
import os
|
3
|
+
import asyncio
|
4
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
+
from edsl.language_models import LanguageModel
|
6
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
7
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
8
|
+
|
9
|
+
from edsl.enums import InferenceServiceType
|
10
|
+
|
11
|
+
|
12
|
+
class TestService(InferenceServiceABC):
|
13
|
+
"""OpenAI service class."""
|
14
|
+
|
15
|
+
key_sequence = None
|
16
|
+
model_exclude_list = []
|
17
|
+
_inference_service_ = "test"
|
18
|
+
input_token_name = "prompt_tokens"
|
19
|
+
output_token_name = "completion_tokens"
|
20
|
+
|
21
|
+
@classmethod
|
22
|
+
def available(cls) -> list[str]:
|
23
|
+
return ["test"]
|
24
|
+
|
25
|
+
@classmethod
|
26
|
+
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
27
|
+
throw_exception = False
|
28
|
+
|
29
|
+
class TestServiceLanguageModel(LanguageModel):
|
30
|
+
_model_ = "test"
|
31
|
+
_parameters_ = {"temperature": 0.5}
|
32
|
+
_inference_service_ = InferenceServiceType.TEST.value
|
33
|
+
usage_sequence = ["usage"]
|
34
|
+
key_sequence = ["message", 0, "text"]
|
35
|
+
input_token_name = cls.input_token_name
|
36
|
+
output_token_name = cls.output_token_name
|
37
|
+
_rpm = 1000
|
38
|
+
_tpm = 100000
|
39
|
+
|
40
|
+
@property
|
41
|
+
def _canned_response(self):
|
42
|
+
if hasattr(self, "canned_response"):
|
43
|
+
return self.canned_response
|
44
|
+
else:
|
45
|
+
return "Hello, world"
|
46
|
+
|
47
|
+
async def async_execute_model_call(
|
48
|
+
self, user_prompt: str, system_prompt: str
|
49
|
+
) -> dict[str, Any]:
|
50
|
+
await asyncio.sleep(0.1)
|
51
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
52
|
+
if hasattr(self, "throw_exception") and self.throw_exception:
|
53
|
+
raise Exception("This is a test error")
|
54
|
+
return {
|
55
|
+
"message": [{"text": f"{self._canned_response}"}],
|
56
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
57
|
+
}
|
58
|
+
|
59
|
+
return TestServiceLanguageModel
|
60
|
+
|
61
|
+
# _inference_service_ = "openai"
|
62
|
+
# _env_key_name_ = "OPENAI_API_KEY"
|
63
|
+
# _base_url_ = None
|
64
|
+
|
65
|
+
# _sync_client_ = openai.OpenAI
|
66
|
+
# _async_client_ = openai.AsyncOpenAI
|
67
|
+
|
68
|
+
# _sync_client_instance = None
|
69
|
+
# _async_client_instance = None
|
70
|
+
|
71
|
+
# key_sequence = ["choices", 0, "message", "content"]
|
@@ -70,12 +70,6 @@ models_available = {
|
|
70
70
|
"amazon.titan-tg1-large",
|
71
71
|
"amazon.titan-text-lite-v1",
|
72
72
|
"amazon.titan-text-express-v1",
|
73
|
-
"ai21.j2-grande-instruct",
|
74
|
-
"ai21.j2-jumbo-instruct",
|
75
|
-
"ai21.j2-mid",
|
76
|
-
"ai21.j2-mid-v1",
|
77
|
-
"ai21.j2-ultra",
|
78
|
-
"ai21.j2-ultra-v1",
|
79
73
|
"anthropic.claude-instant-v1",
|
80
74
|
"anthropic.claude-v2:1",
|
81
75
|
"anthropic.claude-v2",
|
@@ -10,6 +10,8 @@ from edsl.inference_services.GroqService import GroqService
|
|
10
10
|
from edsl.inference_services.AwsBedrock import AwsBedrockService
|
11
11
|
from edsl.inference_services.AzureAI import AzureAIService
|
12
12
|
from edsl.inference_services.OllamaService import OllamaService
|
13
|
+
from edsl.inference_services.TestService import TestService
|
14
|
+
from edsl.inference_services.MistralAIService import MistralAIService
|
13
15
|
|
14
16
|
default = InferenceServicesCollection(
|
15
17
|
[
|
@@ -21,5 +23,7 @@ default = InferenceServicesCollection(
|
|
21
23
|
AwsBedrockService,
|
22
24
|
AzureAIService,
|
23
25
|
OllamaService,
|
26
|
+
TestService,
|
27
|
+
MistralAIService,
|
24
28
|
]
|
25
29
|
)
|
edsl/jobs/Answers.py
CHANGED
@@ -2,24 +2,22 @@
|
|
2
2
|
|
3
3
|
from collections import UserDict
|
4
4
|
from rich.table import Table
|
5
|
+
from edsl.data_transfer_models import EDSLResultObjectInput
|
5
6
|
|
6
7
|
|
7
8
|
class Answers(UserDict):
|
8
9
|
"""Helper class to hold the answers to a survey."""
|
9
10
|
|
10
|
-
def add_answer(
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
>>> answers[q.question_name]
|
18
|
-
'yes'
|
19
|
-
"""
|
20
|
-
answer = response.get("answer")
|
21
|
-
comment = response.pop("comment", None)
|
11
|
+
def add_answer(
|
12
|
+
self, response: EDSLResultObjectInput, question: "QuestionBase"
|
13
|
+
) -> None:
|
14
|
+
"""Add a response to the answers dictionary."""
|
15
|
+
answer = response.answer
|
16
|
+
comment = response.comment
|
17
|
+
generated_tokens = response.generated_tokens
|
22
18
|
# record the answer
|
19
|
+
if generated_tokens:
|
20
|
+
self[question.question_name + "_generated_tokens"] = generated_tokens
|
23
21
|
self[question.question_name] = answer
|
24
22
|
if comment:
|
25
23
|
self[question.question_name + "_comment"] = comment
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from edsl.questions import QuestionBase
|
2
|
+
from edsl import Question, Scenario, Model, Agent
|
3
|
+
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
|
6
|
+
|
7
|
+
class FailedQuestion:
|
8
|
+
# tests/jobs/test_Interview.py::test_handle_model_exceptions
|
9
|
+
|
10
|
+
# (Pdb) dir(self.exception.__traceback__)
|
11
|
+
# ['tb_frame', 'tb_lasti', 'tb_lineno', 'tb_next']
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self, question, scenario, model, agent, raw_model_response, exception, prompts
|
15
|
+
):
|
16
|
+
self.question = question
|
17
|
+
self.scenario = scenario
|
18
|
+
self.model = model
|
19
|
+
self.agent = agent
|
20
|
+
self.raw_model_response = raw_model_response # JSON
|
21
|
+
self.exception = exception
|
22
|
+
self.prompts = prompts
|
23
|
+
|
24
|
+
def to_dict(self):
|
25
|
+
return {
|
26
|
+
"question": self.question._to_dict(),
|
27
|
+
"scenario": self.scenario._to_dict(),
|
28
|
+
"model": self.model._to_dict(),
|
29
|
+
"agent": self.agent._to_dict(),
|
30
|
+
"raw_model_response": self.raw_model_response,
|
31
|
+
"exception": self.exception.__class__.__name__, # self.exception,
|
32
|
+
"prompts": self.prompts,
|
33
|
+
}
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_dict(cls, data):
|
37
|
+
question = QuestionBase.from_dict(data["question"])
|
38
|
+
scenario = Scenario.from_dict(data["scenario"])
|
39
|
+
model = LanguageModel.from_dict(data["model"])
|
40
|
+
agent = Agent.from_dict(data["agent"])
|
41
|
+
raw_model_response = data["raw_model_response"]
|
42
|
+
exception = data["exception"]
|
43
|
+
prompts = data["prompts"]
|
44
|
+
return cls(
|
45
|
+
question, scenario, model, agent, raw_model_response, exception, prompts
|
46
|
+
)
|
47
|
+
|
48
|
+
def __repr__(self):
|
49
|
+
return f"{self.__class__.__name__}(question={repr(self.question)}, scenario={repr(self.scenario)}, model={repr(self.model)}, agent={repr(self.agent)}, raw_model_response={repr(self.raw_model_response)}, exception={repr(self.exception)})"
|
50
|
+
|
51
|
+
@property
|
52
|
+
def jobs(self):
|
53
|
+
return self.question.by(self.scenario).by(self.agent).by(self.model)
|
54
|
+
|
55
|
+
def rerun(self):
|
56
|
+
results = self.jobs.run()
|
57
|
+
return results
|
58
|
+
|
59
|
+
def help(self):
|
60
|
+
pass
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def example(cls):
|
64
|
+
from edsl.language_models.utilities import create_language_model
|
65
|
+
from edsl.language_models.utilities import create_survey
|
66
|
+
|
67
|
+
survey = create_survey(2, chained=False, take_scenario=False)
|
68
|
+
fail_at_number = 1
|
69
|
+
model = create_language_model(ValueError, fail_at_number)()
|
70
|
+
from edsl import Survey
|
71
|
+
|
72
|
+
results = survey.by(model).run()
|
73
|
+
return results.failed_questions[0][0]
|
74
|
+
|
75
|
+
|
76
|
+
if __name__ == "__main__":
|
77
|
+
fq = FailedQuestion.example()
|
78
|
+
new_fq = FailedQuestion.from_dict(fq.to_dict())
|