edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +46 -10
- edsl/__version__.py +1 -0
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +121 -104
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -204
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -417
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -166
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -218
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.14.dist-info/METADATA +0 -69
- edsl-0.1.14.dist-info/RECORD +0 -141
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -0,0 +1,184 @@
|
|
1
|
+
from typing import List, Optional, get_args, Union
|
2
|
+
from pathlib import Path
|
3
|
+
import sqlite3
|
4
|
+
from datetime import datetime
|
5
|
+
import tempfile
|
6
|
+
from platformdirs import user_cache_dir
|
7
|
+
from dataclasses import dataclass
|
8
|
+
import os
|
9
|
+
|
10
|
+
from edsl.inference_services.data_structures import LanguageModelInfo, AvailableModels
|
11
|
+
from edsl.enums import InferenceServiceLiteral
|
12
|
+
|
13
|
+
|
14
|
+
class AvailableModelCacheHandler:
|
15
|
+
MAX_ROWS = 1000
|
16
|
+
CACHE_VALIDITY_HOURS = 48
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
cache_validity_hours: int = 48,
|
21
|
+
verbose: bool = False,
|
22
|
+
testing_db_name: str = None,
|
23
|
+
):
|
24
|
+
self.cache_validity_hours = cache_validity_hours
|
25
|
+
self.verbose = verbose
|
26
|
+
|
27
|
+
if testing_db_name:
|
28
|
+
self.cache_dir = Path(tempfile.mkdtemp())
|
29
|
+
self.db_path = self.cache_dir / testing_db_name
|
30
|
+
else:
|
31
|
+
self.cache_dir = Path(user_cache_dir("edsl", "model_availability"))
|
32
|
+
self.db_path = self.cache_dir / "available_models.db"
|
33
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
34
|
+
|
35
|
+
if os.path.exists(self.db_path):
|
36
|
+
if self.verbose:
|
37
|
+
print(f"Using existing cache DB: {self.db_path}")
|
38
|
+
else:
|
39
|
+
self._initialize_db()
|
40
|
+
|
41
|
+
@property
|
42
|
+
def path_to_db(self):
|
43
|
+
return self.db_path
|
44
|
+
|
45
|
+
def _initialize_db(self):
|
46
|
+
"""Initialize the SQLite database with the required schema."""
|
47
|
+
with sqlite3.connect(self.db_path) as conn:
|
48
|
+
cursor = conn.cursor()
|
49
|
+
# Drop the old table if it exists (for migration)
|
50
|
+
cursor.execute("DROP TABLE IF EXISTS model_cache")
|
51
|
+
cursor.execute(
|
52
|
+
"""
|
53
|
+
CREATE TABLE IF NOT EXISTS model_cache (
|
54
|
+
timestamp DATETIME NOT NULL,
|
55
|
+
model_name TEXT NOT NULL,
|
56
|
+
service_name TEXT NOT NULL,
|
57
|
+
UNIQUE(model_name, service_name)
|
58
|
+
)
|
59
|
+
"""
|
60
|
+
)
|
61
|
+
conn.commit()
|
62
|
+
|
63
|
+
def _prune_old_entries(self, conn: sqlite3.Connection):
|
64
|
+
"""Delete oldest entries when MAX_ROWS is exceeded."""
|
65
|
+
cursor = conn.cursor()
|
66
|
+
cursor.execute("SELECT COUNT(*) FROM model_cache")
|
67
|
+
count = cursor.fetchone()[0]
|
68
|
+
|
69
|
+
if count > self.MAX_ROWS:
|
70
|
+
cursor.execute(
|
71
|
+
"""
|
72
|
+
DELETE FROM model_cache
|
73
|
+
WHERE rowid IN (
|
74
|
+
SELECT rowid
|
75
|
+
FROM model_cache
|
76
|
+
ORDER BY timestamp ASC
|
77
|
+
LIMIT ?
|
78
|
+
)
|
79
|
+
""",
|
80
|
+
(count - self.MAX_ROWS,),
|
81
|
+
)
|
82
|
+
conn.commit()
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def example_models(cls) -> List[LanguageModelInfo]:
|
86
|
+
return [
|
87
|
+
LanguageModelInfo(
|
88
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "deep_infra"
|
89
|
+
),
|
90
|
+
LanguageModelInfo("openai/gpt-4", "openai"),
|
91
|
+
]
|
92
|
+
|
93
|
+
def add_models_to_cache(self, models_data: List[LanguageModelInfo]):
|
94
|
+
"""Add new models to the cache, updating timestamps for existing entries."""
|
95
|
+
current_time = datetime.now()
|
96
|
+
|
97
|
+
with sqlite3.connect(self.db_path) as conn:
|
98
|
+
cursor = conn.cursor()
|
99
|
+
for model in models_data:
|
100
|
+
cursor.execute(
|
101
|
+
"""
|
102
|
+
INSERT INTO model_cache (timestamp, model_name, service_name)
|
103
|
+
VALUES (?, ?, ?)
|
104
|
+
ON CONFLICT(model_name, service_name)
|
105
|
+
DO UPDATE SET timestamp = excluded.timestamp
|
106
|
+
""",
|
107
|
+
(current_time, model.model_name, model.service_name),
|
108
|
+
)
|
109
|
+
|
110
|
+
# self._prune_old_entries(conn)
|
111
|
+
conn.commit()
|
112
|
+
|
113
|
+
def reset_cache(self):
|
114
|
+
"""Clear all entries from the cache."""
|
115
|
+
with sqlite3.connect(self.db_path) as conn:
|
116
|
+
cursor = conn.cursor()
|
117
|
+
cursor.execute("DELETE FROM model_cache")
|
118
|
+
conn.commit()
|
119
|
+
|
120
|
+
@property
|
121
|
+
def num_cache_entries(self):
|
122
|
+
"""Return the number of entries in the cache."""
|
123
|
+
with sqlite3.connect(self.db_path) as conn:
|
124
|
+
cursor = conn.cursor()
|
125
|
+
cursor.execute("SELECT COUNT(*) FROM model_cache")
|
126
|
+
count = cursor.fetchone()[0]
|
127
|
+
return count
|
128
|
+
|
129
|
+
def models(
|
130
|
+
self,
|
131
|
+
service: Optional[InferenceServiceLiteral],
|
132
|
+
) -> Union[None, AvailableModels]:
|
133
|
+
"""Return the available models within the cache validity period."""
|
134
|
+
# if service is not None:
|
135
|
+
# assert service in get_args(InferenceServiceLiteral)
|
136
|
+
|
137
|
+
with sqlite3.connect(self.db_path) as conn:
|
138
|
+
cursor = conn.cursor()
|
139
|
+
valid_time = datetime.now().timestamp() - (self.cache_validity_hours * 3600)
|
140
|
+
|
141
|
+
if self.verbose:
|
142
|
+
print(f"Fetching all with timestamp greater than {valid_time}")
|
143
|
+
|
144
|
+
cursor.execute(
|
145
|
+
"""
|
146
|
+
SELECT DISTINCT model_name, service_name
|
147
|
+
FROM model_cache
|
148
|
+
WHERE timestamp > ?
|
149
|
+
ORDER BY timestamp DESC
|
150
|
+
""",
|
151
|
+
(valid_time,),
|
152
|
+
)
|
153
|
+
|
154
|
+
results = cursor.fetchall()
|
155
|
+
if not results:
|
156
|
+
if self.verbose:
|
157
|
+
print("No results found in cache DB.")
|
158
|
+
return None
|
159
|
+
|
160
|
+
matching_models = [
|
161
|
+
LanguageModelInfo(model_name=row[0], service_name=row[1])
|
162
|
+
for row in results
|
163
|
+
]
|
164
|
+
|
165
|
+
if self.verbose:
|
166
|
+
print(f"Found {len(matching_models)} models in cache DB.")
|
167
|
+
if service:
|
168
|
+
matching_models = [
|
169
|
+
model for model in matching_models if model.service_name == service
|
170
|
+
]
|
171
|
+
|
172
|
+
return AvailableModels(matching_models)
|
173
|
+
|
174
|
+
|
175
|
+
if __name__ == "__main__":
|
176
|
+
import doctest
|
177
|
+
|
178
|
+
doctest.testmod()
|
179
|
+
# cache_handler = AvailableModelCacheHandler(verbose=True)
|
180
|
+
# models_data = cache_handler.example_models()
|
181
|
+
# cache_handler.add_models_to_cache(models_data)
|
182
|
+
# print(cache_handler.models())
|
183
|
+
# cache_handler.clear_cache()
|
184
|
+
# print(cache_handler.models())
|
@@ -0,0 +1,215 @@
|
|
1
|
+
from typing import Any, List, Tuple, Optional, Dict, TYPE_CHECKING, Union, Generator
|
2
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3
|
+
from collections import UserList
|
4
|
+
|
5
|
+
from edsl.inference_services.ServiceAvailability import ServiceAvailability
|
6
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.inference_services.data_structures import ModelNamesList
|
8
|
+
from edsl.enums import InferenceServiceLiteral
|
9
|
+
|
10
|
+
from edsl.inference_services.data_structures import LanguageModelInfo
|
11
|
+
from edsl.inference_services.AvailableModelCacheHandler import (
|
12
|
+
AvailableModelCacheHandler,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
from edsl.inference_services.data_structures import AvailableModels
|
17
|
+
|
18
|
+
|
19
|
+
class AvailableModelFetcher:
|
20
|
+
"""Fetches available models from the various services with JSON caching."""
|
21
|
+
|
22
|
+
service_availability = ServiceAvailability()
|
23
|
+
CACHE_VALIDITY_HOURS = 48 # Cache validity period in hours
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
services: List["InferenceServiceABC"],
|
28
|
+
added_models: Dict[str, List[str]],
|
29
|
+
verbose: bool = False,
|
30
|
+
use_cache: bool = True,
|
31
|
+
):
|
32
|
+
self.services = services
|
33
|
+
self.added_models = added_models
|
34
|
+
self._service_map = {
|
35
|
+
service._inference_service_: service for service in services
|
36
|
+
}
|
37
|
+
self.verbose = verbose
|
38
|
+
if use_cache:
|
39
|
+
self.cache_handler = AvailableModelCacheHandler()
|
40
|
+
else:
|
41
|
+
self.cache_handler = None
|
42
|
+
|
43
|
+
@property
|
44
|
+
def num_cache_entries(self):
|
45
|
+
return self.cache_handler.num_cache_entries
|
46
|
+
|
47
|
+
@property
|
48
|
+
def path_to_db(self):
|
49
|
+
return self.cache_handler.path_to_db
|
50
|
+
|
51
|
+
def reset_cache(self):
|
52
|
+
if self.cache_handler:
|
53
|
+
self.cache_handler.reset_cache()
|
54
|
+
|
55
|
+
def available(
|
56
|
+
self,
|
57
|
+
service: Optional[InferenceServiceABC] = None,
|
58
|
+
force_refresh: bool = False,
|
59
|
+
) -> List[LanguageModelInfo]:
|
60
|
+
"""
|
61
|
+
Get available models from all services, using cached data when available.
|
62
|
+
|
63
|
+
:param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
|
64
|
+
|
65
|
+
>>> from edsl.inference_services.OpenAIService import OpenAIService
|
66
|
+
>>> af = AvailableModelFetcher([OpenAIService()], {})
|
67
|
+
>>> af.available(service="openai")
|
68
|
+
[LanguageModelInfo(model_name='...', service_name='openai'), ...]
|
69
|
+
|
70
|
+
Returns a list of [model, service_name, index] entries.
|
71
|
+
"""
|
72
|
+
|
73
|
+
if service: # they passed a specific service
|
74
|
+
matching_models, _ = self.get_available_models_by_service(
|
75
|
+
service=service, force_refresh=force_refresh
|
76
|
+
)
|
77
|
+
return matching_models
|
78
|
+
|
79
|
+
# Nope, we need to fetch them all
|
80
|
+
all_models = self._get_all_models()
|
81
|
+
|
82
|
+
# if self.cache_handler:
|
83
|
+
# self.cache_handler.add_models_to_cache(all_models)
|
84
|
+
|
85
|
+
return all_models
|
86
|
+
|
87
|
+
def get_available_models_by_service(
|
88
|
+
self,
|
89
|
+
service: Union["InferenceServiceABC", InferenceServiceLiteral],
|
90
|
+
force_refresh: bool = False,
|
91
|
+
) -> Tuple[AvailableModels, InferenceServiceLiteral]:
|
92
|
+
"""Get models for a single service.
|
93
|
+
|
94
|
+
:param service: InferenceServiceABC - e.g., OpenAIService or "openai"
|
95
|
+
:return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
|
96
|
+
"""
|
97
|
+
if isinstance(service, str):
|
98
|
+
service = self._fetch_service_by_service_name(service)
|
99
|
+
|
100
|
+
if not force_refresh:
|
101
|
+
models_from_cache = self.cache_handler.models(
|
102
|
+
service=service._inference_service_
|
103
|
+
)
|
104
|
+
if self.verbose:
|
105
|
+
print(
|
106
|
+
"Searching cache for models with service name:",
|
107
|
+
service._inference_service_,
|
108
|
+
)
|
109
|
+
print("Got models from cache:", models_from_cache)
|
110
|
+
else:
|
111
|
+
models_from_cache = None
|
112
|
+
|
113
|
+
if models_from_cache:
|
114
|
+
# print(f"Models from cache for {service}: {models_from_cache}")
|
115
|
+
# print(hasattr(models_from_cache[0], "service_name"))
|
116
|
+
return models_from_cache, service._inference_service_
|
117
|
+
else:
|
118
|
+
return self.get_available_models_by_service_fresh(service)
|
119
|
+
|
120
|
+
def get_available_models_by_service_fresh(
|
121
|
+
self, service: Union["InferenceServiceABC", InferenceServiceLiteral]
|
122
|
+
) -> Tuple[AvailableModels, InferenceServiceLiteral]:
|
123
|
+
"""Get models for a single service. This method always fetches fresh data.
|
124
|
+
|
125
|
+
:param service: InferenceServiceABC - e.g., OpenAIService or "openai"
|
126
|
+
:return: Tuple[List[LanguageModelInfo], InferenceServiceLiteral]
|
127
|
+
"""
|
128
|
+
if isinstance(service, str):
|
129
|
+
service = self._fetch_service_by_service_name(service)
|
130
|
+
|
131
|
+
service_models: ModelNamesList = (
|
132
|
+
self.service_availability.get_service_available(service, warn=False)
|
133
|
+
)
|
134
|
+
service_name = service._inference_service_
|
135
|
+
|
136
|
+
if not service_models:
|
137
|
+
import warnings
|
138
|
+
|
139
|
+
warnings.warn(f"No models found for service {service_name}")
|
140
|
+
return [], service_name
|
141
|
+
|
142
|
+
models_list = AvailableModels(
|
143
|
+
[
|
144
|
+
LanguageModelInfo(
|
145
|
+
model_name=model_name,
|
146
|
+
service_name=service_name,
|
147
|
+
)
|
148
|
+
for model_name in service_models
|
149
|
+
]
|
150
|
+
)
|
151
|
+
self.cache_handler.add_models_to_cache(models_list) # update the cache
|
152
|
+
return models_list, service_name
|
153
|
+
|
154
|
+
def _fetch_service_by_service_name(
|
155
|
+
self, service_name: InferenceServiceLiteral
|
156
|
+
) -> "InferenceServiceABC":
|
157
|
+
"""The service name is the _inference_service_ attribute of the service."""
|
158
|
+
if service_name in self._service_map:
|
159
|
+
return self._service_map[service_name]
|
160
|
+
raise ValueError(f"Service {service_name} not found")
|
161
|
+
|
162
|
+
def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
|
163
|
+
all_models = []
|
164
|
+
with ThreadPoolExecutor(max_workers=min(len(self.services), 10)) as executor:
|
165
|
+
future_to_service = {
|
166
|
+
executor.submit(
|
167
|
+
self.get_available_models_by_service, service, force_refresh
|
168
|
+
): service
|
169
|
+
for service in self.services
|
170
|
+
}
|
171
|
+
|
172
|
+
for future in as_completed(future_to_service):
|
173
|
+
try:
|
174
|
+
models, service_name = future.result()
|
175
|
+
all_models.extend(models)
|
176
|
+
|
177
|
+
# Add any additional models for this service
|
178
|
+
for model in self.added_models.get(service_name, []):
|
179
|
+
all_models.append(
|
180
|
+
LanguageModelInfo(
|
181
|
+
model_name=model, service_name=service_name
|
182
|
+
)
|
183
|
+
)
|
184
|
+
|
185
|
+
except Exception as exc:
|
186
|
+
print(f"Service query failed for service {service_name}: {exc}")
|
187
|
+
continue
|
188
|
+
|
189
|
+
return AvailableModels(all_models)
|
190
|
+
|
191
|
+
|
192
|
+
def main():
|
193
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
194
|
+
|
195
|
+
af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
|
196
|
+
# print(af.available(service="openai"))
|
197
|
+
all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models(
|
198
|
+
force_refresh=True
|
199
|
+
)
|
200
|
+
print(all_models)
|
201
|
+
|
202
|
+
|
203
|
+
if __name__ == "__main__":
|
204
|
+
import doctest
|
205
|
+
|
206
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
207
|
+
# main()
|
208
|
+
|
209
|
+
# from edsl.inference_services.OpenAIService import OpenAIService
|
210
|
+
|
211
|
+
# af = AvailableModelFetcher([OpenAIService()], {}, verbose=True)
|
212
|
+
# # print(af.available(service="openai"))
|
213
|
+
|
214
|
+
# all_models = AvailableModelFetcher([OpenAIService()], {})._get_all_models()
|
215
|
+
# print(all_models)
|
@@ -0,0 +1,118 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Optional
|
3
|
+
import re
|
4
|
+
import boto3
|
5
|
+
from botocore.exceptions import ClientError
|
6
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
8
|
+
import json
|
9
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
10
|
+
|
11
|
+
|
12
|
+
class AwsBedrockService(InferenceServiceABC):
|
13
|
+
"""AWS Bedrock service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "bedrock"
|
16
|
+
_env_key_name_ = (
|
17
|
+
"AWS_ACCESS_KEY_ID" # or any other environment key for AWS credentials
|
18
|
+
)
|
19
|
+
key_sequence = ["output", "message", "content", 0, "text"]
|
20
|
+
input_token_name = "inputTokens"
|
21
|
+
output_token_name = "outputTokens"
|
22
|
+
usage_sequence = ["usage"]
|
23
|
+
model_exclude_list = [
|
24
|
+
"ai21.j2-grande-instruct",
|
25
|
+
"ai21.j2-jumbo-instruct",
|
26
|
+
"ai21.j2-mid",
|
27
|
+
"ai21.j2-mid-v1",
|
28
|
+
"ai21.j2-ultra",
|
29
|
+
"ai21.j2-ultra-v1",
|
30
|
+
]
|
31
|
+
_models_list_cache: List[str] = []
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def available(cls):
|
35
|
+
"""Fetch available models from AWS Bedrock."""
|
36
|
+
|
37
|
+
region = os.getenv("AWS_REGION", "us-east-1")
|
38
|
+
|
39
|
+
if not cls._models_list_cache:
|
40
|
+
client = boto3.client("bedrock", region_name=region)
|
41
|
+
all_models_ids = [
|
42
|
+
x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
|
43
|
+
]
|
44
|
+
else:
|
45
|
+
all_models_ids = cls._models_list_cache
|
46
|
+
|
47
|
+
return [m for m in all_models_ids if m not in cls.model_exclude_list]
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def create_model(
|
51
|
+
cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
|
52
|
+
) -> LanguageModel:
|
53
|
+
if model_class_name is None:
|
54
|
+
model_class_name = cls.to_class_name(model_name)
|
55
|
+
|
56
|
+
class LLM(LanguageModel):
|
57
|
+
"""
|
58
|
+
Child class of LanguageModel for interacting with AWS Bedrock models.
|
59
|
+
"""
|
60
|
+
|
61
|
+
key_sequence = cls.key_sequence
|
62
|
+
usage_sequence = cls.usage_sequence
|
63
|
+
_inference_service_ = cls._inference_service_
|
64
|
+
_model_ = model_name
|
65
|
+
_parameters_ = {
|
66
|
+
"temperature": 0.5,
|
67
|
+
"max_tokens": 512,
|
68
|
+
"top_p": 0.9,
|
69
|
+
}
|
70
|
+
input_token_name = cls.input_token_name
|
71
|
+
output_token_name = cls.output_token_name
|
72
|
+
|
73
|
+
async def async_execute_model_call(
|
74
|
+
self,
|
75
|
+
user_prompt: str,
|
76
|
+
system_prompt: str = "",
|
77
|
+
files_list: Optional[List["FileStore"]] = None,
|
78
|
+
) -> dict[str, Any]:
|
79
|
+
"""Calls the AWS Bedrock API and returns the API response."""
|
80
|
+
|
81
|
+
api_token = (
|
82
|
+
self.api_token
|
83
|
+
) # call to check the if env variables are set.
|
84
|
+
|
85
|
+
region = os.getenv("AWS_REGION", "us-east-1")
|
86
|
+
client = boto3.client("bedrock-runtime", region_name=region)
|
87
|
+
|
88
|
+
conversation = [
|
89
|
+
{
|
90
|
+
"role": "user",
|
91
|
+
"content": [{"text": user_prompt}],
|
92
|
+
}
|
93
|
+
]
|
94
|
+
system = [
|
95
|
+
{
|
96
|
+
"text": system_prompt,
|
97
|
+
}
|
98
|
+
]
|
99
|
+
try:
|
100
|
+
response = client.converse(
|
101
|
+
modelId=self._model_,
|
102
|
+
messages=conversation,
|
103
|
+
inferenceConfig={
|
104
|
+
"maxTokens": self.max_tokens,
|
105
|
+
"temperature": self.temperature,
|
106
|
+
"topP": self.top_p,
|
107
|
+
},
|
108
|
+
# system=system,
|
109
|
+
additionalModelRequestFields={},
|
110
|
+
)
|
111
|
+
return response
|
112
|
+
except (ClientError, Exception) as e:
|
113
|
+
print(e)
|
114
|
+
return {"error": str(e)}
|
115
|
+
|
116
|
+
LLM.__name__ = model_class_name
|
117
|
+
|
118
|
+
return LLM
|