edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +46 -10
- edsl/__version__.py +1 -0
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +121 -104
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -204
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -417
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -166
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -218
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.14.dist-info/METADATA +0 -69
- edsl-0.1.14.dist-info/RECORD +0 -141
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -0,0 +1,211 @@
|
|
1
|
+
from fastapi import FastAPI, HTTPException
|
2
|
+
from pydantic import BaseModel
|
3
|
+
from typing import Union, Dict
|
4
|
+
from typing import Union, List, Any, Optional
|
5
|
+
from threading import RLock
|
6
|
+
from edsl.jobs.buckets.TokenBucket import TokenBucket # Original implementation
|
7
|
+
|
8
|
+
|
9
|
+
def safe_float_for_json(value: float) -> Union[float, str]:
|
10
|
+
"""Convert float('inf') to 'infinity' for JSON serialization.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
value: The float value to convert
|
14
|
+
|
15
|
+
Returns:
|
16
|
+
Either the original float or the string 'infinity' if the value is infinite
|
17
|
+
"""
|
18
|
+
if value == float("inf"):
|
19
|
+
return "infinity"
|
20
|
+
return value
|
21
|
+
|
22
|
+
|
23
|
+
app = FastAPI()
|
24
|
+
|
25
|
+
# In-memory storage for TokenBucket instances
|
26
|
+
buckets: Dict[str, TokenBucket] = {}
|
27
|
+
|
28
|
+
|
29
|
+
class TokenBucketCreate(BaseModel):
|
30
|
+
bucket_name: str
|
31
|
+
bucket_type: str
|
32
|
+
capacity: Union[int, float]
|
33
|
+
refill_rate: Union[int, float]
|
34
|
+
|
35
|
+
|
36
|
+
@app.get("/buckets")
|
37
|
+
async def list_buckets(
|
38
|
+
bucket_type: Optional[str] = None,
|
39
|
+
bucket_name: Optional[str] = None,
|
40
|
+
include_logs: bool = False,
|
41
|
+
):
|
42
|
+
"""List all buckets and their current status.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
bucket_type: Optional filter by bucket type
|
46
|
+
bucket_name: Optional filter by bucket name
|
47
|
+
include_logs: Whether to include the full logs in the response
|
48
|
+
"""
|
49
|
+
result = {}
|
50
|
+
|
51
|
+
for bucket_id, bucket in buckets.items():
|
52
|
+
# Apply filters if specified
|
53
|
+
if bucket_type and bucket.bucket_type != bucket_type:
|
54
|
+
continue
|
55
|
+
if bucket_name and bucket.bucket_name != bucket_name:
|
56
|
+
continue
|
57
|
+
|
58
|
+
# Get basic bucket info
|
59
|
+
bucket_info = {
|
60
|
+
"bucket_name": bucket.bucket_name,
|
61
|
+
"bucket_type": bucket.bucket_type,
|
62
|
+
"tokens": bucket.tokens,
|
63
|
+
"capacity": bucket.capacity,
|
64
|
+
"refill_rate": bucket.refill_rate,
|
65
|
+
"turbo_mode": bucket.turbo_mode,
|
66
|
+
"num_requests": bucket.num_requests,
|
67
|
+
"num_released": bucket.num_released,
|
68
|
+
"tokens_returned": bucket.tokens_returned,
|
69
|
+
}
|
70
|
+
for k, v in bucket_info.items():
|
71
|
+
if isinstance(v, float):
|
72
|
+
bucket_info[k] = safe_float_for_json(v)
|
73
|
+
|
74
|
+
# Only include logs if requested
|
75
|
+
if include_logs:
|
76
|
+
bucket_info["log"] = bucket.log
|
77
|
+
|
78
|
+
result[bucket_id] = bucket_info
|
79
|
+
|
80
|
+
return result
|
81
|
+
|
82
|
+
|
83
|
+
@app.post("/bucket/{bucket_id}/add_tokens")
|
84
|
+
async def add_tokens(bucket_id: str, amount: float):
|
85
|
+
"""Add tokens to an existing bucket."""
|
86
|
+
if bucket_id not in buckets:
|
87
|
+
raise HTTPException(status_code=404, detail="Bucket not found")
|
88
|
+
|
89
|
+
if not isinstance(amount, (int, float)) or amount != amount: # Check for NaN
|
90
|
+
raise HTTPException(status_code=400, detail="Invalid amount specified")
|
91
|
+
|
92
|
+
if amount == float("inf") or amount == float("-inf"):
|
93
|
+
raise HTTPException(status_code=400, detail="Amount cannot be infinite")
|
94
|
+
|
95
|
+
bucket = buckets[bucket_id]
|
96
|
+
bucket.add_tokens(amount)
|
97
|
+
|
98
|
+
# Ensure we return a JSON-serializable float
|
99
|
+
current_tokens = float(bucket.tokens)
|
100
|
+
if not -1e308 <= current_tokens <= 1e308: # Check if within JSON float bounds
|
101
|
+
current_tokens = 0.0 # or some other reasonable default
|
102
|
+
|
103
|
+
return {"status": "success", "current_tokens": safe_float_for_json(current_tokens)}
|
104
|
+
|
105
|
+
|
106
|
+
# @app.post("/bucket")
|
107
|
+
# async def create_bucket(bucket: TokenBucketCreate):
|
108
|
+
# bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
|
109
|
+
# if bucket_id in buckets:
|
110
|
+
# raise HTTPException(status_code=400, detail="Bucket already exists")
|
111
|
+
|
112
|
+
# # Create an actual TokenBucket instance
|
113
|
+
# buckets[bucket_id] = TokenBucket(
|
114
|
+
# bucket_name=bucket.bucket_name,
|
115
|
+
# bucket_type=bucket.bucket_type,
|
116
|
+
# capacity=bucket.capacity,
|
117
|
+
# refill_rate=bucket.refill_rate,
|
118
|
+
# )
|
119
|
+
# return {"status": "created"}
|
120
|
+
|
121
|
+
|
122
|
+
@app.post("/bucket")
|
123
|
+
async def create_bucket(bucket: TokenBucketCreate):
|
124
|
+
if (
|
125
|
+
not isinstance(bucket.capacity, (int, float))
|
126
|
+
or bucket.capacity != bucket.capacity
|
127
|
+
): # Check for NaN
|
128
|
+
raise HTTPException(status_code=400, detail="Invalid capacity value")
|
129
|
+
if (
|
130
|
+
not isinstance(bucket.refill_rate, (int, float))
|
131
|
+
or bucket.refill_rate != bucket.refill_rate
|
132
|
+
): # Check for NaN
|
133
|
+
raise HTTPException(status_code=400, detail="Invalid refill rate value")
|
134
|
+
if bucket.capacity == float("inf") or bucket.refill_rate == float("inf"):
|
135
|
+
raise HTTPException(status_code=400, detail="Values cannot be infinite")
|
136
|
+
bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
|
137
|
+
if bucket_id in buckets:
|
138
|
+
# Instead of error, return success with "existing" status
|
139
|
+
return {
|
140
|
+
"status": "existing",
|
141
|
+
"bucket": {
|
142
|
+
"capacity": safe_float_for_json(buckets[bucket_id].capacity),
|
143
|
+
"refill_rate": safe_float_for_json(buckets[bucket_id].refill_rate),
|
144
|
+
},
|
145
|
+
}
|
146
|
+
|
147
|
+
# Create a new bucket
|
148
|
+
buckets[bucket_id] = TokenBucket(
|
149
|
+
bucket_name=bucket.bucket_name,
|
150
|
+
bucket_type=bucket.bucket_type,
|
151
|
+
capacity=bucket.capacity,
|
152
|
+
refill_rate=bucket.refill_rate,
|
153
|
+
)
|
154
|
+
return {"status": "created"}
|
155
|
+
|
156
|
+
|
157
|
+
@app.post("/bucket/{bucket_id}/get_tokens")
|
158
|
+
async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool = True):
|
159
|
+
if bucket_id not in buckets:
|
160
|
+
raise HTTPException(status_code=404, detail="Bucket not found")
|
161
|
+
|
162
|
+
bucket = buckets[bucket_id]
|
163
|
+
await bucket.get_tokens(amount, cheat_bucket_capacity)
|
164
|
+
return {"status": "success"}
|
165
|
+
|
166
|
+
|
167
|
+
@app.post("/bucket/{bucket_id}/turbo_mode/{state}")
|
168
|
+
async def set_turbo_mode(bucket_id: str, state: bool):
|
169
|
+
if bucket_id not in buckets:
|
170
|
+
raise HTTPException(status_code=404, detail="Bucket not found")
|
171
|
+
|
172
|
+
bucket = buckets[bucket_id]
|
173
|
+
if state:
|
174
|
+
bucket.turbo_mode_on()
|
175
|
+
else:
|
176
|
+
bucket.turbo_mode_off()
|
177
|
+
return {"status": "success"}
|
178
|
+
|
179
|
+
|
180
|
+
@app.get("/bucket/{bucket_id}/status")
|
181
|
+
async def get_bucket_status(bucket_id: str):
|
182
|
+
if bucket_id not in buckets:
|
183
|
+
raise HTTPException(status_code=404, detail="Bucket not found")
|
184
|
+
|
185
|
+
bucket = buckets[bucket_id]
|
186
|
+
status = {
|
187
|
+
"tokens": bucket.tokens,
|
188
|
+
"capacity": bucket.capacity,
|
189
|
+
"refill_rate": bucket.refill_rate,
|
190
|
+
"turbo_mode": bucket.turbo_mode,
|
191
|
+
"num_requests": bucket.num_requests,
|
192
|
+
"num_released": bucket.num_released,
|
193
|
+
"tokens_returned": bucket.tokens_returned,
|
194
|
+
"log": bucket.log,
|
195
|
+
}
|
196
|
+
for k, v in status.items():
|
197
|
+
if isinstance(v, float):
|
198
|
+
status[k] = safe_float_for_json(v)
|
199
|
+
|
200
|
+
for index, entry in enumerate(status["log"]):
|
201
|
+
ts, value = entry
|
202
|
+
status["log"][index] = (ts, safe_float_for_json(value))
|
203
|
+
|
204
|
+
# print(status)
|
205
|
+
return status
|
206
|
+
|
207
|
+
|
208
|
+
if __name__ == "__main__":
|
209
|
+
import uvicorn
|
210
|
+
|
211
|
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
@@ -0,0 +1,191 @@
|
|
1
|
+
from typing import Union, Optional
|
2
|
+
import asyncio
|
3
|
+
import time
|
4
|
+
import aiohttp
|
5
|
+
|
6
|
+
|
7
|
+
class TokenBucketClient:
|
8
|
+
"""REST API client version of TokenBucket that maintains the same interface
|
9
|
+
by delegating to a server running the original TokenBucket implementation."""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
*,
|
14
|
+
bucket_name: str,
|
15
|
+
bucket_type: str,
|
16
|
+
capacity: Union[int, float],
|
17
|
+
refill_rate: Union[int, float],
|
18
|
+
api_base_url: str = "http://localhost:8000",
|
19
|
+
):
|
20
|
+
self.bucket_name = bucket_name
|
21
|
+
self.bucket_type = bucket_type
|
22
|
+
self.capacity = capacity
|
23
|
+
self.refill_rate = refill_rate
|
24
|
+
self.api_base_url = api_base_url
|
25
|
+
self.bucket_id = f"{bucket_name}_{bucket_type}"
|
26
|
+
|
27
|
+
# Initialize the bucket on the server
|
28
|
+
asyncio.run(self._create_bucket())
|
29
|
+
|
30
|
+
# Cache some values locally
|
31
|
+
self.creation_time = time.monotonic()
|
32
|
+
self.turbo_mode = False
|
33
|
+
|
34
|
+
async def _create_bucket(self):
|
35
|
+
async with aiohttp.ClientSession() as session:
|
36
|
+
payload = {
|
37
|
+
"bucket_name": self.bucket_name,
|
38
|
+
"bucket_type": self.bucket_type,
|
39
|
+
"capacity": self.capacity,
|
40
|
+
"refill_rate": self.refill_rate,
|
41
|
+
}
|
42
|
+
async with session.post(
|
43
|
+
f"{self.api_base_url}/bucket",
|
44
|
+
json=payload,
|
45
|
+
) as response:
|
46
|
+
if response.status != 200:
|
47
|
+
raise ValueError(f"Unexpected error: {await response.text()}")
|
48
|
+
|
49
|
+
result = await response.json()
|
50
|
+
if result["status"] == "existing":
|
51
|
+
# Update our local values to match the existing bucket
|
52
|
+
self.capacity = float(result["bucket"]["capacity"])
|
53
|
+
self.refill_rate = float(result["bucket"]["refill_rate"])
|
54
|
+
|
55
|
+
def turbo_mode_on(self):
|
56
|
+
"""Set the refill rate to infinity."""
|
57
|
+
asyncio.run(self._set_turbo_mode(True))
|
58
|
+
self.turbo_mode = True
|
59
|
+
|
60
|
+
def turbo_mode_off(self):
|
61
|
+
"""Restore the refill rate to its original value."""
|
62
|
+
asyncio.run(self._set_turbo_mode(False))
|
63
|
+
self.turbo_mode = False
|
64
|
+
|
65
|
+
async def add_tokens(self, amount: Union[int, float]):
|
66
|
+
"""Add tokens to the bucket."""
|
67
|
+
async with aiohttp.ClientSession() as session:
|
68
|
+
async with session.post(
|
69
|
+
f"{self.api_base_url}/bucket/{self.bucket_id}/add_tokens",
|
70
|
+
params={"amount": amount},
|
71
|
+
) as response:
|
72
|
+
if response.status != 200:
|
73
|
+
raise ValueError(f"Failed to add tokens: {await response.text()}")
|
74
|
+
|
75
|
+
async def _set_turbo_mode(self, state: bool):
|
76
|
+
async with aiohttp.ClientSession() as session:
|
77
|
+
async with session.post(
|
78
|
+
f"{self.api_base_url}/bucket/{self.bucket_id}/turbo_mode/{str(state).lower()}"
|
79
|
+
) as response:
|
80
|
+
if response.status != 200:
|
81
|
+
raise ValueError(
|
82
|
+
f"Failed to set turbo mode: {await response.text()}"
|
83
|
+
)
|
84
|
+
|
85
|
+
async def get_tokens(
|
86
|
+
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
87
|
+
) -> None:
|
88
|
+
async with aiohttp.ClientSession() as session:
|
89
|
+
async with session.post(
|
90
|
+
f"{self.api_base_url}/bucket/{self.bucket_id}/get_tokens",
|
91
|
+
params={
|
92
|
+
"amount": amount,
|
93
|
+
"cheat_bucket_capacity": int(cheat_bucket_capacity),
|
94
|
+
},
|
95
|
+
) as response:
|
96
|
+
if response.status != 200:
|
97
|
+
raise ValueError(f"Failed to get tokens: {await response.text()}")
|
98
|
+
|
99
|
+
def get_throughput(self, time_window: Optional[float] = None) -> float:
|
100
|
+
status = asyncio.run(self._get_status())
|
101
|
+
now = time.monotonic()
|
102
|
+
|
103
|
+
if time_window is None:
|
104
|
+
start_time = self.creation_time
|
105
|
+
else:
|
106
|
+
start_time = now - time_window
|
107
|
+
|
108
|
+
if start_time < self.creation_time:
|
109
|
+
start_time = self.creation_time
|
110
|
+
|
111
|
+
elapsed_time = now - start_time
|
112
|
+
|
113
|
+
if elapsed_time == 0:
|
114
|
+
return status["num_released"] / 0.001
|
115
|
+
|
116
|
+
return (status["num_released"] / elapsed_time) * 60
|
117
|
+
|
118
|
+
async def _get_status(self) -> dict:
|
119
|
+
async with aiohttp.ClientSession() as session:
|
120
|
+
async with session.get(
|
121
|
+
f"{self.api_base_url}/bucket/{self.bucket_id}/status"
|
122
|
+
) as response:
|
123
|
+
if response.status != 200:
|
124
|
+
raise ValueError(
|
125
|
+
f"Failed to get bucket status: {await response.text()}"
|
126
|
+
)
|
127
|
+
return await response.json()
|
128
|
+
|
129
|
+
def __add__(self, other) -> "TokenBucketClient":
|
130
|
+
"""Combine two token buckets."""
|
131
|
+
return TokenBucketClient(
|
132
|
+
bucket_name=self.bucket_name,
|
133
|
+
bucket_type=self.bucket_type,
|
134
|
+
capacity=min(self.capacity, other.capacity),
|
135
|
+
refill_rate=min(self.refill_rate, other.refill_rate),
|
136
|
+
api_base_url=self.api_base_url,
|
137
|
+
)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def tokens(self) -> float:
|
141
|
+
"""Get the number of tokens remaining in the bucket."""
|
142
|
+
status = asyncio.run(self._get_status())
|
143
|
+
return float(status["tokens"])
|
144
|
+
|
145
|
+
def wait_time(self, requested_tokens: Union[float, int]) -> float:
|
146
|
+
"""Calculate the time to wait for the requested number of tokens."""
|
147
|
+
# self.refill() # Update the current token count
|
148
|
+
if self.tokens >= float(requested_tokens):
|
149
|
+
return 0.0
|
150
|
+
try:
|
151
|
+
return (requested_tokens - self.tokens) / self.refill_rate
|
152
|
+
except Exception as e:
|
153
|
+
raise ValueError(f"Error calculating wait time: {e}")
|
154
|
+
|
155
|
+
# def wait_time(self, num_tokens: Union[int, float]) -> float:
|
156
|
+
# return 0 # TODO - Need to implement this on the server side
|
157
|
+
|
158
|
+
def visualize(self):
|
159
|
+
"""Visualize the token bucket over time."""
|
160
|
+
status = asyncio.run(self._get_status())
|
161
|
+
times, tokens = zip(*status["log"])
|
162
|
+
start_time = times[0]
|
163
|
+
times = [t - start_time for t in times]
|
164
|
+
|
165
|
+
from matplotlib import pyplot as plt
|
166
|
+
|
167
|
+
plt.figure(figsize=(10, 6))
|
168
|
+
plt.plot(times, tokens, label="Tokens Available")
|
169
|
+
plt.xlabel("Time (seconds)", fontsize=12)
|
170
|
+
plt.ylabel("Number of Tokens", fontsize=12)
|
171
|
+
details = f"{self.bucket_name} ({self.bucket_type}) Bucket Usage Over Time\nCapacity: {self.capacity:.1f}, Refill Rate: {self.refill_rate:.1f}/second"
|
172
|
+
plt.title(details, fontsize=14)
|
173
|
+
plt.legend()
|
174
|
+
plt.grid(True)
|
175
|
+
plt.tight_layout()
|
176
|
+
plt.show()
|
177
|
+
|
178
|
+
|
179
|
+
if __name__ == "__main__":
|
180
|
+
import doctest
|
181
|
+
|
182
|
+
doctest.testmod()
|
183
|
+
# bucket = TokenBucketClient(
|
184
|
+
# bucket_name="test", bucket_type="test", capacity=100, refill_rate=10
|
185
|
+
# )
|
186
|
+
# asyncio.run(bucket.get_tokens(50))
|
187
|
+
# time.sleep(1) # Wait for 1 second
|
188
|
+
# asyncio.run(bucket.get_tokens(30))
|
189
|
+
# throughput = bucket.get_throughput(1)
|
190
|
+
# print(throughput)
|
191
|
+
# bucket.visualize()
|
@@ -0,0 +1,85 @@
|
|
1
|
+
import warnings
|
2
|
+
from typing import TYPE_CHECKING
|
3
|
+
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from edsl.surveys.Survey import Survey
|
6
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
7
|
+
|
8
|
+
|
9
|
+
class CheckSurveyScenarioCompatibility:
|
10
|
+
|
11
|
+
def __init__(self, survey: "Survey", scenarios: "ScenarioList"):
|
12
|
+
self.survey = survey
|
13
|
+
self.scenarios = scenarios
|
14
|
+
|
15
|
+
def check(self, strict: bool = False, warn: bool = False) -> None:
|
16
|
+
"""Check if the parameters in the survey and scenarios are consistent.
|
17
|
+
|
18
|
+
>>> from edsl.jobs.Jobs import Jobs
|
19
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
20
|
+
>>> from edsl.surveys.Survey import Survey
|
21
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
22
|
+
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
23
|
+
>>> j = Jobs(survey = Survey(questions=[q]))
|
24
|
+
>>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
|
25
|
+
>>> with warnings.catch_warnings(record=True) as w:
|
26
|
+
... cs.check(warn = True)
|
27
|
+
... assert len(w) == 1
|
28
|
+
... assert issubclass(w[-1].category, UserWarning)
|
29
|
+
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
30
|
+
|
31
|
+
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
32
|
+
>>> s = Scenario({'plop': "A", 'poo': "B"})
|
33
|
+
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
34
|
+
>>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
|
35
|
+
>>> cs.check(strict = True)
|
36
|
+
Traceback (most recent call last):
|
37
|
+
...
|
38
|
+
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
39
|
+
|
40
|
+
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
41
|
+
>>> s = Scenario({'ugly_question': "B"})
|
42
|
+
>>> from edsl.scenarios.ScenarioList import ScenarioList
|
43
|
+
>>> cs = CheckSurveyScenarioCompatibility(Survey(questions=[q]), ScenarioList([s]))
|
44
|
+
>>> cs.check()
|
45
|
+
Traceback (most recent call last):
|
46
|
+
...
|
47
|
+
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
48
|
+
"""
|
49
|
+
survey_parameters: set = self.survey.parameters
|
50
|
+
scenario_parameters: set = self.scenarios.parameters
|
51
|
+
|
52
|
+
msg0, msg1, msg2 = None, None, None
|
53
|
+
|
54
|
+
# look for key issues
|
55
|
+
if intersection := set(self.scenarios.parameters) & set(
|
56
|
+
self.survey.question_names
|
57
|
+
):
|
58
|
+
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
59
|
+
|
60
|
+
raise ValueError(msg0)
|
61
|
+
|
62
|
+
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
63
|
+
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
64
|
+
if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
|
65
|
+
msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
|
66
|
+
|
67
|
+
if msg1 or msg2:
|
68
|
+
message = "\n".join(filter(None, [msg1, msg2]))
|
69
|
+
if strict:
|
70
|
+
raise ValueError(message)
|
71
|
+
else:
|
72
|
+
if warn:
|
73
|
+
warnings.warn(message)
|
74
|
+
|
75
|
+
if self.scenarios.has_jinja_braces:
|
76
|
+
warnings.warn(
|
77
|
+
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
78
|
+
)
|
79
|
+
self.scenarios = self.scenarios._convert_jinja_braces()
|
80
|
+
|
81
|
+
|
82
|
+
if __name__ == "__main__":
|
83
|
+
import doctest
|
84
|
+
|
85
|
+
doctest.testmod()
|
@@ -0,0 +1,120 @@
|
|
1
|
+
from typing import Optional, Literal
|
2
|
+
from dataclasses import dataclass, asdict
|
3
|
+
|
4
|
+
# from edsl.data_transfer_models import VisibilityType
|
5
|
+
from edsl.data.Cache import Cache
|
6
|
+
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
7
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
8
|
+
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
9
|
+
|
10
|
+
VisibilityType = Literal["private", "public", "unlisted"]
|
11
|
+
from edsl.Base import Base
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class RunEnvironment:
|
16
|
+
cache: Optional[Cache] = None
|
17
|
+
bucket_collection: Optional[BucketCollection] = None
|
18
|
+
key_lookup: Optional[KeyLookup] = None
|
19
|
+
jobs_runner_status: Optional["JobsRunnerStatus"] = None
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class RunParameters(Base):
|
24
|
+
n: int = 1
|
25
|
+
progress_bar: bool = False
|
26
|
+
stop_on_exception: bool = False
|
27
|
+
check_api_keys: bool = False
|
28
|
+
verbose: bool = True
|
29
|
+
print_exceptions: bool = True
|
30
|
+
remote_cache_description: Optional[str] = None
|
31
|
+
remote_inference_description: Optional[str] = None
|
32
|
+
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
|
33
|
+
skip_retry: bool = False
|
34
|
+
raise_validation_errors: bool = False
|
35
|
+
disable_remote_cache: bool = False
|
36
|
+
disable_remote_inference: bool = False
|
37
|
+
job_uuid: Optional[str] = None
|
38
|
+
|
39
|
+
def to_dict(self, add_edsl_version=False) -> dict:
|
40
|
+
d = asdict(self)
|
41
|
+
if add_edsl_version:
|
42
|
+
from edsl import __version__
|
43
|
+
|
44
|
+
d["edsl_version"] = __version__
|
45
|
+
d["edsl_class_name"] = "RunConfig"
|
46
|
+
return d
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_dict(cls, data: dict) -> "RunConfig":
|
50
|
+
return cls(**data)
|
51
|
+
|
52
|
+
def code(self):
|
53
|
+
return f"RunConfig(**{self.to_dict()})"
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def example(cls) -> "RunConfig":
|
57
|
+
return cls()
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class RunConfig:
|
62
|
+
environment: RunEnvironment
|
63
|
+
parameters: RunParameters
|
64
|
+
|
65
|
+
def add_environment(self, environment: RunEnvironment):
|
66
|
+
self.environment = environment
|
67
|
+
|
68
|
+
def add_bucket_collection(self, bucket_collection: BucketCollection):
|
69
|
+
self.environment.bucket_collection = bucket_collection
|
70
|
+
|
71
|
+
def add_cache(self, cache: Cache):
|
72
|
+
self.environment.cache = cache
|
73
|
+
|
74
|
+
def add_key_lookup(self, key_lookup: KeyLookup):
|
75
|
+
self.environment.key_lookup = key_lookup
|
76
|
+
|
77
|
+
|
78
|
+
"""This module contains the Answers class, which is a helper class to hold the answers to a survey."""
|
79
|
+
|
80
|
+
from collections import UserDict
|
81
|
+
from edsl.data_transfer_models import EDSLResultObjectInput
|
82
|
+
|
83
|
+
|
84
|
+
class Answers(UserDict):
|
85
|
+
"""Helper class to hold the answers to a survey."""
|
86
|
+
|
87
|
+
def add_answer(
|
88
|
+
self, response: EDSLResultObjectInput, question: "QuestionBase"
|
89
|
+
) -> None:
|
90
|
+
"""Add a response to the answers dictionary."""
|
91
|
+
answer = response.answer
|
92
|
+
comment = response.comment
|
93
|
+
generated_tokens = response.generated_tokens
|
94
|
+
# record the answer
|
95
|
+
if generated_tokens:
|
96
|
+
self[question.question_name + "_generated_tokens"] = generated_tokens
|
97
|
+
self[question.question_name] = answer
|
98
|
+
if comment:
|
99
|
+
self[question.question_name + "_comment"] = comment
|
100
|
+
|
101
|
+
def replace_missing_answers_with_none(self, survey: "Survey") -> None:
|
102
|
+
"""Replace missing answers with None. Answers can be missing if the agent skips a question."""
|
103
|
+
for question_name in survey.question_names:
|
104
|
+
if question_name not in self:
|
105
|
+
self[question_name] = None
|
106
|
+
|
107
|
+
def to_dict(self):
|
108
|
+
"""Return a dictionary of the answers."""
|
109
|
+
return self.data
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def from_dict(cls, d):
|
113
|
+
"""Return an Answers object from a dictionary."""
|
114
|
+
return cls(d)
|
115
|
+
|
116
|
+
|
117
|
+
if __name__ == "__main__":
|
118
|
+
import doctest
|
119
|
+
|
120
|
+
doctest.testmod()
|
edsl/jobs/decorators.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
from functools import wraps
|
2
|
+
from threading import RLock
|
3
|
+
import inspect
|
4
|
+
|
5
|
+
|
6
|
+
def synchronized_class(wrapped_class):
|
7
|
+
"""Class decorator that makes all methods thread-safe."""
|
8
|
+
|
9
|
+
# Add a lock to the class
|
10
|
+
setattr(wrapped_class, "_lock", RLock())
|
11
|
+
|
12
|
+
# Get all methods from the class
|
13
|
+
for name, method in inspect.getmembers(wrapped_class, inspect.isfunction):
|
14
|
+
# Skip magic methods except __getitem__, __setitem__, __delitem__
|
15
|
+
if name.startswith("__") and name not in [
|
16
|
+
"__getitem__",
|
17
|
+
"__setitem__",
|
18
|
+
"__delitem__",
|
19
|
+
]:
|
20
|
+
continue
|
21
|
+
|
22
|
+
# Create synchronized version of the method
|
23
|
+
def create_synchronized_method(method):
|
24
|
+
@wraps(method)
|
25
|
+
def synchronized_method(*args, **kwargs):
|
26
|
+
instance = args[0] # first arg is self
|
27
|
+
with instance._lock:
|
28
|
+
return method(*args, **kwargs)
|
29
|
+
|
30
|
+
return synchronized_method
|
31
|
+
|
32
|
+
# Replace the original method with synchronized version
|
33
|
+
setattr(wrapped_class, name, create_synchronized_method(method))
|
34
|
+
|
35
|
+
return wrapped_class
|