edsl 0.1.31__py3-none-any.whl → 0.1.31.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +2 -7
- edsl/agents/PromptConstructionMixin.py +1 -18
- edsl/config.py +0 -4
- edsl/conjure/Conjure.py +0 -6
- edsl/coop/coop.py +0 -4
- edsl/data/CacheHandler.py +4 -3
- edsl/enums.py +0 -2
- edsl/inference_services/DeepInfraService.py +91 -6
- edsl/inference_services/InferenceServicesCollection.py +5 -13
- edsl/inference_services/OpenAIService.py +21 -64
- edsl/inference_services/registry.py +1 -2
- edsl/jobs/Jobs.py +33 -80
- edsl/jobs/buckets/TokenBucket.py +4 -12
- edsl/jobs/interviews/Interview.py +9 -31
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +33 -49
- edsl/jobs/interviews/interview_exception_tracking.py +10 -68
- edsl/jobs/runners/JobsRunnerAsyncio.py +81 -112
- edsl/jobs/runners/JobsRunnerStatusData.py +237 -0
- edsl/jobs/runners/JobsRunnerStatusMixin.py +35 -291
- edsl/jobs/tasks/TaskCreators.py +2 -8
- edsl/jobs/tasks/TaskHistory.py +1 -145
- edsl/language_models/LanguageModel.py +74 -127
- edsl/language_models/registry.py +0 -4
- edsl/questions/QuestionMultipleChoice.py +0 -1
- edsl/questions/QuestionNumerical.py +1 -0
- edsl/results/DatasetExportMixin.py +3 -12
- edsl/scenarios/Scenario.py +0 -14
- edsl/scenarios/ScenarioList.py +2 -15
- edsl/scenarios/ScenarioListExportMixin.py +4 -15
- edsl/scenarios/ScenarioListPdfMixin.py +0 -3
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev1.dist-info}/METADATA +2 -3
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev1.dist-info}/RECORD +35 -37
- edsl/inference_services/GroqService.py +0 -18
- edsl/jobs/interviews/InterviewExceptionEntry.py +0 -101
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dist-info → edsl-0.1.31.dev1.dist-info}/WHEEL +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.31"
|
1
|
+
__version__ = "0.1.31.dev1"
|
edsl/agents/Invigilator.py
CHANGED
@@ -18,12 +18,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
18
18
|
"""An invigilator that uses an AI model to answer questions."""
|
19
19
|
|
20
20
|
async def async_answer_question(self) -> AgentResponseDict:
|
21
|
-
"""Answer a question using the AI model.
|
22
|
-
|
23
|
-
>>> i = InvigilatorAI.example()
|
24
|
-
>>> i.answer_question()
|
25
|
-
{'message': '{"answer": "SPAM!"}'}
|
26
|
-
"""
|
21
|
+
"""Answer a question using the AI model."""
|
27
22
|
params = self.get_prompts() | {"iteration": self.iteration}
|
28
23
|
raw_response = await self.async_get_response(**params)
|
29
24
|
data = {
|
@@ -34,7 +29,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
34
29
|
"raw_model_response": raw_response["raw_model_response"],
|
35
30
|
}
|
36
31
|
response = self._format_raw_response(**data)
|
37
|
-
# breakpoint()
|
38
32
|
return AgentResponseDict(**response)
|
39
33
|
|
40
34
|
async def async_get_response(
|
@@ -103,6 +97,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
103
97
|
answer = question._translate_answer_code_to_answer(
|
104
98
|
response["answer"], combined_dict
|
105
99
|
)
|
100
|
+
# breakpoint()
|
106
101
|
data = {
|
107
102
|
"answer": answer,
|
108
103
|
"comment": response.get(
|
@@ -275,25 +275,8 @@ class PromptConstructorMixin:
|
|
275
275
|
if (new_question := question.split("_comment")[0]) in d:
|
276
276
|
d[new_question].comment = answer
|
277
277
|
|
278
|
-
question_data = self.question.data.copy()
|
279
|
-
|
280
|
-
# check to see if the questio_options is actuall a string
|
281
|
-
if "question_options" in question_data:
|
282
|
-
if isinstance(self.question.data["question_options"], str):
|
283
|
-
from jinja2 import Environment, meta
|
284
|
-
|
285
|
-
env = Environment()
|
286
|
-
parsed_content = env.parse(self.question.data["question_options"])
|
287
|
-
question_option_key = list(
|
288
|
-
meta.find_undeclared_variables(parsed_content)
|
289
|
-
)[0]
|
290
|
-
question_data["question_options"] = self.scenario.get(
|
291
|
-
question_option_key
|
292
|
-
)
|
293
|
-
|
294
|
-
# breakpoint()
|
295
278
|
rendered_instructions = question_prompt.render(
|
296
|
-
|
279
|
+
self.question.data | self.scenario | d | {"agent": self.agent}
|
297
280
|
)
|
298
281
|
|
299
282
|
undefined_template_variables = (
|
edsl/config.py
CHANGED
@@ -65,10 +65,6 @@ CONFIG_MAP = {
|
|
65
65
|
# "default": None,
|
66
66
|
# "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
|
67
67
|
# },
|
68
|
-
# "GROQ_API_KEY": {
|
69
|
-
# "default": None,
|
70
|
-
# "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
|
71
|
-
# },
|
72
68
|
}
|
73
69
|
|
74
70
|
|
edsl/conjure/Conjure.py
CHANGED
@@ -35,12 +35,6 @@ class Conjure:
|
|
35
35
|
# The __init__ method in Conjure won't be called because __new__ returns a different class instance.
|
36
36
|
pass
|
37
37
|
|
38
|
-
@classmethod
|
39
|
-
def example(cls):
|
40
|
-
from edsl.conjure.InputData import InputDataABC
|
41
|
-
|
42
|
-
return InputDataABC.example()
|
43
|
-
|
44
38
|
|
45
39
|
if __name__ == "__main__":
|
46
40
|
pass
|
edsl/coop/coop.py
CHANGED
@@ -465,7 +465,6 @@ class Coop:
|
|
465
465
|
description: Optional[str] = None,
|
466
466
|
status: RemoteJobStatus = "queued",
|
467
467
|
visibility: Optional[VisibilityType] = "unlisted",
|
468
|
-
iterations: Optional[int] = 1,
|
469
468
|
) -> dict:
|
470
469
|
"""
|
471
470
|
Send a remote inference job to the server.
|
@@ -474,7 +473,6 @@ class Coop:
|
|
474
473
|
:param optional description: A description for this entry in the remote cache.
|
475
474
|
:param status: The status of the job. Should be 'queued', unless you are debugging.
|
476
475
|
:param visibility: The visibility of the cache entry.
|
477
|
-
:param iterations: The number of times to run each interview.
|
478
476
|
|
479
477
|
>>> job = Jobs.example()
|
480
478
|
>>> coop.remote_inference_create(job=job, description="My job")
|
@@ -490,7 +488,6 @@ class Coop:
|
|
490
488
|
),
|
491
489
|
"description": description,
|
492
490
|
"status": status,
|
493
|
-
"iterations": iterations,
|
494
491
|
"visibility": visibility,
|
495
492
|
"version": self._edsl_version,
|
496
493
|
},
|
@@ -501,7 +498,6 @@ class Coop:
|
|
501
498
|
"uuid": response_json.get("jobs_uuid"),
|
502
499
|
"description": response_json.get("description"),
|
503
500
|
"status": response_json.get("status"),
|
504
|
-
"iterations": response_json.get("iterations"),
|
505
501
|
"visibility": response_json.get("visibility"),
|
506
502
|
"version": self._edsl_version,
|
507
503
|
}
|
edsl/data/CacheHandler.py
CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
|
|
41
41
|
old_data = self.from_old_sqlite_cache()
|
42
42
|
self.cache.add_from_dict(old_data)
|
43
43
|
|
44
|
-
def create_cache_directory(self
|
44
|
+
def create_cache_directory(self) -> None:
|
45
45
|
"""
|
46
46
|
Create the cache directory if one is required and it does not exist.
|
47
47
|
"""
|
@@ -49,8 +49,9 @@ class CacheHandler:
|
|
49
49
|
dir_path = os.path.dirname(path)
|
50
50
|
if dir_path and not os.path.exists(dir_path):
|
51
51
|
os.makedirs(dir_path)
|
52
|
-
|
53
|
-
|
52
|
+
import warnings
|
53
|
+
|
54
|
+
warnings.warn(f"Created cache directory: {dir_path}")
|
54
55
|
|
55
56
|
def gen_cache(self) -> Cache:
|
56
57
|
"""
|
edsl/enums.py
CHANGED
@@ -59,7 +59,6 @@ class InferenceServiceType(EnumWithChecks):
|
|
59
59
|
GOOGLE = "google"
|
60
60
|
TEST = "test"
|
61
61
|
ANTHROPIC = "anthropic"
|
62
|
-
GROQ = "groq"
|
63
62
|
|
64
63
|
|
65
64
|
service_to_api_keyname = {
|
@@ -70,7 +69,6 @@ service_to_api_keyname = {
|
|
70
69
|
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
71
70
|
InferenceServiceType.TEST.value: "TBD",
|
72
71
|
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
73
|
-
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
74
72
|
}
|
75
73
|
|
76
74
|
|
@@ -2,17 +2,102 @@ import aiohttp
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from typing import Any, List
|
5
|
-
|
6
|
-
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
+
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
6
|
from edsl.language_models import LanguageModel
|
8
7
|
|
9
|
-
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
-
|
11
8
|
|
12
|
-
class DeepInfraService(
|
9
|
+
class DeepInfraService(InferenceServiceABC):
|
13
10
|
"""DeepInfra service class."""
|
14
11
|
|
15
12
|
_inference_service_ = "deep_infra"
|
16
13
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
-
|
14
|
+
|
18
15
|
_models_list_cache: List[str] = []
|
16
|
+
|
17
|
+
@classmethod
|
18
|
+
def available(cls):
|
19
|
+
text_models = cls.full_details_available()
|
20
|
+
return [m["model_name"] for m in text_models]
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def full_details_available(cls, verbose=False):
|
24
|
+
if not cls._models_list_cache:
|
25
|
+
url = "https://api.deepinfra.com/models/list"
|
26
|
+
response = requests.get(url)
|
27
|
+
if response.status_code == 200:
|
28
|
+
text_generation_models = [
|
29
|
+
r for r in response.json() if r["type"] == "text-generation"
|
30
|
+
]
|
31
|
+
cls._models_list_cache = text_generation_models
|
32
|
+
|
33
|
+
from rich import print_json
|
34
|
+
import json
|
35
|
+
|
36
|
+
if verbose:
|
37
|
+
print_json(json.dumps(text_generation_models))
|
38
|
+
return text_generation_models
|
39
|
+
else:
|
40
|
+
return f"Failed to fetch data: Status code {response.status_code}"
|
41
|
+
else:
|
42
|
+
return cls._models_list_cache
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
|
46
|
+
base_url = "https://api.deepinfra.com/v1/inference/"
|
47
|
+
if model_class_name is None:
|
48
|
+
model_class_name = cls.to_class_name(model_name)
|
49
|
+
url = f"{base_url}{model_name}"
|
50
|
+
|
51
|
+
class LLM(LanguageModel):
|
52
|
+
_inference_service_ = cls._inference_service_
|
53
|
+
_model_ = model_name
|
54
|
+
_parameters_ = {
|
55
|
+
"temperature": 0.7,
|
56
|
+
"top_p": 0.2,
|
57
|
+
"top_k": 0.1,
|
58
|
+
"max_new_tokens": 512,
|
59
|
+
"stopSequences": [],
|
60
|
+
}
|
61
|
+
|
62
|
+
async def async_execute_model_call(
|
63
|
+
self, user_prompt: str, system_prompt: str = ""
|
64
|
+
) -> dict[str, Any]:
|
65
|
+
self.url = url
|
66
|
+
headers = {
|
67
|
+
"Content-Type": "application/json",
|
68
|
+
"Authorization": f"bearer {self.api_token}",
|
69
|
+
}
|
70
|
+
# don't mess w/ the newlines
|
71
|
+
data = {
|
72
|
+
"input": f"""
|
73
|
+
[INST]<<SYS>>
|
74
|
+
{system_prompt}
|
75
|
+
<<SYS>>{user_prompt}[/INST]
|
76
|
+
""",
|
77
|
+
"stream": False,
|
78
|
+
"temperature": self.temperature,
|
79
|
+
"top_p": self.top_p,
|
80
|
+
"top_k": self.top_k,
|
81
|
+
"max_new_tokens": self.max_new_tokens,
|
82
|
+
}
|
83
|
+
async with aiohttp.ClientSession() as session:
|
84
|
+
async with session.post(
|
85
|
+
self.url, headers=headers, data=json.dumps(data)
|
86
|
+
) as response:
|
87
|
+
raw_response_text = await response.text()
|
88
|
+
return json.loads(raw_response_text)
|
89
|
+
|
90
|
+
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
91
|
+
if "results" not in raw_response:
|
92
|
+
raise Exception(
|
93
|
+
f"Deep Infra response does not contain 'results' key: {raw_response}"
|
94
|
+
)
|
95
|
+
if "generated_text" not in raw_response["results"][0]:
|
96
|
+
raise Exception(
|
97
|
+
f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
|
98
|
+
)
|
99
|
+
return raw_response["results"][0]["generated_text"]
|
100
|
+
|
101
|
+
LLM.__name__ = model_class_name
|
102
|
+
|
103
|
+
return LLM
|
@@ -15,19 +15,15 @@ class InferenceServicesCollection:
|
|
15
15
|
cls.added_models[service_name].append(model_name)
|
16
16
|
|
17
17
|
@staticmethod
|
18
|
-
def _get_service_available(service
|
18
|
+
def _get_service_available(service) -> list[str]:
|
19
19
|
from_api = True
|
20
20
|
try:
|
21
21
|
service_models = service.available()
|
22
22
|
except Exception as e:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
28
|
-
Relying on cache.""",
|
29
|
-
UserWarning,
|
30
|
-
)
|
23
|
+
warnings.warn(
|
24
|
+
f"Error getting models for {service._inference_service_}. Relying on cache.",
|
25
|
+
UserWarning,
|
26
|
+
)
|
31
27
|
from edsl.inference_services.models_available_cache import models_available
|
32
28
|
|
33
29
|
service_models = models_available.get(service._inference_service_, [])
|
@@ -61,8 +57,4 @@ class InferenceServicesCollection:
|
|
61
57
|
if service_name is None or service_name == service._inference_service_:
|
62
58
|
return service.create_model(model_name)
|
63
59
|
|
64
|
-
# if model_name == "test":
|
65
|
-
# from edsl.language_models import LanguageModel
|
66
|
-
# return LanguageModel(test = True)
|
67
|
-
|
68
60
|
raise Exception(f"Model {model_name} not found in any of the services")
|
@@ -1,9 +1,6 @@
|
|
1
1
|
from typing import Any, List
|
2
2
|
import re
|
3
|
-
import
|
4
|
-
|
5
|
-
# from openai import AsyncOpenAI
|
6
|
-
import openai
|
3
|
+
from openai import AsyncOpenAI
|
7
4
|
|
8
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
9
6
|
from edsl.language_models import LanguageModel
|
@@ -15,22 +12,6 @@ class OpenAIService(InferenceServiceABC):
|
|
15
12
|
|
16
13
|
_inference_service_ = "openai"
|
17
14
|
_env_key_name_ = "OPENAI_API_KEY"
|
18
|
-
_base_url_ = None
|
19
|
-
|
20
|
-
_sync_client_ = openai.OpenAI
|
21
|
-
_async_client_ = openai.AsyncOpenAI
|
22
|
-
|
23
|
-
@classmethod
|
24
|
-
def sync_client(cls):
|
25
|
-
return cls._sync_client_(
|
26
|
-
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
27
|
-
)
|
28
|
-
|
29
|
-
@classmethod
|
30
|
-
def async_client(cls):
|
31
|
-
return cls._async_client_(
|
32
|
-
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
33
|
-
)
|
34
15
|
|
35
16
|
# TODO: Make this a coop call
|
36
17
|
model_exclude_list = [
|
@@ -50,24 +31,16 @@ class OpenAIService(InferenceServiceABC):
|
|
50
31
|
]
|
51
32
|
_models_list_cache: List[str] = []
|
52
33
|
|
53
|
-
@classmethod
|
54
|
-
def get_model_list(cls):
|
55
|
-
raw_list = cls.sync_client().models.list()
|
56
|
-
if hasattr(raw_list, "data"):
|
57
|
-
return raw_list.data
|
58
|
-
else:
|
59
|
-
return raw_list
|
60
|
-
|
61
34
|
@classmethod
|
62
35
|
def available(cls) -> List[str]:
|
63
|
-
|
36
|
+
from openai import OpenAI
|
64
37
|
|
65
38
|
if not cls._models_list_cache:
|
66
39
|
try:
|
67
|
-
|
40
|
+
client = OpenAI()
|
68
41
|
cls._models_list_cache = [
|
69
42
|
m.id
|
70
|
-
for m in
|
43
|
+
for m in client.models.list()
|
71
44
|
if m.id not in cls.model_exclude_list
|
72
45
|
]
|
73
46
|
except Exception as e:
|
@@ -105,24 +78,15 @@ class OpenAIService(InferenceServiceABC):
|
|
105
78
|
"top_logprobs": 3,
|
106
79
|
}
|
107
80
|
|
108
|
-
def sync_client(self):
|
109
|
-
return cls.sync_client()
|
110
|
-
|
111
|
-
def async_client(self):
|
112
|
-
return cls.async_client()
|
113
|
-
|
114
81
|
@classmethod
|
115
82
|
def available(cls) -> list[str]:
|
116
|
-
|
117
|
-
|
118
|
-
# return client.models.list()
|
119
|
-
return cls.sync_client().models.list()
|
83
|
+
client = openai.OpenAI()
|
84
|
+
return client.models.list()
|
120
85
|
|
121
86
|
def get_headers(self) -> dict[str, Any]:
|
122
|
-
|
87
|
+
from openai import OpenAI
|
123
88
|
|
124
|
-
|
125
|
-
client = self.sync_client()
|
89
|
+
client = OpenAI()
|
126
90
|
response = client.chat.completions.with_raw_response.create(
|
127
91
|
messages=[
|
128
92
|
{
|
@@ -160,8 +124,8 @@ class OpenAIService(InferenceServiceABC):
|
|
160
124
|
encoded_image=None,
|
161
125
|
) -> dict[str, Any]:
|
162
126
|
"""Calls the OpenAI API and returns the API response."""
|
127
|
+
content = [{"type": "text", "text": user_prompt}]
|
163
128
|
if encoded_image:
|
164
|
-
content = [{"type": "text", "text": user_prompt}]
|
165
129
|
content.append(
|
166
130
|
{
|
167
131
|
"type": "image_url",
|
@@ -170,28 +134,21 @@ class OpenAIService(InferenceServiceABC):
|
|
170
134
|
},
|
171
135
|
}
|
172
136
|
)
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
# base_url = cls._base_url_
|
178
|
-
# )
|
179
|
-
client = self.async_client()
|
180
|
-
params = {
|
181
|
-
"model": self.model,
|
182
|
-
"messages": [
|
137
|
+
self.client = AsyncOpenAI()
|
138
|
+
response = await self.client.chat.completions.create(
|
139
|
+
model=self.model,
|
140
|
+
messages=[
|
183
141
|
{"role": "system", "content": system_prompt},
|
184
142
|
{"role": "user", "content": content},
|
185
143
|
],
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
response = await client.chat.completions.create(**params)
|
144
|
+
temperature=self.temperature,
|
145
|
+
max_tokens=self.max_tokens,
|
146
|
+
top_p=self.top_p,
|
147
|
+
frequency_penalty=self.frequency_penalty,
|
148
|
+
presence_penalty=self.presence_penalty,
|
149
|
+
logprobs=self.logprobs,
|
150
|
+
top_logprobs=self.top_logprobs if self.logprobs else None,
|
151
|
+
)
|
195
152
|
return response.model_dump()
|
196
153
|
|
197
154
|
@staticmethod
|
@@ -6,8 +6,7 @@ from edsl.inference_services.OpenAIService import OpenAIService
|
|
6
6
|
from edsl.inference_services.AnthropicService import AnthropicService
|
7
7
|
from edsl.inference_services.DeepInfraService import DeepInfraService
|
8
8
|
from edsl.inference_services.GoogleService import GoogleService
|
9
|
-
from edsl.inference_services.GroqService import GroqService
|
10
9
|
|
11
10
|
default = InferenceServicesCollection(
|
12
|
-
[OpenAIService, AnthropicService, DeepInfraService, GoogleService
|
11
|
+
[OpenAIService, AnthropicService, DeepInfraService, GoogleService]
|
13
12
|
)
|
edsl/jobs/Jobs.py
CHANGED
@@ -3,7 +3,9 @@ from __future__ import annotations
|
|
3
3
|
import warnings
|
4
4
|
from itertools import product
|
5
5
|
from typing import Optional, Union, Sequence, Generator
|
6
|
+
|
6
7
|
from edsl.Base import Base
|
8
|
+
|
7
9
|
from edsl.exceptions import MissingAPIKeyError
|
8
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
9
11
|
from edsl.jobs.interviews.Interview import Interview
|
@@ -319,11 +321,7 @@ class Jobs(Base):
|
|
319
321
|
self.scenarios = self.scenarios or [Scenario()]
|
320
322
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
323
|
yield Interview(
|
322
|
-
survey=self.survey,
|
323
|
-
agent=agent,
|
324
|
-
scenario=scenario,
|
325
|
-
model=model,
|
326
|
-
skip_retry=self.skip_retry,
|
324
|
+
survey=self.survey, agent=agent, scenario=scenario, model=model
|
327
325
|
)
|
328
326
|
|
329
327
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -413,12 +411,6 @@ class Jobs(Base):
|
|
413
411
|
if warn:
|
414
412
|
warnings.warn(message)
|
415
413
|
|
416
|
-
@property
|
417
|
-
def skip_retry(self):
|
418
|
-
if not hasattr(self, "_skip_retry"):
|
419
|
-
return False
|
420
|
-
return self._skip_retry
|
421
|
-
|
422
414
|
def run(
|
423
415
|
self,
|
424
416
|
n: int = 1,
|
@@ -433,7 +425,6 @@ class Jobs(Base):
|
|
433
425
|
print_exceptions=True,
|
434
426
|
remote_cache_description: Optional[str] = None,
|
435
427
|
remote_inference_description: Optional[str] = None,
|
436
|
-
skip_retry: bool = False,
|
437
428
|
) -> Results:
|
438
429
|
"""
|
439
430
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -452,7 +443,6 @@ class Jobs(Base):
|
|
452
443
|
from edsl.coop.coop import Coop
|
453
444
|
|
454
445
|
self._check_parameters()
|
455
|
-
self._skip_retry = skip_retry
|
456
446
|
|
457
447
|
if batch_mode is not None:
|
458
448
|
raise NotImplementedError(
|
@@ -471,11 +461,12 @@ class Jobs(Base):
|
|
471
461
|
remote_inference = False
|
472
462
|
|
473
463
|
if remote_inference:
|
474
|
-
import
|
475
|
-
from
|
476
|
-
from edsl.
|
477
|
-
|
478
|
-
|
464
|
+
from edsl.agents.Agent import Agent
|
465
|
+
from edsl.language_models.registry import Model
|
466
|
+
from edsl.results.Result import Result
|
467
|
+
from edsl.results.Results import Results
|
468
|
+
from edsl.scenarios.Scenario import Scenario
|
469
|
+
from edsl.surveys.Survey import Survey
|
479
470
|
|
480
471
|
self._output("Remote inference activated. Sending job to server...")
|
481
472
|
if remote_cache:
|
@@ -483,60 +474,33 @@ class Jobs(Base):
|
|
483
474
|
"Remote caching activated. The remote cache will be used for this job."
|
484
475
|
)
|
485
476
|
|
486
|
-
|
477
|
+
remote_job_data = coop.remote_inference_create(
|
487
478
|
self,
|
488
479
|
description=remote_inference_description,
|
489
480
|
status="queued",
|
490
|
-
iterations=n,
|
491
481
|
)
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
print(
|
504
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
505
|
-
)
|
506
|
-
return None
|
507
|
-
elif status == "failed":
|
508
|
-
print("\r" + " " * 80 + "\r", end="")
|
509
|
-
print("Job failed.")
|
510
|
-
print(
|
511
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
512
|
-
)
|
513
|
-
return None
|
514
|
-
elif status == "completed":
|
515
|
-
results_uuid = remote_job_data.get("results_uuid")
|
516
|
-
results = coop.get(results_uuid, expected_object_type="results")
|
517
|
-
print("\r" + " " * 80 + "\r", end="")
|
518
|
-
print(
|
519
|
-
f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
|
482
|
+
self._output("Job sent!")
|
483
|
+
# Create mock results object to store job data
|
484
|
+
results = Results(
|
485
|
+
survey=Survey(),
|
486
|
+
data=[
|
487
|
+
Result(
|
488
|
+
agent=Agent.example(),
|
489
|
+
scenario=Scenario.example(),
|
490
|
+
model=Model(),
|
491
|
+
iteration=1,
|
492
|
+
answer={"info": "Remote job details"},
|
520
493
|
)
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
print(
|
530
|
-
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
531
|
-
end="",
|
532
|
-
flush=True,
|
533
|
-
)
|
534
|
-
time.sleep(0.1)
|
535
|
-
i += 1
|
494
|
+
],
|
495
|
+
)
|
496
|
+
results.add_columns_from_dict([remote_job_data])
|
497
|
+
if self.verbose:
|
498
|
+
results.select(["info", "uuid", "status", "version"]).print(
|
499
|
+
format="rich"
|
500
|
+
)
|
501
|
+
return results
|
536
502
|
else:
|
537
503
|
if check_api_keys:
|
538
|
-
from edsl import Model
|
539
|
-
|
540
504
|
for model in self.models + [Model()]:
|
541
505
|
if not model.has_valid_api_key():
|
542
506
|
raise MissingAPIKeyError(
|
@@ -642,9 +606,9 @@ class Jobs(Base):
|
|
642
606
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
643
607
|
return results
|
644
608
|
|
645
|
-
async def run_async(self, cache=None,
|
609
|
+
async def run_async(self, cache=None, **kwargs):
|
646
610
|
"""Run the job asynchronously."""
|
647
|
-
results = await JobsRunnerAsyncio(self).run_async(cache=cache,
|
611
|
+
results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
|
648
612
|
return results
|
649
613
|
|
650
614
|
def all_question_parameters(self):
|
@@ -724,10 +688,7 @@ class Jobs(Base):
|
|
724
688
|
#######################
|
725
689
|
@classmethod
|
726
690
|
def example(
|
727
|
-
cls,
|
728
|
-
throw_exception_probability: int = 0,
|
729
|
-
randomize: bool = False,
|
730
|
-
test_model=False,
|
691
|
+
cls, throw_exception_probability: int = 0, randomize: bool = False
|
731
692
|
) -> Jobs:
|
732
693
|
"""Return an example Jobs instance.
|
733
694
|
|
@@ -745,11 +706,6 @@ class Jobs(Base):
|
|
745
706
|
|
746
707
|
addition = "" if not randomize else str(uuid4())
|
747
708
|
|
748
|
-
if test_model:
|
749
|
-
from edsl.language_models import LanguageModel
|
750
|
-
|
751
|
-
m = LanguageModel.example(test_model=True)
|
752
|
-
|
753
709
|
# (status, question, period)
|
754
710
|
agent_answers = {
|
755
711
|
("Joyful", "how_feeling", "morning"): "OK",
|
@@ -797,10 +753,7 @@ class Jobs(Base):
|
|
797
753
|
Scenario({"period": "afternoon"}),
|
798
754
|
]
|
799
755
|
)
|
800
|
-
|
801
|
-
job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
|
802
|
-
else:
|
803
|
-
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
756
|
+
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
804
757
|
|
805
758
|
return job
|
806
759
|
|