edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +9 -3
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +6 -3
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
  8. edsl/config.py +26 -34
  9. edsl/coop/coop.py +11 -2
  10. edsl/data_transfer_models.py +27 -73
  11. edsl/enums.py +2 -0
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +44 -13
  14. edsl/inference_services/OpenAIService.py +7 -4
  15. edsl/inference_services/TestService.py +24 -15
  16. edsl/inference_services/TogetherAIService.py +170 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +18 -8
  19. edsl/jobs/buckets/BucketCollection.py +24 -15
  20. edsl/jobs/buckets/TokenBucket.py +64 -10
  21. edsl/jobs/interviews/Interview.py +115 -47
  22. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  23. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
  25. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  26. edsl/jobs/tasks/TaskHistory.py +17 -0
  27. edsl/language_models/LanguageModel.py +26 -31
  28. edsl/language_models/registry.py +13 -9
  29. edsl/questions/QuestionBase.py +64 -16
  30. edsl/questions/QuestionBudget.py +93 -41
  31. edsl/questions/QuestionFreeText.py +6 -0
  32. edsl/questions/QuestionMultipleChoice.py +11 -26
  33. edsl/questions/QuestionNumerical.py +5 -4
  34. edsl/questions/Quick.py +41 -0
  35. edsl/questions/ResponseValidatorABC.py +6 -5
  36. edsl/questions/derived/QuestionLinearScale.py +4 -1
  37. edsl/questions/derived/QuestionTopK.py +4 -1
  38. edsl/questions/derived/QuestionYesNo.py +8 -2
  39. edsl/questions/templates/budget/__init__.py +0 -0
  40. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  41. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  42. edsl/questions/templates/extract/__init__.py +0 -0
  43. edsl/questions/templates/rank/__init__.py +0 -0
  44. edsl/results/DatasetExportMixin.py +5 -1
  45. edsl/results/Result.py +1 -1
  46. edsl/results/Results.py +4 -1
  47. edsl/scenarios/FileStore.py +71 -10
  48. edsl/scenarios/Scenario.py +86 -21
  49. edsl/scenarios/ScenarioImageMixin.py +2 -2
  50. edsl/scenarios/ScenarioList.py +13 -0
  51. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  52. edsl/study/Study.py +32 -0
  53. edsl/surveys/Rule.py +10 -1
  54. edsl/surveys/RuleCollection.py +19 -3
  55. edsl/surveys/Survey.py +7 -0
  56. edsl/templates/error_reporting/interview_details.html +6 -1
  57. edsl/utilities/utilities.py +9 -1
  58. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
  59. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
  60. edsl/jobs/interviews/retry_management.py +0 -39
  61. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  62. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -64,7 +64,7 @@ class GoogleService(InferenceServiceABC):
64
64
  "stopSequences": self.stopSequences,
65
65
  },
66
66
  }
67
- print(combined_prompt)
67
+ # print(combined_prompt)
68
68
  async with aiohttp.ClientSession() as session:
69
69
  async with session.post(
70
70
  url, headers=headers, data=json.dumps(data)
@@ -1,14 +1,27 @@
1
1
  from abc import abstractmethod, ABC
2
- from typing import Any
2
+ import os
3
3
  import re
4
4
  from edsl.config import CONFIG
5
5
 
6
6
 
7
7
  class InferenceServiceABC(ABC):
8
- """Abstract class for inference services."""
8
+ """
9
+ Abstract class for inference services.
10
+ Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
+ """
12
+
13
+ default_levels = {
14
+ "google": {"tpm": 2_000_000, "rpm": 15},
15
+ "openai": {"tpm": 2_000_000, "rpm": 10_000},
16
+ "anthropic": {"tpm": 2_000_000, "rpm": 500},
17
+ }
9
18
 
10
- # check if child class has cls attribute "key_sequence"
11
19
  def __init_subclass__(cls):
20
+ """
21
+ Check that the subclass has the required attributes.
22
+ - `key_sequence` attribute determines...
23
+ - `model_exclude_list` attribute determines...
24
+ """
12
25
  if not hasattr(cls, "key_sequence"):
13
26
  raise NotImplementedError(
14
27
  f"Class {cls.__name__} must have a 'key_sequence' attribute."
@@ -18,29 +31,47 @@ class InferenceServiceABC(ABC):
18
31
  f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
19
32
  )
20
33
 
21
- def get_tpm(cls):
22
- key = f"EDSL_SERVICE_TPM_{cls._inference_service_.upper()}"
23
- if key not in CONFIG:
24
- key = "EDSL_SERVICE_TPM_BASELINE"
25
- return int(CONFIG.get(key))
34
+ @classmethod
35
+ def _get_limt(cls, limit_type: str) -> int:
36
+ key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
+ if key in os.environ:
38
+ return int(os.getenv(key))
39
+
40
+ if cls._inference_service_ in cls.default_levels:
41
+ return int(cls.default_levels[cls._inference_service_][limit_type])
42
+
43
+ return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
44
+
45
+ def get_tpm(cls) -> int:
46
+ """
47
+ Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
48
+ """
49
+ return cls._get_limt(limit_type="tpm")
26
50
 
27
51
  def get_rpm(cls):
28
- key = f"EDSL_SERVICE_RPM_{cls._inference_service_.upper()}"
29
- if key not in CONFIG:
30
- key = "EDSL_SERVICE_RPM_BASELINE"
31
- return int(CONFIG.get(key))
52
+ """
53
+ Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
54
+ """
55
+ return cls._get_limt(limit_type="rpm")
32
56
 
33
57
  @abstractmethod
34
58
  def available() -> list[str]:
59
+ """
60
+ Returns a list of available models for the service.
61
+ """
35
62
  pass
36
63
 
37
64
  @abstractmethod
38
65
  def create_model():
66
+ """
67
+ Returns a LanguageModel object.
68
+ """
39
69
  pass
40
70
 
41
71
  @staticmethod
42
72
  def to_class_name(s):
43
- """Convert a string to a valid class name.
73
+ """
74
+ Converts a string to a valid class name.
44
75
 
45
76
  >>> InferenceServiceABC.to_class_name("hello world")
46
77
  'HelloWorld'
@@ -187,12 +187,15 @@ class OpenAIService(InferenceServiceABC):
187
187
  else:
188
188
  content = user_prompt
189
189
  client = self.async_client()
190
+ messages = [
191
+ {"role": "system", "content": system_prompt},
192
+ {"role": "user", "content": content},
193
+ ]
194
+ if system_prompt == "" and self.omit_system_prompt_if_empty:
195
+ messages = messages[1:]
190
196
  params = {
191
197
  "model": self.model,
192
- "messages": [
193
- {"role": "system", "content": system_prompt},
194
- {"role": "user", "content": content},
195
- ],
198
+ "messages": messages,
196
199
  "temperature": self.temperature,
197
200
  "max_tokens": self.max_tokens,
198
201
  "top_p": self.top_p,
@@ -7,14 +7,25 @@ from edsl.inference_services.rate_limits_cache import rate_limits
7
7
  from edsl.utilities.utilities import fix_partial_correct_response
8
8
 
9
9
  from edsl.enums import InferenceServiceType
10
+ import random
10
11
 
11
12
 
12
13
  class TestService(InferenceServiceABC):
13
14
  """OpenAI service class."""
14
15
 
16
+ _inference_service_ = "test"
17
+ _env_key_name_ = None
18
+ _base_url_ = None
19
+
20
+ _sync_client_ = None
21
+ _async_client_ = None
22
+
23
+ _sync_client_instance = None
24
+ _async_client_instance = None
25
+
15
26
  key_sequence = None
27
+ usage_sequence = None
16
28
  model_exclude_list = []
17
- _inference_service_ = "test"
18
29
  input_token_name = "prompt_tokens"
19
30
  output_token_name = "completion_tokens"
20
31
 
@@ -45,27 +56,25 @@ class TestService(InferenceServiceABC):
45
56
  return "Hello, world"
46
57
 
47
58
  async def async_execute_model_call(
48
- self, user_prompt: str, system_prompt: str
59
+ self,
60
+ user_prompt: str,
61
+ system_prompt: str,
62
+ encoded_image=None,
49
63
  ) -> dict[str, Any]:
50
64
  await asyncio.sleep(0.1)
51
65
  # return {"message": """{"answer": "Hello, world"}"""}
66
+
52
67
  if hasattr(self, "throw_exception") and self.throw_exception:
53
- raise Exception("This is a test error")
68
+ if hasattr(self, "exception_probability"):
69
+ p = self.exception_probability
70
+ else:
71
+ p = 1
72
+
73
+ if random.random() < p:
74
+ raise Exception("This is a test error")
54
75
  return {
55
76
  "message": [{"text": f"{self._canned_response}"}],
56
77
  "usage": {"prompt_tokens": 1, "completion_tokens": 1},
57
78
  }
58
79
 
59
80
  return TestServiceLanguageModel
60
-
61
- # _inference_service_ = "openai"
62
- # _env_key_name_ = "OPENAI_API_KEY"
63
- # _base_url_ = None
64
-
65
- # _sync_client_ = openai.OpenAI
66
- # _async_client_ = openai.AsyncOpenAI
67
-
68
- # _sync_client_instance = None
69
- # _async_client_instance = None
70
-
71
- # key_sequence = ["choices", 0, "message", "content"]
@@ -0,0 +1,170 @@
1
+ import aiohttp
2
+ import json
3
+ import requests
4
+ from typing import Any, List
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models import LanguageModel
8
+
9
+ from edsl.inference_services.OpenAIService import OpenAIService
10
+ import openai
11
+
12
+
13
+ class TogetherAIService(OpenAIService):
14
+ """DeepInfra service class."""
15
+
16
+ _inference_service_ = "together"
17
+ _env_key_name_ = "TOGETHER_API_KEY"
18
+ _base_url_ = "https://api.together.xyz/v1"
19
+ _models_list_cache: List[str] = []
20
+
21
+ # These are non-serverless models. There was no api param to filter them
22
+ model_exclude_list = [
23
+ "EleutherAI/llemma_7b",
24
+ "HuggingFaceH4/zephyr-7b-beta",
25
+ "Nexusflow/NexusRaven-V2-13B",
26
+ "NousResearch/Hermes-2-Theta-Llama-3-70B",
27
+ "NousResearch/Nous-Capybara-7B-V1p9",
28
+ "NousResearch/Nous-Hermes-13b",
29
+ "NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
30
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
31
+ "NousResearch/Nous-Hermes-Llama2-13b",
32
+ "NousResearch/Nous-Hermes-Llama2-70b",
33
+ "NousResearch/Nous-Hermes-llama-2-7b",
34
+ "NumbersStation/nsql-llama-2-7B",
35
+ "Open-Orca/Mistral-7B-OpenOrca",
36
+ "Phind/Phind-CodeLlama-34B-Python-v1",
37
+ "Phind/Phind-CodeLlama-34B-v2",
38
+ "Qwen/Qwen1.5-0.5B",
39
+ "Qwen/Qwen1.5-0.5B-Chat",
40
+ "Qwen/Qwen1.5-1.8B",
41
+ "Qwen/Qwen1.5-1.8B-Chat",
42
+ "Qwen/Qwen1.5-14B",
43
+ "Qwen/Qwen1.5-14B-Chat",
44
+ "Qwen/Qwen1.5-32B",
45
+ "Qwen/Qwen1.5-32B-Chat",
46
+ "Qwen/Qwen1.5-4B",
47
+ "Qwen/Qwen1.5-4B-Chat",
48
+ "Qwen/Qwen1.5-72B",
49
+ "Qwen/Qwen1.5-7B",
50
+ "Qwen/Qwen1.5-7B-Chat",
51
+ "Qwen/Qwen2-1.5B",
52
+ "Qwen/Qwen2-1.5B-Instruct",
53
+ "Qwen/Qwen2-72B",
54
+ "Qwen/Qwen2-7B",
55
+ "Qwen/Qwen2-7B-Instruct",
56
+ "SG161222/Realistic_Vision_V3.0_VAE",
57
+ "Snowflake/snowflake-arctic-instruct",
58
+ "Undi95/ReMM-SLERP-L2-13B",
59
+ "Undi95/Toppy-M-7B",
60
+ "WizardLM/WizardCoder-Python-34B-V1.0",
61
+ "WizardLM/WizardLM-13B-V1.2",
62
+ "WizardLM/WizardLM-70B-V1.0",
63
+ "allenai/OLMo-7B",
64
+ "allenai/OLMo-7B-Instruct",
65
+ "bert-base-uncased",
66
+ "codellama/CodeLlama-13b-Instruct-hf",
67
+ "codellama/CodeLlama-13b-Python-hf",
68
+ "codellama/CodeLlama-13b-hf",
69
+ "codellama/CodeLlama-34b-Python-hf",
70
+ "codellama/CodeLlama-34b-hf",
71
+ "codellama/CodeLlama-70b-Instruct-hf",
72
+ "codellama/CodeLlama-70b-Python-hf",
73
+ "codellama/CodeLlama-70b-hf",
74
+ "codellama/CodeLlama-7b-Instruct-hf",
75
+ "codellama/CodeLlama-7b-Python-hf",
76
+ "codellama/CodeLlama-7b-hf",
77
+ "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
78
+ "deepseek-ai/deepseek-coder-33b-instruct",
79
+ "garage-bAInd/Platypus2-70B-instruct",
80
+ "google/gemma-2b",
81
+ "google/gemma-7b",
82
+ "google/gemma-7b-it",
83
+ "gradientai/Llama-3-70B-Instruct-Gradient-1048k",
84
+ "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1",
85
+ "huggyllama/llama-13b",
86
+ "huggyllama/llama-30b",
87
+ "huggyllama/llama-65b",
88
+ "huggyllama/llama-7b",
89
+ "lmsys/vicuna-13b-v1.3",
90
+ "lmsys/vicuna-13b-v1.5",
91
+ "lmsys/vicuna-13b-v1.5-16k",
92
+ "lmsys/vicuna-7b-v1.3",
93
+ "lmsys/vicuna-7b-v1.5",
94
+ "meta-llama/Llama-2-13b-hf",
95
+ "meta-llama/Llama-2-70b-chat-hf",
96
+ "meta-llama/Llama-2-7b-hf",
97
+ "meta-llama/Llama-3-70b-hf",
98
+ "meta-llama/Llama-3-8b-hf",
99
+ "meta-llama/Meta-Llama-3-70B",
100
+ "meta-llama/Meta-Llama-3-70B-Instruct",
101
+ "meta-llama/Meta-Llama-3-8B-Instruct",
102
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference",
103
+ "meta-llama/Meta-Llama-3.1-70B-Reference",
104
+ "meta-llama/Meta-Llama-3.1-8B-Reference",
105
+ "microsoft/phi-2",
106
+ "mistralai/Mixtral-8x22B",
107
+ "openchat/openchat-3.5-1210",
108
+ "prompthero/openjourney",
109
+ "runwayml/stable-diffusion-v1-5",
110
+ "sentence-transformers/msmarco-bert-base-dot-v5",
111
+ "snorkelai/Snorkel-Mistral-PairRM-DPO",
112
+ "stabilityai/stable-diffusion-2-1",
113
+ "teknium/OpenHermes-2-Mistral-7B",
114
+ "teknium/OpenHermes-2p5-Mistral-7B",
115
+ "togethercomputer/CodeLlama-13b-Instruct",
116
+ "togethercomputer/CodeLlama-13b-Python",
117
+ "togethercomputer/CodeLlama-34b",
118
+ "togethercomputer/CodeLlama-34b-Python",
119
+ "togethercomputer/CodeLlama-7b-Instruct",
120
+ "togethercomputer/CodeLlama-7b-Python",
121
+ "togethercomputer/Koala-13B",
122
+ "togethercomputer/Koala-7B",
123
+ "togethercomputer/LLaMA-2-7B-32K",
124
+ "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4",
125
+ "togethercomputer/StripedHyena-Hessian-7B",
126
+ "togethercomputer/alpaca-7b",
127
+ "togethercomputer/evo-1-131k-base",
128
+ "togethercomputer/evo-1-8k-base",
129
+ "togethercomputer/guanaco-13b",
130
+ "togethercomputer/guanaco-33b",
131
+ "togethercomputer/guanaco-65b",
132
+ "togethercomputer/guanaco-7b",
133
+ "togethercomputer/llama-2-13b",
134
+ "togethercomputer/llama-2-70b-chat",
135
+ "togethercomputer/llama-2-7b",
136
+ "wavymulder/Analog-Diffusion",
137
+ "zero-one-ai/Yi-34B",
138
+ "zero-one-ai/Yi-34B-Chat",
139
+ "zero-one-ai/Yi-6B",
140
+ ]
141
+
142
+ _sync_client_ = openai.OpenAI
143
+ _async_client_ = openai.AsyncOpenAI
144
+
145
+ @classmethod
146
+ def get_model_list(cls):
147
+ # Togheter.ai has a different response in model list then openai
148
+ # and the OpenAI class returns an error when calling .models.list()
149
+ import requests
150
+ import os
151
+
152
+ url = "https://api.together.xyz/v1/models?filter=serverless"
153
+ token = os.getenv(cls._env_key_name_)
154
+ headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
155
+
156
+ response = requests.get(url, headers=headers)
157
+ return response.json()
158
+
159
+ @classmethod
160
+ def available(cls) -> List[str]:
161
+ if not cls._models_list_cache:
162
+ try:
163
+ cls._models_list_cache = [
164
+ m["id"]
165
+ for m in cls.get_model_list()
166
+ if m["id"] not in cls.model_exclude_list
167
+ ]
168
+ except Exception as e:
169
+ raise
170
+ return cls._models_list_cache
@@ -12,6 +12,7 @@ from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
13
  from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.MistralAIService import MistralAIService
15
+ from edsl.inference_services.TogetherAIService import TogetherAIService
15
16
 
16
17
  default = InferenceServicesCollection(
17
18
  [
@@ -25,5 +26,6 @@ default = InferenceServicesCollection(
25
26
  OllamaService,
26
27
  TestService,
27
28
  MistralAIService,
29
+ TogetherAIService,
28
30
  ]
29
31
  )
edsl/jobs/Jobs.py CHANGED
@@ -460,6 +460,12 @@ class Jobs(Base):
460
460
  if warn:
461
461
  warnings.warn(message)
462
462
 
463
+ if self.scenarios.has_jinja_braces:
464
+ warnings.warn(
465
+ "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
466
+ )
467
+ self.scenarios = self.scenarios.convert_jinja_braces()
468
+
463
469
  @property
464
470
  def skip_retry(self):
465
471
  if not hasattr(self, "_skip_retry"):
@@ -486,6 +492,7 @@ class Jobs(Base):
486
492
  remote_inference_description: Optional[str] = None,
487
493
  skip_retry: bool = False,
488
494
  raise_validation_errors: bool = False,
495
+ disable_remote_inference: bool = False,
489
496
  ) -> Results:
490
497
  """
491
498
  Runs the Job: conducts Interviews and returns their results.
@@ -508,14 +515,17 @@ class Jobs(Base):
508
515
 
509
516
  self.verbose = verbose
510
517
 
511
- try:
512
- coop = Coop()
513
- user_edsl_settings = coop.edsl_settings
514
- remote_cache = user_edsl_settings["remote_caching"]
515
- remote_inference = user_edsl_settings["remote_inference"]
516
- except Exception:
517
- remote_cache = False
518
- remote_inference = False
518
+ remote_cache = False
519
+ remote_inference = False
520
+
521
+ if not disable_remote_inference:
522
+ try:
523
+ coop = Coop()
524
+ user_edsl_settings = Coop().edsl_settings
525
+ remote_cache = user_edsl_settings.get("remote_caching", False)
526
+ remote_inference = user_edsl_settings.get("remote_inference", False)
527
+ except Exception:
528
+ pass
519
529
 
520
530
  if remote_inference:
521
531
  import time
@@ -13,6 +13,8 @@ class BucketCollection(UserDict):
13
13
  def __init__(self, infinity_buckets=False):
14
14
  super().__init__()
15
15
  self.infinity_buckets = infinity_buckets
16
+ self.models_to_services = {}
17
+ self.services_to_buckets = {}
16
18
 
17
19
  def __repr__(self):
18
20
  return f"BucketCollection({self.data})"
@@ -21,6 +23,7 @@ class BucketCollection(UserDict):
21
23
  """Adds a model to the bucket collection.
22
24
 
23
25
  This will create the token and request buckets for the model."""
26
+
24
27
  # compute the TPS and RPS from the model
25
28
  if not self.infinity_buckets:
26
29
  TPS = model.TPM / 60.0
@@ -29,22 +32,28 @@ class BucketCollection(UserDict):
29
32
  TPS = float("inf")
30
33
  RPS = float("inf")
31
34
 
32
- # create the buckets
33
- requests_bucket = TokenBucket(
34
- bucket_name=model.model,
35
- bucket_type="requests",
36
- capacity=RPS,
37
- refill_rate=RPS,
38
- )
39
- tokens_bucket = TokenBucket(
40
- bucket_name=model.model, bucket_type="tokens", capacity=TPS, refill_rate=TPS
41
- )
42
- model_buckets = ModelBuckets(requests_bucket, tokens_bucket)
43
- if model in self:
44
- # it if already exists, combine the buckets
45
- self[model] += model_buckets
35
+ if model.model not in self.models_to_services:
36
+ service = model._inference_service_
37
+ if service not in self.services_to_buckets:
38
+ requests_bucket = TokenBucket(
39
+ bucket_name=service,
40
+ bucket_type="requests",
41
+ capacity=RPS,
42
+ refill_rate=RPS,
43
+ )
44
+ tokens_bucket = TokenBucket(
45
+ bucket_name=service,
46
+ bucket_type="tokens",
47
+ capacity=TPS,
48
+ refill_rate=TPS,
49
+ )
50
+ self.services_to_buckets[service] = ModelBuckets(
51
+ requests_bucket, tokens_bucket
52
+ )
53
+ self.models_to_services[model.model] = service
54
+ self[model] = self.services_to_buckets[service]
46
55
  else:
47
- self[model] = model_buckets
56
+ self[model] = self.services_to_buckets[self.models_to_services[model.model]]
48
57
 
49
58
  def visualize(self) -> dict:
50
59
  """Visualize the token and request buckets for each model."""
@@ -1,4 +1,4 @@
1
- from typing import Union, List, Any
1
+ from typing import Union, List, Any, Optional
2
2
  import asyncio
3
3
  import time
4
4
 
@@ -17,6 +17,12 @@ class TokenBucket:
17
17
  self.bucket_name = bucket_name
18
18
  self.bucket_type = bucket_type
19
19
  self.capacity = capacity # Maximum number of tokens
20
+ self.added_tokens = 0
21
+
22
+ self.target_rate = (
23
+ capacity * 60
24
+ ) # set this here because it can change with turbo mode
25
+
20
26
  self._old_capacity = capacity
21
27
  self.tokens = capacity # Current number of available tokens
22
28
  self.refill_rate = refill_rate # Rate at which tokens are refilled
@@ -25,6 +31,12 @@ class TokenBucket:
25
31
  self.log: List[Any] = []
26
32
  self.turbo_mode = False
27
33
 
34
+ self.creation_time = time.monotonic()
35
+
36
+ self.num_requests = 0
37
+ self.num_released = 0
38
+ self.tokens_returned = 0
39
+
28
40
  def turbo_mode_on(self):
29
41
  """Set the refill rate to infinity."""
30
42
  if self.turbo_mode:
@@ -69,6 +81,7 @@ class TokenBucket:
69
81
  >>> bucket.tokens
70
82
  10
71
83
  """
84
+ self.tokens_returned += tokens
72
85
  self.tokens = min(self.capacity, self.tokens + tokens)
73
86
  self.log.append((time.monotonic(), self.tokens))
74
87
 
@@ -133,15 +146,12 @@ class TokenBucket:
133
146
  >>> bucket.capacity
134
147
  12.100000000000001
135
148
  """
149
+ self.num_requests += amount
136
150
  if amount >= self.capacity:
137
151
  if not cheat_bucket_capacity:
138
152
  msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
139
153
  raise ValueError(msg)
140
154
  else:
141
- # self.tokens = 0 # clear the bucket but let it go through
142
- # print(
143
- # f"""The requested amount, {amount}, exceeds the current bucket capacity of {self.capacity}.Increasing bucket capacity to {amount} * 1.10 accommodate the requested amount."""
144
- # )
145
155
  self.capacity = amount * 1.10
146
156
  self._old_capacity = self.capacity
147
157
 
@@ -153,14 +163,10 @@ class TokenBucket:
153
163
  break
154
164
 
155
165
  wait_time = self.wait_time(amount)
156
- # print(f"Waiting for {wait_time:.4f} seconds")
157
166
  if wait_time > 0:
158
- # print(f"Waiting for {wait_time:.4f} seconds")
159
167
  await asyncio.sleep(wait_time)
160
168
 
161
- # total_elapsed = time.monotonic() - start_time
162
- # print(f"Total time to acquire tokens: {total_elapsed:.4f} seconds")
163
-
169
+ self.num_released += amount
164
170
  now = time.monotonic()
165
171
  self.log.append((now, self.tokens))
166
172
  return None
@@ -187,6 +193,54 @@ class TokenBucket:
187
193
  plt.tight_layout()
188
194
  plt.show()
189
195
 
196
+ def get_throughput(self, time_window: Optional[float] = None) -> float:
197
+ """
198
+ Calculate the empirical bucket throughput in tokens per minute for the specified time window.
199
+
200
+ :param time_window: The time window in seconds to calculate the throughput for.
201
+ :return: The throughput in tokens per minute.
202
+
203
+ >>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=100, refill_rate=10)
204
+ >>> asyncio.run(bucket.get_tokens(50))
205
+ >>> time.sleep(1) # Wait for 1 second
206
+ >>> asyncio.run(bucket.get_tokens(30))
207
+ >>> throughput = bucket.get_throughput(1)
208
+ >>> 4750 < throughput < 4850
209
+ True
210
+ """
211
+ now = time.monotonic()
212
+
213
+ if time_window is None:
214
+ start_time = self.creation_time
215
+ else:
216
+ start_time = now - time_window
217
+
218
+ if start_time < self.creation_time:
219
+ start_time = self.creation_time
220
+
221
+ elapsed_time = now - start_time
222
+
223
+ return (self.num_released / elapsed_time) * 60
224
+
225
+ # # Filter log entries within the time window
226
+ # relevant_log = [(t, tokens) for t, tokens in self.log if t >= start_time]
227
+
228
+ # if len(relevant_log) < 2:
229
+ # return 0 # Not enough data points to calculate throughput
230
+
231
+ # # Calculate total tokens used
232
+ # initial_tokens = relevant_log[0][1]
233
+ # final_tokens = relevant_log[-1][1]
234
+ # tokens_used = self.num_released - (final_tokens - initial_tokens)
235
+
236
+ # # Calculate actual time elapsed
237
+ # actual_time_elapsed = relevant_log[-1][0] - relevant_log[0][0]
238
+
239
+ # # Calculate throughput in tokens per minute
240
+ # throughput = (tokens_used / actual_time_elapsed) * 60
241
+
242
+ # return throughput
243
+
190
244
 
191
245
  if __name__ == "__main__":
192
246
  import doctest