edsl 0.1.46__py3-none-any.whl → 0.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +44 -39
- edsl/__version__.py +1 -1
- edsl/agents/__init__.py +4 -2
- edsl/agents/{Agent.py → agent.py} +442 -152
- edsl/agents/{AgentList.py → agent_list.py} +220 -162
- edsl/agents/descriptors.py +46 -7
- edsl/{exceptions/agents.py → agents/exceptions.py} +3 -12
- edsl/base/__init__.py +75 -0
- edsl/base/base_class.py +1303 -0
- edsl/base/data_transfer_models.py +114 -0
- edsl/base/enums.py +215 -0
- edsl/base.py +8 -0
- edsl/buckets/__init__.py +25 -0
- edsl/buckets/bucket_collection.py +324 -0
- edsl/buckets/model_buckets.py +206 -0
- edsl/buckets/token_bucket.py +502 -0
- edsl/{jobs/buckets/TokenBucketAPI.py → buckets/token_bucket_api.py} +1 -1
- edsl/buckets/token_bucket_client.py +509 -0
- edsl/caching/__init__.py +20 -0
- edsl/caching/cache.py +814 -0
- edsl/caching/cache_entry.py +427 -0
- edsl/{data/CacheHandler.py → caching/cache_handler.py} +14 -15
- edsl/caching/exceptions.py +24 -0
- edsl/caching/orm.py +30 -0
- edsl/{data/RemoteCacheSync.py → caching/remote_cache_sync.py} +3 -3
- edsl/caching/sql_dict.py +441 -0
- edsl/config/__init__.py +8 -0
- edsl/config/config_class.py +177 -0
- edsl/config.py +4 -176
- edsl/conversation/Conversation.py +7 -7
- edsl/conversation/car_buying.py +4 -4
- edsl/conversation/chips.py +6 -6
- edsl/coop/__init__.py +25 -2
- edsl/coop/coop.py +430 -113
- edsl/coop/{ExpectedParrotKeyHandler.py → ep_key_handling.py} +86 -10
- edsl/coop/exceptions.py +62 -0
- edsl/coop/price_fetcher.py +126 -0
- edsl/coop/utils.py +89 -24
- edsl/data_transfer_models.py +5 -72
- edsl/dataset/__init__.py +10 -0
- edsl/{results/Dataset.py → dataset/dataset.py} +116 -36
- edsl/dataset/dataset_operations_mixin.py +1492 -0
- edsl/{results/DatasetTree.py → dataset/dataset_tree.py} +156 -75
- edsl/{results/TableDisplay.py → dataset/display/table_display.py} +18 -7
- edsl/{results → dataset/display}/table_renderers.py +58 -2
- edsl/{results → dataset}/file_exports.py +4 -5
- edsl/{results → dataset}/smart_objects.py +2 -2
- edsl/enums.py +5 -205
- edsl/inference_services/__init__.py +5 -0
- edsl/inference_services/{AvailableModelCacheHandler.py → available_model_cache_handler.py} +2 -3
- edsl/inference_services/{AvailableModelFetcher.py → available_model_fetcher.py} +8 -14
- edsl/inference_services/data_structures.py +3 -2
- edsl/{exceptions/inference_services.py → inference_services/exceptions.py} +1 -1
- edsl/inference_services/{InferenceServiceABC.py → inference_service_abc.py} +1 -1
- edsl/inference_services/{InferenceServicesCollection.py → inference_services_collection.py} +8 -7
- edsl/inference_services/registry.py +4 -41
- edsl/inference_services/{ServiceAvailability.py → service_availability.py} +5 -25
- edsl/inference_services/services/__init__.py +31 -0
- edsl/inference_services/{AnthropicService.py → services/anthropic_service.py} +3 -3
- edsl/inference_services/{AwsBedrock.py → services/aws_bedrock.py} +2 -2
- edsl/inference_services/{AzureAI.py → services/azure_ai.py} +2 -2
- edsl/inference_services/{DeepInfraService.py → services/deep_infra_service.py} +1 -3
- edsl/inference_services/{DeepSeekService.py → services/deep_seek_service.py} +2 -4
- edsl/inference_services/{GoogleService.py → services/google_service.py} +5 -4
- edsl/inference_services/{GroqService.py → services/groq_service.py} +1 -1
- edsl/inference_services/{MistralAIService.py → services/mistral_ai_service.py} +3 -3
- edsl/inference_services/{OllamaService.py → services/ollama_service.py} +1 -7
- edsl/inference_services/{OpenAIService.py → services/open_ai_service.py} +5 -6
- edsl/inference_services/{PerplexityService.py → services/perplexity_service.py} +12 -12
- edsl/inference_services/{TestService.py → services/test_service.py} +7 -6
- edsl/inference_services/{TogetherAIService.py → services/together_ai_service.py} +2 -6
- edsl/inference_services/{XAIService.py → services/xai_service.py} +1 -1
- edsl/inference_services/write_available.py +1 -2
- edsl/instructions/__init__.py +6 -0
- edsl/{surveys/instructions/Instruction.py → instructions/instruction.py} +11 -6
- edsl/{surveys/instructions/InstructionCollection.py → instructions/instruction_collection.py} +10 -5
- edsl/{surveys/InstructionHandler.py → instructions/instruction_handler.py} +3 -3
- edsl/{jobs/interviews → interviews}/ReportErrors.py +2 -2
- edsl/interviews/__init__.py +4 -0
- edsl/{jobs/AnswerQuestionFunctionConstructor.py → interviews/answering_function.py} +45 -18
- edsl/{jobs/interviews/InterviewExceptionEntry.py → interviews/exception_tracking.py} +107 -22
- edsl/interviews/interview.py +638 -0
- edsl/{jobs/interviews/InterviewStatusDictionary.py → interviews/interview_status_dictionary.py} +21 -12
- edsl/{jobs/interviews/InterviewStatusLog.py → interviews/interview_status_log.py} +16 -7
- edsl/{jobs/InterviewTaskManager.py → interviews/interview_task_manager.py} +12 -7
- edsl/{jobs/RequestTokenEstimator.py → interviews/request_token_estimator.py} +8 -3
- edsl/{jobs/interviews/InterviewStatistic.py → interviews/statistics.py} +36 -10
- edsl/invigilators/__init__.py +38 -0
- edsl/invigilators/invigilator_base.py +477 -0
- edsl/{agents/Invigilator.py → invigilators/invigilators.py} +263 -10
- edsl/invigilators/prompt_constructor.py +476 -0
- edsl/{agents → invigilators}/prompt_helpers.py +2 -1
- edsl/{agents/QuestionInstructionPromptBuilder.py → invigilators/question_instructions_prompt_builder.py} +18 -13
- edsl/{agents → invigilators}/question_option_processor.py +96 -21
- edsl/{agents/QuestionTemplateReplacementsBuilder.py → invigilators/question_template_replacements_builder.py} +64 -12
- edsl/jobs/__init__.py +7 -1
- edsl/jobs/async_interview_runner.py +99 -35
- edsl/jobs/check_survey_scenario_compatibility.py +7 -5
- edsl/jobs/data_structures.py +153 -22
- edsl/{exceptions/jobs.py → jobs/exceptions.py} +2 -1
- edsl/jobs/{FetchInvigilator.py → fetch_invigilator.py} +4 -4
- edsl/jobs/{loggers/HTMLTableJobLogger.py → html_table_job_logger.py} +6 -2
- edsl/jobs/{Jobs.py → jobs.py} +321 -155
- edsl/jobs/{JobsChecks.py → jobs_checks.py} +15 -7
- edsl/jobs/{JobsComponentConstructor.py → jobs_component_constructor.py} +20 -17
- edsl/jobs/{InterviewsConstructor.py → jobs_interview_constructor.py} +10 -5
- edsl/jobs/jobs_pricing_estimation.py +347 -0
- edsl/jobs/{JobsRemoteInferenceLogger.py → jobs_remote_inference_logger.py} +4 -3
- edsl/jobs/jobs_runner_asyncio.py +282 -0
- edsl/jobs/{JobsRemoteInferenceHandler.py → remote_inference.py} +19 -22
- edsl/jobs/results_exceptions_handler.py +2 -2
- edsl/key_management/__init__.py +28 -0
- edsl/key_management/key_lookup.py +161 -0
- edsl/{language_models/key_management/KeyLookupBuilder.py → key_management/key_lookup_builder.py} +118 -47
- edsl/key_management/key_lookup_collection.py +82 -0
- edsl/key_management/models.py +218 -0
- edsl/language_models/__init__.py +7 -2
- edsl/language_models/{ComputeCost.py → compute_cost.py} +18 -3
- edsl/{exceptions/language_models.py → language_models/exceptions.py} +2 -1
- edsl/language_models/language_model.py +1080 -0
- edsl/language_models/model.py +10 -25
- edsl/language_models/{ModelList.py → model_list.py} +9 -14
- edsl/language_models/{RawResponseHandler.py → raw_response_handler.py} +1 -1
- edsl/language_models/{RegisterLanguageModelsMeta.py → registry.py} +1 -1
- edsl/language_models/repair.py +4 -4
- edsl/language_models/utilities.py +4 -4
- edsl/notebooks/__init__.py +3 -1
- edsl/notebooks/{Notebook.py → notebook.py} +7 -8
- edsl/prompts/__init__.py +1 -1
- edsl/{exceptions/prompts.py → prompts/exceptions.py} +3 -1
- edsl/prompts/{Prompt.py → prompt.py} +101 -95
- edsl/questions/HTMLQuestion.py +1 -1
- edsl/questions/__init__.py +154 -25
- edsl/questions/answer_validator_mixin.py +1 -1
- edsl/questions/compose_questions.py +4 -3
- edsl/questions/derived/question_likert_five.py +166 -0
- edsl/questions/derived/{QuestionLinearScale.py → question_linear_scale.py} +4 -4
- edsl/questions/derived/{QuestionTopK.py → question_top_k.py} +4 -4
- edsl/questions/derived/{QuestionYesNo.py → question_yes_no.py} +4 -5
- edsl/questions/descriptors.py +24 -30
- edsl/questions/loop_processor.py +65 -19
- edsl/questions/question_base.py +881 -0
- edsl/questions/question_base_gen_mixin.py +15 -16
- edsl/questions/{QuestionBasePromptsMixin.py → question_base_prompts_mixin.py} +2 -2
- edsl/questions/{QuestionBudget.py → question_budget.py} +3 -4
- edsl/questions/{QuestionCheckBox.py → question_check_box.py} +16 -16
- edsl/questions/{QuestionDict.py → question_dict.py} +39 -5
- edsl/questions/{QuestionExtract.py → question_extract.py} +9 -9
- edsl/questions/question_free_text.py +282 -0
- edsl/questions/{QuestionFunctional.py → question_functional.py} +6 -5
- edsl/questions/{QuestionList.py → question_list.py} +6 -7
- edsl/questions/{QuestionMatrix.py → question_matrix.py} +6 -5
- edsl/questions/{QuestionMultipleChoice.py → question_multiple_choice.py} +126 -21
- edsl/questions/{QuestionNumerical.py → question_numerical.py} +5 -5
- edsl/questions/{QuestionRank.py → question_rank.py} +6 -6
- edsl/questions/question_registry.py +10 -16
- edsl/questions/register_questions_meta.py +8 -4
- edsl/questions/response_validator_abc.py +17 -16
- edsl/results/__init__.py +4 -1
- edsl/{exceptions/results.py → results/exceptions.py} +1 -1
- edsl/results/report.py +197 -0
- edsl/results/{Result.py → result.py} +131 -45
- edsl/results/{Results.py → results.py} +420 -216
- edsl/results/results_selector.py +344 -25
- edsl/scenarios/__init__.py +30 -3
- edsl/scenarios/{ConstructDownloadLink.py → construct_download_link.py} +7 -0
- edsl/scenarios/directory_scanner.py +156 -13
- edsl/scenarios/document_chunker.py +186 -0
- edsl/scenarios/exceptions.py +101 -0
- edsl/scenarios/file_methods.py +2 -3
- edsl/scenarios/file_store.py +755 -0
- edsl/scenarios/handlers/__init__.py +14 -14
- edsl/scenarios/handlers/{csv.py → csv_file_store.py} +1 -2
- edsl/scenarios/handlers/{docx.py → docx_file_store.py} +8 -7
- edsl/scenarios/handlers/{html.py → html_file_store.py} +1 -2
- edsl/scenarios/handlers/{jpeg.py → jpeg_file_store.py} +1 -1
- edsl/scenarios/handlers/{json.py → json_file_store.py} +1 -1
- edsl/scenarios/handlers/latex_file_store.py +5 -0
- edsl/scenarios/handlers/{md.py → md_file_store.py} +1 -1
- edsl/scenarios/handlers/{pdf.py → pdf_file_store.py} +2 -2
- edsl/scenarios/handlers/{png.py → png_file_store.py} +1 -1
- edsl/scenarios/handlers/{pptx.py → pptx_file_store.py} +8 -7
- edsl/scenarios/handlers/{py.py → py_file_store.py} +1 -3
- edsl/scenarios/handlers/{sql.py → sql_file_store.py} +2 -1
- edsl/scenarios/handlers/{sqlite.py → sqlite_file_store.py} +2 -3
- edsl/scenarios/handlers/{txt.py → txt_file_store.py} +1 -1
- edsl/scenarios/scenario.py +928 -0
- edsl/scenarios/scenario_join.py +18 -5
- edsl/scenarios/{ScenarioList.py → scenario_list.py} +424 -106
- edsl/scenarios/{ScenarioListPdfMixin.py → scenario_list_pdf_tools.py} +16 -15
- edsl/scenarios/scenario_selector.py +5 -1
- edsl/study/ObjectEntry.py +2 -2
- edsl/study/SnapShot.py +5 -5
- edsl/study/Study.py +20 -21
- edsl/study/__init__.py +6 -4
- edsl/surveys/__init__.py +7 -4
- edsl/surveys/dag/__init__.py +2 -0
- edsl/surveys/{ConstructDAG.py → dag/construct_dag.py} +3 -3
- edsl/surveys/{DAG.py → dag/dag.py} +13 -10
- edsl/surveys/descriptors.py +1 -1
- edsl/surveys/{EditSurvey.py → edit_survey.py} +9 -9
- edsl/{exceptions/surveys.py → surveys/exceptions.py} +1 -2
- edsl/surveys/memory/__init__.py +3 -0
- edsl/surveys/{MemoryPlan.py → memory/memory_plan.py} +10 -9
- edsl/surveys/rules/__init__.py +3 -0
- edsl/surveys/{Rule.py → rules/rule.py} +103 -43
- edsl/surveys/{RuleCollection.py → rules/rule_collection.py} +21 -30
- edsl/surveys/{RuleManager.py → rules/rule_manager.py} +19 -13
- edsl/surveys/survey.py +1743 -0
- edsl/surveys/{SurveyExportMixin.py → survey_export.py} +22 -27
- edsl/surveys/{SurveyFlowVisualization.py → survey_flow_visualization.py} +11 -2
- edsl/surveys/{Simulator.py → survey_simulator.py} +10 -3
- edsl/tasks/__init__.py +32 -0
- edsl/{jobs/tasks/QuestionTaskCreator.py → tasks/question_task_creator.py} +115 -57
- edsl/tasks/task_creators.py +135 -0
- edsl/{jobs/tasks/TaskHistory.py → tasks/task_history.py} +86 -47
- edsl/{jobs/tasks → tasks}/task_status_enum.py +91 -7
- edsl/tasks/task_status_log.py +85 -0
- edsl/tokens/__init__.py +2 -0
- edsl/tokens/interview_token_usage.py +53 -0
- edsl/utilities/PrettyList.py +1 -1
- edsl/utilities/SystemInfo.py +25 -22
- edsl/utilities/__init__.py +29 -21
- edsl/utilities/gcp_bucket/__init__.py +2 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +99 -96
- edsl/utilities/interface.py +44 -536
- edsl/{results/MarkdownToPDF.py → utilities/markdown_to_pdf.py} +13 -5
- edsl/utilities/repair_functions.py +1 -1
- {edsl-0.1.46.dist-info → edsl-0.1.48.dist-info}/METADATA +3 -2
- edsl-0.1.48.dist-info/RECORD +347 -0
- edsl/Base.py +0 -426
- edsl/BaseDiff.py +0 -260
- edsl/agents/InvigilatorBase.py +0 -260
- edsl/agents/PromptConstructor.py +0 -318
- edsl/auto/AutoStudy.py +0 -130
- edsl/auto/StageBase.py +0 -243
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -74
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -218
- edsl/base/Base.py +0 -279
- edsl/coop/PriceFetcher.py +0 -54
- edsl/data/Cache.py +0 -580
- edsl/data/CacheEntry.py +0 -230
- edsl/data/SQLiteDict.py +0 -292
- edsl/data/__init__.py +0 -5
- edsl/data/orm.py +0 -10
- edsl/exceptions/cache.py +0 -5
- edsl/exceptions/coop.py +0 -14
- edsl/exceptions/data.py +0 -14
- edsl/exceptions/scenarios.py +0 -29
- edsl/jobs/Answers.py +0 -43
- edsl/jobs/JobsPrompts.py +0 -354
- edsl/jobs/buckets/BucketCollection.py +0 -134
- edsl/jobs/buckets/ModelBuckets.py +0 -65
- edsl/jobs/buckets/TokenBucket.py +0 -283
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/interviews/Interview.py +0 -395
- edsl/jobs/interviews/InterviewExceptionCollection.py +0 -99
- edsl/jobs/interviews/InterviewStatisticsCollection.py +0 -25
- edsl/jobs/runners/JobsRunnerAsyncio.py +0 -163
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -0
- edsl/jobs/tasks/TaskCreators.py +0 -64
- edsl/jobs/tasks/TaskStatusLog.py +0 -23
- edsl/jobs/tokens/InterviewTokenUsage.py +0 -27
- edsl/language_models/LanguageModel.py +0 -635
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/models.py +0 -137
- edsl/questions/QuestionBase.py +0 -539
- edsl/questions/QuestionFreeText.py +0 -130
- edsl/questions/derived/QuestionLikertFive.py +0 -76
- edsl/results/DatasetExportMixin.py +0 -911
- edsl/results/ResultsExportMixin.py +0 -45
- edsl/results/TextEditor.py +0 -50
- edsl/results/results_fetch_mixin.py +0 -33
- edsl/results/results_tools_mixin.py +0 -98
- edsl/scenarios/DocumentChunker.py +0 -104
- edsl/scenarios/FileStore.py +0 -564
- edsl/scenarios/Scenario.py +0 -548
- edsl/scenarios/ScenarioHtmlMixin.py +0 -65
- edsl/scenarios/ScenarioListExportMixin.py +0 -45
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/shared.py +0 -1
- edsl/surveys/Survey.py +0 -1306
- edsl/surveys/SurveyQualtricsImport.py +0 -284
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/tools/__init__.py +0 -1
- edsl/tools/clusters.py +0 -192
- edsl/tools/embeddings.py +0 -27
- edsl/tools/embeddings_plotting.py +0 -118
- edsl/tools/plotting.py +0 -112
- edsl/tools/summarize.py +0 -18
- edsl/utilities/data/Registry.py +0 -6
- edsl/utilities/data/__init__.py +0 -1
- edsl/utilities/data/scooter_results.json +0 -1
- edsl-0.1.46.dist-info/RECORD +0 -366
- /edsl/coop/{CoopFunctionsMixin.py → coop_functions.py} +0 -0
- /edsl/{results → dataset/display}/CSSParameterizer.py +0 -0
- /edsl/{language_models/key_management → dataset/display}/__init__.py +0 -0
- /edsl/{results → dataset/display}/table_data_class.py +0 -0
- /edsl/{results → dataset/display}/table_display.css +0 -0
- /edsl/{results/ResultsGGMixin.py → dataset/r/ggplot.py} +0 -0
- /edsl/{results → dataset}/tree_explore.py +0 -0
- /edsl/{surveys/instructions/ChangeInstruction.py → instructions/change_instruction.py} +0 -0
- /edsl/{jobs/interviews → interviews}/interview_status_enum.py +0 -0
- /edsl/jobs/{runners/JobsRunnerStatus.py → jobs_runner_status.py} +0 -0
- /edsl/language_models/{PriceManager.py → price_manager.py} +0 -0
- /edsl/language_models/{fake_openai_call.py → unused/fake_openai_call.py} +0 -0
- /edsl/language_models/{fake_openai_service.py → unused/fake_openai_service.py} +0 -0
- /edsl/notebooks/{NotebookToLaTeX.py → notebook_to_latex.py} +0 -0
- /edsl/{exceptions/questions.py → questions/exceptions.py} +0 -0
- /edsl/questions/{SimpleAskMixin.py → simple_ask_mixin.py} +0 -0
- /edsl/surveys/{Memory.py → memory/memory.py} +0 -0
- /edsl/surveys/{MemoryManagement.py → memory/memory_management.py} +0 -0
- /edsl/surveys/{SurveyCSS.py → survey_css.py} +0 -0
- /edsl/{jobs/tokens/TokenUsage.py → tokens/token_usage.py} +0 -0
- /edsl/{results/MarkdownToDocx.py → utilities/markdown_to_docx.py} +0 -0
- /edsl/{TemplateLoader.py → utilities/template_loader.py} +0 -0
- {edsl-0.1.46.dist-info → edsl-0.1.48.dist-info}/LICENSE +0 -0
- {edsl-0.1.46.dist-info → edsl-0.1.48.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1080 @@
|
|
1
|
+
"""Language model interface and base implementation for EDSL.
|
2
|
+
|
3
|
+
This module contains the LanguageModel abstract base class, which defines the interface
|
4
|
+
for all language models in the EDSL framework. The LanguageModel class handles common
|
5
|
+
functionality like caching, response parsing, token usage tracking, and cost calculation,
|
6
|
+
while specific model implementations (like GPT, Claude, etc.) implement the actual API calls.
|
7
|
+
|
8
|
+
Key concepts and terminology:
|
9
|
+
|
10
|
+
- raw_response: The complete JSON response returned directly from the model API.
|
11
|
+
Contains all model metadata and response information.
|
12
|
+
|
13
|
+
- edsl_augmented_response: The raw model response augmented with EDSL-specific
|
14
|
+
information, such as cache keys, token usage statistics, and cost data.
|
15
|
+
|
16
|
+
- generated_tokens: The actual text output generated by the model in response
|
17
|
+
to the prompt. This is the content displayed to the user.
|
18
|
+
|
19
|
+
- edsl_answer_dict: The standardized, parsed response from the model in the format
|
20
|
+
either {'answer': ...} or {'answer': ..., 'comment': ...} that EDSL uses internally.
|
21
|
+
|
22
|
+
The LanguageModel class handles these different representations and provides a
|
23
|
+
consistent interface regardless of which model provider is being used.
|
24
|
+
"""
|
25
|
+
|
26
|
+
from __future__ import annotations
|
27
|
+
from functools import wraps
|
28
|
+
import asyncio
|
29
|
+
import json
|
30
|
+
import os
|
31
|
+
import warnings
|
32
|
+
from abc import ABC, abstractmethod
|
33
|
+
|
34
|
+
from typing import (
|
35
|
+
Coroutine,
|
36
|
+
Any,
|
37
|
+
Type,
|
38
|
+
Union,
|
39
|
+
List,
|
40
|
+
get_type_hints,
|
41
|
+
TypedDict,
|
42
|
+
Optional,
|
43
|
+
TYPE_CHECKING,
|
44
|
+
)
|
45
|
+
|
46
|
+
from ..data_transfer_models import (
|
47
|
+
ModelResponse,
|
48
|
+
ModelInputs,
|
49
|
+
EDSLOutput,
|
50
|
+
AgentResponseDict,
|
51
|
+
)
|
52
|
+
|
53
|
+
if TYPE_CHECKING:
|
54
|
+
from ..caching import Cache
|
55
|
+
from ..scenarios import FileStore
|
56
|
+
from ..questions import QuestionBase
|
57
|
+
from ..key_management import KeyLookup
|
58
|
+
|
59
|
+
from ..enums import InferenceServiceType
|
60
|
+
|
61
|
+
from ..utilities import sync_wrapper, jupyter_nb_handler, remove_edsl_version, dict_hash
|
62
|
+
from ..base import PersistenceMixin, RepresentationMixin, HashingMixin
|
63
|
+
from ..key_management import KeyLookupCollection
|
64
|
+
|
65
|
+
from .registry import RegisterLanguageModelsMeta
|
66
|
+
from .raw_response_handler import RawResponseHandler
|
67
|
+
|
68
|
+
def handle_key_error(func):
|
69
|
+
"""Decorator to catch and provide user-friendly error messages for KeyError exceptions.
|
70
|
+
|
71
|
+
This decorator gracefully handles KeyError exceptions that may occur when parsing
|
72
|
+
model responses with unexpected structures, providing a clear error message to
|
73
|
+
help users understand what went wrong.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
func: The function to decorate
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
Decorated function that catches KeyError exceptions
|
80
|
+
"""
|
81
|
+
@wraps(func)
|
82
|
+
def wrapper(*args, **kwargs):
|
83
|
+
try:
|
84
|
+
return func(*args, **kwargs)
|
85
|
+
assert True == False # Unreachable code - this should be removed
|
86
|
+
except KeyError as e:
|
87
|
+
return f"""KeyError occurred: {e}. This is most likely because the model you are using
|
88
|
+
returned a JSON object we were not expecting."""
|
89
|
+
|
90
|
+
return wrapper
|
91
|
+
|
92
|
+
|
93
|
+
class classproperty:
|
94
|
+
"""Descriptor that combines @classmethod and @property behaviors.
|
95
|
+
|
96
|
+
This descriptor allows defining properties that work on the class itself
|
97
|
+
rather than on instances, making it possible to have computed attributes
|
98
|
+
at the class level.
|
99
|
+
|
100
|
+
Usage:
|
101
|
+
class MyClass:
|
102
|
+
@classproperty
|
103
|
+
def my_prop(cls):
|
104
|
+
return cls.__name__
|
105
|
+
"""
|
106
|
+
def __init__(self, method):
|
107
|
+
"""Initialize with the decorated method.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
method: The class method to be accessed as a property
|
111
|
+
"""
|
112
|
+
self.method = method
|
113
|
+
|
114
|
+
def __get__(self, instance, cls):
|
115
|
+
"""Return the result of calling the method on the class.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
instance: The instance (if called on an instance)
|
119
|
+
cls: The class (always provided)
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
The result of calling the method with the class as argument
|
123
|
+
"""
|
124
|
+
return self.method(cls)
|
125
|
+
|
126
|
+
|
127
|
+
|
128
|
+
|
129
|
+
class LanguageModel(
|
130
|
+
PersistenceMixin,
|
131
|
+
RepresentationMixin,
|
132
|
+
HashingMixin,
|
133
|
+
ABC,
|
134
|
+
metaclass=RegisterLanguageModelsMeta,
|
135
|
+
):
|
136
|
+
"""Abstract base class for all language model implementations in EDSL.
|
137
|
+
|
138
|
+
This class defines the common interface and functionality for interacting with
|
139
|
+
various language model providers (OpenAI, Anthropic, etc.). It handles caching,
|
140
|
+
response parsing, token usage tracking, and cost calculation, providing a
|
141
|
+
consistent interface regardless of the underlying model.
|
142
|
+
|
143
|
+
Subclasses must implement the async_execute_model_call method to handle the
|
144
|
+
actual API call to the model provider. Other methods may also be overridden
|
145
|
+
to customize behavior for specific models.
|
146
|
+
|
147
|
+
The class uses several mixins to provide serialization, pretty printing, and
|
148
|
+
hashing functionality, and a metaclass to automatically register model implementations.
|
149
|
+
|
150
|
+
Attributes:
|
151
|
+
_model_: The default model identifier (set by subclasses)
|
152
|
+
key_sequence: Path to extract generated text from model responses
|
153
|
+
DEFAULT_RPM: Default requests per minute rate limit
|
154
|
+
DEFAULT_TPM: Default tokens per minute rate limit
|
155
|
+
"""
|
156
|
+
|
157
|
+
_model_ = None
|
158
|
+
key_sequence = (
|
159
|
+
None # This should be something like ["choices", 0, "message", "content"]
|
160
|
+
)
|
161
|
+
|
162
|
+
DEFAULT_RPM = 100
|
163
|
+
DEFAULT_TPM = 1000
|
164
|
+
|
165
|
+
@classproperty
|
166
|
+
def response_handler(cls):
|
167
|
+
"""Get a handler for processing raw model responses.
|
168
|
+
|
169
|
+
This property creates a RawResponseHandler configured for the specific
|
170
|
+
model implementation, using the class's key_sequence and usage_sequence
|
171
|
+
attributes to know how to extract information from the model's response format.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
RawResponseHandler: Handler configured for this model type
|
175
|
+
"""
|
176
|
+
key_sequence = cls.key_sequence
|
177
|
+
usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
|
178
|
+
return RawResponseHandler(key_sequence, usage_sequence)
|
179
|
+
|
180
|
+
def __init__(
|
181
|
+
self,
|
182
|
+
tpm: Optional[float] = None,
|
183
|
+
rpm: Optional[float] = None,
|
184
|
+
omit_system_prompt_if_empty_string: bool = True,
|
185
|
+
key_lookup: Optional["KeyLookup"] = None,
|
186
|
+
**kwargs,
|
187
|
+
):
|
188
|
+
"""Initialize a new language model instance.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
tpm: Optional tokens per minute rate limit override
|
192
|
+
rpm: Optional requests per minute rate limit override
|
193
|
+
omit_system_prompt_if_empty_string: Whether to omit the system prompt when empty
|
194
|
+
key_lookup: Optional custom key lookup for API credentials
|
195
|
+
**kwargs: Additional parameters to pass to the model provider
|
196
|
+
|
197
|
+
The initialization process:
|
198
|
+
1. Sets up the model identifier from the class attribute
|
199
|
+
2. Configures model parameters by merging defaults with provided values
|
200
|
+
3. Sets up API key lookup and rate limits
|
201
|
+
4. Applies all parameters as instance attributes
|
202
|
+
|
203
|
+
For subclasses that define _parameters_ class attribute, these will be
|
204
|
+
used as default parameters that can be overridden by kwargs.
|
205
|
+
"""
|
206
|
+
# Get the model identifier from the class attribute
|
207
|
+
self.model = getattr(self, "_model_", None)
|
208
|
+
|
209
|
+
# Set up model parameters by combining defaults with provided values
|
210
|
+
default_parameters = getattr(self, "_parameters_", None)
|
211
|
+
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
212
|
+
self.parameters = parameters
|
213
|
+
|
214
|
+
# Initialize basic settings
|
215
|
+
self.remote = False
|
216
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
217
|
+
|
218
|
+
# Set up API key lookup and fetch model information
|
219
|
+
self.key_lookup = self._set_key_lookup(key_lookup)
|
220
|
+
self.model_info = self.key_lookup.get(self._inference_service_)
|
221
|
+
|
222
|
+
# Apply rate limit overrides if provided
|
223
|
+
if rpm is not None:
|
224
|
+
self._rpm = rpm
|
225
|
+
|
226
|
+
if tpm is not None:
|
227
|
+
self._tpm = tpm
|
228
|
+
|
229
|
+
# Apply all parameters as instance attributes
|
230
|
+
for key, value in parameters.items():
|
231
|
+
setattr(self, key, value)
|
232
|
+
|
233
|
+
# Apply any additional kwargs that aren't part of the standard parameters
|
234
|
+
for key, value in kwargs.items():
|
235
|
+
if key not in parameters:
|
236
|
+
setattr(self, key, value)
|
237
|
+
|
238
|
+
# Handle API key check skip for testing
|
239
|
+
if kwargs.get("skip_api_key_check", False):
|
240
|
+
# Skip the API key check. Sometimes this is useful for testing.
|
241
|
+
self._api_token = None
|
242
|
+
|
243
|
+
def _set_key_lookup(self, key_lookup: "KeyLookup") -> "KeyLookup":
|
244
|
+
"""Set up the API key lookup mechanism.
|
245
|
+
|
246
|
+
This method either uses the provided key lookup object or creates a default
|
247
|
+
one that looks for API keys in config files and environment variables.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
key_lookup: Optional custom key lookup object
|
251
|
+
|
252
|
+
Returns:
|
253
|
+
KeyLookup: The key lookup object to use for API credentials
|
254
|
+
"""
|
255
|
+
if key_lookup is not None:
|
256
|
+
return key_lookup
|
257
|
+
else:
|
258
|
+
klc = KeyLookupCollection()
|
259
|
+
klc.add_key_lookup(fetch_order=("config", "env"))
|
260
|
+
return klc.get(("config", "env"))
|
261
|
+
|
262
|
+
def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
263
|
+
"""Update the key lookup mechanism after initialization.
|
264
|
+
|
265
|
+
This method allows changing the API key lookup after the model has been
|
266
|
+
created, clearing any cached API tokens.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
key_lookup: The new key lookup object to use
|
270
|
+
"""
|
271
|
+
if hasattr(self, "_api_token"):
|
272
|
+
del self._api_token
|
273
|
+
self.key_lookup = key_lookup
|
274
|
+
|
275
|
+
def ask_question(self, question: "QuestionBase") -> str:
|
276
|
+
"""Ask a question using this language model and return the response.
|
277
|
+
|
278
|
+
This is a convenience method that extracts the necessary prompts from a
|
279
|
+
question object and makes a model call.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
question: The EDSL question object to ask
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
str: The model's response to the question
|
286
|
+
"""
|
287
|
+
user_prompt = question.get_instructions().render(question.data).text
|
288
|
+
system_prompt = "You are a helpful agent pretending to be a human."
|
289
|
+
return self.execute_model_call(user_prompt, system_prompt)
|
290
|
+
|
291
|
+
@property
|
292
|
+
def rpm(self):
|
293
|
+
"""Get the requests per minute rate limit for this model.
|
294
|
+
|
295
|
+
This property provides the rate limit either from an explicitly set value,
|
296
|
+
from the model info in the key lookup, or from the default value.
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
float: The requests per minute rate limit
|
300
|
+
"""
|
301
|
+
if not hasattr(self, "_rpm"):
|
302
|
+
if self.model_info is None:
|
303
|
+
self._rpm = self.DEFAULT_RPM
|
304
|
+
else:
|
305
|
+
self._rpm = self.model_info.rpm
|
306
|
+
return self._rpm
|
307
|
+
|
308
|
+
@property
|
309
|
+
def tpm(self):
|
310
|
+
"""Get the tokens per minute rate limit for this model.
|
311
|
+
|
312
|
+
This property provides the rate limit either from an explicitly set value,
|
313
|
+
from the model info in the key lookup, or from the default value.
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
float: The tokens per minute rate limit
|
317
|
+
"""
|
318
|
+
if not hasattr(self, "_tpm"):
|
319
|
+
if self.model_info is None:
|
320
|
+
self._tpm = self.DEFAULT_TPM
|
321
|
+
else:
|
322
|
+
self._tpm = self.model_info.tpm
|
323
|
+
return self._tpm
|
324
|
+
|
325
|
+
# Setters for rate limits
|
326
|
+
@tpm.setter
|
327
|
+
def tpm(self, value):
|
328
|
+
"""Set the tokens per minute rate limit.
|
329
|
+
|
330
|
+
Args:
|
331
|
+
value: The new tokens per minute limit
|
332
|
+
"""
|
333
|
+
self._tpm = value
|
334
|
+
|
335
|
+
@rpm.setter
|
336
|
+
def rpm(self, value):
|
337
|
+
"""Set the requests per minute rate limit.
|
338
|
+
|
339
|
+
Args:
|
340
|
+
value: The new requests per minute limit
|
341
|
+
"""
|
342
|
+
self._rpm = value
|
343
|
+
|
344
|
+
@property
|
345
|
+
def api_token(self) -> str:
|
346
|
+
"""Get the API token for this model's service.
|
347
|
+
|
348
|
+
This property lazily fetches the API token from the key lookup
|
349
|
+
mechanism when first accessed, caching it for subsequent uses.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
str: The API token for authenticating with the model provider
|
353
|
+
|
354
|
+
Raises:
|
355
|
+
ValueError: If no API key is found for this model's service
|
356
|
+
"""
|
357
|
+
if not hasattr(self, "_api_token"):
|
358
|
+
info = self.key_lookup.get(self._inference_service_, None)
|
359
|
+
if info is None:
|
360
|
+
raise ValueError(
|
361
|
+
f"No key found for service '{self._inference_service_}'"
|
362
|
+
)
|
363
|
+
self._api_token = info.api_token
|
364
|
+
return self._api_token
|
365
|
+
|
366
|
+
def __getitem__(self, key):
|
367
|
+
"""Allow dictionary-style access to model attributes.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
key: The attribute name to access
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
The value of the specified attribute
|
374
|
+
"""
|
375
|
+
return getattr(self, key)
|
376
|
+
|
377
|
+
def hello(self, verbose=False):
|
378
|
+
"""Run a simple test to verify the model connection is working.
|
379
|
+
|
380
|
+
This method makes a basic model call to check if the API credentials
|
381
|
+
are valid and the model is responsive.
|
382
|
+
|
383
|
+
Args:
|
384
|
+
verbose: If True, prints the masked API token
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
str: The model's response to a simple greeting
|
388
|
+
"""
|
389
|
+
token = self.api_token
|
390
|
+
masked = token[: min(8, len(token))] + "..."
|
391
|
+
if verbose:
|
392
|
+
print(f"Current key is {masked}")
|
393
|
+
return self.execute_model_call(
|
394
|
+
user_prompt="Hello, model!", system_prompt="You are a helpful agent."
|
395
|
+
)
|
396
|
+
|
397
|
+
def has_valid_api_key(self) -> bool:
|
398
|
+
"""Check if the model has a valid API key available.
|
399
|
+
|
400
|
+
This method verifies if the necessary API key is available in
|
401
|
+
environment variables or configuration for this model's service.
|
402
|
+
Test models always return True.
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
bool: True if a valid API key is available, False otherwise
|
406
|
+
|
407
|
+
Examples:
|
408
|
+
>>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
|
409
|
+
True
|
410
|
+
"""
|
411
|
+
from ..enums import service_to_api_keyname
|
412
|
+
|
413
|
+
if self._model_ == "test":
|
414
|
+
return True
|
415
|
+
|
416
|
+
key_name = service_to_api_keyname.get(self._inference_service_, "NOT FOUND")
|
417
|
+
key_value = os.getenv(key_name)
|
418
|
+
return key_value is not None
|
419
|
+
|
420
|
+
def __hash__(self) -> int:
|
421
|
+
"""Generate a hash value based on model identity and parameters.
|
422
|
+
|
423
|
+
This method allows language model instances to be used as dictionary
|
424
|
+
keys or in sets by providing a stable hash value based on the
|
425
|
+
model's essential characteristics.
|
426
|
+
|
427
|
+
Returns:
|
428
|
+
int: A hash value for the model instance
|
429
|
+
|
430
|
+
Examples:
|
431
|
+
>>> m = LanguageModel.example()
|
432
|
+
>>> hash(m) # Actual value may vary
|
433
|
+
325654563661254408
|
434
|
+
"""
|
435
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
436
|
+
|
437
|
+
def __eq__(self, other) -> bool:
|
438
|
+
"""Check if two language model instances are functionally equivalent.
|
439
|
+
|
440
|
+
Two models are considered equal if they have the same model identifier
|
441
|
+
and the same parameter settings, meaning they would produce the same
|
442
|
+
outputs given the same inputs.
|
443
|
+
|
444
|
+
Args:
|
445
|
+
other: Another model to compare with
|
446
|
+
|
447
|
+
Returns:
|
448
|
+
bool: True if the models are functionally equivalent
|
449
|
+
|
450
|
+
Examples:
|
451
|
+
>>> m1 = LanguageModel.example()
|
452
|
+
>>> m2 = LanguageModel.example()
|
453
|
+
>>> m1 == m2
|
454
|
+
True
|
455
|
+
"""
|
456
|
+
return self.model == other.model and self.parameters == other.parameters
|
457
|
+
|
458
|
+
@staticmethod
|
459
|
+
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
460
|
+
"""Merge default parameters with user-specified parameters.
|
461
|
+
|
462
|
+
This method creates a parameter dictionary where explicitly passed
|
463
|
+
parameters take precedence over default values, while ensuring all
|
464
|
+
required parameters have a value.
|
465
|
+
|
466
|
+
Args:
|
467
|
+
passed_parameter_dict: Dictionary of user-specified parameters
|
468
|
+
default_parameter_dict: Dictionary of default parameter values
|
469
|
+
|
470
|
+
Returns:
|
471
|
+
dict: Combined parameter dictionary with defaults and overrides
|
472
|
+
|
473
|
+
Examples:
|
474
|
+
>>> LanguageModel._overide_default_parameters(
|
475
|
+
... passed_parameter_dict={"temperature": 0.5},
|
476
|
+
... default_parameter_dict={"temperature": 0.9})
|
477
|
+
{'temperature': 0.5}
|
478
|
+
|
479
|
+
>>> LanguageModel._overide_default_parameters(
|
480
|
+
... passed_parameter_dict={"temperature": 0.5},
|
481
|
+
... default_parameter_dict={"temperature": 0.9, "max_tokens": 1000})
|
482
|
+
{'temperature': 0.5, 'max_tokens': 1000}
|
483
|
+
"""
|
484
|
+
# Handle the case when data is loaded from a dict after serialization
|
485
|
+
if "parameters" in passed_parameter_dict:
|
486
|
+
passed_parameter_dict = passed_parameter_dict["parameters"]
|
487
|
+
|
488
|
+
# Create new dict with defaults, overridden by passed parameters
|
489
|
+
return {
|
490
|
+
parameter_name: passed_parameter_dict.get(parameter_name, default_value)
|
491
|
+
for parameter_name, default_value in default_parameter_dict.items()
|
492
|
+
}
|
493
|
+
|
494
|
+
def __call__(self, user_prompt: str, system_prompt: str):
|
495
|
+
"""Allow the model to be called directly as a function.
|
496
|
+
|
497
|
+
This method provides a convenient way to use the model by calling
|
498
|
+
it directly with prompts, like `response = model(user_prompt, system_prompt)`.
|
499
|
+
|
500
|
+
Args:
|
501
|
+
user_prompt: The user message or input prompt
|
502
|
+
system_prompt: The system message or context
|
503
|
+
|
504
|
+
Returns:
|
505
|
+
The response from the model
|
506
|
+
"""
|
507
|
+
return self.execute_model_call(user_prompt, system_prompt)
|
508
|
+
|
509
|
+
@abstractmethod
|
510
|
+
async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
|
511
|
+
"""Execute the model call asynchronously.
|
512
|
+
|
513
|
+
This abstract method must be implemented by all model subclasses
|
514
|
+
to handle the actual API call to the language model provider.
|
515
|
+
|
516
|
+
Args:
|
517
|
+
user_prompt: The user message or input prompt
|
518
|
+
system_prompt: The system message or context
|
519
|
+
|
520
|
+
Returns:
|
521
|
+
Coroutine that resolves to the model response
|
522
|
+
|
523
|
+
Note:
|
524
|
+
Implementations should handle the actual API communication,
|
525
|
+
including authentication, request formatting, and response parsing.
|
526
|
+
"""
|
527
|
+
pass
|
528
|
+
|
529
|
+
async def remote_async_execute_model_call(
|
530
|
+
self, user_prompt: str, system_prompt: str
|
531
|
+
):
|
532
|
+
"""Execute the model call remotely through the EDSL Coop service.
|
533
|
+
|
534
|
+
This method allows offloading the model call to a remote server,
|
535
|
+
which can be useful for models not available in the local environment
|
536
|
+
or to avoid rate limits.
|
537
|
+
|
538
|
+
Args:
|
539
|
+
user_prompt: The user message or input prompt
|
540
|
+
system_prompt: The system message or context
|
541
|
+
|
542
|
+
Returns:
|
543
|
+
Coroutine that resolves to the model response from the remote service
|
544
|
+
"""
|
545
|
+
from ..coop import Coop
|
546
|
+
|
547
|
+
client = Coop()
|
548
|
+
response_data = await client.remote_async_execute_model_call(
|
549
|
+
self.to_dict(), user_prompt, system_prompt
|
550
|
+
)
|
551
|
+
return response_data
|
552
|
+
|
553
|
+
@jupyter_nb_handler
|
554
|
+
def execute_model_call(self, *args, **kwargs):
|
555
|
+
"""Execute a model call synchronously.
|
556
|
+
|
557
|
+
This method is a synchronous wrapper around the asynchronous execution,
|
558
|
+
making it easier to use the model in non-async contexts. It's decorated
|
559
|
+
with jupyter_nb_handler to ensure proper handling in notebook environments.
|
560
|
+
|
561
|
+
Args:
|
562
|
+
*args: Positional arguments to pass to async_execute_model_call
|
563
|
+
**kwargs: Keyword arguments to pass to async_execute_model_call
|
564
|
+
|
565
|
+
Returns:
|
566
|
+
The model response
|
567
|
+
"""
|
568
|
+
async def main():
|
569
|
+
results = await asyncio.gather(
|
570
|
+
self.async_execute_model_call(*args, **kwargs)
|
571
|
+
)
|
572
|
+
return results[0] # Since there's only one task, return its result
|
573
|
+
|
574
|
+
return main()
|
575
|
+
|
576
|
+
@classmethod
|
577
|
+
def get_generated_token_string(cls, raw_response: dict[str, Any]) -> str:
|
578
|
+
"""Extract the generated text from a raw model response.
|
579
|
+
|
580
|
+
This method navigates the response structure using the model's key_sequence
|
581
|
+
to find and return just the generated text, without metadata.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
raw_response: The complete response dictionary from the model API
|
585
|
+
|
586
|
+
Returns:
|
587
|
+
str: The generated text string
|
588
|
+
|
589
|
+
Examples:
|
590
|
+
>>> m = LanguageModel.example(test_model=True)
|
591
|
+
>>> raw_response = m.execute_model_call("Hello, model!", "You are a helpful agent.")
|
592
|
+
>>> m.get_generated_token_string(raw_response)
|
593
|
+
'Hello world'
|
594
|
+
"""
|
595
|
+
return cls.response_handler.get_generated_token_string(raw_response)
|
596
|
+
|
597
|
+
@classmethod
|
598
|
+
def get_usage_dict(cls, raw_response: dict[str, Any]) -> dict[str, Any]:
|
599
|
+
"""Extract token usage statistics from a raw model response.
|
600
|
+
|
601
|
+
This method navigates the response structure to find and return
|
602
|
+
information about token usage, which is used for cost calculation
|
603
|
+
and monitoring.
|
604
|
+
|
605
|
+
Args:
|
606
|
+
raw_response: The complete response dictionary from the model API
|
607
|
+
|
608
|
+
Returns:
|
609
|
+
dict: Dictionary of token usage statistics (input tokens, output tokens, etc.)
|
610
|
+
"""
|
611
|
+
return cls.response_handler.get_usage_dict(raw_response)
|
612
|
+
|
613
|
+
@classmethod
|
614
|
+
def parse_response(cls, raw_response: dict[str, Any]) -> EDSLOutput:
|
615
|
+
"""Parse the raw API response into a standardized EDSL output format.
|
616
|
+
|
617
|
+
This method processes the model's response to extract the generated content
|
618
|
+
and format it according to EDSL's expected structure, making it consistent
|
619
|
+
across different model providers.
|
620
|
+
|
621
|
+
Args:
|
622
|
+
raw_response: The complete response dictionary from the model API
|
623
|
+
|
624
|
+
Returns:
|
625
|
+
EDSLOutput: Standardized output structure with answer and optional comment
|
626
|
+
"""
|
627
|
+
return cls.response_handler.parse_response(raw_response)
|
628
|
+
|
629
|
+
async def _async_get_intended_model_call_outcome(
|
630
|
+
self,
|
631
|
+
user_prompt: str,
|
632
|
+
system_prompt: str,
|
633
|
+
cache: "Cache",
|
634
|
+
iteration: int = 0,
|
635
|
+
files_list: Optional[List["FileStore"]] = None,
|
636
|
+
invigilator=None,
|
637
|
+
) -> ModelResponse:
|
638
|
+
"""Handle model calls with caching for efficiency.
|
639
|
+
|
640
|
+
This method implements the caching logic for model calls, checking if a
|
641
|
+
response is already cached before making an actual API call. It handles
|
642
|
+
the complete workflow of:
|
643
|
+
1. Creating a cache key from the prompts and parameters
|
644
|
+
2. Checking if the response is already in the cache
|
645
|
+
3. Making the API call if needed
|
646
|
+
4. Storing new responses in the cache
|
647
|
+
5. Adding metadata like cost and cache status
|
648
|
+
|
649
|
+
Args:
|
650
|
+
user_prompt: The user's message or input prompt
|
651
|
+
system_prompt: The system's message or context
|
652
|
+
cache: The cache object to use for storing/retrieving responses
|
653
|
+
iteration: The iteration number, used for the cache key
|
654
|
+
files_list: Optional list of files to include in the prompt
|
655
|
+
invigilator: Optional invigilator object, not used in caching
|
656
|
+
|
657
|
+
Returns:
|
658
|
+
ModelResponse: Response object with the model output and metadata
|
659
|
+
|
660
|
+
Examples:
|
661
|
+
>>> from edsl import Cache
|
662
|
+
>>> m = LanguageModel.example(test_model=True)
|
663
|
+
>>> m._get_intended_model_call_outcome(user_prompt="Hello", system_prompt="hello", cache=Cache())
|
664
|
+
ModelResponse(...)
|
665
|
+
"""
|
666
|
+
# Add file hashes to the prompt if files are provided
|
667
|
+
if files_list:
|
668
|
+
files_hash = "+".join([str(hash(file)) for file in files_list])
|
669
|
+
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
670
|
+
else:
|
671
|
+
user_prompt_with_hashes = user_prompt
|
672
|
+
|
673
|
+
# Prepare parameters for cache lookup
|
674
|
+
cache_call_params = {
|
675
|
+
"model": str(self.model),
|
676
|
+
"parameters": self.parameters,
|
677
|
+
"system_prompt": system_prompt,
|
678
|
+
"user_prompt": user_prompt_with_hashes,
|
679
|
+
"iteration": iteration,
|
680
|
+
}
|
681
|
+
|
682
|
+
# Try to fetch from cache
|
683
|
+
cached_response, cache_key = cache.fetch(**cache_call_params)
|
684
|
+
|
685
|
+
if cache_used := cached_response is not None:
|
686
|
+
# Cache hit - use the cached response
|
687
|
+
response = json.loads(cached_response)
|
688
|
+
else:
|
689
|
+
# Cache miss - make a new API call
|
690
|
+
# Determine whether to use remote or local execution
|
691
|
+
f = (
|
692
|
+
self.remote_async_execute_model_call
|
693
|
+
if hasattr(self, "remote") and self.remote
|
694
|
+
else self.async_execute_model_call
|
695
|
+
)
|
696
|
+
|
697
|
+
# Prepare parameters for the model call
|
698
|
+
params = {
|
699
|
+
"user_prompt": user_prompt,
|
700
|
+
"system_prompt": system_prompt,
|
701
|
+
"files_list": files_list,
|
702
|
+
}
|
703
|
+
|
704
|
+
# Get timeout from configuration
|
705
|
+
from ..config import CONFIG
|
706
|
+
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
707
|
+
|
708
|
+
# Execute the model call with timeout
|
709
|
+
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
710
|
+
|
711
|
+
# Store the response in the cache
|
712
|
+
new_cache_key = cache.store(
|
713
|
+
**cache_call_params, response=response, service=self._inference_service_
|
714
|
+
)
|
715
|
+
assert new_cache_key == cache_key # Verify cache key integrity
|
716
|
+
|
717
|
+
# Calculate cost for the response
|
718
|
+
cost = self.cost(response)
|
719
|
+
|
720
|
+
# Return a structured response with metadata
|
721
|
+
return ModelResponse(
|
722
|
+
response=response,
|
723
|
+
cache_used=cache_used,
|
724
|
+
cache_key=cache_key,
|
725
|
+
cached_response=cached_response,
|
726
|
+
cost=cost,
|
727
|
+
)
|
728
|
+
|
729
|
+
_get_intended_model_call_outcome = sync_wrapper(
|
730
|
+
_async_get_intended_model_call_outcome
|
731
|
+
)
|
732
|
+
|
733
|
+
def simple_ask(
|
734
|
+
self,
|
735
|
+
question: QuestionBase,
|
736
|
+
system_prompt="You are a helpful agent pretending to be a human.",
|
737
|
+
top_logprobs=2,
|
738
|
+
):
|
739
|
+
"""Ask a simple question with log probability tracking.
|
740
|
+
|
741
|
+
This is a convenience method for getting responses with log probabilities,
|
742
|
+
which can be useful for analyzing model confidence and alternatives.
|
743
|
+
|
744
|
+
Args:
|
745
|
+
question: The EDSL question object to ask
|
746
|
+
system_prompt: System message to use (default is human agent instruction)
|
747
|
+
top_logprobs: Number of top alternative tokens to return probabilities for
|
748
|
+
|
749
|
+
Returns:
|
750
|
+
The model response, including log probabilities if supported
|
751
|
+
"""
|
752
|
+
self.logprobs = True
|
753
|
+
self.top_logprobs = top_logprobs
|
754
|
+
return self.execute_model_call(
|
755
|
+
user_prompt=question.human_readable(), system_prompt=system_prompt
|
756
|
+
)
|
757
|
+
|
758
|
+
async def async_get_response(
|
759
|
+
self,
|
760
|
+
user_prompt: str,
|
761
|
+
system_prompt: str,
|
762
|
+
cache: Cache,
|
763
|
+
iteration: int = 1,
|
764
|
+
files_list: Optional[List[FileStore]] = None,
|
765
|
+
**kwargs,
|
766
|
+
) -> AgentResponseDict:
|
767
|
+
"""Get a complete response with all metadata and parsed format.
|
768
|
+
|
769
|
+
This method handles the complete pipeline for:
|
770
|
+
1. Making a model call (with caching)
|
771
|
+
2. Parsing the response
|
772
|
+
3. Constructing a full response object with inputs, outputs, and parsed data
|
773
|
+
|
774
|
+
It's the primary method used by higher-level components to interact with models.
|
775
|
+
|
776
|
+
Args:
|
777
|
+
user_prompt: The user's message or input prompt
|
778
|
+
system_prompt: The system's message or context
|
779
|
+
cache: The cache object to use for storing/retrieving responses
|
780
|
+
iteration: The iteration number (default: 1)
|
781
|
+
files_list: Optional list of files to include in the prompt
|
782
|
+
**kwargs: Additional parameters (invigilator can be provided here)
|
783
|
+
|
784
|
+
Returns:
|
785
|
+
AgentResponseDict: Complete response object with inputs, raw outputs, and parsed data
|
786
|
+
"""
|
787
|
+
# Prepare parameters for the cached model call
|
788
|
+
params = {
|
789
|
+
"user_prompt": user_prompt,
|
790
|
+
"system_prompt": system_prompt,
|
791
|
+
"iteration": iteration,
|
792
|
+
"cache": cache,
|
793
|
+
"files_list": files_list,
|
794
|
+
}
|
795
|
+
|
796
|
+
# Add invigilator if provided
|
797
|
+
if "invigilator" in kwargs:
|
798
|
+
params.update({"invigilator": kwargs["invigilator"]})
|
799
|
+
|
800
|
+
# Create structured input record
|
801
|
+
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
802
|
+
|
803
|
+
# Get model response (using cache if available)
|
804
|
+
model_outputs: ModelResponse = (
|
805
|
+
await self._async_get_intended_model_call_outcome(**params)
|
806
|
+
)
|
807
|
+
|
808
|
+
# Parse the response into EDSL's standard format
|
809
|
+
edsl_dict: EDSLOutput = self.parse_response(model_outputs.response)
|
810
|
+
|
811
|
+
# Combine everything into a complete response object
|
812
|
+
agent_response_dict = AgentResponseDict(
|
813
|
+
model_inputs=model_inputs,
|
814
|
+
model_outputs=model_outputs,
|
815
|
+
edsl_dict=edsl_dict,
|
816
|
+
)
|
817
|
+
|
818
|
+
return agent_response_dict
|
819
|
+
|
820
|
+
get_response = sync_wrapper(async_get_response)
|
821
|
+
|
822
|
+
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
823
|
+
"""Calculate the monetary cost of a model API call.
|
824
|
+
|
825
|
+
This method extracts token usage information from the response and
|
826
|
+
uses the price manager to calculate the actual cost in dollars based
|
827
|
+
on the model's pricing structure and token counts.
|
828
|
+
|
829
|
+
Args:
|
830
|
+
raw_response: The complete response dictionary from the model API
|
831
|
+
|
832
|
+
Returns:
|
833
|
+
Union[float, str]: The calculated cost in dollars, or an error message
|
834
|
+
"""
|
835
|
+
# Extract token usage data from the response
|
836
|
+
usage = self.get_usage_dict(raw_response)
|
837
|
+
|
838
|
+
# Use the price manager to calculate the actual cost
|
839
|
+
from .price_manager import PriceManager
|
840
|
+
price_manager = PriceManager()
|
841
|
+
|
842
|
+
return price_manager.calculate_cost(
|
843
|
+
inference_service=self._inference_service_,
|
844
|
+
model=self.model,
|
845
|
+
usage=usage,
|
846
|
+
input_token_name=self.input_token_name,
|
847
|
+
output_token_name=self.output_token_name,
|
848
|
+
)
|
849
|
+
|
850
|
+
def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
|
851
|
+
"""Serialize the model instance to a dictionary representation.
|
852
|
+
|
853
|
+
This method creates a dictionary containing all the information needed
|
854
|
+
to recreate this model, including its identifier, parameters, and service.
|
855
|
+
Optionally includes EDSL version information for compatibility checking.
|
856
|
+
|
857
|
+
Args:
|
858
|
+
add_edsl_version: Whether to include EDSL version and class name (default: True)
|
859
|
+
|
860
|
+
Returns:
|
861
|
+
dict: Dictionary representation of this model instance
|
862
|
+
|
863
|
+
Examples:
|
864
|
+
>>> m = LanguageModel.example()
|
865
|
+
>>> m.to_dict()
|
866
|
+
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
867
|
+
"""
|
868
|
+
# Build the base dictionary with essential model information
|
869
|
+
d = {
|
870
|
+
"model": self.model,
|
871
|
+
"parameters": self.parameters,
|
872
|
+
"inference_service": self._inference_service_,
|
873
|
+
}
|
874
|
+
|
875
|
+
# Add EDSL version and class information if requested
|
876
|
+
if add_edsl_version:
|
877
|
+
from edsl import __version__
|
878
|
+
|
879
|
+
d["edsl_version"] = __version__
|
880
|
+
d["edsl_class_name"] = self.__class__.__name__
|
881
|
+
|
882
|
+
return d
|
883
|
+
|
884
|
+
@classmethod
|
885
|
+
@remove_edsl_version
|
886
|
+
def from_dict(cls, data: dict) -> "LanguageModel":
|
887
|
+
"""Create a language model instance from a dictionary representation.
|
888
|
+
|
889
|
+
This class method deserializes a model from its dictionary representation,
|
890
|
+
finding the correct model class based on the model identifier and service.
|
891
|
+
|
892
|
+
Args:
|
893
|
+
data: Dictionary containing the model configuration
|
894
|
+
|
895
|
+
Returns:
|
896
|
+
LanguageModel: A new model instance of the appropriate type
|
897
|
+
|
898
|
+
Note:
|
899
|
+
This method does not use the stored inference_service directly but
|
900
|
+
fetches the model class based on the model name and service name.
|
901
|
+
"""
|
902
|
+
from .model import get_model_class
|
903
|
+
|
904
|
+
# Determine the appropriate model class
|
905
|
+
model_class = get_model_class(
|
906
|
+
data["model"], service_name=data.get("inference_service", None)
|
907
|
+
)
|
908
|
+
|
909
|
+
# Create and return a new instance
|
910
|
+
return model_class(**data)
|
911
|
+
|
912
|
+
def __repr__(self) -> str:
|
913
|
+
"""Generate a string representation of the model.
|
914
|
+
|
915
|
+
This representation includes the model identifier and all parameters,
|
916
|
+
providing a clear picture of how the model is configured.
|
917
|
+
|
918
|
+
Returns:
|
919
|
+
str: A string representation of the model
|
920
|
+
"""
|
921
|
+
# Format the parameters as a string
|
922
|
+
param_string = ", ".join(
|
923
|
+
f"{key} = {value}" for key, value in self.parameters.items()
|
924
|
+
)
|
925
|
+
|
926
|
+
# Combine model name and parameters
|
927
|
+
return (
|
928
|
+
f"Model(model_name = '{self.model}'"
|
929
|
+
+ (f", {param_string}" if param_string else "")
|
930
|
+
+ ")"
|
931
|
+
)
|
932
|
+
|
933
|
+
def __add__(self, other_model: "LanguageModel") -> "LanguageModel":
|
934
|
+
"""Define behavior when models are combined with the + operator.
|
935
|
+
|
936
|
+
This operator is used in survey builder contexts, but note that it
|
937
|
+
replaces the left model with the right model rather than combining them.
|
938
|
+
|
939
|
+
Args:
|
940
|
+
other_model: Another model to combine with this one
|
941
|
+
|
942
|
+
Returns:
|
943
|
+
LanguageModel: The other model if provided, otherwise this model
|
944
|
+
|
945
|
+
Warning:
|
946
|
+
This doesn't truly combine models - it replaces one with the other.
|
947
|
+
For running multiple models, use a single 'by' call with multiple models.
|
948
|
+
"""
|
949
|
+
warnings.warn(
|
950
|
+
"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
|
951
|
+
by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
|
952
|
+
)
|
953
|
+
return other_model or self
|
954
|
+
|
955
|
+
@classmethod
|
956
|
+
def example(
|
957
|
+
cls,
|
958
|
+
test_model: bool = False,
|
959
|
+
canned_response: str = "Hello world",
|
960
|
+
throw_exception: bool = False,
|
961
|
+
) -> "LanguageModel":
|
962
|
+
"""Create an example language model instance for testing and demonstration.
|
963
|
+
|
964
|
+
This method provides a convenient way to create a model instance for
|
965
|
+
examples, tests, and documentation. It can create either a real model
|
966
|
+
(with API key checking disabled) or a test model that returns predefined
|
967
|
+
responses.
|
968
|
+
|
969
|
+
Args:
|
970
|
+
test_model: If True, creates a test model that doesn't make real API calls
|
971
|
+
canned_response: For test models, the predefined response to return
|
972
|
+
throw_exception: For test models, whether to throw an exception instead of responding
|
973
|
+
|
974
|
+
Returns:
|
975
|
+
LanguageModel: An example model instance
|
976
|
+
|
977
|
+
Examples:
|
978
|
+
Create a test model with a custom response:
|
979
|
+
|
980
|
+
>>> from edsl.language_models import LanguageModel
|
981
|
+
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!")
|
982
|
+
>>> isinstance(m, LanguageModel)
|
983
|
+
True
|
984
|
+
|
985
|
+
Use the test model to answer a question:
|
986
|
+
|
987
|
+
>>> from edsl import QuestionFreeText
|
988
|
+
>>> q = QuestionFreeText(question_text="What is your name?", question_name='example')
|
989
|
+
>>> q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True).select('example').first()
|
990
|
+
'WOWZA!'
|
991
|
+
|
992
|
+
Create a test model that throws exceptions:
|
993
|
+
|
994
|
+
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
|
995
|
+
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
|
996
|
+
Exception report saved to ...
|
997
|
+
"""
|
998
|
+
from ..language_models import Model
|
999
|
+
|
1000
|
+
if test_model:
|
1001
|
+
# Create a test model with predefined behavior
|
1002
|
+
m = Model(
|
1003
|
+
"test", canned_response=canned_response, throw_exception=throw_exception
|
1004
|
+
)
|
1005
|
+
return m
|
1006
|
+
else:
|
1007
|
+
# Create a regular model with API key checking disabled
|
1008
|
+
return Model(skip_api_key_check=True)
|
1009
|
+
|
1010
|
+
def from_cache(self, cache: "Cache") -> "LanguageModel":
|
1011
|
+
"""Create a new model that only returns responses from the cache.
|
1012
|
+
|
1013
|
+
This method creates a modified copy of the model that will only use
|
1014
|
+
cached responses, never making new API calls. This is useful for
|
1015
|
+
offline operation or repeatable experiments.
|
1016
|
+
|
1017
|
+
Args:
|
1018
|
+
cache: The cache object containing previously cached responses
|
1019
|
+
|
1020
|
+
Returns:
|
1021
|
+
LanguageModel: A new model instance that only reads from cache
|
1022
|
+
"""
|
1023
|
+
from copy import deepcopy
|
1024
|
+
from types import MethodType
|
1025
|
+
from ..caching import Cache
|
1026
|
+
|
1027
|
+
# Create a deep copy of this model instance
|
1028
|
+
new_instance = deepcopy(self)
|
1029
|
+
print("Cache entries", len(cache))
|
1030
|
+
|
1031
|
+
# Filter the cache to only include entries for this model
|
1032
|
+
new_instance.cache = Cache(
|
1033
|
+
data={k: v for k, v in cache.items() if v.model == self.model}
|
1034
|
+
)
|
1035
|
+
print("Cache entries with same model", len(new_instance.cache))
|
1036
|
+
|
1037
|
+
# Store prompt lists for reference
|
1038
|
+
new_instance.user_prompts = [
|
1039
|
+
ce.user_prompt for ce in new_instance.cache.values()
|
1040
|
+
]
|
1041
|
+
new_instance.system_prompts = [
|
1042
|
+
ce.system_prompt for ce in new_instance.cache.values()
|
1043
|
+
]
|
1044
|
+
|
1045
|
+
# Define a new async_execute_model_call that only reads from cache
|
1046
|
+
async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
|
1047
|
+
"""Only use cached responses, never making new API calls."""
|
1048
|
+
cache_call_params = {
|
1049
|
+
"model": str(self.model),
|
1050
|
+
"parameters": self.parameters,
|
1051
|
+
"system_prompt": system_prompt,
|
1052
|
+
"user_prompt": user_prompt,
|
1053
|
+
"iteration": 1,
|
1054
|
+
}
|
1055
|
+
cached_response, cache_key = cache.fetch(**cache_call_params)
|
1056
|
+
response = json.loads(cached_response)
|
1057
|
+
cost = 0
|
1058
|
+
return ModelResponse(
|
1059
|
+
response=response,
|
1060
|
+
cache_used=True,
|
1061
|
+
cache_key=cache_key,
|
1062
|
+
cached_response=cached_response,
|
1063
|
+
cost=cost,
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
# Bind the new method to the copied instance
|
1067
|
+
setattr(
|
1068
|
+
new_instance,
|
1069
|
+
"async_execute_model_call",
|
1070
|
+
MethodType(async_execute_model_call, new_instance),
|
1071
|
+
)
|
1072
|
+
|
1073
|
+
return new_instance
|
1074
|
+
|
1075
|
+
|
1076
|
+
if __name__ == "__main__":
|
1077
|
+
"""Run the module's test suite."""
|
1078
|
+
import doctest
|
1079
|
+
|
1080
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|