edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +46 -10
- edsl/__version__.py +1 -0
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +121 -104
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -204
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -417
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -166
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -218
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.14.dist-info/METADATA +0 -69
- edsl-0.1.14.dist-info/RECORD +0 -141
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
@@ -1,115 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
import asyncio
|
3
|
-
|
4
|
-
from rich.table import Table
|
5
|
-
from rich.text import Text
|
6
|
-
from rich.box import SIMPLE
|
7
|
-
|
8
|
-
from edsl.jobs.token_tracking import TokenPricing
|
9
|
-
|
10
|
-
pricing = {
|
11
|
-
"gpt-3.5-turbo": TokenPricing(
|
12
|
-
model_name="gpt-3.5-turbo",
|
13
|
-
prompt_token_price_per_k=0.0005,
|
14
|
-
completion_token_price_per_k=0.0015,
|
15
|
-
),
|
16
|
-
"gpt-4-1106-preview": TokenPricing(
|
17
|
-
model_name="gpt-4",
|
18
|
-
prompt_token_price_per_k=0.01,
|
19
|
-
completion_token_price_per_k=0.03,
|
20
|
-
),
|
21
|
-
"test": TokenPricing(
|
22
|
-
model_name="test",
|
23
|
-
prompt_token_price_per_k=0.0,
|
24
|
-
completion_token_price_per_k=0.0,
|
25
|
-
),
|
26
|
-
"gemini_pro": TokenPricing(
|
27
|
-
model_name="gemini_pro",
|
28
|
-
prompt_token_price_per_k=0.0,
|
29
|
-
completion_token_price_per_k=0.0,
|
30
|
-
),
|
31
|
-
"llama-2-13b-chat-hf": TokenPricing(
|
32
|
-
model_name="llama-2-13b-chat-hf",
|
33
|
-
prompt_token_price_per_k=0.0,
|
34
|
-
completion_token_price_per_k=0.0,
|
35
|
-
),
|
36
|
-
"llama-2-70b-chat-hf": TokenPricing(
|
37
|
-
model_name="llama-2-70b-chat-hf",
|
38
|
-
prompt_token_price_per_k=0.0,
|
39
|
-
completion_token_price_per_k=0.0,
|
40
|
-
),
|
41
|
-
"mixtral-8x7B-instruct-v0.1": TokenPricing(
|
42
|
-
model_name="mixtral-8x7B-instruct-v0.1",
|
43
|
-
prompt_token_price_per_k=0.0,
|
44
|
-
completion_token_price_per_k=0.0,
|
45
|
-
),
|
46
|
-
}
|
47
|
-
|
48
|
-
|
49
|
-
class ModelStatus:
|
50
|
-
def __init__(self, model, TPM, RPM):
|
51
|
-
self.model = model
|
52
|
-
self.TPM = TPM
|
53
|
-
self.RPM = RPM
|
54
|
-
|
55
|
-
|
56
|
-
from edsl.jobs.token_tracking import InterviewTokenUsage
|
57
|
-
from edsl.jobs.task_management import InterviewStatusDictionary
|
58
|
-
|
59
|
-
from collections import defaultdict
|
60
|
-
|
61
|
-
|
62
|
-
class JobsRunnerStatusMixin:
|
63
|
-
def _generate_status_table(self, data: List[asyncio.Task], elapsed_time):
|
64
|
-
models_to_tokens = defaultdict(InterviewTokenUsage)
|
65
|
-
model_to_status = defaultdict(InterviewStatusDictionary)
|
66
|
-
|
67
|
-
for interview in self.interviews:
|
68
|
-
model = interview.model
|
69
|
-
models_to_tokens[model] += interview.token_usage
|
70
|
-
model_to_status[model] += interview.interview_status
|
71
|
-
|
72
|
-
pct_complete = len(data) / len(self.interviews) * 100
|
73
|
-
average_time = elapsed_time / len(data) if len(data) > 0 else 0
|
74
|
-
|
75
|
-
table = Table(
|
76
|
-
title="Job Status",
|
77
|
-
show_header=True,
|
78
|
-
header_style="bold magenta",
|
79
|
-
box=SIMPLE,
|
80
|
-
)
|
81
|
-
table.add_column("Key", style="dim", no_wrap=True)
|
82
|
-
table.add_column("Value")
|
83
|
-
|
84
|
-
# Add rows for each key-value pair
|
85
|
-
table.add_row(Text("Task status", style="bold red"), "")
|
86
|
-
table.add_row("Total interviews requested", str(len(self.interviews)))
|
87
|
-
table.add_row("Completed interviews", str(len(data)))
|
88
|
-
# table.add_row("Interviews from cache", str(num_from_cache))
|
89
|
-
table.add_row("Percent complete", f"{pct_complete:.2f}%")
|
90
|
-
table.add_row("", "")
|
91
|
-
|
92
|
-
# table.add_row(Text("Timing", style = "bold red"), "")
|
93
|
-
# table.add_row("Elapsed time (seconds)", f"{elapsed_time:.3f}")
|
94
|
-
# table.add_row("Average time/interview (seconds)", f"{average_time:.3f}")
|
95
|
-
# table.add_row("", "")
|
96
|
-
|
97
|
-
# table.add_row(Text("Model Queues", style = "bold red"), "")
|
98
|
-
# for model, num_waiting in waiting_dict.items():
|
99
|
-
# if model.model not in pricing:
|
100
|
-
# raise ValueError(f"Model {model.model} not found in pricing")
|
101
|
-
# prices = pricing[model.model]
|
102
|
-
# table.add_row(Text(f"{model.model}", style="blue"),"")
|
103
|
-
# table.add_row(f"-TPM limit (k)", str(model.TPM/1000))
|
104
|
-
# table.add_row(f"-RPM limit (k)", str(model.RPM/1000))
|
105
|
-
# table.add_row(f"-Num tasks waiting", str(num_waiting))
|
106
|
-
# token_usage = models_to_tokens[model]
|
107
|
-
# for cache_status in ['new_token_usage', 'cached_token_usage']:
|
108
|
-
# table.add_row(Text(f"{cache_status}", style="bold"), "")
|
109
|
-
# token_usage = getattr(models_to_tokens[model], cache_status)
|
110
|
-
# for token_type in ["prompt_tokens", "completion_tokens"]:
|
111
|
-
# tokens = getattr(token_usage, token_type)
|
112
|
-
# table.add_row(f"-{token_type}", str(tokens))
|
113
|
-
# table.add_row("Cost", f"${token_usage.cost(prices):.5f}")
|
114
|
-
|
115
|
-
return table
|
edsl/jobs/base.py
DELETED
@@ -1,47 +0,0 @@
|
|
1
|
-
from collections import UserDict
|
2
|
-
import importlib
|
3
|
-
|
4
|
-
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
5
|
-
from edsl.jobs.runners.JobsRunnerDryRun import JobsRunnerDryRun
|
6
|
-
|
7
|
-
from edsl.exceptions import JobsRunError
|
8
|
-
from edsl.jobs.JobsRunner import RegisterJobsRunnerMeta
|
9
|
-
|
10
|
-
|
11
|
-
class JobsRunnersRegistryDict(UserDict):
|
12
|
-
def __getitem__(self, key):
|
13
|
-
try:
|
14
|
-
return super().__getitem__(key)
|
15
|
-
except KeyError:
|
16
|
-
raise JobsRunError(f"JobsRunner '{key}' not found in registry.")
|
17
|
-
|
18
|
-
|
19
|
-
registry_data = RegisterJobsRunnerMeta.lookup()
|
20
|
-
JobsRunnersRegistry = JobsRunnersRegistryDict(registry_data)
|
21
|
-
|
22
|
-
|
23
|
-
class JobsRunnerDescriptor:
|
24
|
-
def validate(self, value: str) -> None:
|
25
|
-
"""Validates the value. If it is invalid, raises an exception. If it is valid, does nothing."""
|
26
|
-
if value not in JobsRunnersRegistry:
|
27
|
-
raise ValueError(
|
28
|
-
f"JobsRunner must be one of {list(JobsRunnersRegistry.keys())}"
|
29
|
-
)
|
30
|
-
|
31
|
-
def __get__(self, instance, owner):
|
32
|
-
""""""
|
33
|
-
if self.name not in instance.__dict__:
|
34
|
-
return None
|
35
|
-
else:
|
36
|
-
return instance.__dict__[self.name]
|
37
|
-
|
38
|
-
def __set__(self, instance, value: str) -> None:
|
39
|
-
self.validate(value, instance)
|
40
|
-
instance.__dict__[self.name] = value
|
41
|
-
|
42
|
-
def __set_name__(self, owner, name: str) -> None:
|
43
|
-
self.name = "_" + name
|
44
|
-
|
45
|
-
|
46
|
-
if __name__ == "__main__":
|
47
|
-
pass
|
edsl/jobs/buckets.py
DELETED
@@ -1,166 +0,0 @@
|
|
1
|
-
from typing import Union
|
2
|
-
import asyncio
|
3
|
-
import time
|
4
|
-
from collections import UserDict
|
5
|
-
from matplotlib import pyplot as plt
|
6
|
-
|
7
|
-
|
8
|
-
class TokenBucket:
|
9
|
-
"""This is a token bucket used to respect rate limits to services."""
|
10
|
-
|
11
|
-
def __init__(
|
12
|
-
self,
|
13
|
-
*,
|
14
|
-
bucket_name,
|
15
|
-
bucket_type: str,
|
16
|
-
capacity: Union[int, float],
|
17
|
-
refill_rate: Union[int, float],
|
18
|
-
):
|
19
|
-
self.bucket_name = bucket_name
|
20
|
-
self.bucket_type = bucket_type
|
21
|
-
self.capacity = capacity # Maximum number of tokens
|
22
|
-
self.tokens = capacity # Current number of available tokens
|
23
|
-
self.refill_rate = refill_rate # Rate at which tokens are refilled
|
24
|
-
self.last_refill = time.monotonic() # Last refill time
|
25
|
-
|
26
|
-
self.log = []
|
27
|
-
|
28
|
-
def __add__(self, other) -> "TokenBucket":
|
29
|
-
"""Combine two token buckets. The resulting bucket has the minimum capacity and refill rate of the two buckets.
|
30
|
-
This is useful, for example, if we have two calls to the same model on the same service but have different temperatures.
|
31
|
-
"""
|
32
|
-
return TokenBucket(
|
33
|
-
bucket_name=self.bucket_name,
|
34
|
-
bucket_type=self.bucket_type,
|
35
|
-
capacity=min(self.capacity, other.capacity),
|
36
|
-
refil_rate=min(self.refill_rate, other.refill_rate),
|
37
|
-
)
|
38
|
-
|
39
|
-
def __repr__(self):
|
40
|
-
return f"TokenBucket(bucket_name={self.bucket_name}, bucket_type='{self.bucket_type}', capacity={self.capacity}, refill_rate={self.refill_rate})"
|
41
|
-
|
42
|
-
def add_tokens(self, tokens: Union[int, float]) -> None:
|
43
|
-
"""Add tokens to the bucket, up to the maximum capacity."""
|
44
|
-
self.tokens = min(self.capacity, self.tokens + tokens)
|
45
|
-
self.log.append((time.monotonic(), self.tokens))
|
46
|
-
|
47
|
-
def refill(self) -> None:
|
48
|
-
"""Refill the bucket with new tokens based on elapsed time."""
|
49
|
-
now = time.monotonic()
|
50
|
-
elapsed = now - self.last_refill
|
51
|
-
refill_amount = elapsed * self.refill_rate
|
52
|
-
self.tokens = min(self.capacity, self.tokens + refill_amount)
|
53
|
-
self.last_refill = now
|
54
|
-
|
55
|
-
self.log.append((now, self.tokens))
|
56
|
-
|
57
|
-
def wait_time(self, requested_tokens) -> float:
|
58
|
-
"""Calculate the time to wait for the requested number of tokens."""
|
59
|
-
now = time.monotonic()
|
60
|
-
elapsed = now - self.last_refill
|
61
|
-
refill_amount = elapsed * self.refill_rate
|
62
|
-
available_tokens = min(self.capacity, self.tokens + refill_amount)
|
63
|
-
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
64
|
-
|
65
|
-
async def get_tokens(self, amount=1) -> None:
|
66
|
-
"""Wait for the specified number of tokens to become available.
|
67
|
-
Note that this method is a coroutine.
|
68
|
-
"""
|
69
|
-
if amount > self.capacity:
|
70
|
-
raise ValueError(
|
71
|
-
f"Requested tokens exceed bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}"
|
72
|
-
)
|
73
|
-
while self.tokens < amount:
|
74
|
-
self.refill()
|
75
|
-
await asyncio.sleep(0.1) # Sleep briefly to prevent busy waiting
|
76
|
-
self.tokens -= amount
|
77
|
-
|
78
|
-
now = time.monotonic()
|
79
|
-
self.log.append((now, self.tokens))
|
80
|
-
|
81
|
-
def get_log(self) -> list[tuple]:
|
82
|
-
return self.log
|
83
|
-
|
84
|
-
def visualize(self):
|
85
|
-
"""Visualize the token bucket over time."""
|
86
|
-
times, tokens = zip(*self.get_log())
|
87
|
-
start_time = times[0]
|
88
|
-
times = [t - start_time for t in times] # Normalize time to start from 0
|
89
|
-
|
90
|
-
plt.figure(figsize=(10, 6))
|
91
|
-
plt.plot(times, tokens, label="Tokens Available")
|
92
|
-
plt.xlabel("Time (seconds)", fontsize=12)
|
93
|
-
plt.ylabel("Number of Tokens", fontsize=12)
|
94
|
-
details = f"{self.bucket_name} ({self.bucket_type}) Bucket Usage Over Time\nCapacity: {self.capacity:.1f}, Refill Rate: {self.refill_rate:.1f}/second"
|
95
|
-
plt.title(details, fontsize=14)
|
96
|
-
|
97
|
-
plt.legend()
|
98
|
-
plt.grid(True)
|
99
|
-
plt.tight_layout()
|
100
|
-
plt.show()
|
101
|
-
|
102
|
-
|
103
|
-
class ModelBuckets:
|
104
|
-
"""A class to represent the token and request buckets for a model.
|
105
|
-
Most LLM model services have limits both on requests-per-minute (RPM) and tokens-per-minute (TPM).
|
106
|
-
A request is one call to the service. The number of tokens required for a request depends on parameters.
|
107
|
-
"""
|
108
|
-
|
109
|
-
def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
|
110
|
-
self.requests_bucket = requests_bucket
|
111
|
-
self.tokens_bucket = tokens_bucket
|
112
|
-
|
113
|
-
def __add__(self, other):
|
114
|
-
return ModelBuckets(
|
115
|
-
requests_bucket=self.requests_bucket + other.requests_bucket,
|
116
|
-
tokens_bucket=self.tokens_bucket + other.tokens_bucket,
|
117
|
-
)
|
118
|
-
|
119
|
-
def visualize(self):
|
120
|
-
plot1 = self.requests_bucket.visualize()
|
121
|
-
plot2 = self.tokens_bucket.visualize()
|
122
|
-
return plot1, plot2
|
123
|
-
|
124
|
-
def __repr__(self):
|
125
|
-
return f"ModelBuckets(requests_bucket={self.requests_bucket}, tokens_bucket={self.tokens_bucket})"
|
126
|
-
|
127
|
-
|
128
|
-
class BucketCollection(UserDict):
|
129
|
-
"""A jobs object will have a whole collection of model buckets, as multiple models could be used.
|
130
|
-
The keys here are the models, and the values are the ModelBuckets objects.
|
131
|
-
Models themselves are hashable, so this works.
|
132
|
-
"""
|
133
|
-
|
134
|
-
def __init__(self):
|
135
|
-
super().__init__()
|
136
|
-
|
137
|
-
def __repr__(self):
|
138
|
-
return f"BucketCollection({self.data})"
|
139
|
-
|
140
|
-
def add_model(self, model) -> None:
|
141
|
-
"""Adds a model to the bucket collection. This will create the token and request buckets for the model."""
|
142
|
-
# compute the TPS and RPS from the model
|
143
|
-
TPS = model.TPM / 60.0
|
144
|
-
RPS = model.RPM / 60.0
|
145
|
-
# create the buckets
|
146
|
-
requests_bucket = TokenBucket(
|
147
|
-
bucket_name=model.model,
|
148
|
-
bucket_type="requests",
|
149
|
-
capacity=RPS,
|
150
|
-
refill_rate=RPS,
|
151
|
-
)
|
152
|
-
tokens_bucket = TokenBucket(
|
153
|
-
bucket_name=model.model, bucket_type="tokens", capacity=TPS, refill_rate=TPS
|
154
|
-
)
|
155
|
-
model_buckets = ModelBuckets(requests_bucket, tokens_bucket)
|
156
|
-
if model in self:
|
157
|
-
# it if already exists, combine the buckets
|
158
|
-
self[model] += model_buckets
|
159
|
-
else:
|
160
|
-
self[model] = model_buckets
|
161
|
-
|
162
|
-
def visualize(self) -> dict:
|
163
|
-
plots = {}
|
164
|
-
for model in self:
|
165
|
-
plots[model] = self[model].visualize()
|
166
|
-
return plots
|
@@ -1,19 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
|
3
|
-
from edsl.jobs import Jobs
|
4
|
-
from edsl.results import Results, Result
|
5
|
-
from edsl.jobs.JobsRunner import JobsRunner
|
6
|
-
|
7
|
-
|
8
|
-
class JobsRunnerDryRun(JobsRunner):
|
9
|
-
runner_name = "dryrun"
|
10
|
-
|
11
|
-
def __init__(self, jobs: Jobs):
|
12
|
-
super().__init__(jobs)
|
13
|
-
|
14
|
-
def run(
|
15
|
-
self, n=1, verbose=False, sleep=0, debug=False, progress_bar=False
|
16
|
-
) -> Results:
|
17
|
-
"""Runs a collection of interviews."""
|
18
|
-
|
19
|
-
print(f"This will run {len(self.interviews)} interviews.")
|
@@ -1,54 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import threading
|
3
|
-
import time
|
4
|
-
import uuid
|
5
|
-
from edsl.data import CRUD
|
6
|
-
from edsl.jobs.JobsRunner import JobsRunner
|
7
|
-
from edsl.results import Results
|
8
|
-
|
9
|
-
|
10
|
-
class JobsRunnerStreaming(JobsRunner):
|
11
|
-
"""This JobRunner conducts interviews serially."""
|
12
|
-
|
13
|
-
runner_name = "streaming"
|
14
|
-
|
15
|
-
def run(
|
16
|
-
self,
|
17
|
-
debug: bool = False,
|
18
|
-
sleep: int = 0,
|
19
|
-
n: int = 1,
|
20
|
-
verbose: bool = False,
|
21
|
-
progress_bar: bool = False,
|
22
|
-
) -> Results:
|
23
|
-
"""
|
24
|
-
Conducts Interviews **serially** and returns their results.
|
25
|
-
- `n`: how many times to run each interview
|
26
|
-
- `debug`: prints debug messages
|
27
|
-
- `verbose`: prints messages
|
28
|
-
- `progress_bar`: shows a progress bar
|
29
|
-
"""
|
30
|
-
job_uuid = str(uuid.uuid4())
|
31
|
-
total_results = len(self.interviews)
|
32
|
-
|
33
|
-
def conduct_and_save_interviews():
|
34
|
-
for i, interview in enumerate(self.interviews):
|
35
|
-
answer = interview.conduct_interview(debug=debug)
|
36
|
-
CRUD.write_result(
|
37
|
-
job_uuid=job_uuid,
|
38
|
-
result_uuid=str(i),
|
39
|
-
agent=json.dumps(interview.agent.to_dict()),
|
40
|
-
scenario=json.dumps(interview.scenario.to_dict()),
|
41
|
-
model=json.dumps(interview.model.to_dict()),
|
42
|
-
answer=json.dumps(answer),
|
43
|
-
)
|
44
|
-
time.sleep(sleep)
|
45
|
-
|
46
|
-
interview_thread = threading.Thread(target=conduct_and_save_interviews)
|
47
|
-
interview_thread.start()
|
48
|
-
|
49
|
-
return Results(
|
50
|
-
survey=self.jobs.survey,
|
51
|
-
data=[],
|
52
|
-
job_uuid=job_uuid,
|
53
|
-
total_results=total_results,
|
54
|
-
)
|
edsl/jobs/task_management.py
DELETED
@@ -1,218 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import enum
|
3
|
-
from typing import Callable
|
4
|
-
from collections import UserDict, UserList
|
5
|
-
|
6
|
-
from edsl.jobs.buckets import ModelBuckets
|
7
|
-
from edsl.jobs.token_tracking import TokenUsage
|
8
|
-
from edsl.questions import Question
|
9
|
-
|
10
|
-
from edsl.exceptions import InterviewErrorPriorTaskCanceled
|
11
|
-
|
12
|
-
|
13
|
-
class TaskStatus(enum.Enum):
|
14
|
-
"These are the possible statuses for a task."
|
15
|
-
NOT_STARTED = enum.auto()
|
16
|
-
WAITING_ON_DEPENDENCIES = enum.auto()
|
17
|
-
CANCELLED = enum.auto()
|
18
|
-
PARENT_FAILED = enum.auto()
|
19
|
-
DEPENDENCIES_COMPLETE = enum.auto()
|
20
|
-
WAITING_FOR_REQUEST_CAPCITY = enum.auto()
|
21
|
-
REQUEST_CAPACITY_ACQUIRED = enum.auto()
|
22
|
-
WAITING_FOR_TOKEN_CAPCITY = enum.auto()
|
23
|
-
TOKEN_CAPACITY_ACQUIRED = enum.auto()
|
24
|
-
API_CALL_IN_PROGRESS = enum.auto()
|
25
|
-
API_CALL_COMPLETE = enum.auto()
|
26
|
-
|
27
|
-
|
28
|
-
class InterviewStatusDictionary(UserDict):
|
29
|
-
def __init__(self, data=None):
|
30
|
-
if data:
|
31
|
-
assert all([task_status in data for task_status in TaskStatus])
|
32
|
-
super().__init__(data)
|
33
|
-
else:
|
34
|
-
d = {}
|
35
|
-
for task_status in TaskStatus:
|
36
|
-
d[task_status] = 0
|
37
|
-
d["number_from_cache"] = 0
|
38
|
-
super().__init__(d)
|
39
|
-
|
40
|
-
def __add__(
|
41
|
-
self, other: "InterviewStatusDictionary"
|
42
|
-
) -> "InterviewStatusDictionary":
|
43
|
-
if not isinstance(other, InterviewStatusDictionary):
|
44
|
-
raise ValueError(f"Can't add {type(other)} to InterviewStatusDictionary")
|
45
|
-
new_dict = {}
|
46
|
-
for key in self.keys():
|
47
|
-
new_dict[key] = self[key] + other[key]
|
48
|
-
return InterviewStatusDictionary(new_dict)
|
49
|
-
|
50
|
-
def __repr__(self):
|
51
|
-
return f"InterviewStatusDictionary({self.data})"
|
52
|
-
|
53
|
-
|
54
|
-
# Configure logging
|
55
|
-
# logging.basicConfig(level=logging.INFO)
|
56
|
-
|
57
|
-
|
58
|
-
class TaskStatusDescriptor:
|
59
|
-
def __init__(self):
|
60
|
-
self._task_status = None
|
61
|
-
|
62
|
-
def __get__(self, instance, owner):
|
63
|
-
return self._task_status
|
64
|
-
|
65
|
-
def __set__(self, instance, value):
|
66
|
-
if not isinstance(value, TaskStatus):
|
67
|
-
raise ValueError("Value must be an instance of TaskStatus enum")
|
68
|
-
# logging.info(f"TaskStatus changed for {instance} from {self._task_status} to {value}")
|
69
|
-
self._task_status = value
|
70
|
-
|
71
|
-
def __delete__(self, instance):
|
72
|
-
self._task_status = None
|
73
|
-
|
74
|
-
|
75
|
-
class QuestionTaskCreator(UserList):
|
76
|
-
"""Class to create and manage question tasks with dependencies.
|
77
|
-
It is a UserList with all the tasks that must be completed before the focal task can be run.
|
78
|
-
When called, it returns an asyncio.Task that depends on the tasks that must be completed before it can be run.
|
79
|
-
"""
|
80
|
-
|
81
|
-
task_status = TaskStatusDescriptor()
|
82
|
-
|
83
|
-
def __init__(
|
84
|
-
self,
|
85
|
-
*,
|
86
|
-
question: Question,
|
87
|
-
answer_question_func: Callable,
|
88
|
-
model_buckets: ModelBuckets,
|
89
|
-
token_estimator: Callable = None,
|
90
|
-
):
|
91
|
-
super().__init__([])
|
92
|
-
self.answer_question_func = answer_question_func
|
93
|
-
self.question = question
|
94
|
-
|
95
|
-
self.model_buckets = model_buckets
|
96
|
-
self.requests_bucket = self.model_buckets.requests_bucket
|
97
|
-
self.tokens_bucket = self.model_buckets.tokens_bucket
|
98
|
-
self.token_estimator = token_estimator
|
99
|
-
|
100
|
-
self.from_cache = False
|
101
|
-
|
102
|
-
self.cached_token_usage = TokenUsage(from_cache=True)
|
103
|
-
self.new_token_usage = TokenUsage(from_cache=False)
|
104
|
-
|
105
|
-
self.task_status = TaskStatus.NOT_STARTED
|
106
|
-
|
107
|
-
def add_dependency(self, task) -> None:
|
108
|
-
"""Adds a task dependency to the list of dependencies."""
|
109
|
-
self.append(task)
|
110
|
-
|
111
|
-
def __repr__(self):
|
112
|
-
return f"QuestionTaskCreator for {self.question.question_name}"
|
113
|
-
|
114
|
-
def generate_task(self, debug) -> asyncio.Task:
|
115
|
-
"""Creates a task that depends on the passed-in dependencies."""
|
116
|
-
task = asyncio.create_task(self._run_task_async(debug))
|
117
|
-
task.edsl_name = self.question.question_name
|
118
|
-
task.depends_on = [x.edsl_name for x in self]
|
119
|
-
return task
|
120
|
-
|
121
|
-
def estimated_tokens(self) -> int:
|
122
|
-
"""Estimates the number of tokens that will be required to run the focal task."""
|
123
|
-
token_estimate = self.token_estimator(self.question)
|
124
|
-
return token_estimate
|
125
|
-
|
126
|
-
def token_usage(self) -> dict:
|
127
|
-
"""Returns the token usage for the task."""
|
128
|
-
return {
|
129
|
-
"cached_tokens": self.cached_token_usage,
|
130
|
-
"new_tokens": self.new_token_usage,
|
131
|
-
}
|
132
|
-
|
133
|
-
async def _run_focal_task(self, debug) -> "Answers":
|
134
|
-
"""Runs the focal task i.e., the question that we are interested in answering.
|
135
|
-
It is only called after all the dependency tasks are completed.
|
136
|
-
"""
|
137
|
-
|
138
|
-
requested_tokens = self.estimated_tokens()
|
139
|
-
if (estimated_wait_time := self.tokens_bucket.wait_time(requested_tokens)) > 0:
|
140
|
-
self.task_status = TaskStatus.WAITING_FOR_TOKEN_CAPCITY
|
141
|
-
|
142
|
-
await self.tokens_bucket.get_tokens(requested_tokens)
|
143
|
-
self.task_status = TaskStatus.TOKEN_CAPACITY_ACQUIRED
|
144
|
-
|
145
|
-
if (estimated_wait_time := self.requests_bucket.wait_time(1)) > 0:
|
146
|
-
self.waiting = True
|
147
|
-
self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPCITY
|
148
|
-
|
149
|
-
await self.requests_bucket.get_tokens(1)
|
150
|
-
self.task_status = TaskStatus.REQUEST_CAPACITY_ACQUIRED
|
151
|
-
|
152
|
-
self.task_status = TaskStatus.API_CALL_IN_PROGRESS
|
153
|
-
results = await self.answer_question_func(self.question, debug)
|
154
|
-
self.task_status = TaskStatus.API_CALL_COMPLETE
|
155
|
-
|
156
|
-
if "cached_response" in results:
|
157
|
-
if results["cached_response"]:
|
158
|
-
self.tokens_bucket.add_tokens(requested_tokens)
|
159
|
-
self.requests_bucket.add_tokens(1)
|
160
|
-
self.from_cache = True
|
161
|
-
|
162
|
-
tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
163
|
-
|
164
|
-
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
165
|
-
|
166
|
-
usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
167
|
-
prompt_tokens = usage.get("prompt_tokens", 0)
|
168
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
169
|
-
tracker.add_tokens(
|
170
|
-
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
171
|
-
)
|
172
|
-
|
173
|
-
return results
|
174
|
-
|
175
|
-
async def _run_task_async(self, debug) -> None:
|
176
|
-
"""Runs the task asynchronously, awaiting the tasks that must be completed before this one can be run."""
|
177
|
-
# logger.info(f"Running task for {self.question.question_name}")
|
178
|
-
try:
|
179
|
-
# This is waiting for the tasks that must be completed before this one can be run.
|
180
|
-
# This does *not* use the return_exceptions = True flag, so if any of the tasks fail,
|
181
|
-
# it throws the exception immediately, which is what we want.
|
182
|
-
self.task_status = TaskStatus.WAITING_ON_DEPENDENCIES
|
183
|
-
await asyncio.gather(*self)
|
184
|
-
except asyncio.CancelledError:
|
185
|
-
self.status = TaskStatus.CANCELLED
|
186
|
-
# logger.info(f"Task for {self.question.question_name} was cancelled, most likely because it was skipped.")
|
187
|
-
raise
|
188
|
-
except Exception as e:
|
189
|
-
self.task_status = TaskStatus.PARENT_FAILED
|
190
|
-
# logger.error(f"Required tasks for {self.question.question_name} failed: {e}")
|
191
|
-
# turns the parent exception into a custom exception
|
192
|
-
# So the task gets canceled but this InterviewErrorPriorTaskCanceled exception
|
193
|
-
# So we never get the question details we need.
|
194
|
-
raise InterviewErrorPriorTaskCanceled(
|
195
|
-
f"Required tasks failed for {self.question.question_name}"
|
196
|
-
) from e
|
197
|
-
else:
|
198
|
-
# logger.info(f"Tasks for {self.question.question_name} completed")
|
199
|
-
# This is the actual task that we want to run.
|
200
|
-
self.task_status = TaskStatus.DEPENDENCIES_COMPLETE
|
201
|
-
return await self._run_focal_task(debug)
|
202
|
-
|
203
|
-
|
204
|
-
class TasksList(UserList):
|
205
|
-
def status(self, debug=False):
|
206
|
-
if debug:
|
207
|
-
for task in self:
|
208
|
-
print(f"Task {task.edsl_name}")
|
209
|
-
print(f"\t DEPENDS ON: {task.depends_on}")
|
210
|
-
print(f"\t DONE: {task.done()}")
|
211
|
-
print(f"\t CANCELLED: {task.cancelled()}")
|
212
|
-
if not task.cancelled():
|
213
|
-
if task.done():
|
214
|
-
print(f"\t RESULT: {task.result()}")
|
215
|
-
else:
|
216
|
-
print(f"\t RESULT: None - Not done yet")
|
217
|
-
|
218
|
-
print("---------------------")
|