edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +46 -10
- edsl/__version__.py +1 -0
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +121 -104
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -204
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -417
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -166
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -218
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.14.dist-info/METADATA +0 -69
- edsl-0.1.14.dist-info/RECORD +0 -141
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -0,0 +1,120 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Optional
|
3
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
5
|
+
import asyncio
|
6
|
+
from mistralai import Mistral
|
7
|
+
|
8
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
9
|
+
|
10
|
+
|
11
|
+
class MistralAIService(InferenceServiceABC):
|
12
|
+
"""Mistral AI service class."""
|
13
|
+
|
14
|
+
key_sequence = ["choices", 0, "message", "content"]
|
15
|
+
usage_sequence = ["usage"]
|
16
|
+
|
17
|
+
_inference_service_ = "mistral"
|
18
|
+
_env_key_name_ = "MISTRAL_API_KEY" # Environment variable for Mistral API key
|
19
|
+
input_token_name = "prompt_tokens"
|
20
|
+
output_token_name = "completion_tokens"
|
21
|
+
|
22
|
+
_sync_client_instance = None
|
23
|
+
_async_client_instance = None
|
24
|
+
|
25
|
+
_sync_client = Mistral
|
26
|
+
_async_client = Mistral
|
27
|
+
|
28
|
+
_models_list_cache: List[str] = []
|
29
|
+
model_exclude_list = []
|
30
|
+
|
31
|
+
def __init_subclass__(cls, **kwargs):
|
32
|
+
super().__init_subclass__(**kwargs)
|
33
|
+
# so subclasses have to create their own instances of the clients
|
34
|
+
cls._sync_client_instance = None
|
35
|
+
cls._async_client_instance = None
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def sync_client(cls):
|
39
|
+
if cls._sync_client_instance is None:
|
40
|
+
cls._sync_client_instance = cls._sync_client(
|
41
|
+
api_key=os.getenv(cls._env_key_name_)
|
42
|
+
)
|
43
|
+
return cls._sync_client_instance
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def async_client(cls):
|
47
|
+
if cls._async_client_instance is None:
|
48
|
+
cls._async_client_instance = cls._async_client(
|
49
|
+
api_key=os.getenv(cls._env_key_name_)
|
50
|
+
)
|
51
|
+
return cls._async_client_instance
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def available(cls) -> list[str]:
|
55
|
+
if not cls._models_list_cache:
|
56
|
+
cls._models_list_cache = [
|
57
|
+
m.id for m in cls.sync_client().models.list().data
|
58
|
+
]
|
59
|
+
|
60
|
+
return cls._models_list_cache
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def create_model(
|
64
|
+
cls, model_name: str = "mistral", model_class_name=None
|
65
|
+
) -> LanguageModel:
|
66
|
+
if model_class_name is None:
|
67
|
+
model_class_name = cls.to_class_name(model_name)
|
68
|
+
|
69
|
+
class LLM(LanguageModel):
|
70
|
+
"""
|
71
|
+
Child class of LanguageModel for interacting with Mistral models.
|
72
|
+
"""
|
73
|
+
|
74
|
+
key_sequence = cls.key_sequence
|
75
|
+
usage_sequence = cls.usage_sequence
|
76
|
+
|
77
|
+
input_token_name = cls.input_token_name
|
78
|
+
output_token_name = cls.output_token_name
|
79
|
+
|
80
|
+
_inference_service_ = cls._inference_service_
|
81
|
+
_model_ = model_name
|
82
|
+
_parameters_ = {
|
83
|
+
"temperature": 0.5,
|
84
|
+
"max_tokens": 512,
|
85
|
+
"top_p": 0.9,
|
86
|
+
}
|
87
|
+
|
88
|
+
def sync_client(self):
|
89
|
+
return cls.sync_client()
|
90
|
+
|
91
|
+
def async_client(self):
|
92
|
+
return cls.async_client()
|
93
|
+
|
94
|
+
async def async_execute_model_call(
|
95
|
+
self,
|
96
|
+
user_prompt: str,
|
97
|
+
system_prompt: str = "",
|
98
|
+
files_list: Optional[List["FileStore"]] = None,
|
99
|
+
) -> dict[str, Any]:
|
100
|
+
"""Calls the Mistral API and returns the API response."""
|
101
|
+
s = self.async_client()
|
102
|
+
|
103
|
+
try:
|
104
|
+
res = await s.chat.complete_async(
|
105
|
+
model=model_name,
|
106
|
+
messages=[
|
107
|
+
{
|
108
|
+
"content": user_prompt,
|
109
|
+
"role": "user",
|
110
|
+
},
|
111
|
+
],
|
112
|
+
)
|
113
|
+
except Exception as e:
|
114
|
+
raise LanguageModelBadResponseError(f"Error with Mistral API: {e}")
|
115
|
+
|
116
|
+
return res.model_dump()
|
117
|
+
|
118
|
+
LLM.__name__ = model_class_name
|
119
|
+
|
120
|
+
return LLM
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class OllamaService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "ollama"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "http://localhost:11434/v1"
|
18
|
+
_models_list_cache: List[str] = []
|
@@ -0,0 +1,236 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, List, Optional, Dict, NewType
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
import openai
|
7
|
+
|
8
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
9
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
10
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
11
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
12
|
+
|
13
|
+
from edsl.config import CONFIG
|
14
|
+
|
15
|
+
APIToken = NewType("APIToken", str)
|
16
|
+
|
17
|
+
|
18
|
+
class OpenAIService(InferenceServiceABC):
|
19
|
+
"""OpenAI service class."""
|
20
|
+
|
21
|
+
_inference_service_ = "openai"
|
22
|
+
_env_key_name_ = "OPENAI_API_KEY"
|
23
|
+
_base_url_ = None
|
24
|
+
|
25
|
+
_sync_client_ = openai.OpenAI
|
26
|
+
_async_client_ = openai.AsyncOpenAI
|
27
|
+
|
28
|
+
_sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
|
29
|
+
_async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
|
30
|
+
|
31
|
+
key_sequence = ["choices", 0, "message", "content"]
|
32
|
+
usage_sequence = ["usage"]
|
33
|
+
input_token_name = "prompt_tokens"
|
34
|
+
output_token_name = "completion_tokens"
|
35
|
+
|
36
|
+
available_models_url = "https://platform.openai.com/docs/models/gp"
|
37
|
+
|
38
|
+
def __init_subclass__(cls, **kwargs):
|
39
|
+
super().__init_subclass__(**kwargs)
|
40
|
+
# so subclasses that use the OpenAI api key have to create their own instances of the clients
|
41
|
+
cls._sync_client_instances = {}
|
42
|
+
cls._async_client_instances = {}
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def sync_client(cls, api_key):
|
46
|
+
if api_key not in cls._sync_client_instances:
|
47
|
+
client = cls._sync_client_(
|
48
|
+
api_key=api_key,
|
49
|
+
base_url=cls._base_url_,
|
50
|
+
)
|
51
|
+
cls._sync_client_instances[api_key] = client
|
52
|
+
client = cls._sync_client_instances[api_key]
|
53
|
+
return client
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def async_client(cls, api_key):
|
57
|
+
if api_key not in cls._async_client_instances:
|
58
|
+
client = cls._async_client_(
|
59
|
+
api_key=api_key,
|
60
|
+
base_url=cls._base_url_,
|
61
|
+
)
|
62
|
+
cls._async_client_instances[api_key] = client
|
63
|
+
client = cls._async_client_instances[api_key]
|
64
|
+
return client
|
65
|
+
|
66
|
+
model_exclude_list = [
|
67
|
+
"whisper-1",
|
68
|
+
"davinci-002",
|
69
|
+
"dall-e-2",
|
70
|
+
"tts-1-hd-1106",
|
71
|
+
"tts-1-hd",
|
72
|
+
"dall-e-3",
|
73
|
+
"tts-1",
|
74
|
+
"babbage-002",
|
75
|
+
"tts-1-1106",
|
76
|
+
"text-embedding-3-large",
|
77
|
+
"text-embedding-3-small",
|
78
|
+
"text-embedding-ada-002",
|
79
|
+
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
80
|
+
"gpt-3.5-turbo-instruct-0914",
|
81
|
+
"gpt-3.5-turbo-instruct",
|
82
|
+
]
|
83
|
+
_models_list_cache: List[str] = []
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def get_model_list(cls, api_key=None):
|
87
|
+
if api_key is None:
|
88
|
+
api_key = os.getenv(cls._env_key_name_)
|
89
|
+
raw_list = cls.sync_client(api_key).models.list()
|
90
|
+
if hasattr(raw_list, "data"):
|
91
|
+
return raw_list.data
|
92
|
+
else:
|
93
|
+
return raw_list
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def available(cls, api_token=None) -> List[str]:
|
97
|
+
if api_token is None:
|
98
|
+
api_token = os.getenv(cls._env_key_name_)
|
99
|
+
if not cls._models_list_cache:
|
100
|
+
try:
|
101
|
+
cls._models_list_cache = [
|
102
|
+
m.id
|
103
|
+
for m in cls.get_model_list(api_key=api_token)
|
104
|
+
if m.id not in cls.model_exclude_list
|
105
|
+
]
|
106
|
+
except Exception as e:
|
107
|
+
raise
|
108
|
+
return cls._models_list_cache
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
112
|
+
if model_class_name is None:
|
113
|
+
model_class_name = cls.to_class_name(model_name)
|
114
|
+
|
115
|
+
class LLM(LanguageModel):
|
116
|
+
"""
|
117
|
+
Child class of LanguageModel for interacting with OpenAI models
|
118
|
+
"""
|
119
|
+
|
120
|
+
key_sequence = cls.key_sequence
|
121
|
+
usage_sequence = cls.usage_sequence
|
122
|
+
input_token_name = cls.input_token_name
|
123
|
+
output_token_name = cls.output_token_name
|
124
|
+
|
125
|
+
_inference_service_ = cls._inference_service_
|
126
|
+
_model_ = model_name
|
127
|
+
_parameters_ = {
|
128
|
+
"temperature": 0.5,
|
129
|
+
"max_tokens": 1000,
|
130
|
+
"top_p": 1,
|
131
|
+
"frequency_penalty": 0,
|
132
|
+
"presence_penalty": 0,
|
133
|
+
"logprobs": False,
|
134
|
+
"top_logprobs": 3,
|
135
|
+
}
|
136
|
+
|
137
|
+
def sync_client(self):
|
138
|
+
return cls.sync_client(api_key=self.api_token)
|
139
|
+
|
140
|
+
def async_client(self):
|
141
|
+
return cls.async_client(api_key=self.api_token)
|
142
|
+
|
143
|
+
@classmethod
|
144
|
+
def available(cls) -> list[str]:
|
145
|
+
return cls.sync_client().models.list()
|
146
|
+
|
147
|
+
def get_headers(self) -> dict[str, Any]:
|
148
|
+
client = self.sync_client()
|
149
|
+
response = client.chat.completions.with_raw_response.create(
|
150
|
+
messages=[
|
151
|
+
{
|
152
|
+
"role": "user",
|
153
|
+
"content": "Say this is a test",
|
154
|
+
}
|
155
|
+
],
|
156
|
+
model=self.model,
|
157
|
+
)
|
158
|
+
return dict(response.headers)
|
159
|
+
|
160
|
+
def get_rate_limits(self) -> dict[str, Any]:
|
161
|
+
try:
|
162
|
+
if "openai" in rate_limits:
|
163
|
+
headers = rate_limits["openai"]
|
164
|
+
|
165
|
+
else:
|
166
|
+
headers = self.get_headers()
|
167
|
+
|
168
|
+
except Exception as e:
|
169
|
+
return {
|
170
|
+
"rpm": 10_000,
|
171
|
+
"tpm": 2_000_000,
|
172
|
+
}
|
173
|
+
else:
|
174
|
+
return {
|
175
|
+
"rpm": int(headers["x-ratelimit-limit-requests"]),
|
176
|
+
"tpm": int(headers["x-ratelimit-limit-tokens"]),
|
177
|
+
}
|
178
|
+
|
179
|
+
async def async_execute_model_call(
|
180
|
+
self,
|
181
|
+
user_prompt: str,
|
182
|
+
system_prompt: str = "",
|
183
|
+
files_list: Optional[List["Files"]] = None,
|
184
|
+
invigilator: Optional[
|
185
|
+
"InvigilatorAI"
|
186
|
+
] = None, # TBD - can eventually be used for function-calling
|
187
|
+
) -> dict[str, Any]:
|
188
|
+
"""Calls the OpenAI API and returns the API response."""
|
189
|
+
if files_list:
|
190
|
+
content = [{"type": "text", "text": user_prompt}]
|
191
|
+
for file_entry in files_list:
|
192
|
+
content.append(
|
193
|
+
{
|
194
|
+
"type": "image_url",
|
195
|
+
"image_url": {
|
196
|
+
"url": f"data:{file_entry.mime_type};base64,{file_entry.base64_string}"
|
197
|
+
},
|
198
|
+
}
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
content = user_prompt
|
202
|
+
client = self.async_client()
|
203
|
+
|
204
|
+
messages = [
|
205
|
+
{"role": "system", "content": system_prompt},
|
206
|
+
{"role": "user", "content": content},
|
207
|
+
]
|
208
|
+
if (
|
209
|
+
system_prompt == "" and self.omit_system_prompt_if_empty
|
210
|
+
) or "o1" in self.model:
|
211
|
+
messages = messages[1:]
|
212
|
+
|
213
|
+
params = {
|
214
|
+
"model": self.model,
|
215
|
+
"messages": messages,
|
216
|
+
"temperature": self.temperature,
|
217
|
+
"max_tokens": self.max_tokens,
|
218
|
+
"top_p": self.top_p,
|
219
|
+
"frequency_penalty": self.frequency_penalty,
|
220
|
+
"presence_penalty": self.presence_penalty,
|
221
|
+
"logprobs": self.logprobs,
|
222
|
+
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
223
|
+
}
|
224
|
+
if "o1" in self.model:
|
225
|
+
params.pop("max_tokens")
|
226
|
+
params["max_completion_tokens"] = self.max_tokens
|
227
|
+
params["temperature"] = 1
|
228
|
+
try:
|
229
|
+
response = await client.chat.completions.create(**params)
|
230
|
+
except Exception as e:
|
231
|
+
print(e)
|
232
|
+
return response.model_dump()
|
233
|
+
|
234
|
+
LLM.__name__ = "LanguageModel"
|
235
|
+
|
236
|
+
return LLM
|
@@ -0,0 +1,160 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List, Optional
|
5
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
6
|
+
|
7
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
|
+
from edsl.language_models import LanguageModel
|
9
|
+
|
10
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
11
|
+
|
12
|
+
|
13
|
+
class PerplexityService(OpenAIService):
|
14
|
+
"""Perplexity service class."""
|
15
|
+
|
16
|
+
_inference_service_ = "perplexity"
|
17
|
+
_env_key_name_ = "PERPLEXITY_API_KEY"
|
18
|
+
_base_url_ = "https://api.perplexity.ai"
|
19
|
+
_models_list_cache: List[str] = []
|
20
|
+
# default perplexity parameters
|
21
|
+
_parameters_ = {
|
22
|
+
"temperature": 0.5,
|
23
|
+
"max_tokens": 1000,
|
24
|
+
"top_p": 1,
|
25
|
+
"logprobs": False,
|
26
|
+
"top_logprobs": 3,
|
27
|
+
}
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def available(cls) -> List[str]:
|
31
|
+
return [
|
32
|
+
"llama-3.1-sonar-huge-128k-online",
|
33
|
+
"llama-3.1-sonar-large-128k-online",
|
34
|
+
"llama-3.1-sonar-small-128k-online",
|
35
|
+
]
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def create_model(
|
39
|
+
cls, model_name="llama-3.1-sonar-large-128k-online", model_class_name=None
|
40
|
+
) -> LanguageModel:
|
41
|
+
if model_class_name is None:
|
42
|
+
model_class_name = cls.to_class_name(model_name)
|
43
|
+
|
44
|
+
class LLM(LanguageModel):
|
45
|
+
"""
|
46
|
+
Child class of LanguageModel for interacting with Perplexity models
|
47
|
+
"""
|
48
|
+
|
49
|
+
key_sequence = cls.key_sequence
|
50
|
+
usage_sequence = cls.usage_sequence
|
51
|
+
input_token_name = cls.input_token_name
|
52
|
+
output_token_name = cls.output_token_name
|
53
|
+
|
54
|
+
_inference_service_ = cls._inference_service_
|
55
|
+
_model_ = model_name
|
56
|
+
|
57
|
+
_parameters_ = {
|
58
|
+
"temperature": 0.5,
|
59
|
+
"max_tokens": 1000,
|
60
|
+
"top_p": 1,
|
61
|
+
"frequency_penalty": 1,
|
62
|
+
"presence_penalty": 0,
|
63
|
+
# "logprobs": False, # Enable this returns 'Neither or both of logprobs and top_logprobs must be set.
|
64
|
+
# "top_logprobs": 3,
|
65
|
+
}
|
66
|
+
|
67
|
+
def sync_client(self):
|
68
|
+
return cls.sync_client()
|
69
|
+
|
70
|
+
def async_client(self):
|
71
|
+
return cls.async_client()
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def available(cls) -> list[str]:
|
75
|
+
return cls.sync_client().models.list()
|
76
|
+
|
77
|
+
def get_headers(self) -> dict[str, Any]:
|
78
|
+
client = self.sync_client()
|
79
|
+
response = client.chat.completions.with_raw_response.create(
|
80
|
+
messages=[
|
81
|
+
{
|
82
|
+
"role": "user",
|
83
|
+
"content": "Say this is a test",
|
84
|
+
}
|
85
|
+
],
|
86
|
+
model=self.model,
|
87
|
+
)
|
88
|
+
return dict(response.headers)
|
89
|
+
|
90
|
+
def get_rate_limits(self) -> dict[str, Any]:
|
91
|
+
try:
|
92
|
+
if "openai" in rate_limits:
|
93
|
+
headers = rate_limits["openai"]
|
94
|
+
|
95
|
+
else:
|
96
|
+
headers = self.get_headers()
|
97
|
+
|
98
|
+
except Exception as e:
|
99
|
+
return {
|
100
|
+
"rpm": 10_000,
|
101
|
+
"tpm": 2_000_000,
|
102
|
+
}
|
103
|
+
else:
|
104
|
+
return {
|
105
|
+
"rpm": int(headers["x-ratelimit-limit-requests"]),
|
106
|
+
"tpm": int(headers["x-ratelimit-limit-tokens"]),
|
107
|
+
}
|
108
|
+
|
109
|
+
async def async_execute_model_call(
|
110
|
+
self,
|
111
|
+
user_prompt: str,
|
112
|
+
system_prompt: str = "",
|
113
|
+
files_list: Optional[List["Files"]] = None,
|
114
|
+
invigilator: Optional[
|
115
|
+
"InvigilatorAI"
|
116
|
+
] = None, # TBD - can eventually be used for function-calling
|
117
|
+
) -> dict[str, Any]:
|
118
|
+
"""Calls the OpenAI API and returns the API response."""
|
119
|
+
if files_list:
|
120
|
+
encoded_image = files_list[0].base64_string
|
121
|
+
content = [{"type": "text", "text": user_prompt}]
|
122
|
+
content.append(
|
123
|
+
{
|
124
|
+
"type": "image_url",
|
125
|
+
"image_url": {
|
126
|
+
"url": f"data:image/jpeg;base64,{encoded_image}"
|
127
|
+
},
|
128
|
+
}
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
content = user_prompt
|
132
|
+
client = self.async_client()
|
133
|
+
|
134
|
+
messages = [
|
135
|
+
{"role": "system", "content": system_prompt},
|
136
|
+
{"role": "user", "content": content},
|
137
|
+
]
|
138
|
+
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
139
|
+
messages = messages[1:]
|
140
|
+
|
141
|
+
params = {
|
142
|
+
"model": self.model,
|
143
|
+
"messages": messages,
|
144
|
+
"temperature": self.temperature,
|
145
|
+
"max_tokens": self.max_tokens,
|
146
|
+
"top_p": self.top_p,
|
147
|
+
"frequency_penalty": self.frequency_penalty,
|
148
|
+
"presence_penalty": self.presence_penalty,
|
149
|
+
# "logprobs": self.logprobs,
|
150
|
+
# "top_logprobs": self.top_logprobs if self.logprobs else None,
|
151
|
+
}
|
152
|
+
try:
|
153
|
+
response = await client.chat.completions.create(**params)
|
154
|
+
except Exception as e:
|
155
|
+
print(e, flush=True)
|
156
|
+
return response.model_dump()
|
157
|
+
|
158
|
+
LLM.__name__ = "LanguageModel"
|
159
|
+
|
160
|
+
return LLM
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import List, Optional, TYPE_CHECKING
|
3
|
+
from functools import partial
|
4
|
+
import warnings
|
5
|
+
|
6
|
+
from edsl.inference_services.data_structures import AvailableModels, ModelNamesList
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
10
|
+
|
11
|
+
|
12
|
+
class ModelSource(Enum):
|
13
|
+
LOCAL = "local"
|
14
|
+
COOP = "coop"
|
15
|
+
CACHE = "cache"
|
16
|
+
|
17
|
+
|
18
|
+
class ServiceAvailability:
|
19
|
+
"""This class is responsible for fetching the available models from different sources."""
|
20
|
+
|
21
|
+
_coop_model_list = None
|
22
|
+
|
23
|
+
def __init__(self, source_order: Optional[List[ModelSource]] = None):
|
24
|
+
"""
|
25
|
+
Initialize with custom source order.
|
26
|
+
Default order is LOCAL -> COOP -> CACHE
|
27
|
+
"""
|
28
|
+
self.source_order = source_order or [
|
29
|
+
ModelSource.LOCAL,
|
30
|
+
ModelSource.COOP,
|
31
|
+
ModelSource.CACHE,
|
32
|
+
]
|
33
|
+
|
34
|
+
# Map sources to their fetch functions
|
35
|
+
self._source_fetchers = {
|
36
|
+
ModelSource.LOCAL: self._fetch_from_local_service,
|
37
|
+
ModelSource.COOP: self._fetch_from_coop,
|
38
|
+
ModelSource.CACHE: self._fetch_from_cache,
|
39
|
+
}
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def models_from_coop(cls) -> AvailableModels:
|
43
|
+
if not cls._coop_model_list:
|
44
|
+
from edsl.coop.coop import Coop
|
45
|
+
|
46
|
+
c = Coop()
|
47
|
+
coop_model_list = c.fetch_models()
|
48
|
+
cls._coop_model_list = coop_model_list
|
49
|
+
return cls._coop_model_list
|
50
|
+
|
51
|
+
def get_service_available(
|
52
|
+
self, service: "InferenceServiceABC", warn: bool = False
|
53
|
+
) -> ModelNamesList:
|
54
|
+
"""
|
55
|
+
Try to fetch available models from sources in specified order.
|
56
|
+
Returns first successful result.
|
57
|
+
"""
|
58
|
+
last_error = None
|
59
|
+
|
60
|
+
for source in self.source_order:
|
61
|
+
try:
|
62
|
+
fetch_func = partial(self._source_fetchers[source], service)
|
63
|
+
result = fetch_func()
|
64
|
+
|
65
|
+
# Cache successful result
|
66
|
+
service._models_list_cache = result
|
67
|
+
return result
|
68
|
+
|
69
|
+
except Exception as e:
|
70
|
+
last_error = e
|
71
|
+
if warn:
|
72
|
+
self._warn_source_failed(service, source)
|
73
|
+
continue
|
74
|
+
|
75
|
+
# If we get here, all sources failed
|
76
|
+
raise RuntimeError(
|
77
|
+
f"All sources failed to fetch models. Last error: {last_error}"
|
78
|
+
)
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def _fetch_from_local_service(service: "InferenceServiceABC") -> ModelNamesList:
|
82
|
+
"""Attempt to fetch models directly from the service."""
|
83
|
+
return service.available()
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def _fetch_from_coop(cls, service: "InferenceServiceABC") -> ModelNamesList:
|
87
|
+
"""Fetch models from Coop."""
|
88
|
+
models_from_coop = cls.models_from_coop()
|
89
|
+
return models_from_coop.get(service._inference_service_, [])
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
|
93
|
+
"""Fetch models from local cache."""
|
94
|
+
from edsl.inference_services.models_available_cache import models_available
|
95
|
+
|
96
|
+
return models_available.get(service._inference_service_, [])
|
97
|
+
|
98
|
+
def _warn_source_failed(self, service: "InferenceServiceABC", source: ModelSource):
|
99
|
+
"""Display appropriate warning message based on failed source."""
|
100
|
+
messages = {
|
101
|
+
ModelSource.LOCAL: f"""Error getting models for {service._inference_service_}.
|
102
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference,
|
103
|
+
or stored your own API keys for the language models that you want to use.
|
104
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
105
|
+
Trying next source.""",
|
106
|
+
ModelSource.COOP: f"Error getting models from Coop for {service._inference_service_}. Trying next source.",
|
107
|
+
ModelSource.CACHE: f"Error getting models from cache for {service._inference_service_}.",
|
108
|
+
}
|
109
|
+
warnings.warn(messages[source], UserWarning)
|
110
|
+
|
111
|
+
|
112
|
+
if __name__ == "__main__":
|
113
|
+
# sa = ServiceAvailability()
|
114
|
+
# models_from_coop = sa.models_from_coop()
|
115
|
+
# print(models_from_coop)
|
116
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
117
|
+
|
118
|
+
openai_models = ServiceAvailability._fetch_from_local_service(OpenAIService())
|
119
|
+
print(openai_models)
|
120
|
+
|
121
|
+
# Example usage:
|
122
|
+
"""
|
123
|
+
# Default order (LOCAL -> COOP -> CACHE)
|
124
|
+
availability = ServiceAvailability()
|
125
|
+
|
126
|
+
# Custom order (COOP -> LOCAL -> CACHE)
|
127
|
+
availability_coop_first = ServiceAvailability([
|
128
|
+
ModelSource.COOP,
|
129
|
+
ModelSource.LOCAL,
|
130
|
+
ModelSource.CACHE
|
131
|
+
])
|
132
|
+
|
133
|
+
# Get available models using custom order
|
134
|
+
models = availability_coop_first.get_service_available(service, warn=True)
|
135
|
+
"""
|