edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -1,7 +1,17 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
-
import
|
4
|
-
from
|
3
|
+
import asyncio
|
4
|
+
from inspect import signature
|
5
|
+
from typing import (
|
6
|
+
Literal,
|
7
|
+
Optional,
|
8
|
+
Union,
|
9
|
+
Sequence,
|
10
|
+
Generator,
|
11
|
+
TYPE_CHECKING,
|
12
|
+
Callable,
|
13
|
+
Tuple,
|
14
|
+
)
|
5
15
|
|
6
16
|
from edsl.Base import Base
|
7
17
|
|
@@ -9,10 +19,13 @@ from edsl.jobs.buckets.BucketCollection import BucketCollection
|
|
9
19
|
from edsl.jobs.JobsPrompts import JobsPrompts
|
10
20
|
from edsl.jobs.interviews.Interview import Interview
|
11
21
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
12
|
-
|
22
|
+
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
13
23
|
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
14
24
|
from edsl.exceptions.coop import CoopServerResponseError
|
15
25
|
|
26
|
+
from edsl.jobs.JobsChecks import JobsChecks
|
27
|
+
from edsl.jobs.data_structures import RunEnvironment, RunParameters, RunConfig
|
28
|
+
|
16
29
|
if TYPE_CHECKING:
|
17
30
|
from edsl.agents.Agent import Agent
|
18
31
|
from edsl.agents.AgentList import AgentList
|
@@ -23,6 +36,66 @@ if TYPE_CHECKING:
|
|
23
36
|
from edsl.results.Results import Results
|
24
37
|
from edsl.results.Dataset import Dataset
|
25
38
|
from edsl.language_models.ModelList import ModelList
|
39
|
+
from edsl.data.Cache import Cache
|
40
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
41
|
+
|
42
|
+
VisibilityType = Literal["private", "public", "unlisted"]
|
43
|
+
|
44
|
+
from dataclasses import dataclass
|
45
|
+
from typing import Optional, Union, TypeVar, Callable, cast
|
46
|
+
from functools import wraps
|
47
|
+
|
48
|
+
try:
|
49
|
+
from typing import ParamSpec
|
50
|
+
except ImportError:
|
51
|
+
from typing_extensions import ParamSpec
|
52
|
+
|
53
|
+
|
54
|
+
P = ParamSpec("P")
|
55
|
+
T = TypeVar("T")
|
56
|
+
|
57
|
+
|
58
|
+
from edsl.jobs.check_survey_scenario_compatibility import (
|
59
|
+
CheckSurveyScenarioCompatibility,
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def with_config(f: Callable[P, T]) -> Callable[P, T]:
|
64
|
+
"This decorator make it so that the run function parameters match the RunConfig dataclass."
|
65
|
+
parameter_fields = {
|
66
|
+
name: field.default
|
67
|
+
for name, field in RunParameters.__dataclass_fields__.items()
|
68
|
+
}
|
69
|
+
environment_fields = {
|
70
|
+
name: field.default
|
71
|
+
for name, field in RunEnvironment.__dataclass_fields__.items()
|
72
|
+
}
|
73
|
+
combined = {**parameter_fields, **environment_fields}
|
74
|
+
|
75
|
+
@wraps(f)
|
76
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
77
|
+
environment = RunEnvironment(
|
78
|
+
**{k: v for k, v in kwargs.items() if k in environment_fields}
|
79
|
+
)
|
80
|
+
parameters = RunParameters(
|
81
|
+
**{k: v for k, v in kwargs.items() if k in parameter_fields}
|
82
|
+
)
|
83
|
+
config = RunConfig(environment=environment, parameters=parameters)
|
84
|
+
return f(*args, config=config)
|
85
|
+
|
86
|
+
# Update the wrapper's signature to include all RunConfig parameters
|
87
|
+
# old_sig = signature(f)
|
88
|
+
# wrapper.__signature__ = old_sig.replace(
|
89
|
+
# parameters=list(old_sig.parameters.values())[:-1]
|
90
|
+
# + [
|
91
|
+
# old_sig.parameters["config"].replace(
|
92
|
+
# default=parameter_fields[name], name=name
|
93
|
+
# )
|
94
|
+
# for name in combined
|
95
|
+
# ]
|
96
|
+
# )
|
97
|
+
|
98
|
+
return cast(Callable[P, T], wrapper)
|
26
99
|
|
27
100
|
|
28
101
|
class Jobs(Base):
|
@@ -46,15 +119,62 @@ class Jobs(Base):
|
|
46
119
|
:param models: a list of models
|
47
120
|
:param scenarios: a list of scenarios
|
48
121
|
"""
|
122
|
+
self.run_config = RunConfig(
|
123
|
+
environment=RunEnvironment(), parameters=RunParameters()
|
124
|
+
)
|
125
|
+
|
49
126
|
self.survey = survey
|
50
127
|
self.agents: AgentList = agents
|
51
128
|
self.scenarios: ScenarioList = scenarios
|
52
|
-
self.models = models
|
129
|
+
self.models: ModelList = models
|
130
|
+
|
131
|
+
def add_running_env(self, running_env: RunEnvironment):
|
132
|
+
self.run_config.add_environment(running_env)
|
133
|
+
return self
|
53
134
|
|
54
|
-
|
135
|
+
def using_cache(self, cache: "Cache") -> Jobs:
|
136
|
+
"""
|
137
|
+
Add a Cache to the job.
|
138
|
+
|
139
|
+
:param cache: the cache to add
|
140
|
+
"""
|
141
|
+
self.run_config.add_cache(cache)
|
142
|
+
return self
|
143
|
+
|
144
|
+
def using_bucket_collection(self, bucket_collection: BucketCollection) -> Jobs:
|
145
|
+
"""
|
146
|
+
Add a BucketCollection to the job.
|
147
|
+
|
148
|
+
:param bucket_collection: the bucket collection to add
|
149
|
+
"""
|
150
|
+
self.run_config.add_bucket_collection(bucket_collection)
|
151
|
+
return self
|
152
|
+
|
153
|
+
def using_key_lookup(self, key_lookup: KeyLookup) -> Jobs:
|
154
|
+
"""
|
155
|
+
Add a KeyLookup to the job.
|
156
|
+
|
157
|
+
:param key_lookup: the key lookup to add
|
158
|
+
"""
|
159
|
+
self.run_config.add_key_lookup(key_lookup)
|
160
|
+
return self
|
161
|
+
|
162
|
+
def using(self, obj: Union[Cache, BucketCollection, KeyLookup]) -> Jobs:
|
163
|
+
"""
|
164
|
+
Add a Cache, BucketCollection, or KeyLookup to the job.
|
165
|
+
|
166
|
+
:param obj: the object to add
|
167
|
+
"""
|
168
|
+
from edsl.data.Cache import Cache
|
169
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
55
170
|
|
56
|
-
|
57
|
-
|
171
|
+
if isinstance(obj, Cache):
|
172
|
+
self.using_cache(obj)
|
173
|
+
elif isinstance(obj, BucketCollection):
|
174
|
+
self.using_bucket_collection(obj)
|
175
|
+
elif isinstance(obj, KeyLookup):
|
176
|
+
self.using_key_lookup(obj)
|
177
|
+
return self
|
58
178
|
|
59
179
|
@property
|
60
180
|
def models(self):
|
@@ -72,6 +192,12 @@ class Jobs(Base):
|
|
72
192
|
else:
|
73
193
|
self._models = ModelList([])
|
74
194
|
|
195
|
+
# update the bucket collection if it exists
|
196
|
+
if self.run_config.environment.bucket_collection is None:
|
197
|
+
self.run_config.environment.bucket_collection = (
|
198
|
+
self.create_bucket_collection()
|
199
|
+
)
|
200
|
+
|
75
201
|
@property
|
76
202
|
def agents(self):
|
77
203
|
return self._agents
|
@@ -214,13 +340,29 @@ class Jobs(Base):
|
|
214
340
|
|
215
341
|
def replace_missing_objects(self) -> None:
|
216
342
|
from edsl.agents.Agent import Agent
|
217
|
-
from edsl.language_models.
|
343
|
+
from edsl.language_models.model import Model
|
218
344
|
from edsl.scenarios.Scenario import Scenario
|
219
345
|
|
220
346
|
self.agents = self.agents or [Agent()]
|
221
347
|
self.models = self.models or [Model()]
|
222
348
|
self.scenarios = self.scenarios or [Scenario()]
|
223
349
|
|
350
|
+
def generate_interviews(self) -> Generator[Interview, None, None]:
|
351
|
+
"""
|
352
|
+
Generate interviews.
|
353
|
+
|
354
|
+
Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
|
355
|
+
This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
|
356
|
+
with us filling in defaults.
|
357
|
+
|
358
|
+
"""
|
359
|
+
from edsl.jobs.InterviewsConstructor import InterviewsConstructor
|
360
|
+
|
361
|
+
self.replace_missing_objects()
|
362
|
+
yield from InterviewsConstructor(
|
363
|
+
self, cache=self.run_config.environment.cache
|
364
|
+
).create_interviews()
|
365
|
+
|
224
366
|
def interviews(self) -> list[Interview]:
|
225
367
|
"""
|
226
368
|
Return a list of :class:`edsl.jobs.interviews.Interview` objects.
|
@@ -235,18 +377,10 @@ class Jobs(Base):
|
|
235
377
|
>>> j.interviews()[0]
|
236
378
|
Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
|
237
379
|
"""
|
238
|
-
|
239
|
-
return self._interviews
|
240
|
-
else:
|
241
|
-
self.replace_missing_objects()
|
242
|
-
from edsl.jobs.InterviewsConstructor import InterviewsConstructor
|
243
|
-
|
244
|
-
self._interviews = list(InterviewsConstructor(self).create_interviews())
|
245
|
-
|
246
|
-
return self._interviews
|
380
|
+
return list(self.generate_interviews())
|
247
381
|
|
248
382
|
@classmethod
|
249
|
-
def from_interviews(cls, interview_list):
|
383
|
+
def from_interviews(cls, interview_list) -> "Jobs":
|
250
384
|
"""Return a Jobs instance from a list of interviews.
|
251
385
|
|
252
386
|
This is useful when you have, say, a list of failed interviews and you want to create
|
@@ -273,16 +407,8 @@ class Jobs(Base):
|
|
273
407
|
>>> bc
|
274
408
|
BucketCollection(...)
|
275
409
|
"""
|
276
|
-
self.replace_missing_objects() # ensure that all objects are present
|
277
410
|
return BucketCollection.from_models(self.models)
|
278
411
|
|
279
|
-
@property
|
280
|
-
def bucket_collection(self) -> BucketCollection:
|
281
|
-
"""Return the bucket collection. If it does not exist, create it."""
|
282
|
-
if self.__bucket_collection is None:
|
283
|
-
self.__bucket_collection = self.create_bucket_collection()
|
284
|
-
return self.__bucket_collection
|
285
|
-
|
286
412
|
def html(self):
|
287
413
|
"""Return the HTML representations for each scenario"""
|
288
414
|
links = []
|
@@ -308,10 +434,12 @@ class Jobs(Base):
|
|
308
434
|
|
309
435
|
def _output(self, message) -> None:
|
310
436
|
"""Check if a Job is verbose. If so, print the message."""
|
311
|
-
if
|
437
|
+
if self.run_config.parameters.verbose:
|
312
438
|
print(message)
|
439
|
+
# if hasattr(self, "verbose") and self.verbose:
|
440
|
+
# print(message)
|
313
441
|
|
314
|
-
def all_question_parameters(self):
|
442
|
+
def all_question_parameters(self) -> set:
|
315
443
|
"""Return all the fields in the questions in the survey.
|
316
444
|
>>> from edsl.jobs import Jobs
|
317
445
|
>>> Jobs.example().all_question_parameters()
|
@@ -319,86 +447,12 @@ class Jobs(Base):
|
|
319
447
|
"""
|
320
448
|
return set.union(*[question.parameters for question in self.survey.questions])
|
321
449
|
|
322
|
-
def
|
323
|
-
"""Check if the parameters in the survey and scenarios are consistent.
|
324
|
-
|
325
|
-
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
326
|
-
>>> from edsl.surveys.Survey import Survey
|
327
|
-
>>> from edsl.scenarios.Scenario import Scenario
|
328
|
-
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
329
|
-
>>> j = Jobs(survey = Survey(questions=[q]))
|
330
|
-
>>> with warnings.catch_warnings(record=True) as w:
|
331
|
-
... j._check_parameters(warn = True)
|
332
|
-
... assert len(w) == 1
|
333
|
-
... assert issubclass(w[-1].category, UserWarning)
|
334
|
-
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
335
|
-
|
336
|
-
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
337
|
-
>>> s = Scenario({'plop': "A", 'poo': "B"})
|
338
|
-
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
339
|
-
>>> j._check_parameters(strict = True)
|
340
|
-
Traceback (most recent call last):
|
341
|
-
...
|
342
|
-
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
343
|
-
|
344
|
-
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
345
|
-
>>> s = Scenario({'ugly_question': "B"})
|
346
|
-
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
347
|
-
>>> j._check_parameters()
|
348
|
-
Traceback (most recent call last):
|
349
|
-
...
|
350
|
-
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
351
|
-
"""
|
352
|
-
survey_parameters: set = self.survey.parameters
|
353
|
-
scenario_parameters: set = self.scenarios.parameters
|
354
|
-
|
355
|
-
msg0, msg1, msg2 = None, None, None
|
356
|
-
|
357
|
-
# look for key issues
|
358
|
-
if intersection := set(self.scenarios.parameters) & set(
|
359
|
-
self.survey.question_names
|
360
|
-
):
|
361
|
-
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
362
|
-
|
363
|
-
raise ValueError(msg0)
|
364
|
-
|
365
|
-
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
366
|
-
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
367
|
-
if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
|
368
|
-
msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
|
369
|
-
|
370
|
-
if msg1 or msg2:
|
371
|
-
message = "\n".join(filter(None, [msg1, msg2]))
|
372
|
-
if strict:
|
373
|
-
raise ValueError(message)
|
374
|
-
else:
|
375
|
-
if warn:
|
376
|
-
warnings.warn(message)
|
377
|
-
|
378
|
-
if self.scenarios.has_jinja_braces:
|
379
|
-
warnings.warn(
|
380
|
-
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
381
|
-
)
|
382
|
-
self.scenarios = self.scenarios._convert_jinja_braces()
|
383
|
-
|
384
|
-
@property
|
385
|
-
def skip_retry(self):
|
386
|
-
if not hasattr(self, "_skip_retry"):
|
387
|
-
return False
|
388
|
-
return self._skip_retry
|
389
|
-
|
390
|
-
@property
|
391
|
-
def raise_validation_errors(self):
|
392
|
-
if not hasattr(self, "_raise_validation_errors"):
|
393
|
-
return False
|
394
|
-
return self._raise_validation_errors
|
395
|
-
|
396
|
-
def use_remote_cache(self, disable_remote_cache: bool) -> bool:
|
450
|
+
def use_remote_cache(self) -> bool:
|
397
451
|
import requests
|
398
452
|
|
399
|
-
if disable_remote_cache:
|
453
|
+
if self.run_config.parameters.disable_remote_cache:
|
400
454
|
return False
|
401
|
-
if not disable_remote_cache:
|
455
|
+
if not self.run_config.parameters.disable_remote_cache:
|
402
456
|
try:
|
403
457
|
from edsl.coop.coop import Coop
|
404
458
|
|
@@ -411,154 +465,173 @@ class Jobs(Base):
|
|
411
465
|
|
412
466
|
return False
|
413
467
|
|
414
|
-
def
|
468
|
+
def _remote_results(
|
415
469
|
self,
|
416
|
-
|
417
|
-
progress_bar: bool = False,
|
418
|
-
stop_on_exception: bool = False,
|
419
|
-
cache: Union["Cache", bool] = None,
|
420
|
-
check_api_keys: bool = False,
|
421
|
-
sidecar_model: Optional[LanguageModel] = None,
|
422
|
-
verbose: bool = True,
|
423
|
-
print_exceptions=True,
|
424
|
-
remote_cache_description: Optional[str] = None,
|
425
|
-
remote_inference_description: Optional[str] = None,
|
426
|
-
remote_inference_results_visibility: Optional[
|
427
|
-
Literal["private", "public", "unlisted"]
|
428
|
-
] = "unlisted",
|
429
|
-
skip_retry: bool = False,
|
430
|
-
raise_validation_errors: bool = False,
|
431
|
-
disable_remote_cache: bool = False,
|
432
|
-
disable_remote_inference: bool = False,
|
433
|
-
bucket_collection: Optional[BucketCollection] = None,
|
434
|
-
) -> Results:
|
435
|
-
"""
|
436
|
-
Runs the Job: conducts Interviews and returns their results.
|
437
|
-
|
438
|
-
:param n: How many times to run each interview
|
439
|
-
:param progress_bar: Whether to show a progress bar
|
440
|
-
:param stop_on_exception: Stops the job if an exception is raised
|
441
|
-
:param cache: A Cache object to store results
|
442
|
-
:param check_api_keys: Raises an error if API keys are invalid
|
443
|
-
:param verbose: Prints extra messages
|
444
|
-
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
445
|
-
:param remote_inference_description: Specifies a description for the remote inference job
|
446
|
-
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
447
|
-
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
448
|
-
:param disable_remote_inference: If True, the job will not use remote inference
|
449
|
-
"""
|
450
|
-
from edsl.coop.coop import Coop
|
451
|
-
from edsl.jobs.JobsChecks import JobsChecks
|
470
|
+
) -> Union["Results", None]:
|
452
471
|
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
453
472
|
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
self.
|
473
|
+
jh = JobsRemoteInferenceHandler(
|
474
|
+
self, verbose=self.run_config.parameters.verbose
|
475
|
+
)
|
476
|
+
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
477
|
+
job_info = jh.create_remote_inference_job(
|
478
|
+
iterations=self.run_config.parameters.n,
|
479
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
480
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
481
|
+
)
|
482
|
+
results = jh.poll_remote_inference_job(job_info)
|
483
|
+
return results
|
484
|
+
else:
|
485
|
+
return None
|
486
|
+
|
487
|
+
def _prepare_to_run(self) -> None:
|
488
|
+
"This makes sure that the job is ready to run and that keys are in place for a remote job."
|
489
|
+
CheckSurveyScenarioCompatibility(self.survey, self.scenarios).check()
|
458
490
|
|
491
|
+
def _check_if_remote_keys_ok(self):
|
459
492
|
jc = JobsChecks(self)
|
460
|
-
|
461
|
-
# check if the user has all the keys they need
|
462
493
|
if jc.needs_key_process():
|
463
494
|
jc.key_process()
|
464
495
|
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
iterations=n,
|
469
|
-
remote_inference_description=remote_inference_description,
|
470
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
471
|
-
)
|
472
|
-
results = jh.poll_remote_inference_job()
|
473
|
-
return results
|
474
|
-
|
475
|
-
if check_api_keys:
|
496
|
+
def _check_if_local_keys_ok(self):
|
497
|
+
jc = JobsChecks(self)
|
498
|
+
if self.run_config.parameters.check_api_keys:
|
476
499
|
jc.check_api_keys()
|
477
500
|
|
478
|
-
|
479
|
-
if cache is None or cache is True:
|
480
|
-
from edsl.data.CacheHandler import CacheHandler
|
501
|
+
async def _execute_with_remote_cache(self, run_job_async: bool) -> Results:
|
481
502
|
|
482
|
-
|
483
|
-
if cache is False:
|
484
|
-
from edsl.data.Cache import Cache
|
503
|
+
use_remote_cache = self.use_remote_cache()
|
485
504
|
|
486
|
-
|
505
|
+
from edsl.coop.coop import Coop
|
506
|
+
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
507
|
+
from edsl.data.Cache import Cache
|
487
508
|
|
488
|
-
|
489
|
-
bucket_collection = self.create_bucket_collection()
|
509
|
+
assert isinstance(self.run_config.environment.cache, Cache)
|
490
510
|
|
491
|
-
remote_cache = self.use_remote_cache(disable_remote_cache)
|
492
511
|
with RemoteCacheSync(
|
493
512
|
coop=Coop(),
|
494
|
-
cache=cache,
|
513
|
+
cache=self.run_config.environment.cache,
|
495
514
|
output_func=self._output,
|
496
|
-
remote_cache=
|
497
|
-
remote_cache_description=remote_cache_description,
|
498
|
-
)
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
sidecar_model=sidecar_model,
|
505
|
-
print_exceptions=print_exceptions,
|
506
|
-
raise_validation_errors=raise_validation_errors,
|
507
|
-
bucket_collection=bucket_collection,
|
508
|
-
)
|
515
|
+
remote_cache=use_remote_cache,
|
516
|
+
remote_cache_description=self.run_config.parameters.remote_cache_description,
|
517
|
+
):
|
518
|
+
runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
|
519
|
+
if run_job_async:
|
520
|
+
results = await runner.run_async(self.run_config.parameters)
|
521
|
+
else:
|
522
|
+
results = runner.run(self.run_config.parameters)
|
509
523
|
return results
|
510
524
|
|
511
|
-
|
512
|
-
self,
|
513
|
-
cache=None,
|
514
|
-
n=1,
|
515
|
-
disable_remote_inference: bool = False,
|
516
|
-
remote_inference_description: Optional[str] = None,
|
517
|
-
remote_inference_results_visibility: Optional[
|
518
|
-
Literal["private", "public", "unlisted"]
|
519
|
-
] = "unlisted",
|
520
|
-
bucket_collection: Optional[BucketCollection] = None,
|
521
|
-
**kwargs,
|
522
|
-
):
|
523
|
-
"""Run the job asynchronously, either locally or remotely.
|
525
|
+
def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
|
524
526
|
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
527
|
+
self._prepare_to_run()
|
528
|
+
self._check_if_remote_keys_ok()
|
529
|
+
|
530
|
+
# first try to run the job remotely
|
531
|
+
if results := self._remote_results():
|
532
|
+
return results
|
533
|
+
|
534
|
+
self._check_if_local_keys_ok()
|
535
|
+
return None
|
536
|
+
|
537
|
+
@property
|
538
|
+
def num_interviews(self):
|
539
|
+
if self.run_config.parameters.n is None:
|
540
|
+
return len(self)
|
541
|
+
else:
|
542
|
+
len(self) * self.run_config.parameters.n
|
536
543
|
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
544
|
+
def _run(self, config: RunConfig):
|
545
|
+
"Shared code for run and run_async"
|
546
|
+
if config.environment.cache is not None:
|
547
|
+
self.run_config.environment.cache = config.environment.cache
|
548
|
+
|
549
|
+
if config.environment.bucket_collection is not None:
|
550
|
+
self.run_config.environment.bucket_collection = (
|
551
|
+
config.environment.bucket_collection
|
543
552
|
)
|
553
|
+
|
554
|
+
if config.environment.key_lookup is not None:
|
555
|
+
self.run_config.environment.key_lookup = config.environment.key_lookup
|
556
|
+
|
557
|
+
# replace the parameters with the ones from the config
|
558
|
+
self.run_config.parameters = config.parameters
|
559
|
+
|
560
|
+
self.replace_missing_objects()
|
561
|
+
|
562
|
+
# try to run remotely first
|
563
|
+
self._prepare_to_run()
|
564
|
+
self._check_if_remote_keys_ok()
|
565
|
+
|
566
|
+
if (
|
567
|
+
self.run_config.environment.cache is None
|
568
|
+
or self.run_config.environment.cache is True
|
569
|
+
):
|
570
|
+
from edsl.data.CacheHandler import CacheHandler
|
571
|
+
|
572
|
+
self.run_config.environment.cache = CacheHandler().get_cache()
|
573
|
+
|
574
|
+
if self.run_config.environment.cache is False:
|
575
|
+
from edsl.data.Cache import Cache
|
576
|
+
|
577
|
+
self.run_config.environment.cache = Cache(immediate_write=False)
|
578
|
+
|
579
|
+
# first try to run the job remotely
|
580
|
+
if results := self._remote_results():
|
544
581
|
return results
|
545
582
|
|
546
|
-
|
547
|
-
bucket_collection = self.create_bucket_collection()
|
583
|
+
self._check_if_local_keys_ok()
|
548
584
|
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
585
|
+
if config.environment.bucket_collection is None:
|
586
|
+
self.run_config.environment.bucket_collection = (
|
587
|
+
self.create_bucket_collection()
|
588
|
+
)
|
553
589
|
|
554
|
-
|
555
|
-
|
556
|
-
|
590
|
+
@with_config
|
591
|
+
def run(self, *, config: RunConfig) -> "Results":
|
592
|
+
"""
|
593
|
+
Runs the Job: conducts Interviews and returns their results.
|
557
594
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
595
|
+
:param n: How many times to run each interview
|
596
|
+
:param progress_bar: Whether to show a progress bar
|
597
|
+
:param stop_on_exception: Stops the job if an exception is raised
|
598
|
+
:param check_api_keys: Raises an error if API keys are invalid
|
599
|
+
:param verbose: Prints extra messages
|
600
|
+
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
601
|
+
:param remote_inference_description: Specifies a description for the remote inference job
|
602
|
+
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
603
|
+
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
604
|
+
:param disable_remote_inference: If True, the job will not use remote inference
|
605
|
+
:param cache: A Cache object to store results
|
606
|
+
:param bucket_collection: A BucketCollection object to track API calls
|
607
|
+
:param key_lookup: A KeyLookup object to manage API keys
|
608
|
+
"""
|
609
|
+
self._run(config)
|
610
|
+
|
611
|
+
return asyncio.run(self._execute_with_remote_cache(run_job_async=False))
|
612
|
+
|
613
|
+
@with_config
|
614
|
+
async def run_async(self, *, config: RunConfig) -> "Results":
|
615
|
+
"""
|
616
|
+
Runs the Job: conducts Interviews and returns their results.
|
617
|
+
|
618
|
+
:param n: How many times to run each interview
|
619
|
+
:param progress_bar: Whether to show a progress bar
|
620
|
+
:param stop_on_exception: Stops the job if an exception is raised
|
621
|
+
:param check_api_keys: Raises an error if API keys are invalid
|
622
|
+
:param verbose: Prints extra messages
|
623
|
+
:param remote_cache_description: Specifies a description for this group of entries in the remote cache
|
624
|
+
:param remote_inference_description: Specifies a description for the remote inference job
|
625
|
+
:param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
|
626
|
+
:param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
|
627
|
+
:param disable_remote_inference: If True, the job will not use remote inference
|
628
|
+
:param cache: A Cache object to store results
|
629
|
+
:param bucket_collection: A BucketCollection object to track API calls
|
630
|
+
:param key_lookup: A KeyLookup object to manage API keys
|
631
|
+
"""
|
632
|
+
self._run(config)
|
633
|
+
|
634
|
+
return await self._execute_with_remote_cache(run_job_async=True)
|
562
635
|
|
563
636
|
def __repr__(self) -> str:
|
564
637
|
"""Return an eval-able string representation of the Jobs instance."""
|
@@ -588,10 +661,6 @@ class Jobs(Base):
|
|
588
661
|
)
|
589
662
|
return number_of_questions
|
590
663
|
|
591
|
-
#######################
|
592
|
-
# Serialization methods
|
593
|
-
#######################
|
594
|
-
|
595
664
|
def to_dict(self, add_edsl_version=True):
|
596
665
|
d = {
|
597
666
|
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
@@ -645,9 +714,6 @@ class Jobs(Base):
|
|
645
714
|
"""
|
646
715
|
return hash(self) == hash(other)
|
647
716
|
|
648
|
-
#######################
|
649
|
-
# Example methods
|
650
|
-
#######################
|
651
717
|
@classmethod
|
652
718
|
def example(
|
653
719
|
cls,
|
edsl/jobs/JobsChecks.py
CHANGED
@@ -8,7 +8,7 @@ class JobsChecks:
|
|
8
8
|
self.jobs = jobs
|
9
9
|
|
10
10
|
def check_api_keys(self) -> None:
|
11
|
-
from edsl.language_models.
|
11
|
+
from edsl.language_models.model import Model
|
12
12
|
|
13
13
|
if len(self.jobs.models) == 0:
|
14
14
|
models = [Model()]
|
@@ -28,7 +28,7 @@ class JobsChecks:
|
|
28
28
|
"""
|
29
29
|
missing_api_keys = set()
|
30
30
|
|
31
|
-
from edsl.language_models.
|
31
|
+
from edsl.language_models.model import Model
|
32
32
|
from edsl.enums import service_to_api_keyname
|
33
33
|
|
34
34
|
for model in self.jobs.models + [Model()]:
|