edsl 0.1.37.dev2__py3-none-any.whl → 0.1.37.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -48
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +804 -804
- edsl/agents/AgentList.py +345 -345
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +305 -305
- edsl/agents/PromptConstructor.py +312 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +86 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +152 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +238 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +824 -824
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +40 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +26 -26
- edsl/exceptions/surveys.py +34 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +115 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +74 -74
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1121 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +182 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +338 -338
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +441 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/LanguageModel.py +718 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +2 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +353 -353
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +616 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +266 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +418 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +693 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +435 -435
- edsl/results/Results.py +1160 -1160
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +118 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +510 -510
- edsl/scenarios/ScenarioHtmlMixin.py +59 -59
- edsl/scenarios/ScenarioList.py +1101 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +324 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1772 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +391 -391
- {edsl-0.1.37.dev2.dist-info → edsl-0.1.37.dev3.dist-info}/LICENSE +21 -21
- {edsl-0.1.37.dev2.dist-info → edsl-0.1.37.dev3.dist-info}/METADATA +1 -1
- edsl-0.1.37.dev3.dist-info/RECORD +279 -0
- edsl-0.1.37.dev2.dist-info/RECORD +0 -279
- {edsl-0.1.37.dev2.dist-info → edsl-0.1.37.dev3.dist-info}/WHEEL +0 -0
@@ -1,74 +1,74 @@
|
|
1
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
2
|
-
import warnings
|
3
|
-
|
4
|
-
|
5
|
-
class InferenceServicesCollection:
|
6
|
-
added_models = {}
|
7
|
-
|
8
|
-
def __init__(self, services: list[InferenceServiceABC] = None):
|
9
|
-
self.services = services or []
|
10
|
-
|
11
|
-
@classmethod
|
12
|
-
def add_model(cls, service_name, model_name):
|
13
|
-
if service_name not in cls.added_models:
|
14
|
-
cls.added_models[service_name] = []
|
15
|
-
cls.added_models[service_name].append(model_name)
|
16
|
-
|
17
|
-
@staticmethod
|
18
|
-
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
|
-
from_api = True
|
20
|
-
try:
|
21
|
-
service_models = service.available()
|
22
|
-
except Exception as e:
|
23
|
-
if warn:
|
24
|
-
warnings.warn(
|
25
|
-
f"""Error getting models for {service._inference_service_}.
|
26
|
-
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
27
|
-
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
28
|
-
Relying on cache.""",
|
29
|
-
UserWarning,
|
30
|
-
)
|
31
|
-
from edsl.inference_services.models_available_cache import models_available
|
32
|
-
|
33
|
-
service_models = models_available.get(service._inference_service_, [])
|
34
|
-
# cache results
|
35
|
-
service._models_list_cache = service_models
|
36
|
-
from_api = False
|
37
|
-
return service_models # , from_api
|
38
|
-
|
39
|
-
def available(self):
|
40
|
-
total_models = []
|
41
|
-
for service in self.services:
|
42
|
-
service_models = self._get_service_available(service)
|
43
|
-
for model in service_models:
|
44
|
-
total_models.append([model, service._inference_service_, -1])
|
45
|
-
|
46
|
-
for model in self.added_models.get(service._inference_service_, []):
|
47
|
-
total_models.append([model, service._inference_service_, -1])
|
48
|
-
|
49
|
-
sorted_models = sorted(total_models)
|
50
|
-
for i, model in enumerate(sorted_models):
|
51
|
-
model[2] = i
|
52
|
-
model = tuple(model)
|
53
|
-
return sorted_models
|
54
|
-
|
55
|
-
def register(self, service):
|
56
|
-
self.services.append(service)
|
57
|
-
|
58
|
-
def create_model_factory(self, model_name: str, service_name=None, index=None):
|
59
|
-
from edsl.inference_services.TestService import TestService
|
60
|
-
|
61
|
-
if model_name == "test":
|
62
|
-
return TestService.create_model(model_name)
|
63
|
-
|
64
|
-
if service_name:
|
65
|
-
for service in self.services:
|
66
|
-
if service_name == service._inference_service_:
|
67
|
-
return service.create_model(model_name)
|
68
|
-
|
69
|
-
for service in self.services:
|
70
|
-
if model_name in self._get_service_available(service):
|
71
|
-
if service_name is None or service_name == service._inference_service_:
|
72
|
-
return service.create_model(model_name)
|
73
|
-
|
74
|
-
raise Exception(f"Model {model_name} not found in any of the services")
|
1
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
|
5
|
+
class InferenceServicesCollection:
|
6
|
+
added_models = {}
|
7
|
+
|
8
|
+
def __init__(self, services: list[InferenceServiceABC] = None):
|
9
|
+
self.services = services or []
|
10
|
+
|
11
|
+
@classmethod
|
12
|
+
def add_model(cls, service_name, model_name):
|
13
|
+
if service_name not in cls.added_models:
|
14
|
+
cls.added_models[service_name] = []
|
15
|
+
cls.added_models[service_name].append(model_name)
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
|
+
from_api = True
|
20
|
+
try:
|
21
|
+
service_models = service.available()
|
22
|
+
except Exception as e:
|
23
|
+
if warn:
|
24
|
+
warnings.warn(
|
25
|
+
f"""Error getting models for {service._inference_service_}.
|
26
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
27
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
28
|
+
Relying on cache.""",
|
29
|
+
UserWarning,
|
30
|
+
)
|
31
|
+
from edsl.inference_services.models_available_cache import models_available
|
32
|
+
|
33
|
+
service_models = models_available.get(service._inference_service_, [])
|
34
|
+
# cache results
|
35
|
+
service._models_list_cache = service_models
|
36
|
+
from_api = False
|
37
|
+
return service_models # , from_api
|
38
|
+
|
39
|
+
def available(self):
|
40
|
+
total_models = []
|
41
|
+
for service in self.services:
|
42
|
+
service_models = self._get_service_available(service)
|
43
|
+
for model in service_models:
|
44
|
+
total_models.append([model, service._inference_service_, -1])
|
45
|
+
|
46
|
+
for model in self.added_models.get(service._inference_service_, []):
|
47
|
+
total_models.append([model, service._inference_service_, -1])
|
48
|
+
|
49
|
+
sorted_models = sorted(total_models)
|
50
|
+
for i, model in enumerate(sorted_models):
|
51
|
+
model[2] = i
|
52
|
+
model = tuple(model)
|
53
|
+
return sorted_models
|
54
|
+
|
55
|
+
def register(self, service):
|
56
|
+
self.services.append(service)
|
57
|
+
|
58
|
+
def create_model_factory(self, model_name: str, service_name=None, index=None):
|
59
|
+
from edsl.inference_services.TestService import TestService
|
60
|
+
|
61
|
+
if model_name == "test":
|
62
|
+
return TestService.create_model(model_name)
|
63
|
+
|
64
|
+
if service_name:
|
65
|
+
for service in self.services:
|
66
|
+
if service_name == service._inference_service_:
|
67
|
+
return service.create_model(model_name)
|
68
|
+
|
69
|
+
for service in self.services:
|
70
|
+
if model_name in self._get_service_available(service):
|
71
|
+
if service_name is None or service_name == service._inference_service_:
|
72
|
+
return service.create_model(model_name)
|
73
|
+
|
74
|
+
raise Exception(f"Model {model_name} not found in any of the services")
|
@@ -1,123 +1,123 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Any, List, Optional
|
3
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
-
import asyncio
|
6
|
-
from mistralai import Mistral
|
7
|
-
|
8
|
-
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
-
|
10
|
-
|
11
|
-
class MistralAIService(InferenceServiceABC):
|
12
|
-
"""Mistral AI service class."""
|
13
|
-
|
14
|
-
key_sequence = ["choices", 0, "message", "content"]
|
15
|
-
usage_sequence = ["usage"]
|
16
|
-
|
17
|
-
_inference_service_ = "mistral"
|
18
|
-
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
-
input_token_name = "prompt_tokens"
|
20
|
-
output_token_name = "completion_tokens"
|
21
|
-
|
22
|
-
_sync_client_instance = None
|
23
|
-
_async_client_instance = None
|
24
|
-
|
25
|
-
_sync_client = Mistral
|
26
|
-
_async_client = Mistral
|
27
|
-
|
28
|
-
_models_list_cache: List[str] = []
|
29
|
-
model_exclude_list = []
|
30
|
-
|
31
|
-
def __init_subclass__(cls, **kwargs):
|
32
|
-
super().__init_subclass__(**kwargs)
|
33
|
-
# so subclasses have to create their own instances of the clients
|
34
|
-
cls._sync_client_instance = None
|
35
|
-
cls._async_client_instance = None
|
36
|
-
|
37
|
-
@classmethod
|
38
|
-
def sync_client(cls):
|
39
|
-
if cls._sync_client_instance is None:
|
40
|
-
cls._sync_client_instance = cls._sync_client(
|
41
|
-
api_key=os.getenv(cls._env_key_name_)
|
42
|
-
)
|
43
|
-
return cls._sync_client_instance
|
44
|
-
|
45
|
-
@classmethod
|
46
|
-
def async_client(cls):
|
47
|
-
if cls._async_client_instance is None:
|
48
|
-
cls._async_client_instance = cls._async_client(
|
49
|
-
api_key=os.getenv(cls._env_key_name_)
|
50
|
-
)
|
51
|
-
return cls._async_client_instance
|
52
|
-
|
53
|
-
@classmethod
|
54
|
-
def available(cls) -> list[str]:
|
55
|
-
if not cls._models_list_cache:
|
56
|
-
cls._models_list_cache = [
|
57
|
-
m.id for m in cls.sync_client().models.list().data
|
58
|
-
]
|
59
|
-
|
60
|
-
return cls._models_list_cache
|
61
|
-
|
62
|
-
@classmethod
|
63
|
-
def create_model(
|
64
|
-
cls, model_name: str = "mistral", model_class_name=None
|
65
|
-
) -> LanguageModel:
|
66
|
-
if model_class_name is None:
|
67
|
-
model_class_name = cls.to_class_name(model_name)
|
68
|
-
|
69
|
-
class LLM(LanguageModel):
|
70
|
-
"""
|
71
|
-
Child class of LanguageModel for interacting with Mistral models.
|
72
|
-
"""
|
73
|
-
|
74
|
-
key_sequence = cls.key_sequence
|
75
|
-
usage_sequence = cls.usage_sequence
|
76
|
-
|
77
|
-
input_token_name = cls.input_token_name
|
78
|
-
output_token_name = cls.output_token_name
|
79
|
-
|
80
|
-
_inference_service_ = cls._inference_service_
|
81
|
-
_model_ = model_name
|
82
|
-
_parameters_ = {
|
83
|
-
"temperature": 0.5,
|
84
|
-
"max_tokens": 512,
|
85
|
-
"top_p": 0.9,
|
86
|
-
}
|
87
|
-
|
88
|
-
_tpm = cls.get_tpm(cls)
|
89
|
-
_rpm = cls.get_rpm(cls)
|
90
|
-
|
91
|
-
def sync_client(self):
|
92
|
-
return cls.sync_client()
|
93
|
-
|
94
|
-
def async_client(self):
|
95
|
-
return cls.async_client()
|
96
|
-
|
97
|
-
async def async_execute_model_call(
|
98
|
-
self,
|
99
|
-
user_prompt: str,
|
100
|
-
system_prompt: str = "",
|
101
|
-
files_list: Optional[List["FileStore"]] = None,
|
102
|
-
) -> dict[str, Any]:
|
103
|
-
"""Calls the Mistral API and returns the API response."""
|
104
|
-
s = self.async_client()
|
105
|
-
|
106
|
-
try:
|
107
|
-
res = await s.chat.complete_async(
|
108
|
-
model=model_name,
|
109
|
-
messages=[
|
110
|
-
{
|
111
|
-
"content": user_prompt,
|
112
|
-
"role": "user",
|
113
|
-
},
|
114
|
-
],
|
115
|
-
)
|
116
|
-
except Exception as e:
|
117
|
-
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
118
|
-
|
119
|
-
return res.model_dump()
|
120
|
-
|
121
|
-
LLM.__name__ = model_class_name
|
122
|
-
|
123
|
-
return LLM
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Optional
|
3
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
import asyncio
|
6
|
+
from mistralai import Mistral
|
7
|
+
|
8
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
+
|
10
|
+
|
11
|
+
class MistralAIService(InferenceServiceABC):
|
12
|
+
"""Mistral AI service class."""
|
13
|
+
|
14
|
+
key_sequence = ["choices", 0, "message", "content"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
|
17
|
+
_inference_service_ = "mistral"
|
18
|
+
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
+
input_token_name = "prompt_tokens"
|
20
|
+
output_token_name = "completion_tokens"
|
21
|
+
|
22
|
+
_sync_client_instance = None
|
23
|
+
_async_client_instance = None
|
24
|
+
|
25
|
+
_sync_client = Mistral
|
26
|
+
_async_client = Mistral
|
27
|
+
|
28
|
+
_models_list_cache: List[str] = []
|
29
|
+
model_exclude_list = []
|
30
|
+
|
31
|
+
def __init_subclass__(cls, **kwargs):
|
32
|
+
super().__init_subclass__(**kwargs)
|
33
|
+
# so subclasses have to create their own instances of the clients
|
34
|
+
cls._sync_client_instance = None
|
35
|
+
cls._async_client_instance = None
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def sync_client(cls):
|
39
|
+
if cls._sync_client_instance is None:
|
40
|
+
cls._sync_client_instance = cls._sync_client(
|
41
|
+
api_key=os.getenv(cls._env_key_name_)
|
42
|
+
)
|
43
|
+
return cls._sync_client_instance
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def async_client(cls):
|
47
|
+
if cls._async_client_instance is None:
|
48
|
+
cls._async_client_instance = cls._async_client(
|
49
|
+
api_key=os.getenv(cls._env_key_name_)
|
50
|
+
)
|
51
|
+
return cls._async_client_instance
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def available(cls) -> list[str]:
|
55
|
+
if not cls._models_list_cache:
|
56
|
+
cls._models_list_cache = [
|
57
|
+
m.id for m in cls.sync_client().models.list().data
|
58
|
+
]
|
59
|
+
|
60
|
+
return cls._models_list_cache
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def create_model(
|
64
|
+
cls, model_name: str = "mistral", model_class_name=None
|
65
|
+
) -> LanguageModel:
|
66
|
+
if model_class_name is None:
|
67
|
+
model_class_name = cls.to_class_name(model_name)
|
68
|
+
|
69
|
+
class LLM(LanguageModel):
|
70
|
+
"""
|
71
|
+
Child class of LanguageModel for interacting with Mistral models.
|
72
|
+
"""
|
73
|
+
|
74
|
+
key_sequence = cls.key_sequence
|
75
|
+
usage_sequence = cls.usage_sequence
|
76
|
+
|
77
|
+
input_token_name = cls.input_token_name
|
78
|
+
output_token_name = cls.output_token_name
|
79
|
+
|
80
|
+
_inference_service_ = cls._inference_service_
|
81
|
+
_model_ = model_name
|
82
|
+
_parameters_ = {
|
83
|
+
"temperature": 0.5,
|
84
|
+
"max_tokens": 512,
|
85
|
+
"top_p": 0.9,
|
86
|
+
}
|
87
|
+
|
88
|
+
_tpm = cls.get_tpm(cls)
|
89
|
+
_rpm = cls.get_rpm(cls)
|
90
|
+
|
91
|
+
def sync_client(self):
|
92
|
+
return cls.sync_client()
|
93
|
+
|
94
|
+
def async_client(self):
|
95
|
+
return cls.async_client()
|
96
|
+
|
97
|
+
async def async_execute_model_call(
|
98
|
+
self,
|
99
|
+
user_prompt: str,
|
100
|
+
system_prompt: str = "",
|
101
|
+
files_list: Optional[List["FileStore"]] = None,
|
102
|
+
) -> dict[str, Any]:
|
103
|
+
"""Calls the Mistral API and returns the API response."""
|
104
|
+
s = self.async_client()
|
105
|
+
|
106
|
+
try:
|
107
|
+
res = await s.chat.complete_async(
|
108
|
+
model=model_name,
|
109
|
+
messages=[
|
110
|
+
{
|
111
|
+
"content": user_prompt,
|
112
|
+
"role": "user",
|
113
|
+
},
|
114
|
+
],
|
115
|
+
)
|
116
|
+
except Exception as e:
|
117
|
+
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
118
|
+
|
119
|
+
return res.model_dump()
|
120
|
+
|
121
|
+
LLM.__name__ = model_class_name
|
122
|
+
|
123
|
+
return LLM
|
@@ -1,18 +1,18 @@
|
|
1
|
-
import aiohttp
|
2
|
-
import json
|
3
|
-
import requests
|
4
|
-
from typing import Any, List
|
5
|
-
|
6
|
-
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
-
from edsl.language_models import LanguageModel
|
8
|
-
|
9
|
-
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
-
|
11
|
-
|
12
|
-
class OllamaService(OpenAIService):
|
13
|
-
"""DeepInfra service class."""
|
14
|
-
|
15
|
-
_inference_service_ = "ollama"
|
16
|
-
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
-
_base_url_ = "http://localhost:11434/v1"
|
18
|
-
_models_list_cache: List[str] = []
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class OllamaService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "ollama"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "http://localhost:11434/v1"
|
18
|
+
_models_list_cache: List[str] = []
|