edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +3 -8
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -40
- edsl/agents/AgentList.py +0 -43
- edsl/agents/Invigilator.py +219 -135
- edsl/agents/InvigilatorBase.py +59 -148
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
- edsl/agents/__init__.py +0 -1
- edsl/config.py +56 -47
- edsl/coop/coop.py +7 -50
- edsl/data/Cache.py +1 -35
- edsl/data_transfer_models.py +38 -73
- edsl/enums.py +0 -4
- edsl/exceptions/language_models.py +1 -25
- edsl/exceptions/questions.py +5 -62
- edsl/exceptions/results.py +0 -4
- edsl/inference_services/AnthropicService.py +11 -13
- edsl/inference_services/AwsBedrock.py +17 -19
- edsl/inference_services/AzureAI.py +20 -37
- edsl/inference_services/GoogleService.py +12 -16
- edsl/inference_services/GroqService.py +0 -2
- edsl/inference_services/InferenceServiceABC.py +3 -58
- edsl/inference_services/OpenAIService.py +54 -48
- edsl/inference_services/models_available_cache.py +6 -0
- edsl/inference_services/registry.py +0 -6
- edsl/jobs/Answers.py +12 -10
- edsl/jobs/Jobs.py +21 -36
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +14 -93
- edsl/jobs/interviews/Interview.py +78 -366
- edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
- edsl/jobs/interviews/retry_management.py +37 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
- edsl/jobs/tasks/TaskHistory.py +213 -148
- edsl/language_models/LanguageModel.py +156 -261
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
- edsl/language_models/registry.py +6 -23
- edsl/language_models/repair.py +19 -0
- edsl/prompts/Prompt.py +2 -52
- edsl/questions/AnswerValidatorMixin.py +26 -23
- edsl/questions/QuestionBase.py +249 -329
- edsl/questions/QuestionBudget.py +41 -99
- edsl/questions/QuestionCheckBox.py +35 -227
- edsl/questions/QuestionExtract.py +27 -98
- edsl/questions/QuestionFreeText.py +29 -52
- edsl/questions/QuestionFunctional.py +0 -7
- edsl/questions/QuestionList.py +22 -141
- edsl/questions/QuestionMultipleChoice.py +65 -159
- edsl/questions/QuestionNumerical.py +46 -88
- edsl/questions/QuestionRank.py +24 -182
- edsl/questions/RegisterQuestionsMeta.py +12 -31
- edsl/questions/__init__.py +4 -3
- edsl/questions/derived/QuestionLikertFive.py +5 -10
- edsl/questions/derived/QuestionLinearScale.py +2 -15
- edsl/questions/derived/QuestionTopK.py +1 -10
- edsl/questions/derived/QuestionYesNo.py +3 -24
- edsl/questions/descriptors.py +7 -43
- edsl/questions/question_registry.py +2 -6
- edsl/results/Dataset.py +0 -20
- edsl/results/DatasetExportMixin.py +48 -46
- edsl/results/Result.py +5 -32
- edsl/results/Results.py +46 -135
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +25 -96
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +39 -361
- edsl/scenarios/ScenarioListExportMixin.py +0 -9
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/SnapShot.py +1 -8
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +5 -21
- edsl/surveys/Survey.py +310 -636
- edsl/surveys/SurveyExportMixin.py +9 -71
- edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
- edsl/surveys/SurveyQualtricsImport.py +4 -75
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
- edsl-0.1.33.dev1.dist-info/RECORD +209 -0
- edsl/TemplateLoader.py +0 -24
- edsl/auto/AutoStudy.py +0 -117
- edsl/auto/StageBase.py +0 -230
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -73
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -224
- edsl/coop/PriceFetcher.py +0 -58
- edsl/inference_services/MistralAIService.py +0 -120
- edsl/inference_services/TestService.py +0 -80
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/language_models/fake_openai_call.py +0 -15
- edsl/language_models/fake_openai_service.py +0 -61
- edsl/language_models/utilities.py +0 -61
- edsl/questions/QuestionBaseGenMixin.py +0 -133
- edsl/questions/QuestionBasePromptsMixin.py +0 -266
- edsl/questions/Quick.py +0 -41
- edsl/questions/ResponseValidatorABC.py +0 -170
- edsl/questions/decorators.py +0 -21
- edsl/questions/prompt_templates/question_budget.jinja +0 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
- edsl/questions/prompt_templates/question_extract.jinja +0 -11
- edsl/questions/prompt_templates/question_free_text.jinja +0 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
- edsl/questions/prompt_templates/question_list.jinja +0 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
- edsl/questions/prompt_templates/question_numerical.jinja +0 -37
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +0 -7
- edsl/questions/templates/extract/question_presentation.jinja +0 -1
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +0 -1
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +0 -4
- edsl/questions/templates/list/question_presentation.jinja +0 -5
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
- edsl/questions/templates/numerical/question_presentation.jinja +0 -7
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +0 -11
- edsl/questions/templates/rank/question_presentation.jinja +0 -15
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
- edsl/questions/templates/top_k/question_presentation.jinja +0 -22
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
- edsl/results/DatasetTree.py +0 -145
- edsl/results/Selector.py +0 -118
- edsl/results/tree_explore.py +0 -115
- edsl/surveys/instructions/ChangeInstruction.py +0 -47
- edsl/surveys/instructions/Instruction.py +0 -34
- edsl/surveys/instructions/InstructionCollection.py +0 -77
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +0 -24
- edsl/templates/error_reporting/exceptions_by_model.html +0 -35
- edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
- edsl/templates/error_reporting/exceptions_by_type.html +0 -17
- edsl/templates/error_reporting/interview_details.html +0 -116
- edsl/templates/error_reporting/interviews.html +0 -10
- edsl/templates/error_reporting/overview.html +0 -5
- edsl/templates/error_reporting/performance_plot.html +0 -2
- edsl/templates/error_reporting/report.css +0 -74
- edsl/templates/error_reporting/report.html +0 -118
- edsl/templates/error_reporting/report.js +0 -25
- edsl-0.1.33.dist-info/RECORD +0 -295
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -156,11 +156,7 @@ class Jobs(Base):
|
|
156
156
|
from edsl.results.Dataset import Dataset
|
157
157
|
|
158
158
|
for interview_index, interview in enumerate(interviews):
|
159
|
-
invigilators =
|
160
|
-
interview._get_invigilator(question)
|
161
|
-
for question in self.survey.questions
|
162
|
-
]
|
163
|
-
# list(interview._build_invigilators(debug=False))
|
159
|
+
invigilators = list(interview._build_invigilators(debug=False))
|
164
160
|
for _, invigilator in enumerate(invigilators):
|
165
161
|
prompts = invigilator.get_prompts()
|
166
162
|
user_prompts.append(prompts["user_prompt"])
|
@@ -348,7 +344,6 @@ class Jobs(Base):
|
|
348
344
|
scenario=scenario,
|
349
345
|
model=model,
|
350
346
|
skip_retry=self.skip_retry,
|
351
|
-
raise_validation_errors=self.raise_validation_errors,
|
352
347
|
)
|
353
348
|
|
354
349
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -460,44 +455,33 @@ class Jobs(Base):
|
|
460
455
|
if warn:
|
461
456
|
warnings.warn(message)
|
462
457
|
|
463
|
-
if self.scenarios.has_jinja_braces:
|
464
|
-
warnings.warn(
|
465
|
-
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
466
|
-
)
|
467
|
-
self.scenarios = self.scenarios.convert_jinja_braces()
|
468
|
-
|
469
458
|
@property
|
470
459
|
def skip_retry(self):
|
471
460
|
if not hasattr(self, "_skip_retry"):
|
472
461
|
return False
|
473
462
|
return self._skip_retry
|
474
463
|
|
475
|
-
@property
|
476
|
-
def raise_validation_errors(self):
|
477
|
-
if not hasattr(self, "_raise_validation_errors"):
|
478
|
-
return False
|
479
|
-
return self._raise_validation_errors
|
480
|
-
|
481
464
|
def run(
|
482
465
|
self,
|
483
466
|
n: int = 1,
|
467
|
+
debug: bool = False,
|
484
468
|
progress_bar: bool = False,
|
485
469
|
stop_on_exception: bool = False,
|
486
470
|
cache: Union[Cache, bool] = None,
|
487
471
|
check_api_keys: bool = False,
|
488
472
|
sidecar_model: Optional[LanguageModel] = None,
|
473
|
+
batch_mode: Optional[bool] = None,
|
489
474
|
verbose: bool = False,
|
490
475
|
print_exceptions=True,
|
491
476
|
remote_cache_description: Optional[str] = None,
|
492
477
|
remote_inference_description: Optional[str] = None,
|
493
478
|
skip_retry: bool = False,
|
494
|
-
raise_validation_errors: bool = False,
|
495
|
-
disable_remote_inference: bool = False,
|
496
479
|
) -> Results:
|
497
480
|
"""
|
498
481
|
Runs the Job: conducts Interviews and returns their results.
|
499
482
|
|
500
483
|
:param n: how many times to run each interview
|
484
|
+
:param debug: prints debug messages
|
501
485
|
:param progress_bar: shows a progress bar
|
502
486
|
:param stop_on_exception: stops the job if an exception is raised
|
503
487
|
:param cache: a cache object to store results
|
@@ -511,21 +495,22 @@ class Jobs(Base):
|
|
511
495
|
|
512
496
|
self._check_parameters()
|
513
497
|
self._skip_retry = skip_retry
|
514
|
-
self._raise_validation_errors = raise_validation_errors
|
515
498
|
|
516
|
-
|
499
|
+
if batch_mode is not None:
|
500
|
+
raise NotImplementedError(
|
501
|
+
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
502
|
+
)
|
517
503
|
|
518
|
-
|
519
|
-
remote_inference = False
|
504
|
+
self.verbose = verbose
|
520
505
|
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
506
|
+
try:
|
507
|
+
coop = Coop()
|
508
|
+
user_edsl_settings = coop.edsl_settings
|
509
|
+
remote_cache = user_edsl_settings["remote_caching"]
|
510
|
+
remote_inference = user_edsl_settings["remote_inference"]
|
511
|
+
except Exception:
|
512
|
+
remote_cache = False
|
513
|
+
remote_inference = False
|
529
514
|
|
530
515
|
if remote_inference:
|
531
516
|
import time
|
@@ -602,7 +587,7 @@ class Jobs(Base):
|
|
602
587
|
)
|
603
588
|
|
604
589
|
# handle cache
|
605
|
-
if cache is None
|
590
|
+
if cache is None:
|
606
591
|
from edsl.data.CacheHandler import CacheHandler
|
607
592
|
|
608
593
|
cache = CacheHandler().get_cache()
|
@@ -614,12 +599,12 @@ class Jobs(Base):
|
|
614
599
|
if not remote_cache:
|
615
600
|
results = self._run_local(
|
616
601
|
n=n,
|
602
|
+
debug=debug,
|
617
603
|
progress_bar=progress_bar,
|
618
604
|
cache=cache,
|
619
605
|
stop_on_exception=stop_on_exception,
|
620
606
|
sidecar_model=sidecar_model,
|
621
607
|
print_exceptions=print_exceptions,
|
622
|
-
raise_validation_errors=raise_validation_errors,
|
623
608
|
)
|
624
609
|
|
625
610
|
results.cache = cache.new_entries_cache()
|
@@ -658,12 +643,12 @@ class Jobs(Base):
|
|
658
643
|
self._output("Running job...")
|
659
644
|
results = self._run_local(
|
660
645
|
n=n,
|
646
|
+
debug=debug,
|
661
647
|
progress_bar=progress_bar,
|
662
648
|
cache=cache,
|
663
649
|
stop_on_exception=stop_on_exception,
|
664
650
|
sidecar_model=sidecar_model,
|
665
651
|
print_exceptions=print_exceptions,
|
666
|
-
raise_validation_errors=raise_validation_errors,
|
667
652
|
)
|
668
653
|
self._output("Job completed!")
|
669
654
|
|
@@ -898,7 +883,7 @@ def main():
|
|
898
883
|
|
899
884
|
job = Jobs.example()
|
900
885
|
len(job) == 8
|
901
|
-
results = job.run(cache=Cache())
|
886
|
+
results = job.run(debug=True, cache=Cache())
|
902
887
|
len(results) == 8
|
903
888
|
results
|
904
889
|
|
@@ -13,8 +13,6 @@ class BucketCollection(UserDict):
|
|
13
13
|
def __init__(self, infinity_buckets=False):
|
14
14
|
super().__init__()
|
15
15
|
self.infinity_buckets = infinity_buckets
|
16
|
-
self.models_to_services = {}
|
17
|
-
self.services_to_buckets = {}
|
18
16
|
|
19
17
|
def __repr__(self):
|
20
18
|
return f"BucketCollection({self.data})"
|
@@ -23,7 +21,6 @@ class BucketCollection(UserDict):
|
|
23
21
|
"""Adds a model to the bucket collection.
|
24
22
|
|
25
23
|
This will create the token and request buckets for the model."""
|
26
|
-
|
27
24
|
# compute the TPS and RPS from the model
|
28
25
|
if not self.infinity_buckets:
|
29
26
|
TPS = model.TPM / 60.0
|
@@ -32,28 +29,22 @@ class BucketCollection(UserDict):
|
|
32
29
|
TPS = float("inf")
|
33
30
|
RPS = float("inf")
|
34
31
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
)
|
50
|
-
self.services_to_buckets[service] = ModelBuckets(
|
51
|
-
requests_bucket, tokens_bucket
|
52
|
-
)
|
53
|
-
self.models_to_services[model.model] = service
|
54
|
-
self[model] = self.services_to_buckets[service]
|
32
|
+
# create the buckets
|
33
|
+
requests_bucket = TokenBucket(
|
34
|
+
bucket_name=model.model,
|
35
|
+
bucket_type="requests",
|
36
|
+
capacity=RPS,
|
37
|
+
refill_rate=RPS,
|
38
|
+
)
|
39
|
+
tokens_bucket = TokenBucket(
|
40
|
+
bucket_name=model.model, bucket_type="tokens", capacity=TPS, refill_rate=TPS
|
41
|
+
)
|
42
|
+
model_buckets = ModelBuckets(requests_bucket, tokens_bucket)
|
43
|
+
if model in self:
|
44
|
+
# it if already exists, combine the buckets
|
45
|
+
self[model] += model_buckets
|
55
46
|
else:
|
56
|
-
self[model] =
|
47
|
+
self[model] = model_buckets
|
57
48
|
|
58
49
|
def visualize(self) -> dict:
|
59
50
|
"""Visualize the token and request buckets for each model."""
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Union, List, Any
|
1
|
+
from typing import Union, List, Any
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
4
|
|
@@ -17,12 +17,6 @@ class TokenBucket:
|
|
17
17
|
self.bucket_name = bucket_name
|
18
18
|
self.bucket_type = bucket_type
|
19
19
|
self.capacity = capacity # Maximum number of tokens
|
20
|
-
self.added_tokens = 0
|
21
|
-
|
22
|
-
self.target_rate = (
|
23
|
-
capacity * 60
|
24
|
-
) # set this here because it can change with turbo mode
|
25
|
-
|
26
20
|
self._old_capacity = capacity
|
27
21
|
self.tokens = capacity # Current number of available tokens
|
28
22
|
self.refill_rate = refill_rate # Rate at which tokens are refilled
|
@@ -31,12 +25,6 @@ class TokenBucket:
|
|
31
25
|
self.log: List[Any] = []
|
32
26
|
self.turbo_mode = False
|
33
27
|
|
34
|
-
self.creation_time = time.monotonic()
|
35
|
-
|
36
|
-
self.num_requests = 0
|
37
|
-
self.num_released = 0
|
38
|
-
self.tokens_returned = 0
|
39
|
-
|
40
28
|
def turbo_mode_on(self):
|
41
29
|
"""Set the refill rate to infinity."""
|
42
30
|
if self.turbo_mode:
|
@@ -81,7 +69,6 @@ class TokenBucket:
|
|
81
69
|
>>> bucket.tokens
|
82
70
|
10
|
83
71
|
"""
|
84
|
-
self.tokens_returned += tokens
|
85
72
|
self.tokens = min(self.capacity, self.tokens + tokens)
|
86
73
|
self.log.append((time.monotonic(), self.tokens))
|
87
74
|
|
@@ -95,30 +82,23 @@ class TokenBucket:
|
|
95
82
|
>>> bucket.refill()
|
96
83
|
>>> bucket.tokens > 0
|
97
84
|
True
|
85
|
+
|
98
86
|
"""
|
99
|
-
"""Refill the bucket with new tokens based on elapsed time."""
|
100
87
|
now = time.monotonic()
|
101
|
-
# print(f"Time is now: {now}; Last refill time: {self.last_refill}")
|
102
88
|
elapsed = now - self.last_refill
|
103
|
-
# print("Elapsed time: ", elapsed)
|
104
89
|
refill_amount = elapsed * self.refill_rate
|
105
90
|
self.tokens = min(self.capacity, self.tokens + refill_amount)
|
106
91
|
self.last_refill = now
|
107
92
|
|
108
|
-
if self.tokens < self.capacity:
|
109
|
-
pass
|
110
|
-
# print(f"Refilled. Current tokens: {self.tokens:.4f}")
|
111
|
-
# print(f"Elapsed time: {elapsed:.4f} seconds")
|
112
|
-
# print(f"Refill amount: {refill_amount:.4f}")
|
113
|
-
|
114
93
|
self.log.append((now, self.tokens))
|
115
94
|
|
116
95
|
def wait_time(self, requested_tokens: Union[float, int]) -> float:
|
117
96
|
"""Calculate the time to wait for the requested number of tokens."""
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
97
|
+
now = time.monotonic()
|
98
|
+
elapsed = now - self.last_refill
|
99
|
+
refill_amount = elapsed * self.refill_rate
|
100
|
+
available_tokens = min(self.capacity, self.tokens + refill_amount)
|
101
|
+
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
122
102
|
|
123
103
|
async def get_tokens(
|
124
104
|
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
@@ -143,33 +123,22 @@ class TokenBucket:
|
|
143
123
|
...
|
144
124
|
ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
|
145
125
|
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
|
146
|
-
>>> bucket.capacity
|
147
|
-
12.100000000000001
|
148
126
|
"""
|
149
|
-
self.
|
150
|
-
if amount >= self.capacity:
|
127
|
+
if amount > self.capacity:
|
151
128
|
if not cheat_bucket_capacity:
|
152
129
|
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
153
130
|
raise ValueError(msg)
|
154
131
|
else:
|
155
|
-
self.
|
156
|
-
|
132
|
+
self.tokens = 0 # clear the bucket but let it go through
|
133
|
+
return
|
157
134
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
self.tokens -= amount
|
163
|
-
break
|
135
|
+
while self.tokens < amount:
|
136
|
+
self.refill()
|
137
|
+
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|
138
|
+
self.tokens -= amount
|
164
139
|
|
165
|
-
wait_time = self.wait_time(amount)
|
166
|
-
if wait_time > 0:
|
167
|
-
await asyncio.sleep(wait_time)
|
168
|
-
|
169
|
-
self.num_released += amount
|
170
140
|
now = time.monotonic()
|
171
141
|
self.log.append((now, self.tokens))
|
172
|
-
return None
|
173
142
|
|
174
143
|
def get_log(self) -> list[tuple]:
|
175
144
|
return self.log
|
@@ -193,54 +162,6 @@ class TokenBucket:
|
|
193
162
|
plt.tight_layout()
|
194
163
|
plt.show()
|
195
164
|
|
196
|
-
def get_throughput(self, time_window: Optional[float] = None) -> float:
|
197
|
-
"""
|
198
|
-
Calculate the empirical bucket throughput in tokens per minute for the specified time window.
|
199
|
-
|
200
|
-
:param time_window: The time window in seconds to calculate the throughput for.
|
201
|
-
:return: The throughput in tokens per minute.
|
202
|
-
|
203
|
-
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=100, refill_rate=10)
|
204
|
-
>>> asyncio.run(bucket.get_tokens(50))
|
205
|
-
>>> time.sleep(1) # Wait for 1 second
|
206
|
-
>>> asyncio.run(bucket.get_tokens(30))
|
207
|
-
>>> throughput = bucket.get_throughput(1)
|
208
|
-
>>> 4750 < throughput < 4850
|
209
|
-
True
|
210
|
-
"""
|
211
|
-
now = time.monotonic()
|
212
|
-
|
213
|
-
if time_window is None:
|
214
|
-
start_time = self.creation_time
|
215
|
-
else:
|
216
|
-
start_time = now - time_window
|
217
|
-
|
218
|
-
if start_time < self.creation_time:
|
219
|
-
start_time = self.creation_time
|
220
|
-
|
221
|
-
elapsed_time = now - start_time
|
222
|
-
|
223
|
-
return (self.num_released / elapsed_time) * 60
|
224
|
-
|
225
|
-
# # Filter log entries within the time window
|
226
|
-
# relevant_log = [(t, tokens) for t, tokens in self.log if t >= start_time]
|
227
|
-
|
228
|
-
# if len(relevant_log) < 2:
|
229
|
-
# return 0 # Not enough data points to calculate throughput
|
230
|
-
|
231
|
-
# # Calculate total tokens used
|
232
|
-
# initial_tokens = relevant_log[0][1]
|
233
|
-
# final_tokens = relevant_log[-1][1]
|
234
|
-
# tokens_used = self.num_released - (final_tokens - initial_tokens)
|
235
|
-
|
236
|
-
# # Calculate actual time elapsed
|
237
|
-
# actual_time_elapsed = relevant_log[-1][0] - relevant_log[0][0]
|
238
|
-
|
239
|
-
# # Calculate throughput in tokens per minute
|
240
|
-
# throughput = (tokens_used / actual_time_elapsed) * 60
|
241
|
-
|
242
|
-
# return throughput
|
243
|
-
|
244
165
|
|
245
166
|
if __name__ == "__main__":
|
246
167
|
import doctest
|