edsl 0.1.37__py3-none-any.whl → 0.1.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Agent.py +35 -86
  3. edsl/agents/AgentList.py +0 -5
  4. edsl/agents/InvigilatorBase.py +23 -2
  5. edsl/agents/PromptConstructor.py +105 -148
  6. edsl/agents/descriptors.py +4 -17
  7. edsl/conjure/AgentConstructionMixin.py +3 -11
  8. edsl/conversation/Conversation.py +14 -66
  9. edsl/coop/coop.py +14 -148
  10. edsl/data/Cache.py +1 -1
  11. edsl/exceptions/__init__.py +3 -7
  12. edsl/exceptions/agents.py +19 -17
  13. edsl/exceptions/results.py +8 -11
  14. edsl/exceptions/surveys.py +10 -13
  15. edsl/inference_services/AwsBedrock.py +2 -7
  16. edsl/inference_services/InferenceServicesCollection.py +9 -32
  17. edsl/jobs/Jobs.py +71 -306
  18. edsl/jobs/interviews/InterviewExceptionEntry.py +1 -5
  19. edsl/jobs/tasks/TaskHistory.py +0 -1
  20. edsl/language_models/LanguageModel.py +59 -47
  21. edsl/language_models/__init__.py +0 -1
  22. edsl/prompts/Prompt.py +4 -11
  23. edsl/questions/QuestionBase.py +13 -53
  24. edsl/questions/QuestionBasePromptsMixin.py +33 -1
  25. edsl/questions/QuestionFreeText.py +0 -1
  26. edsl/questions/QuestionFunctional.py +2 -2
  27. edsl/questions/descriptors.py +28 -23
  28. edsl/results/DatasetExportMixin.py +1 -25
  29. edsl/results/Result.py +1 -16
  30. edsl/results/Results.py +120 -31
  31. edsl/results/ResultsDBMixin.py +1 -1
  32. edsl/results/Selector.py +1 -18
  33. edsl/scenarios/Scenario.py +12 -48
  34. edsl/scenarios/ScenarioHtmlMixin.py +2 -7
  35. edsl/scenarios/ScenarioList.py +1 -12
  36. edsl/surveys/Rule.py +4 -10
  37. edsl/surveys/Survey.py +77 -100
  38. edsl/utilities/utilities.py +0 -18
  39. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/METADATA +1 -1
  40. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/RECORD +42 -46
  41. edsl/conversation/chips.py +0 -95
  42. edsl/exceptions/BaseException.py +0 -21
  43. edsl/exceptions/scenarios.py +0 -22
  44. edsl/language_models/KeyLookup.py +0 -30
  45. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/LICENSE +0 -0
  46. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/WHEEL +0 -0
@@ -16,48 +16,25 @@ class InferenceServicesCollection:
16
16
 
17
17
  @staticmethod
18
18
  def _get_service_available(service, warn: bool = False) -> list[str]:
19
+ from_api = True
19
20
  try:
20
21
  service_models = service.available()
21
- except Exception:
22
+ except Exception as e:
22
23
  if warn:
23
24
  warnings.warn(
24
25
  f"""Error getting models for {service._inference_service_}.
25
26
  Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
27
  See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on Coop.""",
28
+ Relying on cache.""",
28
29
  UserWarning,
29
30
  )
31
+ from edsl.inference_services.models_available_cache import models_available
30
32
 
31
- # Use the list of models on Coop as a fallback
32
- try:
33
- from edsl import Coop
34
-
35
- c = Coop()
36
- models_from_coop = c.fetch_models()
37
- service_models = models_from_coop.get(service._inference_service_, [])
38
-
39
- # cache results
40
- service._models_list_cache = service_models
41
-
42
- # Finally, use the available models cache from the Python file
43
- except Exception:
44
- if warn:
45
- warnings.warn(
46
- f"""Error getting models for {service._inference_service_}.
47
- Relying on EDSL cache.""",
48
- UserWarning,
49
- )
50
-
51
- from edsl.inference_services.models_available_cache import (
52
- models_available,
53
- )
54
-
55
- service_models = models_available.get(service._inference_service_, [])
56
-
57
- # cache results
58
- service._models_list_cache = service_models
59
-
60
- return service_models
33
+ service_models = models_available.get(service._inference_service_, [])
34
+ # cache results
35
+ service._models_list_cache = service_models
36
+ from_api = False
37
+ return service_models # , from_api
61
38
 
62
39
  def available(self):
63
40
  total_models = []
edsl/jobs/Jobs.py CHANGED
@@ -3,10 +3,9 @@ from __future__ import annotations
3
3
  import warnings
4
4
  import requests
5
5
  from itertools import product
6
- from typing import Literal, Optional, Union, Sequence, Generator
6
+ from typing import Optional, Union, Sequence, Generator
7
7
 
8
8
  from edsl.Base import Base
9
-
10
9
  from edsl.exceptions import MissingAPIKeyError
11
10
  from edsl.jobs.buckets.BucketCollection import BucketCollection
12
11
  from edsl.jobs.interviews.Interview import Interview
@@ -194,7 +193,7 @@ class Jobs(Base):
194
193
  inference_service=invigilator.model._inference_service_,
195
194
  model=invigilator.model.model,
196
195
  )
197
- costs.append(prompt_cost["cost_usd"])
196
+ costs.append(prompt_cost["cost"])
198
197
 
199
198
  d = Dataset(
200
199
  [
@@ -210,14 +209,14 @@ class Jobs(Base):
210
209
  )
211
210
  return d
212
211
 
213
- def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
212
+ def show_prompts(self, all=False) -> None:
214
213
  """Print the prompts."""
215
214
  if all:
216
- self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
215
+ self.prompts().to_scenario_list().print(format="rich")
217
216
  else:
218
217
  self.prompts().select(
219
218
  "user_prompt", "system_prompt"
220
- ).to_scenario_list().print(format="rich", max_rows=max_rows)
219
+ ).to_scenario_list().print(format="rich")
221
220
 
222
221
  @staticmethod
223
222
  def estimate_prompt_cost(
@@ -228,7 +227,6 @@ class Jobs(Base):
228
227
  model: str,
229
228
  ) -> dict:
230
229
  """Estimates the cost of a prompt. Takes piping into account."""
231
- import math
232
230
 
233
231
  def get_piping_multiplier(prompt: str):
234
232
  """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
@@ -242,25 +240,10 @@ class Jobs(Base):
242
240
 
243
241
  try:
244
242
  relevant_prices = price_lookup[key]
245
-
246
- service_input_token_price = float(
247
- relevant_prices["input"]["service_stated_token_price"]
248
- )
249
- service_input_token_qty = float(
250
- relevant_prices["input"]["service_stated_token_qty"]
251
- )
252
- input_price_per_token = service_input_token_price / service_input_token_qty
253
-
254
- service_output_token_price = float(
255
- relevant_prices["output"]["service_stated_token_price"]
256
- )
257
- service_output_token_qty = float(
258
- relevant_prices["output"]["service_stated_token_qty"]
259
- )
260
- output_price_per_token = (
261
- service_output_token_price / service_output_token_qty
243
+ output_price_per_token = 1 / float(
244
+ relevant_prices["output"]["one_usd_buys"]
262
245
  )
263
-
246
+ input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
264
247
  except KeyError:
265
248
  # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
266
249
  # Use a sensible default
@@ -270,8 +253,9 @@ class Jobs(Base):
270
253
  warnings.warn(
271
254
  "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
272
255
  )
273
- input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
- output_price_per_token = 0.00000060 # $0.60 / 1M tokens
256
+
257
+ output_price_per_token = 0.00000015 # $0.15 / 1M tokens
258
+ input_price_per_token = 0.00000060 # $0.60 / 1M tokens
275
259
 
276
260
  # Compute the number of characters (double if the question involves piping)
277
261
  user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
@@ -283,8 +267,7 @@ class Jobs(Base):
283
267
 
284
268
  # Convert into tokens (1 token approx. equals 4 characters)
285
269
  input_tokens = (user_prompt_chars + system_prompt_chars) // 4
286
-
287
- output_tokens = math.ceil(0.75 * input_tokens)
270
+ output_tokens = input_tokens
288
271
 
289
272
  cost = (
290
273
  input_tokens * input_price_per_token
@@ -294,17 +277,15 @@ class Jobs(Base):
294
277
  return {
295
278
  "input_tokens": input_tokens,
296
279
  "output_tokens": output_tokens,
297
- "cost_usd": cost,
280
+ "cost": cost,
298
281
  }
299
282
 
300
- def estimate_job_cost_from_external_prices(
301
- self, price_lookup: dict, iterations: int = 1
302
- ) -> dict:
283
+ def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
303
284
  """
304
285
  Estimates the cost of a job according to the following assumptions:
305
286
 
306
287
  - 1 token = 4 characters.
307
- - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
288
+ - Input tokens = output tokens.
308
289
 
309
290
  price_lookup is an external pricing dictionary.
310
291
  """
@@ -341,7 +322,7 @@ class Jobs(Base):
341
322
  "system_prompt": system_prompt,
342
323
  "estimated_input_tokens": prompt_cost["input_tokens"],
343
324
  "estimated_output_tokens": prompt_cost["output_tokens"],
344
- "estimated_cost_usd": prompt_cost["cost_usd"],
325
+ "estimated_cost": prompt_cost["cost"],
345
326
  "inference_service": inference_service,
346
327
  "model": model,
347
328
  }
@@ -353,21 +334,18 @@ class Jobs(Base):
353
334
  df.groupby(["inference_service", "model"])
354
335
  .agg(
355
336
  {
356
- "estimated_cost_usd": "sum",
337
+ "estimated_cost": "sum",
357
338
  "estimated_input_tokens": "sum",
358
339
  "estimated_output_tokens": "sum",
359
340
  }
360
341
  )
361
342
  .reset_index()
362
343
  )
363
- df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
- df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
- df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
366
344
 
367
345
  estimated_costs_by_model = df.to_dict("records")
368
346
 
369
347
  estimated_total_cost = sum(
370
- model["estimated_cost_usd"] for model in estimated_costs_by_model
348
+ model["estimated_cost"] for model in estimated_costs_by_model
371
349
  )
372
350
  estimated_total_input_tokens = sum(
373
351
  model["estimated_input_tokens"] for model in estimated_costs_by_model
@@ -377,7 +355,7 @@ class Jobs(Base):
377
355
  )
378
356
 
379
357
  output = {
380
- "estimated_total_cost_usd": estimated_total_cost,
358
+ "estimated_total_cost": estimated_total_cost,
381
359
  "estimated_total_input_tokens": estimated_total_input_tokens,
382
360
  "estimated_total_output_tokens": estimated_total_output_tokens,
383
361
  "model_costs": estimated_costs_by_model,
@@ -385,12 +363,12 @@ class Jobs(Base):
385
363
 
386
364
  return output
387
365
 
388
- def estimate_job_cost(self, iterations: int = 1) -> dict:
366
+ def estimate_job_cost(self) -> dict:
389
367
  """
390
368
  Estimates the cost of a job according to the following assumptions:
391
369
 
392
370
  - 1 token = 4 characters.
393
- - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
371
+ - Input tokens = output tokens.
394
372
 
395
373
  Fetches prices from Coop.
396
374
  """
@@ -399,9 +377,7 @@ class Jobs(Base):
399
377
  c = Coop()
400
378
  price_lookup = c.fetch_prices()
401
379
 
402
- return self.estimate_job_cost_from_external_prices(
403
- price_lookup=price_lookup, iterations=iterations
404
- )
380
+ return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
405
381
 
406
382
  @staticmethod
407
383
  def compute_job_cost(job_results: "Results") -> float:
@@ -723,11 +699,7 @@ class Jobs(Base):
723
699
  return self._raise_validation_errors
724
700
 
725
701
  def create_remote_inference_job(
726
- self,
727
- iterations: int = 1,
728
- remote_inference_description: Optional[str] = None,
729
- remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
- verbose=False,
702
+ self, iterations: int = 1, remote_inference_description: Optional[str] = None
731
703
  ):
732
704
  """ """
733
705
  from edsl.coop.coop import Coop
@@ -739,11 +711,9 @@ class Jobs(Base):
739
711
  description=remote_inference_description,
740
712
  status="queued",
741
713
  iterations=iterations,
742
- initial_results_visibility=remote_inference_results_visibility,
743
714
  )
744
715
  job_uuid = remote_job_creation_data.get("uuid")
745
- if self.verbose:
746
- print(f"Job sent to server. (Job uuid={job_uuid}).")
716
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
747
717
  return remote_job_creation_data
748
718
 
749
719
  @staticmethod
@@ -754,7 +724,7 @@ class Jobs(Base):
754
724
  return coop.remote_inference_get(job_uuid)
755
725
 
756
726
  def poll_remote_inference_job(
757
- self, remote_job_creation_data: dict, verbose=False, poll_interval=5
727
+ self, remote_job_creation_data: dict
758
728
  ) -> Union[Results, None]:
759
729
  from edsl.coop.coop import Coop
760
730
  import time
@@ -771,46 +741,42 @@ class Jobs(Base):
771
741
  remote_job_data = coop.remote_inference_get(job_uuid)
772
742
  status = remote_job_data.get("status")
773
743
  if status == "cancelled":
774
- if self.verbose:
775
- print("\r" + " " * 80 + "\r", end="")
776
- print("Job cancelled by the user.")
777
- print(
778
- f"See {expected_parrot_url}/home/remote-inference for more details."
779
- )
744
+ print("\r" + " " * 80 + "\r", end="")
745
+ print("Job cancelled by the user.")
746
+ print(
747
+ f"See {expected_parrot_url}/home/remote-inference for more details."
748
+ )
780
749
  return None
781
750
  elif status == "failed":
782
- if self.verbose:
783
- print("\r" + " " * 80 + "\r", end="")
784
- print("Job failed.")
785
- print(
786
- f"See {expected_parrot_url}/home/remote-inference for more details."
787
- )
751
+ print("\r" + " " * 80 + "\r", end="")
752
+ print("Job failed.")
753
+ print(
754
+ f"See {expected_parrot_url}/home/remote-inference for more details."
755
+ )
788
756
  return None
789
757
  elif status == "completed":
790
758
  results_uuid = remote_job_data.get("results_uuid")
791
759
  results = coop.get(results_uuid, expected_object_type="results")
792
- if self.verbose:
793
- print("\r" + " " * 80 + "\r", end="")
794
- url = f"{expected_parrot_url}/content/{results_uuid}"
795
- print(f"Job completed and Results stored on Coop: {url}.")
760
+ print("\r" + " " * 80 + "\r", end="")
761
+ url = f"{expected_parrot_url}/content/{results_uuid}"
762
+ print(f"Job completed and Results stored on Coop: {url}.")
796
763
  return results
797
764
  else:
798
- duration = poll_interval
765
+ duration = 5
799
766
  time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
800
767
  frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
801
768
  start_time = time.time()
802
769
  i = 0
803
770
  while time.time() - start_time < duration:
804
- if self.verbose:
805
- print(
806
- f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
- end="",
808
- flush=True,
809
- )
771
+ print(
772
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
773
+ end="",
774
+ flush=True,
775
+ )
810
776
  time.sleep(0.1)
811
777
  i += 1
812
778
 
813
- def use_remote_inference(self, disable_remote_inference: bool) -> bool:
779
+ def use_remote_inference(self, disable_remote_inference: bool):
814
780
  if disable_remote_inference:
815
781
  return False
816
782
  if not disable_remote_inference:
@@ -826,23 +792,20 @@ class Jobs(Base):
826
792
 
827
793
  return False
828
794
 
829
- def use_remote_cache(self, disable_remote_cache: bool) -> bool:
830
- if disable_remote_cache:
831
- return False
832
- if not disable_remote_cache:
833
- try:
834
- from edsl import Coop
795
+ def use_remote_cache(self):
796
+ try:
797
+ from edsl import Coop
835
798
 
836
- user_edsl_settings = Coop().edsl_settings
837
- return user_edsl_settings.get("remote_caching", False)
838
- except requests.ConnectionError:
839
- pass
840
- except CoopServerResponseError as e:
841
- pass
799
+ user_edsl_settings = Coop().edsl_settings
800
+ return user_edsl_settings.get("remote_caching", False)
801
+ except requests.ConnectionError:
802
+ pass
803
+ except CoopServerResponseError as e:
804
+ pass
842
805
 
843
806
  return False
844
807
 
845
- def check_api_keys(self) -> None:
808
+ def check_api_keys(self):
846
809
  from edsl import Model
847
810
 
848
811
  for model in self.models + [Model()]:
@@ -852,86 +815,6 @@ class Jobs(Base):
852
815
  inference_service=model._inference_service_,
853
816
  )
854
817
 
855
- def get_missing_api_keys(self) -> set:
856
- """
857
- Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
- """
859
-
860
- missing_api_keys = set()
861
-
862
- from edsl import Model
863
- from edsl.enums import service_to_api_keyname
864
-
865
- for model in self.models + [Model()]:
866
- if not model.has_valid_api_key():
867
- key_name = service_to_api_keyname.get(
868
- model._inference_service_, "NOT FOUND"
869
- )
870
- missing_api_keys.add(key_name)
871
-
872
- return missing_api_keys
873
-
874
- def user_has_all_model_keys(self):
875
- """
876
- Returns True if the user has all model keys required to run their job.
877
-
878
- Otherwise, returns False.
879
- """
880
-
881
- try:
882
- self.check_api_keys()
883
- return True
884
- except MissingAPIKeyError:
885
- return False
886
- except Exception:
887
- raise
888
-
889
- def user_has_ep_api_key(self) -> bool:
890
- """
891
- Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
-
893
- Otherwise, returns False.
894
- """
895
-
896
- import os
897
-
898
- coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
-
900
- if coop_api_key is not None:
901
- return True
902
- else:
903
- return False
904
-
905
- def needs_external_llms(self) -> bool:
906
- """
907
- Returns True if the job needs external LLMs to run.
908
-
909
- Otherwise, returns False.
910
- """
911
- # These cases are necessary to skip the API key check during doctests
912
-
913
- # Accounts for Results.example()
914
- all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
- [hasattr(a, "answer_question_directly") for a in self.agents]
916
- )
917
-
918
- # Accounts for InterviewExceptionEntry.example()
919
- only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
-
921
- # Accounts for Survey.__call__
922
- all_questions_are_functional = set(
923
- [q.question_type for q in self.survey.questions]
924
- ) == set(["functional"])
925
-
926
- if (
927
- all_agents_answer_questions_directly
928
- or only_model_is_test
929
- or all_questions_are_functional
930
- ):
931
- return False
932
- else:
933
- return True
934
-
935
818
  def run(
936
819
  self,
937
820
  n: int = 1,
@@ -944,28 +827,22 @@ class Jobs(Base):
944
827
  print_exceptions=True,
945
828
  remote_cache_description: Optional[str] = None,
946
829
  remote_inference_description: Optional[str] = None,
947
- remote_inference_results_visibility: Optional[
948
- Literal["private", "public", "unlisted"]
949
- ] = "unlisted",
950
830
  skip_retry: bool = False,
951
831
  raise_validation_errors: bool = False,
952
- disable_remote_cache: bool = False,
953
832
  disable_remote_inference: bool = False,
954
833
  ) -> Results:
955
834
  """
956
835
  Runs the Job: conducts Interviews and returns their results.
957
836
 
958
- :param n: How many times to run each interview
959
- :param progress_bar: Whether to show a progress bar
960
- :param stop_on_exception: Stops the job if an exception is raised
961
- :param cache: A Cache object to store results
962
- :param check_api_keys: Raises an error if API keys are invalid
963
- :param verbose: Prints extra messages
964
- :param remote_cache_description: Specifies a description for this group of entries in the remote cache
965
- :param remote_inference_description: Specifies a description for the remote inference job
966
- :param remote_inference_results_visibility: The initial visibility of the Results object on Coop. This will only be used for remote jobs!
967
- :param disable_remote_cache: If True, the job will not use remote cache. This only works for local jobs!
968
- :param disable_remote_inference: If True, the job will not use remote inference
837
+ :param n: how many times to run each interview
838
+ :param progress_bar: shows a progress bar
839
+ :param stop_on_exception: stops the job if an exception is raised
840
+ :param cache: a cache object to store results
841
+ :param check_api_keys: check if the API keys are valid
842
+ :param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
843
+ :param verbose: prints messages
844
+ :param remote_cache_description: specifies a description for this group of entries in the remote cache
845
+ :param remote_inference_description: specifies a description for the remote inference job
969
846
  """
970
847
  from edsl.coop.coop import Coop
971
848
 
@@ -975,54 +852,9 @@ class Jobs(Base):
975
852
 
976
853
  self.verbose = verbose
977
854
 
978
- if (
979
- not self.user_has_all_model_keys()
980
- and not self.user_has_ep_api_key()
981
- and self.needs_external_llms()
982
- ):
983
- import secrets
984
- from dotenv import load_dotenv
985
- from edsl import CONFIG
986
- from edsl.coop.coop import Coop
987
- from edsl.utilities.utilities import write_api_key_to_env
988
-
989
- missing_api_keys = self.get_missing_api_keys()
990
-
991
- edsl_auth_token = secrets.token_urlsafe(16)
992
-
993
- print("You're missing some of the API keys needed to run this job:")
994
- for api_key in missing_api_keys:
995
- print(f" 🔑 {api_key}")
996
- print(
997
- "\nYou can either add the missing keys to your .env file, or use remote inference."
998
- )
999
- print("Remote inference allows you to run jobs on our server.")
1000
- print("\n🚀 To use remote inference, sign up at the following link:")
1001
-
1002
- coop = Coop()
1003
- coop._display_login_url(edsl_auth_token=edsl_auth_token)
1004
-
1005
- print(
1006
- "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
- )
1008
-
1009
- api_key = coop._poll_for_api_key(edsl_auth_token)
1010
-
1011
- if api_key is None:
1012
- print("\nTimed out waiting for login. Please try again.")
1013
- return
1014
-
1015
- write_api_key_to_env(api_key)
1016
- print("✨ API key retrieved and written to .env file.\n")
1017
-
1018
- # Retrieve API key so we can continue running the job
1019
- load_dotenv()
1020
-
1021
855
  if remote_inference := self.use_remote_inference(disable_remote_inference):
1022
856
  remote_job_creation_data = self.create_remote_inference_job(
1023
- iterations=n,
1024
- remote_inference_description=remote_inference_description,
1025
- remote_inference_results_visibility=remote_inference_results_visibility,
857
+ iterations=n, remote_inference_description=remote_inference_description
1026
858
  )
1027
859
  results = self.poll_remote_inference_job(remote_job_creation_data)
1028
860
  if results is None:
@@ -1042,7 +874,7 @@ class Jobs(Base):
1042
874
 
1043
875
  cache = Cache()
1044
876
 
1045
- remote_cache = self.use_remote_cache(disable_remote_cache)
877
+ remote_cache = self.use_remote_cache()
1046
878
  with RemoteCacheSync(
1047
879
  coop=Coop(),
1048
880
  cache=cache,
@@ -1063,84 +895,17 @@ class Jobs(Base):
1063
895
  results.cache = cache.new_entries_cache()
1064
896
  return results
1065
897
 
1066
- async def create_and_poll_remote_job(
1067
- self,
1068
- iterations: int = 1,
1069
- remote_inference_description: Optional[str] = None,
1070
- remote_inference_results_visibility: Optional[
1071
- Literal["private", "public", "unlisted"]
1072
- ] = "unlisted",
1073
- ) -> Union[Results, None]:
1074
- """
1075
- Creates and polls a remote inference job asynchronously.
1076
- Reuses existing synchronous methods but runs them in an async context.
1077
-
1078
- :param iterations: Number of times to run each interview
1079
- :param remote_inference_description: Optional description for the remote job
1080
- :param remote_inference_results_visibility: Visibility setting for results
1081
- :return: Results object if successful, None if job fails or is cancelled
1082
- """
1083
- import asyncio
1084
- from functools import partial
1085
-
1086
- # Create job using existing method
1087
- loop = asyncio.get_event_loop()
1088
- remote_job_creation_data = await loop.run_in_executor(
1089
- None,
1090
- partial(
1091
- self.create_remote_inference_job,
1092
- iterations=iterations,
1093
- remote_inference_description=remote_inference_description,
1094
- remote_inference_results_visibility=remote_inference_results_visibility,
1095
- ),
1096
- )
1097
-
1098
- # Poll using existing method but with async sleep
1099
- return await loop.run_in_executor(
1100
- None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
- )
1102
-
1103
- async def run_async(
1104
- self,
1105
- cache=None,
1106
- n=1,
1107
- disable_remote_inference: bool = False,
1108
- remote_inference_description: Optional[str] = None,
1109
- remote_inference_results_visibility: Optional[
1110
- Literal["private", "public", "unlisted"]
1111
- ] = "unlisted",
1112
- **kwargs,
1113
- ):
1114
- """Run the job asynchronously, either locally or remotely.
1115
-
1116
- :param cache: Cache object or boolean
1117
- :param n: Number of iterations
1118
- :param disable_remote_inference: If True, forces local execution
1119
- :param remote_inference_description: Description for remote jobs
1120
- :param remote_inference_results_visibility: Visibility setting for remote results
1121
- :param kwargs: Additional arguments passed to local execution
1122
- :return: Results object
1123
- """
1124
- # Check if we should use remote inference
1125
- if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
- results = await self.create_and_poll_remote_job(
1127
- iterations=n,
1128
- remote_inference_description=remote_inference_description,
1129
- remote_inference_results_visibility=remote_inference_results_visibility,
1130
- )
1131
- if results is None:
1132
- self._output("Job failed.")
1133
- return results
1134
-
1135
- # If not using remote inference, run locally with async
1136
- return await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
1137
-
1138
898
  def _run_local(self, *args, **kwargs):
1139
899
  """Run the job locally."""
1140
900
 
1141
901
  results = JobsRunnerAsyncio(self).run(*args, **kwargs)
1142
902
  return results
1143
903
 
904
+ async def run_async(self, cache=None, n=1, **kwargs):
905
+ """Run asynchronously."""
906
+ results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
907
+ return results
908
+
1144
909
  def all_question_parameters(self):
1145
910
  """Return all the fields in the questions in the survey.
1146
911
  >>> from edsl.jobs import Jobs
@@ -67,11 +67,7 @@ class InterviewExceptionEntry:
67
67
  m = LanguageModel.example(test_model=True)
68
68
  q = QuestionFreeText.example(exception_to_throw=ValueError)
69
69
  results = q.by(m).run(
70
- skip_retry=True,
71
- print_exceptions=False,
72
- raise_validation_errors=True,
73
- disable_remote_cache=True,
74
- disable_remote_inference=True,
70
+ skip_retry=True, print_exceptions=False, raise_validation_errors=True
75
71
  )
76
72
  return results.task_history.exceptions[0]["how_are_you"][0]
77
73
 
@@ -39,7 +39,6 @@ class TaskHistory:
39
39
  skip_retry=True,
40
40
  cache=False,
41
41
  raise_validation_errors=True,
42
- disable_remote_cache=True,
43
42
  disable_remote_inference=True,
44
43
  )
45
44