edsl 0.1.47__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +44 -39
- edsl/__version__.py +1 -1
- edsl/agents/__init__.py +4 -2
- edsl/agents/{Agent.py → agent.py} +442 -152
- edsl/agents/{AgentList.py → agent_list.py} +220 -162
- edsl/agents/descriptors.py +46 -7
- edsl/{exceptions/agents.py → agents/exceptions.py} +3 -12
- edsl/base/__init__.py +75 -0
- edsl/base/base_class.py +1303 -0
- edsl/base/data_transfer_models.py +114 -0
- edsl/base/enums.py +215 -0
- edsl/base.py +8 -0
- edsl/buckets/__init__.py +25 -0
- edsl/buckets/bucket_collection.py +324 -0
- edsl/buckets/model_buckets.py +206 -0
- edsl/buckets/token_bucket.py +502 -0
- edsl/{jobs/buckets/TokenBucketAPI.py → buckets/token_bucket_api.py} +1 -1
- edsl/buckets/token_bucket_client.py +509 -0
- edsl/caching/__init__.py +20 -0
- edsl/caching/cache.py +814 -0
- edsl/caching/cache_entry.py +427 -0
- edsl/{data/CacheHandler.py → caching/cache_handler.py} +14 -15
- edsl/caching/exceptions.py +24 -0
- edsl/caching/orm.py +30 -0
- edsl/{data/RemoteCacheSync.py → caching/remote_cache_sync.py} +3 -3
- edsl/caching/sql_dict.py +441 -0
- edsl/config/__init__.py +8 -0
- edsl/config/config_class.py +177 -0
- edsl/config.py +4 -176
- edsl/conversation/Conversation.py +7 -7
- edsl/conversation/car_buying.py +4 -4
- edsl/conversation/chips.py +6 -6
- edsl/coop/__init__.py +25 -2
- edsl/coop/coop.py +311 -75
- edsl/coop/{ExpectedParrotKeyHandler.py → ep_key_handling.py} +86 -10
- edsl/coop/exceptions.py +62 -0
- edsl/coop/price_fetcher.py +126 -0
- edsl/coop/utils.py +89 -24
- edsl/data_transfer_models.py +5 -72
- edsl/dataset/__init__.py +10 -0
- edsl/{results/Dataset.py → dataset/dataset.py} +116 -36
- edsl/{results/DatasetExportMixin.py → dataset/dataset_operations_mixin.py} +606 -122
- edsl/{results/DatasetTree.py → dataset/dataset_tree.py} +156 -75
- edsl/{results/TableDisplay.py → dataset/display/table_display.py} +18 -7
- edsl/{results → dataset/display}/table_renderers.py +58 -2
- edsl/{results → dataset}/file_exports.py +4 -5
- edsl/{results → dataset}/smart_objects.py +2 -2
- edsl/enums.py +5 -205
- edsl/inference_services/__init__.py +5 -0
- edsl/inference_services/{AvailableModelCacheHandler.py → available_model_cache_handler.py} +2 -3
- edsl/inference_services/{AvailableModelFetcher.py → available_model_fetcher.py} +8 -14
- edsl/inference_services/data_structures.py +3 -2
- edsl/{exceptions/inference_services.py → inference_services/exceptions.py} +1 -1
- edsl/inference_services/{InferenceServiceABC.py → inference_service_abc.py} +1 -1
- edsl/inference_services/{InferenceServicesCollection.py → inference_services_collection.py} +8 -7
- edsl/inference_services/registry.py +4 -41
- edsl/inference_services/{ServiceAvailability.py → service_availability.py} +5 -25
- edsl/inference_services/services/__init__.py +31 -0
- edsl/inference_services/{AnthropicService.py → services/anthropic_service.py} +3 -3
- edsl/inference_services/{AwsBedrock.py → services/aws_bedrock.py} +2 -2
- edsl/inference_services/{AzureAI.py → services/azure_ai.py} +2 -2
- edsl/inference_services/{DeepInfraService.py → services/deep_infra_service.py} +1 -3
- edsl/inference_services/{DeepSeekService.py → services/deep_seek_service.py} +2 -4
- edsl/inference_services/{GoogleService.py → services/google_service.py} +5 -4
- edsl/inference_services/{GroqService.py → services/groq_service.py} +1 -1
- edsl/inference_services/{MistralAIService.py → services/mistral_ai_service.py} +3 -3
- edsl/inference_services/{OllamaService.py → services/ollama_service.py} +1 -7
- edsl/inference_services/{OpenAIService.py → services/open_ai_service.py} +5 -6
- edsl/inference_services/{PerplexityService.py → services/perplexity_service.py} +3 -7
- edsl/inference_services/{TestService.py → services/test_service.py} +7 -6
- edsl/inference_services/{TogetherAIService.py → services/together_ai_service.py} +2 -6
- edsl/inference_services/{XAIService.py → services/xai_service.py} +1 -1
- edsl/inference_services/write_available.py +1 -2
- edsl/instructions/__init__.py +6 -0
- edsl/{surveys/instructions/Instruction.py → instructions/instruction.py} +11 -6
- edsl/{surveys/instructions/InstructionCollection.py → instructions/instruction_collection.py} +10 -5
- edsl/{surveys/InstructionHandler.py → instructions/instruction_handler.py} +3 -3
- edsl/{jobs/interviews → interviews}/ReportErrors.py +2 -2
- edsl/interviews/__init__.py +4 -0
- edsl/{jobs/AnswerQuestionFunctionConstructor.py → interviews/answering_function.py} +45 -18
- edsl/{jobs/interviews/InterviewExceptionEntry.py → interviews/exception_tracking.py} +107 -22
- edsl/interviews/interview.py +638 -0
- edsl/{jobs/interviews/InterviewStatusDictionary.py → interviews/interview_status_dictionary.py} +21 -12
- edsl/{jobs/interviews/InterviewStatusLog.py → interviews/interview_status_log.py} +16 -7
- edsl/{jobs/InterviewTaskManager.py → interviews/interview_task_manager.py} +12 -7
- edsl/{jobs/RequestTokenEstimator.py → interviews/request_token_estimator.py} +8 -3
- edsl/{jobs/interviews/InterviewStatistic.py → interviews/statistics.py} +36 -10
- edsl/invigilators/__init__.py +38 -0
- edsl/invigilators/invigilator_base.py +477 -0
- edsl/{agents/Invigilator.py → invigilators/invigilators.py} +263 -10
- edsl/invigilators/prompt_constructor.py +476 -0
- edsl/{agents → invigilators}/prompt_helpers.py +2 -1
- edsl/{agents/QuestionInstructionPromptBuilder.py → invigilators/question_instructions_prompt_builder.py} +18 -13
- edsl/{agents → invigilators}/question_option_processor.py +96 -21
- edsl/{agents/QuestionTemplateReplacementsBuilder.py → invigilators/question_template_replacements_builder.py} +64 -12
- edsl/jobs/__init__.py +7 -1
- edsl/jobs/async_interview_runner.py +99 -35
- edsl/jobs/check_survey_scenario_compatibility.py +7 -5
- edsl/jobs/data_structures.py +153 -22
- edsl/{exceptions/jobs.py → jobs/exceptions.py} +2 -1
- edsl/jobs/{FetchInvigilator.py → fetch_invigilator.py} +4 -4
- edsl/jobs/{loggers/HTMLTableJobLogger.py → html_table_job_logger.py} +6 -2
- edsl/jobs/{Jobs.py → jobs.py} +313 -167
- edsl/jobs/{JobsChecks.py → jobs_checks.py} +15 -7
- edsl/jobs/{JobsComponentConstructor.py → jobs_component_constructor.py} +19 -17
- edsl/jobs/{InterviewsConstructor.py → jobs_interview_constructor.py} +10 -5
- edsl/jobs/jobs_pricing_estimation.py +347 -0
- edsl/jobs/{JobsRemoteInferenceLogger.py → jobs_remote_inference_logger.py} +4 -3
- edsl/jobs/jobs_runner_asyncio.py +282 -0
- edsl/jobs/{JobsRemoteInferenceHandler.py → remote_inference.py} +19 -22
- edsl/jobs/results_exceptions_handler.py +2 -2
- edsl/key_management/__init__.py +28 -0
- edsl/key_management/key_lookup.py +161 -0
- edsl/{language_models/key_management/KeyLookupBuilder.py → key_management/key_lookup_builder.py} +118 -47
- edsl/key_management/key_lookup_collection.py +82 -0
- edsl/key_management/models.py +218 -0
- edsl/language_models/__init__.py +7 -2
- edsl/language_models/{ComputeCost.py → compute_cost.py} +18 -3
- edsl/{exceptions/language_models.py → language_models/exceptions.py} +2 -1
- edsl/language_models/language_model.py +1080 -0
- edsl/language_models/model.py +10 -25
- edsl/language_models/{ModelList.py → model_list.py} +9 -14
- edsl/language_models/{RawResponseHandler.py → raw_response_handler.py} +1 -1
- edsl/language_models/{RegisterLanguageModelsMeta.py → registry.py} +1 -1
- edsl/language_models/repair.py +4 -4
- edsl/language_models/utilities.py +4 -4
- edsl/notebooks/__init__.py +3 -1
- edsl/notebooks/{Notebook.py → notebook.py} +7 -8
- edsl/prompts/__init__.py +1 -1
- edsl/{exceptions/prompts.py → prompts/exceptions.py} +3 -1
- edsl/prompts/{Prompt.py → prompt.py} +101 -95
- edsl/questions/HTMLQuestion.py +1 -1
- edsl/questions/__init__.py +154 -25
- edsl/questions/answer_validator_mixin.py +1 -1
- edsl/questions/compose_questions.py +4 -3
- edsl/questions/derived/question_likert_five.py +166 -0
- edsl/questions/derived/{QuestionLinearScale.py → question_linear_scale.py} +4 -4
- edsl/questions/derived/{QuestionTopK.py → question_top_k.py} +4 -4
- edsl/questions/derived/{QuestionYesNo.py → question_yes_no.py} +4 -5
- edsl/questions/descriptors.py +24 -30
- edsl/questions/loop_processor.py +65 -19
- edsl/questions/question_base.py +881 -0
- edsl/questions/question_base_gen_mixin.py +15 -16
- edsl/questions/{QuestionBasePromptsMixin.py → question_base_prompts_mixin.py} +2 -2
- edsl/questions/{QuestionBudget.py → question_budget.py} +3 -4
- edsl/questions/{QuestionCheckBox.py → question_check_box.py} +16 -16
- edsl/questions/{QuestionDict.py → question_dict.py} +39 -5
- edsl/questions/{QuestionExtract.py → question_extract.py} +9 -9
- edsl/questions/question_free_text.py +282 -0
- edsl/questions/{QuestionFunctional.py → question_functional.py} +6 -5
- edsl/questions/{QuestionList.py → question_list.py} +6 -7
- edsl/questions/{QuestionMatrix.py → question_matrix.py} +6 -5
- edsl/questions/{QuestionMultipleChoice.py → question_multiple_choice.py} +126 -21
- edsl/questions/{QuestionNumerical.py → question_numerical.py} +5 -5
- edsl/questions/{QuestionRank.py → question_rank.py} +6 -6
- edsl/questions/question_registry.py +4 -9
- edsl/questions/register_questions_meta.py +8 -4
- edsl/questions/response_validator_abc.py +17 -16
- edsl/results/__init__.py +4 -1
- edsl/{exceptions/results.py → results/exceptions.py} +1 -1
- edsl/results/report.py +197 -0
- edsl/results/{Result.py → result.py} +131 -45
- edsl/results/{Results.py → results.py} +365 -220
- edsl/results/results_selector.py +344 -25
- edsl/scenarios/__init__.py +30 -3
- edsl/scenarios/{ConstructDownloadLink.py → construct_download_link.py} +7 -0
- edsl/scenarios/directory_scanner.py +156 -13
- edsl/scenarios/document_chunker.py +186 -0
- edsl/scenarios/exceptions.py +101 -0
- edsl/scenarios/file_methods.py +2 -3
- edsl/scenarios/{FileStore.py → file_store.py} +275 -189
- edsl/scenarios/handlers/__init__.py +14 -14
- edsl/scenarios/handlers/{csv.py → csv_file_store.py} +1 -2
- edsl/scenarios/handlers/{docx.py → docx_file_store.py} +8 -7
- edsl/scenarios/handlers/{html.py → html_file_store.py} +1 -2
- edsl/scenarios/handlers/{jpeg.py → jpeg_file_store.py} +1 -1
- edsl/scenarios/handlers/{json.py → json_file_store.py} +1 -1
- edsl/scenarios/handlers/latex_file_store.py +5 -0
- edsl/scenarios/handlers/{md.py → md_file_store.py} +1 -1
- edsl/scenarios/handlers/{pdf.py → pdf_file_store.py} +2 -2
- edsl/scenarios/handlers/{png.py → png_file_store.py} +1 -1
- edsl/scenarios/handlers/{pptx.py → pptx_file_store.py} +8 -7
- edsl/scenarios/handlers/{py.py → py_file_store.py} +1 -3
- edsl/scenarios/handlers/{sql.py → sql_file_store.py} +2 -1
- edsl/scenarios/handlers/{sqlite.py → sqlite_file_store.py} +2 -3
- edsl/scenarios/handlers/{txt.py → txt_file_store.py} +1 -1
- edsl/scenarios/scenario.py +928 -0
- edsl/scenarios/scenario_join.py +18 -5
- edsl/scenarios/{ScenarioList.py → scenario_list.py} +294 -106
- edsl/scenarios/{ScenarioListPdfMixin.py → scenario_list_pdf_tools.py} +16 -15
- edsl/scenarios/scenario_selector.py +5 -1
- edsl/study/ObjectEntry.py +2 -2
- edsl/study/SnapShot.py +5 -5
- edsl/study/Study.py +18 -19
- edsl/study/__init__.py +6 -4
- edsl/surveys/__init__.py +7 -4
- edsl/surveys/dag/__init__.py +2 -0
- edsl/surveys/{ConstructDAG.py → dag/construct_dag.py} +3 -3
- edsl/surveys/{DAG.py → dag/dag.py} +13 -10
- edsl/surveys/descriptors.py +1 -1
- edsl/surveys/{EditSurvey.py → edit_survey.py} +9 -9
- edsl/{exceptions/surveys.py → surveys/exceptions.py} +1 -2
- edsl/surveys/memory/__init__.py +3 -0
- edsl/surveys/{MemoryPlan.py → memory/memory_plan.py} +10 -9
- edsl/surveys/rules/__init__.py +3 -0
- edsl/surveys/{Rule.py → rules/rule.py} +103 -43
- edsl/surveys/{RuleCollection.py → rules/rule_collection.py} +21 -30
- edsl/surveys/{RuleManager.py → rules/rule_manager.py} +19 -13
- edsl/surveys/survey.py +1743 -0
- edsl/surveys/{SurveyExportMixin.py → survey_export.py} +22 -27
- edsl/surveys/{SurveyFlowVisualization.py → survey_flow_visualization.py} +11 -2
- edsl/surveys/{Simulator.py → survey_simulator.py} +10 -3
- edsl/tasks/__init__.py +32 -0
- edsl/{jobs/tasks/QuestionTaskCreator.py → tasks/question_task_creator.py} +115 -57
- edsl/tasks/task_creators.py +135 -0
- edsl/{jobs/tasks/TaskHistory.py → tasks/task_history.py} +86 -47
- edsl/{jobs/tasks → tasks}/task_status_enum.py +91 -7
- edsl/tasks/task_status_log.py +85 -0
- edsl/tokens/__init__.py +2 -0
- edsl/tokens/interview_token_usage.py +53 -0
- edsl/utilities/PrettyList.py +1 -1
- edsl/utilities/SystemInfo.py +25 -22
- edsl/utilities/__init__.py +29 -21
- edsl/utilities/gcp_bucket/__init__.py +2 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +99 -96
- edsl/utilities/interface.py +44 -536
- edsl/{results/MarkdownToPDF.py → utilities/markdown_to_pdf.py} +13 -5
- edsl/utilities/repair_functions.py +1 -1
- {edsl-0.1.47.dist-info → edsl-0.1.49.dist-info}/METADATA +1 -1
- edsl-0.1.49.dist-info/RECORD +347 -0
- edsl/Base.py +0 -493
- edsl/BaseDiff.py +0 -260
- edsl/agents/InvigilatorBase.py +0 -260
- edsl/agents/PromptConstructor.py +0 -318
- edsl/coop/PriceFetcher.py +0 -54
- edsl/data/Cache.py +0 -582
- edsl/data/CacheEntry.py +0 -238
- edsl/data/SQLiteDict.py +0 -292
- edsl/data/__init__.py +0 -5
- edsl/data/orm.py +0 -10
- edsl/exceptions/cache.py +0 -5
- edsl/exceptions/coop.py +0 -14
- edsl/exceptions/data.py +0 -14
- edsl/exceptions/scenarios.py +0 -29
- edsl/jobs/Answers.py +0 -43
- edsl/jobs/JobsPrompts.py +0 -354
- edsl/jobs/buckets/BucketCollection.py +0 -134
- edsl/jobs/buckets/ModelBuckets.py +0 -65
- edsl/jobs/buckets/TokenBucket.py +0 -283
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/interviews/Interview.py +0 -395
- edsl/jobs/interviews/InterviewExceptionCollection.py +0 -99
- edsl/jobs/interviews/InterviewStatisticsCollection.py +0 -25
- edsl/jobs/runners/JobsRunnerAsyncio.py +0 -163
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -0
- edsl/jobs/tasks/TaskCreators.py +0 -64
- edsl/jobs/tasks/TaskStatusLog.py +0 -23
- edsl/jobs/tokens/InterviewTokenUsage.py +0 -27
- edsl/language_models/LanguageModel.py +0 -635
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/models.py +0 -137
- edsl/questions/QuestionBase.py +0 -544
- edsl/questions/QuestionFreeText.py +0 -130
- edsl/questions/derived/QuestionLikertFive.py +0 -76
- edsl/results/ResultsExportMixin.py +0 -45
- edsl/results/TextEditor.py +0 -50
- edsl/results/results_fetch_mixin.py +0 -33
- edsl/results/results_tools_mixin.py +0 -98
- edsl/scenarios/DocumentChunker.py +0 -104
- edsl/scenarios/Scenario.py +0 -548
- edsl/scenarios/ScenarioHtmlMixin.py +0 -65
- edsl/scenarios/ScenarioListExportMixin.py +0 -45
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/shared.py +0 -1
- edsl/surveys/Survey.py +0 -1301
- edsl/surveys/SurveyQualtricsImport.py +0 -284
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/tools/__init__.py +0 -1
- edsl/tools/clusters.py +0 -192
- edsl/tools/embeddings.py +0 -27
- edsl/tools/embeddings_plotting.py +0 -118
- edsl/tools/plotting.py +0 -112
- edsl/tools/summarize.py +0 -18
- edsl/utilities/data/Registry.py +0 -6
- edsl/utilities/data/__init__.py +0 -1
- edsl/utilities/data/scooter_results.json +0 -1
- edsl-0.1.47.dist-info/RECORD +0 -354
- /edsl/coop/{CoopFunctionsMixin.py → coop_functions.py} +0 -0
- /edsl/{results → dataset/display}/CSSParameterizer.py +0 -0
- /edsl/{language_models/key_management → dataset/display}/__init__.py +0 -0
- /edsl/{results → dataset/display}/table_data_class.py +0 -0
- /edsl/{results → dataset/display}/table_display.css +0 -0
- /edsl/{results/ResultsGGMixin.py → dataset/r/ggplot.py} +0 -0
- /edsl/{results → dataset}/tree_explore.py +0 -0
- /edsl/{surveys/instructions/ChangeInstruction.py → instructions/change_instruction.py} +0 -0
- /edsl/{jobs/interviews → interviews}/interview_status_enum.py +0 -0
- /edsl/jobs/{runners/JobsRunnerStatus.py → jobs_runner_status.py} +0 -0
- /edsl/language_models/{PriceManager.py → price_manager.py} +0 -0
- /edsl/language_models/{fake_openai_call.py → unused/fake_openai_call.py} +0 -0
- /edsl/language_models/{fake_openai_service.py → unused/fake_openai_service.py} +0 -0
- /edsl/notebooks/{NotebookToLaTeX.py → notebook_to_latex.py} +0 -0
- /edsl/{exceptions/questions.py → questions/exceptions.py} +0 -0
- /edsl/questions/{SimpleAskMixin.py → simple_ask_mixin.py} +0 -0
- /edsl/surveys/{Memory.py → memory/memory.py} +0 -0
- /edsl/surveys/{MemoryManagement.py → memory/memory_management.py} +0 -0
- /edsl/surveys/{SurveyCSS.py → survey_css.py} +0 -0
- /edsl/{jobs/tokens/TokenUsage.py → tokens/token_usage.py} +0 -0
- /edsl/{results/MarkdownToDocx.py → utilities/markdown_to_docx.py} +0 -0
- /edsl/{TemplateLoader.py → utilities/template_loader.py} +0 -0
- {edsl-0.1.47.dist-info → edsl-0.1.49.dist-info}/LICENSE +0 -0
- {edsl-0.1.47.dist-info → edsl-0.1.49.dist-info}/WHEEL +0 -0
@@ -0,0 +1,114 @@
|
|
1
|
+
from collections import UserDict
|
2
|
+
from typing import NamedTuple, Dict, List, Optional, Any
|
3
|
+
from dataclasses import dataclass, fields
|
4
|
+
|
5
|
+
|
6
|
+
class ModelInputs(NamedTuple):
|
7
|
+
"This is what was send by the agent to the model"
|
8
|
+
user_prompt: str
|
9
|
+
system_prompt: str
|
10
|
+
encoded_image: Optional[str] = None
|
11
|
+
|
12
|
+
|
13
|
+
class EDSLOutput(NamedTuple):
|
14
|
+
"This is the edsl dictionary that is returned by the model"
|
15
|
+
answer: Any
|
16
|
+
generated_tokens: str
|
17
|
+
comment: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelResponse(NamedTuple):
|
21
|
+
"This is the metadata that is returned by the model and includes info about the cache"
|
22
|
+
response: dict
|
23
|
+
cache_used: bool
|
24
|
+
cache_key: str
|
25
|
+
cached_response: Optional[Dict[str, Any]] = None
|
26
|
+
cost: Optional[float] = None
|
27
|
+
|
28
|
+
|
29
|
+
class AgentResponseDict(NamedTuple):
|
30
|
+
edsl_dict: EDSLOutput
|
31
|
+
model_inputs: ModelInputs
|
32
|
+
model_outputs: ModelResponse
|
33
|
+
|
34
|
+
|
35
|
+
class EDSLResultObjectInput(NamedTuple):
|
36
|
+
generated_tokens: str
|
37
|
+
question_name: str
|
38
|
+
prompts: dict
|
39
|
+
cached_response: str
|
40
|
+
raw_model_response: str
|
41
|
+
cache_used: bool
|
42
|
+
cache_key: str
|
43
|
+
answer: Any
|
44
|
+
comment: str
|
45
|
+
validated: bool = False
|
46
|
+
exception_occurred: Exception = None
|
47
|
+
cost: Optional[float] = None
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class ImageInfo:
|
52
|
+
file_path: str
|
53
|
+
file_name: str
|
54
|
+
image_format: str
|
55
|
+
file_size: int
|
56
|
+
encoded_image: str
|
57
|
+
|
58
|
+
def __repr__(self):
|
59
|
+
import reprlib
|
60
|
+
|
61
|
+
reprlib_instance = reprlib.Repr()
|
62
|
+
reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
|
63
|
+
|
64
|
+
# Get all fields except encoded_image
|
65
|
+
field_reprs = [
|
66
|
+
f"{f.name}={getattr(self, f.name)!r}"
|
67
|
+
for f in fields(self)
|
68
|
+
if f.name != "encoded_image"
|
69
|
+
]
|
70
|
+
|
71
|
+
# Add the reprlib-restricted encoded_image field
|
72
|
+
field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
|
73
|
+
|
74
|
+
# Join everything to create the repr
|
75
|
+
return f"{self.__class__.__name__}({', '.join(field_reprs)})"
|
76
|
+
|
77
|
+
|
78
|
+
class Answers(UserDict):
|
79
|
+
"""Helper class to hold the answers to a survey."""
|
80
|
+
|
81
|
+
def add_answer(
|
82
|
+
self, response: EDSLResultObjectInput, question: "QuestionBase"
|
83
|
+
) -> None:
|
84
|
+
"""Add a response to the answers dictionary."""
|
85
|
+
answer = response.answer
|
86
|
+
comment = response.comment
|
87
|
+
generated_tokens = response.generated_tokens
|
88
|
+
# record the answer
|
89
|
+
if generated_tokens:
|
90
|
+
self[question.question_name + "_generated_tokens"] = generated_tokens
|
91
|
+
self[question.question_name] = answer
|
92
|
+
if comment:
|
93
|
+
self[question.question_name + "_comment"] = comment
|
94
|
+
|
95
|
+
def replace_missing_answers_with_none(self, survey: "Survey") -> None:
|
96
|
+
"""Replace missing answers with None. Answers can be missing if the agent skips a question."""
|
97
|
+
for question_name in survey.question_names:
|
98
|
+
if question_name not in self:
|
99
|
+
self[question_name] = None
|
100
|
+
|
101
|
+
def to_dict(self):
|
102
|
+
"""Return a dictionary of the answers."""
|
103
|
+
return self.data
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def from_dict(cls, d):
|
107
|
+
"""Return an Answers object from a dictionary."""
|
108
|
+
return cls(d)
|
109
|
+
|
110
|
+
|
111
|
+
if __name__ == "__main__":
|
112
|
+
import doctest
|
113
|
+
|
114
|
+
doctest.testmod()
|
edsl/base/enums.py
ADDED
@@ -0,0 +1,215 @@
|
|
1
|
+
"""Enums for the different types of questions, language models, and inference services."""
|
2
|
+
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
|
7
|
+
class EnumWithChecks(Enum):
|
8
|
+
"""Base class for all enums with checks."""
|
9
|
+
|
10
|
+
@classmethod
|
11
|
+
def is_value_valid(cls, value):
|
12
|
+
"""Check if the value is valid."""
|
13
|
+
return any(value == item.value for item in cls)
|
14
|
+
|
15
|
+
|
16
|
+
class QuestionType(EnumWithChecks):
|
17
|
+
"""Enum for the question types."""
|
18
|
+
|
19
|
+
MULTIPLE_CHOICE = "multiple_choice"
|
20
|
+
YES_NO = "yes_no"
|
21
|
+
FREE_TEXT = "free_text"
|
22
|
+
RANK = "rank"
|
23
|
+
BUDGET = "budget"
|
24
|
+
CHECKBOX = "checkbox"
|
25
|
+
EXTRACT = "extract"
|
26
|
+
FUNCTIONAL = "functional"
|
27
|
+
LIST = "list"
|
28
|
+
NUMERICAL = "numerical"
|
29
|
+
TOP_K = "top_k"
|
30
|
+
LIKERT_FIVE = "likert_five"
|
31
|
+
LINEAR_SCALE = "linear_scale"
|
32
|
+
|
33
|
+
|
34
|
+
# https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
35
|
+
|
36
|
+
|
37
|
+
# class LanguageModelType(EnumWithChecks):
|
38
|
+
# """Enum for the language model types."""
|
39
|
+
|
40
|
+
# GPT_4 = "gpt-4-1106-preview"
|
41
|
+
# GPT_3_5_Turbo = "gpt-3.5-turbo"
|
42
|
+
# LLAMA_2_70B_CHAT_HF = "llama-2-70b-chat-hf"
|
43
|
+
# LLAMA_2_13B_CHAT_HF = "llama-2-13b-chat-hf"
|
44
|
+
# GEMINI_PRO = "gemini_pro"
|
45
|
+
# MIXTRAL_8x7B_INSTRUCT = "mixtral-8x7B-instruct-v0.1"
|
46
|
+
# TEST = "test"
|
47
|
+
# ANTHROPIC_3_OPUS = "claude-3-opus-20240229"
|
48
|
+
# ANTHROPIC_3_SONNET = "claude-3-sonnet-20240229"
|
49
|
+
# ANTHROPIC_3_HAIKU = "claude-3-haiku-20240307"
|
50
|
+
# DBRX_INSTRUCT = "dbrx-instruct"
|
51
|
+
|
52
|
+
|
53
|
+
class InferenceServiceType(EnumWithChecks):
|
54
|
+
"""Enum for the inference service types."""
|
55
|
+
|
56
|
+
BEDROCK = "bedrock"
|
57
|
+
DEEP_INFRA = "deep_infra"
|
58
|
+
REPLICATE = "replicate"
|
59
|
+
OPENAI = "openai"
|
60
|
+
GOOGLE = "google"
|
61
|
+
TEST = "test"
|
62
|
+
ANTHROPIC = "anthropic"
|
63
|
+
GROQ = "groq"
|
64
|
+
AZURE = "azure"
|
65
|
+
OLLAMA = "ollama"
|
66
|
+
MISTRAL = "mistral"
|
67
|
+
TOGETHER = "together"
|
68
|
+
PERPLEXITY = "perplexity"
|
69
|
+
DEEPSEEK = "deepseek"
|
70
|
+
XAI = "xai"
|
71
|
+
|
72
|
+
|
73
|
+
# unavoidable violation of the DRY principle but it is necessary
|
74
|
+
# checked w/ a unit test to make sure consistent with services in enums.py
|
75
|
+
InferenceServiceLiteral = Literal[
|
76
|
+
"bedrock",
|
77
|
+
"deep_infra",
|
78
|
+
"replicate",
|
79
|
+
"openai",
|
80
|
+
"google",
|
81
|
+
"test",
|
82
|
+
"anthropic",
|
83
|
+
"groq",
|
84
|
+
"azure",
|
85
|
+
"ollama",
|
86
|
+
"mistral",
|
87
|
+
"together",
|
88
|
+
"perplexity",
|
89
|
+
"deepseek",
|
90
|
+
"xai",
|
91
|
+
]
|
92
|
+
|
93
|
+
available_models_urls = {
|
94
|
+
"anthropic": "https://docs.anthropic.com/en/docs/about-claude/models",
|
95
|
+
"openai": "https://platform.openai.com/docs/models/gp",
|
96
|
+
"groq": "https://console.groq.com/docs/models",
|
97
|
+
"google": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models",
|
98
|
+
}
|
99
|
+
|
100
|
+
|
101
|
+
service_to_api_keyname = {
|
102
|
+
InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
|
103
|
+
InferenceServiceType.REPLICATE.value: "TBD",
|
104
|
+
InferenceServiceType.OPENAI.value: "OPENAI_API_KEY",
|
105
|
+
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
106
|
+
InferenceServiceType.TEST.value: "TBD",
|
107
|
+
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
108
|
+
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
109
|
+
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
110
|
+
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
111
|
+
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
112
|
+
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
113
|
+
InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
|
114
|
+
InferenceServiceType.XAI.value: "XAI_API_KEY",
|
115
|
+
}
|
116
|
+
|
117
|
+
|
118
|
+
class TokenPricing:
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
*,
|
122
|
+
model_name,
|
123
|
+
prompt_token_price_per_k: float,
|
124
|
+
completion_token_price_per_k: float,
|
125
|
+
):
|
126
|
+
self.model_name = model_name
|
127
|
+
self.prompt_token_price = prompt_token_price_per_k / 1_000.0
|
128
|
+
self.completion_token_price = completion_token_price_per_k / 1_000.0
|
129
|
+
|
130
|
+
def __eq__(self, other):
|
131
|
+
if not isinstance(other, TokenPricing):
|
132
|
+
return False
|
133
|
+
return (
|
134
|
+
self.model_name == other.model_name
|
135
|
+
and self.prompt_token_price == other.prompt_token_price
|
136
|
+
and self.completion_token_price == other.completion_token_price
|
137
|
+
)
|
138
|
+
|
139
|
+
@classmethod
|
140
|
+
def example(cls) -> "TokenPricing":
|
141
|
+
"""Return an example TokenPricing object."""
|
142
|
+
return cls(
|
143
|
+
model_name="fake_model",
|
144
|
+
prompt_token_price_per_k=0.01,
|
145
|
+
completion_token_price_per_k=0.03,
|
146
|
+
)
|
147
|
+
|
148
|
+
pricing = {
|
149
|
+
"dbrx-instruct": TokenPricing(
|
150
|
+
model_name="dbrx-instruct",
|
151
|
+
prompt_token_price_per_k=0.0,
|
152
|
+
completion_token_price_per_k=0.0,
|
153
|
+
),
|
154
|
+
"claude-3-opus-20240229": TokenPricing(
|
155
|
+
model_name="claude-3-opus-20240229",
|
156
|
+
prompt_token_price_per_k=0.0,
|
157
|
+
completion_token_price_per_k=0.0,
|
158
|
+
),
|
159
|
+
"claude-3-haiku-20240307": TokenPricing(
|
160
|
+
model_name="claude-3-haiku-20240307",
|
161
|
+
prompt_token_price_per_k=0.0,
|
162
|
+
completion_token_price_per_k=0.0,
|
163
|
+
),
|
164
|
+
"claude-3-sonnet-20240229": TokenPricing(
|
165
|
+
model_name="claude-3-sonnet-20240229",
|
166
|
+
prompt_token_price_per_k=0.0,
|
167
|
+
completion_token_price_per_k=0.0,
|
168
|
+
),
|
169
|
+
"gpt-3.5-turbo": TokenPricing(
|
170
|
+
model_name="gpt-3.5-turbo",
|
171
|
+
prompt_token_price_per_k=0.0005,
|
172
|
+
completion_token_price_per_k=0.0015,
|
173
|
+
),
|
174
|
+
"gpt-4-1106-preview": TokenPricing(
|
175
|
+
model_name="gpt-4",
|
176
|
+
prompt_token_price_per_k=0.01,
|
177
|
+
completion_token_price_per_k=0.03,
|
178
|
+
),
|
179
|
+
"test": TokenPricing(
|
180
|
+
model_name="test",
|
181
|
+
prompt_token_price_per_k=0.0,
|
182
|
+
completion_token_price_per_k=0.0,
|
183
|
+
),
|
184
|
+
"gemini_pro": TokenPricing(
|
185
|
+
model_name="gemini_pro",
|
186
|
+
prompt_token_price_per_k=0.0,
|
187
|
+
completion_token_price_per_k=0.0,
|
188
|
+
),
|
189
|
+
"llama-2-13b-chat-hf": TokenPricing(
|
190
|
+
model_name="llama-2-13b-chat-hf",
|
191
|
+
prompt_token_price_per_k=0.0,
|
192
|
+
completion_token_price_per_k=0.0,
|
193
|
+
),
|
194
|
+
"llama-2-70b-chat-hf": TokenPricing(
|
195
|
+
model_name="llama-2-70b-chat-hf",
|
196
|
+
prompt_token_price_per_k=0.0,
|
197
|
+
completion_token_price_per_k=0.0,
|
198
|
+
),
|
199
|
+
"mixtral-8x7B-instruct-v0.1": TokenPricing(
|
200
|
+
model_name="mixtral-8x7B-instruct-v0.1",
|
201
|
+
prompt_token_price_per_k=0.0,
|
202
|
+
completion_token_price_per_k=0.0,
|
203
|
+
),
|
204
|
+
}
|
205
|
+
|
206
|
+
|
207
|
+
def get_token_pricing(model_name):
|
208
|
+
if model_name in pricing:
|
209
|
+
return pricing[model_name]
|
210
|
+
else:
|
211
|
+
return TokenPricing(
|
212
|
+
model_name=model_name,
|
213
|
+
prompt_token_price_per_k=0.0,
|
214
|
+
completion_token_price_per_k=0.0,
|
215
|
+
)
|
edsl/base.py
ADDED
edsl/buckets/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
Buckets module for managing rate limits of language model API requests.
|
3
|
+
|
4
|
+
This module provides a robust rate-limiting system for language model API calls,
|
5
|
+
implementing the token bucket algorithm to manage both requests-per-minute and
|
6
|
+
tokens-per-minute limits. It supports both local (in-process) and remote
|
7
|
+
(distributed) rate limiting through a client-server architecture.
|
8
|
+
|
9
|
+
Key components:
|
10
|
+
- TokenBucket: Core rate-limiting class implementing the token bucket algorithm
|
11
|
+
- ModelBuckets: Manages rate limits for a specific language model, containing
|
12
|
+
separate buckets for requests and tokens
|
13
|
+
- BucketCollection: Manages multiple ModelBuckets instances across different
|
14
|
+
language model services
|
15
|
+
|
16
|
+
The module also includes a FastAPI server implementation (token_bucket_api) and
|
17
|
+
client (token_bucket_client) for distributed rate limiting scenarios where
|
18
|
+
multiple processes or machines need to share rate limits.
|
19
|
+
"""
|
20
|
+
|
21
|
+
from .bucket_collection import BucketCollection
|
22
|
+
from .model_buckets import ModelBuckets
|
23
|
+
from .token_bucket import TokenBucket
|
24
|
+
|
25
|
+
__all__ = ["BucketCollection", "ModelBuckets", "TokenBucket"]
|
@@ -0,0 +1,324 @@
|
|
1
|
+
"""
|
2
|
+
BucketCollection module for managing rate limits across multiple language models.
|
3
|
+
|
4
|
+
This module provides the BucketCollection class, which manages rate limits for
|
5
|
+
multiple language models, organizing them by service provider. It ensures that
|
6
|
+
API rate limits are respected while allowing models from the same service to
|
7
|
+
share the same rate limit buckets.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from typing import Optional, TYPE_CHECKING, Dict, List, Any, Tuple
|
11
|
+
from collections import UserDict
|
12
|
+
from threading import RLock
|
13
|
+
import matplotlib.pyplot as plt
|
14
|
+
from matplotlib.figure import Figure
|
15
|
+
|
16
|
+
from .token_bucket import TokenBucket
|
17
|
+
from .model_buckets import ModelBuckets
|
18
|
+
from ..jobs.decorators import synchronized_class
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from ..language_models import LanguageModel
|
22
|
+
from ..key_management import KeyLookup
|
23
|
+
|
24
|
+
@synchronized_class
|
25
|
+
class BucketCollection(UserDict):
|
26
|
+
"""
|
27
|
+
Collection of ModelBuckets for managing rate limits across multiple language models.
|
28
|
+
|
29
|
+
BucketCollection is a thread-safe dictionary-like container that maps language models
|
30
|
+
to their corresponding ModelBuckets objects. It helps manage rate limits for multiple
|
31
|
+
models, organizing them by service provider to ensure that API rate limits are
|
32
|
+
respected across all models using the same service.
|
33
|
+
|
34
|
+
The class maps models to services, and services to buckets, allowing models from
|
35
|
+
the same service to share rate limit buckets. This approach ensures accurate
|
36
|
+
rate limiting when multiple models use the same underlying service.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
infinity_buckets (bool): If True, all buckets have infinite capacity and refill rate
|
40
|
+
models_to_services (dict): Maps model names to their service provider names
|
41
|
+
services_to_buckets (dict): Maps service names to their ModelBuckets instances
|
42
|
+
remote_url (str, optional): URL for remote token bucket server if using distributed mode
|
43
|
+
|
44
|
+
Example:
|
45
|
+
>>> from edsl import Model
|
46
|
+
>>> bucket_collection = BucketCollection()
|
47
|
+
>>> model = Model('gpt-4')
|
48
|
+
>>> bucket_collection.add_model(model)
|
49
|
+
>>> # Now rate limits for the model are being tracked
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(self, infinity_buckets: bool = False):
|
53
|
+
"""
|
54
|
+
Initialize a new BucketCollection.
|
55
|
+
|
56
|
+
Creates a new BucketCollection to manage rate limits across multiple language
|
57
|
+
models. If infinity_buckets is True, all buckets will have unlimited capacity
|
58
|
+
and refill rate, effectively bypassing rate limits.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
infinity_buckets: If True, creates buckets with unlimited capacity
|
62
|
+
and refill rate (default: False)
|
63
|
+
|
64
|
+
Example:
|
65
|
+
>>> # Create a standard bucket collection with rate limiting
|
66
|
+
>>> bucket_collection = BucketCollection()
|
67
|
+
>>> # Create a bucket collection with unlimited capacity (for testing)
|
68
|
+
>>> unlimited_collection = BucketCollection(infinity_buckets=True)
|
69
|
+
"""
|
70
|
+
super().__init__()
|
71
|
+
self.infinity_buckets = infinity_buckets
|
72
|
+
self.models_to_services = {} # Maps model names to service names
|
73
|
+
self.services_to_buckets = {} # Maps service names to ModelBuckets
|
74
|
+
self._lock = RLock()
|
75
|
+
|
76
|
+
# Check for remote token bucket server URL in environment
|
77
|
+
from edsl.config import CONFIG
|
78
|
+
import os
|
79
|
+
|
80
|
+
url = os.environ.get("EDSL_REMOTE_TOKEN_BUCKET_URL", None)
|
81
|
+
|
82
|
+
if url == "None" or url is None:
|
83
|
+
self.remote_url = None
|
84
|
+
else:
|
85
|
+
self.remote_url = url
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def from_models(
|
89
|
+
cls, models_list: List["LanguageModel"], infinity_buckets: bool = False
|
90
|
+
) -> "BucketCollection":
|
91
|
+
"""
|
92
|
+
Create a BucketCollection pre-populated with a list of models.
|
93
|
+
|
94
|
+
This factory method creates a new BucketCollection and adds multiple
|
95
|
+
models to it at initialization time.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
models_list: List of LanguageModel instances to add to the collection
|
99
|
+
infinity_buckets: If True, creates buckets with unlimited capacity
|
100
|
+
and refill rate (default: False)
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
A new BucketCollection containing the specified models
|
104
|
+
|
105
|
+
Example:
|
106
|
+
>>> from edsl import Model
|
107
|
+
>>> models = [Model('gpt-4'), Model('gpt-3.5-turbo')]
|
108
|
+
>>> collection = BucketCollection.from_models(models)
|
109
|
+
"""
|
110
|
+
bucket_collection = cls(infinity_buckets=infinity_buckets)
|
111
|
+
for model in models_list:
|
112
|
+
bucket_collection.add_model(model)
|
113
|
+
return bucket_collection
|
114
|
+
|
115
|
+
def get_tokens(
|
116
|
+
self, model: 'LanguageModel', bucket_type: str, num_tokens: int
|
117
|
+
) -> int:
|
118
|
+
"""
|
119
|
+
[DEPRECATED] Get the number of tokens remaining in the bucket.
|
120
|
+
|
121
|
+
This method is deprecated and will raise an exception if called.
|
122
|
+
It is kept for reference purposes only.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
model: The language model to get tokens for
|
126
|
+
bucket_type: The type of bucket ('requests' or 'tokens')
|
127
|
+
num_tokens: The number of tokens to retrieve
|
128
|
+
|
129
|
+
Raises:
|
130
|
+
Exception: This method is deprecated
|
131
|
+
|
132
|
+
Example:
|
133
|
+
>>> bucket_collection = BucketCollection()
|
134
|
+
>>> from edsl import Model
|
135
|
+
>>> m = Model('test')
|
136
|
+
>>> bucket_collection.add_model(m)
|
137
|
+
>>> # The following would raise an exception:
|
138
|
+
>>> # bucket_collection.get_tokens(m, 'tokens', 10)
|
139
|
+
"""
|
140
|
+
raise Exception("This method is deprecated and should not be used")
|
141
|
+
# The following code is kept for reference only
|
142
|
+
# relevant_bucket = getattr(self[model], bucket_type)
|
143
|
+
# return relevant_bucket.get_tokens(num_tokens)
|
144
|
+
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
"""
|
147
|
+
Generate a string representation of the BucketCollection.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
String representation showing the collection's contents
|
151
|
+
"""
|
152
|
+
return f"BucketCollection({self.data})"
|
153
|
+
|
154
|
+
def add_model(self, model: "LanguageModel") -> None:
|
155
|
+
"""
|
156
|
+
Add a language model to the bucket collection.
|
157
|
+
|
158
|
+
This method adds a language model to the BucketCollection, creating the
|
159
|
+
necessary token buckets for its service provider if they don't already exist.
|
160
|
+
Models from the same service share the same buckets.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
model: The LanguageModel instance to add to the collection
|
164
|
+
|
165
|
+
Example:
|
166
|
+
>>> from edsl import Model
|
167
|
+
>>> model = Model('gpt-4')
|
168
|
+
>>> bucket_collection = BucketCollection()
|
169
|
+
>>> bucket_collection.add_model(model)
|
170
|
+
"""
|
171
|
+
# Calculate tokens-per-second (TPS) and requests-per-second (RPS) rates
|
172
|
+
if not self.infinity_buckets:
|
173
|
+
TPS = model.tpm / 60.0 # Convert tokens-per-minute to tokens-per-second
|
174
|
+
RPS = model.rpm / 60.0 # Convert requests-per-minute to requests-per-second
|
175
|
+
else:
|
176
|
+
TPS = float("inf") # Infinite tokens per second
|
177
|
+
RPS = float("inf") # Infinite requests per second
|
178
|
+
|
179
|
+
# If this is a new model we haven't seen before
|
180
|
+
if model.model not in self.models_to_services:
|
181
|
+
service = model._inference_service_
|
182
|
+
|
183
|
+
# If this is a new service we haven't created buckets for yet
|
184
|
+
if service not in self.services_to_buckets:
|
185
|
+
# Create request rate limiting bucket
|
186
|
+
requests_bucket = TokenBucket(
|
187
|
+
bucket_name=service,
|
188
|
+
bucket_type="requests",
|
189
|
+
capacity=RPS,
|
190
|
+
refill_rate=RPS,
|
191
|
+
remote_url=self.remote_url,
|
192
|
+
)
|
193
|
+
|
194
|
+
# Create token rate limiting bucket
|
195
|
+
tokens_bucket = TokenBucket(
|
196
|
+
bucket_name=service,
|
197
|
+
bucket_type="tokens",
|
198
|
+
capacity=TPS,
|
199
|
+
refill_rate=TPS,
|
200
|
+
remote_url=self.remote_url,
|
201
|
+
)
|
202
|
+
|
203
|
+
# Store the buckets for this service
|
204
|
+
self.services_to_buckets[service] = ModelBuckets(
|
205
|
+
requests_bucket, tokens_bucket
|
206
|
+
)
|
207
|
+
|
208
|
+
# Map this model to its service and buckets
|
209
|
+
self.models_to_services[model.model] = service
|
210
|
+
self[model] = self.services_to_buckets[service]
|
211
|
+
else:
|
212
|
+
# Model already exists, just retrieve its existing buckets
|
213
|
+
self[model] = self.services_to_buckets[self.models_to_services[model.model]]
|
214
|
+
|
215
|
+
def update_from_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
216
|
+
"""
|
217
|
+
Update bucket rate limits based on information from KeyLookup.
|
218
|
+
|
219
|
+
This method updates the capacity and refill rates of all buckets based on
|
220
|
+
the RPM (requests per minute) and TPM (tokens per minute) limits specified
|
221
|
+
in the provided KeyLookup. This is useful when API keys are rotated or
|
222
|
+
rate limits change.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
key_lookup: KeyLookup object containing service rate limit information
|
226
|
+
|
227
|
+
Example:
|
228
|
+
>>> from edsl.key_management import KeyLookup
|
229
|
+
>>> key_lookup = KeyLookup() # Assume this has rate limit info
|
230
|
+
>>> bucket_collection = BucketCollection()
|
231
|
+
>>> # Add some models to the collection
|
232
|
+
>>> bucket_collection.update_from_key_lookup(key_lookup)
|
233
|
+
>>> # Now rate limits are updated based on key_lookup
|
234
|
+
"""
|
235
|
+
# Skip updates if we're using infinite buckets
|
236
|
+
if self.infinity_buckets:
|
237
|
+
return
|
238
|
+
|
239
|
+
# Update each service with new rate limits
|
240
|
+
for model_name, service in self.models_to_services.items():
|
241
|
+
if service in key_lookup:
|
242
|
+
# Update request rate limits if available
|
243
|
+
if key_lookup[service].rpm is not None:
|
244
|
+
new_rps = key_lookup[service].rpm / 60.0 # Convert to per-second
|
245
|
+
new_requests_bucket = TokenBucket(
|
246
|
+
bucket_name=service,
|
247
|
+
bucket_type="requests",
|
248
|
+
capacity=new_rps,
|
249
|
+
refill_rate=new_rps,
|
250
|
+
remote_url=self.remote_url,
|
251
|
+
)
|
252
|
+
self.services_to_buckets[service].requests_bucket = new_requests_bucket
|
253
|
+
|
254
|
+
# Update token rate limits if available
|
255
|
+
if key_lookup[service].tpm is not None:
|
256
|
+
new_tps = key_lookup[service].tpm / 60.0 # Convert to per-second
|
257
|
+
new_tokens_bucket = TokenBucket(
|
258
|
+
bucket_name=service,
|
259
|
+
bucket_type="tokens",
|
260
|
+
capacity=new_tps,
|
261
|
+
refill_rate=new_tps,
|
262
|
+
remote_url=self.remote_url,
|
263
|
+
)
|
264
|
+
self.services_to_buckets[service].tokens_bucket = new_tokens_bucket
|
265
|
+
|
266
|
+
def visualize(self) -> Dict["LanguageModel", Tuple[Figure, Figure]]:
|
267
|
+
"""
|
268
|
+
Visualize the token and request buckets for all models.
|
269
|
+
|
270
|
+
This method generates visualization plots for each model's token and
|
271
|
+
request buckets, which can be useful for monitoring rate limit usage
|
272
|
+
and debugging rate limiting issues.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
Dictionary mapping language models to tuples of (request_plot, token_plot)
|
276
|
+
|
277
|
+
Example:
|
278
|
+
>>> bucket_collection = BucketCollection()
|
279
|
+
>>> # Add some models
|
280
|
+
>>> plots = bucket_collection.visualize()
|
281
|
+
>>> # Now you can display or save these plots
|
282
|
+
"""
|
283
|
+
plots = {}
|
284
|
+
for model in self:
|
285
|
+
plots[model] = self[model].visualize()
|
286
|
+
return plots
|
287
|
+
|
288
|
+
|
289
|
+
# Examples and doctests
|
290
|
+
if __name__ == "__main__":
|
291
|
+
import doctest
|
292
|
+
|
293
|
+
# Example showing how to use BucketCollection
|
294
|
+
def example_usage():
|
295
|
+
"""
|
296
|
+
Example demonstrating how to use BucketCollection:
|
297
|
+
|
298
|
+
>>> from edsl import Model
|
299
|
+
>>> # Create models
|
300
|
+
>>> gpt4 = Model('gpt-4')
|
301
|
+
>>> gpt35 = Model('gpt-3.5-turbo')
|
302
|
+
>>> claude = Model('claude-3-opus-20240229')
|
303
|
+
>>>
|
304
|
+
>>> # Create bucket collection
|
305
|
+
>>> collection = BucketCollection()
|
306
|
+
>>>
|
307
|
+
>>> # Add models to the collection
|
308
|
+
>>> collection.add_model(gpt4)
|
309
|
+
>>> collection.add_model(gpt35)
|
310
|
+
>>> collection.add_model(claude)
|
311
|
+
>>>
|
312
|
+
>>> # Models from the same service share rate limits
|
313
|
+
>>> print(collection[gpt4] is collection[gpt35]) # Both OpenAI
|
314
|
+
True
|
315
|
+
>>> print(collection[gpt4] is collection[claude]) # Different services
|
316
|
+
False
|
317
|
+
>>>
|
318
|
+
>>> # Visualize rate limits
|
319
|
+
>>> # plots = collection.visualize()
|
320
|
+
"""
|
321
|
+
pass
|
322
|
+
|
323
|
+
# Run doctests
|
324
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|