edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. edsl/__init__.py +1 -0
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +92 -41
  4. edsl/agents/AgentList.py +15 -2
  5. edsl/agents/InvigilatorBase.py +15 -25
  6. edsl/agents/PromptConstructor.py +149 -108
  7. edsl/agents/descriptors.py +17 -4
  8. edsl/conjure/AgentConstructionMixin.py +11 -3
  9. edsl/conversation/Conversation.py +66 -14
  10. edsl/conversation/chips.py +95 -0
  11. edsl/coop/coop.py +148 -39
  12. edsl/data/Cache.py +1 -1
  13. edsl/data/RemoteCacheSync.py +25 -12
  14. edsl/exceptions/BaseException.py +21 -0
  15. edsl/exceptions/__init__.py +7 -3
  16. edsl/exceptions/agents.py +17 -19
  17. edsl/exceptions/results.py +11 -8
  18. edsl/exceptions/scenarios.py +22 -0
  19. edsl/exceptions/surveys.py +13 -10
  20. edsl/inference_services/AwsBedrock.py +7 -2
  21. edsl/inference_services/InferenceServicesCollection.py +42 -13
  22. edsl/inference_services/models_available_cache.py +25 -1
  23. edsl/jobs/Jobs.py +306 -71
  24. edsl/jobs/interviews/Interview.py +24 -14
  25. edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
  26. edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
  27. edsl/jobs/interviews/ReportErrors.py +2 -2
  28. edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
  29. edsl/jobs/tasks/TaskHistory.py +1 -0
  30. edsl/language_models/KeyLookup.py +30 -0
  31. edsl/language_models/LanguageModel.py +47 -59
  32. edsl/language_models/__init__.py +1 -0
  33. edsl/prompts/Prompt.py +11 -12
  34. edsl/questions/QuestionBase.py +53 -13
  35. edsl/questions/QuestionBasePromptsMixin.py +1 -33
  36. edsl/questions/QuestionFreeText.py +1 -0
  37. edsl/questions/QuestionFunctional.py +2 -2
  38. edsl/questions/descriptors.py +23 -28
  39. edsl/results/DatasetExportMixin.py +25 -1
  40. edsl/results/Result.py +27 -10
  41. edsl/results/Results.py +34 -121
  42. edsl/results/ResultsDBMixin.py +1 -1
  43. edsl/results/Selector.py +18 -1
  44. edsl/scenarios/FileStore.py +20 -5
  45. edsl/scenarios/Scenario.py +52 -13
  46. edsl/scenarios/ScenarioHtmlMixin.py +7 -2
  47. edsl/scenarios/ScenarioList.py +12 -1
  48. edsl/scenarios/__init__.py +2 -0
  49. edsl/surveys/Rule.py +10 -4
  50. edsl/surveys/Survey.py +100 -77
  51. edsl/utilities/utilities.py +18 -0
  52. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
  53. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
  54. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
  55. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
3
3
  import warnings
4
4
  import requests
5
5
  from itertools import product
6
- from typing import Optional, Union, Sequence, Generator
6
+ from typing import Literal, Optional, Union, Sequence, Generator
7
7
 
8
8
  from edsl.Base import Base
9
+
9
10
  from edsl.exceptions import MissingAPIKeyError
10
11
  from edsl.jobs.buckets.BucketCollection import BucketCollection
11
12
  from edsl.jobs.interviews.Interview import Interview
@@ -193,7 +194,7 @@ class Jobs(Base):
193
194
  inference_service=invigilator.model._inference_service_,
194
195
  model=invigilator.model.model,
195
196
  )
196
- costs.append(prompt_cost["cost"])
197
+ costs.append(prompt_cost["cost_usd"])
197
198
 
198
199
  d = Dataset(
199
200
  [
@@ -209,14 +210,14 @@ class Jobs(Base):
209
210
  )
210
211
  return d
211
212
 
212
- def show_prompts(self, all=False) -> None:
213
+ def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
213
214
  """Print the prompts."""
214
215
  if all:
215
- self.prompts().to_scenario_list().print(format="rich")
216
+ self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
216
217
  else:
217
218
  self.prompts().select(
218
219
  "user_prompt", "system_prompt"
219
- ).to_scenario_list().print(format="rich")
220
+ ).to_scenario_list().print(format="rich", max_rows=max_rows)
220
221
 
221
222
  @staticmethod
222
223
  def estimate_prompt_cost(
@@ -227,6 +228,7 @@ class Jobs(Base):
227
228
  model: str,
228
229
  ) -> dict:
229
230
  """Estimates the cost of a prompt. Takes piping into account."""
231
+ import math
230
232
 
231
233
  def get_piping_multiplier(prompt: str):
232
234
  """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
@@ -240,10 +242,25 @@ class Jobs(Base):
240
242
 
241
243
  try:
242
244
  relevant_prices = price_lookup[key]
243
- output_price_per_token = 1 / float(
244
- relevant_prices["output"]["one_usd_buys"]
245
+
246
+ service_input_token_price = float(
247
+ relevant_prices["input"]["service_stated_token_price"]
248
+ )
249
+ service_input_token_qty = float(
250
+ relevant_prices["input"]["service_stated_token_qty"]
251
+ )
252
+ input_price_per_token = service_input_token_price / service_input_token_qty
253
+
254
+ service_output_token_price = float(
255
+ relevant_prices["output"]["service_stated_token_price"]
256
+ )
257
+ service_output_token_qty = float(
258
+ relevant_prices["output"]["service_stated_token_qty"]
259
+ )
260
+ output_price_per_token = (
261
+ service_output_token_price / service_output_token_qty
245
262
  )
246
- input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
263
+
247
264
  except KeyError:
248
265
  # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
249
266
  # Use a sensible default
@@ -253,9 +270,8 @@ class Jobs(Base):
253
270
  warnings.warn(
254
271
  "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
255
272
  )
256
-
257
- output_price_per_token = 0.00000015 # $0.15 / 1M tokens
258
- input_price_per_token = 0.00000060 # $0.60 / 1M tokens
273
+ input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
+ output_price_per_token = 0.00000060 # $0.60 / 1M tokens
259
275
 
260
276
  # Compute the number of characters (double if the question involves piping)
261
277
  user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
@@ -267,7 +283,8 @@ class Jobs(Base):
267
283
 
268
284
  # Convert into tokens (1 token approx. equals 4 characters)
269
285
  input_tokens = (user_prompt_chars + system_prompt_chars) // 4
270
- output_tokens = input_tokens
286
+
287
+ output_tokens = math.ceil(0.75 * input_tokens)
271
288
 
272
289
  cost = (
273
290
  input_tokens * input_price_per_token
@@ -277,15 +294,17 @@ class Jobs(Base):
277
294
  return {
278
295
  "input_tokens": input_tokens,
279
296
  "output_tokens": output_tokens,
280
- "cost": cost,
297
+ "cost_usd": cost,
281
298
  }
282
299
 
283
- def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
300
+ def estimate_job_cost_from_external_prices(
301
+ self, price_lookup: dict, iterations: int = 1
302
+ ) -> dict:
284
303
  """
285
304
  Estimates the cost of a job according to the following assumptions:
286
305
 
287
306
  - 1 token = 4 characters.
288
- - Input tokens = output tokens.
307
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
289
308
 
290
309
  price_lookup is an external pricing dictionary.
291
310
  """
@@ -322,7 +341,7 @@ class Jobs(Base):
322
341
  "system_prompt": system_prompt,
323
342
  "estimated_input_tokens": prompt_cost["input_tokens"],
324
343
  "estimated_output_tokens": prompt_cost["output_tokens"],
325
- "estimated_cost": prompt_cost["cost"],
344
+ "estimated_cost_usd": prompt_cost["cost_usd"],
326
345
  "inference_service": inference_service,
327
346
  "model": model,
328
347
  }
@@ -334,18 +353,21 @@ class Jobs(Base):
334
353
  df.groupby(["inference_service", "model"])
335
354
  .agg(
336
355
  {
337
- "estimated_cost": "sum",
356
+ "estimated_cost_usd": "sum",
338
357
  "estimated_input_tokens": "sum",
339
358
  "estimated_output_tokens": "sum",
340
359
  }
341
360
  )
342
361
  .reset_index()
343
362
  )
363
+ df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
+ df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
+ df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
344
366
 
345
367
  estimated_costs_by_model = df.to_dict("records")
346
368
 
347
369
  estimated_total_cost = sum(
348
- model["estimated_cost"] for model in estimated_costs_by_model
370
+ model["estimated_cost_usd"] for model in estimated_costs_by_model
349
371
  )
350
372
  estimated_total_input_tokens = sum(
351
373
  model["estimated_input_tokens"] for model in estimated_costs_by_model
@@ -355,7 +377,7 @@ class Jobs(Base):
355
377
  )
356
378
 
357
379
  output = {
358
- "estimated_total_cost": estimated_total_cost,
380
+ "estimated_total_cost_usd": estimated_total_cost,
359
381
  "estimated_total_input_tokens": estimated_total_input_tokens,
360
382
  "estimated_total_output_tokens": estimated_total_output_tokens,
361
383
  "model_costs": estimated_costs_by_model,
@@ -363,12 +385,12 @@ class Jobs(Base):
363
385
 
364
386
  return output
365
387
 
366
- def estimate_job_cost(self) -> dict:
388
+ def estimate_job_cost(self, iterations: int = 1) -> dict:
367
389
  """
368
390
  Estimates the cost of a job according to the following assumptions:
369
391
 
370
392
  - 1 token = 4 characters.
371
- - Input tokens = output tokens.
393
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
372
394
 
373
395
  Fetches prices from Coop.
374
396
  """
@@ -377,7 +399,9 @@ class Jobs(Base):
377
399
  c = Coop()
378
400
  price_lookup = c.fetch_prices()
379
401
 
380
- return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
402
+ return self.estimate_job_cost_from_external_prices(
403
+ price_lookup=price_lookup, iterations=iterations
404
+ )
381
405
 
382
406
  @staticmethod
383
407
  def compute_job_cost(job_results: "Results") -> float:
@@ -699,7 +723,11 @@ class Jobs(Base):
699
723
  return self._raise_validation_errors
700
724
 
701
725
  def create_remote_inference_job(
702
- self, iterations: int = 1, remote_inference_description: Optional[str] = None
726
+ self,
727
+ iterations: int = 1,
728
+ remote_inference_description: Optional[str] = None,
729
+ remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
+ verbose=False,
703
731
  ):
704
732
  """ """
705
733
  from edsl.coop.coop import Coop
@@ -711,9 +739,11 @@ class Jobs(Base):
711
739
  description=remote_inference_description,
712
740
  status="queued",
713
741
  iterations=iterations,
742
+ initial_results_visibility=remote_inference_results_visibility,
714
743
  )
715
744
  job_uuid = remote_job_creation_data.get("uuid")
716
- print(f"Job sent to server. (Job uuid={job_uuid}).")
745
+ if self.verbose:
746
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
717
747
  return remote_job_creation_data
718
748
 
719
749
  @staticmethod
@@ -724,7 +754,7 @@ class Jobs(Base):
724
754
  return coop.remote_inference_get(job_uuid)
725
755
 
726
756
  def poll_remote_inference_job(
727
- self, remote_job_creation_data: dict
757
+ self, remote_job_creation_data: dict, verbose=False, poll_interval=5
728
758
  ) -> Union[Results, None]:
729
759
  from edsl.coop.coop import Coop
730
760
  import time
@@ -741,42 +771,46 @@ class Jobs(Base):
741
771
  remote_job_data = coop.remote_inference_get(job_uuid)
742
772
  status = remote_job_data.get("status")
743
773
  if status == "cancelled":
744
- print("\r" + " " * 80 + "\r", end="")
745
- print("Job cancelled by the user.")
746
- print(
747
- f"See {expected_parrot_url}/home/remote-inference for more details."
748
- )
774
+ if self.verbose:
775
+ print("\r" + " " * 80 + "\r", end="")
776
+ print("Job cancelled by the user.")
777
+ print(
778
+ f"See {expected_parrot_url}/home/remote-inference for more details."
779
+ )
749
780
  return None
750
781
  elif status == "failed":
751
- print("\r" + " " * 80 + "\r", end="")
752
- print("Job failed.")
753
- print(
754
- f"See {expected_parrot_url}/home/remote-inference for more details."
755
- )
782
+ if self.verbose:
783
+ print("\r" + " " * 80 + "\r", end="")
784
+ print("Job failed.")
785
+ print(
786
+ f"See {expected_parrot_url}/home/remote-inference for more details."
787
+ )
756
788
  return None
757
789
  elif status == "completed":
758
790
  results_uuid = remote_job_data.get("results_uuid")
759
791
  results = coop.get(results_uuid, expected_object_type="results")
760
- print("\r" + " " * 80 + "\r", end="")
761
- url = f"{expected_parrot_url}/content/{results_uuid}"
762
- print(f"Job completed and Results stored on Coop: {url}.")
792
+ if self.verbose:
793
+ print("\r" + " " * 80 + "\r", end="")
794
+ url = f"{expected_parrot_url}/content/{results_uuid}"
795
+ print(f"Job completed and Results stored on Coop: {url}.")
763
796
  return results
764
797
  else:
765
- duration = 5
798
+ duration = poll_interval
766
799
  time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
767
800
  frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
768
801
  start_time = time.time()
769
802
  i = 0
770
803
  while time.time() - start_time < duration:
771
- print(
772
- f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
773
- end="",
774
- flush=True,
775
- )
804
+ if self.verbose:
805
+ print(
806
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
+ end="",
808
+ flush=True,
809
+ )
776
810
  time.sleep(0.1)
777
811
  i += 1
778
812
 
779
- def use_remote_inference(self, disable_remote_inference: bool):
813
+ def use_remote_inference(self, disable_remote_inference: bool) -> bool:
780
814
  if disable_remote_inference:
781
815
  return False
782
816
  if not disable_remote_inference:
@@ -792,20 +826,23 @@ class Jobs(Base):
792
826
 
793
827
  return False
794
828
 
795
- def use_remote_cache(self):
796
- try:
797
- from edsl import Coop
829
+ def use_remote_cache(self, disable_remote_cache: bool) -> bool:
830
+ if disable_remote_cache:
831
+ return False
832
+ if not disable_remote_cache:
833
+ try:
834
+ from edsl import Coop
798
835
 
799
- user_edsl_settings = Coop().edsl_settings
800
- return user_edsl_settings.get("remote_caching", False)
801
- except requests.ConnectionError:
802
- pass
803
- except CoopServerResponseError as e:
804
- pass
836
+ user_edsl_settings = Coop().edsl_settings
837
+ return user_edsl_settings.get("remote_caching", False)
838
+ except requests.ConnectionError:
839
+ pass
840
+ except CoopServerResponseError as e:
841
+ pass
805
842
 
806
843
  return False
807
844
 
808
- def check_api_keys(self):
845
+ def check_api_keys(self) -> None:
809
846
  from edsl import Model
810
847
 
811
848
  for model in self.models + [Model()]:
@@ -815,6 +852,86 @@ class Jobs(Base):
815
852
  inference_service=model._inference_service_,
816
853
  )
817
854
 
855
+ def get_missing_api_keys(self) -> set:
856
+ """
857
+ Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
+ """
859
+
860
+ missing_api_keys = set()
861
+
862
+ from edsl import Model
863
+ from edsl.enums import service_to_api_keyname
864
+
865
+ for model in self.models + [Model()]:
866
+ if not model.has_valid_api_key():
867
+ key_name = service_to_api_keyname.get(
868
+ model._inference_service_, "NOT FOUND"
869
+ )
870
+ missing_api_keys.add(key_name)
871
+
872
+ return missing_api_keys
873
+
874
+ def user_has_all_model_keys(self):
875
+ """
876
+ Returns True if the user has all model keys required to run their job.
877
+
878
+ Otherwise, returns False.
879
+ """
880
+
881
+ try:
882
+ self.check_api_keys()
883
+ return True
884
+ except MissingAPIKeyError:
885
+ return False
886
+ except Exception:
887
+ raise
888
+
889
+ def user_has_ep_api_key(self) -> bool:
890
+ """
891
+ Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
+
893
+ Otherwise, returns False.
894
+ """
895
+
896
+ import os
897
+
898
+ coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
+
900
+ if coop_api_key is not None:
901
+ return True
902
+ else:
903
+ return False
904
+
905
+ def needs_external_llms(self) -> bool:
906
+ """
907
+ Returns True if the job needs external LLMs to run.
908
+
909
+ Otherwise, returns False.
910
+ """
911
+ # These cases are necessary to skip the API key check during doctests
912
+
913
+ # Accounts for Results.example()
914
+ all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
+ [hasattr(a, "answer_question_directly") for a in self.agents]
916
+ )
917
+
918
+ # Accounts for InterviewExceptionEntry.example()
919
+ only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
+
921
+ # Accounts for Survey.__call__
922
+ all_questions_are_functional = set(
923
+ [q.question_type for q in self.survey.questions]
924
+ ) == set(["functional"])
925
+
926
+ if (
927
+ all_agents_answer_questions_directly
928
+ or only_model_is_test
929
+ or all_questions_are_functional
930
+ ):
931
+ return False
932
+ else:
933
+ return True
934
+
818
935
  def run(
819
936
  self,
820
937
  n: int = 1,
@@ -827,22 +944,28 @@ class Jobs(Base):
827
944
  print_exceptions=True,
828
945
  remote_cache_description: Optional[str] = None,
829
946
  remote_inference_description: Optional[str] = None,
947
+ remote_inference_results_visibility: Optional[
948
+ Literal["private", "public", "unlisted"]
949
+ ] = "unlisted",
830
950
  skip_retry: bool = False,
831
951
  raise_validation_errors: bool = False,
952
+ disable_remote_cache: bool = False,
832
953
  disable_remote_inference: bool = False,
833
954
  ) -> Results:
834
955
  """
835
956
  Runs the Job: conducts Interviews and returns their results.
836
957
 
837
- :param n: how many times to run each interview
838
- :param progress_bar: shows a progress bar
839
- :param stop_on_exception: stops the job if an exception is raised
840
- :param cache: a cache object to store results
841
- :param check_api_keys: check if the API keys are valid
842
- :param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
843
- :param verbose: prints messages
844
- :param remote_cache_description: specifies a description for this group of entries in the remote cache
845
- :param remote_inference_description: specifies a description for the remote inference job
958
+ :param n: How many times to run each interview
959
+ :param progress_bar: Whether to show a progress bar
960
+ :param stop_on_exception: Stops the job if an exception is raised
961
+ :param cache: A Cache object to store results
962
+ :param check_api_keys: Raises an error if API keys are invalid
963
+ :param verbose: Prints extra messages
964
+ :param remote_cache_description: Specifies a description for this group of entries in the remote cache
965
+ :param remote_inference_description: Specifies a description for the remote inference job
966
+ :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
967
+ :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
968
+ :param disable_remote_inference: If True, the job will not use remote inference
846
969
  """
847
970
  from edsl.coop.coop import Coop
848
971
 
@@ -852,9 +975,54 @@ class Jobs(Base):
852
975
 
853
976
  self.verbose = verbose
854
977
 
978
+ if (
979
+ not self.user_has_all_model_keys()
980
+ and not self.user_has_ep_api_key()
981
+ and self.needs_external_llms()
982
+ ):
983
+ import secrets
984
+ from dotenv import load_dotenv
985
+ from edsl import CONFIG
986
+ from edsl.coop.coop import Coop
987
+ from edsl.utilities.utilities import write_api_key_to_env
988
+
989
+ missing_api_keys = self.get_missing_api_keys()
990
+
991
+ edsl_auth_token = secrets.token_urlsafe(16)
992
+
993
+ print("You're missing some of the API keys needed to run this job:")
994
+ for api_key in missing_api_keys:
995
+ print(f" 🔑 {api_key}")
996
+ print(
997
+ "\nYou can either add the missing keys to your .env file, or use remote inference."
998
+ )
999
+ print("Remote inference allows you to run jobs on our server.")
1000
+ print("\n🚀 To use remote inference, sign up at the following link:")
1001
+
1002
+ coop = Coop()
1003
+ coop._display_login_url(edsl_auth_token=edsl_auth_token)
1004
+
1005
+ print(
1006
+ "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
+ )
1008
+
1009
+ api_key = coop._poll_for_api_key(edsl_auth_token)
1010
+
1011
+ if api_key is None:
1012
+ print("\nTimed out waiting for login. Please try again.")
1013
+ return
1014
+
1015
+ write_api_key_to_env(api_key)
1016
+ print("✨ API key retrieved and written to .env file.\n")
1017
+
1018
+ # Retrieve API key so we can continue running the job
1019
+ load_dotenv()
1020
+
855
1021
  if remote_inference := self.use_remote_inference(disable_remote_inference):
856
1022
  remote_job_creation_data = self.create_remote_inference_job(
857
- iterations=n, remote_inference_description=remote_inference_description
1023
+ iterations=n,
1024
+ remote_inference_description=remote_inference_description,
1025
+ remote_inference_results_visibility=remote_inference_results_visibility,
858
1026
  )
859
1027
  results = self.poll_remote_inference_job(remote_job_creation_data)
860
1028
  if results is None:
@@ -874,7 +1042,7 @@ class Jobs(Base):
874
1042
 
875
1043
  cache = Cache()
876
1044
 
877
- remote_cache = self.use_remote_cache()
1045
+ remote_cache = self.use_remote_cache(disable_remote_cache)
878
1046
  with RemoteCacheSync(
879
1047
  coop=Coop(),
880
1048
  cache=cache,
@@ -895,17 +1063,84 @@ class Jobs(Base):
895
1063
  results.cache = cache.new_entries_cache()
896
1064
  return results
897
1065
 
1066
+ async def create_and_poll_remote_job(
1067
+ self,
1068
+ iterations: int = 1,
1069
+ remote_inference_description: Optional[str] = None,
1070
+ remote_inference_results_visibility: Optional[
1071
+ Literal["private", "public", "unlisted"]
1072
+ ] = "unlisted",
1073
+ ) -> Union[Results, None]:
1074
+ """
1075
+ Creates and polls a remote inference job asynchronously.
1076
+ Reuses existing synchronous methods but runs them in an async context.
1077
+
1078
+ :param iterations: Number of times to run each interview
1079
+ :param remote_inference_description: Optional description for the remote job
1080
+ :param remote_inference_results_visibility: Visibility setting for results
1081
+ :return: Results object if successful, None if job fails or is cancelled
1082
+ """
1083
+ import asyncio
1084
+ from functools import partial
1085
+
1086
+ # Create job using existing method
1087
+ loop = asyncio.get_event_loop()
1088
+ remote_job_creation_data = await loop.run_in_executor(
1089
+ None,
1090
+ partial(
1091
+ self.create_remote_inference_job,
1092
+ iterations=iterations,
1093
+ remote_inference_description=remote_inference_description,
1094
+ remote_inference_results_visibility=remote_inference_results_visibility,
1095
+ ),
1096
+ )
1097
+
1098
+ # Poll using existing method but with async sleep
1099
+ return await loop.run_in_executor(
1100
+ None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
+ )
1102
+
1103
+ async def run_async(
1104
+ self,
1105
+ cache=None,
1106
+ n=1,
1107
+ disable_remote_inference: bool = False,
1108
+ remote_inference_description: Optional[str] = None,
1109
+ remote_inference_results_visibility: Optional[
1110
+ Literal["private", "public", "unlisted"]
1111
+ ] = "unlisted",
1112
+ **kwargs,
1113
+ ):
1114
+ """Run the job asynchronously, either locally or remotely.
1115
+
1116
+ :param cache: Cache object or boolean
1117
+ :param n: Number of iterations
1118
+ :param disable_remote_inference: If True, forces local execution
1119
+ :param remote_inference_description: Description for remote jobs
1120
+ :param remote_inference_results_visibility: Visibility setting for remote results
1121
+ :param kwargs: Additional arguments passed to local execution
1122
+ :return: Results object
1123
+ """
1124
+ # Check if we should use remote inference
1125
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
+ results = await self.create_and_poll_remote_job(
1127
+ iterations=n,
1128
+ remote_inference_description=remote_inference_description,
1129
+ remote_inference_results_visibility=remote_inference_results_visibility,
1130
+ )
1131
+ if results is None:
1132
+ self._output("Job failed.")
1133
+ return results
1134
+
1135
+ # If not using remote inference, run locally with async
1136
+ return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
1137
+
898
1138
  def _run_local(self, *args, **kwargs):
899
1139
  """Run the job locally."""
900
1140
 
901
1141
  results = JobsRunnerAsyncio(self).run(*args, **kwargs)
902
1142
  return results
903
1143
 
904
- async def run_async(self, cache=None, n=1, **kwargs):
905
- """Run asynchronously."""
906
- results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
907
- return results
908
-
909
1144
  def all_question_parameters(self):
910
1145
  """Return all the fields in the questions in the survey.
911
1146
  >>> from edsl.jobs import Jobs
@@ -110,9 +110,9 @@ class Interview:
110
110
  self.debug = debug
111
111
  self.iteration = iteration
112
112
  self.cache = cache
113
- self.answers: dict[str, str] = (
114
- Answers()
115
- ) # will get filled in as interview progresses
113
+ self.answers: dict[
114
+ str, str
115
+ ] = Answers() # will get filled in as interview progresses
116
116
  self.sidecar_model = sidecar_model
117
117
 
118
118
  # Trackers
@@ -143,9 +143,9 @@ class Interview:
143
143
  The keys are the question names; the values are the lists of status log changes for each task.
144
144
  """
145
145
  for task_creator in self.task_creators.values():
146
- self._task_status_log_dict[task_creator.question.question_name] = (
147
- task_creator.status_log
148
- )
146
+ self._task_status_log_dict[
147
+ task_creator.question.question_name
148
+ ] = task_creator.status_log
149
149
  return self._task_status_log_dict
150
150
 
151
151
  @property
@@ -178,7 +178,7 @@ class Interview:
178
178
  if include_exceptions:
179
179
  d["exceptions"] = self.exceptions.to_dict()
180
180
  return d
181
-
181
+
182
182
  @classmethod
183
183
  def from_dict(cls, d: dict[str, Any]) -> "Interview":
184
184
  """Return an Interview instance from a dictionary."""
@@ -187,13 +187,23 @@ class Interview:
187
187
  scenario = Scenario.from_dict(d["scenario"])
188
188
  model = LanguageModel.from_dict(d["model"])
189
189
  iteration = d["iteration"]
190
- return cls(agent=agent, survey=survey, scenario=scenario, model=model, iteration=iteration)
190
+ interview = cls(
191
+ agent=agent,
192
+ survey=survey,
193
+ scenario=scenario,
194
+ model=model,
195
+ iteration=iteration,
196
+ )
197
+ if "exceptions" in d:
198
+ exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
199
+ interview.exceptions = exceptions
200
+ return interview
191
201
 
192
202
  def __hash__(self) -> int:
193
203
  from edsl.utilities.utilities import dict_hash
194
204
 
195
205
  return dict_hash(self._to_dict(include_exceptions=False))
196
-
206
+
197
207
  def __eq__(self, other: "Interview") -> bool:
198
208
  """
199
209
  >>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
@@ -476,11 +486,11 @@ class Interview:
476
486
  """
477
487
  current_question_index: int = self.to_index[current_question.question_name]
478
488
 
479
- next_question: Union[int, EndOfSurvey] = (
480
- self.survey.rule_collection.next_question(
481
- q_now=current_question_index,
482
- answers=self.answers | self.scenario | self.agent["traits"],
483
- )
489
+ next_question: Union[
490
+ int, EndOfSurvey
491
+ ] = self.survey.rule_collection.next_question(
492
+ q_now=current_question_index,
493
+ answers=self.answers | self.scenario | self.agent["traits"],
484
494
  )
485
495
 
486
496
  next_question_index = next_question.next_q
@@ -33,7 +33,7 @@ class InterviewExceptionCollection(UserDict):
33
33
  """Return the collection of exceptions as a dictionary."""
34
34
  newdata = {k: [e.to_dict() for e in v] for k, v in self.data.items()}
35
35
  return newdata
36
-
36
+
37
37
  @classmethod
38
38
  def from_dict(cls, data: dict) -> "InterviewExceptionCollection":
39
39
  """Create an InterviewExceptionCollection from a dictionary."""