edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +3 -8
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -40
  5. edsl/agents/AgentList.py +0 -43
  6. edsl/agents/Invigilator.py +219 -135
  7. edsl/agents/InvigilatorBase.py +59 -148
  8. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
  9. edsl/agents/__init__.py +0 -1
  10. edsl/config.py +56 -47
  11. edsl/coop/coop.py +7 -50
  12. edsl/data/Cache.py +1 -35
  13. edsl/data_transfer_models.py +38 -73
  14. edsl/enums.py +0 -4
  15. edsl/exceptions/language_models.py +1 -25
  16. edsl/exceptions/questions.py +5 -62
  17. edsl/exceptions/results.py +0 -4
  18. edsl/inference_services/AnthropicService.py +11 -13
  19. edsl/inference_services/AwsBedrock.py +17 -19
  20. edsl/inference_services/AzureAI.py +20 -37
  21. edsl/inference_services/GoogleService.py +12 -16
  22. edsl/inference_services/GroqService.py +0 -2
  23. edsl/inference_services/InferenceServiceABC.py +3 -58
  24. edsl/inference_services/OpenAIService.py +54 -48
  25. edsl/inference_services/models_available_cache.py +6 -0
  26. edsl/inference_services/registry.py +0 -6
  27. edsl/jobs/Answers.py +12 -10
  28. edsl/jobs/Jobs.py +21 -36
  29. edsl/jobs/buckets/BucketCollection.py +15 -24
  30. edsl/jobs/buckets/TokenBucket.py +14 -93
  31. edsl/jobs/interviews/Interview.py +78 -366
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
  33. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
  34. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
  35. edsl/jobs/interviews/retry_management.py +37 -0
  36. edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
  37. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  38. edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
  39. edsl/jobs/tasks/TaskHistory.py +213 -148
  40. edsl/language_models/LanguageModel.py +156 -261
  41. edsl/language_models/ModelList.py +2 -2
  42. edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
  43. edsl/language_models/registry.py +6 -23
  44. edsl/language_models/repair.py +19 -0
  45. edsl/prompts/Prompt.py +2 -52
  46. edsl/questions/AnswerValidatorMixin.py +26 -23
  47. edsl/questions/QuestionBase.py +249 -329
  48. edsl/questions/QuestionBudget.py +41 -99
  49. edsl/questions/QuestionCheckBox.py +35 -227
  50. edsl/questions/QuestionExtract.py +27 -98
  51. edsl/questions/QuestionFreeText.py +29 -52
  52. edsl/questions/QuestionFunctional.py +0 -7
  53. edsl/questions/QuestionList.py +22 -141
  54. edsl/questions/QuestionMultipleChoice.py +65 -159
  55. edsl/questions/QuestionNumerical.py +46 -88
  56. edsl/questions/QuestionRank.py +24 -182
  57. edsl/questions/RegisterQuestionsMeta.py +12 -31
  58. edsl/questions/__init__.py +4 -3
  59. edsl/questions/derived/QuestionLikertFive.py +5 -10
  60. edsl/questions/derived/QuestionLinearScale.py +2 -15
  61. edsl/questions/derived/QuestionTopK.py +1 -10
  62. edsl/questions/derived/QuestionYesNo.py +3 -24
  63. edsl/questions/descriptors.py +7 -43
  64. edsl/questions/question_registry.py +2 -6
  65. edsl/results/Dataset.py +0 -20
  66. edsl/results/DatasetExportMixin.py +48 -46
  67. edsl/results/Result.py +5 -32
  68. edsl/results/Results.py +46 -135
  69. edsl/results/ResultsDBMixin.py +3 -3
  70. edsl/scenarios/FileStore.py +10 -71
  71. edsl/scenarios/Scenario.py +25 -96
  72. edsl/scenarios/ScenarioImageMixin.py +2 -2
  73. edsl/scenarios/ScenarioList.py +39 -361
  74. edsl/scenarios/ScenarioListExportMixin.py +0 -9
  75. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  76. edsl/study/SnapShot.py +1 -8
  77. edsl/study/Study.py +0 -32
  78. edsl/surveys/Rule.py +1 -10
  79. edsl/surveys/RuleCollection.py +5 -21
  80. edsl/surveys/Survey.py +310 -636
  81. edsl/surveys/SurveyExportMixin.py +9 -71
  82. edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
  83. edsl/surveys/SurveyQualtricsImport.py +4 -75
  84. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  85. edsl/utilities/utilities.py +1 -9
  86. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
  87. edsl-0.1.33.dev1.dist-info/RECORD +209 -0
  88. edsl/TemplateLoader.py +0 -24
  89. edsl/auto/AutoStudy.py +0 -117
  90. edsl/auto/StageBase.py +0 -230
  91. edsl/auto/StageGenerateSurvey.py +0 -178
  92. edsl/auto/StageLabelQuestions.py +0 -125
  93. edsl/auto/StagePersona.py +0 -61
  94. edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
  95. edsl/auto/StagePersonaDimensionValues.py +0 -74
  96. edsl/auto/StagePersonaDimensions.py +0 -69
  97. edsl/auto/StageQuestions.py +0 -73
  98. edsl/auto/SurveyCreatorPipeline.py +0 -21
  99. edsl/auto/utilities.py +0 -224
  100. edsl/coop/PriceFetcher.py +0 -58
  101. edsl/inference_services/MistralAIService.py +0 -120
  102. edsl/inference_services/TestService.py +0 -80
  103. edsl/inference_services/TogetherAIService.py +0 -170
  104. edsl/jobs/FailedQuestion.py +0 -78
  105. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  106. edsl/language_models/fake_openai_call.py +0 -15
  107. edsl/language_models/fake_openai_service.py +0 -61
  108. edsl/language_models/utilities.py +0 -61
  109. edsl/questions/QuestionBaseGenMixin.py +0 -133
  110. edsl/questions/QuestionBasePromptsMixin.py +0 -266
  111. edsl/questions/Quick.py +0 -41
  112. edsl/questions/ResponseValidatorABC.py +0 -170
  113. edsl/questions/decorators.py +0 -21
  114. edsl/questions/prompt_templates/question_budget.jinja +0 -13
  115. edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
  116. edsl/questions/prompt_templates/question_extract.jinja +0 -11
  117. edsl/questions/prompt_templates/question_free_text.jinja +0 -3
  118. edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
  119. edsl/questions/prompt_templates/question_list.jinja +0 -17
  120. edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
  121. edsl/questions/prompt_templates/question_numerical.jinja +0 -37
  122. edsl/questions/templates/__init__.py +0 -0
  123. edsl/questions/templates/budget/__init__.py +0 -0
  124. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  125. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  126. edsl/questions/templates/checkbox/__init__.py +0 -0
  127. edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
  128. edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
  129. edsl/questions/templates/extract/__init__.py +0 -0
  130. edsl/questions/templates/extract/answering_instructions.jinja +0 -7
  131. edsl/questions/templates/extract/question_presentation.jinja +0 -1
  132. edsl/questions/templates/free_text/__init__.py +0 -0
  133. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  134. edsl/questions/templates/free_text/question_presentation.jinja +0 -1
  135. edsl/questions/templates/likert_five/__init__.py +0 -0
  136. edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
  137. edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
  138. edsl/questions/templates/linear_scale/__init__.py +0 -0
  139. edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
  140. edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
  141. edsl/questions/templates/list/__init__.py +0 -0
  142. edsl/questions/templates/list/answering_instructions.jinja +0 -4
  143. edsl/questions/templates/list/question_presentation.jinja +0 -5
  144. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  145. edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
  146. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  147. edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
  148. edsl/questions/templates/numerical/__init__.py +0 -0
  149. edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
  150. edsl/questions/templates/numerical/question_presentation.jinja +0 -7
  151. edsl/questions/templates/rank/__init__.py +0 -0
  152. edsl/questions/templates/rank/answering_instructions.jinja +0 -11
  153. edsl/questions/templates/rank/question_presentation.jinja +0 -15
  154. edsl/questions/templates/top_k/__init__.py +0 -0
  155. edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
  156. edsl/questions/templates/top_k/question_presentation.jinja +0 -22
  157. edsl/questions/templates/yes_no/__init__.py +0 -0
  158. edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
  159. edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
  160. edsl/results/DatasetTree.py +0 -145
  161. edsl/results/Selector.py +0 -118
  162. edsl/results/tree_explore.py +0 -115
  163. edsl/surveys/instructions/ChangeInstruction.py +0 -47
  164. edsl/surveys/instructions/Instruction.py +0 -34
  165. edsl/surveys/instructions/InstructionCollection.py +0 -77
  166. edsl/surveys/instructions/__init__.py +0 -0
  167. edsl/templates/error_reporting/base.html +0 -24
  168. edsl/templates/error_reporting/exceptions_by_model.html +0 -35
  169. edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
  170. edsl/templates/error_reporting/exceptions_by_type.html +0 -17
  171. edsl/templates/error_reporting/interview_details.html +0 -116
  172. edsl/templates/error_reporting/interviews.html +0 -10
  173. edsl/templates/error_reporting/overview.html +0 -5
  174. edsl/templates/error_reporting/performance_plot.html +0 -2
  175. edsl/templates/error_reporting/report.css +0 -74
  176. edsl/templates/error_reporting/report.html +0 -118
  177. edsl/templates/error_reporting/report.js +0 -25
  178. edsl-0.1.33.dist-info/RECORD +0 -295
  179. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
  180. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -156,11 +156,7 @@ class Jobs(Base):
156
156
  from edsl.results.Dataset import Dataset
157
157
 
158
158
  for interview_index, interview in enumerate(interviews):
159
- invigilators = [
160
- interview._get_invigilator(question)
161
- for question in self.survey.questions
162
- ]
163
- # list(interview._build_invigilators(debug=False))
159
+ invigilators = list(interview._build_invigilators(debug=False))
164
160
  for _, invigilator in enumerate(invigilators):
165
161
  prompts = invigilator.get_prompts()
166
162
  user_prompts.append(prompts["user_prompt"])
@@ -348,7 +344,6 @@ class Jobs(Base):
348
344
  scenario=scenario,
349
345
  model=model,
350
346
  skip_retry=self.skip_retry,
351
- raise_validation_errors=self.raise_validation_errors,
352
347
  )
353
348
 
354
349
  def create_bucket_collection(self) -> BucketCollection:
@@ -460,44 +455,33 @@ class Jobs(Base):
460
455
  if warn:
461
456
  warnings.warn(message)
462
457
 
463
- if self.scenarios.has_jinja_braces:
464
- warnings.warn(
465
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
466
- )
467
- self.scenarios = self.scenarios.convert_jinja_braces()
468
-
469
458
  @property
470
459
  def skip_retry(self):
471
460
  if not hasattr(self, "_skip_retry"):
472
461
  return False
473
462
  return self._skip_retry
474
463
 
475
- @property
476
- def raise_validation_errors(self):
477
- if not hasattr(self, "_raise_validation_errors"):
478
- return False
479
- return self._raise_validation_errors
480
-
481
464
  def run(
482
465
  self,
483
466
  n: int = 1,
467
+ debug: bool = False,
484
468
  progress_bar: bool = False,
485
469
  stop_on_exception: bool = False,
486
470
  cache: Union[Cache, bool] = None,
487
471
  check_api_keys: bool = False,
488
472
  sidecar_model: Optional[LanguageModel] = None,
473
+ batch_mode: Optional[bool] = None,
489
474
  verbose: bool = False,
490
475
  print_exceptions=True,
491
476
  remote_cache_description: Optional[str] = None,
492
477
  remote_inference_description: Optional[str] = None,
493
478
  skip_retry: bool = False,
494
- raise_validation_errors: bool = False,
495
- disable_remote_inference: bool = False,
496
479
  ) -> Results:
497
480
  """
498
481
  Runs the Job: conducts Interviews and returns their results.
499
482
 
500
483
  :param n: how many times to run each interview
484
+ :param debug: prints debug messages
501
485
  :param progress_bar: shows a progress bar
502
486
  :param stop_on_exception: stops the job if an exception is raised
503
487
  :param cache: a cache object to store results
@@ -511,21 +495,22 @@ class Jobs(Base):
511
495
 
512
496
  self._check_parameters()
513
497
  self._skip_retry = skip_retry
514
- self._raise_validation_errors = raise_validation_errors
515
498
 
516
- self.verbose = verbose
499
+ if batch_mode is not None:
500
+ raise NotImplementedError(
501
+ "Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
502
+ )
517
503
 
518
- remote_cache = False
519
- remote_inference = False
504
+ self.verbose = verbose
520
505
 
521
- if not disable_remote_inference:
522
- try:
523
- coop = Coop()
524
- user_edsl_settings = Coop().edsl_settings
525
- remote_cache = user_edsl_settings.get("remote_caching", False)
526
- remote_inference = user_edsl_settings.get("remote_inference", False)
527
- except Exception:
528
- pass
506
+ try:
507
+ coop = Coop()
508
+ user_edsl_settings = coop.edsl_settings
509
+ remote_cache = user_edsl_settings["remote_caching"]
510
+ remote_inference = user_edsl_settings["remote_inference"]
511
+ except Exception:
512
+ remote_cache = False
513
+ remote_inference = False
529
514
 
530
515
  if remote_inference:
531
516
  import time
@@ -602,7 +587,7 @@ class Jobs(Base):
602
587
  )
603
588
 
604
589
  # handle cache
605
- if cache is None or cache is True:
590
+ if cache is None:
606
591
  from edsl.data.CacheHandler import CacheHandler
607
592
 
608
593
  cache = CacheHandler().get_cache()
@@ -614,12 +599,12 @@ class Jobs(Base):
614
599
  if not remote_cache:
615
600
  results = self._run_local(
616
601
  n=n,
602
+ debug=debug,
617
603
  progress_bar=progress_bar,
618
604
  cache=cache,
619
605
  stop_on_exception=stop_on_exception,
620
606
  sidecar_model=sidecar_model,
621
607
  print_exceptions=print_exceptions,
622
- raise_validation_errors=raise_validation_errors,
623
608
  )
624
609
 
625
610
  results.cache = cache.new_entries_cache()
@@ -658,12 +643,12 @@ class Jobs(Base):
658
643
  self._output("Running job...")
659
644
  results = self._run_local(
660
645
  n=n,
646
+ debug=debug,
661
647
  progress_bar=progress_bar,
662
648
  cache=cache,
663
649
  stop_on_exception=stop_on_exception,
664
650
  sidecar_model=sidecar_model,
665
651
  print_exceptions=print_exceptions,
666
- raise_validation_errors=raise_validation_errors,
667
652
  )
668
653
  self._output("Job completed!")
669
654
 
@@ -898,7 +883,7 @@ def main():
898
883
 
899
884
  job = Jobs.example()
900
885
  len(job) == 8
901
- results = job.run(cache=Cache())
886
+ results = job.run(debug=True, cache=Cache())
902
887
  len(results) == 8
903
888
  results
904
889
 
@@ -13,8 +13,6 @@ class BucketCollection(UserDict):
13
13
  def __init__(self, infinity_buckets=False):
14
14
  super().__init__()
15
15
  self.infinity_buckets = infinity_buckets
16
- self.models_to_services = {}
17
- self.services_to_buckets = {}
18
16
 
19
17
  def __repr__(self):
20
18
  return f"BucketCollection({self.data})"
@@ -23,7 +21,6 @@ class BucketCollection(UserDict):
23
21
  """Adds a model to the bucket collection.
24
22
 
25
23
  This will create the token and request buckets for the model."""
26
-
27
24
  # compute the TPS and RPS from the model
28
25
  if not self.infinity_buckets:
29
26
  TPS = model.TPM / 60.0
@@ -32,28 +29,22 @@ class BucketCollection(UserDict):
32
29
  TPS = float("inf")
33
30
  RPS = float("inf")
34
31
 
35
- if model.model not in self.models_to_services:
36
- service = model._inference_service_
37
- if service not in self.services_to_buckets:
38
- requests_bucket = TokenBucket(
39
- bucket_name=service,
40
- bucket_type="requests",
41
- capacity=RPS,
42
- refill_rate=RPS,
43
- )
44
- tokens_bucket = TokenBucket(
45
- bucket_name=service,
46
- bucket_type="tokens",
47
- capacity=TPS,
48
- refill_rate=TPS,
49
- )
50
- self.services_to_buckets[service] = ModelBuckets(
51
- requests_bucket, tokens_bucket
52
- )
53
- self.models_to_services[model.model] = service
54
- self[model] = self.services_to_buckets[service]
32
+ # create the buckets
33
+ requests_bucket = TokenBucket(
34
+ bucket_name=model.model,
35
+ bucket_type="requests",
36
+ capacity=RPS,
37
+ refill_rate=RPS,
38
+ )
39
+ tokens_bucket = TokenBucket(
40
+ bucket_name=model.model, bucket_type="tokens", capacity=TPS, refill_rate=TPS
41
+ )
42
+ model_buckets = ModelBuckets(requests_bucket, tokens_bucket)
43
+ if model in self:
44
+ # it if already exists, combine the buckets
45
+ self[model] += model_buckets
55
46
  else:
56
- self[model] = self.services_to_buckets[self.models_to_services[model.model]]
47
+ self[model] = model_buckets
57
48
 
58
49
  def visualize(self) -> dict:
59
50
  """Visualize the token and request buckets for each model."""
@@ -1,4 +1,4 @@
1
- from typing import Union, List, Any, Optional
1
+ from typing import Union, List, Any
2
2
  import asyncio
3
3
  import time
4
4
 
@@ -17,12 +17,6 @@ class TokenBucket:
17
17
  self.bucket_name = bucket_name
18
18
  self.bucket_type = bucket_type
19
19
  self.capacity = capacity # Maximum number of tokens
20
- self.added_tokens = 0
21
-
22
- self.target_rate = (
23
- capacity * 60
24
- ) # set this here because it can change with turbo mode
25
-
26
20
  self._old_capacity = capacity
27
21
  self.tokens = capacity # Current number of available tokens
28
22
  self.refill_rate = refill_rate # Rate at which tokens are refilled
@@ -31,12 +25,6 @@ class TokenBucket:
31
25
  self.log: List[Any] = []
32
26
  self.turbo_mode = False
33
27
 
34
- self.creation_time = time.monotonic()
35
-
36
- self.num_requests = 0
37
- self.num_released = 0
38
- self.tokens_returned = 0
39
-
40
28
  def turbo_mode_on(self):
41
29
  """Set the refill rate to infinity."""
42
30
  if self.turbo_mode:
@@ -81,7 +69,6 @@ class TokenBucket:
81
69
  >>> bucket.tokens
82
70
  10
83
71
  """
84
- self.tokens_returned += tokens
85
72
  self.tokens = min(self.capacity, self.tokens + tokens)
86
73
  self.log.append((time.monotonic(), self.tokens))
87
74
 
@@ -95,30 +82,23 @@ class TokenBucket:
95
82
  >>> bucket.refill()
96
83
  >>> bucket.tokens > 0
97
84
  True
85
+
98
86
  """
99
- """Refill the bucket with new tokens based on elapsed time."""
100
87
  now = time.monotonic()
101
- # print(f"Time is now: {now}; Last refill time: {self.last_refill}")
102
88
  elapsed = now - self.last_refill
103
- # print("Elapsed time: ", elapsed)
104
89
  refill_amount = elapsed * self.refill_rate
105
90
  self.tokens = min(self.capacity, self.tokens + refill_amount)
106
91
  self.last_refill = now
107
92
 
108
- if self.tokens < self.capacity:
109
- pass
110
- # print(f"Refilled. Current tokens: {self.tokens:.4f}")
111
- # print(f"Elapsed time: {elapsed:.4f} seconds")
112
- # print(f"Refill amount: {refill_amount:.4f}")
113
-
114
93
  self.log.append((now, self.tokens))
115
94
 
116
95
  def wait_time(self, requested_tokens: Union[float, int]) -> float:
117
96
  """Calculate the time to wait for the requested number of tokens."""
118
- # self.refill() # Update the current token count
119
- if self.tokens >= requested_tokens:
120
- return 0
121
- return (requested_tokens - self.tokens) / self.refill_rate
97
+ now = time.monotonic()
98
+ elapsed = now - self.last_refill
99
+ refill_amount = elapsed * self.refill_rate
100
+ available_tokens = min(self.capacity, self.tokens + refill_amount)
101
+ return max(0, requested_tokens - available_tokens) / self.refill_rate
122
102
 
123
103
  async def get_tokens(
124
104
  self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
@@ -143,33 +123,22 @@ class TokenBucket:
143
123
  ...
144
124
  ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
145
125
  >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
146
- >>> bucket.capacity
147
- 12.100000000000001
148
126
  """
149
- self.num_requests += amount
150
- if amount >= self.capacity:
127
+ if amount > self.capacity:
151
128
  if not cheat_bucket_capacity:
152
129
  msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
153
130
  raise ValueError(msg)
154
131
  else:
155
- self.capacity = amount * 1.10
156
- self._old_capacity = self.capacity
132
+ self.tokens = 0 # clear the bucket but let it go through
133
+ return
157
134
 
158
- start_time = time.monotonic()
159
- while True:
160
- self.refill() # Refill based on elapsed time
161
- if self.tokens >= amount:
162
- self.tokens -= amount
163
- break
135
+ while self.tokens < amount:
136
+ self.refill()
137
+ await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
138
+ self.tokens -= amount
164
139
 
165
- wait_time = self.wait_time(amount)
166
- if wait_time > 0:
167
- await asyncio.sleep(wait_time)
168
-
169
- self.num_released += amount
170
140
  now = time.monotonic()
171
141
  self.log.append((now, self.tokens))
172
- return None
173
142
 
174
143
  def get_log(self) -> list[tuple]:
175
144
  return self.log
@@ -193,54 +162,6 @@ class TokenBucket:
193
162
  plt.tight_layout()
194
163
  plt.show()
195
164
 
196
- def get_throughput(self, time_window: Optional[float] = None) -> float:
197
- """
198
- Calculate the empirical bucket throughput in tokens per minute for the specified time window.
199
-
200
- :param time_window: The time window in seconds to calculate the throughput for.
201
- :return: The throughput in tokens per minute.
202
-
203
- >>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=100, refill_rate=10)
204
- >>> asyncio.run(bucket.get_tokens(50))
205
- >>> time.sleep(1) # Wait for 1 second
206
- >>> asyncio.run(bucket.get_tokens(30))
207
- >>> throughput = bucket.get_throughput(1)
208
- >>> 4750 < throughput < 4850
209
- True
210
- """
211
- now = time.monotonic()
212
-
213
- if time_window is None:
214
- start_time = self.creation_time
215
- else:
216
- start_time = now - time_window
217
-
218
- if start_time < self.creation_time:
219
- start_time = self.creation_time
220
-
221
- elapsed_time = now - start_time
222
-
223
- return (self.num_released / elapsed_time) * 60
224
-
225
- # # Filter log entries within the time window
226
- # relevant_log = [(t, tokens) for t, tokens in self.log if t >= start_time]
227
-
228
- # if len(relevant_log) < 2:
229
- # return 0 # Not enough data points to calculate throughput
230
-
231
- # # Calculate total tokens used
232
- # initial_tokens = relevant_log[0][1]
233
- # final_tokens = relevant_log[-1][1]
234
- # tokens_used = self.num_released - (final_tokens - initial_tokens)
235
-
236
- # # Calculate actual time elapsed
237
- # actual_time_elapsed = relevant_log[-1][0] - relevant_log[0][0]
238
-
239
- # # Calculate throughput in tokens per minute
240
- # throughput = (tokens_used / actual_time_elapsed) * 60
241
-
242
- # return throughput
243
-
244
165
 
245
166
  if __name__ == "__main__":
246
167
  import doctest