edsl 0.1.31__py3-none-any.whl → 0.1.31.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +2 -7
- edsl/agents/PromptConstructionMixin.py +4 -9
- edsl/config.py +0 -4
- edsl/conjure/Conjure.py +0 -6
- edsl/coop/coop.py +0 -4
- edsl/data/CacheHandler.py +4 -3
- edsl/enums.py +0 -2
- edsl/inference_services/DeepInfraService.py +91 -6
- edsl/inference_services/InferenceServicesCollection.py +8 -13
- edsl/inference_services/OpenAIService.py +21 -64
- edsl/inference_services/registry.py +1 -2
- edsl/jobs/Jobs.py +5 -29
- edsl/jobs/buckets/TokenBucket.py +4 -12
- edsl/jobs/interviews/Interview.py +9 -31
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +33 -49
- edsl/jobs/interviews/interview_exception_tracking.py +10 -68
- edsl/jobs/runners/JobsRunnerAsyncio.py +81 -112
- edsl/jobs/runners/JobsRunnerStatusData.py +237 -0
- edsl/jobs/runners/JobsRunnerStatusMixin.py +35 -291
- edsl/jobs/tasks/TaskCreators.py +2 -8
- edsl/jobs/tasks/TaskHistory.py +1 -145
- edsl/language_models/LanguageModel.py +32 -49
- edsl/language_models/registry.py +0 -4
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -0
- edsl/results/DatasetExportMixin.py +3 -12
- edsl/scenarios/Scenario.py +0 -14
- edsl/scenarios/ScenarioList.py +2 -15
- edsl/scenarios/ScenarioListExportMixin.py +4 -15
- edsl/scenarios/ScenarioListPdfMixin.py +0 -3
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/METADATA +1 -2
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/RECORD +35 -37
- edsl/inference_services/GroqService.py +0 -18
- edsl/jobs/interviews/InterviewExceptionEntry.py +0 -101
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/WHEEL +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.31"
|
1
|
+
__version__ = "0.1.31.dev2"
|
edsl/agents/Invigilator.py
CHANGED
@@ -18,12 +18,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
18
18
|
"""An invigilator that uses an AI model to answer questions."""
|
19
19
|
|
20
20
|
async def async_answer_question(self) -> AgentResponseDict:
|
21
|
-
"""Answer a question using the AI model.
|
22
|
-
|
23
|
-
>>> i = InvigilatorAI.example()
|
24
|
-
>>> i.answer_question()
|
25
|
-
{'message': '{"answer": "SPAM!"}'}
|
26
|
-
"""
|
21
|
+
"""Answer a question using the AI model."""
|
27
22
|
params = self.get_prompts() | {"iteration": self.iteration}
|
28
23
|
raw_response = await self.async_get_response(**params)
|
29
24
|
data = {
|
@@ -34,7 +29,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
34
29
|
"raw_model_response": raw_response["raw_model_response"],
|
35
30
|
}
|
36
31
|
response = self._format_raw_response(**data)
|
37
|
-
# breakpoint()
|
38
32
|
return AgentResponseDict(**response)
|
39
33
|
|
40
34
|
async def async_get_response(
|
@@ -103,6 +97,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
103
97
|
answer = question._translate_answer_code_to_answer(
|
104
98
|
response["answer"], combined_dict
|
105
99
|
)
|
100
|
+
# breakpoint()
|
106
101
|
data = {
|
107
102
|
"answer": answer,
|
108
103
|
"comment": response.get(
|
@@ -281,17 +281,12 @@ class PromptConstructorMixin:
|
|
281
281
|
if "question_options" in question_data:
|
282
282
|
if isinstance(self.question.data["question_options"], str):
|
283
283
|
from jinja2 import Environment, meta
|
284
|
-
|
285
284
|
env = Environment()
|
286
|
-
parsed_content = env.parse(self.question.data[
|
287
|
-
question_option_key = list(
|
288
|
-
|
289
|
-
)[0]
|
290
|
-
question_data["question_options"] = self.scenario.get(
|
291
|
-
question_option_key
|
292
|
-
)
|
285
|
+
parsed_content = env.parse(self.question.data['question_options'])
|
286
|
+
question_option_key = list(meta.find_undeclared_variables(parsed_content))[0]
|
287
|
+
question_data["question_options"] = self.scenario.get(question_option_key)
|
293
288
|
|
294
|
-
#
|
289
|
+
#breakpoint()
|
295
290
|
rendered_instructions = question_prompt.render(
|
296
291
|
question_data | self.scenario | d | {"agent": self.agent}
|
297
292
|
)
|
edsl/config.py
CHANGED
@@ -65,10 +65,6 @@ CONFIG_MAP = {
|
|
65
65
|
# "default": None,
|
66
66
|
# "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
|
67
67
|
# },
|
68
|
-
# "GROQ_API_KEY": {
|
69
|
-
# "default": None,
|
70
|
-
# "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
|
71
|
-
# },
|
72
68
|
}
|
73
69
|
|
74
70
|
|
edsl/conjure/Conjure.py
CHANGED
@@ -35,12 +35,6 @@ class Conjure:
|
|
35
35
|
# The __init__ method in Conjure won't be called because __new__ returns a different class instance.
|
36
36
|
pass
|
37
37
|
|
38
|
-
@classmethod
|
39
|
-
def example(cls):
|
40
|
-
from edsl.conjure.InputData import InputDataABC
|
41
|
-
|
42
|
-
return InputDataABC.example()
|
43
|
-
|
44
38
|
|
45
39
|
if __name__ == "__main__":
|
46
40
|
pass
|
edsl/coop/coop.py
CHANGED
@@ -465,7 +465,6 @@ class Coop:
|
|
465
465
|
description: Optional[str] = None,
|
466
466
|
status: RemoteJobStatus = "queued",
|
467
467
|
visibility: Optional[VisibilityType] = "unlisted",
|
468
|
-
iterations: Optional[int] = 1,
|
469
468
|
) -> dict:
|
470
469
|
"""
|
471
470
|
Send a remote inference job to the server.
|
@@ -474,7 +473,6 @@ class Coop:
|
|
474
473
|
:param optional description: A description for this entry in the remote cache.
|
475
474
|
:param status: The status of the job. Should be 'queued', unless you are debugging.
|
476
475
|
:param visibility: The visibility of the cache entry.
|
477
|
-
:param iterations: The number of times to run each interview.
|
478
476
|
|
479
477
|
>>> job = Jobs.example()
|
480
478
|
>>> coop.remote_inference_create(job=job, description="My job")
|
@@ -490,7 +488,6 @@ class Coop:
|
|
490
488
|
),
|
491
489
|
"description": description,
|
492
490
|
"status": status,
|
493
|
-
"iterations": iterations,
|
494
491
|
"visibility": visibility,
|
495
492
|
"version": self._edsl_version,
|
496
493
|
},
|
@@ -501,7 +498,6 @@ class Coop:
|
|
501
498
|
"uuid": response_json.get("jobs_uuid"),
|
502
499
|
"description": response_json.get("description"),
|
503
500
|
"status": response_json.get("status"),
|
504
|
-
"iterations": response_json.get("iterations"),
|
505
501
|
"visibility": response_json.get("visibility"),
|
506
502
|
"version": self._edsl_version,
|
507
503
|
}
|
edsl/data/CacheHandler.py
CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
|
|
41
41
|
old_data = self.from_old_sqlite_cache()
|
42
42
|
self.cache.add_from_dict(old_data)
|
43
43
|
|
44
|
-
def create_cache_directory(self
|
44
|
+
def create_cache_directory(self) -> None:
|
45
45
|
"""
|
46
46
|
Create the cache directory if one is required and it does not exist.
|
47
47
|
"""
|
@@ -49,8 +49,9 @@ class CacheHandler:
|
|
49
49
|
dir_path = os.path.dirname(path)
|
50
50
|
if dir_path and not os.path.exists(dir_path):
|
51
51
|
os.makedirs(dir_path)
|
52
|
-
|
53
|
-
|
52
|
+
import warnings
|
53
|
+
|
54
|
+
warnings.warn(f"Created cache directory: {dir_path}")
|
54
55
|
|
55
56
|
def gen_cache(self) -> Cache:
|
56
57
|
"""
|
edsl/enums.py
CHANGED
@@ -59,7 +59,6 @@ class InferenceServiceType(EnumWithChecks):
|
|
59
59
|
GOOGLE = "google"
|
60
60
|
TEST = "test"
|
61
61
|
ANTHROPIC = "anthropic"
|
62
|
-
GROQ = "groq"
|
63
62
|
|
64
63
|
|
65
64
|
service_to_api_keyname = {
|
@@ -70,7 +69,6 @@ service_to_api_keyname = {
|
|
70
69
|
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
71
70
|
InferenceServiceType.TEST.value: "TBD",
|
72
71
|
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
73
|
-
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
74
72
|
}
|
75
73
|
|
76
74
|
|
@@ -2,17 +2,102 @@ import aiohttp
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from typing import Any, List
|
5
|
-
|
6
|
-
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
6
|
from edsl.language_models import LanguageModel
|
8
7
|
|
9
|
-
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
-
|
11
8
|
|
12
|
-
class DeepInfraService(
|
9
|
+
class DeepInfraService(InferenceServiceABC):
|
13
10
|
"""DeepInfra service class."""
|
14
11
|
|
15
12
|
_inference_service_ = "deep_infra"
|
16
13
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
-
|
14
|
+
|
18
15
|
_models_list_cache: List[str] = []
|
16
|
+
|
17
|
+
@classmethod
|
18
|
+
def available(cls):
|
19
|
+
text_models = cls.full_details_available()
|
20
|
+
return [m["model_name"] for m in text_models]
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def full_details_available(cls, verbose=False):
|
24
|
+
if not cls._models_list_cache:
|
25
|
+
url = "https://api.deepinfra.com/models/list"
|
26
|
+
response = requests.get(url)
|
27
|
+
if response.status_code == 200:
|
28
|
+
text_generation_models = [
|
29
|
+
r for r in response.json() if r["type"] == "text-generation"
|
30
|
+
]
|
31
|
+
cls._models_list_cache = text_generation_models
|
32
|
+
|
33
|
+
from rich import print_json
|
34
|
+
import json
|
35
|
+
|
36
|
+
if verbose:
|
37
|
+
print_json(json.dumps(text_generation_models))
|
38
|
+
return text_generation_models
|
39
|
+
else:
|
40
|
+
return f"Failed to fetch data: Status code {response.status_code}"
|
41
|
+
else:
|
42
|
+
return cls._models_list_cache
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
|
46
|
+
base_url = "https://api.deepinfra.com/v1/inference/"
|
47
|
+
if model_class_name is None:
|
48
|
+
model_class_name = cls.to_class_name(model_name)
|
49
|
+
url = f"{base_url}{model_name}"
|
50
|
+
|
51
|
+
class LLM(LanguageModel):
|
52
|
+
_inference_service_ = cls._inference_service_
|
53
|
+
_model_ = model_name
|
54
|
+
_parameters_ = {
|
55
|
+
"temperature": 0.7,
|
56
|
+
"top_p": 0.2,
|
57
|
+
"top_k": 0.1,
|
58
|
+
"max_new_tokens": 512,
|
59
|
+
"stopSequences": [],
|
60
|
+
}
|
61
|
+
|
62
|
+
async def async_execute_model_call(
|
63
|
+
self, user_prompt: str, system_prompt: str = ""
|
64
|
+
) -> dict[str, Any]:
|
65
|
+
self.url = url
|
66
|
+
headers = {
|
67
|
+
"Content-Type": "application/json",
|
68
|
+
"Authorization": f"bearer {self.api_token}",
|
69
|
+
}
|
70
|
+
# don't mess w/ the newlines
|
71
|
+
data = {
|
72
|
+
"input": f"""
|
73
|
+
[INST]<<SYS>>
|
74
|
+
{system_prompt}
|
75
|
+
<<SYS>>{user_prompt}[/INST]
|
76
|
+
""",
|
77
|
+
"stream": False,
|
78
|
+
"temperature": self.temperature,
|
79
|
+
"top_p": self.top_p,
|
80
|
+
"top_k": self.top_k,
|
81
|
+
"max_new_tokens": self.max_new_tokens,
|
82
|
+
}
|
83
|
+
async with aiohttp.ClientSession() as session:
|
84
|
+
async with session.post(
|
85
|
+
self.url, headers=headers, data=json.dumps(data)
|
86
|
+
) as response:
|
87
|
+
raw_response_text = await response.text()
|
88
|
+
return json.loads(raw_response_text)
|
89
|
+
|
90
|
+
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
91
|
+
if "results" not in raw_response:
|
92
|
+
raise Exception(
|
93
|
+
f"Deep Infra response does not contain 'results' key: {raw_response}"
|
94
|
+
)
|
95
|
+
if "generated_text" not in raw_response["results"][0]:
|
96
|
+
raise Exception(
|
97
|
+
f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
|
98
|
+
)
|
99
|
+
return raw_response["results"][0]["generated_text"]
|
100
|
+
|
101
|
+
LLM.__name__ = model_class_name
|
102
|
+
|
103
|
+
return LLM
|
@@ -15,19 +15,18 @@ class InferenceServicesCollection:
|
|
15
15
|
cls.added_models[service_name].append(model_name)
|
16
16
|
|
17
17
|
@staticmethod
|
18
|
-
def _get_service_available(service
|
18
|
+
def _get_service_available(service) -> list[str]:
|
19
19
|
from_api = True
|
20
20
|
try:
|
21
21
|
service_models = service.available()
|
22
22
|
except Exception as e:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
)
|
23
|
+
warnings.warn(
|
24
|
+
f"""Error getting models for {service._inference_service_}.
|
25
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
26
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
27
|
+
Relying on cache.""",
|
28
|
+
UserWarning,
|
29
|
+
)
|
31
30
|
from edsl.inference_services.models_available_cache import models_available
|
32
31
|
|
33
32
|
service_models = models_available.get(service._inference_service_, [])
|
@@ -61,8 +60,4 @@ class InferenceServicesCollection:
|
|
61
60
|
if service_name is None or service_name == service._inference_service_:
|
62
61
|
return service.create_model(model_name)
|
63
62
|
|
64
|
-
# if model_name == "test":
|
65
|
-
# from edsl.language_models import LanguageModel
|
66
|
-
# return LanguageModel(test = True)
|
67
|
-
|
68
63
|
raise Exception(f"Model {model_name} not found in any of the services")
|
@@ -1,9 +1,6 @@
|
|
1
1
|
from typing import Any, List
|
2
2
|
import re
|
3
|
-
import
|
4
|
-
|
5
|
-
# from openai import AsyncOpenAI
|
6
|
-
import openai
|
3
|
+
from openai import AsyncOpenAI
|
7
4
|
|
8
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
9
6
|
from edsl.language_models import LanguageModel
|
@@ -15,22 +12,6 @@ class OpenAIService(InferenceServiceABC):
|
|
15
12
|
|
16
13
|
_inference_service_ = "openai"
|
17
14
|
_env_key_name_ = "OPENAI_API_KEY"
|
18
|
-
_base_url_ = None
|
19
|
-
|
20
|
-
_sync_client_ = openai.OpenAI
|
21
|
-
_async_client_ = openai.AsyncOpenAI
|
22
|
-
|
23
|
-
@classmethod
|
24
|
-
def sync_client(cls):
|
25
|
-
return cls._sync_client_(
|
26
|
-
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
27
|
-
)
|
28
|
-
|
29
|
-
@classmethod
|
30
|
-
def async_client(cls):
|
31
|
-
return cls._async_client_(
|
32
|
-
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
33
|
-
)
|
34
15
|
|
35
16
|
# TODO: Make this a coop call
|
36
17
|
model_exclude_list = [
|
@@ -50,24 +31,16 @@ class OpenAIService(InferenceServiceABC):
|
|
50
31
|
]
|
51
32
|
_models_list_cache: List[str] = []
|
52
33
|
|
53
|
-
@classmethod
|
54
|
-
def get_model_list(cls):
|
55
|
-
raw_list = cls.sync_client().models.list()
|
56
|
-
if hasattr(raw_list, "data"):
|
57
|
-
return raw_list.data
|
58
|
-
else:
|
59
|
-
return raw_list
|
60
|
-
|
61
34
|
@classmethod
|
62
35
|
def available(cls) -> List[str]:
|
63
|
-
|
36
|
+
from openai import OpenAI
|
64
37
|
|
65
38
|
if not cls._models_list_cache:
|
66
39
|
try:
|
67
|
-
|
40
|
+
client = OpenAI()
|
68
41
|
cls._models_list_cache = [
|
69
42
|
m.id
|
70
|
-
for m in
|
43
|
+
for m in client.models.list()
|
71
44
|
if m.id not in cls.model_exclude_list
|
72
45
|
]
|
73
46
|
except Exception as e:
|
@@ -105,24 +78,15 @@ class OpenAIService(InferenceServiceABC):
|
|
105
78
|
"top_logprobs": 3,
|
106
79
|
}
|
107
80
|
|
108
|
-
def sync_client(self):
|
109
|
-
return cls.sync_client()
|
110
|
-
|
111
|
-
def async_client(self):
|
112
|
-
return cls.async_client()
|
113
|
-
|
114
81
|
@classmethod
|
115
82
|
def available(cls) -> list[str]:
|
116
|
-
|
117
|
-
|
118
|
-
# return client.models.list()
|
119
|
-
return cls.sync_client().models.list()
|
83
|
+
client = openai.OpenAI()
|
84
|
+
return client.models.list()
|
120
85
|
|
121
86
|
def get_headers(self) -> dict[str, Any]:
|
122
|
-
|
87
|
+
from openai import OpenAI
|
123
88
|
|
124
|
-
|
125
|
-
client = self.sync_client()
|
89
|
+
client = OpenAI()
|
126
90
|
response = client.chat.completions.with_raw_response.create(
|
127
91
|
messages=[
|
128
92
|
{
|
@@ -160,8 +124,8 @@ class OpenAIService(InferenceServiceABC):
|
|
160
124
|
encoded_image=None,
|
161
125
|
) -> dict[str, Any]:
|
162
126
|
"""Calls the OpenAI API and returns the API response."""
|
127
|
+
content = [{"type": "text", "text": user_prompt}]
|
163
128
|
if encoded_image:
|
164
|
-
content = [{"type": "text", "text": user_prompt}]
|
165
129
|
content.append(
|
166
130
|
{
|
167
131
|
"type": "image_url",
|
@@ -170,28 +134,21 @@ class OpenAIService(InferenceServiceABC):
|
|
170
134
|
},
|
171
135
|
}
|
172
136
|
)
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
# base_url = cls._base_url_
|
178
|
-
# )
|
179
|
-
client = self.async_client()
|
180
|
-
params = {
|
181
|
-
"model": self.model,
|
182
|
-
"messages": [
|
137
|
+
self.client = AsyncOpenAI()
|
138
|
+
response = await self.client.chat.completions.create(
|
139
|
+
model=self.model,
|
140
|
+
messages=[
|
183
141
|
{"role": "system", "content": system_prompt},
|
184
142
|
{"role": "user", "content": content},
|
185
143
|
],
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
response = await client.chat.completions.create(**params)
|
144
|
+
temperature=self.temperature,
|
145
|
+
max_tokens=self.max_tokens,
|
146
|
+
top_p=self.top_p,
|
147
|
+
frequency_penalty=self.frequency_penalty,
|
148
|
+
presence_penalty=self.presence_penalty,
|
149
|
+
logprobs=self.logprobs,
|
150
|
+
top_logprobs=self.top_logprobs if self.logprobs else None,
|
151
|
+
)
|
195
152
|
return response.model_dump()
|
196
153
|
|
197
154
|
@staticmethod
|
@@ -6,8 +6,7 @@ from edsl.inference_services.OpenAIService import OpenAIService
|
|
6
6
|
from edsl.inference_services.AnthropicService import AnthropicService
|
7
7
|
from edsl.inference_services.DeepInfraService import DeepInfraService
|
8
8
|
from edsl.inference_services.GoogleService import GoogleService
|
9
|
-
from edsl.inference_services.GroqService import GroqService
|
10
9
|
|
11
10
|
default = InferenceServicesCollection(
|
12
|
-
[OpenAIService, AnthropicService, DeepInfraService, GoogleService
|
11
|
+
[OpenAIService, AnthropicService, DeepInfraService, GoogleService]
|
13
12
|
)
|
edsl/jobs/Jobs.py
CHANGED
@@ -319,11 +319,7 @@ class Jobs(Base):
|
|
319
319
|
self.scenarios = self.scenarios or [Scenario()]
|
320
320
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
321
|
yield Interview(
|
322
|
-
survey=self.survey,
|
323
|
-
agent=agent,
|
324
|
-
scenario=scenario,
|
325
|
-
model=model,
|
326
|
-
skip_retry=self.skip_retry,
|
322
|
+
survey=self.survey, agent=agent, scenario=scenario, model=model
|
327
323
|
)
|
328
324
|
|
329
325
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -413,12 +409,6 @@ class Jobs(Base):
|
|
413
409
|
if warn:
|
414
410
|
warnings.warn(message)
|
415
411
|
|
416
|
-
@property
|
417
|
-
def skip_retry(self):
|
418
|
-
if not hasattr(self, "_skip_retry"):
|
419
|
-
return False
|
420
|
-
return self._skip_retry
|
421
|
-
|
422
412
|
def run(
|
423
413
|
self,
|
424
414
|
n: int = 1,
|
@@ -433,7 +423,6 @@ class Jobs(Base):
|
|
433
423
|
print_exceptions=True,
|
434
424
|
remote_cache_description: Optional[str] = None,
|
435
425
|
remote_inference_description: Optional[str] = None,
|
436
|
-
skip_retry: bool = False,
|
437
426
|
) -> Results:
|
438
427
|
"""
|
439
428
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -452,7 +441,6 @@ class Jobs(Base):
|
|
452
441
|
from edsl.coop.coop import Coop
|
453
442
|
|
454
443
|
self._check_parameters()
|
455
|
-
self._skip_retry = skip_retry
|
456
444
|
|
457
445
|
if batch_mode is not None:
|
458
446
|
raise NotImplementedError(
|
@@ -487,7 +475,6 @@ class Jobs(Base):
|
|
487
475
|
self,
|
488
476
|
description=remote_inference_description,
|
489
477
|
status="queued",
|
490
|
-
iterations=n,
|
491
478
|
)
|
492
479
|
time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
|
493
480
|
job_uuid = remote_job_creation_data.get("uuid")
|
@@ -642,9 +629,9 @@ class Jobs(Base):
|
|
642
629
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
643
630
|
return results
|
644
631
|
|
645
|
-
async def run_async(self, cache=None,
|
632
|
+
async def run_async(self, cache=None, **kwargs):
|
646
633
|
"""Run the job asynchronously."""
|
647
|
-
results = await JobsRunnerAsyncio(self).run_async(cache=cache,
|
634
|
+
results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
|
648
635
|
return results
|
649
636
|
|
650
637
|
def all_question_parameters(self):
|
@@ -724,10 +711,7 @@ class Jobs(Base):
|
|
724
711
|
#######################
|
725
712
|
@classmethod
|
726
713
|
def example(
|
727
|
-
cls,
|
728
|
-
throw_exception_probability: int = 0,
|
729
|
-
randomize: bool = False,
|
730
|
-
test_model=False,
|
714
|
+
cls, throw_exception_probability: int = 0, randomize: bool = False
|
731
715
|
) -> Jobs:
|
732
716
|
"""Return an example Jobs instance.
|
733
717
|
|
@@ -745,11 +729,6 @@ class Jobs(Base):
|
|
745
729
|
|
746
730
|
addition = "" if not randomize else str(uuid4())
|
747
731
|
|
748
|
-
if test_model:
|
749
|
-
from edsl.language_models import LanguageModel
|
750
|
-
|
751
|
-
m = LanguageModel.example(test_model=True)
|
752
|
-
|
753
732
|
# (status, question, period)
|
754
733
|
agent_answers = {
|
755
734
|
("Joyful", "how_feeling", "morning"): "OK",
|
@@ -797,10 +776,7 @@ class Jobs(Base):
|
|
797
776
|
Scenario({"period": "afternoon"}),
|
798
777
|
]
|
799
778
|
)
|
800
|
-
|
801
|
-
job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
|
802
|
-
else:
|
803
|
-
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
779
|
+
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
804
780
|
|
805
781
|
return job
|
806
782
|
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -100,9 +100,7 @@ class TokenBucket:
|
|
100
100
|
available_tokens = min(self.capacity, self.tokens + refill_amount)
|
101
101
|
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
102
102
|
|
103
|
-
async def get_tokens(
|
104
|
-
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
105
|
-
) -> None:
|
103
|
+
async def get_tokens(self, amount: Union[int, float] = 1) -> None:
|
106
104
|
"""Wait for the specified number of tokens to become available.
|
107
105
|
|
108
106
|
|
@@ -118,20 +116,14 @@ class TokenBucket:
|
|
118
116
|
True
|
119
117
|
|
120
118
|
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
121
|
-
>>> asyncio.run(bucket.get_tokens(11
|
119
|
+
>>> asyncio.run(bucket.get_tokens(11))
|
122
120
|
Traceback (most recent call last):
|
123
121
|
...
|
124
122
|
ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
|
125
|
-
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
|
126
123
|
"""
|
127
124
|
if amount > self.capacity:
|
128
|
-
|
129
|
-
|
130
|
-
raise ValueError(msg)
|
131
|
-
else:
|
132
|
-
self.tokens = 0 # clear the bucket but let it go through
|
133
|
-
return
|
134
|
-
|
125
|
+
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
126
|
+
raise ValueError(msg)
|
135
127
|
while self.tokens < amount:
|
136
128
|
self.refill()
|
137
129
|
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|
@@ -14,8 +14,8 @@ from edsl.jobs.tasks.TaskCreators import TaskCreators
|
|
14
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
15
15
|
from edsl.jobs.interviews.interview_exception_tracking import (
|
16
16
|
InterviewExceptionCollection,
|
17
|
+
InterviewExceptionEntry,
|
17
18
|
)
|
18
|
-
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
19
19
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
20
20
|
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
21
21
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
@@ -44,7 +44,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
44
44
|
iteration: int = 0,
|
45
45
|
cache: Optional["Cache"] = None,
|
46
46
|
sidecar_model: Optional["LanguageModel"] = None,
|
47
|
-
skip_retry=False,
|
48
47
|
):
|
49
48
|
"""Initialize the Interview instance.
|
50
49
|
|
@@ -88,7 +87,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
88
87
|
self.task_creators = TaskCreators() # tracks the task creators
|
89
88
|
self.exceptions = InterviewExceptionCollection()
|
90
89
|
self._task_status_log_dict = InterviewStatusLog()
|
91
|
-
self.skip_retry = skip_retry
|
92
90
|
|
93
91
|
# dictionary mapping question names to their index in the survey.
|
94
92
|
self.to_index = {
|
@@ -96,30 +94,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
96
94
|
for index, question_name in enumerate(self.survey.question_names)
|
97
95
|
}
|
98
96
|
|
99
|
-
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
100
|
-
"""Return a dictionary representation of the Interview instance.
|
101
|
-
This is just for hashing purposes.
|
102
|
-
|
103
|
-
>>> i = Interview.example()
|
104
|
-
>>> hash(i)
|
105
|
-
1646262796627658719
|
106
|
-
"""
|
107
|
-
d = {
|
108
|
-
"agent": self.agent._to_dict(),
|
109
|
-
"survey": self.survey._to_dict(),
|
110
|
-
"scenario": self.scenario._to_dict(),
|
111
|
-
"model": self.model._to_dict(),
|
112
|
-
"iteration": self.iteration,
|
113
|
-
"exceptions": {},
|
114
|
-
}
|
115
|
-
if include_exceptions:
|
116
|
-
d["exceptions"] = self.exceptions.to_dict()
|
117
|
-
|
118
|
-
def __hash__(self) -> int:
|
119
|
-
from edsl.utilities.utilities import dict_hash
|
120
|
-
|
121
|
-
return dict_hash(self._to_dict())
|
122
|
-
|
123
97
|
async def async_conduct_interview(
|
124
98
|
self,
|
125
99
|
*,
|
@@ -160,7 +134,8 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
160
134
|
<BLANKLINE>
|
161
135
|
|
162
136
|
>>> i.exceptions
|
163
|
-
{'q0': ...
|
137
|
+
{'q0': [{'exception': "Exception('This is a test error')", 'time': ..., 'traceback': ...
|
138
|
+
|
164
139
|
>>> i = Interview.example()
|
165
140
|
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
166
141
|
Traceback (most recent call last):
|
@@ -229,9 +204,13 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
229
204
|
{}
|
230
205
|
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
231
206
|
>>> i.exceptions
|
232
|
-
{'q0':
|
207
|
+
{'q0': [{'exception': "Exception('An exception occurred.')", 'time': ..., 'traceback': 'NoneType: None\\n'}]}
|
233
208
|
"""
|
234
|
-
exception_entry = InterviewExceptionEntry(
|
209
|
+
exception_entry = InterviewExceptionEntry(
|
210
|
+
exception=repr(exception),
|
211
|
+
time=time.time(),
|
212
|
+
traceback=traceback.format_exc(),
|
213
|
+
)
|
235
214
|
self.exceptions.add(task.get_name(), exception_entry)
|
236
215
|
|
237
216
|
@property
|
@@ -272,7 +251,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
272
251
|
model=self.model,
|
273
252
|
iteration=iteration,
|
274
253
|
cache=cache,
|
275
|
-
skip_retry=self.skip_retry,
|
276
254
|
)
|
277
255
|
|
278
256
|
@classmethod
|