edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +28 -0
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -16
  5. edsl/agents/Invigilator.py +13 -14
  6. edsl/agents/InvigilatorBase.py +4 -1
  7. edsl/agents/PromptConstructor.py +42 -22
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +18 -5
  10. edsl/auto/StageBase.py +53 -40
  11. edsl/auto/StageQuestions.py +2 -1
  12. edsl/auto/utilities.py +0 -6
  13. edsl/coop/coop.py +21 -5
  14. edsl/data/Cache.py +29 -18
  15. edsl/data/CacheHandler.py +0 -2
  16. edsl/data/RemoteCacheSync.py +154 -46
  17. edsl/data/hack.py +10 -0
  18. edsl/enums.py +7 -0
  19. edsl/inference_services/AnthropicService.py +38 -16
  20. edsl/inference_services/AvailableModelFetcher.py +7 -1
  21. edsl/inference_services/GoogleService.py +5 -1
  22. edsl/inference_services/InferenceServicesCollection.py +18 -2
  23. edsl/inference_services/OpenAIService.py +46 -31
  24. edsl/inference_services/TestService.py +1 -3
  25. edsl/inference_services/TogetherAIService.py +5 -3
  26. edsl/inference_services/data_structures.py +74 -2
  27. edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
  28. edsl/jobs/FetchInvigilator.py +10 -3
  29. edsl/jobs/InterviewsConstructor.py +6 -4
  30. edsl/jobs/Jobs.py +299 -233
  31. edsl/jobs/JobsChecks.py +2 -2
  32. edsl/jobs/JobsPrompts.py +1 -1
  33. edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
  34. edsl/jobs/async_interview_runner.py +138 -0
  35. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  36. edsl/jobs/data_structures.py +120 -0
  37. edsl/jobs/interviews/Interview.py +80 -42
  38. edsl/jobs/results_exceptions_handler.py +98 -0
  39. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
  40. edsl/jobs/runners/JobsRunnerStatus.py +131 -164
  41. edsl/jobs/tasks/TaskHistory.py +24 -3
  42. edsl/language_models/LanguageModel.py +59 -4
  43. edsl/language_models/ModelList.py +19 -8
  44. edsl/language_models/__init__.py +1 -1
  45. edsl/language_models/model.py +256 -0
  46. edsl/language_models/repair.py +1 -1
  47. edsl/questions/QuestionBase.py +35 -26
  48. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  49. edsl/questions/QuestionBudget.py +1 -1
  50. edsl/questions/QuestionCheckBox.py +2 -2
  51. edsl/questions/QuestionExtract.py +5 -7
  52. edsl/questions/QuestionFreeText.py +1 -1
  53. edsl/questions/QuestionList.py +9 -15
  54. edsl/questions/QuestionMatrix.py +1 -1
  55. edsl/questions/QuestionMultipleChoice.py +1 -1
  56. edsl/questions/QuestionNumerical.py +1 -1
  57. edsl/questions/QuestionRank.py +1 -1
  58. edsl/questions/SimpleAskMixin.py +1 -1
  59. edsl/questions/__init__.py +1 -1
  60. edsl/questions/data_structures.py +20 -0
  61. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
  62. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
  63. edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
  64. edsl/results/DatasetExportMixin.py +60 -119
  65. edsl/results/Result.py +109 -3
  66. edsl/results/Results.py +50 -39
  67. edsl/results/file_exports.py +252 -0
  68. edsl/scenarios/ScenarioList.py +35 -7
  69. edsl/surveys/Survey.py +71 -20
  70. edsl/test_h +1 -0
  71. edsl/utilities/gcp_bucket/example.py +50 -0
  72. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
  73. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
  74. edsl/language_models/registry.py +0 -180
  75. /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
  76. /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
  77. /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
  78. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  79. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  80. /edsl/results/{Selector.py → results_selector.py} +0 -0
  81. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  82. /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
  83. /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
  84. /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
  85. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -1,7 +1,17 @@
1
1
  # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
2
  from __future__ import annotations
3
- import warnings
4
- from typing import Literal, Optional, Union, Sequence, Generator, TYPE_CHECKING
3
+ import asyncio
4
+ from inspect import signature
5
+ from typing import (
6
+ Literal,
7
+ Optional,
8
+ Union,
9
+ Sequence,
10
+ Generator,
11
+ TYPE_CHECKING,
12
+ Callable,
13
+ Tuple,
14
+ )
5
15
 
6
16
  from edsl.Base import Base
7
17
 
@@ -9,10 +19,13 @@ from edsl.jobs.buckets.BucketCollection import BucketCollection
9
19
  from edsl.jobs.JobsPrompts import JobsPrompts
10
20
  from edsl.jobs.interviews.Interview import Interview
11
21
  from edsl.utilities.remove_edsl_version import remove_edsl_version
12
-
22
+ from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
13
23
  from edsl.data.RemoteCacheSync import RemoteCacheSync
14
24
  from edsl.exceptions.coop import CoopServerResponseError
15
25
 
26
+ from edsl.jobs.JobsChecks import JobsChecks
27
+ from edsl.jobs.data_structures import RunEnvironment, RunParameters, RunConfig
28
+
16
29
  if TYPE_CHECKING:
17
30
  from edsl.agents.Agent import Agent
18
31
  from edsl.agents.AgentList import AgentList
@@ -23,6 +36,66 @@ if TYPE_CHECKING:
23
36
  from edsl.results.Results import Results
24
37
  from edsl.results.Dataset import Dataset
25
38
  from edsl.language_models.ModelList import ModelList
39
+ from edsl.data.Cache import Cache
40
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
41
+
42
+ VisibilityType = Literal["private", "public", "unlisted"]
43
+
44
+ from dataclasses import dataclass
45
+ from typing import Optional, Union, TypeVar, Callable, cast
46
+ from functools import wraps
47
+
48
+ try:
49
+ from typing import ParamSpec
50
+ except ImportError:
51
+ from typing_extensions import ParamSpec
52
+
53
+
54
+ P = ParamSpec("P")
55
+ T = TypeVar("T")
56
+
57
+
58
+ from edsl.jobs.check_survey_scenario_compatibility import (
59
+ CheckSurveyScenarioCompatibility,
60
+ )
61
+
62
+
63
+ def with_config(f: Callable[P, T]) -> Callable[P, T]:
64
+ "This decorator make it so that the run function parameters match the RunConfig dataclass."
65
+ parameter_fields = {
66
+ name: field.default
67
+ for name, field in RunParameters.__dataclass_fields__.items()
68
+ }
69
+ environment_fields = {
70
+ name: field.default
71
+ for name, field in RunEnvironment.__dataclass_fields__.items()
72
+ }
73
+ combined = {**parameter_fields, **environment_fields}
74
+
75
+ @wraps(f)
76
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
77
+ environment = RunEnvironment(
78
+ **{k: v for k, v in kwargs.items() if k in environment_fields}
79
+ )
80
+ parameters = RunParameters(
81
+ **{k: v for k, v in kwargs.items() if k in parameter_fields}
82
+ )
83
+ config = RunConfig(environment=environment, parameters=parameters)
84
+ return f(*args, config=config)
85
+
86
+ # Update the wrapper's signature to include all RunConfig parameters
87
+ # old_sig = signature(f)
88
+ # wrapper.__signature__ = old_sig.replace(
89
+ # parameters=list(old_sig.parameters.values())[:-1]
90
+ # + [
91
+ # old_sig.parameters["config"].replace(
92
+ # default=parameter_fields[name], name=name
93
+ # )
94
+ # for name in combined
95
+ # ]
96
+ # )
97
+
98
+ return cast(Callable[P, T], wrapper)
26
99
 
27
100
 
28
101
  class Jobs(Base):
@@ -46,15 +119,62 @@ class Jobs(Base):
46
119
  :param models: a list of models
47
120
  :param scenarios: a list of scenarios
48
121
  """
122
+ self.run_config = RunConfig(
123
+ environment=RunEnvironment(), parameters=RunParameters()
124
+ )
125
+
49
126
  self.survey = survey
50
127
  self.agents: AgentList = agents
51
128
  self.scenarios: ScenarioList = scenarios
52
- self.models = models
129
+ self.models: ModelList = models
130
+
131
+ def add_running_env(self, running_env: RunEnvironment):
132
+ self.run_config.add_environment(running_env)
133
+ return self
53
134
 
54
- self.__bucket_collection = None
135
+ def using_cache(self, cache: "Cache") -> Jobs:
136
+ """
137
+ Add a Cache to the job.
138
+
139
+ :param cache: the cache to add
140
+ """
141
+ self.run_config.add_cache(cache)
142
+ return self
143
+
144
+ def using_bucket_collection(self, bucket_collection: BucketCollection) -> Jobs:
145
+ """
146
+ Add a BucketCollection to the job.
147
+
148
+ :param bucket_collection: the bucket collection to add
149
+ """
150
+ self.run_config.add_bucket_collection(bucket_collection)
151
+ return self
152
+
153
+ def using_key_lookup(self, key_lookup: KeyLookup) -> Jobs:
154
+ """
155
+ Add a KeyLookup to the job.
156
+
157
+ :param key_lookup: the key lookup to add
158
+ """
159
+ self.run_config.add_key_lookup(key_lookup)
160
+ return self
161
+
162
+ def using(self, obj: Union[Cache, BucketCollection, KeyLookup]) -> Jobs:
163
+ """
164
+ Add a Cache, BucketCollection, or KeyLookup to the job.
165
+
166
+ :param obj: the object to add
167
+ """
168
+ from edsl.data.Cache import Cache
169
+ from edsl.language_models.key_management.KeyLookup import KeyLookup
55
170
 
56
- # these setters and getters are used to ensure that the agents, models, and scenarios
57
- # are stored as AgentList, ModelList, and ScenarioList objects.
171
+ if isinstance(obj, Cache):
172
+ self.using_cache(obj)
173
+ elif isinstance(obj, BucketCollection):
174
+ self.using_bucket_collection(obj)
175
+ elif isinstance(obj, KeyLookup):
176
+ self.using_key_lookup(obj)
177
+ return self
58
178
 
59
179
  @property
60
180
  def models(self):
@@ -72,6 +192,12 @@ class Jobs(Base):
72
192
  else:
73
193
  self._models = ModelList([])
74
194
 
195
+ # update the bucket collection if it exists
196
+ if self.run_config.environment.bucket_collection is None:
197
+ self.run_config.environment.bucket_collection = (
198
+ self.create_bucket_collection()
199
+ )
200
+
75
201
  @property
76
202
  def agents(self):
77
203
  return self._agents
@@ -214,13 +340,29 @@ class Jobs(Base):
214
340
 
215
341
  def replace_missing_objects(self) -> None:
216
342
  from edsl.agents.Agent import Agent
217
- from edsl.language_models.registry import Model
343
+ from edsl.language_models.model import Model
218
344
  from edsl.scenarios.Scenario import Scenario
219
345
 
220
346
  self.agents = self.agents or [Agent()]
221
347
  self.models = self.models or [Model()]
222
348
  self.scenarios = self.scenarios or [Scenario()]
223
349
 
350
+ def generate_interviews(self) -> Generator[Interview, None, None]:
351
+ """
352
+ Generate interviews.
353
+
354
+ Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
355
+ This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
356
+ with us filling in defaults.
357
+
358
+ """
359
+ from edsl.jobs.InterviewsConstructor import InterviewsConstructor
360
+
361
+ self.replace_missing_objects()
362
+ yield from InterviewsConstructor(
363
+ self, cache=self.run_config.environment.cache
364
+ ).create_interviews()
365
+
224
366
  def interviews(self) -> list[Interview]:
225
367
  """
226
368
  Return a list of :class:`edsl.jobs.interviews.Interview` objects.
@@ -235,18 +377,10 @@ class Jobs(Base):
235
377
  >>> j.interviews()[0]
236
378
  Interview(agent = Agent(traits = {'status': 'Joyful'}), survey = Survey(...), scenario = Scenario({'period': 'morning'}), model = Model(...))
237
379
  """
238
- if hasattr(self, "_interviews"):
239
- return self._interviews
240
- else:
241
- self.replace_missing_objects()
242
- from edsl.jobs.InterviewsConstructor import InterviewsConstructor
243
-
244
- self._interviews = list(InterviewsConstructor(self).create_interviews())
245
-
246
- return self._interviews
380
+ return list(self.generate_interviews())
247
381
 
248
382
  @classmethod
249
- def from_interviews(cls, interview_list):
383
+ def from_interviews(cls, interview_list) -> "Jobs":
250
384
  """Return a Jobs instance from a list of interviews.
251
385
 
252
386
  This is useful when you have, say, a list of failed interviews and you want to create
@@ -273,16 +407,8 @@ class Jobs(Base):
273
407
  >>> bc
274
408
  BucketCollection(...)
275
409
  """
276
- self.replace_missing_objects() # ensure that all objects are present
277
410
  return BucketCollection.from_models(self.models)
278
411
 
279
- @property
280
- def bucket_collection(self) -> BucketCollection:
281
- """Return the bucket collection. If it does not exist, create it."""
282
- if self.__bucket_collection is None:
283
- self.__bucket_collection = self.create_bucket_collection()
284
- return self.__bucket_collection
285
-
286
412
  def html(self):
287
413
  """Return the HTML representations for each scenario"""
288
414
  links = []
@@ -308,10 +434,12 @@ class Jobs(Base):
308
434
 
309
435
  def _output(self, message) -> None:
310
436
  """Check if a Job is verbose. If so, print the message."""
311
- if hasattr(self, "verbose") and self.verbose:
437
+ if self.run_config.parameters.verbose:
312
438
  print(message)
439
+ # if hasattr(self, "verbose") and self.verbose:
440
+ # print(message)
313
441
 
314
- def all_question_parameters(self):
442
+ def all_question_parameters(self) -> set:
315
443
  """Return all the fields in the questions in the survey.
316
444
  >>> from edsl.jobs import Jobs
317
445
  >>> Jobs.example().all_question_parameters()
@@ -319,86 +447,12 @@ class Jobs(Base):
319
447
  """
320
448
  return set.union(*[question.parameters for question in self.survey.questions])
321
449
 
322
- def _check_parameters(self, strict=False, warn=False) -> None:
323
- """Check if the parameters in the survey and scenarios are consistent.
324
-
325
- >>> from edsl.questions.QuestionFreeText import QuestionFreeText
326
- >>> from edsl.surveys.Survey import Survey
327
- >>> from edsl.scenarios.Scenario import Scenario
328
- >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
329
- >>> j = Jobs(survey = Survey(questions=[q]))
330
- >>> with warnings.catch_warnings(record=True) as w:
331
- ... j._check_parameters(warn = True)
332
- ... assert len(w) == 1
333
- ... assert issubclass(w[-1].category, UserWarning)
334
- ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
335
-
336
- >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
337
- >>> s = Scenario({'plop': "A", 'poo': "B"})
338
- >>> j = Jobs(survey = Survey(questions=[q])).by(s)
339
- >>> j._check_parameters(strict = True)
340
- Traceback (most recent call last):
341
- ...
342
- ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
343
-
344
- >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
345
- >>> s = Scenario({'ugly_question': "B"})
346
- >>> j = Jobs(survey = Survey(questions=[q])).by(s)
347
- >>> j._check_parameters()
348
- Traceback (most recent call last):
349
- ...
350
- ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
351
- """
352
- survey_parameters: set = self.survey.parameters
353
- scenario_parameters: set = self.scenarios.parameters
354
-
355
- msg0, msg1, msg2 = None, None, None
356
-
357
- # look for key issues
358
- if intersection := set(self.scenarios.parameters) & set(
359
- self.survey.question_names
360
- ):
361
- msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
362
-
363
- raise ValueError(msg0)
364
-
365
- if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
366
- msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
367
- if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
368
- msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
369
-
370
- if msg1 or msg2:
371
- message = "\n".join(filter(None, [msg1, msg2]))
372
- if strict:
373
- raise ValueError(message)
374
- else:
375
- if warn:
376
- warnings.warn(message)
377
-
378
- if self.scenarios.has_jinja_braces:
379
- warnings.warn(
380
- "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
381
- )
382
- self.scenarios = self.scenarios._convert_jinja_braces()
383
-
384
- @property
385
- def skip_retry(self):
386
- if not hasattr(self, "_skip_retry"):
387
- return False
388
- return self._skip_retry
389
-
390
- @property
391
- def raise_validation_errors(self):
392
- if not hasattr(self, "_raise_validation_errors"):
393
- return False
394
- return self._raise_validation_errors
395
-
396
- def use_remote_cache(self, disable_remote_cache: bool) -> bool:
450
+ def use_remote_cache(self) -> bool:
397
451
  import requests
398
452
 
399
- if disable_remote_cache:
453
+ if self.run_config.parameters.disable_remote_cache:
400
454
  return False
401
- if not disable_remote_cache:
455
+ if not self.run_config.parameters.disable_remote_cache:
402
456
  try:
403
457
  from edsl.coop.coop import Coop
404
458
 
@@ -411,154 +465,173 @@ class Jobs(Base):
411
465
 
412
466
  return False
413
467
 
414
- def run(
468
+ def _remote_results(
415
469
  self,
416
- n: int = 1,
417
- progress_bar: bool = False,
418
- stop_on_exception: bool = False,
419
- cache: Union["Cache", bool] = None,
420
- check_api_keys: bool = False,
421
- sidecar_model: Optional[LanguageModel] = None,
422
- verbose: bool = True,
423
- print_exceptions=True,
424
- remote_cache_description: Optional[str] = None,
425
- remote_inference_description: Optional[str] = None,
426
- remote_inference_results_visibility: Optional[
427
- Literal["private", "public", "unlisted"]
428
- ] = "unlisted",
429
- skip_retry: bool = False,
430
- raise_validation_errors: bool = False,
431
- disable_remote_cache: bool = False,
432
- disable_remote_inference: bool = False,
433
- bucket_collection: Optional[BucketCollection] = None,
434
- ) -> Results:
435
- """
436
- Runs the Job: conducts Interviews and returns their results.
437
-
438
- :param n: How many times to run each interview
439
- :param progress_bar: Whether to show a progress bar
440
- :param stop_on_exception: Stops the job if an exception is raised
441
- :param cache: A Cache object to store results
442
- :param check_api_keys: Raises an error if API keys are invalid
443
- :param verbose: Prints extra messages
444
- :param remote_cache_description: Specifies a description for this group of entries in the remote cache
445
- :param remote_inference_description: Specifies a description for the remote inference job
446
- :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
447
- :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
448
- :param disable_remote_inference: If True, the job will not use remote inference
449
- """
450
- from edsl.coop.coop import Coop
451
- from edsl.jobs.JobsChecks import JobsChecks
470
+ ) -> Union["Results", None]:
452
471
  from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
453
472
 
454
- self._check_parameters()
455
- self._skip_retry = skip_retry
456
- self._raise_validation_errors = raise_validation_errors
457
- self.verbose = verbose
473
+ jh = JobsRemoteInferenceHandler(
474
+ self, verbose=self.run_config.parameters.verbose
475
+ )
476
+ if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
477
+ job_info = jh.create_remote_inference_job(
478
+ iterations=self.run_config.parameters.n,
479
+ remote_inference_description=self.run_config.parameters.remote_inference_description,
480
+ remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
481
+ )
482
+ results = jh.poll_remote_inference_job(job_info)
483
+ return results
484
+ else:
485
+ return None
486
+
487
+ def _prepare_to_run(self) -> None:
488
+ "This makes sure that the job is ready to run and that keys are in place for a remote job."
489
+ CheckSurveyScenarioCompatibility(self.survey, self.scenarios).check()
458
490
 
491
+ def _check_if_remote_keys_ok(self):
459
492
  jc = JobsChecks(self)
460
-
461
- # check if the user has all the keys they need
462
493
  if jc.needs_key_process():
463
494
  jc.key_process()
464
495
 
465
- jh = JobsRemoteInferenceHandler(self, verbose=verbose)
466
- if jh.use_remote_inference(disable_remote_inference):
467
- jh.create_remote_inference_job(
468
- iterations=n,
469
- remote_inference_description=remote_inference_description,
470
- remote_inference_results_visibility=remote_inference_results_visibility,
471
- )
472
- results = jh.poll_remote_inference_job()
473
- return results
474
-
475
- if check_api_keys:
496
+ def _check_if_local_keys_ok(self):
497
+ jc = JobsChecks(self)
498
+ if self.run_config.parameters.check_api_keys:
476
499
  jc.check_api_keys()
477
500
 
478
- # handle cache
479
- if cache is None or cache is True:
480
- from edsl.data.CacheHandler import CacheHandler
501
+ async def _execute_with_remote_cache(self, run_job_async: bool) -> Results:
481
502
 
482
- cache = CacheHandler().get_cache()
483
- if cache is False:
484
- from edsl.data.Cache import Cache
503
+ use_remote_cache = self.use_remote_cache()
485
504
 
486
- cache = Cache()
505
+ from edsl.coop.coop import Coop
506
+ from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
507
+ from edsl.data.Cache import Cache
487
508
 
488
- if bucket_collection is None:
489
- bucket_collection = self.create_bucket_collection()
509
+ assert isinstance(self.run_config.environment.cache, Cache)
490
510
 
491
- remote_cache = self.use_remote_cache(disable_remote_cache)
492
511
  with RemoteCacheSync(
493
512
  coop=Coop(),
494
- cache=cache,
513
+ cache=self.run_config.environment.cache,
495
514
  output_func=self._output,
496
- remote_cache=remote_cache,
497
- remote_cache_description=remote_cache_description,
498
- ) as r:
499
- results = self._run_local(
500
- n=n,
501
- progress_bar=progress_bar,
502
- cache=cache,
503
- stop_on_exception=stop_on_exception,
504
- sidecar_model=sidecar_model,
505
- print_exceptions=print_exceptions,
506
- raise_validation_errors=raise_validation_errors,
507
- bucket_collection=bucket_collection,
508
- )
515
+ remote_cache=use_remote_cache,
516
+ remote_cache_description=self.run_config.parameters.remote_cache_description,
517
+ ):
518
+ runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
519
+ if run_job_async:
520
+ results = await runner.run_async(self.run_config.parameters)
521
+ else:
522
+ results = runner.run(self.run_config.parameters)
509
523
  return results
510
524
 
511
- async def run_async(
512
- self,
513
- cache=None,
514
- n=1,
515
- disable_remote_inference: bool = False,
516
- remote_inference_description: Optional[str] = None,
517
- remote_inference_results_visibility: Optional[
518
- Literal["private", "public", "unlisted"]
519
- ] = "unlisted",
520
- bucket_collection: Optional[BucketCollection] = None,
521
- **kwargs,
522
- ):
523
- """Run the job asynchronously, either locally or remotely.
525
+ def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
524
526
 
525
- :param cache: Cache object or boolean
526
- :param n: Number of iterations
527
- :param disable_remote_inference: If True, forces local execution
528
- :param remote_inference_description: Description for remote jobs
529
- :param remote_inference_results_visibility: Visibility setting for remote results
530
- :param kwargs: Additional arguments passed to local execution
531
- :return: Results object
532
- """
533
- # Check if we should use remote inference
534
- from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
535
- from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
527
+ self._prepare_to_run()
528
+ self._check_if_remote_keys_ok()
529
+
530
+ # first try to run the job remotely
531
+ if results := self._remote_results():
532
+ return results
533
+
534
+ self._check_if_local_keys_ok()
535
+ return None
536
+
537
+ @property
538
+ def num_interviews(self):
539
+ if self.run_config.parameters.n is None:
540
+ return len(self)
541
+ else:
542
+ len(self) * self.run_config.parameters.n
536
543
 
537
- jh = JobsRemoteInferenceHandler(self, verbose=False)
538
- if jh.use_remote_inference(disable_remote_inference):
539
- results = await jh.create_and_poll_remote_job(
540
- iterations=n,
541
- remote_inference_description=remote_inference_description,
542
- remote_inference_results_visibility=remote_inference_results_visibility,
544
+ def _run(self, config: RunConfig):
545
+ "Shared code for run and run_async"
546
+ if config.environment.cache is not None:
547
+ self.run_config.environment.cache = config.environment.cache
548
+
549
+ if config.environment.bucket_collection is not None:
550
+ self.run_config.environment.bucket_collection = (
551
+ config.environment.bucket_collection
543
552
  )
553
+
554
+ if config.environment.key_lookup is not None:
555
+ self.run_config.environment.key_lookup = config.environment.key_lookup
556
+
557
+ # replace the parameters with the ones from the config
558
+ self.run_config.parameters = config.parameters
559
+
560
+ self.replace_missing_objects()
561
+
562
+ # try to run remotely first
563
+ self._prepare_to_run()
564
+ self._check_if_remote_keys_ok()
565
+
566
+ if (
567
+ self.run_config.environment.cache is None
568
+ or self.run_config.environment.cache is True
569
+ ):
570
+ from edsl.data.CacheHandler import CacheHandler
571
+
572
+ self.run_config.environment.cache = CacheHandler().get_cache()
573
+
574
+ if self.run_config.environment.cache is False:
575
+ from edsl.data.Cache import Cache
576
+
577
+ self.run_config.environment.cache = Cache(immediate_write=False)
578
+
579
+ # first try to run the job remotely
580
+ if results := self._remote_results():
544
581
  return results
545
582
 
546
- if bucket_collection is None:
547
- bucket_collection = self.create_bucket_collection()
583
+ self._check_if_local_keys_ok()
548
584
 
549
- # If not using remote inference, run locally with async
550
- return await JobsRunnerAsyncio(
551
- self, bucket_collection=bucket_collection
552
- ).run_async(cache=cache, n=n, **kwargs)
585
+ if config.environment.bucket_collection is None:
586
+ self.run_config.environment.bucket_collection = (
587
+ self.create_bucket_collection()
588
+ )
553
589
 
554
- def _run_local(self, bucket_collection, *args, **kwargs):
555
- """Run the job locally."""
556
- from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
590
+ @with_config
591
+ def run(self, *, config: RunConfig) -> "Results":
592
+ """
593
+ Runs the Job: conducts Interviews and returns their results.
557
594
 
558
- results = JobsRunnerAsyncio(self, bucket_collection=bucket_collection).run(
559
- *args, **kwargs
560
- )
561
- return results
595
+ :param n: How many times to run each interview
596
+ :param progress_bar: Whether to show a progress bar
597
+ :param stop_on_exception: Stops the job if an exception is raised
598
+ :param check_api_keys: Raises an error if API keys are invalid
599
+ :param verbose: Prints extra messages
600
+ :param remote_cache_description: Specifies a description for this group of entries in the remote cache
601
+ :param remote_inference_description: Specifies a description for the remote inference job
602
+ :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
603
+ :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
604
+ :param disable_remote_inference: If True, the job will not use remote inference
605
+ :param cache: A Cache object to store results
606
+ :param bucket_collection: A BucketCollection object to track API calls
607
+ :param key_lookup: A KeyLookup object to manage API keys
608
+ """
609
+ self._run(config)
610
+
611
+ return asyncio.run(self._execute_with_remote_cache(run_job_async=False))
612
+
613
+ @with_config
614
+ async def run_async(self, *, config: RunConfig) -> "Results":
615
+ """
616
+ Runs the Job: conducts Interviews and returns their results.
617
+
618
+ :param n: How many times to run each interview
619
+ :param progress_bar: Whether to show a progress bar
620
+ :param stop_on_exception: Stops the job if an exception is raised
621
+ :param check_api_keys: Raises an error if API keys are invalid
622
+ :param verbose: Prints extra messages
623
+ :param remote_cache_description: Specifies a description for this group of entries in the remote cache
624
+ :param remote_inference_description: Specifies a description for the remote inference job
625
+ :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
626
+ :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
627
+ :param disable_remote_inference: If True, the job will not use remote inference
628
+ :param cache: A Cache object to store results
629
+ :param bucket_collection: A BucketCollection object to track API calls
630
+ :param key_lookup: A KeyLookup object to manage API keys
631
+ """
632
+ self._run(config)
633
+
634
+ return await self._execute_with_remote_cache(run_job_async=True)
562
635
 
563
636
  def __repr__(self) -> str:
564
637
  """Return an eval-able string representation of the Jobs instance."""
@@ -588,10 +661,6 @@ class Jobs(Base):
588
661
  )
589
662
  return number_of_questions
590
663
 
591
- #######################
592
- # Serialization methods
593
- #######################
594
-
595
664
  def to_dict(self, add_edsl_version=True):
596
665
  d = {
597
666
  "survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
@@ -645,9 +714,6 @@ class Jobs(Base):
645
714
  """
646
715
  return hash(self) == hash(other)
647
716
 
648
- #######################
649
- # Example methods
650
- #######################
651
717
  @classmethod
652
718
  def example(
653
719
  cls,
edsl/jobs/JobsChecks.py CHANGED
@@ -8,7 +8,7 @@ class JobsChecks:
8
8
  self.jobs = jobs
9
9
 
10
10
  def check_api_keys(self) -> None:
11
- from edsl.language_models.registry import Model
11
+ from edsl.language_models.model import Model
12
12
 
13
13
  if len(self.jobs.models) == 0:
14
14
  models = [Model()]
@@ -28,7 +28,7 @@ class JobsChecks:
28
28
  """
29
29
  missing_api_keys = set()
30
30
 
31
- from edsl.language_models.registry import Model
31
+ from edsl.language_models.model import Model
32
32
  from edsl.enums import service_to_api_keyname
33
33
 
34
34
  for model in self.jobs.models + [Model()]: