edsl 0.1.39.dev3__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +413 -332
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +57 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +1071 -867
- edsl/agents/AgentList.py +551 -413
- edsl/agents/Invigilator.py +284 -233
- edsl/agents/InvigilatorBase.py +257 -270
- edsl/agents/PromptConstructor.py +272 -354
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +2 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -117
- edsl/auto/StageBase.py +243 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +74 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +218 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +177 -157
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +59 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +1106 -1028
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +573 -555
- edsl/data/CacheEntry.py +230 -233
- edsl/data/CacheHandler.py +168 -149
- edsl/data/RemoteCacheSync.py +186 -78
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +5 -4
- edsl/data/hack.py +10 -0
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +74 -73
- edsl/enums.py +202 -175
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +54 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +109 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +29 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +106 -87
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -120
- edsl/inference_services/AzureAI.py +215 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +143 -148
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +80 -147
- edsl/inference_services/InferenceServicesCollection.py +138 -97
- edsl/inference_services/MistralAIService.py +120 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +236 -224
- edsl/inference_services/PerplexityService.py +160 -163
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -89
- edsl/inference_services/TogetherAIService.py +172 -170
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +41 -41
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +43 -56
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +823 -898
- edsl/jobs/JobsChecks.py +172 -147
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -239
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +283 -251
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +396 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -466
- edsl/jobs/runners/JobsRunnerStatus.py +297 -330
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +470 -450
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +161 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +626 -668
- edsl/language_models/ModelList.py +164 -155
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +2 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +156 -156
- edsl/language_models/utilities.py +65 -64
- edsl/notebooks/Notebook.py +263 -258
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +352 -362
- edsl/prompts/__init__.py +2 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -664
- edsl/questions/QuestionBasePromptsMixin.py +221 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +180 -182
- edsl/questions/QuestionFreeText.py +113 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +223 -231
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +330 -286
- edsl/questions/QuestionNumerical.py +151 -153
- edsl/questions/QuestionRank.py +314 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/SimpleAskMixin.py +74 -73
- edsl/questions/__init__.py +27 -26
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +334 -289
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +90 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +427 -413
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +168 -161
- edsl/questions/question_registry.py +177 -177
- edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +71 -71
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +188 -174
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/CSSParameterizer.py +108 -108
- edsl/results/Dataset.py +587 -424
- edsl/results/DatasetExportMixin.py +594 -731
- edsl/results/DatasetTree.py +295 -275
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +557 -465
- edsl/results/Results.py +1183 -1165
- edsl/results/ResultsExportMixin.py +45 -43
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/TableDisplay.py +125 -198
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +2 -2
- edsl/results/file_exports.py +252 -0
- edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +33 -33
- edsl/results/{Selector.py → results_selector.py} +145 -135
- edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +98 -98
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +77 -77
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +511 -632
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +498 -601
- edsl/scenarios/ScenarioHtmlMixin.py +65 -64
- edsl/scenarios/ScenarioList.py +1458 -1287
- edsl/scenarios/ScenarioListExportMixin.py +45 -52
- edsl/scenarios/ScenarioListPdfMixin.py +239 -261
- edsl/scenarios/__init__.py +3 -4
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +131 -127
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +521 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +327 -326
- edsl/surveys/RuleCollection.py +385 -387
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1280 -1801
- edsl/surveys/SurveyCSS.py +273 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +181 -179
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +5 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +60 -56
- edsl/surveys/instructions/ChangeInstruction.py +48 -49
- edsl/surveys/instructions/Instruction.py +56 -65
- edsl/surveys/instructions/InstructionCollection.py +82 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +19 -19
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/test_h +1 -0
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/gcp_bucket/example.py +50 -0
- edsl/utilities/interface.py +627 -627
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +436 -424
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +21 -21
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +13 -11
- edsl-0.1.39.dev4.dist-info/RECORD +361 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev3.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev3.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -1,97 +1,138 @@
|
|
1
|
-
from
|
2
|
-
import
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
1
|
+
from functools import lru_cache
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING, Literal
|
4
|
+
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
|
+
from edsl.inference_services.AvailableModelFetcher import AvailableModelFetcher
|
7
|
+
from edsl.exceptions.inference_services import InferenceServiceError
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
11
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
12
|
+
|
13
|
+
|
14
|
+
class ModelCreator(Protocol):
|
15
|
+
def create_model(self, model_name: str) -> "LanguageModel":
|
16
|
+
...
|
17
|
+
|
18
|
+
|
19
|
+
from edsl.enums import InferenceServiceLiteral
|
20
|
+
|
21
|
+
|
22
|
+
class ModelResolver:
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
services: List[InferenceServiceLiteral],
|
26
|
+
models_to_services: Dict[InferenceServiceLiteral, InferenceServiceABC],
|
27
|
+
availability_fetcher: "AvailableModelFetcher",
|
28
|
+
):
|
29
|
+
"""
|
30
|
+
Class for determining which service to use for a given model.
|
31
|
+
"""
|
32
|
+
self.services = services
|
33
|
+
self._models_to_services = models_to_services
|
34
|
+
self.availability_fetcher = availability_fetcher
|
35
|
+
self._service_names_to_classes = {
|
36
|
+
service._inference_service_: service for service in services
|
37
|
+
}
|
38
|
+
|
39
|
+
def resolve_model(
|
40
|
+
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
41
|
+
) -> InferenceServiceABC:
|
42
|
+
"""Returns an InferenceServiceABC object for the given model name.
|
43
|
+
|
44
|
+
:param model_name: The name of the model to resolve. E.g., 'gpt-4o'
|
45
|
+
:param service_name: The name of the service to use. E.g., 'openai'
|
46
|
+
:return: An InferenceServiceABC object
|
47
|
+
|
48
|
+
"""
|
49
|
+
if model_name == "test":
|
50
|
+
from edsl.inference_services.TestService import TestService
|
51
|
+
|
52
|
+
return TestService()
|
53
|
+
|
54
|
+
if service_name is not None:
|
55
|
+
service: InferenceServiceABC = self._service_names_to_classes.get(
|
56
|
+
service_name
|
57
|
+
)
|
58
|
+
if not service:
|
59
|
+
raise InferenceServiceError(f"Service {service_name} not found")
|
60
|
+
return service
|
61
|
+
|
62
|
+
if model_name in self._models_to_services: # maybe we've seen it before!
|
63
|
+
return self._models_to_services[model_name]
|
64
|
+
|
65
|
+
for service in self.services:
|
66
|
+
(
|
67
|
+
available_models,
|
68
|
+
service_name,
|
69
|
+
) = self.availability_fetcher.get_available_models_by_service(service)
|
70
|
+
if model_name in available_models:
|
71
|
+
self._models_to_services[model_name] = service
|
72
|
+
return service
|
73
|
+
|
74
|
+
raise InferenceServiceError(
|
75
|
+
f"""Model {model_name} not found in any services.
|
76
|
+
If you know the service that has this model, use the service_name parameter directly.
|
77
|
+
E.g., Model("gpt-4o", service_name="openai")
|
78
|
+
"""
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
class InferenceServicesCollection:
|
83
|
+
added_models = defaultdict(list) # Moved back to class level
|
84
|
+
|
85
|
+
def __init__(self, services: Optional[List[InferenceServiceABC]] = None):
|
86
|
+
self.services = services or []
|
87
|
+
self._models_to_services: Dict[str, InferenceServiceABC] = {}
|
88
|
+
|
89
|
+
self.availability_fetcher = AvailableModelFetcher(
|
90
|
+
self.services, self.added_models
|
91
|
+
)
|
92
|
+
self.resolver = ModelResolver(
|
93
|
+
self.services, self._models_to_services, self.availability_fetcher
|
94
|
+
)
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def add_model(cls, service_name: str, model_name: str) -> None:
|
98
|
+
if service_name not in cls.added_models:
|
99
|
+
cls.added_models[service_name].append(model_name)
|
100
|
+
|
101
|
+
def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
|
102
|
+
return {service._inference_service_: service for service in self.services}
|
103
|
+
|
104
|
+
def available(
|
105
|
+
self,
|
106
|
+
service: Optional[str] = None,
|
107
|
+
) -> List[Tuple[str, str, int]]:
|
108
|
+
return self.availability_fetcher.available(service)
|
109
|
+
|
110
|
+
def reset_cache(self) -> None:
|
111
|
+
self.availability_fetcher.reset_cache()
|
112
|
+
|
113
|
+
@property
|
114
|
+
def num_cache_entries(self) -> int:
|
115
|
+
return self.availability_fetcher.num_cache_entries
|
116
|
+
|
117
|
+
def register(self, service: InferenceServiceABC) -> None:
|
118
|
+
self.services.append(service)
|
119
|
+
|
120
|
+
def create_model_factory(
|
121
|
+
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
122
|
+
) -> "LanguageModel":
|
123
|
+
|
124
|
+
if service_name is None: # we try to find the right service
|
125
|
+
service = self.resolver.resolve_model(model_name, service_name)
|
126
|
+
else: # if they passed a service, we'll use that
|
127
|
+
service = self.service_names_to_classes().get(service_name)
|
128
|
+
|
129
|
+
if not service: # but if we can't find it, we'll raise an error
|
130
|
+
raise InferenceServiceError(f"Service {service_name} not found")
|
131
|
+
|
132
|
+
return service.create_model(model_name)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
import doctest
|
137
|
+
|
138
|
+
doctest.testmod()
|
@@ -1,123 +1,120 @@
|
|
1
|
-
import os
|
2
|
-
from typing import Any, List, Optional
|
3
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
-
import asyncio
|
6
|
-
from mistralai import Mistral
|
7
|
-
|
8
|
-
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
-
|
10
|
-
|
11
|
-
class MistralAIService(InferenceServiceABC):
|
12
|
-
"""Mistral AI service class."""
|
13
|
-
|
14
|
-
key_sequence = ["choices", 0, "message", "content"]
|
15
|
-
usage_sequence = ["usage"]
|
16
|
-
|
17
|
-
_inference_service_ = "mistral"
|
18
|
-
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
-
input_token_name = "prompt_tokens"
|
20
|
-
output_token_name = "completion_tokens"
|
21
|
-
|
22
|
-
_sync_client_instance = None
|
23
|
-
_async_client_instance = None
|
24
|
-
|
25
|
-
_sync_client = Mistral
|
26
|
-
_async_client = Mistral
|
27
|
-
|
28
|
-
_models_list_cache: List[str] = []
|
29
|
-
model_exclude_list = []
|
30
|
-
|
31
|
-
def __init_subclass__(cls, **kwargs):
|
32
|
-
super().__init_subclass__(**kwargs)
|
33
|
-
# so subclasses have to create their own instances of the clients
|
34
|
-
cls._sync_client_instance = None
|
35
|
-
cls._async_client_instance = None
|
36
|
-
|
37
|
-
@classmethod
|
38
|
-
def sync_client(cls):
|
39
|
-
if cls._sync_client_instance is None:
|
40
|
-
cls._sync_client_instance = cls._sync_client(
|
41
|
-
api_key=os.getenv(cls._env_key_name_)
|
42
|
-
)
|
43
|
-
return cls._sync_client_instance
|
44
|
-
|
45
|
-
@classmethod
|
46
|
-
def async_client(cls):
|
47
|
-
if cls._async_client_instance is None:
|
48
|
-
cls._async_client_instance = cls._async_client(
|
49
|
-
api_key=os.getenv(cls._env_key_name_)
|
50
|
-
)
|
51
|
-
return cls._async_client_instance
|
52
|
-
|
53
|
-
@classmethod
|
54
|
-
def available(cls) -> list[str]:
|
55
|
-
if not cls._models_list_cache:
|
56
|
-
cls._models_list_cache = [
|
57
|
-
m.id for m in cls.sync_client().models.list().data
|
58
|
-
]
|
59
|
-
|
60
|
-
return cls._models_list_cache
|
61
|
-
|
62
|
-
@classmethod
|
63
|
-
def create_model(
|
64
|
-
cls, model_name: str = "mistral", model_class_name=None
|
65
|
-
) -> LanguageModel:
|
66
|
-
if model_class_name is None:
|
67
|
-
model_class_name = cls.to_class_name(model_name)
|
68
|
-
|
69
|
-
class LLM(LanguageModel):
|
70
|
-
"""
|
71
|
-
Child class of LanguageModel for interacting with Mistral models.
|
72
|
-
"""
|
73
|
-
|
74
|
-
key_sequence = cls.key_sequence
|
75
|
-
usage_sequence = cls.usage_sequence
|
76
|
-
|
77
|
-
input_token_name = cls.input_token_name
|
78
|
-
output_token_name = cls.output_token_name
|
79
|
-
|
80
|
-
_inference_service_ = cls._inference_service_
|
81
|
-
_model_ = model_name
|
82
|
-
_parameters_ = {
|
83
|
-
"temperature": 0.5,
|
84
|
-
"max_tokens": 512,
|
85
|
-
"top_p": 0.9,
|
86
|
-
}
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
def
|
92
|
-
return cls.
|
93
|
-
|
94
|
-
def
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
LLM.__name__ = model_class_name
|
122
|
-
|
123
|
-
return LLM
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Optional
|
3
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
import asyncio
|
6
|
+
from mistralai import Mistral
|
7
|
+
|
8
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
+
|
10
|
+
|
11
|
+
class MistralAIService(InferenceServiceABC):
|
12
|
+
"""Mistral AI service class."""
|
13
|
+
|
14
|
+
key_sequence = ["choices", 0, "message", "content"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
|
17
|
+
_inference_service_ = "mistral"
|
18
|
+
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
+
input_token_name = "prompt_tokens"
|
20
|
+
output_token_name = "completion_tokens"
|
21
|
+
|
22
|
+
_sync_client_instance = None
|
23
|
+
_async_client_instance = None
|
24
|
+
|
25
|
+
_sync_client = Mistral
|
26
|
+
_async_client = Mistral
|
27
|
+
|
28
|
+
_models_list_cache: List[str] = []
|
29
|
+
model_exclude_list = []
|
30
|
+
|
31
|
+
def __init_subclass__(cls, **kwargs):
|
32
|
+
super().__init_subclass__(**kwargs)
|
33
|
+
# so subclasses have to create their own instances of the clients
|
34
|
+
cls._sync_client_instance = None
|
35
|
+
cls._async_client_instance = None
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def sync_client(cls):
|
39
|
+
if cls._sync_client_instance is None:
|
40
|
+
cls._sync_client_instance = cls._sync_client(
|
41
|
+
api_key=os.getenv(cls._env_key_name_)
|
42
|
+
)
|
43
|
+
return cls._sync_client_instance
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def async_client(cls):
|
47
|
+
if cls._async_client_instance is None:
|
48
|
+
cls._async_client_instance = cls._async_client(
|
49
|
+
api_key=os.getenv(cls._env_key_name_)
|
50
|
+
)
|
51
|
+
return cls._async_client_instance
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def available(cls) -> list[str]:
|
55
|
+
if not cls._models_list_cache:
|
56
|
+
cls._models_list_cache = [
|
57
|
+
m.id for m in cls.sync_client().models.list().data
|
58
|
+
]
|
59
|
+
|
60
|
+
return cls._models_list_cache
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def create_model(
|
64
|
+
cls, model_name: str = "mistral", model_class_name=None
|
65
|
+
) -> LanguageModel:
|
66
|
+
if model_class_name is None:
|
67
|
+
model_class_name = cls.to_class_name(model_name)
|
68
|
+
|
69
|
+
class LLM(LanguageModel):
|
70
|
+
"""
|
71
|
+
Child class of LanguageModel for interacting with Mistral models.
|
72
|
+
"""
|
73
|
+
|
74
|
+
key_sequence = cls.key_sequence
|
75
|
+
usage_sequence = cls.usage_sequence
|
76
|
+
|
77
|
+
input_token_name = cls.input_token_name
|
78
|
+
output_token_name = cls.output_token_name
|
79
|
+
|
80
|
+
_inference_service_ = cls._inference_service_
|
81
|
+
_model_ = model_name
|
82
|
+
_parameters_ = {
|
83
|
+
"temperature": 0.5,
|
84
|
+
"max_tokens": 512,
|
85
|
+
"top_p": 0.9,
|
86
|
+
}
|
87
|
+
|
88
|
+
def sync_client(self):
|
89
|
+
return cls.sync_client()
|
90
|
+
|
91
|
+
def async_client(self):
|
92
|
+
return cls.async_client()
|
93
|
+
|
94
|
+
async def async_execute_model_call(
|
95
|
+
self,
|
96
|
+
user_prompt: str,
|
97
|
+
system_prompt: str = "",
|
98
|
+
files_list: Optional[List["FileStore"]] = None,
|
99
|
+
) -> dict[str, Any]:
|
100
|
+
"""Calls the Mistral API and returns the API response."""
|
101
|
+
s = self.async_client()
|
102
|
+
|
103
|
+
try:
|
104
|
+
res = await s.chat.complete_async(
|
105
|
+
model=model_name,
|
106
|
+
messages=[
|
107
|
+
{
|
108
|
+
"content": user_prompt,
|
109
|
+
"role": "user",
|
110
|
+
},
|
111
|
+
],
|
112
|
+
)
|
113
|
+
except Exception as e:
|
114
|
+
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
115
|
+
|
116
|
+
return res.model_dump()
|
117
|
+
|
118
|
+
LLM.__name__ = model_class_name
|
119
|
+
|
120
|
+
return LLM
|
@@ -1,18 +1,18 @@
|
|
1
|
-
import aiohttp
|
2
|
-
import json
|
3
|
-
import requests
|
4
|
-
from typing import Any, List
|
5
|
-
|
6
|
-
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
-
from edsl.language_models import LanguageModel
|
8
|
-
|
9
|
-
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
-
|
11
|
-
|
12
|
-
class OllamaService(OpenAIService):
|
13
|
-
"""DeepInfra service class."""
|
14
|
-
|
15
|
-
_inference_service_ = "ollama"
|
16
|
-
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
-
_base_url_ = "http://localhost:11434/v1"
|
18
|
-
_models_list_cache: List[str] = []
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class OllamaService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "ollama"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "http://localhost:11434/v1"
|
18
|
+
_models_list_cache: List[str] = []
|