edsl 0.1.31.dev3__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +7 -2
- edsl/agents/PromptConstructionMixin.py +35 -15
- edsl/config.py +15 -1
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/coop.py +4 -0
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +5 -0
- edsl/exceptions/general.py +10 -8
- edsl/inference_services/AwsBedrock.py +110 -0
- edsl/inference_services/AzureAI.py +197 -0
- edsl/inference_services/DeepInfraService.py +6 -91
- edsl/inference_services/GroqService.py +18 -0
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +68 -21
- edsl/inference_services/models_available_cache.py +31 -0
- edsl/inference_services/registry.py +14 -1
- edsl/jobs/Jobs.py +103 -21
- edsl/jobs/buckets/TokenBucket.py +12 -4
- edsl/jobs/interviews/Interview.py +31 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -33
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
- edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
- edsl/jobs/tasks/TaskCreators.py +8 -2
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +62 -41
- edsl/language_models/registry.py +4 -0
- edsl/questions/QuestionBudget.py +0 -1
- edsl/questions/QuestionCheckBox.py +0 -1
- edsl/questions/QuestionExtract.py +0 -1
- edsl/questions/QuestionFreeText.py +2 -9
- edsl/questions/QuestionList.py +0 -1
- edsl/questions/QuestionMultipleChoice.py +1 -2
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/QuestionRank.py +0 -1
- edsl/results/DatasetExportMixin.py +33 -3
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +216 -13
- edsl/scenarios/ScenarioListExportMixin.py +15 -4
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/surveys/Rule.py +5 -2
- edsl/surveys/Survey.py +84 -1
- edsl/surveys/SurveyQualtricsImport.py +213 -0
- edsl/utilities/utilities.py +31 -0
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/METADATA +5 -1
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/RECORD +52 -46
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -0,0 +1,18 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
3
|
+
|
4
|
+
import groq
|
5
|
+
|
6
|
+
|
7
|
+
class GroqService(OpenAIService):
|
8
|
+
"""DeepInfra service class."""
|
9
|
+
|
10
|
+
_inference_service_ = "groq"
|
11
|
+
_env_key_name_ = "GROQ_API_KEY"
|
12
|
+
|
13
|
+
_sync_client_ = groq.Groq
|
14
|
+
_async_client_ = groq.AsyncGroq
|
15
|
+
|
16
|
+
# _base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
|
+
_base_url_ = None
|
18
|
+
_models_list_cache: List[str] = []
|
@@ -15,18 +15,19 @@ class InferenceServicesCollection:
|
|
15
15
|
cls.added_models[service_name].append(model_name)
|
16
16
|
|
17
17
|
@staticmethod
|
18
|
-
def _get_service_available(service) -> list[str]:
|
18
|
+
def _get_service_available(service, warn: bool = False) -> list[str]:
|
19
19
|
from_api = True
|
20
20
|
try:
|
21
21
|
service_models = service.available()
|
22
22
|
except Exception as e:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
23
|
+
if warn:
|
24
|
+
warnings.warn(
|
25
|
+
f"""Error getting models for {service._inference_service_}.
|
26
|
+
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
|
27
|
+
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
|
28
|
+
Relying on cache.""",
|
29
|
+
UserWarning,
|
30
|
+
)
|
30
31
|
from edsl.inference_services.models_available_cache import models_available
|
31
32
|
|
32
33
|
service_models = models_available.get(service._inference_service_, [])
|
@@ -60,4 +61,8 @@ class InferenceServicesCollection:
|
|
60
61
|
if service_name is None or service_name == service._inference_service_:
|
61
62
|
return service.create_model(model_name)
|
62
63
|
|
64
|
+
# if model_name == "test":
|
65
|
+
# from edsl.language_models import LanguageModel
|
66
|
+
# return LanguageModel(test = True)
|
67
|
+
|
63
68
|
raise Exception(f"Model {model_name} not found in any of the services")
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List
|
5
|
+
|
6
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
7
|
+
from edsl.language_models import LanguageModel
|
8
|
+
|
9
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
10
|
+
|
11
|
+
|
12
|
+
class OllamaService(OpenAIService):
|
13
|
+
"""DeepInfra service class."""
|
14
|
+
|
15
|
+
_inference_service_ = "ollama"
|
16
|
+
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
17
|
+
_base_url_ = "http://localhost:11434/v1"
|
18
|
+
_models_list_cache: List[str] = []
|
@@ -1,10 +1,14 @@
|
|
1
1
|
from typing import Any, List
|
2
2
|
import re
|
3
|
-
|
3
|
+
import os
|
4
|
+
|
5
|
+
# from openai import AsyncOpenAI
|
6
|
+
import openai
|
4
7
|
|
5
8
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
9
|
from edsl.language_models import LanguageModel
|
7
10
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
11
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
8
12
|
|
9
13
|
|
10
14
|
class OpenAIService(InferenceServiceABC):
|
@@ -12,6 +16,22 @@ class OpenAIService(InferenceServiceABC):
|
|
12
16
|
|
13
17
|
_inference_service_ = "openai"
|
14
18
|
_env_key_name_ = "OPENAI_API_KEY"
|
19
|
+
_base_url_ = None
|
20
|
+
|
21
|
+
_sync_client_ = openai.OpenAI
|
22
|
+
_async_client_ = openai.AsyncOpenAI
|
23
|
+
|
24
|
+
@classmethod
|
25
|
+
def sync_client(cls):
|
26
|
+
return cls._sync_client_(
|
27
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
28
|
+
)
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def async_client(cls):
|
32
|
+
return cls._async_client_(
|
33
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
34
|
+
)
|
15
35
|
|
16
36
|
# TODO: Make this a coop call
|
17
37
|
model_exclude_list = [
|
@@ -31,16 +51,24 @@ class OpenAIService(InferenceServiceABC):
|
|
31
51
|
]
|
32
52
|
_models_list_cache: List[str] = []
|
33
53
|
|
54
|
+
@classmethod
|
55
|
+
def get_model_list(cls):
|
56
|
+
raw_list = cls.sync_client().models.list()
|
57
|
+
if hasattr(raw_list, "data"):
|
58
|
+
return raw_list.data
|
59
|
+
else:
|
60
|
+
return raw_list
|
61
|
+
|
34
62
|
@classmethod
|
35
63
|
def available(cls) -> List[str]:
|
36
|
-
from openai import OpenAI
|
64
|
+
# from openai import OpenAI
|
37
65
|
|
38
66
|
if not cls._models_list_cache:
|
39
67
|
try:
|
40
|
-
client = OpenAI()
|
68
|
+
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
41
69
|
cls._models_list_cache = [
|
42
70
|
m.id
|
43
|
-
for m in
|
71
|
+
for m in cls.get_model_list()
|
44
72
|
if m.id not in cls.model_exclude_list
|
45
73
|
]
|
46
74
|
except Exception as e:
|
@@ -78,15 +106,24 @@ class OpenAIService(InferenceServiceABC):
|
|
78
106
|
"top_logprobs": 3,
|
79
107
|
}
|
80
108
|
|
109
|
+
def sync_client(self):
|
110
|
+
return cls.sync_client()
|
111
|
+
|
112
|
+
def async_client(self):
|
113
|
+
return cls.async_client()
|
114
|
+
|
81
115
|
@classmethod
|
82
116
|
def available(cls) -> list[str]:
|
83
|
-
|
84
|
-
|
117
|
+
# import openai
|
118
|
+
# client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
119
|
+
# return client.models.list()
|
120
|
+
return cls.sync_client().models.list()
|
85
121
|
|
86
122
|
def get_headers(self) -> dict[str, Any]:
|
87
|
-
from openai import OpenAI
|
123
|
+
# from openai import OpenAI
|
88
124
|
|
89
|
-
client = OpenAI()
|
125
|
+
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
126
|
+
client = self.sync_client()
|
90
127
|
response = client.chat.completions.with_raw_response.create(
|
91
128
|
messages=[
|
92
129
|
{
|
@@ -124,8 +161,8 @@ class OpenAIService(InferenceServiceABC):
|
|
124
161
|
encoded_image=None,
|
125
162
|
) -> dict[str, Any]:
|
126
163
|
"""Calls the OpenAI API and returns the API response."""
|
127
|
-
content = [{"type": "text", "text": user_prompt}]
|
128
164
|
if encoded_image:
|
165
|
+
content = [{"type": "text", "text": user_prompt}]
|
129
166
|
content.append(
|
130
167
|
{
|
131
168
|
"type": "image_url",
|
@@ -134,21 +171,28 @@ class OpenAIService(InferenceServiceABC):
|
|
134
171
|
},
|
135
172
|
}
|
136
173
|
)
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
174
|
+
else:
|
175
|
+
content = user_prompt
|
176
|
+
# self.client = AsyncOpenAI(
|
177
|
+
# api_key = os.getenv(cls._env_key_name_),
|
178
|
+
# base_url = cls._base_url_
|
179
|
+
# )
|
180
|
+
client = self.async_client()
|
181
|
+
params = {
|
182
|
+
"model": self.model,
|
183
|
+
"messages": [
|
141
184
|
{"role": "system", "content": system_prompt},
|
142
185
|
{"role": "user", "content": content},
|
143
186
|
],
|
144
|
-
temperature
|
145
|
-
max_tokens
|
146
|
-
top_p
|
147
|
-
frequency_penalty
|
148
|
-
presence_penalty
|
149
|
-
logprobs
|
150
|
-
top_logprobs
|
151
|
-
|
187
|
+
"temperature": self.temperature,
|
188
|
+
"max_tokens": self.max_tokens,
|
189
|
+
"top_p": self.top_p,
|
190
|
+
"frequency_penalty": self.frequency_penalty,
|
191
|
+
"presence_penalty": self.presence_penalty,
|
192
|
+
"logprobs": self.logprobs,
|
193
|
+
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
194
|
+
}
|
195
|
+
response = await client.chat.completions.create(**params)
|
152
196
|
return response.model_dump()
|
153
197
|
|
154
198
|
@staticmethod
|
@@ -164,6 +208,9 @@ class OpenAIService(InferenceServiceABC):
|
|
164
208
|
if match:
|
165
209
|
return match.group(1)
|
166
210
|
else:
|
211
|
+
out = fix_partial_correct_response(response)
|
212
|
+
if "error" not in out:
|
213
|
+
response = out["extracted_json"]
|
167
214
|
return response
|
168
215
|
|
169
216
|
LLM.__name__ = "LanguageModel"
|
@@ -66,4 +66,35 @@ models_available = {
|
|
66
66
|
"openchat/openchat_3.5",
|
67
67
|
],
|
68
68
|
"google": ["gemini-pro"],
|
69
|
+
"bedrock": [
|
70
|
+
"amazon.titan-tg1-large",
|
71
|
+
"amazon.titan-text-lite-v1",
|
72
|
+
"amazon.titan-text-express-v1",
|
73
|
+
"ai21.j2-grande-instruct",
|
74
|
+
"ai21.j2-jumbo-instruct",
|
75
|
+
"ai21.j2-mid",
|
76
|
+
"ai21.j2-mid-v1",
|
77
|
+
"ai21.j2-ultra",
|
78
|
+
"ai21.j2-ultra-v1",
|
79
|
+
"anthropic.claude-instant-v1",
|
80
|
+
"anthropic.claude-v2:1",
|
81
|
+
"anthropic.claude-v2",
|
82
|
+
"anthropic.claude-3-sonnet-20240229-v1:0",
|
83
|
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
84
|
+
"anthropic.claude-3-opus-20240229-v1:0",
|
85
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
86
|
+
"cohere.command-text-v14",
|
87
|
+
"cohere.command-r-v1:0",
|
88
|
+
"cohere.command-r-plus-v1:0",
|
89
|
+
"cohere.command-light-text-v14",
|
90
|
+
"meta.llama3-8b-instruct-v1:0",
|
91
|
+
"meta.llama3-70b-instruct-v1:0",
|
92
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
93
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
94
|
+
"meta.llama3-1-405b-instruct-v1:0",
|
95
|
+
"mistral.mistral-7b-instruct-v0:2",
|
96
|
+
"mistral.mixtral-8x7b-instruct-v0:1",
|
97
|
+
"mistral.mistral-large-2402-v1:0",
|
98
|
+
"mistral.mistral-large-2407-v1:0",
|
99
|
+
],
|
69
100
|
}
|
@@ -6,7 +6,20 @@ from edsl.inference_services.OpenAIService import OpenAIService
|
|
6
6
|
from edsl.inference_services.AnthropicService import AnthropicService
|
7
7
|
from edsl.inference_services.DeepInfraService import DeepInfraService
|
8
8
|
from edsl.inference_services.GoogleService import GoogleService
|
9
|
+
from edsl.inference_services.GroqService import GroqService
|
10
|
+
from edsl.inference_services.AwsBedrock import AwsBedrockService
|
11
|
+
from edsl.inference_services.AzureAI import AzureAIService
|
12
|
+
from edsl.inference_services.OllamaService import OllamaService
|
9
13
|
|
10
14
|
default = InferenceServicesCollection(
|
11
|
-
[
|
15
|
+
[
|
16
|
+
OpenAIService,
|
17
|
+
AnthropicService,
|
18
|
+
DeepInfraService,
|
19
|
+
GoogleService,
|
20
|
+
GroqService,
|
21
|
+
AwsBedrockService,
|
22
|
+
AzureAIService,
|
23
|
+
OllamaService,
|
24
|
+
]
|
12
25
|
)
|
edsl/jobs/Jobs.py
CHANGED
@@ -39,6 +39,8 @@ class Jobs(Base):
|
|
39
39
|
|
40
40
|
self.__bucket_collection = None
|
41
41
|
|
42
|
+
# these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
|
43
|
+
|
42
44
|
@property
|
43
45
|
def models(self):
|
44
46
|
return self._models
|
@@ -119,7 +121,9 @@ class Jobs(Base):
|
|
119
121
|
- scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
|
120
122
|
- models: new models overwrite old models.
|
121
123
|
"""
|
122
|
-
passed_objects = self._turn_args_to_list(
|
124
|
+
passed_objects = self._turn_args_to_list(
|
125
|
+
args
|
126
|
+
) # objects can also be passed comma-separated
|
123
127
|
|
124
128
|
current_objects, objects_key = self._get_current_objects_of_this_type(
|
125
129
|
passed_objects[0]
|
@@ -176,17 +180,27 @@ class Jobs(Base):
|
|
176
180
|
from edsl.agents.Agent import Agent
|
177
181
|
from edsl.scenarios.Scenario import Scenario
|
178
182
|
from edsl.scenarios.ScenarioList import ScenarioList
|
183
|
+
from edsl.language_models.ModelList import ModelList
|
179
184
|
|
180
185
|
if isinstance(object, Agent):
|
181
186
|
return AgentList
|
182
187
|
elif isinstance(object, Scenario):
|
183
188
|
return ScenarioList
|
189
|
+
elif isinstance(object, ModelList):
|
190
|
+
return ModelList
|
184
191
|
else:
|
185
192
|
return list
|
186
193
|
|
187
194
|
@staticmethod
|
188
195
|
def _turn_args_to_list(args):
|
189
|
-
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
|
196
|
+
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
|
197
|
+
|
198
|
+
Example:
|
199
|
+
|
200
|
+
>>> Jobs._turn_args_to_list([1,2,3])
|
201
|
+
[1, 2, 3]
|
202
|
+
|
203
|
+
"""
|
190
204
|
|
191
205
|
def did_user_pass_a_sequence(args):
|
192
206
|
"""Return True if the user passed a sequence, False otherwise.
|
@@ -209,7 +223,7 @@ class Jobs(Base):
|
|
209
223
|
return container_class(args)
|
210
224
|
|
211
225
|
def _get_current_objects_of_this_type(
|
212
|
-
self, object: Union[Agent, Scenario, LanguageModel]
|
226
|
+
self, object: Union["Agent", "Scenario", "LanguageModel"]
|
213
227
|
) -> tuple[list, str]:
|
214
228
|
from edsl.agents.Agent import Agent
|
215
229
|
from edsl.scenarios.Scenario import Scenario
|
@@ -292,7 +306,11 @@ class Jobs(Base):
|
|
292
306
|
|
293
307
|
@classmethod
|
294
308
|
def from_interviews(cls, interview_list):
|
295
|
-
"""Return a Jobs instance from a list of interviews.
|
309
|
+
"""Return a Jobs instance from a list of interviews.
|
310
|
+
|
311
|
+
This is useful when you have, say, a list of failed interviews and you want to create
|
312
|
+
a new job with only those interviews.
|
313
|
+
"""
|
296
314
|
survey = interview_list[0].survey
|
297
315
|
# get all the models
|
298
316
|
models = list(set([interview.model for interview in interview_list]))
|
@@ -308,6 +326,8 @@ class Jobs(Base):
|
|
308
326
|
Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
|
309
327
|
This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
|
310
328
|
with us filling in defaults.
|
329
|
+
|
330
|
+
|
311
331
|
"""
|
312
332
|
# if no agents, models, or scenarios are set, set them to defaults
|
313
333
|
from edsl.agents.Agent import Agent
|
@@ -319,7 +339,11 @@ class Jobs(Base):
|
|
319
339
|
self.scenarios = self.scenarios or [Scenario()]
|
320
340
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
341
|
yield Interview(
|
322
|
-
survey=self.survey,
|
342
|
+
survey=self.survey,
|
343
|
+
agent=agent,
|
344
|
+
scenario=scenario,
|
345
|
+
model=model,
|
346
|
+
skip_retry=self.skip_retry,
|
323
347
|
)
|
324
348
|
|
325
349
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -359,10 +383,16 @@ class Jobs(Base):
|
|
359
383
|
return links
|
360
384
|
|
361
385
|
def __hash__(self):
|
362
|
-
"""Allow the model to be used as a key in a dictionary.
|
386
|
+
"""Allow the model to be used as a key in a dictionary.
|
387
|
+
|
388
|
+
>>> from edsl.jobs import Jobs
|
389
|
+
>>> hash(Jobs.example())
|
390
|
+
846655441787442972
|
391
|
+
|
392
|
+
"""
|
363
393
|
from edsl.utilities.utilities import dict_hash
|
364
394
|
|
365
|
-
return dict_hash(self.
|
395
|
+
return dict_hash(self._to_dict())
|
366
396
|
|
367
397
|
def _output(self, message) -> None:
|
368
398
|
"""Check if a Job is verbose. If so, print the message."""
|
@@ -390,11 +420,27 @@ class Jobs(Base):
|
|
390
420
|
Traceback (most recent call last):
|
391
421
|
...
|
392
422
|
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
423
|
+
|
424
|
+
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
425
|
+
>>> s = Scenario({'ugly_question': "B"})
|
426
|
+
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
427
|
+
>>> j._check_parameters()
|
428
|
+
Traceback (most recent call last):
|
429
|
+
...
|
430
|
+
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
393
431
|
"""
|
394
432
|
survey_parameters: set = self.survey.parameters
|
395
433
|
scenario_parameters: set = self.scenarios.parameters
|
396
434
|
|
397
|
-
msg1, msg2 = None, None
|
435
|
+
msg0, msg1, msg2 = None, None, None
|
436
|
+
|
437
|
+
# look for key issues
|
438
|
+
if intersection := set(self.scenarios.parameters) & set(
|
439
|
+
self.survey.question_names
|
440
|
+
):
|
441
|
+
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
442
|
+
|
443
|
+
raise ValueError(msg0)
|
398
444
|
|
399
445
|
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
400
446
|
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
@@ -409,6 +455,12 @@ class Jobs(Base):
|
|
409
455
|
if warn:
|
410
456
|
warnings.warn(message)
|
411
457
|
|
458
|
+
@property
|
459
|
+
def skip_retry(self):
|
460
|
+
if not hasattr(self, "_skip_retry"):
|
461
|
+
return False
|
462
|
+
return self._skip_retry
|
463
|
+
|
412
464
|
def run(
|
413
465
|
self,
|
414
466
|
n: int = 1,
|
@@ -423,6 +475,7 @@ class Jobs(Base):
|
|
423
475
|
print_exceptions=True,
|
424
476
|
remote_cache_description: Optional[str] = None,
|
425
477
|
remote_inference_description: Optional[str] = None,
|
478
|
+
skip_retry: bool = False,
|
426
479
|
) -> Results:
|
427
480
|
"""
|
428
481
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -441,6 +494,7 @@ class Jobs(Base):
|
|
441
494
|
from edsl.coop.coop import Coop
|
442
495
|
|
443
496
|
self._check_parameters()
|
497
|
+
self._skip_retry = skip_retry
|
444
498
|
|
445
499
|
if batch_mode is not None:
|
446
500
|
raise NotImplementedError(
|
@@ -475,6 +529,7 @@ class Jobs(Base):
|
|
475
529
|
self,
|
476
530
|
description=remote_inference_description,
|
477
531
|
status="queued",
|
532
|
+
iterations=n,
|
478
533
|
)
|
479
534
|
time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
|
480
535
|
job_uuid = remote_job_creation_data.get("uuid")
|
@@ -629,13 +684,17 @@ class Jobs(Base):
|
|
629
684
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
630
685
|
return results
|
631
686
|
|
632
|
-
async def run_async(self, cache=None, **kwargs):
|
633
|
-
"""Run
|
634
|
-
results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
|
687
|
+
async def run_async(self, cache=None, n=1, **kwargs):
|
688
|
+
"""Run asynchronously."""
|
689
|
+
results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
635
690
|
return results
|
636
691
|
|
637
692
|
def all_question_parameters(self):
|
638
|
-
"""Return all the fields in the questions in the survey.
|
693
|
+
"""Return all the fields in the questions in the survey.
|
694
|
+
>>> from edsl.jobs import Jobs
|
695
|
+
>>> Jobs.example().all_question_parameters()
|
696
|
+
{'period'}
|
697
|
+
"""
|
639
698
|
return set.union(*[question.parameters for question in self.survey.questions])
|
640
699
|
|
641
700
|
#######################
|
@@ -676,15 +735,19 @@ class Jobs(Base):
|
|
676
735
|
#######################
|
677
736
|
# Serialization methods
|
678
737
|
#######################
|
738
|
+
|
739
|
+
def _to_dict(self):
|
740
|
+
return {
|
741
|
+
"survey": self.survey._to_dict(),
|
742
|
+
"agents": [agent._to_dict() for agent in self.agents],
|
743
|
+
"models": [model._to_dict() for model in self.models],
|
744
|
+
"scenarios": [scenario._to_dict() for scenario in self.scenarios],
|
745
|
+
}
|
746
|
+
|
679
747
|
@add_edsl_version
|
680
748
|
def to_dict(self) -> dict:
|
681
749
|
"""Convert the Jobs instance to a dictionary."""
|
682
|
-
return
|
683
|
-
"survey": self.survey.to_dict(),
|
684
|
-
"agents": [agent.to_dict() for agent in self.agents],
|
685
|
-
"models": [model.to_dict() for model in self.models],
|
686
|
-
"scenarios": [scenario.to_dict() for scenario in self.scenarios],
|
687
|
-
}
|
750
|
+
return self._to_dict()
|
688
751
|
|
689
752
|
@classmethod
|
690
753
|
@remove_edsl_version
|
@@ -703,7 +766,13 @@ class Jobs(Base):
|
|
703
766
|
)
|
704
767
|
|
705
768
|
def __eq__(self, other: Jobs) -> bool:
|
706
|
-
"""Return True if the Jobs instance is equal to another Jobs instance.
|
769
|
+
"""Return True if the Jobs instance is equal to another Jobs instance.
|
770
|
+
|
771
|
+
>>> from edsl.jobs import Jobs
|
772
|
+
>>> Jobs.example() == Jobs.example()
|
773
|
+
True
|
774
|
+
|
775
|
+
"""
|
707
776
|
return self.to_dict() == other.to_dict()
|
708
777
|
|
709
778
|
#######################
|
@@ -711,11 +780,16 @@ class Jobs(Base):
|
|
711
780
|
#######################
|
712
781
|
@classmethod
|
713
782
|
def example(
|
714
|
-
cls,
|
783
|
+
cls,
|
784
|
+
throw_exception_probability: float = 0.0,
|
785
|
+
randomize: bool = False,
|
786
|
+
test_model=False,
|
715
787
|
) -> Jobs:
|
716
788
|
"""Return an example Jobs instance.
|
717
789
|
|
718
790
|
:param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
|
791
|
+
:param randomize: whether to randomize the job by adding a random string to the period
|
792
|
+
:param test_model: whether to use a test model
|
719
793
|
|
720
794
|
>>> Jobs.example()
|
721
795
|
Jobs(...)
|
@@ -729,6 +803,11 @@ class Jobs(Base):
|
|
729
803
|
|
730
804
|
addition = "" if not randomize else str(uuid4())
|
731
805
|
|
806
|
+
if test_model:
|
807
|
+
from edsl.language_models import LanguageModel
|
808
|
+
|
809
|
+
m = LanguageModel.example(test_model=True)
|
810
|
+
|
732
811
|
# (status, question, period)
|
733
812
|
agent_answers = {
|
734
813
|
("Joyful", "how_feeling", "morning"): "OK",
|
@@ -776,7 +855,10 @@ class Jobs(Base):
|
|
776
855
|
Scenario({"period": "afternoon"}),
|
777
856
|
]
|
778
857
|
)
|
779
|
-
|
858
|
+
if test_model:
|
859
|
+
job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
|
860
|
+
else:
|
861
|
+
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
780
862
|
|
781
863
|
return job
|
782
864
|
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -100,7 +100,9 @@ class TokenBucket:
|
|
100
100
|
available_tokens = min(self.capacity, self.tokens + refill_amount)
|
101
101
|
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
102
102
|
|
103
|
-
async def get_tokens(
|
103
|
+
async def get_tokens(
|
104
|
+
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
105
|
+
) -> None:
|
104
106
|
"""Wait for the specified number of tokens to become available.
|
105
107
|
|
106
108
|
|
@@ -116,14 +118,20 @@ class TokenBucket:
|
|
116
118
|
True
|
117
119
|
|
118
120
|
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
119
|
-
>>> asyncio.run(bucket.get_tokens(11))
|
121
|
+
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=False))
|
120
122
|
Traceback (most recent call last):
|
121
123
|
...
|
122
124
|
ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
|
125
|
+
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
|
123
126
|
"""
|
124
127
|
if amount > self.capacity:
|
125
|
-
|
126
|
-
|
128
|
+
if not cheat_bucket_capacity:
|
129
|
+
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
130
|
+
raise ValueError(msg)
|
131
|
+
else:
|
132
|
+
self.tokens = 0 # clear the bucket but let it go through
|
133
|
+
return
|
134
|
+
|
127
135
|
while self.tokens < amount:
|
128
136
|
self.refill()
|
129
137
|
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|