edsl 0.1.15__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +45 -10
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +115 -113
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -206
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.15.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -435
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -178
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -215
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.15.dist-info/METADATA +0 -69
- edsl-0.1.15.dist-info/RECORD +0 -142
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.15.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -1,310 +1,307 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
import asyncio
|
5
|
-
import json
|
6
|
-
import time
|
7
|
-
import inspect
|
8
|
-
from typing import Coroutine
|
9
|
-
from abc import ABC, abstractmethod, ABCMeta
|
10
|
-
from rich.console import Console
|
11
|
-
from rich.table import Table
|
1
|
+
"""This module contains the LanguageModel class, which is an abstract base class for all language models.
|
2
|
+
|
3
|
+
Terminology:
|
12
4
|
|
5
|
+
raw_response: The JSON response from the model. This has all the model meta-data about the call.
|
13
6
|
|
14
|
-
from
|
15
|
-
|
16
|
-
from typing import Any, Callable, Type, List
|
17
|
-
from edsl.data import CRUDOperations, CRUD
|
18
|
-
from edsl.exceptions import LanguageModelResponseNotJSONError
|
19
|
-
from edsl.language_models.schemas import model_prices
|
20
|
-
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
7
|
+
edsl_augmented_response: The JSON response from model, but augmented with EDSL-specific information,
|
8
|
+
such as the cache key, token usage, etc.
|
21
9
|
|
22
|
-
|
23
|
-
from
|
10
|
+
generated_tokens: The actual tokens generated by the model. This is the output that is used by the user.
|
11
|
+
edsl_answer_dict: The parsed JSON response from the model either {'answer': ...} or {'answer': ..., 'comment': ...}
|
24
12
|
|
25
|
-
|
26
|
-
from edsl.enums import LanguageModelType, InferenceServiceType
|
13
|
+
"""
|
27
14
|
|
28
|
-
from
|
15
|
+
from __future__ import annotations
|
16
|
+
import warnings
|
17
|
+
from functools import wraps
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import os
|
21
|
+
from typing import (
|
22
|
+
Coroutine,
|
23
|
+
Any,
|
24
|
+
Type,
|
25
|
+
Union,
|
26
|
+
List,
|
27
|
+
get_type_hints,
|
28
|
+
TypedDict,
|
29
|
+
Optional,
|
30
|
+
TYPE_CHECKING,
|
31
|
+
)
|
32
|
+
from abc import ABC, abstractmethod
|
33
|
+
|
34
|
+
from edsl.data_transfer_models import (
|
35
|
+
ModelResponse,
|
36
|
+
ModelInputs,
|
37
|
+
EDSLOutput,
|
38
|
+
AgentResponseDict,
|
39
|
+
)
|
40
|
+
|
41
|
+
if TYPE_CHECKING:
|
42
|
+
from edsl.data.Cache import Cache
|
43
|
+
from edsl.scenarios.FileStore import FileStore
|
44
|
+
from edsl.questions.QuestionBase import QuestionBase
|
45
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
46
|
+
|
47
|
+
from edsl.enums import InferenceServiceType
|
48
|
+
|
49
|
+
from edsl.utilities.decorators import (
|
50
|
+
sync_wrapper,
|
51
|
+
jupyter_nb_handler,
|
52
|
+
)
|
53
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
54
|
+
|
55
|
+
from edsl.Base import PersistenceMixin, RepresentationMixin
|
56
|
+
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
57
|
+
|
58
|
+
from edsl.language_models.key_management.KeyLookupCollection import (
|
59
|
+
KeyLookupCollection,
|
60
|
+
)
|
61
|
+
|
62
|
+
from edsl.language_models.RawResponseHandler import RawResponseHandler
|
29
63
|
|
30
64
|
|
31
65
|
def handle_key_error(func):
|
66
|
+
"""Handle KeyError exceptions."""
|
67
|
+
|
32
68
|
@wraps(func)
|
33
69
|
def wrapper(*args, **kwargs):
|
34
70
|
try:
|
35
71
|
return func(*args, **kwargs)
|
36
72
|
assert True == False
|
37
73
|
except KeyError as e:
|
38
|
-
# Handle the KeyError exception
|
39
74
|
return f"""KeyError occurred: {e}. This is most likely because the model you are using
|
40
75
|
returned a JSON object we were not expecting."""
|
41
76
|
|
42
77
|
return wrapper
|
43
78
|
|
44
79
|
|
45
|
-
class
|
46
|
-
|
47
|
-
|
48
|
-
REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
|
80
|
+
class classproperty:
|
81
|
+
def __init__(self, method):
|
82
|
+
self.method = method
|
49
83
|
|
50
|
-
def
|
51
|
-
|
52
|
-
# if name != "LanguageModel":
|
53
|
-
if (model_name := getattr(cls, "_model_", None)) is not None:
|
54
|
-
RegisterLanguageModelsMeta.check_required_class_variables(
|
55
|
-
cls, RegisterLanguageModelsMeta.REQUIRED_CLASS_ATTRIBUTES
|
56
|
-
)
|
84
|
+
def __get__(self, instance, cls):
|
85
|
+
return self.method(cls)
|
57
86
|
|
58
|
-
## Check that model name is valid
|
59
|
-
if not LanguageModelType.is_value_valid(model_name):
|
60
|
-
acceptable_values = [item.value for item in LanguageModelType]
|
61
|
-
raise LanguageModelAttributeTypeError(
|
62
|
-
f"""A LanguageModel's model must be one of {LanguageModelType} values, which are
|
63
|
-
{acceptable_values}. You passed {model_name}."""
|
64
|
-
)
|
65
87
|
|
66
|
-
|
67
|
-
inference_service := getattr(cls, "_inference_service_", None)
|
68
|
-
):
|
69
|
-
acceptable_values = [item.value for item in InferenceServiceType]
|
70
|
-
raise LanguageModelAttributeTypeError(
|
71
|
-
f"""A LanguageModel's model must have an _inference_service_ value from
|
72
|
-
{acceptable_values}. You passed {inference_service}."""
|
73
|
-
)
|
88
|
+
from edsl.Base import HashingMixin
|
74
89
|
|
75
|
-
# LanguageModel children have to implement the async_execute_model_call method
|
76
|
-
RegisterLanguageModelsMeta.verify_method(
|
77
|
-
candidate_class=cls,
|
78
|
-
method_name="async_execute_model_call",
|
79
|
-
expected_return_type=dict[str, Any],
|
80
|
-
required_parameters=[("user_prompt", str), ("system_prompt", str)],
|
81
|
-
must_be_async=True,
|
82
|
-
)
|
83
|
-
# LanguageModel children have to implement the parse_response method
|
84
|
-
RegisterLanguageModelsMeta.verify_method(
|
85
|
-
candidate_class=cls,
|
86
|
-
method_name="parse_response",
|
87
|
-
expected_return_type=str,
|
88
|
-
required_parameters=[("raw_response", dict[str, Any])],
|
89
|
-
must_be_async=False,
|
90
|
-
)
|
91
|
-
RegisterLanguageModelsMeta._registry[model_name] = cls
|
92
90
|
|
93
|
-
|
94
|
-
|
95
|
-
|
91
|
+
class LanguageModel(
|
92
|
+
PersistenceMixin,
|
93
|
+
RepresentationMixin,
|
94
|
+
HashingMixin,
|
95
|
+
ABC,
|
96
|
+
metaclass=RegisterLanguageModelsMeta,
|
97
|
+
):
|
98
|
+
"""ABC for Language Models."""
|
96
99
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
+
_model_ = None
|
101
|
+
key_sequence = (
|
102
|
+
None # This should be something like ["choices", 0, "message", "content"]
|
103
|
+
)
|
104
|
+
|
105
|
+
DEFAULT_RPM = 100
|
106
|
+
DEFAULT_TPM = 1000
|
107
|
+
|
108
|
+
@classproperty
|
109
|
+
def response_handler(cls):
|
110
|
+
key_sequence = cls.key_sequence
|
111
|
+
usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
|
112
|
+
return RawResponseHandler(key_sequence, usage_sequence)
|
113
|
+
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
tpm: Optional[float] = None,
|
117
|
+
rpm: Optional[float] = None,
|
118
|
+
omit_system_prompt_if_empty_string: bool = True,
|
119
|
+
key_lookup: Optional["KeyLookup"] = None,
|
120
|
+
**kwargs,
|
100
121
|
):
|
101
|
-
"""
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
>>> RegisterLanguageModelsMeta.check_required_class_variables(M2, ["_model_", "_parameters_"])
|
109
|
-
Traceback (most recent call last):
|
110
|
-
...
|
111
|
-
Exception: Class M2 does not have required attribute _parameters_
|
112
|
-
"""
|
113
|
-
required_attributes = required_attributes or []
|
114
|
-
for attribute in required_attributes:
|
115
|
-
if not hasattr(candidate_class, attribute):
|
116
|
-
raise Exception(
|
117
|
-
f"Class {candidate_class.__name__} does not have required attribute {attribute}"
|
118
|
-
)
|
122
|
+
"""Initialize the LanguageModel."""
|
123
|
+
self.model = getattr(self, "_model_", None)
|
124
|
+
default_parameters = getattr(self, "_parameters_", None)
|
125
|
+
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
126
|
+
self.parameters = parameters
|
127
|
+
self.remote = False
|
128
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
119
129
|
|
120
|
-
|
121
|
-
|
122
|
-
candidate_class: LanguageModel,
|
123
|
-
method_name: str,
|
124
|
-
expected_return_type: Any,
|
125
|
-
required_parameters: List[tuple[str, Any]] = None,
|
126
|
-
must_be_async: bool = False,
|
127
|
-
):
|
128
|
-
RegisterLanguageModelsMeta._check_method_defined(candidate_class, method_name)
|
130
|
+
self.key_lookup = self._set_key_lookup(key_lookup)
|
131
|
+
self.model_info = self.key_lookup.get(self._inference_service_)
|
129
132
|
|
130
|
-
|
131
|
-
|
132
|
-
signature = inspect.signature(method)
|
133
|
+
if rpm is not None:
|
134
|
+
self._rpm = rpm
|
133
135
|
|
134
|
-
|
136
|
+
if tpm is not None:
|
137
|
+
self._tpm = tpm
|
135
138
|
|
136
|
-
|
137
|
-
|
139
|
+
for key, value in parameters.items():
|
140
|
+
setattr(self, key, value)
|
138
141
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
RegisterLanguageModelsMeta._verify_parameter(
|
143
|
-
params, param_name, param_type, method_name
|
144
|
-
)
|
142
|
+
for key, value in kwargs.items():
|
143
|
+
if key not in parameters:
|
144
|
+
setattr(self, key, value)
|
145
145
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
>>> class M:
|
150
|
-
... def f(self): pass
|
151
|
-
>>> RegisterLanguageModelsMeta._check_method_defined(M, "f")
|
152
|
-
>>> RegisterLanguageModelsMeta._check_method_defined(M, "g")
|
153
|
-
Traceback (most recent call last):
|
154
|
-
...
|
155
|
-
NotImplementedError: g method must be implemented
|
156
|
-
"""
|
157
|
-
if not hasattr(cls, method_name):
|
158
|
-
raise NotImplementedError(f"{method_name} method must be implemented")
|
146
|
+
if kwargs.get("skip_api_key_check", False):
|
147
|
+
# Skip the API key check. Sometimes this is useful for testing.
|
148
|
+
self._api_token = None
|
159
149
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
TypeError: A LangugeModel class with method f must be an asynchronous method
|
169
|
-
"""
|
170
|
-
if not inspect.iscoroutinefunction(func):
|
171
|
-
raise TypeError(
|
172
|
-
f"A LangugeModel class with method {func.__name__} must be an asynchronous method"
|
173
|
-
)
|
150
|
+
def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
|
151
|
+
"""Set the key lookup."""
|
152
|
+
if key_lookup is not None:
|
153
|
+
return key_lookup
|
154
|
+
else:
|
155
|
+
klc = KeyLookupCollection()
|
156
|
+
klc.add_key_lookup(fetch_order=("config", "env"))
|
157
|
+
return klc.get(("config", "env"))
|
174
158
|
|
175
|
-
|
176
|
-
|
177
|
-
if
|
178
|
-
|
179
|
-
|
180
|
-
"""
|
181
|
-
)
|
182
|
-
if params[param_name].annotation != param_type:
|
183
|
-
raise TypeError(
|
184
|
-
f"""Parameter "{param_name}" of method "{method_name}" must be of type {param_type.__name__}.
|
185
|
-
Got {params[param_name].annotation} instead.
|
186
|
-
"""
|
187
|
-
)
|
159
|
+
def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
160
|
+
"""Set the key lookup, later"""
|
161
|
+
if hasattr(self, "_api_token"):
|
162
|
+
del self._api_token
|
163
|
+
self.key_lookup = key_lookup
|
188
164
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
>>> class M:
|
194
|
-
... async def f(self) -> str: pass
|
195
|
-
>>> RegisterLanguageModelsMeta._check_return_type(M.f, str)
|
196
|
-
>>> class N:
|
197
|
-
... async def f(self) -> int: pass
|
198
|
-
>>> RegisterLanguageModelsMeta._check_return_type(N.f, str)
|
199
|
-
Traceback (most recent call last):
|
200
|
-
...
|
201
|
-
TypeError: Return type of f must be <class 'str'>. Got <class 'int'>
|
165
|
+
def ask_question(self, question: "QuestionBase") -> str:
|
166
|
+
"""Ask a question and return the response.
|
167
|
+
|
168
|
+
:param question: The question to ask.
|
202
169
|
"""
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
if return_type != expected_return_type:
|
207
|
-
raise TypeError(
|
208
|
-
f"Return type of {method.__name__} must be {expected_return_type}. Got {return_type}."
|
209
|
-
)
|
170
|
+
user_prompt = question.get_instructions().render(question.data).text
|
171
|
+
system_prompt = "You are a helpful agent pretending to be a human."
|
172
|
+
return self.execute_model_call(user_prompt, system_prompt)
|
210
173
|
|
211
|
-
@
|
212
|
-
def
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
d[cls._model_] = cls
|
174
|
+
@property
|
175
|
+
def rpm(self):
|
176
|
+
if not hasattr(self, "_rpm"):
|
177
|
+
if self.model_info is None:
|
178
|
+
self._rpm = self.DEFAULT_RPM
|
217
179
|
else:
|
218
|
-
|
219
|
-
|
220
|
-
)
|
221
|
-
return d
|
180
|
+
self._rpm = self.model_info.rpm
|
181
|
+
return self._rpm
|
222
182
|
|
183
|
+
@property
|
184
|
+
def tpm(self):
|
185
|
+
if not hasattr(self, "_tpm"):
|
186
|
+
if self.model_info is None:
|
187
|
+
self._tpm = self.DEFAULT_TPM
|
188
|
+
else:
|
189
|
+
self._tpm = self.model_info.tpm
|
190
|
+
return self._tpm
|
223
191
|
|
224
|
-
|
225
|
-
|
226
|
-
):
|
227
|
-
|
192
|
+
# in case we want to override the default values
|
193
|
+
@tpm.setter
|
194
|
+
def tpm(self, value):
|
195
|
+
self._tpm = value
|
228
196
|
|
229
|
-
|
197
|
+
@rpm.setter
|
198
|
+
def rpm(self, value):
|
199
|
+
self._rpm = value
|
230
200
|
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
201
|
+
@property
|
202
|
+
def api_token(self) -> str:
|
203
|
+
if not hasattr(self, "_api_token"):
|
204
|
+
info = self.key_lookup.get(self._inference_service_, None)
|
205
|
+
if info is None:
|
206
|
+
raise ValueError(
|
207
|
+
f"No key found for service '{self._inference_service_}'"
|
208
|
+
)
|
209
|
+
self._api_token = info.api_token
|
210
|
+
return self._api_token
|
211
|
+
|
212
|
+
def __getitem__(self, key):
|
213
|
+
return getattr(self, key)
|
214
|
+
|
215
|
+
def hello(self, verbose=False):
|
216
|
+
"""Runs a simple test to check if the model is working."""
|
217
|
+
token = self.api_token
|
218
|
+
masked = token[: min(8, len(token))] + "..."
|
219
|
+
if verbose:
|
220
|
+
print(f"Current key is {masked}")
|
221
|
+
return self.execute_model_call(
|
222
|
+
user_prompt="Hello, model!", system_prompt="You are a helpful agent."
|
223
|
+
)
|
235
224
|
|
236
|
-
def
|
237
|
-
"""
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
225
|
+
def has_valid_api_key(self) -> bool:
|
226
|
+
"""Check if the model has a valid API key.
|
227
|
+
|
228
|
+
>>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
|
229
|
+
True
|
230
|
+
|
231
|
+
This method is used to check if the model has a valid API key.
|
242
232
|
"""
|
243
|
-
|
244
|
-
default_parameters = getattr(self, "_parameters_", None)
|
245
|
-
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
246
|
-
self.parameters = parameters
|
233
|
+
from edsl.enums import service_to_api_keyname
|
247
234
|
|
248
|
-
|
249
|
-
|
235
|
+
if self._model_ == "test":
|
236
|
+
return True
|
250
237
|
|
251
|
-
|
252
|
-
|
253
|
-
|
238
|
+
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
239
|
+
key_value = os.getenv(key_name)
|
240
|
+
return key_value is not None
|
254
241
|
|
255
|
-
|
256
|
-
|
257
|
-
self.crud = crud
|
242
|
+
def __hash__(self) -> str:
|
243
|
+
"""Allow the model to be used as a key in a dictionary.
|
258
244
|
|
259
|
-
|
260
|
-
|
261
|
-
|
245
|
+
>>> m = LanguageModel.example()
|
246
|
+
>>> hash(m)
|
247
|
+
1811901442659237949
|
248
|
+
"""
|
249
|
+
from edsl.utilities.utilities import dict_hash
|
262
250
|
|
263
|
-
|
264
|
-
return self.model == other.model and self.parameters == other.parameters
|
251
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
265
252
|
|
266
|
-
def
|
267
|
-
|
268
|
-
if hasattr(self, "get_rate_limits"):
|
269
|
-
self.__rate_limits = self.get_rate_limits()
|
270
|
-
else:
|
271
|
-
self.__rate_limits = self.__default_rate_limits
|
253
|
+
def __eq__(self, other) -> bool:
|
254
|
+
"""Check is two models are the same.
|
272
255
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
return self._safety_factor * self.__rate_limits["rpm"]
|
256
|
+
>>> m1 = LanguageModel.example()
|
257
|
+
>>> m2 = LanguageModel.example()
|
258
|
+
>>> m1 == m2
|
259
|
+
True
|
278
260
|
|
279
|
-
|
280
|
-
|
281
|
-
"Model's tokens-per-minute limit"
|
282
|
-
self._set_rate_limits()
|
283
|
-
return self._safety_factor * self.__rate_limits["tpm"]
|
261
|
+
"""
|
262
|
+
return self.model == other.model and self.parameters == other.parameters
|
284
263
|
|
285
264
|
@staticmethod
|
286
265
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
287
|
-
"""
|
266
|
+
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
|
288
267
|
|
289
268
|
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9})
|
290
269
|
{'temperature': 0.5}
|
291
270
|
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
|
292
271
|
{'temperature': 0.5, 'max_tokens': 1000}
|
293
272
|
"""
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
273
|
+
# this is the case when data is loaded from a dict after serialization
|
274
|
+
if "parameters" in passed_parameter_dict:
|
275
|
+
passed_parameter_dict = passed_parameter_dict["parameters"]
|
276
|
+
return {
|
277
|
+
parameter_name: passed_parameter_dict.get(parameter_name, default_value)
|
278
|
+
for parameter_name, default_value in default_parameter_dict.items()
|
279
|
+
}
|
280
|
+
|
281
|
+
def __call__(self, user_prompt: str, system_prompt: str):
|
282
|
+
return self.execute_model_call(user_prompt, system_prompt)
|
301
283
|
|
302
284
|
@abstractmethod
|
303
|
-
async def async_execute_model_call():
|
285
|
+
async def async_execute_model_call(user_prompt: str, system_prompt: str):
|
286
|
+
"""Execute the model call and returns a coroutine."""
|
304
287
|
pass
|
305
288
|
|
289
|
+
async def remote_async_execute_model_call(
|
290
|
+
self, user_prompt: str, system_prompt: str
|
291
|
+
):
|
292
|
+
"""Execute the model call and returns the result as a coroutine, using Coop."""
|
293
|
+
from edsl.coop import Coop
|
294
|
+
|
295
|
+
client = Coop()
|
296
|
+
response_data = await client.remote_async_execute_model_call(
|
297
|
+
self.to_dict(), user_prompt, system_prompt
|
298
|
+
)
|
299
|
+
return response_data
|
300
|
+
|
306
301
|
@jupyter_nb_handler
|
307
302
|
def execute_model_call(self, *args, **kwargs) -> Coroutine:
|
303
|
+
"""Execute the model call and returns the result as a coroutine."""
|
304
|
+
|
308
305
|
async def main():
|
309
306
|
results = await asyncio.gather(
|
310
307
|
self.async_execute_model_call(*args, **kwargs)
|
@@ -313,193 +310,317 @@ class LanguageModel(
|
|
313
310
|
|
314
311
|
return main()
|
315
312
|
|
316
|
-
@
|
317
|
-
def
|
318
|
-
"""
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
"choices": [
|
325
|
-
{
|
326
|
-
"finish_reason": "stop",
|
327
|
-
"index": 0,
|
328
|
-
"logprobs": None,
|
329
|
-
"message": {
|
330
|
-
"content": "Hello! How can I assist you today? If you have any questions or need information on a particular topic, feel free to ask.",
|
331
|
-
"role": "assistant",
|
332
|
-
"function_call": None,
|
333
|
-
"tool_calls": None,
|
334
|
-
},
|
335
|
-
}
|
336
|
-
],
|
337
|
-
"created": 1704637774,
|
338
|
-
"model": "gpt-4-1106-preview",
|
339
|
-
"object": "chat.completion",
|
340
|
-
"system_fingerprint": "fp_168383a679",
|
341
|
-
"usage": {"completion_tokens": 27, "prompt_tokens": 13, "total_tokens": 40},
|
342
|
-
}
|
313
|
+
@classmethod
|
314
|
+
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
315
|
+
"""Return the generated token string from the raw response.
|
316
|
+
|
317
|
+
>>> m = LanguageModel.example(test_model = True)
|
318
|
+
>>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
319
|
+
>>> m.get_generated_token_string(raw_response)
|
320
|
+
'Hello world'
|
343
321
|
|
344
|
-
To actually tract the response, we need to grab
|
345
|
-
data["choices[0]"]["message"]["content"].
|
346
322
|
"""
|
347
|
-
|
323
|
+
return cls.response_handler.get_generated_token_string(raw_response)
|
348
324
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
return
|
358
|
-
|
359
|
-
async def
|
360
|
-
self,
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
325
|
+
@classmethod
|
326
|
+
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
327
|
+
"""Return the usage dictionary from the raw response."""
|
328
|
+
return cls.response_handler.get_usage_dict(raw_response)
|
329
|
+
|
330
|
+
@classmethod
|
331
|
+
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
332
|
+
"""Parses the API response and returns the response text."""
|
333
|
+
return cls.response_handler.parse_response(raw_response)
|
334
|
+
|
335
|
+
async def _async_get_intended_model_call_outcome(
|
336
|
+
self,
|
337
|
+
user_prompt: str,
|
338
|
+
system_prompt: str,
|
339
|
+
cache: Cache,
|
340
|
+
iteration: int = 0,
|
341
|
+
files_list: Optional[List[FileStore]] = None,
|
342
|
+
invigilator=None,
|
343
|
+
) -> ModelResponse:
|
344
|
+
"""Handle caching of responses.
|
345
|
+
|
346
|
+
:param user_prompt: The user's prompt.
|
347
|
+
:param system_prompt: The system's prompt.
|
348
|
+
:param iteration: The iteration number.
|
349
|
+
:param cache: The cache to use.
|
350
|
+
:param files_list: The list of files to use.
|
351
|
+
:param invigilator: The invigilator to use.
|
352
|
+
|
353
|
+
If the cache isn't being used, it just returns a 'fresh' call to the LLM.
|
365
354
|
But if cache is being used, it first checks the database to see if the response is already there.
|
366
355
|
If it is, it returns the cached response, but again appends some tracking information.
|
367
356
|
If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
|
368
357
|
|
369
358
|
If self.use_cache is True, then attempts to retrieve the response from the database;
|
370
|
-
if not in the DB, calls the LLM and writes the response to the DB.
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
cached_response = self.crud.get_LLMOutputData(
|
378
|
-
model=str(self.model),
|
379
|
-
parameters=str(self.parameters),
|
380
|
-
system_prompt=system_prompt,
|
381
|
-
prompt=user_prompt,
|
382
|
-
)
|
359
|
+
if not in the DB, calls the LLM and writes the response to the DB.
|
360
|
+
|
361
|
+
>>> from edsl import Cache
|
362
|
+
>>> m = LanguageModel.example(test_model = True)
|
363
|
+
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
364
|
+
ModelResponse(...)"""
|
383
365
|
|
384
|
-
if
|
366
|
+
if files_list:
|
367
|
+
files_hash = "+".join([str(hash(file)) for file in files_list])
|
368
|
+
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
369
|
+
else:
|
370
|
+
user_prompt_with_hashes = user_prompt
|
371
|
+
|
372
|
+
cache_call_params = {
|
373
|
+
"model": str(self.model),
|
374
|
+
"parameters": self.parameters,
|
375
|
+
"system_prompt": system_prompt,
|
376
|
+
"user_prompt": user_prompt_with_hashes,
|
377
|
+
"iteration": iteration,
|
378
|
+
}
|
379
|
+
cached_response, cache_key = cache.fetch(**cache_call_params)
|
380
|
+
|
381
|
+
if cache_used := cached_response is not None:
|
385
382
|
response = json.loads(cached_response)
|
386
|
-
cache_used = True
|
387
383
|
else:
|
388
|
-
|
389
|
-
|
390
|
-
|
384
|
+
f = (
|
385
|
+
self.remote_async_execute_model_call
|
386
|
+
if hasattr(self, "remote") and self.remote
|
387
|
+
else self.async_execute_model_call
|
388
|
+
)
|
389
|
+
params = {
|
390
|
+
"user_prompt": user_prompt,
|
391
|
+
"system_prompt": system_prompt,
|
392
|
+
"files_list": files_list,
|
393
|
+
}
|
394
|
+
from edsl.config import CONFIG
|
395
|
+
|
396
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
397
|
+
|
398
|
+
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
399
|
+
new_cache_key = cache.store(
|
400
|
+
**cache_call_params, response=response
|
401
|
+
) # store the response in the cache
|
402
|
+
assert new_cache_key == cache_key # should be the same
|
403
|
+
|
404
|
+
cost = self.cost(response)
|
405
|
+
return ModelResponse(
|
406
|
+
response=response,
|
407
|
+
cache_used=cache_used,
|
408
|
+
cache_key=cache_key,
|
409
|
+
cached_response=cached_response,
|
410
|
+
cost=cost,
|
411
|
+
)
|
391
412
|
|
392
|
-
|
413
|
+
_get_intended_model_call_outcome = sync_wrapper(
|
414
|
+
_async_get_intended_model_call_outcome
|
415
|
+
)
|
393
416
|
|
394
|
-
|
417
|
+
def simple_ask(
|
418
|
+
self,
|
419
|
+
question: QuestionBase,
|
420
|
+
system_prompt="You are a helpful agent pretending to be a human.",
|
421
|
+
top_logprobs=2,
|
422
|
+
):
|
423
|
+
"""Ask a question and return the response."""
|
424
|
+
self.logprobs = True
|
425
|
+
self.top_logprobs = top_logprobs
|
426
|
+
return self.execute_model_call(
|
427
|
+
user_prompt=question.human_readable(), system_prompt=system_prompt
|
428
|
+
)
|
395
429
|
|
396
|
-
def
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
430
|
+
async def async_get_response(
|
431
|
+
self,
|
432
|
+
user_prompt: str,
|
433
|
+
system_prompt: str,
|
434
|
+
cache: Cache,
|
435
|
+
iteration: int = 1,
|
436
|
+
files_list: Optional[List[FileStore]] = None,
|
437
|
+
**kwargs,
|
438
|
+
) -> dict:
|
439
|
+
"""Get response, parse, and return as string.
|
440
|
+
|
441
|
+
:param user_prompt: The user's prompt.
|
442
|
+
:param system_prompt: The system's prompt.
|
443
|
+
:param cache: The cache to use.
|
444
|
+
:param iteration: The iteration number.
|
445
|
+
:param files_list: The list of files to use.
|
446
|
+
|
447
|
+
"""
|
448
|
+
params = {
|
449
|
+
"user_prompt": user_prompt,
|
450
|
+
"system_prompt": system_prompt,
|
451
|
+
"iteration": iteration,
|
452
|
+
"cache": cache,
|
453
|
+
"files_list": files_list,
|
454
|
+
}
|
455
|
+
if "invigilator" in kwargs:
|
456
|
+
params.update({"invigilator": kwargs["invigilator"]})
|
457
|
+
|
458
|
+
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
459
|
+
model_outputs: ModelResponse = (
|
460
|
+
await self._async_get_intended_model_call_outcome(**params)
|
407
461
|
)
|
462
|
+
edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
|
408
463
|
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
except json.JSONDecodeError as e:
|
416
|
-
# TODO: Turn into logs to generate issues
|
417
|
-
dict_response, success = await repair(response, str(e))
|
418
|
-
if not success:
|
419
|
-
raise Exception("Even the repair failed.")
|
420
|
-
|
421
|
-
dict_response["cached_response"] = raw_response["cached_response"]
|
422
|
-
dict_response["usage"] = raw_response.get("usage", {})
|
423
|
-
dict_response["raw_model_response"] = raw_response
|
424
|
-
return dict_response
|
464
|
+
agent_response_dict = AgentResponseDict(
|
465
|
+
model_inputs=model_inputs,
|
466
|
+
model_outputs=model_outputs,
|
467
|
+
edsl_dict=edsl_dict,
|
468
|
+
)
|
469
|
+
return agent_response_dict
|
425
470
|
|
426
471
|
get_response = sync_wrapper(async_get_response)
|
427
472
|
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
"""
|
433
|
-
|
434
|
-
usage.
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
473
|
+
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
474
|
+
"""Return the dollar cost of a raw response.
|
475
|
+
|
476
|
+
:param raw_response: The raw response from the model.
|
477
|
+
"""
|
478
|
+
|
479
|
+
usage = self.get_usage_dict(raw_response)
|
480
|
+
from edsl.language_models.PriceManager import PriceManager
|
481
|
+
|
482
|
+
price_manger = PriceManager()
|
483
|
+
return price_manger.calculate_cost(
|
484
|
+
inference_service=self._inference_service_,
|
485
|
+
model=self.model,
|
486
|
+
usage=usage,
|
487
|
+
input_token_name=self.input_token_name,
|
488
|
+
output_token_name=self.output_token_name,
|
440
489
|
)
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
490
|
+
|
491
|
+
def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
|
492
|
+
"""Convert instance to a dictionary
|
493
|
+
|
494
|
+
:param add_edsl_version: Whether to add the EDSL version to the dictionary.
|
495
|
+
|
496
|
+
>>> m = LanguageModel.example()
|
497
|
+
>>> m.to_dict()
|
498
|
+
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
499
|
+
"""
|
500
|
+
d = {
|
501
|
+
"model": self.model,
|
502
|
+
"parameters": self.parameters,
|
503
|
+
}
|
504
|
+
if add_edsl_version:
|
505
|
+
from edsl import __version__
|
506
|
+
|
507
|
+
d["edsl_version"] = __version__
|
508
|
+
d["edsl_class_name"] = self.__class__.__name__
|
509
|
+
return d
|
456
510
|
|
457
511
|
@classmethod
|
512
|
+
@remove_edsl_version
|
458
513
|
def from_dict(cls, data: dict) -> Type[LanguageModel]:
|
459
|
-
"""
|
460
|
-
from edsl.language_models.
|
514
|
+
"""Convert dictionary to a LanguageModel child instance."""
|
515
|
+
from edsl.language_models.model import get_model_class
|
461
516
|
|
462
517
|
model_class = get_model_class(data["model"])
|
463
|
-
data["use_cache"] = True
|
464
518
|
return model_class(**data)
|
465
519
|
|
466
|
-
#######################
|
467
|
-
# DUNDER METHODS
|
468
|
-
#######################
|
469
520
|
def __repr__(self) -> str:
|
470
|
-
|
521
|
+
"""Return a representation of the object."""
|
522
|
+
param_string = ", ".join(
|
523
|
+
f"{key} = {value}" for key, value in self.parameters.items()
|
524
|
+
)
|
525
|
+
return (
|
526
|
+
f"Model(model_name = '{self.model}'"
|
527
|
+
+ (f", {param_string}" if param_string else "")
|
528
|
+
+ ")"
|
529
|
+
)
|
471
530
|
|
472
531
|
def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
|
473
|
-
"""Combine two models into a single model (other_model takes precedence over self)"""
|
474
|
-
|
532
|
+
"""Combine two models into a single model (other_model takes precedence over self)."""
|
533
|
+
import warnings
|
534
|
+
|
535
|
+
warnings.warn(
|
475
536
|
f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
|
476
537
|
by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
|
477
538
|
)
|
478
539
|
return other_model or self
|
479
540
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
541
|
+
@classmethod
|
542
|
+
def example(
|
543
|
+
cls,
|
544
|
+
test_model: bool = False,
|
545
|
+
canned_response: str = "Hello world",
|
546
|
+
throw_exception: bool = False,
|
547
|
+
) -> LanguageModel:
|
548
|
+
"""Return a default instance of the class.
|
549
|
+
|
550
|
+
>>> from edsl.language_models import LanguageModel
|
551
|
+
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
|
552
|
+
>>> isinstance(m, LanguageModel)
|
553
|
+
True
|
554
|
+
>>> from edsl import QuestionFreeText
|
555
|
+
>>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
|
556
|
+
>>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
|
557
|
+
'WOWZA!'
|
558
|
+
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
|
559
|
+
>>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
|
560
|
+
Exception report saved to ...
|
561
|
+
Also see: ...
|
562
|
+
"""
|
563
|
+
from edsl.language_models.model import Model
|
485
564
|
|
486
|
-
|
487
|
-
|
488
|
-
|
565
|
+
if test_model:
|
566
|
+
m = Model(
|
567
|
+
"test", canned_response=canned_response, throw_exception=throw_exception
|
568
|
+
)
|
569
|
+
return m
|
570
|
+
else:
|
571
|
+
return Model(skip_api_key_check=True)
|
489
572
|
|
490
|
-
|
573
|
+
def from_cache(self, cache: "Cache") -> LanguageModel:
|
491
574
|
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
575
|
+
from copy import deepcopy
|
576
|
+
from types import MethodType
|
577
|
+
from edsl import Cache
|
578
|
+
|
579
|
+
new_instance = deepcopy(self)
|
580
|
+
print("Cache entries", len(cache))
|
581
|
+
new_instance.cache = Cache(
|
582
|
+
data={k: v for k, v in cache.items() if v.model == self.model}
|
583
|
+
)
|
584
|
+
print("Cache entries with same model", len(new_instance.cache))
|
585
|
+
|
586
|
+
new_instance.user_prompts = [
|
587
|
+
ce.user_prompt for ce in new_instance.cache.values()
|
588
|
+
]
|
589
|
+
new_instance.system_prompts = [
|
590
|
+
ce.system_prompt for ce in new_instance.cache.values()
|
591
|
+
]
|
592
|
+
|
593
|
+
async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
|
594
|
+
cache_call_params = {
|
595
|
+
"model": str(self.model),
|
596
|
+
"parameters": self.parameters,
|
597
|
+
"system_prompt": system_prompt,
|
598
|
+
"user_prompt": user_prompt,
|
599
|
+
"iteration": 1,
|
600
|
+
}
|
601
|
+
cached_response, cache_key = cache.fetch(**cache_call_params)
|
602
|
+
response = json.loads(cached_response)
|
603
|
+
cost = 0
|
604
|
+
return ModelResponse(
|
605
|
+
response=response,
|
606
|
+
cache_used=True,
|
607
|
+
cache_key=cache_key,
|
608
|
+
cached_response=cached_response,
|
609
|
+
cost=cost,
|
610
|
+
)
|
611
|
+
|
612
|
+
# Bind the new method to the copied instance
|
613
|
+
setattr(
|
614
|
+
new_instance,
|
615
|
+
"async_execute_model_call",
|
616
|
+
MethodType(async_execute_model_call, new_instance),
|
617
|
+
)
|
496
618
|
|
497
|
-
return
|
619
|
+
return new_instance
|
498
620
|
|
499
621
|
|
500
622
|
if __name__ == "__main__":
|
501
|
-
|
502
|
-
|
503
|
-
from edsl.language_models import LanguageModel
|
623
|
+
"""Run the module's test suite."""
|
624
|
+
import doctest
|
504
625
|
|
505
|
-
|
626
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|