edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, List
1
+ from typing import Any, List, Optional
2
2
  import os
3
3
  import asyncio
4
4
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -7,14 +7,25 @@ from edsl.inference_services.rate_limits_cache import rate_limits
7
7
  from edsl.utilities.utilities import fix_partial_correct_response
8
8
 
9
9
  from edsl.enums import InferenceServiceType
10
+ import random
10
11
 
11
12
 
12
13
  class TestService(InferenceServiceABC):
13
14
  """OpenAI service class."""
14
15
 
16
+ _inference_service_ = "test"
17
+ _env_key_name_ = None
18
+ _base_url_ = None
19
+
20
+ _sync_client_ = None
21
+ _async_client_ = None
22
+
23
+ _sync_client_instance = None
24
+ _async_client_instance = None
25
+
15
26
  key_sequence = None
27
+ usage_sequence = None
16
28
  model_exclude_list = []
17
- _inference_service_ = "test"
18
29
  input_token_name = "prompt_tokens"
19
30
  output_token_name = "completion_tokens"
20
31
 
@@ -45,27 +56,34 @@ class TestService(InferenceServiceABC):
45
56
  return "Hello, world"
46
57
 
47
58
  async def async_execute_model_call(
48
- self, user_prompt: str, system_prompt: str
59
+ self,
60
+ user_prompt: str,
61
+ system_prompt: str,
62
+ # func: Optional[callable] = None,
63
+ files_list: Optional[List["File"]] = None,
49
64
  ) -> dict[str, Any]:
50
65
  await asyncio.sleep(0.1)
51
66
  # return {"message": """{"answer": "Hello, world"}"""}
67
+
68
+ if hasattr(self, "func"):
69
+ return {
70
+ "message": [
71
+ {"text": self.func(user_prompt, system_prompt, files_list)}
72
+ ],
73
+ "usage": {"prompt_tokens": 1, "completion_tokens": 1},
74
+ }
75
+
52
76
  if hasattr(self, "throw_exception") and self.throw_exception:
53
- raise Exception("This is a test error")
77
+ if hasattr(self, "exception_probability"):
78
+ p = self.exception_probability
79
+ else:
80
+ p = 1
81
+
82
+ if random.random() < p:
83
+ raise Exception("This is a test error")
54
84
  return {
55
85
  "message": [{"text": f"{self._canned_response}"}],
56
86
  "usage": {"prompt_tokens": 1, "completion_tokens": 1},
57
87
  }
58
88
 
59
89
  return TestServiceLanguageModel
60
-
61
- # _inference_service_ = "openai"
62
- # _env_key_name_ = "OPENAI_API_KEY"
63
- # _base_url_ = None
64
-
65
- # _sync_client_ = openai.OpenAI
66
- # _async_client_ = openai.AsyncOpenAI
67
-
68
- # _sync_client_instance = None
69
- # _async_client_instance = None
70
-
71
- # key_sequence = ["choices", 0, "message", "content"]
@@ -0,0 +1,170 @@
1
+ import aiohttp
2
+ import json
3
+ import requests
4
+ from typing import Any, List, Optional
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models import LanguageModel
8
+
9
+ from edsl.inference_services.OpenAIService import OpenAIService
10
+ import openai
11
+
12
+
13
+ class TogetherAIService(OpenAIService):
14
+ """DeepInfra service class."""
15
+
16
+ _inference_service_ = "together"
17
+ _env_key_name_ = "TOGETHER_API_KEY"
18
+ _base_url_ = "https://api.together.xyz/v1"
19
+ _models_list_cache: List[str] = []
20
+
21
+ # These are non-serverless models. There was no api param to filter them
22
+ model_exclude_list = [
23
+ "EleutherAI/llemma_7b",
24
+ "HuggingFaceH4/zephyr-7b-beta",
25
+ "Nexusflow/NexusRaven-V2-13B",
26
+ "NousResearch/Hermes-2-Theta-Llama-3-70B",
27
+ "NousResearch/Nous-Capybara-7B-V1p9",
28
+ "NousResearch/Nous-Hermes-13b",
29
+ "NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
30
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
31
+ "NousResearch/Nous-Hermes-Llama2-13b",
32
+ "NousResearch/Nous-Hermes-Llama2-70b",
33
+ "NousResearch/Nous-Hermes-llama-2-7b",
34
+ "NumbersStation/nsql-llama-2-7B",
35
+ "Open-Orca/Mistral-7B-OpenOrca",
36
+ "Phind/Phind-CodeLlama-34B-Python-v1",
37
+ "Phind/Phind-CodeLlama-34B-v2",
38
+ "Qwen/Qwen1.5-0.5B",
39
+ "Qwen/Qwen1.5-0.5B-Chat",
40
+ "Qwen/Qwen1.5-1.8B",
41
+ "Qwen/Qwen1.5-1.8B-Chat",
42
+ "Qwen/Qwen1.5-14B",
43
+ "Qwen/Qwen1.5-14B-Chat",
44
+ "Qwen/Qwen1.5-32B",
45
+ "Qwen/Qwen1.5-32B-Chat",
46
+ "Qwen/Qwen1.5-4B",
47
+ "Qwen/Qwen1.5-4B-Chat",
48
+ "Qwen/Qwen1.5-72B",
49
+ "Qwen/Qwen1.5-7B",
50
+ "Qwen/Qwen1.5-7B-Chat",
51
+ "Qwen/Qwen2-1.5B",
52
+ "Qwen/Qwen2-1.5B-Instruct",
53
+ "Qwen/Qwen2-72B",
54
+ "Qwen/Qwen2-7B",
55
+ "Qwen/Qwen2-7B-Instruct",
56
+ "SG161222/Realistic_Vision_V3.0_VAE",
57
+ "Snowflake/snowflake-arctic-instruct",
58
+ "Undi95/ReMM-SLERP-L2-13B",
59
+ "Undi95/Toppy-M-7B",
60
+ "WizardLM/WizardCoder-Python-34B-V1.0",
61
+ "WizardLM/WizardLM-13B-V1.2",
62
+ "WizardLM/WizardLM-70B-V1.0",
63
+ "allenai/OLMo-7B",
64
+ "allenai/OLMo-7B-Instruct",
65
+ "bert-base-uncased",
66
+ "codellama/CodeLlama-13b-Instruct-hf",
67
+ "codellama/CodeLlama-13b-Python-hf",
68
+ "codellama/CodeLlama-13b-hf",
69
+ "codellama/CodeLlama-34b-Python-hf",
70
+ "codellama/CodeLlama-34b-hf",
71
+ "codellama/CodeLlama-70b-Instruct-hf",
72
+ "codellama/CodeLlama-70b-Python-hf",
73
+ "codellama/CodeLlama-70b-hf",
74
+ "codellama/CodeLlama-7b-Instruct-hf",
75
+ "codellama/CodeLlama-7b-Python-hf",
76
+ "codellama/CodeLlama-7b-hf",
77
+ "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
78
+ "deepseek-ai/deepseek-coder-33b-instruct",
79
+ "garage-bAInd/Platypus2-70B-instruct",
80
+ "google/gemma-2b",
81
+ "google/gemma-7b",
82
+ "google/gemma-7b-it",
83
+ "gradientai/Llama-3-70B-Instruct-Gradient-1048k",
84
+ "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1",
85
+ "huggyllama/llama-13b",
86
+ "huggyllama/llama-30b",
87
+ "huggyllama/llama-65b",
88
+ "huggyllama/llama-7b",
89
+ "lmsys/vicuna-13b-v1.3",
90
+ "lmsys/vicuna-13b-v1.5",
91
+ "lmsys/vicuna-13b-v1.5-16k",
92
+ "lmsys/vicuna-7b-v1.3",
93
+ "lmsys/vicuna-7b-v1.5",
94
+ "meta-llama/Llama-2-13b-hf",
95
+ "meta-llama/Llama-2-70b-chat-hf",
96
+ "meta-llama/Llama-2-7b-hf",
97
+ "meta-llama/Llama-3-70b-hf",
98
+ "meta-llama/Llama-3-8b-hf",
99
+ "meta-llama/Meta-Llama-3-70B",
100
+ "meta-llama/Meta-Llama-3-70B-Instruct",
101
+ "meta-llama/Meta-Llama-3-8B-Instruct",
102
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference",
103
+ "meta-llama/Meta-Llama-3.1-70B-Reference",
104
+ "meta-llama/Meta-Llama-3.1-8B-Reference",
105
+ "microsoft/phi-2",
106
+ "mistralai/Mixtral-8x22B",
107
+ "openchat/openchat-3.5-1210",
108
+ "prompthero/openjourney",
109
+ "runwayml/stable-diffusion-v1-5",
110
+ "sentence-transformers/msmarco-bert-base-dot-v5",
111
+ "snorkelai/Snorkel-Mistral-PairRM-DPO",
112
+ "stabilityai/stable-diffusion-2-1",
113
+ "teknium/OpenHermes-2-Mistral-7B",
114
+ "teknium/OpenHermes-2p5-Mistral-7B",
115
+ "togethercomputer/CodeLlama-13b-Instruct",
116
+ "togethercomputer/CodeLlama-13b-Python",
117
+ "togethercomputer/CodeLlama-34b",
118
+ "togethercomputer/CodeLlama-34b-Python",
119
+ "togethercomputer/CodeLlama-7b-Instruct",
120
+ "togethercomputer/CodeLlama-7b-Python",
121
+ "togethercomputer/Koala-13B",
122
+ "togethercomputer/Koala-7B",
123
+ "togethercomputer/LLaMA-2-7B-32K",
124
+ "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4",
125
+ "togethercomputer/StripedHyena-Hessian-7B",
126
+ "togethercomputer/alpaca-7b",
127
+ "togethercomputer/evo-1-131k-base",
128
+ "togethercomputer/evo-1-8k-base",
129
+ "togethercomputer/guanaco-13b",
130
+ "togethercomputer/guanaco-33b",
131
+ "togethercomputer/guanaco-65b",
132
+ "togethercomputer/guanaco-7b",
133
+ "togethercomputer/llama-2-13b",
134
+ "togethercomputer/llama-2-70b-chat",
135
+ "togethercomputer/llama-2-7b",
136
+ "wavymulder/Analog-Diffusion",
137
+ "zero-one-ai/Yi-34B",
138
+ "zero-one-ai/Yi-34B-Chat",
139
+ "zero-one-ai/Yi-6B",
140
+ ]
141
+
142
+ _sync_client_ = openai.OpenAI
143
+ _async_client_ = openai.AsyncOpenAI
144
+
145
+ @classmethod
146
+ def get_model_list(cls):
147
+ # Togheter.ai has a different response in model list then openai
148
+ # and the OpenAI class returns an error when calling .models.list()
149
+ import requests
150
+ import os
151
+
152
+ url = "https://api.together.xyz/v1/models?filter=serverless"
153
+ token = os.getenv(cls._env_key_name_)
154
+ headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
155
+
156
+ response = requests.get(url, headers=headers)
157
+ return response.json()
158
+
159
+ @classmethod
160
+ def available(cls) -> List[str]:
161
+ if not cls._models_list_cache:
162
+ try:
163
+ cls._models_list_cache = [
164
+ m["id"]
165
+ for m in cls.get_model_list()
166
+ if m["id"] not in cls.model_exclude_list
167
+ ]
168
+ except Exception as e:
169
+ raise
170
+ return cls._models_list_cache
@@ -12,6 +12,7 @@ from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
13
  from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.MistralAIService import MistralAIService
15
+ from edsl.inference_services.TogetherAIService import TogetherAIService
15
16
 
16
17
  default = InferenceServicesCollection(
17
18
  [
@@ -25,5 +26,6 @@ default = InferenceServicesCollection(
25
26
  OllamaService,
26
27
  TestService,
27
28
  MistralAIService,
29
+ TogetherAIService,
28
30
  ]
29
31
  )
edsl/jobs/Jobs.py CHANGED
@@ -145,14 +145,21 @@ class Jobs(Base):
145
145
  >>> Jobs.example().prompts()
146
146
  Dataset(...)
147
147
  """
148
+ from edsl import Coop
149
+
150
+ c = Coop()
151
+ price_lookup = c.fetch_prices()
148
152
 
149
153
  interviews = self.interviews()
150
154
  # data = []
151
155
  interview_indices = []
152
- question_indices = []
156
+ question_names = []
153
157
  user_prompts = []
154
158
  system_prompts = []
155
159
  scenario_indices = []
160
+ agent_indices = []
161
+ models = []
162
+ costs = []
156
163
  from edsl.results.Dataset import Dataset
157
164
 
158
165
  for interview_index, interview in enumerate(interviews):
@@ -160,23 +167,97 @@ class Jobs(Base):
160
167
  interview._get_invigilator(question)
161
168
  for question in self.survey.questions
162
169
  ]
163
- # list(interview._build_invigilators(debug=False))
164
170
  for _, invigilator in enumerate(invigilators):
165
171
  prompts = invigilator.get_prompts()
166
- user_prompts.append(prompts["user_prompt"])
167
- system_prompts.append(prompts["system_prompt"])
172
+ user_prompt = prompts["user_prompt"]
173
+ system_prompt = prompts["system_prompt"]
174
+ user_prompts.append(user_prompt)
175
+ system_prompts.append(system_prompt)
176
+ agent_index = self.agents.index(invigilator.agent)
177
+ agent_indices.append(agent_index)
168
178
  interview_indices.append(interview_index)
169
- scenario_indices.append(invigilator.scenario)
170
- question_indices.append(invigilator.question.question_name)
171
- return Dataset(
179
+ scenario_index = self.scenarios.index(invigilator.scenario)
180
+ scenario_indices.append(scenario_index)
181
+ models.append(invigilator.model.model)
182
+ question_names.append(invigilator.question.question_name)
183
+ # cost calculation
184
+ key = (invigilator.model._inference_service_, invigilator.model.model)
185
+ relevant_prices = price_lookup[key]
186
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
187
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
188
+ input_tokens = len(str(user_prompt) + str(system_prompt)) // 4
189
+ output_tokens = len(str(user_prompt) + str(system_prompt)) // 4
190
+ cost = input_tokens / float(
191
+ inverse_input_price
192
+ ) + output_tokens / float(inverse_output_price)
193
+ costs.append(cost)
194
+
195
+ d = Dataset(
172
196
  [
173
- {"interview_index": interview_indices},
174
- {"question_index": question_indices},
175
197
  {"user_prompt": user_prompts},
176
- {"scenario_index": scenario_indices},
177
198
  {"system_prompt": system_prompts},
199
+ {"interview_index": interview_indices},
200
+ {"question_name": question_names},
201
+ {"scenario_index": scenario_indices},
202
+ {"agent_index": agent_indices},
203
+ {"model": models},
204
+ {"estimated_cost": costs},
178
205
  ]
179
206
  )
207
+ return d
208
+ # if table:
209
+ # d.to_scenario_list().print(format="rich")
210
+ # else:
211
+ # return d
212
+
213
+ def show_prompts(self) -> None:
214
+ """Print the prompts."""
215
+ self.prompts().to_scenario_list().print(format="rich")
216
+
217
+ def estimate_job_cost(self):
218
+ from edsl import Coop
219
+
220
+ c = Coop()
221
+ price_lookup = c.fetch_prices()
222
+
223
+ prompts = self.prompts()
224
+
225
+ text_len = 0
226
+ for prompt in prompts:
227
+ text_len += len(str(prompt))
228
+
229
+ input_token_aproximations = text_len // 4
230
+
231
+ aproximation_cost = {}
232
+ total_cost = 0
233
+ for model in self.models:
234
+ key = (model._inference_service_, model.model)
235
+ relevant_prices = price_lookup[key]
236
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
237
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
238
+
239
+ aproximation_cost[key] = {
240
+ "input": input_token_aproximations / float(inverse_input_price),
241
+ "output": input_token_aproximations / float(inverse_output_price),
242
+ }
243
+ ##TODO curenlty we approximate the number of output tokens with the number
244
+ # of input tokens. A better solution will be to compute the quesiton answer options length and sum them
245
+ # to compute the output tokens
246
+
247
+ total_cost += input_token_aproximations / float(inverse_input_price)
248
+ total_cost += input_token_aproximations / float(inverse_output_price)
249
+
250
+ # multiply_factor = len(self.agents or [1]) * len(self.scenarios or [1])
251
+ multiply_factor = 1
252
+ out = {
253
+ "input_token_aproximations": input_token_aproximations,
254
+ "models_costs": aproximation_cost,
255
+ "estimated_total_cost": total_cost * multiply_factor,
256
+ "multiply_factor": multiply_factor,
257
+ "single_config_cost": total_cost,
258
+ }
259
+
260
+ return out
180
261
 
181
262
  @staticmethod
182
263
  def _get_container_class(object):
@@ -460,6 +541,12 @@ class Jobs(Base):
460
541
  if warn:
461
542
  warnings.warn(message)
462
543
 
544
+ if self.scenarios.has_jinja_braces:
545
+ warnings.warn(
546
+ "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
547
+ )
548
+ self.scenarios = self.scenarios.convert_jinja_braces()
549
+
463
550
  @property
464
551
  def skip_retry(self):
465
552
  if not hasattr(self, "_skip_retry"):
@@ -486,6 +573,7 @@ class Jobs(Base):
486
573
  remote_inference_description: Optional[str] = None,
487
574
  skip_retry: bool = False,
488
575
  raise_validation_errors: bool = False,
576
+ disable_remote_inference: bool = False,
489
577
  ) -> Results:
490
578
  """
491
579
  Runs the Job: conducts Interviews and returns their results.
@@ -508,14 +596,17 @@ class Jobs(Base):
508
596
 
509
597
  self.verbose = verbose
510
598
 
511
- try:
512
- coop = Coop()
513
- user_edsl_settings = coop.edsl_settings
514
- remote_cache = user_edsl_settings["remote_caching"]
515
- remote_inference = user_edsl_settings["remote_inference"]
516
- except Exception:
517
- remote_cache = False
518
- remote_inference = False
599
+ remote_cache = False
600
+ remote_inference = False
601
+
602
+ if not disable_remote_inference:
603
+ try:
604
+ coop = Coop()
605
+ user_edsl_settings = Coop().edsl_settings
606
+ remote_cache = user_edsl_settings.get("remote_caching", False)
607
+ remote_inference = user_edsl_settings.get("remote_inference", False)
608
+ except Exception:
609
+ pass
519
610
 
520
611
  if remote_inference:
521
612
  import time
@@ -13,6 +13,8 @@ class BucketCollection(UserDict):
13
13
  def __init__(self, infinity_buckets=False):
14
14
  super().__init__()
15
15
  self.infinity_buckets = infinity_buckets
16
+ self.models_to_services = {}
17
+ self.services_to_buckets = {}
16
18
 
17
19
  def __repr__(self):
18
20
  return f"BucketCollection({self.data})"
@@ -21,6 +23,7 @@ class BucketCollection(UserDict):
21
23
  """Adds a model to the bucket collection.
22
24
 
23
25
  This will create the token and request buckets for the model."""
26
+
24
27
  # compute the TPS and RPS from the model
25
28
  if not self.infinity_buckets:
26
29
  TPS = model.TPM / 60.0
@@ -29,22 +32,28 @@ class BucketCollection(UserDict):
29
32
  TPS = float("inf")
30
33
  RPS = float("inf")
31
34
 
32
- # create the buckets
33
- requests_bucket = TokenBucket(
34
- bucket_name=model.model,
35
- bucket_type="requests",
36
- capacity=RPS,
37
- refill_rate=RPS,
38
- )
39
- tokens_bucket = TokenBucket(
40
- bucket_name=model.model, bucket_type="tokens", capacity=TPS, refill_rate=TPS
41
- )
42
- model_buckets = ModelBuckets(requests_bucket, tokens_bucket)
43
- if model in self:
44
- # it if already exists, combine the buckets
45
- self[model] += model_buckets
35
+ if model.model not in self.models_to_services:
36
+ service = model._inference_service_
37
+ if service not in self.services_to_buckets:
38
+ requests_bucket = TokenBucket(
39
+ bucket_name=service,
40
+ bucket_type="requests",
41
+ capacity=RPS,
42
+ refill_rate=RPS,
43
+ )
44
+ tokens_bucket = TokenBucket(
45
+ bucket_name=service,
46
+ bucket_type="tokens",
47
+ capacity=TPS,
48
+ refill_rate=TPS,
49
+ )
50
+ self.services_to_buckets[service] = ModelBuckets(
51
+ requests_bucket, tokens_bucket
52
+ )
53
+ self.models_to_services[model.model] = service
54
+ self[model] = self.services_to_buckets[service]
46
55
  else:
47
- self[model] = model_buckets
56
+ self[model] = self.services_to_buckets[self.models_to_services[model.model]]
48
57
 
49
58
  def visualize(self) -> dict:
50
59
  """Visualize the token and request buckets for each model."""
@@ -1,4 +1,4 @@
1
- from typing import Union, List, Any
1
+ from typing import Union, List, Any, Optional
2
2
  import asyncio
3
3
  import time
4
4
 
@@ -17,6 +17,12 @@ class TokenBucket:
17
17
  self.bucket_name = bucket_name
18
18
  self.bucket_type = bucket_type
19
19
  self.capacity = capacity # Maximum number of tokens
20
+ self.added_tokens = 0
21
+
22
+ self.target_rate = (
23
+ capacity * 60
24
+ ) # set this here because it can change with turbo mode
25
+
20
26
  self._old_capacity = capacity
21
27
  self.tokens = capacity # Current number of available tokens
22
28
  self.refill_rate = refill_rate # Rate at which tokens are refilled
@@ -25,6 +31,12 @@ class TokenBucket:
25
31
  self.log: List[Any] = []
26
32
  self.turbo_mode = False
27
33
 
34
+ self.creation_time = time.monotonic()
35
+
36
+ self.num_requests = 0
37
+ self.num_released = 0
38
+ self.tokens_returned = 0
39
+
28
40
  def turbo_mode_on(self):
29
41
  """Set the refill rate to infinity."""
30
42
  if self.turbo_mode:
@@ -69,6 +81,7 @@ class TokenBucket:
69
81
  >>> bucket.tokens
70
82
  10
71
83
  """
84
+ self.tokens_returned += tokens
72
85
  self.tokens = min(self.capacity, self.tokens + tokens)
73
86
  self.log.append((time.monotonic(), self.tokens))
74
87
 
@@ -133,15 +146,12 @@ class TokenBucket:
133
146
  >>> bucket.capacity
134
147
  12.100000000000001
135
148
  """
149
+ self.num_requests += amount
136
150
  if amount >= self.capacity:
137
151
  if not cheat_bucket_capacity:
138
152
  msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
139
153
  raise ValueError(msg)
140
154
  else:
141
- # self.tokens = 0 # clear the bucket but let it go through
142
- # print(
143
- # f"""The requested amount, {amount}, exceeds the current bucket capacity of {self.capacity}.Increasing bucket capacity to {amount} * 1.10 accommodate the requested amount."""
144
- # )
145
155
  self.capacity = amount * 1.10
146
156
  self._old_capacity = self.capacity
147
157
 
@@ -153,14 +163,10 @@ class TokenBucket:
153
163
  break
154
164
 
155
165
  wait_time = self.wait_time(amount)
156
- # print(f"Waiting for {wait_time:.4f} seconds")
157
166
  if wait_time > 0:
158
- # print(f"Waiting for {wait_time:.4f} seconds")
159
167
  await asyncio.sleep(wait_time)
160
168
 
161
- # total_elapsed = time.monotonic() - start_time
162
- # print(f"Total time to acquire tokens: {total_elapsed:.4f} seconds")
163
-
169
+ self.num_released += amount
164
170
  now = time.monotonic()
165
171
  self.log.append((now, self.tokens))
166
172
  return None
@@ -187,6 +193,54 @@ class TokenBucket:
187
193
  plt.tight_layout()
188
194
  plt.show()
189
195
 
196
+ def get_throughput(self, time_window: Optional[float] = None) -> float:
197
+ """
198
+ Calculate the empirical bucket throughput in tokens per minute for the specified time window.
199
+
200
+ :param time_window: The time window in seconds to calculate the throughput for.
201
+ :return: The throughput in tokens per minute.
202
+
203
+ >>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=100, refill_rate=10)
204
+ >>> asyncio.run(bucket.get_tokens(50))
205
+ >>> time.sleep(1) # Wait for 1 second
206
+ >>> asyncio.run(bucket.get_tokens(30))
207
+ >>> throughput = bucket.get_throughput(1)
208
+ >>> 4750 < throughput < 4850
209
+ True
210
+ """
211
+ now = time.monotonic()
212
+
213
+ if time_window is None:
214
+ start_time = self.creation_time
215
+ else:
216
+ start_time = now - time_window
217
+
218
+ if start_time < self.creation_time:
219
+ start_time = self.creation_time
220
+
221
+ elapsed_time = now - start_time
222
+
223
+ return (self.num_released / elapsed_time) * 60
224
+
225
+ # # Filter log entries within the time window
226
+ # relevant_log = [(t, tokens) for t, tokens in self.log if t >= start_time]
227
+
228
+ # if len(relevant_log) < 2:
229
+ # return 0 # Not enough data points to calculate throughput
230
+
231
+ # # Calculate total tokens used
232
+ # initial_tokens = relevant_log[0][1]
233
+ # final_tokens = relevant_log[-1][1]
234
+ # tokens_used = self.num_released - (final_tokens - initial_tokens)
235
+
236
+ # # Calculate actual time elapsed
237
+ # actual_time_elapsed = relevant_log[-1][0] - relevant_log[0][0]
238
+
239
+ # # Calculate throughput in tokens per minute
240
+ # throughput = (tokens_used / actual_time_elapsed) * 60
241
+
242
+ # return throughput
243
+
190
244
 
191
245
  if __name__ == "__main__":
192
246
  import doctest