edsl 0.1.37.dev4__py3-none-any.whl → 0.1.37.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -48
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +855 -804
- edsl/agents/AgentList.py +350 -345
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +284 -305
- edsl/agents/PromptConstructor.py +353 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +160 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +290 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +958 -827
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +54 -50
- edsl/exceptions/agents.py +38 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -26
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +37 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -74
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1347 -1135
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +338 -338
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +442 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +706 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +357 -353
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +656 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +234 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +717 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +450 -435
- edsl/results/Results.py +1071 -1160
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +546 -510
- edsl/scenarios/ScenarioHtmlMixin.py +64 -59
- edsl/scenarios/ScenarioList.py +1112 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +330 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1795 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +409 -391
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev5.dist-info}/LICENSE +21 -21
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev5.dist-info}/METADATA +1 -1
- edsl-0.1.37.dev5.dist-info/RECORD +283 -0
- edsl-0.1.37.dev4.dist-info/RECORD +0 -279
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev5.dist-info}/WHEEL +0 -0
@@ -1,156 +1,156 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Any, Dict, List, Optional
|
3
|
-
import google
|
4
|
-
import google.generativeai as genai
|
5
|
-
from google.generativeai.types import GenerationConfig
|
6
|
-
from google.api_core.exceptions import InvalidArgument
|
7
|
-
|
8
|
-
from edsl.exceptions import MissingAPIKeyError
|
9
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
10
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
11
|
-
|
12
|
-
safety_settings = [
|
13
|
-
{
|
14
|
-
"category": "HARM_CATEGORY_HARASSMENT",
|
15
|
-
"threshold": "BLOCK_NONE",
|
16
|
-
},
|
17
|
-
{
|
18
|
-
"category": "HARM_CATEGORY_HATE_SPEECH",
|
19
|
-
"threshold": "BLOCK_NONE",
|
20
|
-
},
|
21
|
-
{
|
22
|
-
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
23
|
-
"threshold": "BLOCK_NONE",
|
24
|
-
},
|
25
|
-
{
|
26
|
-
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
27
|
-
"threshold": "BLOCK_NONE",
|
28
|
-
},
|
29
|
-
]
|
30
|
-
|
31
|
-
|
32
|
-
class GoogleService(InferenceServiceABC):
|
33
|
-
_inference_service_ = "google"
|
34
|
-
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
35
|
-
usage_sequence = ["usage_metadata"]
|
36
|
-
input_token_name = "prompt_token_count"
|
37
|
-
output_token_name = "candidates_token_count"
|
38
|
-
|
39
|
-
model_exclude_list = []
|
40
|
-
|
41
|
-
# @classmethod
|
42
|
-
# def available(cls) -> List[str]:
|
43
|
-
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
44
|
-
|
45
|
-
@classmethod
|
46
|
-
def available(cls) -> List[str]:
|
47
|
-
model_list = []
|
48
|
-
for m in genai.list_models():
|
49
|
-
if "generateContent" in m.supported_generation_methods:
|
50
|
-
model_list.append(m.name.split("/")[-1])
|
51
|
-
return model_list
|
52
|
-
|
53
|
-
@classmethod
|
54
|
-
def create_model(
|
55
|
-
cls, model_name: str = "gemini-pro", model_class_name=None
|
56
|
-
) -> LanguageModel:
|
57
|
-
if model_class_name is None:
|
58
|
-
model_class_name = cls.to_class_name(model_name)
|
59
|
-
|
60
|
-
class LLM(LanguageModel):
|
61
|
-
_model_ = model_name
|
62
|
-
key_sequence = cls.key_sequence
|
63
|
-
usage_sequence = cls.usage_sequence
|
64
|
-
input_token_name = cls.input_token_name
|
65
|
-
output_token_name = cls.output_token_name
|
66
|
-
_inference_service_ = cls._inference_service_
|
67
|
-
|
68
|
-
_tpm = cls.get_tpm(cls)
|
69
|
-
_rpm = cls.get_rpm(cls)
|
70
|
-
|
71
|
-
_parameters_ = {
|
72
|
-
"temperature": 0.5,
|
73
|
-
"topP": 1,
|
74
|
-
"topK": 1,
|
75
|
-
"maxOutputTokens": 2048,
|
76
|
-
"stopSequences": [],
|
77
|
-
}
|
78
|
-
|
79
|
-
api_token = None
|
80
|
-
model = None
|
81
|
-
|
82
|
-
@classmethod
|
83
|
-
def initialize(cls):
|
84
|
-
if cls.api_token is None:
|
85
|
-
cls.api_token = os.getenv("GOOGLE_API_KEY")
|
86
|
-
if not cls.api_token:
|
87
|
-
raise MissingAPIKeyError(
|
88
|
-
"GOOGLE_API_KEY environment variable is not set"
|
89
|
-
)
|
90
|
-
genai.configure(api_key=cls.api_token)
|
91
|
-
cls.generative_model = genai.GenerativeModel(
|
92
|
-
cls._model_, safety_settings=safety_settings
|
93
|
-
)
|
94
|
-
|
95
|
-
def __init__(self, *args, **kwargs):
|
96
|
-
super().__init__(*args, **kwargs)
|
97
|
-
self.initialize()
|
98
|
-
|
99
|
-
def get_generation_config(self) -> GenerationConfig:
|
100
|
-
return GenerationConfig(
|
101
|
-
temperature=self.temperature,
|
102
|
-
top_p=self.topP,
|
103
|
-
top_k=self.topK,
|
104
|
-
max_output_tokens=self.maxOutputTokens,
|
105
|
-
stop_sequences=self.stopSequences,
|
106
|
-
)
|
107
|
-
|
108
|
-
async def async_execute_model_call(
|
109
|
-
self,
|
110
|
-
user_prompt: str,
|
111
|
-
system_prompt: str = "",
|
112
|
-
files_list: Optional["Files"] = None,
|
113
|
-
) -> Dict[str, Any]:
|
114
|
-
generation_config = self.get_generation_config()
|
115
|
-
|
116
|
-
if files_list is None:
|
117
|
-
files_list = []
|
118
|
-
|
119
|
-
if (
|
120
|
-
system_prompt is not None
|
121
|
-
and system_prompt != ""
|
122
|
-
and self._model_ != "gemini-pro"
|
123
|
-
):
|
124
|
-
try:
|
125
|
-
self.generative_model = genai.GenerativeModel(
|
126
|
-
self._model_,
|
127
|
-
safety_settings=safety_settings,
|
128
|
-
system_instruction=system_prompt,
|
129
|
-
)
|
130
|
-
except InvalidArgument as e:
|
131
|
-
print(
|
132
|
-
f"This model, {self._model_}, does not support system_instruction"
|
133
|
-
)
|
134
|
-
print("Will add system_prompt to user_prompt")
|
135
|
-
user_prompt = f"{system_prompt}\n{user_prompt}"
|
136
|
-
|
137
|
-
combined_prompt = [user_prompt]
|
138
|
-
for file in files_list:
|
139
|
-
if "google" not in file.external_locations:
|
140
|
-
_ = file.upload_google()
|
141
|
-
gen_ai_file = google.generativeai.types.file_types.File(
|
142
|
-
file.external_locations["google"]
|
143
|
-
)
|
144
|
-
combined_prompt.append(gen_ai_file)
|
145
|
-
|
146
|
-
response = await self.generative_model.generate_content_async(
|
147
|
-
combined_prompt, generation_config=generation_config
|
148
|
-
)
|
149
|
-
return response.to_dict()
|
150
|
-
|
151
|
-
LLM.__name__ = model_name
|
152
|
-
return LLM
|
153
|
-
|
154
|
-
|
155
|
-
if __name__ == "__main__":
|
156
|
-
pass
|
1
|
+
import os
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
import google
|
4
|
+
import google.generativeai as genai
|
5
|
+
from google.generativeai.types import GenerationConfig
|
6
|
+
from google.api_core.exceptions import InvalidArgument
|
7
|
+
|
8
|
+
from edsl.exceptions import MissingAPIKeyError
|
9
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
10
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
11
|
+
|
12
|
+
safety_settings = [
|
13
|
+
{
|
14
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
15
|
+
"threshold": "BLOCK_NONE",
|
16
|
+
},
|
17
|
+
{
|
18
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
19
|
+
"threshold": "BLOCK_NONE",
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
23
|
+
"threshold": "BLOCK_NONE",
|
24
|
+
},
|
25
|
+
{
|
26
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
27
|
+
"threshold": "BLOCK_NONE",
|
28
|
+
},
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
class GoogleService(InferenceServiceABC):
|
33
|
+
_inference_service_ = "google"
|
34
|
+
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
35
|
+
usage_sequence = ["usage_metadata"]
|
36
|
+
input_token_name = "prompt_token_count"
|
37
|
+
output_token_name = "candidates_token_count"
|
38
|
+
|
39
|
+
model_exclude_list = []
|
40
|
+
|
41
|
+
# @classmethod
|
42
|
+
# def available(cls) -> List[str]:
|
43
|
+
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def available(cls) -> List[str]:
|
47
|
+
model_list = []
|
48
|
+
for m in genai.list_models():
|
49
|
+
if "generateContent" in m.supported_generation_methods:
|
50
|
+
model_list.append(m.name.split("/")[-1])
|
51
|
+
return model_list
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def create_model(
|
55
|
+
cls, model_name: str = "gemini-pro", model_class_name=None
|
56
|
+
) -> LanguageModel:
|
57
|
+
if model_class_name is None:
|
58
|
+
model_class_name = cls.to_class_name(model_name)
|
59
|
+
|
60
|
+
class LLM(LanguageModel):
|
61
|
+
_model_ = model_name
|
62
|
+
key_sequence = cls.key_sequence
|
63
|
+
usage_sequence = cls.usage_sequence
|
64
|
+
input_token_name = cls.input_token_name
|
65
|
+
output_token_name = cls.output_token_name
|
66
|
+
_inference_service_ = cls._inference_service_
|
67
|
+
|
68
|
+
_tpm = cls.get_tpm(cls)
|
69
|
+
_rpm = cls.get_rpm(cls)
|
70
|
+
|
71
|
+
_parameters_ = {
|
72
|
+
"temperature": 0.5,
|
73
|
+
"topP": 1,
|
74
|
+
"topK": 1,
|
75
|
+
"maxOutputTokens": 2048,
|
76
|
+
"stopSequences": [],
|
77
|
+
}
|
78
|
+
|
79
|
+
api_token = None
|
80
|
+
model = None
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def initialize(cls):
|
84
|
+
if cls.api_token is None:
|
85
|
+
cls.api_token = os.getenv("GOOGLE_API_KEY")
|
86
|
+
if not cls.api_token:
|
87
|
+
raise MissingAPIKeyError(
|
88
|
+
"GOOGLE_API_KEY environment variable is not set"
|
89
|
+
)
|
90
|
+
genai.configure(api_key=cls.api_token)
|
91
|
+
cls.generative_model = genai.GenerativeModel(
|
92
|
+
cls._model_, safety_settings=safety_settings
|
93
|
+
)
|
94
|
+
|
95
|
+
def __init__(self, *args, **kwargs):
|
96
|
+
super().__init__(*args, **kwargs)
|
97
|
+
self.initialize()
|
98
|
+
|
99
|
+
def get_generation_config(self) -> GenerationConfig:
|
100
|
+
return GenerationConfig(
|
101
|
+
temperature=self.temperature,
|
102
|
+
top_p=self.topP,
|
103
|
+
top_k=self.topK,
|
104
|
+
max_output_tokens=self.maxOutputTokens,
|
105
|
+
stop_sequences=self.stopSequences,
|
106
|
+
)
|
107
|
+
|
108
|
+
async def async_execute_model_call(
|
109
|
+
self,
|
110
|
+
user_prompt: str,
|
111
|
+
system_prompt: str = "",
|
112
|
+
files_list: Optional["Files"] = None,
|
113
|
+
) -> Dict[str, Any]:
|
114
|
+
generation_config = self.get_generation_config()
|
115
|
+
|
116
|
+
if files_list is None:
|
117
|
+
files_list = []
|
118
|
+
|
119
|
+
if (
|
120
|
+
system_prompt is not None
|
121
|
+
and system_prompt != ""
|
122
|
+
and self._model_ != "gemini-pro"
|
123
|
+
):
|
124
|
+
try:
|
125
|
+
self.generative_model = genai.GenerativeModel(
|
126
|
+
self._model_,
|
127
|
+
safety_settings=safety_settings,
|
128
|
+
system_instruction=system_prompt,
|
129
|
+
)
|
130
|
+
except InvalidArgument as e:
|
131
|
+
print(
|
132
|
+
f"This model, {self._model_}, does not support system_instruction"
|
133
|
+
)
|
134
|
+
print("Will add system_prompt to user_prompt")
|
135
|
+
user_prompt = f"{system_prompt}\n{user_prompt}"
|
136
|
+
|
137
|
+
combined_prompt = [user_prompt]
|
138
|
+
for file in files_list:
|
139
|
+
if "google" not in file.external_locations:
|
140
|
+
_ = file.upload_google()
|
141
|
+
gen_ai_file = google.generativeai.types.file_types.File(
|
142
|
+
file.external_locations["google"]
|
143
|
+
)
|
144
|
+
combined_prompt.append(gen_ai_file)
|
145
|
+
|
146
|
+
response = await self.generative_model.generate_content_async(
|
147
|
+
combined_prompt, generation_config=generation_config
|
148
|
+
)
|
149
|
+
return response.to_dict()
|
150
|
+
|
151
|
+
LLM.__name__ = model_name
|
152
|
+
return LLM
|
153
|
+
|
154
|
+
|
155
|
+
if __name__ == "__main__":
|
156
|
+
pass
|
@@ -1,20 +1,20 @@
|
|
1
|
-
from typing import Any, List
|
2
|
-
from edsl.inference_services.OpenAIService import OpenAIService
|
3
|
-
|
4
|
-
import groq
|
5
|
-
|
6
|
-
|
7
|
-
class GroqService(OpenAIService):
|
8
|
-
"""DeepInfra service class."""
|
9
|
-
|
10
|
-
_inference_service_ = "groq"
|
11
|
-
_env_key_name_ = "GROQ_API_KEY"
|
12
|
-
|
13
|
-
_sync_client_ = groq.Groq
|
14
|
-
_async_client_ = groq.AsyncGroq
|
15
|
-
|
16
|
-
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
-
|
18
|
-
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
19
|
-
_base_url_ = None
|
20
|
-
_models_list_cache: List[str] = []
|
1
|
+
from typing import Any, List
|
2
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
3
|
+
|
4
|
+
import groq
|
5
|
+
|
6
|
+
|
7
|
+
class GroqService(OpenAIService):
|
8
|
+
"""DeepInfra service class."""
|
9
|
+
|
10
|
+
_inference_service_ = "groq"
|
11
|
+
_env_key_name_ = "GROQ_API_KEY"
|
12
|
+
|
13
|
+
_sync_client_ = groq.Groq
|
14
|
+
_async_client_ = groq.AsyncGroq
|
15
|
+
|
16
|
+
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
+
|
18
|
+
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
19
|
+
_base_url_ = None
|
20
|
+
_models_list_cache: List[str] = []
|
@@ -1,147 +1,147 @@
|
|
1
|
-
from abc import abstractmethod, ABC
|
2
|
-
import os
|
3
|
-
import re
|
4
|
-
from datetime import datetime, timedelta
|
5
|
-
from edsl.config import CONFIG
|
6
|
-
|
7
|
-
|
8
|
-
class InferenceServiceABC(ABC):
|
9
|
-
"""
|
10
|
-
Abstract class for inference services.
|
11
|
-
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
12
|
-
"""
|
13
|
-
|
14
|
-
_coop_config_vars = None
|
15
|
-
|
16
|
-
default_levels = {
|
17
|
-
"google": {"tpm": 2_000_000, "rpm": 15},
|
18
|
-
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
19
|
-
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
20
|
-
}
|
21
|
-
|
22
|
-
def __init_subclass__(cls):
|
23
|
-
"""
|
24
|
-
Check that the subclass has the required attributes.
|
25
|
-
- `key_sequence` attribute determines...
|
26
|
-
- `model_exclude_list` attribute determines...
|
27
|
-
"""
|
28
|
-
if not hasattr(cls, "key_sequence"):
|
29
|
-
raise NotImplementedError(
|
30
|
-
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
31
|
-
)
|
32
|
-
if not hasattr(cls, "model_exclude_list"):
|
33
|
-
raise NotImplementedError(
|
34
|
-
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
35
|
-
)
|
36
|
-
|
37
|
-
@classmethod
|
38
|
-
def _should_refresh_coop_config_vars(cls):
|
39
|
-
"""
|
40
|
-
Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
|
41
|
-
"""
|
42
|
-
|
43
|
-
if cls._last_config_fetch is None:
|
44
|
-
return True
|
45
|
-
return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
|
46
|
-
|
47
|
-
@classmethod
|
48
|
-
def _get_limt(cls, limit_type: str) -> int:
|
49
|
-
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
50
|
-
if key in os.environ:
|
51
|
-
return int(os.getenv(key))
|
52
|
-
|
53
|
-
if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
|
54
|
-
try:
|
55
|
-
from edsl import Coop
|
56
|
-
|
57
|
-
c = Coop()
|
58
|
-
cls._coop_config_vars = c.fetch_rate_limit_config_vars()
|
59
|
-
cls._last_config_fetch = datetime.now()
|
60
|
-
if key in cls._coop_config_vars:
|
61
|
-
return cls._coop_config_vars[key]
|
62
|
-
except Exception:
|
63
|
-
cls._coop_config_vars = None
|
64
|
-
else:
|
65
|
-
if key in cls._coop_config_vars:
|
66
|
-
return cls._coop_config_vars[key]
|
67
|
-
|
68
|
-
if cls._inference_service_ in cls.default_levels:
|
69
|
-
return int(cls.default_levels[cls._inference_service_][limit_type])
|
70
|
-
|
71
|
-
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
72
|
-
|
73
|
-
def get_tpm(cls) -> int:
|
74
|
-
"""
|
75
|
-
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
76
|
-
"""
|
77
|
-
return cls._get_limt(limit_type="tpm")
|
78
|
-
|
79
|
-
def get_rpm(cls):
|
80
|
-
"""
|
81
|
-
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
82
|
-
"""
|
83
|
-
return cls._get_limt(limit_type="rpm")
|
84
|
-
|
85
|
-
@abstractmethod
|
86
|
-
def available() -> list[str]:
|
87
|
-
"""
|
88
|
-
Returns a list of available models for the service.
|
89
|
-
"""
|
90
|
-
pass
|
91
|
-
|
92
|
-
@abstractmethod
|
93
|
-
def create_model():
|
94
|
-
"""
|
95
|
-
Returns a LanguageModel object.
|
96
|
-
"""
|
97
|
-
pass
|
98
|
-
|
99
|
-
@staticmethod
|
100
|
-
def to_class_name(s):
|
101
|
-
"""
|
102
|
-
Converts a string to a valid class name.
|
103
|
-
|
104
|
-
>>> InferenceServiceABC.to_class_name("hello world")
|
105
|
-
'HelloWorld'
|
106
|
-
"""
|
107
|
-
|
108
|
-
s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
|
109
|
-
s = "".join(word.title() for word in s.split())
|
110
|
-
if s and s[0].isdigit():
|
111
|
-
s = "Class" + s
|
112
|
-
return s
|
113
|
-
|
114
|
-
|
115
|
-
if __name__ == "__main__":
|
116
|
-
pass
|
117
|
-
# deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
|
118
|
-
# deep_infra_service.available()
|
119
|
-
# m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
|
120
|
-
# response = m().hello()
|
121
|
-
# print(response)
|
122
|
-
|
123
|
-
# anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
|
124
|
-
# anthropic_service.available()
|
125
|
-
# m = anthropic_service.create_model("claude-3-opus-20240229")
|
126
|
-
# response = m().hello()
|
127
|
-
# print(response)
|
128
|
-
# factory = OpenAIService("openai", "OPENAI_API")
|
129
|
-
# factory.available()
|
130
|
-
# m = factory.create_model("gpt-3.5-turbo")
|
131
|
-
# response = m().hello()
|
132
|
-
|
133
|
-
# from edsl import QuestionFreeText
|
134
|
-
# results = QuestionFreeText.example().by(m()).run()
|
135
|
-
|
136
|
-
# collection = InferenceServicesCollection([
|
137
|
-
# OpenAIService,
|
138
|
-
# AnthropicService,
|
139
|
-
# DeepInfraService
|
140
|
-
# ])
|
141
|
-
|
142
|
-
# available = collection.available()
|
143
|
-
# factory = collection.create_model_factory(*available[0])
|
144
|
-
# m = factory()
|
145
|
-
# from edsl import QuestionFreeText
|
146
|
-
# results = QuestionFreeText.example().by(m).run()
|
147
|
-
# print(results)
|
1
|
+
from abc import abstractmethod, ABC
|
2
|
+
import os
|
3
|
+
import re
|
4
|
+
from datetime import datetime, timedelta
|
5
|
+
from edsl.config import CONFIG
|
6
|
+
|
7
|
+
|
8
|
+
class InferenceServiceABC(ABC):
|
9
|
+
"""
|
10
|
+
Abstract class for inference services.
|
11
|
+
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
12
|
+
"""
|
13
|
+
|
14
|
+
_coop_config_vars = None
|
15
|
+
|
16
|
+
default_levels = {
|
17
|
+
"google": {"tpm": 2_000_000, "rpm": 15},
|
18
|
+
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
19
|
+
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
20
|
+
}
|
21
|
+
|
22
|
+
def __init_subclass__(cls):
|
23
|
+
"""
|
24
|
+
Check that the subclass has the required attributes.
|
25
|
+
- `key_sequence` attribute determines...
|
26
|
+
- `model_exclude_list` attribute determines...
|
27
|
+
"""
|
28
|
+
if not hasattr(cls, "key_sequence"):
|
29
|
+
raise NotImplementedError(
|
30
|
+
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
31
|
+
)
|
32
|
+
if not hasattr(cls, "model_exclude_list"):
|
33
|
+
raise NotImplementedError(
|
34
|
+
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
35
|
+
)
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def _should_refresh_coop_config_vars(cls):
|
39
|
+
"""
|
40
|
+
Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
|
41
|
+
"""
|
42
|
+
|
43
|
+
if cls._last_config_fetch is None:
|
44
|
+
return True
|
45
|
+
return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def _get_limt(cls, limit_type: str) -> int:
|
49
|
+
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
50
|
+
if key in os.environ:
|
51
|
+
return int(os.getenv(key))
|
52
|
+
|
53
|
+
if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
|
54
|
+
try:
|
55
|
+
from edsl import Coop
|
56
|
+
|
57
|
+
c = Coop()
|
58
|
+
cls._coop_config_vars = c.fetch_rate_limit_config_vars()
|
59
|
+
cls._last_config_fetch = datetime.now()
|
60
|
+
if key in cls._coop_config_vars:
|
61
|
+
return cls._coop_config_vars[key]
|
62
|
+
except Exception:
|
63
|
+
cls._coop_config_vars = None
|
64
|
+
else:
|
65
|
+
if key in cls._coop_config_vars:
|
66
|
+
return cls._coop_config_vars[key]
|
67
|
+
|
68
|
+
if cls._inference_service_ in cls.default_levels:
|
69
|
+
return int(cls.default_levels[cls._inference_service_][limit_type])
|
70
|
+
|
71
|
+
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
72
|
+
|
73
|
+
def get_tpm(cls) -> int:
|
74
|
+
"""
|
75
|
+
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
76
|
+
"""
|
77
|
+
return cls._get_limt(limit_type="tpm")
|
78
|
+
|
79
|
+
def get_rpm(cls):
|
80
|
+
"""
|
81
|
+
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
82
|
+
"""
|
83
|
+
return cls._get_limt(limit_type="rpm")
|
84
|
+
|
85
|
+
@abstractmethod
|
86
|
+
def available() -> list[str]:
|
87
|
+
"""
|
88
|
+
Returns a list of available models for the service.
|
89
|
+
"""
|
90
|
+
pass
|
91
|
+
|
92
|
+
@abstractmethod
|
93
|
+
def create_model():
|
94
|
+
"""
|
95
|
+
Returns a LanguageModel object.
|
96
|
+
"""
|
97
|
+
pass
|
98
|
+
|
99
|
+
@staticmethod
|
100
|
+
def to_class_name(s):
|
101
|
+
"""
|
102
|
+
Converts a string to a valid class name.
|
103
|
+
|
104
|
+
>>> InferenceServiceABC.to_class_name("hello world")
|
105
|
+
'HelloWorld'
|
106
|
+
"""
|
107
|
+
|
108
|
+
s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
|
109
|
+
s = "".join(word.title() for word in s.split())
|
110
|
+
if s and s[0].isdigit():
|
111
|
+
s = "Class" + s
|
112
|
+
return s
|
113
|
+
|
114
|
+
|
115
|
+
if __name__ == "__main__":
|
116
|
+
pass
|
117
|
+
# deep_infra_service = DeepInfraService("deep_infra", "DEEP_INFRA_API_KEY")
|
118
|
+
# deep_infra_service.available()
|
119
|
+
# m = deep_infra_service.create_model("microsoft/WizardLM-2-7B")
|
120
|
+
# response = m().hello()
|
121
|
+
# print(response)
|
122
|
+
|
123
|
+
# anthropic_service = AnthropicService("anthropic", "ANTHROPIC_API_KEY")
|
124
|
+
# anthropic_service.available()
|
125
|
+
# m = anthropic_service.create_model("claude-3-opus-20240229")
|
126
|
+
# response = m().hello()
|
127
|
+
# print(response)
|
128
|
+
# factory = OpenAIService("openai", "OPENAI_API")
|
129
|
+
# factory.available()
|
130
|
+
# m = factory.create_model("gpt-3.5-turbo")
|
131
|
+
# response = m().hello()
|
132
|
+
|
133
|
+
# from edsl import QuestionFreeText
|
134
|
+
# results = QuestionFreeText.example().by(m()).run()
|
135
|
+
|
136
|
+
# collection = InferenceServicesCollection([
|
137
|
+
# OpenAIService,
|
138
|
+
# AnthropicService,
|
139
|
+
# DeepInfraService
|
140
|
+
# ])
|
141
|
+
|
142
|
+
# available = collection.available()
|
143
|
+
# factory = collection.create_model_factory(*available[0])
|
144
|
+
# m = factory()
|
145
|
+
# from edsl import QuestionFreeText
|
146
|
+
# results = QuestionFreeText.example().by(m).run()
|
147
|
+
# print(results)
|