edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +6 -3
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
- edsl/config.py +26 -34
- edsl/coop/coop.py +11 -2
- edsl/data_transfer_models.py +27 -73
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +1 -1
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/OpenAIService.py +7 -4
- edsl/inference_services/TestService.py +24 -15
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +18 -8
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +115 -47
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +26 -31
- edsl/language_models/registry.py +13 -9
- edsl/questions/QuestionBase.py +64 -16
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +11 -26
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +6 -5
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +86 -21
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +13 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +19 -3
- edsl/surveys/Survey.py +7 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -64,7 +64,7 @@ class GoogleService(InferenceServiceABC):
|
|
64
64
|
"stopSequences": self.stopSequences,
|
65
65
|
},
|
66
66
|
}
|
67
|
-
print(combined_prompt)
|
67
|
+
# print(combined_prompt)
|
68
68
|
async with aiohttp.ClientSession() as session:
|
69
69
|
async with session.post(
|
70
70
|
url, headers=headers, data=json.dumps(data)
|
@@ -1,14 +1,27 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
|
-
|
2
|
+
import os
|
3
3
|
import re
|
4
4
|
from edsl.config import CONFIG
|
5
5
|
|
6
6
|
|
7
7
|
class InferenceServiceABC(ABC):
|
8
|
-
"""
|
8
|
+
"""
|
9
|
+
Abstract class for inference services.
|
10
|
+
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
11
|
+
"""
|
12
|
+
|
13
|
+
default_levels = {
|
14
|
+
"google": {"tpm": 2_000_000, "rpm": 15},
|
15
|
+
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
16
|
+
"anthropic": {"tpm": 2_000_000, "rpm": 500},
|
17
|
+
}
|
9
18
|
|
10
|
-
# check if child class has cls attribute "key_sequence"
|
11
19
|
def __init_subclass__(cls):
|
20
|
+
"""
|
21
|
+
Check that the subclass has the required attributes.
|
22
|
+
- `key_sequence` attribute determines...
|
23
|
+
- `model_exclude_list` attribute determines...
|
24
|
+
"""
|
12
25
|
if not hasattr(cls, "key_sequence"):
|
13
26
|
raise NotImplementedError(
|
14
27
|
f"Class {cls.__name__} must have a 'key_sequence' attribute."
|
@@ -18,29 +31,47 @@ class InferenceServiceABC(ABC):
|
|
18
31
|
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
19
32
|
)
|
20
33
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
34
|
+
@classmethod
|
35
|
+
def _get_limt(cls, limit_type: str) -> int:
|
36
|
+
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
37
|
+
if key in os.environ:
|
38
|
+
return int(os.getenv(key))
|
39
|
+
|
40
|
+
if cls._inference_service_ in cls.default_levels:
|
41
|
+
return int(cls.default_levels[cls._inference_service_][limit_type])
|
42
|
+
|
43
|
+
return int(CONFIG.get(f"EDSL_SERVICE_{limit_type.upper()}_BASELINE"))
|
44
|
+
|
45
|
+
def get_tpm(cls) -> int:
|
46
|
+
"""
|
47
|
+
Returns the TPM for the service. If the service is not defined in the environment variables, it will return the baseline TPM.
|
48
|
+
"""
|
49
|
+
return cls._get_limt(limit_type="tpm")
|
26
50
|
|
27
51
|
def get_rpm(cls):
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
return
|
52
|
+
"""
|
53
|
+
Returns the RPM for the service. If the service is not defined in the environment variables, it will return the baseline RPM.
|
54
|
+
"""
|
55
|
+
return cls._get_limt(limit_type="rpm")
|
32
56
|
|
33
57
|
@abstractmethod
|
34
58
|
def available() -> list[str]:
|
59
|
+
"""
|
60
|
+
Returns a list of available models for the service.
|
61
|
+
"""
|
35
62
|
pass
|
36
63
|
|
37
64
|
@abstractmethod
|
38
65
|
def create_model():
|
66
|
+
"""
|
67
|
+
Returns a LanguageModel object.
|
68
|
+
"""
|
39
69
|
pass
|
40
70
|
|
41
71
|
@staticmethod
|
42
72
|
def to_class_name(s):
|
43
|
-
"""
|
73
|
+
"""
|
74
|
+
Converts a string to a valid class name.
|
44
75
|
|
45
76
|
>>> InferenceServiceABC.to_class_name("hello world")
|
46
77
|
'HelloWorld'
|
@@ -187,12 +187,15 @@ class OpenAIService(InferenceServiceABC):
|
|
187
187
|
else:
|
188
188
|
content = user_prompt
|
189
189
|
client = self.async_client()
|
190
|
+
messages = [
|
191
|
+
{"role": "system", "content": system_prompt},
|
192
|
+
{"role": "user", "content": content},
|
193
|
+
]
|
194
|
+
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
195
|
+
messages = messages[1:]
|
190
196
|
params = {
|
191
197
|
"model": self.model,
|
192
|
-
"messages":
|
193
|
-
{"role": "system", "content": system_prompt},
|
194
|
-
{"role": "user", "content": content},
|
195
|
-
],
|
198
|
+
"messages": messages,
|
196
199
|
"temperature": self.temperature,
|
197
200
|
"max_tokens": self.max_tokens,
|
198
201
|
"top_p": self.top_p,
|
@@ -7,14 +7,25 @@ from edsl.inference_services.rate_limits_cache import rate_limits
|
|
7
7
|
from edsl.utilities.utilities import fix_partial_correct_response
|
8
8
|
|
9
9
|
from edsl.enums import InferenceServiceType
|
10
|
+
import random
|
10
11
|
|
11
12
|
|
12
13
|
class TestService(InferenceServiceABC):
|
13
14
|
"""OpenAI service class."""
|
14
15
|
|
16
|
+
_inference_service_ = "test"
|
17
|
+
_env_key_name_ = None
|
18
|
+
_base_url_ = None
|
19
|
+
|
20
|
+
_sync_client_ = None
|
21
|
+
_async_client_ = None
|
22
|
+
|
23
|
+
_sync_client_instance = None
|
24
|
+
_async_client_instance = None
|
25
|
+
|
15
26
|
key_sequence = None
|
27
|
+
usage_sequence = None
|
16
28
|
model_exclude_list = []
|
17
|
-
_inference_service_ = "test"
|
18
29
|
input_token_name = "prompt_tokens"
|
19
30
|
output_token_name = "completion_tokens"
|
20
31
|
|
@@ -45,27 +56,25 @@ class TestService(InferenceServiceABC):
|
|
45
56
|
return "Hello, world"
|
46
57
|
|
47
58
|
async def async_execute_model_call(
|
48
|
-
self,
|
59
|
+
self,
|
60
|
+
user_prompt: str,
|
61
|
+
system_prompt: str,
|
62
|
+
encoded_image=None,
|
49
63
|
) -> dict[str, Any]:
|
50
64
|
await asyncio.sleep(0.1)
|
51
65
|
# return {"message": """{"answer": "Hello, world"}"""}
|
66
|
+
|
52
67
|
if hasattr(self, "throw_exception") and self.throw_exception:
|
53
|
-
|
68
|
+
if hasattr(self, "exception_probability"):
|
69
|
+
p = self.exception_probability
|
70
|
+
else:
|
71
|
+
p = 1
|
72
|
+
|
73
|
+
if random.random() < p:
|
74
|
+
raise Exception("This is a test error")
|
54
75
|
return {
|
55
76
|
"message": [{"text": f"{self._canned_response}"}],
|
56
77
|
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
57
78
|
}
|
58
79
|
|
59
80
|
return TestServiceLanguageModel
|
60
|
-
|
61
|
-
# _inference_service_ = "openai"
|
62
|
-
# _env_key_name_ = "OPENAI_API_KEY"
|
63
|
-
# _base_url_ = None
|
64
|
-
|
65
|
-
# _sync_client_ = openai.OpenAI
|
66
|
-
# _async_client_ = openai.AsyncOpenAI
|
67
|
-
|
68
|
-
# _sync_client_instance = None
|
69
|
-
# _async_client_instance = None
|
70
|
-
|
71
|
-
# key_sequence = ["choices", 0, "message", "content"]
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
import openai
|
11
|
+
|
12
|
+
|
13
|
+
class TogetherAIService(OpenAIService):
|
14
|
+
"""DeepInfra service class."""
|
15
|
+
|
16
|
+
_inference_service_ = "together"
|
17
|
+
_env_key_name_ = "TOGETHER_API_KEY"
|
18
|
+
_base_url_ = "https://api.together.xyz/v1"
|
19
|
+
_models_list_cache: List[str] = []
|
20
|
+
|
21
|
+
# These are non-serverless models. There was no api param to filter them
|
22
|
+
model_exclude_list = [
|
23
|
+
"EleutherAI/llemma_7b",
|
24
|
+
"HuggingFaceH4/zephyr-7b-beta",
|
25
|
+
"Nexusflow/NexusRaven-V2-13B",
|
26
|
+
"NousResearch/Hermes-2-Theta-Llama-3-70B",
|
27
|
+
"NousResearch/Nous-Capybara-7B-V1p9",
|
28
|
+
"NousResearch/Nous-Hermes-13b",
|
29
|
+
"NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
|
30
|
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
|
31
|
+
"NousResearch/Nous-Hermes-Llama2-13b",
|
32
|
+
"NousResearch/Nous-Hermes-Llama2-70b",
|
33
|
+
"NousResearch/Nous-Hermes-llama-2-7b",
|
34
|
+
"NumbersStation/nsql-llama-2-7B",
|
35
|
+
"Open-Orca/Mistral-7B-OpenOrca",
|
36
|
+
"Phind/Phind-CodeLlama-34B-Python-v1",
|
37
|
+
"Phind/Phind-CodeLlama-34B-v2",
|
38
|
+
"Qwen/Qwen1.5-0.5B",
|
39
|
+
"Qwen/Qwen1.5-0.5B-Chat",
|
40
|
+
"Qwen/Qwen1.5-1.8B",
|
41
|
+
"Qwen/Qwen1.5-1.8B-Chat",
|
42
|
+
"Qwen/Qwen1.5-14B",
|
43
|
+
"Qwen/Qwen1.5-14B-Chat",
|
44
|
+
"Qwen/Qwen1.5-32B",
|
45
|
+
"Qwen/Qwen1.5-32B-Chat",
|
46
|
+
"Qwen/Qwen1.5-4B",
|
47
|
+
"Qwen/Qwen1.5-4B-Chat",
|
48
|
+
"Qwen/Qwen1.5-72B",
|
49
|
+
"Qwen/Qwen1.5-7B",
|
50
|
+
"Qwen/Qwen1.5-7B-Chat",
|
51
|
+
"Qwen/Qwen2-1.5B",
|
52
|
+
"Qwen/Qwen2-1.5B-Instruct",
|
53
|
+
"Qwen/Qwen2-72B",
|
54
|
+
"Qwen/Qwen2-7B",
|
55
|
+
"Qwen/Qwen2-7B-Instruct",
|
56
|
+
"SG161222/Realistic_Vision_V3.0_VAE",
|
57
|
+
"Snowflake/snowflake-arctic-instruct",
|
58
|
+
"Undi95/ReMM-SLERP-L2-13B",
|
59
|
+
"Undi95/Toppy-M-7B",
|
60
|
+
"WizardLM/WizardCoder-Python-34B-V1.0",
|
61
|
+
"WizardLM/WizardLM-13B-V1.2",
|
62
|
+
"WizardLM/WizardLM-70B-V1.0",
|
63
|
+
"allenai/OLMo-7B",
|
64
|
+
"allenai/OLMo-7B-Instruct",
|
65
|
+
"bert-base-uncased",
|
66
|
+
"codellama/CodeLlama-13b-Instruct-hf",
|
67
|
+
"codellama/CodeLlama-13b-Python-hf",
|
68
|
+
"codellama/CodeLlama-13b-hf",
|
69
|
+
"codellama/CodeLlama-34b-Python-hf",
|
70
|
+
"codellama/CodeLlama-34b-hf",
|
71
|
+
"codellama/CodeLlama-70b-Instruct-hf",
|
72
|
+
"codellama/CodeLlama-70b-Python-hf",
|
73
|
+
"codellama/CodeLlama-70b-hf",
|
74
|
+
"codellama/CodeLlama-7b-Instruct-hf",
|
75
|
+
"codellama/CodeLlama-7b-Python-hf",
|
76
|
+
"codellama/CodeLlama-7b-hf",
|
77
|
+
"cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
78
|
+
"deepseek-ai/deepseek-coder-33b-instruct",
|
79
|
+
"garage-bAInd/Platypus2-70B-instruct",
|
80
|
+
"google/gemma-2b",
|
81
|
+
"google/gemma-7b",
|
82
|
+
"google/gemma-7b-it",
|
83
|
+
"gradientai/Llama-3-70B-Instruct-Gradient-1048k",
|
84
|
+
"hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1",
|
85
|
+
"huggyllama/llama-13b",
|
86
|
+
"huggyllama/llama-30b",
|
87
|
+
"huggyllama/llama-65b",
|
88
|
+
"huggyllama/llama-7b",
|
89
|
+
"lmsys/vicuna-13b-v1.3",
|
90
|
+
"lmsys/vicuna-13b-v1.5",
|
91
|
+
"lmsys/vicuna-13b-v1.5-16k",
|
92
|
+
"lmsys/vicuna-7b-v1.3",
|
93
|
+
"lmsys/vicuna-7b-v1.5",
|
94
|
+
"meta-llama/Llama-2-13b-hf",
|
95
|
+
"meta-llama/Llama-2-70b-chat-hf",
|
96
|
+
"meta-llama/Llama-2-7b-hf",
|
97
|
+
"meta-llama/Llama-3-70b-hf",
|
98
|
+
"meta-llama/Llama-3-8b-hf",
|
99
|
+
"meta-llama/Meta-Llama-3-70B",
|
100
|
+
"meta-llama/Meta-Llama-3-70B-Instruct",
|
101
|
+
"meta-llama/Meta-Llama-3-8B-Instruct",
|
102
|
+
"meta-llama/Meta-Llama-3.1-70B-Instruct-Reference",
|
103
|
+
"meta-llama/Meta-Llama-3.1-70B-Reference",
|
104
|
+
"meta-llama/Meta-Llama-3.1-8B-Reference",
|
105
|
+
"microsoft/phi-2",
|
106
|
+
"mistralai/Mixtral-8x22B",
|
107
|
+
"openchat/openchat-3.5-1210",
|
108
|
+
"prompthero/openjourney",
|
109
|
+
"runwayml/stable-diffusion-v1-5",
|
110
|
+
"sentence-transformers/msmarco-bert-base-dot-v5",
|
111
|
+
"snorkelai/Snorkel-Mistral-PairRM-DPO",
|
112
|
+
"stabilityai/stable-diffusion-2-1",
|
113
|
+
"teknium/OpenHermes-2-Mistral-7B",
|
114
|
+
"teknium/OpenHermes-2p5-Mistral-7B",
|
115
|
+
"togethercomputer/CodeLlama-13b-Instruct",
|
116
|
+
"togethercomputer/CodeLlama-13b-Python",
|
117
|
+
"togethercomputer/CodeLlama-34b",
|
118
|
+
"togethercomputer/CodeLlama-34b-Python",
|
119
|
+
"togethercomputer/CodeLlama-7b-Instruct",
|
120
|
+
"togethercomputer/CodeLlama-7b-Python",
|
121
|
+
"togethercomputer/Koala-13B",
|
122
|
+
"togethercomputer/Koala-7B",
|
123
|
+
"togethercomputer/LLaMA-2-7B-32K",
|
124
|
+
"togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4",
|
125
|
+
"togethercomputer/StripedHyena-Hessian-7B",
|
126
|
+
"togethercomputer/alpaca-7b",
|
127
|
+
"togethercomputer/evo-1-131k-base",
|
128
|
+
"togethercomputer/evo-1-8k-base",
|
129
|
+
"togethercomputer/guanaco-13b",
|
130
|
+
"togethercomputer/guanaco-33b",
|
131
|
+
"togethercomputer/guanaco-65b",
|
132
|
+
"togethercomputer/guanaco-7b",
|
133
|
+
"togethercomputer/llama-2-13b",
|
134
|
+
"togethercomputer/llama-2-70b-chat",
|
135
|
+
"togethercomputer/llama-2-7b",
|
136
|
+
"wavymulder/Analog-Diffusion",
|
137
|
+
"zero-one-ai/Yi-34B",
|
138
|
+
"zero-one-ai/Yi-34B-Chat",
|
139
|
+
"zero-one-ai/Yi-6B",
|
140
|
+
]
|
141
|
+
|
142
|
+
_sync_client_ = openai.OpenAI
|
143
|
+
_async_client_ = openai.AsyncOpenAI
|
144
|
+
|
145
|
+
@classmethod
|
146
|
+
def get_model_list(cls):
|
147
|
+
# Togheter.ai has a different response in model list then openai
|
148
|
+
# and the OpenAI class returns an error when calling .models.list()
|
149
|
+
import requests
|
150
|
+
import os
|
151
|
+
|
152
|
+
url = "https://api.together.xyz/v1/models?filter=serverless"
|
153
|
+
token = os.getenv(cls._env_key_name_)
|
154
|
+
headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
|
155
|
+
|
156
|
+
response = requests.get(url, headers=headers)
|
157
|
+
return response.json()
|
158
|
+
|
159
|
+
@classmethod
|
160
|
+
def available(cls) -> List[str]:
|
161
|
+
if not cls._models_list_cache:
|
162
|
+
try:
|
163
|
+
cls._models_list_cache = [
|
164
|
+
m["id"]
|
165
|
+
for m in cls.get_model_list()
|
166
|
+
if m["id"] not in cls.model_exclude_list
|
167
|
+
]
|
168
|
+
except Exception as e:
|
169
|
+
raise
|
170
|
+
return cls._models_list_cache
|
@@ -12,6 +12,7 @@ from edsl.inference_services.AzureAI import AzureAIService
|
|
12
12
|
from edsl.inference_services.OllamaService import OllamaService
|
13
13
|
from edsl.inference_services.TestService import TestService
|
14
14
|
from edsl.inference_services.MistralAIService import MistralAIService
|
15
|
+
from edsl.inference_services.TogetherAIService import TogetherAIService
|
15
16
|
|
16
17
|
default = InferenceServicesCollection(
|
17
18
|
[
|
@@ -25,5 +26,6 @@ default = InferenceServicesCollection(
|
|
25
26
|
OllamaService,
|
26
27
|
TestService,
|
27
28
|
MistralAIService,
|
29
|
+
TogetherAIService,
|
28
30
|
]
|
29
31
|
)
|
edsl/jobs/Jobs.py
CHANGED
@@ -460,6 +460,12 @@ class Jobs(Base):
|
|
460
460
|
if warn:
|
461
461
|
warnings.warn(message)
|
462
462
|
|
463
|
+
if self.scenarios.has_jinja_braces:
|
464
|
+
warnings.warn(
|
465
|
+
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
466
|
+
)
|
467
|
+
self.scenarios = self.scenarios.convert_jinja_braces()
|
468
|
+
|
463
469
|
@property
|
464
470
|
def skip_retry(self):
|
465
471
|
if not hasattr(self, "_skip_retry"):
|
@@ -486,6 +492,7 @@ class Jobs(Base):
|
|
486
492
|
remote_inference_description: Optional[str] = None,
|
487
493
|
skip_retry: bool = False,
|
488
494
|
raise_validation_errors: bool = False,
|
495
|
+
disable_remote_inference: bool = False,
|
489
496
|
) -> Results:
|
490
497
|
"""
|
491
498
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -508,14 +515,17 @@ class Jobs(Base):
|
|
508
515
|
|
509
516
|
self.verbose = verbose
|
510
517
|
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
518
|
+
remote_cache = False
|
519
|
+
remote_inference = False
|
520
|
+
|
521
|
+
if not disable_remote_inference:
|
522
|
+
try:
|
523
|
+
coop = Coop()
|
524
|
+
user_edsl_settings = Coop().edsl_settings
|
525
|
+
remote_cache = user_edsl_settings.get("remote_caching", False)
|
526
|
+
remote_inference = user_edsl_settings.get("remote_inference", False)
|
527
|
+
except Exception:
|
528
|
+
pass
|
519
529
|
|
520
530
|
if remote_inference:
|
521
531
|
import time
|
@@ -13,6 +13,8 @@ class BucketCollection(UserDict):
|
|
13
13
|
def __init__(self, infinity_buckets=False):
|
14
14
|
super().__init__()
|
15
15
|
self.infinity_buckets = infinity_buckets
|
16
|
+
self.models_to_services = {}
|
17
|
+
self.services_to_buckets = {}
|
16
18
|
|
17
19
|
def __repr__(self):
|
18
20
|
return f"BucketCollection({self.data})"
|
@@ -21,6 +23,7 @@ class BucketCollection(UserDict):
|
|
21
23
|
"""Adds a model to the bucket collection.
|
22
24
|
|
23
25
|
This will create the token and request buckets for the model."""
|
26
|
+
|
24
27
|
# compute the TPS and RPS from the model
|
25
28
|
if not self.infinity_buckets:
|
26
29
|
TPS = model.TPM / 60.0
|
@@ -29,22 +32,28 @@ class BucketCollection(UserDict):
|
|
29
32
|
TPS = float("inf")
|
30
33
|
RPS = float("inf")
|
31
34
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
35
|
+
if model.model not in self.models_to_services:
|
36
|
+
service = model._inference_service_
|
37
|
+
if service not in self.services_to_buckets:
|
38
|
+
requests_bucket = TokenBucket(
|
39
|
+
bucket_name=service,
|
40
|
+
bucket_type="requests",
|
41
|
+
capacity=RPS,
|
42
|
+
refill_rate=RPS,
|
43
|
+
)
|
44
|
+
tokens_bucket = TokenBucket(
|
45
|
+
bucket_name=service,
|
46
|
+
bucket_type="tokens",
|
47
|
+
capacity=TPS,
|
48
|
+
refill_rate=TPS,
|
49
|
+
)
|
50
|
+
self.services_to_buckets[service] = ModelBuckets(
|
51
|
+
requests_bucket, tokens_bucket
|
52
|
+
)
|
53
|
+
self.models_to_services[model.model] = service
|
54
|
+
self[model] = self.services_to_buckets[service]
|
46
55
|
else:
|
47
|
-
self[model] =
|
56
|
+
self[model] = self.services_to_buckets[self.models_to_services[model.model]]
|
48
57
|
|
49
58
|
def visualize(self) -> dict:
|
50
59
|
"""Visualize the token and request buckets for each model."""
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Union, List, Any
|
1
|
+
from typing import Union, List, Any, Optional
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
4
|
|
@@ -17,6 +17,12 @@ class TokenBucket:
|
|
17
17
|
self.bucket_name = bucket_name
|
18
18
|
self.bucket_type = bucket_type
|
19
19
|
self.capacity = capacity # Maximum number of tokens
|
20
|
+
self.added_tokens = 0
|
21
|
+
|
22
|
+
self.target_rate = (
|
23
|
+
capacity * 60
|
24
|
+
) # set this here because it can change with turbo mode
|
25
|
+
|
20
26
|
self._old_capacity = capacity
|
21
27
|
self.tokens = capacity # Current number of available tokens
|
22
28
|
self.refill_rate = refill_rate # Rate at which tokens are refilled
|
@@ -25,6 +31,12 @@ class TokenBucket:
|
|
25
31
|
self.log: List[Any] = []
|
26
32
|
self.turbo_mode = False
|
27
33
|
|
34
|
+
self.creation_time = time.monotonic()
|
35
|
+
|
36
|
+
self.num_requests = 0
|
37
|
+
self.num_released = 0
|
38
|
+
self.tokens_returned = 0
|
39
|
+
|
28
40
|
def turbo_mode_on(self):
|
29
41
|
"""Set the refill rate to infinity."""
|
30
42
|
if self.turbo_mode:
|
@@ -69,6 +81,7 @@ class TokenBucket:
|
|
69
81
|
>>> bucket.tokens
|
70
82
|
10
|
71
83
|
"""
|
84
|
+
self.tokens_returned += tokens
|
72
85
|
self.tokens = min(self.capacity, self.tokens + tokens)
|
73
86
|
self.log.append((time.monotonic(), self.tokens))
|
74
87
|
|
@@ -133,15 +146,12 @@ class TokenBucket:
|
|
133
146
|
>>> bucket.capacity
|
134
147
|
12.100000000000001
|
135
148
|
"""
|
149
|
+
self.num_requests += amount
|
136
150
|
if amount >= self.capacity:
|
137
151
|
if not cheat_bucket_capacity:
|
138
152
|
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
139
153
|
raise ValueError(msg)
|
140
154
|
else:
|
141
|
-
# self.tokens = 0 # clear the bucket but let it go through
|
142
|
-
# print(
|
143
|
-
# f"""The requested amount, {amount}, exceeds the current bucket capacity of {self.capacity}.Increasing bucket capacity to {amount} * 1.10 accommodate the requested amount."""
|
144
|
-
# )
|
145
155
|
self.capacity = amount * 1.10
|
146
156
|
self._old_capacity = self.capacity
|
147
157
|
|
@@ -153,14 +163,10 @@ class TokenBucket:
|
|
153
163
|
break
|
154
164
|
|
155
165
|
wait_time = self.wait_time(amount)
|
156
|
-
# print(f"Waiting for {wait_time:.4f} seconds")
|
157
166
|
if wait_time > 0:
|
158
|
-
# print(f"Waiting for {wait_time:.4f} seconds")
|
159
167
|
await asyncio.sleep(wait_time)
|
160
168
|
|
161
|
-
|
162
|
-
# print(f"Total time to acquire tokens: {total_elapsed:.4f} seconds")
|
163
|
-
|
169
|
+
self.num_released += amount
|
164
170
|
now = time.monotonic()
|
165
171
|
self.log.append((now, self.tokens))
|
166
172
|
return None
|
@@ -187,6 +193,54 @@ class TokenBucket:
|
|
187
193
|
plt.tight_layout()
|
188
194
|
plt.show()
|
189
195
|
|
196
|
+
def get_throughput(self, time_window: Optional[float] = None) -> float:
|
197
|
+
"""
|
198
|
+
Calculate the empirical bucket throughput in tokens per minute for the specified time window.
|
199
|
+
|
200
|
+
:param time_window: The time window in seconds to calculate the throughput for.
|
201
|
+
:return: The throughput in tokens per minute.
|
202
|
+
|
203
|
+
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=100, refill_rate=10)
|
204
|
+
>>> asyncio.run(bucket.get_tokens(50))
|
205
|
+
>>> time.sleep(1) # Wait for 1 second
|
206
|
+
>>> asyncio.run(bucket.get_tokens(30))
|
207
|
+
>>> throughput = bucket.get_throughput(1)
|
208
|
+
>>> 4750 < throughput < 4850
|
209
|
+
True
|
210
|
+
"""
|
211
|
+
now = time.monotonic()
|
212
|
+
|
213
|
+
if time_window is None:
|
214
|
+
start_time = self.creation_time
|
215
|
+
else:
|
216
|
+
start_time = now - time_window
|
217
|
+
|
218
|
+
if start_time < self.creation_time:
|
219
|
+
start_time = self.creation_time
|
220
|
+
|
221
|
+
elapsed_time = now - start_time
|
222
|
+
|
223
|
+
return (self.num_released / elapsed_time) * 60
|
224
|
+
|
225
|
+
# # Filter log entries within the time window
|
226
|
+
# relevant_log = [(t, tokens) for t, tokens in self.log if t >= start_time]
|
227
|
+
|
228
|
+
# if len(relevant_log) < 2:
|
229
|
+
# return 0 # Not enough data points to calculate throughput
|
230
|
+
|
231
|
+
# # Calculate total tokens used
|
232
|
+
# initial_tokens = relevant_log[0][1]
|
233
|
+
# final_tokens = relevant_log[-1][1]
|
234
|
+
# tokens_used = self.num_released - (final_tokens - initial_tokens)
|
235
|
+
|
236
|
+
# # Calculate actual time elapsed
|
237
|
+
# actual_time_elapsed = relevant_log[-1][0] - relevant_log[0][0]
|
238
|
+
|
239
|
+
# # Calculate throughput in tokens per minute
|
240
|
+
# throughput = (tokens_used / actual_time_elapsed) * 60
|
241
|
+
|
242
|
+
# return throughput
|
243
|
+
|
190
244
|
|
191
245
|
if __name__ == "__main__":
|
192
246
|
import doctest
|