edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. edsl/Base.py +0 -28
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +17 -9
  5. edsl/agents/Invigilator.py +14 -13
  6. edsl/agents/InvigilatorBase.py +1 -4
  7. edsl/agents/PromptConstructor.py +22 -42
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +5 -18
  10. edsl/auto/StageBase.py +40 -53
  11. edsl/auto/StageQuestions.py +1 -2
  12. edsl/auto/utilities.py +6 -0
  13. edsl/coop/coop.py +5 -21
  14. edsl/data/Cache.py +18 -29
  15. edsl/data/CacheHandler.py +2 -0
  16. edsl/data/RemoteCacheSync.py +46 -154
  17. edsl/enums.py +0 -7
  18. edsl/inference_services/AnthropicService.py +16 -38
  19. edsl/inference_services/AvailableModelFetcher.py +1 -7
  20. edsl/inference_services/GoogleService.py +1 -5
  21. edsl/inference_services/InferenceServicesCollection.py +2 -18
  22. edsl/inference_services/OpenAIService.py +31 -46
  23. edsl/inference_services/TestService.py +3 -1
  24. edsl/inference_services/TogetherAIService.py +3 -5
  25. edsl/inference_services/data_structures.py +2 -74
  26. edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
  27. edsl/jobs/FetchInvigilator.py +3 -10
  28. edsl/jobs/InterviewsConstructor.py +4 -6
  29. edsl/jobs/Jobs.py +233 -299
  30. edsl/jobs/JobsChecks.py +2 -2
  31. edsl/jobs/JobsPrompts.py +1 -1
  32. edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
  33. edsl/jobs/interviews/Interview.py +42 -80
  34. edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
  35. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  36. edsl/jobs/tasks/TaskHistory.py +3 -24
  37. edsl/language_models/LanguageModel.py +4 -59
  38. edsl/language_models/ModelList.py +8 -19
  39. edsl/language_models/__init__.py +1 -1
  40. edsl/language_models/registry.py +180 -0
  41. edsl/language_models/repair.py +1 -1
  42. edsl/questions/QuestionBase.py +26 -35
  43. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
  44. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  45. edsl/questions/QuestionBudget.py +1 -1
  46. edsl/questions/QuestionCheckBox.py +2 -2
  47. edsl/questions/QuestionExtract.py +7 -5
  48. edsl/questions/QuestionFreeText.py +1 -1
  49. edsl/questions/QuestionList.py +15 -9
  50. edsl/questions/QuestionMatrix.py +1 -1
  51. edsl/questions/QuestionMultipleChoice.py +1 -1
  52. edsl/questions/QuestionNumerical.py +1 -1
  53. edsl/questions/QuestionRank.py +1 -1
  54. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
  55. edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
  56. edsl/questions/SimpleAskMixin.py +1 -1
  57. edsl/questions/__init__.py +1 -1
  58. edsl/results/DatasetExportMixin.py +119 -60
  59. edsl/results/Result.py +3 -109
  60. edsl/results/Results.py +39 -50
  61. edsl/scenarios/FileStore.py +0 -32
  62. edsl/scenarios/ScenarioList.py +7 -35
  63. edsl/scenarios/handlers/csv.py +0 -11
  64. edsl/surveys/Survey.py +20 -71
  65. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
  66. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
  67. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
  68. edsl/jobs/async_interview_runner.py +0 -138
  69. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  70. edsl/jobs/data_structures.py +0 -120
  71. edsl/jobs/results_exceptions_handler.py +0 -98
  72. edsl/language_models/model.py +0 -256
  73. edsl/questions/data_structures.py +0 -20
  74. edsl/results/file_exports.py +0 -252
  75. /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
  76. /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
  77. /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
  78. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  79. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  80. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  81. /edsl/results/{results_selector.py → Selector.py} +0 -0
  82. /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
  83. /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
  84. /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
  85. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -1,17 +1,7 @@
1
1
  # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
2
  from __future__ import annotations
3
- import asyncio
4
- from inspect import signature
5
- from typing import (
6
- Literal,
7
- Optional,
8
- Union,
9
- Sequence,
10
- Generator,
11
- TYPE_CHECKING,
12
- Callable,
13
- Tuple,
14
- )
3
+ import warnings
4
+ from typing import Literal, Optional, Union, Sequence, Generator, TYPE_CHECKING
15
5
 
16
6
  from edsl.Base import Base
17
7
 
@@ -19,13 +9,10 @@ from edsl.jobs.buckets.BucketCollection import BucketCollection
19
9
  from edsl.jobs.JobsPrompts import JobsPrompts
20
10
  from edsl.jobs.interviews.Interview import Interview
21
11
  from edsl.utilities.remove_edsl_version import remove_edsl_version
22
- from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
12
+
23
13
  from edsl.data.RemoteCacheSync import RemoteCacheSync
24
14
  from edsl.exceptions.coop import CoopServerResponseError
25
15
 
26
- from edsl.jobs.JobsChecks import JobsChecks
27
- from edsl.jobs.data_structures import RunEnvironment, RunParameters, RunConfig
28
-
29
16
  if TYPE_CHECKING:
30
17
  from edsl.agents.Agent import Agent
31
18
  from edsl.agents.AgentList import AgentList
@@ -36,66 +23,6 @@ if TYPE_CHECKING:
36
23
  from edsl.results.Results import Results
37
24
  from edsl.results.Dataset import Dataset
38
25
  from edsl.language_models.ModelList import ModelList
39
- from edsl.data.Cache import Cache
40
- from edsl.language_models.key_management.KeyLookup import KeyLookup
41
-
42
- VisibilityType = Literal["private", "public", "unlisted"]
43
-
44
- from dataclasses import dataclass
45
- from typing import Optional, Union, TypeVar, Callable, cast
46
- from functools import wraps
47
-
48
- try:
49
- from typing import ParamSpec
50
- except ImportError:
51
- from typing_extensions import ParamSpec
52
-
53
-
54
- P = ParamSpec("P")
55
- T = TypeVar("T")
56
-
57
-
58
- from edsl.jobs.check_survey_scenario_compatibility import (
59
- CheckSurveyScenarioCompatibility,
60
- )
61
-
62
-
63
- def with_config(f: Callable[P, T]) -> Callable[P, T]:
64
- "This decorator make it so that the run function parameters match the RunConfig dataclass."
65
- parameter_fields = {
66
- name: field.default
67
- for name, field in RunParameters.__dataclass_fields__.items()
68
- }
69
- environment_fields = {
70
- name: field.default
71
- for name, field in RunEnvironment.__dataclass_fields__.items()
72
- }
73
- combined = {**parameter_fields, **environment_fields}
74
-
75
- @wraps(f)
76
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
77
- environment = RunEnvironment(
78
- **{k: v for k, v in kwargs.items() if k in environment_fields}
79
- )
80
- parameters = RunParameters(
81
- **{k: v for k, v in kwargs.items() if k in parameter_fields}
82
- )
83
- config = RunConfig(environment=environment, parameters=parameters)
84
- return f(*args, config=config)
85
-
86
- # Update the wrapper's signature to include all RunConfig parameters
87
- # old_sig = signature(f)
88
- # wrapper.__signature__ = old_sig.replace(
89
- # parameters=list(old_sig.parameters.values())[:-1]
90
- # + [
91
- # old_sig.parameters["config"].replace(
92
- # default=parameter_fields[name], name=name
93
- # )
94
- # for name in combined
95
- # ]
96
- # )
97
-
98
- return cast(Callable[P, T], wrapper)
99
26
 
100
27
 
101
28
  class Jobs(Base):
@@ -119,62 +46,15 @@ class Jobs(Base):
119
46
  :param models: a list of models
120
47
  :param scenarios: a list of scenarios
121
48
  """
122
- self.run_config = RunConfig(
123
- environment=RunEnvironment(), parameters=RunParameters()
124
- )
125
-
126
49
  self.survey = survey
127
50
  self.agents: AgentList = agents
128
51
  self.scenarios: ScenarioList = scenarios
129
- self.models: ModelList = models
130
-
131
- def add_running_env(self, running_env: RunEnvironment):
132
- self.run_config.add_environment(running_env)
133
- return self
52
+ self.models = models
134
53
 
135
- def using_cache(self, cache: "Cache") -> Jobs:
136
- """
137
- Add a Cache to the job.
138
-
139
- :param cache: the cache to add
140
- """
141
- self.run_config.add_cache(cache)
142
- return self
143
-
144
- def using_bucket_collection(self, bucket_collection: BucketCollection) -> Jobs:
145
- """
146
- Add a BucketCollection to the job.
147
-
148
- :param bucket_collection: the bucket collection to add
149
- """
150
- self.run_config.add_bucket_collection(bucket_collection)
151
- return self
152
-
153
- def using_key_lookup(self, key_lookup: KeyLookup) -> Jobs:
154
- """
155
- Add a KeyLookup to the job.
156
-
157
- :param key_lookup: the key lookup to add
158
- """
159
- self.run_config.add_key_lookup(key_lookup)
160
- return self
161
-
162
- def using(self, obj: Union[Cache, BucketCollection, KeyLookup]) -> Jobs:
163
- """
164
- Add a Cache, BucketCollection, or KeyLookup to the job.
165
-
166
- :param obj: the object to add
167
- """
168
- from edsl.data.Cache import Cache
169
- from edsl.language_models.key_management.KeyLookup import KeyLookup
54
+ self.__bucket_collection = None
170
55
 
171
- if isinstance(obj, Cache):
172
- self.using_cache(obj)
173
- elif isinstance(obj, BucketCollection):
174
- self.using_bucket_collection(obj)
175
- elif isinstance(obj, KeyLookup):
176
- self.using_key_lookup(obj)
177
- return self
56
+ # these setters and getters are used to ensure that the agents, models, and scenarios
57
+ # are stored as AgentList, ModelList, and ScenarioList objects.
178
58
 
179
59
  @property
180
60
  def models(self):
@@ -192,12 +72,6 @@ class Jobs(Base):
192
72
  else:
193
73
  self._models = ModelList([])
194
74
 
195
- # update the bucket collection if it exists
196
- if self.run_config.environment.bucket_collection is None:
197
- self.run_config.environment.bucket_collection = (
198
- self.create_bucket_collection()
199
- )
200
-
201
75
  @property
202
76
  def agents(self):
203
77
  return self._agents
@@ -340,29 +214,13 @@ class Jobs(Base):
340
214
 
341
215
  def replace_missing_objects(self) -> None:
342
216
  from edsl.agents.Agent import Agent
343
- from edsl.language_models.model import Model
217
+ from edsl.language_models.registry import Model
344
218
  from edsl.scenarios.Scenario import Scenario
345
219
 
346
220
  self.agents = self.agents or [Agent()]
347
221
  self.models = self.models or [Model()]
348
222
  self.scenarios = self.scenarios or [Scenario()]
349
223
 
350
- def generate_interviews(self) -> Generator[Interview, None, None]:
351
- """
352
- Generate interviews.
353
-
354
- Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
355
- This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
356
- with us filling in defaults.
357
-
358
- """
359
- from edsl.jobs.InterviewsConstructor import InterviewsConstructor
360
-
361
- self.replace_missing_objects()
362
- yield from InterviewsConstructor(
363
- self, cache=self.run_config.environment.cache
364
- ).create_interviews()
365
-
366
224
  def interviews(self) -> list[Interview]:
367
225
  """
368
226
  Return a list of :class:`edsl.jobs.interviews.Interview` objects.
@@ -377,10 +235,18 @@ class Jobs(Base):
377
235
  >>> j.interviews()[0]
378
236
  Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
379
237
  """
380
- return list(self.generate_interviews())
238
+ if hasattr(self, "_interviews"):
239
+ return self._interviews
240
+ else:
241
+ self.replace_missing_objects()
242
+ from edsl.jobs.InterviewsConstructor import InterviewsConstructor
243
+
244
+ self._interviews = list(InterviewsConstructor(self).create_interviews())
245
+
246
+ return self._interviews
381
247
 
382
248
  @classmethod
383
- def from_interviews(cls, interview_list) -> "Jobs":
249
+ def from_interviews(cls, interview_list):
384
250
  """Return a Jobs instance from a list of interviews.
385
251
 
386
252
  This is useful when you have, say, a list of failed interviews and you want to create
@@ -407,8 +273,16 @@ class Jobs(Base):
407
273
  >>> bc
408
274
  BucketCollection(...)
409
275
  """
276
+ self.replace_missing_objects() # ensure that all objects are present
410
277
  return BucketCollection.from_models(self.models)
411
278
 
279
+ @property
280
+ def bucket_collection(self) -> BucketCollection:
281
+ """Return the bucket collection. If it does not exist, create it."""
282
+ if self.__bucket_collection is None:
283
+ self.__bucket_collection = self.create_bucket_collection()
284
+ return self.__bucket_collection
285
+
412
286
  def html(self):
413
287
  """Return the HTML representations for each scenario"""
414
288
  links = []
@@ -434,12 +308,10 @@ class Jobs(Base):
434
308
 
435
309
  def _output(self, message) -> None:
436
310
  """Check if a Job is verbose. If so, print the message."""
437
- if self.run_config.parameters.verbose:
311
+ if hasattr(self, "verbose") and self.verbose:
438
312
  print(message)
439
- # if hasattr(self, "verbose") and self.verbose:
440
- # print(message)
441
313
 
442
- def all_question_parameters(self) -> set:
314
+ def all_question_parameters(self):
443
315
  """Return all the fields in the questions in the survey.
444
316
  >>> from edsl.jobs import Jobs
445
317
  >>> Jobs.example().all_question_parameters()
@@ -447,12 +319,86 @@ class Jobs(Base):
447
319
  """
448
320
  return set.union(*[question.parameters for question in self.survey.questions])
449
321
 
450
- def use_remote_cache(self) -> bool:
322
+ def _check_parameters(self, strict=False, warn=False) -> None:
323
+ """Check if the parameters in the survey and scenarios are consistent.
324
+
325
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
326
+ >>> from edsl.surveys.Survey import Survey
327
+ >>> from edsl.scenarios.Scenario import Scenario
328
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
329
+ >>> j = Jobs(survey = Survey(questions=[q]))
330
+ >>> with warnings.catch_warnings(record=True) as w:
331
+ ... j._check_parameters(warn = True)
332
+ ... assert len(w) == 1
333
+ ... assert issubclass(w[-1].category, UserWarning)
334
+ ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
335
+
336
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
337
+ >>> s = Scenario({'plop': "A", 'poo': "B"})
338
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
339
+ >>> j._check_parameters(strict = True)
340
+ Traceback (most recent call last):
341
+ ...
342
+ ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
343
+
344
+ >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
345
+ >>> s = Scenario({'ugly_question': "B"})
346
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
347
+ >>> j._check_parameters()
348
+ Traceback (most recent call last):
349
+ ...
350
+ ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
351
+ """
352
+ survey_parameters: set = self.survey.parameters
353
+ scenario_parameters: set = self.scenarios.parameters
354
+
355
+ msg0, msg1, msg2 = None, None, None
356
+
357
+ # look for key issues
358
+ if intersection := set(self.scenarios.parameters) & set(
359
+ self.survey.question_names
360
+ ):
361
+ msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
362
+
363
+ raise ValueError(msg0)
364
+
365
+ if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
366
+ msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
367
+ if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
368
+ msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
369
+
370
+ if msg1 or msg2:
371
+ message = "\n".join(filter(None, [msg1, msg2]))
372
+ if strict:
373
+ raise ValueError(message)
374
+ else:
375
+ if warn:
376
+ warnings.warn(message)
377
+
378
+ if self.scenarios.has_jinja_braces:
379
+ warnings.warn(
380
+ "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
381
+ )
382
+ self.scenarios = self.scenarios._convert_jinja_braces()
383
+
384
+ @property
385
+ def skip_retry(self):
386
+ if not hasattr(self, "_skip_retry"):
387
+ return False
388
+ return self._skip_retry
389
+
390
+ @property
391
+ def raise_validation_errors(self):
392
+ if not hasattr(self, "_raise_validation_errors"):
393
+ return False
394
+ return self._raise_validation_errors
395
+
396
+ def use_remote_cache(self, disable_remote_cache: bool) -> bool:
451
397
  import requests
452
398
 
453
- if self.run_config.parameters.disable_remote_cache:
399
+ if disable_remote_cache:
454
400
  return False
455
- if not self.run_config.parameters.disable_remote_cache:
401
+ if not disable_remote_cache:
456
402
  try:
457
403
  from edsl.coop.coop import Coop
458
404
 
@@ -465,173 +411,154 @@ class Jobs(Base):
465
411
 
466
412
  return False
467
413
 
468
- def _remote_results(
414
+ def run(
469
415
  self,
470
- ) -> Union["Results", None]:
471
- from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
416
+ n: int = 1,
417
+ progress_bar: bool = False,
418
+ stop_on_exception: bool = False,
419
+ cache: Union["Cache", bool] = None,
420
+ check_api_keys: bool = False,
421
+ sidecar_model: Optional[LanguageModel] = None,
422
+ verbose: bool = True,
423
+ print_exceptions=True,
424
+ remote_cache_description: Optional[str] = None,
425
+ remote_inference_description: Optional[str] = None,
426
+ remote_inference_results_visibility: Optional[
427
+ Literal["private", "public", "unlisted"]
428
+ ] = "unlisted",
429
+ skip_retry: bool = False,
430
+ raise_validation_errors: bool = False,
431
+ disable_remote_cache: bool = False,
432
+ disable_remote_inference: bool = False,
433
+ bucket_collection: Optional[BucketCollection] = None,
434
+ ) -> Results:
435
+ """
436
+ Runs the Job: conducts Interviews and returns their results.
472
437
 
473
- jh = JobsRemoteInferenceHandler(
474
- self, verbose=self.run_config.parameters.verbose
475
- )
476
- if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
477
- job_info = jh.create_remote_inference_job(
478
- iterations=self.run_config.parameters.n,
479
- remote_inference_description=self.run_config.parameters.remote_inference_description,
480
- remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
481
- )
482
- results = jh.poll_remote_inference_job(job_info)
483
- return results
484
- else:
485
- return None
438
+ :param n: How many times to run each interview
439
+ :param progress_bar: Whether to show a progress bar
440
+ :param stop_on_exception: Stops the job if an exception is raised
441
+ :param cache: A Cache object to store results
442
+ :param check_api_keys: Raises an error if API keys are invalid
443
+ :param verbose: Prints extra messages
444
+ :param remote_cache_description: Specifies a description for this group of entries in the remote cache
445
+ :param remote_inference_description: Specifies a description for the remote inference job
446
+ :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
447
+ :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
448
+ :param disable_remote_inference: If True, the job will not use remote inference
449
+ """
450
+ from edsl.coop.coop import Coop
451
+ from edsl.jobs.JobsChecks import JobsChecks
452
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
486
453
 
487
- def _prepare_to_run(self) -> None:
488
- "This makes sure that the job is ready to run and that keys are in place for a remote job."
489
- CheckSurveyScenarioCompatibility(self.survey, self.scenarios).check()
454
+ self._check_parameters()
455
+ self._skip_retry = skip_retry
456
+ self._raise_validation_errors = raise_validation_errors
457
+ self.verbose = verbose
490
458
 
491
- def _check_if_remote_keys_ok(self):
492
459
  jc = JobsChecks(self)
460
+
461
+ # check if the user has all the keys they need
493
462
  if jc.needs_key_process():
494
463
  jc.key_process()
495
464
 
496
- def _check_if_local_keys_ok(self):
497
- jc = JobsChecks(self)
498
- if self.run_config.parameters.check_api_keys:
465
+ jh = JobsRemoteInferenceHandler(self, verbose=verbose)
466
+ if jh.use_remote_inference(disable_remote_inference):
467
+ jh.create_remote_inference_job(
468
+ iterations=n,
469
+ remote_inference_description=remote_inference_description,
470
+ remote_inference_results_visibility=remote_inference_results_visibility,
471
+ )
472
+ results = jh.poll_remote_inference_job()
473
+ return results
474
+
475
+ if check_api_keys:
499
476
  jc.check_api_keys()
500
477
 
501
- async def _execute_with_remote_cache(self, run_job_async: bool) -> Results:
478
+ # handle cache
479
+ if cache is None or cache is True:
480
+ from edsl.data.CacheHandler import CacheHandler
502
481
 
503
- use_remote_cache = self.use_remote_cache()
482
+ cache = CacheHandler().get_cache()
483
+ if cache is False:
484
+ from edsl.data.Cache import Cache
504
485
 
505
- from edsl.coop.coop import Coop
506
- from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
507
- from edsl.data.Cache import Cache
486
+ cache = Cache()
508
487
 
509
- assert isinstance(self.run_config.environment.cache, Cache)
488
+ if bucket_collection is None:
489
+ bucket_collection = self.create_bucket_collection()
510
490
 
491
+ remote_cache = self.use_remote_cache(disable_remote_cache)
511
492
  with RemoteCacheSync(
512
493
  coop=Coop(),
513
- cache=self.run_config.environment.cache,
494
+ cache=cache,
514
495
  output_func=self._output,
515
- remote_cache=use_remote_cache,
516
- remote_cache_description=self.run_config.parameters.remote_cache_description,
517
- ):
518
- runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
519
- if run_job_async:
520
- results = await runner.run_async(self.run_config.parameters)
521
- else:
522
- results = runner.run(self.run_config.parameters)
496
+ remote_cache=remote_cache,
497
+ remote_cache_description=remote_cache_description,
498
+ ) as r:
499
+ results = self._run_local(
500
+ n=n,
501
+ progress_bar=progress_bar,
502
+ cache=cache,
503
+ stop_on_exception=stop_on_exception,
504
+ sidecar_model=sidecar_model,
505
+ print_exceptions=print_exceptions,
506
+ raise_validation_errors=raise_validation_errors,
507
+ bucket_collection=bucket_collection,
508
+ )
523
509
  return results
524
510
 
525
- def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
526
-
527
- self._prepare_to_run()
528
- self._check_if_remote_keys_ok()
529
-
530
- # first try to run the job remotely
531
- if results := self._remote_results():
532
- return results
533
-
534
- self._check_if_local_keys_ok()
535
- return None
536
-
537
- @property
538
- def num_interviews(self):
539
- if self.run_config.parameters.n is None:
540
- return len(self)
541
- else:
542
- len(self) * self.run_config.parameters.n
511
+ async def run_async(
512
+ self,
513
+ cache=None,
514
+ n=1,
515
+ disable_remote_inference: bool = False,
516
+ remote_inference_description: Optional[str] = None,
517
+ remote_inference_results_visibility: Optional[
518
+ Literal["private", "public", "unlisted"]
519
+ ] = "unlisted",
520
+ bucket_collection: Optional[BucketCollection] = None,
521
+ **kwargs,
522
+ ):
523
+ """Run the job asynchronously, either locally or remotely.
543
524
 
544
- def _run(self, config: RunConfig):
545
- "Shared code for run and run_async"
546
- if config.environment.cache is not None:
547
- self.run_config.environment.cache = config.environment.cache
525
+ :param cache: Cache object or boolean
526
+ :param n: Number of iterations
527
+ :param disable_remote_inference: If True, forces local execution
528
+ :param remote_inference_description: Description for remote jobs
529
+ :param remote_inference_results_visibility: Visibility setting for remote results
530
+ :param kwargs: Additional arguments passed to local execution
531
+ :return: Results object
532
+ """
533
+ # Check if we should use remote inference
534
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
535
+ from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
548
536
 
549
- if config.environment.bucket_collection is not None:
550
- self.run_config.environment.bucket_collection = (
551
- config.environment.bucket_collection
537
+ jh = JobsRemoteInferenceHandler(self, verbose=False)
538
+ if jh.use_remote_inference(disable_remote_inference):
539
+ results = await jh.create_and_poll_remote_job(
540
+ iterations=n,
541
+ remote_inference_description=remote_inference_description,
542
+ remote_inference_results_visibility=remote_inference_results_visibility,
552
543
  )
553
-
554
- if config.environment.key_lookup is not None:
555
- self.run_config.environment.key_lookup = config.environment.key_lookup
556
-
557
- # replace the parameters with the ones from the config
558
- self.run_config.parameters = config.parameters
559
-
560
- self.replace_missing_objects()
561
-
562
- # try to run remotely first
563
- self._prepare_to_run()
564
- self._check_if_remote_keys_ok()
565
-
566
- if (
567
- self.run_config.environment.cache is None
568
- or self.run_config.environment.cache is True
569
- ):
570
- from edsl.data.CacheHandler import CacheHandler
571
-
572
- self.run_config.environment.cache = CacheHandler().get_cache()
573
-
574
- if self.run_config.environment.cache is False:
575
- from edsl.data.Cache import Cache
576
-
577
- self.run_config.environment.cache = Cache(immediate_write=False)
578
-
579
- # first try to run the job remotely
580
- if results := self._remote_results():
581
544
  return results
582
545
 
583
- self._check_if_local_keys_ok()
584
-
585
- if config.environment.bucket_collection is None:
586
- self.run_config.environment.bucket_collection = (
587
- self.create_bucket_collection()
588
- )
589
-
590
- @with_config
591
- def run(self, *, config: RunConfig) -> "Results":
592
- """
593
- Runs the Job: conducts Interviews and returns their results.
594
-
595
- :param n: How many times to run each interview
596
- :param progress_bar: Whether to show a progress bar
597
- :param stop_on_exception: Stops the job if an exception is raised
598
- :param check_api_keys: Raises an error if API keys are invalid
599
- :param verbose: Prints extra messages
600
- :param remote_cache_description: Specifies a description for this group of entries in the remote cache
601
- :param remote_inference_description: Specifies a description for the remote inference job
602
- :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
603
- :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
604
- :param disable_remote_inference: If True, the job will not use remote inference
605
- :param cache: A Cache object to store results
606
- :param bucket_collection: A BucketCollection object to track API calls
607
- :param key_lookup: A KeyLookup object to manage API keys
608
- """
609
- self._run(config)
610
-
611
- return asyncio.run(self._execute_with_remote_cache(run_job_async=False))
546
+ if bucket_collection is None:
547
+ bucket_collection = self.create_bucket_collection()
612
548
 
613
- @with_config
614
- async def run_async(self, *, config: RunConfig) -> "Results":
615
- """
616
- Runs the Job: conducts Interviews and returns their results.
549
+ # If not using remote inference, run locally with async
550
+ return await JobsRunnerAsyncio(
551
+ self, bucket_collection=bucket_collection
552
+ ).run_async(cache=cache, n=n, **kwargs)
617
553
 
618
- :param n: How many times to run each interview
619
- :param progress_bar: Whether to show a progress bar
620
- :param stop_on_exception: Stops the job if an exception is raised
621
- :param check_api_keys: Raises an error if API keys are invalid
622
- :param verbose: Prints extra messages
623
- :param remote_cache_description: Specifies a description for this group of entries in the remote cache
624
- :param remote_inference_description: Specifies a description for the remote inference job
625
- :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
626
- :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
627
- :param disable_remote_inference: If True, the job will not use remote inference
628
- :param cache: A Cache object to store results
629
- :param bucket_collection: A BucketCollection object to track API calls
630
- :param key_lookup: A KeyLookup object to manage API keys
631
- """
632
- self._run(config)
554
+ def _run_local(self, bucket_collection, *args, **kwargs):
555
+ """Run the job locally."""
556
+ from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
633
557
 
634
- return await self._execute_with_remote_cache(run_job_async=True)
558
+ results = JobsRunnerAsyncio(self, bucket_collection=bucket_collection).run(
559
+ *args, **kwargs
560
+ )
561
+ return results
635
562
 
636
563
  def __repr__(self) -> str:
637
564
  """Return an eval-able string representation of the Jobs instance."""
@@ -661,6 +588,10 @@ class Jobs(Base):
661
588
  )
662
589
  return number_of_questions
663
590
 
591
+ #######################
592
+ # Serialization methods
593
+ #######################
594
+
664
595
  def to_dict(self, add_edsl_version=True):
665
596
  d = {
666
597
  "survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
@@ -714,6 +645,9 @@ class Jobs(Base):
714
645
  """
715
646
  return hash(self) == hash(other)
716
647
 
648
+ #######################
649
+ # Example methods
650
+ #######################
717
651
  @classmethod
718
652
  def example(
719
653
  cls,