edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +46 -10
- edsl/__version__.py +1 -0
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +121 -104
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -204
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -417
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -166
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -218
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.14.dist-info/METADATA +0 -69
- edsl-0.1.14.dist-info/RECORD +0 -141
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, Optional, List
|
3
|
+
import re
|
4
|
+
from openai import AsyncAzureOpenAI
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
+
|
8
|
+
from azure.ai.inference.aio import ChatCompletionsClient
|
9
|
+
from azure.core.credentials import AzureKeyCredential
|
10
|
+
from azure.ai.inference.models import SystemMessage, UserMessage
|
11
|
+
import asyncio
|
12
|
+
import json
|
13
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
14
|
+
|
15
|
+
|
16
|
+
def json_handle_none(value: Any) -> Any:
|
17
|
+
"""
|
18
|
+
Handle None values during JSON serialization.
|
19
|
+
- Return "null" if the value is None. Otherwise, don't return anything.
|
20
|
+
"""
|
21
|
+
if value is None:
|
22
|
+
return "null"
|
23
|
+
|
24
|
+
|
25
|
+
class AzureAIService(InferenceServiceABC):
|
26
|
+
"""Azure AI service class."""
|
27
|
+
|
28
|
+
# key_sequence = ["content", 0, "text"] # ["content"][0]["text"]
|
29
|
+
key_sequence = ["choices", 0, "message", "content"]
|
30
|
+
usage_sequence = ["usage"]
|
31
|
+
input_token_name = "prompt_tokens"
|
32
|
+
output_token_name = "completion_tokens"
|
33
|
+
|
34
|
+
_inference_service_ = "azure"
|
35
|
+
_env_key_name_ = (
|
36
|
+
"AZURE_ENDPOINT_URL_AND_KEY" # Environment variable for Azure API key
|
37
|
+
)
|
38
|
+
_model_id_to_endpoint_and_key = {}
|
39
|
+
model_exclude_list = [
|
40
|
+
"Cohere-command-r-plus-xncmg",
|
41
|
+
"Mistral-Nemo-klfsi",
|
42
|
+
"Mistral-large-2407-ojfld",
|
43
|
+
]
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def available(cls):
|
47
|
+
out = []
|
48
|
+
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
49
|
+
if not azure_endpoints:
|
50
|
+
raise EnvironmentError(f"AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
51
|
+
azure_endpoints = azure_endpoints.split(",")
|
52
|
+
for data in azure_endpoints:
|
53
|
+
try:
|
54
|
+
# data has this format for non openai models https://model_id.azure_endpoint:azure_key
|
55
|
+
_, endpoint, azure_endpoint_key = data.split(":")
|
56
|
+
if "openai" not in endpoint:
|
57
|
+
model_id = endpoint.split(".")[0].replace("/", "")
|
58
|
+
out.append(model_id)
|
59
|
+
cls._model_id_to_endpoint_and_key[model_id] = {
|
60
|
+
"endpoint": f"https:{endpoint}",
|
61
|
+
"azure_endpoint_key": azure_endpoint_key,
|
62
|
+
}
|
63
|
+
else:
|
64
|
+
# data has this format for openai models ,https://azure_project_id.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview:azure_key
|
65
|
+
if "/deployments/" in endpoint:
|
66
|
+
start_idx = endpoint.index("/deployments/") + len(
|
67
|
+
"/deployments/"
|
68
|
+
)
|
69
|
+
end_idx = (
|
70
|
+
endpoint.index("/", start_idx)
|
71
|
+
if "/" in endpoint[start_idx:]
|
72
|
+
else len(endpoint)
|
73
|
+
)
|
74
|
+
model_id = endpoint[start_idx:end_idx]
|
75
|
+
api_version_value = None
|
76
|
+
if "api-version=" in endpoint:
|
77
|
+
start_idx = endpoint.index("api-version=") + len(
|
78
|
+
"api-version="
|
79
|
+
)
|
80
|
+
end_idx = (
|
81
|
+
endpoint.index("&", start_idx)
|
82
|
+
if "&" in endpoint[start_idx:]
|
83
|
+
else len(endpoint)
|
84
|
+
)
|
85
|
+
api_version_value = endpoint[start_idx:end_idx]
|
86
|
+
|
87
|
+
cls._model_id_to_endpoint_and_key[f"azure:{model_id}"] = {
|
88
|
+
"endpoint": f"https:{endpoint}",
|
89
|
+
"azure_endpoint_key": azure_endpoint_key,
|
90
|
+
"api_version": api_version_value,
|
91
|
+
}
|
92
|
+
out.append(f"azure:{model_id}")
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
raise e
|
96
|
+
return [m for m in out if m not in cls.model_exclude_list]
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def create_model(
|
100
|
+
cls, model_name: str = "azureai", model_class_name=None
|
101
|
+
) -> LanguageModel:
|
102
|
+
if model_class_name is None:
|
103
|
+
model_class_name = cls.to_class_name(model_name)
|
104
|
+
|
105
|
+
class LLM(LanguageModel):
|
106
|
+
"""
|
107
|
+
Child class of LanguageModel for interacting with Azure OpenAI models.
|
108
|
+
"""
|
109
|
+
|
110
|
+
key_sequence = cls.key_sequence
|
111
|
+
usage_sequence = cls.usage_sequence
|
112
|
+
input_token_name = cls.input_token_name
|
113
|
+
output_token_name = cls.output_token_name
|
114
|
+
_inference_service_ = cls._inference_service_
|
115
|
+
_model_ = model_name
|
116
|
+
_parameters_ = {
|
117
|
+
"temperature": 0.5,
|
118
|
+
"max_tokens": 512,
|
119
|
+
"top_p": 0.9,
|
120
|
+
}
|
121
|
+
|
122
|
+
async def async_execute_model_call(
|
123
|
+
self,
|
124
|
+
user_prompt: str,
|
125
|
+
system_prompt: str = "",
|
126
|
+
files_list: Optional[List["FileStore"]] = None,
|
127
|
+
) -> dict[str, Any]:
|
128
|
+
"""Calls the Azure OpenAI API and returns the API response."""
|
129
|
+
|
130
|
+
try:
|
131
|
+
api_key = cls._model_id_to_endpoint_and_key[model_name][
|
132
|
+
"azure_endpoint_key"
|
133
|
+
]
|
134
|
+
except:
|
135
|
+
api_key = None
|
136
|
+
|
137
|
+
if not api_key:
|
138
|
+
raise EnvironmentError(
|
139
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
140
|
+
)
|
141
|
+
|
142
|
+
try:
|
143
|
+
endpoint = cls._model_id_to_endpoint_and_key[model_name]["endpoint"]
|
144
|
+
except:
|
145
|
+
endpoint = None
|
146
|
+
|
147
|
+
if not endpoint:
|
148
|
+
raise EnvironmentError(
|
149
|
+
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
150
|
+
)
|
151
|
+
|
152
|
+
if "openai" not in endpoint:
|
153
|
+
client = ChatCompletionsClient(
|
154
|
+
endpoint=endpoint,
|
155
|
+
credential=AzureKeyCredential(api_key),
|
156
|
+
temperature=self.temperature,
|
157
|
+
top_p=self.top_p,
|
158
|
+
max_tokens=self.max_tokens,
|
159
|
+
)
|
160
|
+
try:
|
161
|
+
response = await client.complete(
|
162
|
+
messages=[
|
163
|
+
SystemMessage(content=system_prompt),
|
164
|
+
UserMessage(content=user_prompt),
|
165
|
+
],
|
166
|
+
# model_extras={"safe_mode": True},
|
167
|
+
)
|
168
|
+
await client.close()
|
169
|
+
return response.as_dict()
|
170
|
+
except Exception as e:
|
171
|
+
await client.close()
|
172
|
+
return {"error": str(e)}
|
173
|
+
else:
|
174
|
+
api_version = cls._model_id_to_endpoint_and_key[model_name][
|
175
|
+
"api_version"
|
176
|
+
]
|
177
|
+
client = AsyncAzureOpenAI(
|
178
|
+
azure_endpoint=endpoint,
|
179
|
+
api_version=api_version,
|
180
|
+
api_key=api_key,
|
181
|
+
)
|
182
|
+
response = await client.chat.completions.create(
|
183
|
+
model=model_name,
|
184
|
+
messages=[
|
185
|
+
{
|
186
|
+
"role": "user",
|
187
|
+
"content": user_prompt, # Your question can go here
|
188
|
+
},
|
189
|
+
],
|
190
|
+
)
|
191
|
+
return response.model_dump()
|
192
|
+
|
193
|
+
# @staticmethod
|
194
|
+
# def parse_response(raw_response: dict[str, Any]) -> str:
|
195
|
+
# """Parses the API response and returns the response text."""
|
196
|
+
# if (
|
197
|
+
# raw_response
|
198
|
+
# and "choices" in raw_response
|
199
|
+
# and raw_response["choices"]
|
200
|
+
# ):
|
201
|
+
# response = raw_response["choices"][0]["message"]["content"]
|
202
|
+
# pattern = r"^```json(?:\\n|\n)(.+?)(?:\\n|\n)```$"
|
203
|
+
# match = re.match(pattern, response, re.DOTALL)
|
204
|
+
# if match:
|
205
|
+
# return match.group(1)
|
206
|
+
# else:
|
207
|
+
# out = fix_partial_correct_response(response)
|
208
|
+
# if "error" not in out:
|
209
|
+
# response = out["extracted_json"]
|
210
|
+
# return response
|
211
|
+
# return "Error parsing response"
|
212
|
+
|
213
|
+
LLM.__name__ = model_class_name
|
214
|
+
|
215
|
+
return LLM
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class DeepInfraService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "deep_infra"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
18
|
+
_models_list_cache: List[str] = []
|
@@ -0,0 +1,143 @@
|
|
1
|
+
# import os
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
import google
|
4
|
+
import google.generativeai as genai
|
5
|
+
from google.generativeai.types import GenerationConfig
|
6
|
+
from google.api_core.exceptions import InvalidArgument
|
7
|
+
|
8
|
+
# from edsl.exceptions.general import MissingAPIKeyError
|
9
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
10
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
11
|
+
from edsl.coop import Coop
|
12
|
+
|
13
|
+
safety_settings = [
|
14
|
+
{
|
15
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
16
|
+
"threshold": "BLOCK_NONE",
|
17
|
+
},
|
18
|
+
{
|
19
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
20
|
+
"threshold": "BLOCK_NONE",
|
21
|
+
},
|
22
|
+
{
|
23
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
24
|
+
"threshold": "BLOCK_NONE",
|
25
|
+
},
|
26
|
+
{
|
27
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
28
|
+
"threshold": "BLOCK_NONE",
|
29
|
+
},
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
class GoogleService(InferenceServiceABC):
|
34
|
+
_inference_service_ = "google"
|
35
|
+
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
36
|
+
usage_sequence = ["usage_metadata"]
|
37
|
+
input_token_name = "prompt_token_count"
|
38
|
+
output_token_name = "candidates_token_count"
|
39
|
+
|
40
|
+
model_exclude_list = []
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def get_model_list(cls):
|
44
|
+
model_list = []
|
45
|
+
for m in genai.list_models():
|
46
|
+
if "generateContent" in m.supported_generation_methods:
|
47
|
+
model_list.append(m.name.split("/")[-1])
|
48
|
+
return model_list
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
def available(cls) -> List[str]:
|
52
|
+
return cls.get_model_list()
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def create_model(
|
56
|
+
cls, model_name: str = "gemini-pro", model_class_name=None
|
57
|
+
) -> LanguageModel:
|
58
|
+
if model_class_name is None:
|
59
|
+
model_class_name = cls.to_class_name(model_name)
|
60
|
+
|
61
|
+
class LLM(LanguageModel):
|
62
|
+
_model_ = model_name
|
63
|
+
key_sequence = cls.key_sequence
|
64
|
+
usage_sequence = cls.usage_sequence
|
65
|
+
input_token_name = cls.input_token_name
|
66
|
+
output_token_name = cls.output_token_name
|
67
|
+
_inference_service_ = cls._inference_service_
|
68
|
+
|
69
|
+
_parameters_ = {
|
70
|
+
"temperature": 0.5,
|
71
|
+
"topP": 1,
|
72
|
+
"topK": 1,
|
73
|
+
"maxOutputTokens": 2048,
|
74
|
+
"stopSequences": [],
|
75
|
+
}
|
76
|
+
|
77
|
+
model = None
|
78
|
+
|
79
|
+
def __init__(self, *args, **kwargs):
|
80
|
+
super().__init__(*args, **kwargs)
|
81
|
+
|
82
|
+
def get_generation_config(self) -> GenerationConfig:
|
83
|
+
return GenerationConfig(
|
84
|
+
temperature=self.temperature,
|
85
|
+
top_p=self.topP,
|
86
|
+
top_k=self.topK,
|
87
|
+
max_output_tokens=self.maxOutputTokens,
|
88
|
+
stop_sequences=self.stopSequences,
|
89
|
+
)
|
90
|
+
|
91
|
+
async def async_execute_model_call(
|
92
|
+
self,
|
93
|
+
user_prompt: str,
|
94
|
+
system_prompt: str = "",
|
95
|
+
files_list: Optional["Files"] = None,
|
96
|
+
) -> Dict[str, Any]:
|
97
|
+
generation_config = self.get_generation_config()
|
98
|
+
|
99
|
+
if files_list is None:
|
100
|
+
files_list = []
|
101
|
+
genai.configure(api_key=self.api_token)
|
102
|
+
if (
|
103
|
+
system_prompt is not None
|
104
|
+
and system_prompt != ""
|
105
|
+
and self._model_ != "gemini-pro"
|
106
|
+
):
|
107
|
+
try:
|
108
|
+
self.generative_model = genai.GenerativeModel(
|
109
|
+
self._model_,
|
110
|
+
safety_settings=safety_settings,
|
111
|
+
system_instruction=system_prompt,
|
112
|
+
)
|
113
|
+
except InvalidArgument as e:
|
114
|
+
print(
|
115
|
+
f"This model, {self._model_}, does not support system_instruction"
|
116
|
+
)
|
117
|
+
print("Will add system_prompt to user_prompt")
|
118
|
+
user_prompt = f"{system_prompt}\n{user_prompt}"
|
119
|
+
else:
|
120
|
+
self.generative_model = genai.GenerativeModel(
|
121
|
+
self._model_,
|
122
|
+
safety_settings=safety_settings,
|
123
|
+
)
|
124
|
+
combined_prompt = [user_prompt]
|
125
|
+
for file in files_list:
|
126
|
+
if "google" not in file.external_locations:
|
127
|
+
_ = file.upload_google()
|
128
|
+
gen_ai_file = google.generativeai.types.file_types.File(
|
129
|
+
file.external_locations["google"]
|
130
|
+
)
|
131
|
+
combined_prompt.append(gen_ai_file)
|
132
|
+
|
133
|
+
response = await self.generative_model.generate_content_async(
|
134
|
+
combined_prompt, generation_config=generation_config
|
135
|
+
)
|
136
|
+
return response.to_dict()
|
137
|
+
|
138
|
+
LLM.__name__ = model_name
|
139
|
+
return LLM
|
140
|
+
|
141
|
+
|
142
|
+
if __name__ == "__main__":
|
143
|
+
pass
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
3
|
+
|
4
|
+
import groq
|
5
|
+
|
6
|
+
|
7
|
+
class GroqService(OpenAIService):
|
8
|
+
"""DeepInfra service class."""
|
9
|
+
|
10
|
+
_inference_service_ = "groq"
|
11
|
+
_env_key_name_ = "GROQ_API_KEY"
|
12
|
+
|
13
|
+
_sync_client_ = groq.Groq
|
14
|
+
_async_client_ = groq.AsyncGroq
|
15
|
+
|
16
|
+
model_exclude_list = ["whisper-large-v3", "distil-whisper-large-v3-en"]
|
17
|
+
|
18
|
+
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
19
|
+
_base_url_ = None
|
20
|
+
_models_list_cache: List[str] = []
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from abc import abstractmethod, ABC
|
2
|
+
import re
|
3
|
+
from datetime import datetime, timedelta
|
4
|
+
from edsl.config import CONFIG
|
5
|
+
|
6
|
+
|
7
|
+
class InferenceServiceABC(ABC):
|
8
|
+
"""
|
9
|
+
Abstract class for inference services.
|
10
|
+
"""
|
11
|
+
|
12
|
+
_coop_config_vars = None
|
13
|
+
|
14
|
+
def __init_subclass__(cls):
|
15
|
+
"""
|
16
|
+
Check that the subclass has the required attributes.
|
17
|
+
- `key_sequence` attribute determines...
|
18
|
+
- `model_exclude_list` attribute determines...
|
19
|
+
"""
|
20
|
+
must_have_attributes = [
|
21
|
+
"key_sequence",
|
22
|
+
"model_exclude_list",
|
23
|
+
"usage_sequence",
|
24
|
+
"input_token_name",
|
25
|
+
"output_token_name",
|
26
|
+
]
|
27
|
+
for attr in must_have_attributes:
|
28
|
+
if not hasattr(cls, attr):
|
29
|
+
raise NotImplementedError(
|
30
|
+
f"Class {cls.__name__} must have a '{attr}' attribute."
|
31
|
+
)
|
32
|
+
|
33
|
+
@property
|
34
|
+
def service_name(self):
|
35
|
+
return self._inference_service_
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def _should_refresh_coop_config_vars(cls):
|
39
|
+
"""
|
40
|
+
Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
|
41
|
+
"""
|
42
|
+
|
43
|
+
if cls._last_config_fetch is None:
|
44
|
+
return True
|
45
|
+
return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
|
46
|
+
|
47
|
+
@abstractmethod
|
48
|
+
def available() -> list[str]:
|
49
|
+
"""
|
50
|
+
Returns a list of available models for the service.
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def create_model():
|
56
|
+
"""
|
57
|
+
Returns a LanguageModel object.
|
58
|
+
"""
|
59
|
+
pass
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
def to_class_name(s):
|
63
|
+
"""
|
64
|
+
Converts a string to a valid class name.
|
65
|
+
|
66
|
+
>>> InferenceServiceABC.to_class_name("hello world")
|
67
|
+
'HelloWorld'
|
68
|
+
"""
|
69
|
+
|
70
|
+
s = re.sub(r"[^a-zA-Z0-9 ]", "", s)
|
71
|
+
s = "".join(word.title() for word in s.split())
|
72
|
+
if s and s[0].isdigit():
|
73
|
+
s = "Class" + s
|
74
|
+
return s
|
75
|
+
|
76
|
+
|
77
|
+
if __name__ == "__main__":
|
78
|
+
import doctest
|
79
|
+
|
80
|
+
doctest.testmod()
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING, Literal
|
4
|
+
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
|
+
from edsl.inference_services.AvailableModelFetcher import AvailableModelFetcher
|
7
|
+
from edsl.exceptions.inference_services import InferenceServiceError
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
11
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
12
|
+
|
13
|
+
|
14
|
+
class ModelCreator(Protocol):
|
15
|
+
def create_model(self, model_name: str) -> "LanguageModel":
|
16
|
+
...
|
17
|
+
|
18
|
+
|
19
|
+
from edsl.enums import InferenceServiceLiteral
|
20
|
+
|
21
|
+
|
22
|
+
class ModelResolver:
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
services: List[InferenceServiceLiteral],
|
26
|
+
models_to_services: Dict[InferenceServiceLiteral, InferenceServiceABC],
|
27
|
+
availability_fetcher: "AvailableModelFetcher",
|
28
|
+
):
|
29
|
+
"""
|
30
|
+
Class for determining which service to use for a given model.
|
31
|
+
"""
|
32
|
+
self.services = services
|
33
|
+
self._models_to_services = models_to_services
|
34
|
+
self.availability_fetcher = availability_fetcher
|
35
|
+
self._service_names_to_classes = {
|
36
|
+
service._inference_service_: service for service in services
|
37
|
+
}
|
38
|
+
|
39
|
+
def resolve_model(
|
40
|
+
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
41
|
+
) -> InferenceServiceABC:
|
42
|
+
"""Returns an InferenceServiceABC object for the given model name.
|
43
|
+
|
44
|
+
:param model_name: The name of the model to resolve. E.g., 'gpt-4o'
|
45
|
+
:param service_name: The name of the service to use. E.g., 'openai'
|
46
|
+
:return: An InferenceServiceABC object
|
47
|
+
|
48
|
+
"""
|
49
|
+
if model_name == "test":
|
50
|
+
from edsl.inference_services.TestService import TestService
|
51
|
+
|
52
|
+
return TestService()
|
53
|
+
|
54
|
+
if service_name is not None:
|
55
|
+
service: InferenceServiceABC = self._service_names_to_classes.get(
|
56
|
+
service_name
|
57
|
+
)
|
58
|
+
if not service:
|
59
|
+
raise InferenceServiceError(f"Service {service_name} not found")
|
60
|
+
return service
|
61
|
+
|
62
|
+
if model_name in self._models_to_services: # maybe we've seen it before!
|
63
|
+
return self._models_to_services[model_name]
|
64
|
+
|
65
|
+
for service in self.services:
|
66
|
+
(
|
67
|
+
available_models,
|
68
|
+
service_name,
|
69
|
+
) = self.availability_fetcher.get_available_models_by_service(service)
|
70
|
+
if model_name in available_models:
|
71
|
+
self._models_to_services[model_name] = service
|
72
|
+
return service
|
73
|
+
|
74
|
+
raise InferenceServiceError(
|
75
|
+
f"""Model {model_name} not found in any services.
|
76
|
+
If you know the service that has this model, use the service_name parameter directly.
|
77
|
+
E.g., Model("gpt-4o", service_name="openai")
|
78
|
+
"""
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
class InferenceServicesCollection:
|
83
|
+
added_models = defaultdict(list) # Moved back to class level
|
84
|
+
|
85
|
+
def __init__(self, services: Optional[List[InferenceServiceABC]] = None):
|
86
|
+
self.services = services or []
|
87
|
+
self._models_to_services: Dict[str, InferenceServiceABC] = {}
|
88
|
+
|
89
|
+
self.availability_fetcher = AvailableModelFetcher(
|
90
|
+
self.services, self.added_models
|
91
|
+
)
|
92
|
+
self.resolver = ModelResolver(
|
93
|
+
self.services, self._models_to_services, self.availability_fetcher
|
94
|
+
)
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def add_model(cls, service_name: str, model_name: str) -> None:
|
98
|
+
if service_name not in cls.added_models:
|
99
|
+
cls.added_models[service_name].append(model_name)
|
100
|
+
|
101
|
+
def service_names_to_classes(self) -> Dict[str, InferenceServiceABC]:
|
102
|
+
return {service._inference_service_: service for service in self.services}
|
103
|
+
|
104
|
+
def available(
|
105
|
+
self,
|
106
|
+
service: Optional[str] = None,
|
107
|
+
) -> List[Tuple[str, str, int]]:
|
108
|
+
return self.availability_fetcher.available(service)
|
109
|
+
|
110
|
+
def reset_cache(self) -> None:
|
111
|
+
self.availability_fetcher.reset_cache()
|
112
|
+
|
113
|
+
@property
|
114
|
+
def num_cache_entries(self) -> int:
|
115
|
+
return self.availability_fetcher.num_cache_entries
|
116
|
+
|
117
|
+
def register(self, service: InferenceServiceABC) -> None:
|
118
|
+
self.services.append(service)
|
119
|
+
|
120
|
+
def create_model_factory(
|
121
|
+
self, model_name: str, service_name: Optional[InferenceServiceLiteral] = None
|
122
|
+
) -> "LanguageModel":
|
123
|
+
|
124
|
+
if service_name is None: # we try to find the right service
|
125
|
+
service = self.resolver.resolve_model(model_name, service_name)
|
126
|
+
else: # if they passed a service, we'll use that
|
127
|
+
service = self.service_names_to_classes().get(service_name)
|
128
|
+
|
129
|
+
if not service: # but if we can't find it, we'll raise an error
|
130
|
+
raise InferenceServiceError(f"Service {service_name} not found")
|
131
|
+
|
132
|
+
return service.create_model(model_name)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
import doctest
|
137
|
+
|
138
|
+
doctest.testmod()
|