edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +3 -8
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -40
- edsl/agents/AgentList.py +0 -43
- edsl/agents/Invigilator.py +219 -135
- edsl/agents/InvigilatorBase.py +59 -148
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
- edsl/agents/__init__.py +0 -1
- edsl/config.py +56 -47
- edsl/coop/coop.py +7 -50
- edsl/data/Cache.py +1 -35
- edsl/data_transfer_models.py +38 -73
- edsl/enums.py +0 -4
- edsl/exceptions/language_models.py +1 -25
- edsl/exceptions/questions.py +5 -62
- edsl/exceptions/results.py +0 -4
- edsl/inference_services/AnthropicService.py +11 -13
- edsl/inference_services/AwsBedrock.py +17 -19
- edsl/inference_services/AzureAI.py +20 -37
- edsl/inference_services/GoogleService.py +12 -16
- edsl/inference_services/GroqService.py +0 -2
- edsl/inference_services/InferenceServiceABC.py +3 -58
- edsl/inference_services/OpenAIService.py +54 -48
- edsl/inference_services/models_available_cache.py +6 -0
- edsl/inference_services/registry.py +0 -6
- edsl/jobs/Answers.py +12 -10
- edsl/jobs/Jobs.py +21 -36
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +14 -93
- edsl/jobs/interviews/Interview.py +78 -366
- edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
- edsl/jobs/interviews/retry_management.py +37 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
- edsl/jobs/tasks/TaskHistory.py +213 -148
- edsl/language_models/LanguageModel.py +156 -261
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
- edsl/language_models/registry.py +6 -23
- edsl/language_models/repair.py +19 -0
- edsl/prompts/Prompt.py +2 -52
- edsl/questions/AnswerValidatorMixin.py +26 -23
- edsl/questions/QuestionBase.py +249 -329
- edsl/questions/QuestionBudget.py +41 -99
- edsl/questions/QuestionCheckBox.py +35 -227
- edsl/questions/QuestionExtract.py +27 -98
- edsl/questions/QuestionFreeText.py +29 -52
- edsl/questions/QuestionFunctional.py +0 -7
- edsl/questions/QuestionList.py +22 -141
- edsl/questions/QuestionMultipleChoice.py +65 -159
- edsl/questions/QuestionNumerical.py +46 -88
- edsl/questions/QuestionRank.py +24 -182
- edsl/questions/RegisterQuestionsMeta.py +12 -31
- edsl/questions/__init__.py +4 -3
- edsl/questions/derived/QuestionLikertFive.py +5 -10
- edsl/questions/derived/QuestionLinearScale.py +2 -15
- edsl/questions/derived/QuestionTopK.py +1 -10
- edsl/questions/derived/QuestionYesNo.py +3 -24
- edsl/questions/descriptors.py +7 -43
- edsl/questions/question_registry.py +2 -6
- edsl/results/Dataset.py +0 -20
- edsl/results/DatasetExportMixin.py +48 -46
- edsl/results/Result.py +5 -32
- edsl/results/Results.py +46 -135
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +25 -96
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +39 -361
- edsl/scenarios/ScenarioListExportMixin.py +0 -9
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/SnapShot.py +1 -8
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +5 -21
- edsl/surveys/Survey.py +310 -636
- edsl/surveys/SurveyExportMixin.py +9 -71
- edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
- edsl/surveys/SurveyQualtricsImport.py +4 -75
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
- edsl-0.1.33.dev1.dist-info/RECORD +209 -0
- edsl/TemplateLoader.py +0 -24
- edsl/auto/AutoStudy.py +0 -117
- edsl/auto/StageBase.py +0 -230
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -73
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -224
- edsl/coop/PriceFetcher.py +0 -58
- edsl/inference_services/MistralAIService.py +0 -120
- edsl/inference_services/TestService.py +0 -80
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/language_models/fake_openai_call.py +0 -15
- edsl/language_models/fake_openai_service.py +0 -61
- edsl/language_models/utilities.py +0 -61
- edsl/questions/QuestionBaseGenMixin.py +0 -133
- edsl/questions/QuestionBasePromptsMixin.py +0 -266
- edsl/questions/Quick.py +0 -41
- edsl/questions/ResponseValidatorABC.py +0 -170
- edsl/questions/decorators.py +0 -21
- edsl/questions/prompt_templates/question_budget.jinja +0 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
- edsl/questions/prompt_templates/question_extract.jinja +0 -11
- edsl/questions/prompt_templates/question_free_text.jinja +0 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
- edsl/questions/prompt_templates/question_list.jinja +0 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
- edsl/questions/prompt_templates/question_numerical.jinja +0 -37
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +0 -7
- edsl/questions/templates/extract/question_presentation.jinja +0 -1
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +0 -1
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +0 -4
- edsl/questions/templates/list/question_presentation.jinja +0 -5
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
- edsl/questions/templates/numerical/question_presentation.jinja +0 -7
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +0 -11
- edsl/questions/templates/rank/question_presentation.jinja +0 -15
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
- edsl/questions/templates/top_k/question_presentation.jinja +0 -22
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
- edsl/results/DatasetTree.py +0 -145
- edsl/results/Selector.py +0 -118
- edsl/results/tree_explore.py +0 -115
- edsl/surveys/instructions/ChangeInstruction.py +0 -47
- edsl/surveys/instructions/Instruction.py +0 -34
- edsl/surveys/instructions/InstructionCollection.py +0 -77
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +0 -24
- edsl/templates/error_reporting/exceptions_by_model.html +0 -35
- edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
- edsl/templates/error_reporting/exceptions_by_type.html +0 -17
- edsl/templates/error_reporting/interview_details.html +0 -116
- edsl/templates/error_reporting/interviews.html +0 -10
- edsl/templates/error_reporting/overview.html +0 -5
- edsl/templates/error_reporting/performance_plot.html +0 -2
- edsl/templates/error_reporting/report.css +0 -74
- edsl/templates/error_reporting/report.html +0 -118
- edsl/templates/error_reporting/report.js +0 -25
- edsl-0.1.33.dist-info/RECORD +0 -295
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -16,18 +16,6 @@ class AwsBedrockService(InferenceServiceABC):
|
|
16
16
|
_env_key_name_ = (
|
17
17
|
"AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
|
18
18
|
)
|
19
|
-
key_sequence = ["output", "message", "content", 0, "text"]
|
20
|
-
input_token_name = "inputTokens"
|
21
|
-
output_token_name = "outputTokens"
|
22
|
-
usage_sequence = ["usage"]
|
23
|
-
model_exclude_list = [
|
24
|
-
"ai21.j2-grande-instruct",
|
25
|
-
"ai21.j2-jumbo-instruct",
|
26
|
-
"ai21.j2-mid",
|
27
|
-
"ai21.j2-mid-v1",
|
28
|
-
"ai21.j2-ultra",
|
29
|
-
"ai21.j2-ultra-v1",
|
30
|
-
]
|
31
19
|
|
32
20
|
@classmethod
|
33
21
|
def available(cls):
|
@@ -40,7 +28,7 @@ class AwsBedrockService(InferenceServiceABC):
|
|
40
28
|
else:
|
41
29
|
all_models_ids = cls._models_list_cache
|
42
30
|
|
43
|
-
return
|
31
|
+
return all_models_ids
|
44
32
|
|
45
33
|
@classmethod
|
46
34
|
def create_model(
|
@@ -54,8 +42,6 @@ class AwsBedrockService(InferenceServiceABC):
|
|
54
42
|
Child class of LanguageModel for interacting with AWS Bedrock models.
|
55
43
|
"""
|
56
44
|
|
57
|
-
key_sequence = cls.key_sequence
|
58
|
-
usage_sequence = cls.usage_sequence
|
59
45
|
_inference_service_ = cls._inference_service_
|
60
46
|
_model_ = model_name
|
61
47
|
_parameters_ = {
|
@@ -63,10 +49,6 @@ class AwsBedrockService(InferenceServiceABC):
|
|
63
49
|
"max_tokens": 512,
|
64
50
|
"top_p": 0.9,
|
65
51
|
}
|
66
|
-
input_token_name = cls.input_token_name
|
67
|
-
output_token_name = cls.output_token_name
|
68
|
-
_rpm = cls.get_rpm(cls)
|
69
|
-
_tpm = cls.get_tpm(cls)
|
70
52
|
|
71
53
|
async def async_execute_model_call(
|
72
54
|
self, user_prompt: str, system_prompt: str = ""
|
@@ -107,6 +89,22 @@ class AwsBedrockService(InferenceServiceABC):
|
|
107
89
|
print(e)
|
108
90
|
return {"error": str(e)}
|
109
91
|
|
92
|
+
@staticmethod
|
93
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
94
|
+
"""Parses the API response and returns the response text."""
|
95
|
+
if "output" in raw_response and "message" in raw_response["output"]:
|
96
|
+
response = raw_response["output"]["message"]["content"][0]["text"]
|
97
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
98
|
+
match = re.match(pattern, response, re.DOTALL)
|
99
|
+
if match:
|
100
|
+
return match.group(1)
|
101
|
+
else:
|
102
|
+
out = fix_partial_correct_response(response)
|
103
|
+
if "error" not in out:
|
104
|
+
response = out["extracted_json"]
|
105
|
+
return response
|
106
|
+
return "Error parsing response"
|
107
|
+
|
110
108
|
LLM.__name__ = model_class_name
|
111
109
|
|
112
110
|
return LLM
|
@@ -25,22 +25,11 @@ def json_handle_none(value: Any) -> Any:
|
|
25
25
|
class AzureAIService(InferenceServiceABC):
|
26
26
|
"""Azure AI service class."""
|
27
27
|
|
28
|
-
# key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
29
|
-
key_sequence = ["choices", 0, "message", "content"]
|
30
|
-
usage_sequence = ["usage"]
|
31
|
-
input_token_name = "prompt_tokens"
|
32
|
-
output_token_name = "completion_tokens"
|
33
|
-
|
34
28
|
_inference_service_ = "azure"
|
35
29
|
_env_key_name_ = (
|
36
30
|
"AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
|
37
31
|
)
|
38
32
|
_model_id_to_endpoint_and_key = {}
|
39
|
-
model_exclude_list = [
|
40
|
-
"Cohere-command-r-plus-xncmg",
|
41
|
-
"Mistral-Nemo-klfsi",
|
42
|
-
"Mistral-large-2407-ojfld",
|
43
|
-
]
|
44
33
|
|
45
34
|
@classmethod
|
46
35
|
def available(cls):
|
@@ -93,7 +82,7 @@ class AzureAIService(InferenceServiceABC):
|
|
93
82
|
|
94
83
|
except Exception as e:
|
95
84
|
raise e
|
96
|
-
return
|
85
|
+
return out
|
97
86
|
|
98
87
|
@classmethod
|
99
88
|
def create_model(
|
@@ -107,10 +96,6 @@ class AzureAIService(InferenceServiceABC):
|
|
107
96
|
Child class of LanguageModel for interacting with Azure OpenAI models.
|
108
97
|
"""
|
109
98
|
|
110
|
-
key_sequence = cls.key_sequence
|
111
|
-
usage_sequence = cls.usage_sequence
|
112
|
-
input_token_name = cls.input_token_name
|
113
|
-
output_token_name = cls.output_token_name
|
114
99
|
_inference_service_ = cls._inference_service_
|
115
100
|
_model_ = model_name
|
116
101
|
_parameters_ = {
|
@@ -118,8 +103,6 @@ class AzureAIService(InferenceServiceABC):
|
|
118
103
|
"max_tokens": 512,
|
119
104
|
"top_p": 0.9,
|
120
105
|
}
|
121
|
-
_rpm = cls.get_rpm(cls)
|
122
|
-
_tpm = cls.get_tpm(cls)
|
123
106
|
|
124
107
|
async def async_execute_model_call(
|
125
108
|
self, user_prompt: str, system_prompt: str = ""
|
@@ -189,25 +172,25 @@ class AzureAIService(InferenceServiceABC):
|
|
189
172
|
)
|
190
173
|
return response.model_dump()
|
191
174
|
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
175
|
+
@staticmethod
|
176
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
177
|
+
"""Parses the API response and returns the response text."""
|
178
|
+
if (
|
179
|
+
raw_response
|
180
|
+
and "choices" in raw_response
|
181
|
+
and raw_response["choices"]
|
182
|
+
):
|
183
|
+
response = raw_response["choices"][0]["message"]["content"]
|
184
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
185
|
+
match = re.match(pattern, response, re.DOTALL)
|
186
|
+
if match:
|
187
|
+
return match.group(1)
|
188
|
+
else:
|
189
|
+
out = fix_partial_correct_response(response)
|
190
|
+
if "error" not in out:
|
191
|
+
response = out["extracted_json"]
|
192
|
+
return response
|
193
|
+
return "Error parsing response"
|
211
194
|
|
212
195
|
LLM.__name__ = model_class_name
|
213
196
|
|
@@ -10,16 +10,10 @@ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
|
10
10
|
|
11
11
|
class GoogleService(InferenceServiceABC):
|
12
12
|
_inference_service_ = "google"
|
13
|
-
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
-
usage_sequence = ["usageMetadata"]
|
15
|
-
input_token_name = "promptTokenCount"
|
16
|
-
output_token_name = "candidatesTokenCount"
|
17
|
-
|
18
|
-
model_exclude_list = []
|
19
13
|
|
20
14
|
@classmethod
|
21
15
|
def available(cls):
|
22
|
-
return ["gemini-pro"
|
16
|
+
return ["gemini-pro"]
|
23
17
|
|
24
18
|
@classmethod
|
25
19
|
def create_model(
|
@@ -30,15 +24,7 @@ class GoogleService(InferenceServiceABC):
|
|
30
24
|
|
31
25
|
class LLM(LanguageModel):
|
32
26
|
_model_ = model_name
|
33
|
-
key_sequence = cls.key_sequence
|
34
|
-
usage_sequence = cls.usage_sequence
|
35
|
-
input_token_name = cls.input_token_name
|
36
|
-
output_token_name = cls.output_token_name
|
37
27
|
_inference_service_ = cls._inference_service_
|
38
|
-
|
39
|
-
_tpm = cls.get_tpm(cls)
|
40
|
-
_rpm = cls.get_rpm(cls)
|
41
|
-
|
42
28
|
_parameters_ = {
|
43
29
|
"temperature": 0.5,
|
44
30
|
"topP": 1,
|
@@ -64,7 +50,7 @@ class GoogleService(InferenceServiceABC):
|
|
64
50
|
"stopSequences": self.stopSequences,
|
65
51
|
},
|
66
52
|
}
|
67
|
-
|
53
|
+
|
68
54
|
async with aiohttp.ClientSession() as session:
|
69
55
|
async with session.post(
|
70
56
|
url, headers=headers, data=json.dumps(data)
|
@@ -72,6 +58,16 @@ class GoogleService(InferenceServiceABC):
|
|
72
58
|
raw_response_text = await response.text()
|
73
59
|
return json.loads(raw_response_text)
|
74
60
|
|
61
|
+
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
62
|
+
data = raw_response
|
63
|
+
try:
|
64
|
+
return data["candidates"][0]["content"]["parts"][0]["text"]
|
65
|
+
except KeyError as e:
|
66
|
+
print(
|
67
|
+
f"The data return was {data}, which was missing the key 'candidates'"
|
68
|
+
)
|
69
|
+
raise e
|
70
|
+
|
75
71
|
LLM.__name__ = model_name
|
76
72
|
|
77
73
|
return LLM
|
@@ -13,8 +13,6 @@ class GroqService(OpenAIService):
|
|
13
13
|
_sync_client_ = groq.Groq
|
14
14
|
_async_client_ = groq.AsyncGroq
|
15
15
|
|
16
|
-
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
-
|
18
16
|
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
19
17
|
_base_url_ = None
|
20
18
|
_models_list_cache: List[str] = []
|
@@ -1,77 +1,22 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
|
-
import
|
2
|
+
from typing import Any
|
3
3
|
import re
|
4
|
-
from edsl.config import CONFIG
|
5
4
|
|
6
5
|
|
7
6
|
class InferenceServiceABC(ABC):
|
8
|
-
"""
|
9
|
-
Abstract class for inference services.
|
10
|
-
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
11
|
-
"""
|
12
|
-
|
13
|
-
default_levels = {
|
14
|
-
"google": {"tpm": 2_000_000, "rpm": 15},
|
15
|
-
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
16
|
-
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
17
|
-
}
|
18
|
-
|
19
|
-
def __init_subclass__(cls):
|
20
|
-
"""
|
21
|
-
Check that the subclass has the required attributes.
|
22
|
-
- `key_sequence` attribute determines...
|
23
|
-
- `model_exclude_list` attribute determines...
|
24
|
-
"""
|
25
|
-
if not hasattr(cls, "key_sequence"):
|
26
|
-
raise NotImplementedError(
|
27
|
-
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
28
|
-
)
|
29
|
-
if not hasattr(cls, "model_exclude_list"):
|
30
|
-
raise NotImplementedError(
|
31
|
-
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
32
|
-
)
|
33
|
-
|
34
|
-
@classmethod
|
35
|
-
def _get_limt(cls, limit_type: str) -> int:
|
36
|
-
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
37
|
-
if key in os.environ:
|
38
|
-
return int(os.getenv(key))
|
39
|
-
|
40
|
-
if cls._inference_service_ in cls.default_levels:
|
41
|
-
return int(cls.default_levels[cls._inference_service_][limit_type])
|
42
|
-
|
43
|
-
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
44
|
-
|
45
|
-
def get_tpm(cls) -> int:
|
46
|
-
"""
|
47
|
-
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
48
|
-
"""
|
49
|
-
return cls._get_limt(limit_type="tpm")
|
50
|
-
|
51
|
-
def get_rpm(cls):
|
52
|
-
"""
|
53
|
-
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
54
|
-
"""
|
55
|
-
return cls._get_limt(limit_type="rpm")
|
7
|
+
"""Abstract class for inference services."""
|
56
8
|
|
57
9
|
@abstractmethod
|
58
10
|
def available() -> list[str]:
|
59
|
-
"""
|
60
|
-
Returns a list of available models for the service.
|
61
|
-
"""
|
62
11
|
pass
|
63
12
|
|
64
13
|
@abstractmethod
|
65
14
|
def create_model():
|
66
|
-
"""
|
67
|
-
Returns a LanguageModel object.
|
68
|
-
"""
|
69
15
|
pass
|
70
16
|
|
71
17
|
@staticmethod
|
72
18
|
def to_class_name(s):
|
73
|
-
"""
|
74
|
-
Converts a string to a valid class name.
|
19
|
+
"""Convert a string to a valid class name.
|
75
20
|
|
76
21
|
>>> InferenceServiceABC.to_class_name("hello world")
|
77
22
|
'HelloWorld'
|
@@ -1,7 +1,8 @@
|
|
1
|
-
from
|
2
|
-
|
1
|
+
from typing import Any, List
|
2
|
+
import re
|
3
3
|
import os
|
4
4
|
|
5
|
+
# from openai import AsyncOpenAI
|
5
6
|
import openai
|
6
7
|
|
7
8
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -9,8 +10,6 @@ from edsl.language_models import LanguageModel
|
|
9
10
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
10
11
|
from edsl.utilities.utilities import fix_partial_correct_response
|
11
12
|
|
12
|
-
from edsl.config import CONFIG
|
13
|
-
|
14
13
|
|
15
14
|
class OpenAIService(InferenceServiceABC):
|
16
15
|
"""OpenAI service class."""
|
@@ -22,36 +21,19 @@ class OpenAIService(InferenceServiceABC):
|
|
22
21
|
_sync_client_ = openai.OpenAI
|
23
22
|
_async_client_ = openai.AsyncOpenAI
|
24
23
|
|
25
|
-
_sync_client_instance = None
|
26
|
-
_async_client_instance = None
|
27
|
-
|
28
|
-
key_sequence = ["choices", 0, "message", "content"]
|
29
|
-
usage_sequence = ["usage"]
|
30
|
-
input_token_name = "prompt_tokens"
|
31
|
-
output_token_name = "completion_tokens"
|
32
|
-
|
33
|
-
def __init_subclass__(cls, **kwargs):
|
34
|
-
super().__init_subclass__(**kwargs)
|
35
|
-
# so subclasses have to create their own instances of the clients
|
36
|
-
cls._sync_client_instance = None
|
37
|
-
cls._async_client_instance = None
|
38
|
-
|
39
24
|
@classmethod
|
40
25
|
def sync_client(cls):
|
41
|
-
|
42
|
-
cls.
|
43
|
-
|
44
|
-
)
|
45
|
-
return cls._sync_client_instance
|
26
|
+
return cls._sync_client_(
|
27
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
28
|
+
)
|
46
29
|
|
47
30
|
@classmethod
|
48
31
|
def async_client(cls):
|
49
|
-
|
50
|
-
cls.
|
51
|
-
|
52
|
-
)
|
53
|
-
return cls._async_client_instance
|
32
|
+
return cls._async_client_(
|
33
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
34
|
+
)
|
54
35
|
|
36
|
+
# TODO: Make this a coop call
|
55
37
|
model_exclude_list = [
|
56
38
|
"whisper-1",
|
57
39
|
"davinci-002",
|
@@ -66,8 +48,6 @@ class OpenAIService(InferenceServiceABC):
|
|
66
48
|
"text-embedding-3-small",
|
67
49
|
"text-embedding-ada-002",
|
68
50
|
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
69
|
-
"gpt-3.5-turbo-instruct-0914",
|
70
|
-
"gpt-3.5-turbo-instruct",
|
71
51
|
]
|
72
52
|
_models_list_cache: List[str] = []
|
73
53
|
|
@@ -81,8 +61,11 @@ class OpenAIService(InferenceServiceABC):
|
|
81
61
|
|
82
62
|
@classmethod
|
83
63
|
def available(cls) -> List[str]:
|
64
|
+
# from openai import OpenAI
|
65
|
+
|
84
66
|
if not cls._models_list_cache:
|
85
67
|
try:
|
68
|
+
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
86
69
|
cls._models_list_cache = [
|
87
70
|
m.id
|
88
71
|
for m in cls.get_model_list()
|
@@ -90,6 +73,15 @@ class OpenAIService(InferenceServiceABC):
|
|
90
73
|
]
|
91
74
|
except Exception as e:
|
92
75
|
raise
|
76
|
+
# print(
|
77
|
+
# f"""Error retrieving models: {e}.
|
78
|
+
# See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
|
79
|
+
# )
|
80
|
+
# cls._models_list_cache = [
|
81
|
+
# "gpt-3.5-turbo",
|
82
|
+
# "gpt-4-1106-preview",
|
83
|
+
# "gpt-4",
|
84
|
+
# ] # Fallback list
|
93
85
|
return cls._models_list_cache
|
94
86
|
|
95
87
|
@classmethod
|
@@ -102,14 +94,6 @@ class OpenAIService(InferenceServiceABC):
|
|
102
94
|
Child class of LanguageModel for interacting with OpenAI models
|
103
95
|
"""
|
104
96
|
|
105
|
-
key_sequence = cls.key_sequence
|
106
|
-
usage_sequence = cls.usage_sequence
|
107
|
-
input_token_name = cls.input_token_name
|
108
|
-
output_token_name = cls.output_token_name
|
109
|
-
|
110
|
-
_rpm = cls.get_rpm(cls)
|
111
|
-
_tpm = cls.get_tpm(cls)
|
112
|
-
|
113
97
|
_inference_service_ = cls._inference_service_
|
114
98
|
_model_ = model_name
|
115
99
|
_parameters_ = {
|
@@ -130,9 +114,15 @@ class OpenAIService(InferenceServiceABC):
|
|
130
114
|
|
131
115
|
@classmethod
|
132
116
|
def available(cls) -> list[str]:
|
117
|
+
# import openai
|
118
|
+
# client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
119
|
+
# return client.models.list()
|
133
120
|
return cls.sync_client().models.list()
|
134
121
|
|
135
122
|
def get_headers(self) -> dict[str, Any]:
|
123
|
+
# from openai import OpenAI
|
124
|
+
|
125
|
+
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
136
126
|
client = self.sync_client()
|
137
127
|
response = client.chat.completions.with_raw_response.create(
|
138
128
|
messages=[
|
@@ -169,9 +159,6 @@ class OpenAIService(InferenceServiceABC):
|
|
169
159
|
user_prompt: str,
|
170
160
|
system_prompt: str = "",
|
171
161
|
encoded_image=None,
|
172
|
-
invigilator: Optional[
|
173
|
-
"InvigilatorAI"
|
174
|
-
] = None, # TBD - can eventually be used for function-calling
|
175
162
|
) -> dict[str, Any]:
|
176
163
|
"""Calls the OpenAI API and returns the API response."""
|
177
164
|
if encoded_image:
|
@@ -186,16 +173,17 @@ class OpenAIService(InferenceServiceABC):
|
|
186
173
|
)
|
187
174
|
else:
|
188
175
|
content = user_prompt
|
176
|
+
# self.client = AsyncOpenAI(
|
177
|
+
# api_key = os.getenv(cls._env_key_name_),
|
178
|
+
# base_url = cls._base_url_
|
179
|
+
# )
|
189
180
|
client = self.async_client()
|
190
|
-
messages = [
|
191
|
-
{"role": "system", "content": system_prompt},
|
192
|
-
{"role": "user", "content": content},
|
193
|
-
]
|
194
|
-
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
195
|
-
messages = messages[1:]
|
196
181
|
params = {
|
197
182
|
"model": self.model,
|
198
|
-
"messages":
|
183
|
+
"messages": [
|
184
|
+
{"role": "system", "content": system_prompt},
|
185
|
+
{"role": "user", "content": content},
|
186
|
+
],
|
199
187
|
"temperature": self.temperature,
|
200
188
|
"max_tokens": self.max_tokens,
|
201
189
|
"top_p": self.top_p,
|
@@ -207,6 +195,24 @@ class OpenAIService(InferenceServiceABC):
|
|
207
195
|
response = await client.chat.completions.create(**params)
|
208
196
|
return response.model_dump()
|
209
197
|
|
198
|
+
@staticmethod
|
199
|
+
def parse_response(raw_response: dict[str, Any]) -> str:
|
200
|
+
"""Parses the API response and returns the response text."""
|
201
|
+
try:
|
202
|
+
response = raw_response["choices"][0]["message"]["content"]
|
203
|
+
except KeyError:
|
204
|
+
print("Tried to parse response but failed:")
|
205
|
+
print(raw_response)
|
206
|
+
pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
207
|
+
match = re.match(pattern, response, re.DOTALL)
|
208
|
+
if match:
|
209
|
+
return match.group(1)
|
210
|
+
else:
|
211
|
+
out = fix_partial_correct_response(response)
|
212
|
+
if "error" not in out:
|
213
|
+
response = out["extracted_json"]
|
214
|
+
return response
|
215
|
+
|
210
216
|
LLM.__name__ = "LanguageModel"
|
211
217
|
|
212
218
|
return LLM
|
@@ -70,6 +70,12 @@ models_available = {
|
|
70
70
|
"amazon.titan-tg1-large",
|
71
71
|
"amazon.titan-text-lite-v1",
|
72
72
|
"amazon.titan-text-express-v1",
|
73
|
+
"ai21.j2-grande-instruct",
|
74
|
+
"ai21.j2-jumbo-instruct",
|
75
|
+
"ai21.j2-mid",
|
76
|
+
"ai21.j2-mid-v1",
|
77
|
+
"ai21.j2-ultra",
|
78
|
+
"ai21.j2-ultra-v1",
|
73
79
|
"anthropic.claude-instant-v1",
|
74
80
|
"anthropic.claude-v2:1",
|
75
81
|
"anthropic.claude-v2",
|
@@ -10,9 +10,6 @@ from edsl.inference_services.GroqService import GroqService
|
|
10
10
|
from edsl.inference_services.AwsBedrock import AwsBedrockService
|
11
11
|
from edsl.inference_services.AzureAI import AzureAIService
|
12
12
|
from edsl.inference_services.OllamaService import OllamaService
|
13
|
-
from edsl.inference_services.TestService import TestService
|
14
|
-
from edsl.inference_services.MistralAIService import MistralAIService
|
15
|
-
from edsl.inference_services.TogetherAIService import TogetherAIService
|
16
13
|
|
17
14
|
default = InferenceServicesCollection(
|
18
15
|
[
|
@@ -24,8 +21,5 @@ default = InferenceServicesCollection(
|
|
24
21
|
AwsBedrockService,
|
25
22
|
AzureAIService,
|
26
23
|
OllamaService,
|
27
|
-
TestService,
|
28
|
-
MistralAIService,
|
29
|
-
TogetherAIService,
|
30
24
|
]
|
31
25
|
)
|
edsl/jobs/Answers.py
CHANGED
@@ -2,22 +2,24 @@
|
|
2
2
|
|
3
3
|
from collections import UserDict
|
4
4
|
from rich.table import Table
|
5
|
-
from edsl.data_transfer_models import EDSLResultObjectInput
|
6
5
|
|
7
6
|
|
8
7
|
class Answers(UserDict):
|
9
8
|
"""Helper class to hold the answers to a survey."""
|
10
9
|
|
11
|
-
def add_answer(
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
10
|
+
def add_answer(self, response, question) -> None:
|
11
|
+
"""Add a response to the answers dictionary.
|
12
|
+
|
13
|
+
>>> from edsl import QuestionFreeText
|
14
|
+
>>> q = QuestionFreeText.example()
|
15
|
+
>>> answers = Answers()
|
16
|
+
>>> answers.add_answer({"answer": "yes"}, q)
|
17
|
+
>>> answers[q.question_name]
|
18
|
+
'yes'
|
19
|
+
"""
|
20
|
+
answer = response.get("answer")
|
21
|
+
comment = response.pop("comment", None)
|
18
22
|
# record the answer
|
19
|
-
if generated_tokens:
|
20
|
-
self[question.question_name + "_generated_tokens"] = generated_tokens
|
21
23
|
self[question.question_name] = answer
|
22
24
|
if comment:
|
23
25
|
self[question.question_name + "_comment"] = comment
|