edsl 0.1.38__py3-none-any.whl → 0.1.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. edsl/Base.py +31 -60
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +9 -18
  4. edsl/agents/AgentList.py +8 -59
  5. edsl/agents/Invigilator.py +7 -18
  6. edsl/agents/InvigilatorBase.py +19 -0
  7. edsl/agents/PromptConstructor.py +4 -5
  8. edsl/config.py +0 -8
  9. edsl/coop/coop.py +7 -74
  10. edsl/data/Cache.py +2 -27
  11. edsl/data/CacheEntry.py +3 -8
  12. edsl/data/RemoteCacheSync.py +19 -0
  13. edsl/enums.py +0 -2
  14. edsl/inference_services/GoogleService.py +15 -7
  15. edsl/inference_services/registry.py +0 -2
  16. edsl/jobs/Jobs.py +548 -88
  17. edsl/jobs/interviews/Interview.py +11 -11
  18. edsl/jobs/runners/JobsRunnerAsyncio.py +35 -140
  19. edsl/jobs/runners/JobsRunnerStatus.py +2 -0
  20. edsl/jobs/tasks/TaskHistory.py +16 -15
  21. edsl/language_models/LanguageModel.py +84 -44
  22. edsl/language_models/ModelList.py +1 -47
  23. edsl/language_models/registry.py +4 -57
  24. edsl/prompts/Prompt.py +3 -8
  25. edsl/questions/QuestionBase.py +16 -20
  26. edsl/questions/QuestionExtract.py +4 -3
  27. edsl/questions/question_registry.py +6 -36
  28. edsl/results/Dataset.py +15 -146
  29. edsl/results/DatasetExportMixin.py +217 -231
  30. edsl/results/DatasetTree.py +4 -134
  31. edsl/results/Result.py +9 -18
  32. edsl/results/Results.py +51 -145
  33. edsl/scenarios/FileStore.py +13 -187
  34. edsl/scenarios/Scenario.py +4 -61
  35. edsl/scenarios/ScenarioList.py +62 -237
  36. edsl/surveys/Survey.py +2 -16
  37. edsl/surveys/SurveyFlowVisualizationMixin.py +9 -67
  38. edsl/surveys/instructions/Instruction.py +0 -12
  39. edsl/templates/error_reporting/interview_details.html +3 -3
  40. edsl/templates/error_reporting/interviews.html +9 -18
  41. edsl/utilities/utilities.py +0 -15
  42. {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/METADATA +1 -2
  43. {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/RECORD +45 -53
  44. edsl/inference_services/PerplexityService.py +0 -163
  45. edsl/jobs/JobsChecks.py +0 -147
  46. edsl/jobs/JobsPrompts.py +0 -268
  47. edsl/jobs/JobsRemoteInferenceHandler.py +0 -239
  48. edsl/results/CSSParameterizer.py +0 -108
  49. edsl/results/TableDisplay.py +0 -198
  50. edsl/results/table_display.css +0 -78
  51. edsl/scenarios/ScenarioJoin.py +0 -127
  52. {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/LICENSE +0 -0
  53. {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import warnings
4
4
  import requests
5
5
  from itertools import product
6
- from typing import Literal, Optional, Union, Sequence, Generator, TYPE_CHECKING
6
+ from typing import Literal, Optional, Union, Sequence, Generator
7
7
 
8
8
  from edsl.Base import Base
9
9
 
@@ -11,20 +11,11 @@ from edsl.exceptions import MissingAPIKeyError
11
11
  from edsl.jobs.buckets.BucketCollection import BucketCollection
12
12
  from edsl.jobs.interviews.Interview import Interview
13
13
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
14
- from edsl.utilities.decorators import remove_edsl_version
14
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
15
15
 
16
16
  from edsl.data.RemoteCacheSync import RemoteCacheSync
17
17
  from edsl.exceptions.coop import CoopServerResponseError
18
18
 
19
- if TYPE_CHECKING:
20
- from edsl.agents.Agent import Agent
21
- from edsl.agents.AgentList import AgentList
22
- from edsl.language_models.LanguageModel import LanguageModel
23
- from edsl.scenarios.Scenario import Scenario
24
- from edsl.surveys.Survey import Survey
25
- from edsl.results.Results import Results
26
- from edsl.results.Dataset import Dataset
27
-
28
19
 
29
20
  class Jobs(Base):
30
21
  """
@@ -33,8 +24,6 @@ class Jobs(Base):
33
24
  The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
34
25
  """
35
26
 
36
- __documentation__ = "https://docs.expectedparrot.com/en/latest/jobs.html"
37
-
38
27
  def __init__(
39
28
  self,
40
29
  survey: "Survey",
@@ -97,14 +86,8 @@ class Jobs(Base):
97
86
  @scenarios.setter
98
87
  def scenarios(self, value):
99
88
  from edsl import ScenarioList
100
- from edsl.results.Dataset import Dataset
101
89
 
102
90
  if value:
103
- if isinstance(
104
- value, Dataset
105
- ): # if the user passes in a Dataset, convert it to a ScenarioList
106
- value = value.to_scenario_list()
107
-
108
91
  if not isinstance(value, ScenarioList):
109
92
  self._scenarios = ScenarioList(value)
110
93
  else:
@@ -144,13 +127,6 @@ class Jobs(Base):
144
127
  - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
145
128
  - models: new models overwrite old models.
146
129
  """
147
- from edsl.results.Dataset import Dataset
148
-
149
- if isinstance(
150
- args[0], Dataset
151
- ): # let the user user a Dataset as if it were a ScenarioList
152
- args = args[0].to_scenario_list()
153
-
154
130
  passed_objects = self._turn_args_to_list(
155
131
  args
156
132
  ) # objects can also be passed comma-separated
@@ -175,19 +151,73 @@ class Jobs(Base):
175
151
  >>> Jobs.example().prompts()
176
152
  Dataset(...)
177
153
  """
178
- from edsl.jobs.JobsPrompts import JobsPrompts
154
+ from edsl import Coop
155
+
156
+ c = Coop()
157
+ price_lookup = c.fetch_prices()
158
+
159
+ interviews = self.interviews()
160
+ # data = []
161
+ interview_indices = []
162
+ question_names = []
163
+ user_prompts = []
164
+ system_prompts = []
165
+ scenario_indices = []
166
+ agent_indices = []
167
+ models = []
168
+ costs = []
169
+ from edsl.results.Dataset import Dataset
170
+
171
+ for interview_index, interview in enumerate(interviews):
172
+ invigilators = [
173
+ interview._get_invigilator(question)
174
+ for question in self.survey.questions
175
+ ]
176
+ for _, invigilator in enumerate(invigilators):
177
+ prompts = invigilator.get_prompts()
178
+ user_prompt = prompts["user_prompt"]
179
+ system_prompt = prompts["system_prompt"]
180
+ user_prompts.append(user_prompt)
181
+ system_prompts.append(system_prompt)
182
+ agent_index = self.agents.index(invigilator.agent)
183
+ agent_indices.append(agent_index)
184
+ interview_indices.append(interview_index)
185
+ scenario_index = self.scenarios.index(invigilator.scenario)
186
+ scenario_indices.append(scenario_index)
187
+ models.append(invigilator.model.model)
188
+ question_names.append(invigilator.question.question_name)
189
+
190
+ prompt_cost = self.estimate_prompt_cost(
191
+ system_prompt=system_prompt,
192
+ user_prompt=user_prompt,
193
+ price_lookup=price_lookup,
194
+ inference_service=invigilator.model._inference_service_,
195
+ model=invigilator.model.model,
196
+ )
197
+ costs.append(prompt_cost["cost_usd"])
179
198
 
180
- j = JobsPrompts(self)
181
- return j.prompts()
199
+ d = Dataset(
200
+ [
201
+ {"user_prompt": user_prompts},
202
+ {"system_prompt": system_prompts},
203
+ {"interview_index": interview_indices},
204
+ {"question_name": question_names},
205
+ {"scenario_index": scenario_indices},
206
+ {"agent_index": agent_indices},
207
+ {"model": models},
208
+ {"estimated_cost": costs},
209
+ ]
210
+ )
211
+ return d
182
212
 
183
- def show_prompts(self, all=False) -> None:
213
+ def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
184
214
  """Print the prompts."""
185
215
  if all:
186
- return self.prompts().to_scenario_list().table()
216
+ self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
187
217
  else:
188
- return (
189
- self.prompts().to_scenario_list().table("user_prompt", "system_prompt")
190
- )
218
+ self.prompts().select(
219
+ "user_prompt", "system_prompt"
220
+ ).to_scenario_list().print(format="rich", max_rows=max_rows)
191
221
 
192
222
  @staticmethod
193
223
  def estimate_prompt_cost(
@@ -196,42 +226,201 @@ class Jobs(Base):
196
226
  price_lookup: dict,
197
227
  inference_service: str,
198
228
  model: str,
229
+ ) -> dict:
230
+ """Estimates the cost of a prompt. Takes piping into account."""
231
+ import math
232
+
233
+ def get_piping_multiplier(prompt: str):
234
+ """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
235
+
236
+ if "{{" in prompt and "}}" in prompt:
237
+ return 2
238
+ return 1
239
+
240
+ # Look up prices per token
241
+ key = (inference_service, model)
242
+
243
+ try:
244
+ relevant_prices = price_lookup[key]
245
+
246
+ service_input_token_price = float(
247
+ relevant_prices["input"]["service_stated_token_price"]
248
+ )
249
+ service_input_token_qty = float(
250
+ relevant_prices["input"]["service_stated_token_qty"]
251
+ )
252
+ input_price_per_token = service_input_token_price / service_input_token_qty
253
+
254
+ service_output_token_price = float(
255
+ relevant_prices["output"]["service_stated_token_price"]
256
+ )
257
+ service_output_token_qty = float(
258
+ relevant_prices["output"]["service_stated_token_qty"]
259
+ )
260
+ output_price_per_token = (
261
+ service_output_token_price / service_output_token_qty
262
+ )
263
+
264
+ except KeyError:
265
+ # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
266
+ # Use a sensible default
267
+
268
+ import warnings
269
+
270
+ warnings.warn(
271
+ "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
272
+ )
273
+ input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
+ output_price_per_token = 0.00000060 # $0.60 / 1M tokens
275
+
276
+ # Compute the number of characters (double if the question involves piping)
277
+ user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
278
+ str(user_prompt)
279
+ )
280
+ system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
281
+ str(system_prompt)
282
+ )
283
+
284
+ # Convert into tokens (1 token approx. equals 4 characters)
285
+ input_tokens = (user_prompt_chars + system_prompt_chars) // 4
286
+
287
+ output_tokens = math.ceil(0.75 * input_tokens)
288
+
289
+ cost = (
290
+ input_tokens * input_price_per_token
291
+ + output_tokens * output_price_per_token
292
+ )
293
+
294
+ return {
295
+ "input_tokens": input_tokens,
296
+ "output_tokens": output_tokens,
297
+ "cost_usd": cost,
298
+ }
299
+
300
+ def estimate_job_cost_from_external_prices(
301
+ self, price_lookup: dict, iterations: int = 1
199
302
  ) -> dict:
200
303
  """
201
- Estimate the cost of running the prompts.
202
- :param iterations: the number of iterations to run
304
+ Estimates the cost of a job according to the following assumptions:
305
+
306
+ - 1 token = 4 characters.
307
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
308
+
309
+ price_lookup is an external pricing dictionary.
203
310
  """
204
- from edsl.jobs.JobsPrompts import JobsPrompts
205
311
 
206
- return JobsPrompts.estimate_prompt_cost(
207
- system_prompt, user_prompt, price_lookup, inference_service, model
312
+ import pandas as pd
313
+
314
+ interviews = self.interviews()
315
+ data = []
316
+ for interview in interviews:
317
+ invigilators = [
318
+ interview._get_invigilator(question)
319
+ for question in self.survey.questions
320
+ ]
321
+ for invigilator in invigilators:
322
+ prompts = invigilator.get_prompts()
323
+
324
+ # By this point, agent and scenario data has already been added to the prompts
325
+ user_prompt = prompts["user_prompt"]
326
+ system_prompt = prompts["system_prompt"]
327
+ inference_service = invigilator.model._inference_service_
328
+ model = invigilator.model.model
329
+
330
+ prompt_cost = self.estimate_prompt_cost(
331
+ system_prompt=system_prompt,
332
+ user_prompt=user_prompt,
333
+ price_lookup=price_lookup,
334
+ inference_service=inference_service,
335
+ model=model,
336
+ )
337
+
338
+ data.append(
339
+ {
340
+ "user_prompt": user_prompt,
341
+ "system_prompt": system_prompt,
342
+ "estimated_input_tokens": prompt_cost["input_tokens"],
343
+ "estimated_output_tokens": prompt_cost["output_tokens"],
344
+ "estimated_cost_usd": prompt_cost["cost_usd"],
345
+ "inference_service": inference_service,
346
+ "model": model,
347
+ }
348
+ )
349
+
350
+ df = pd.DataFrame.from_records(data)
351
+
352
+ df = (
353
+ df.groupby(["inference_service", "model"])
354
+ .agg(
355
+ {
356
+ "estimated_cost_usd": "sum",
357
+ "estimated_input_tokens": "sum",
358
+ "estimated_output_tokens": "sum",
359
+ }
360
+ )
361
+ .reset_index()
362
+ )
363
+ df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
+ df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
+ df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
366
+
367
+ estimated_costs_by_model = df.to_dict("records")
368
+
369
+ estimated_total_cost = sum(
370
+ model["estimated_cost_usd"] for model in estimated_costs_by_model
371
+ )
372
+ estimated_total_input_tokens = sum(
373
+ model["estimated_input_tokens"] for model in estimated_costs_by_model
374
+ )
375
+ estimated_total_output_tokens = sum(
376
+ model["estimated_output_tokens"] for model in estimated_costs_by_model
208
377
  )
209
378
 
379
+ output = {
380
+ "estimated_total_cost_usd": estimated_total_cost,
381
+ "estimated_total_input_tokens": estimated_total_input_tokens,
382
+ "estimated_total_output_tokens": estimated_total_output_tokens,
383
+ "model_costs": estimated_costs_by_model,
384
+ }
385
+
386
+ return output
387
+
210
388
  def estimate_job_cost(self, iterations: int = 1) -> dict:
211
389
  """
212
- Estimate the cost of running the job.
390
+ Estimates the cost of a job according to the following assumptions:
213
391
 
214
- :param iterations: the number of iterations to run
215
- """
216
- from edsl.jobs.JobsPrompts import JobsPrompts
392
+ - 1 token = 4 characters.
393
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
217
394
 
218
- j = JobsPrompts(self)
219
- return j.estimate_job_cost(iterations)
395
+ Fetches prices from Coop.
396
+ """
397
+ from edsl import Coop
220
398
 
221
- def estimate_job_cost_from_external_prices(
222
- self, price_lookup: dict, iterations: int = 1
223
- ) -> dict:
224
- from edsl.jobs.JobsPrompts import JobsPrompts
399
+ c = Coop()
400
+ price_lookup = c.fetch_prices()
225
401
 
226
- j = JobsPrompts(self)
227
- return j.estimate_job_cost_from_external_prices(price_lookup, iterations)
402
+ return self.estimate_job_cost_from_external_prices(
403
+ price_lookup=price_lookup, iterations=iterations
404
+ )
228
405
 
229
406
  @staticmethod
230
- def compute_job_cost(job_results: Results) -> float:
407
+ def compute_job_cost(job_results: "Results") -> float:
231
408
  """
232
409
  Computes the cost of a completed job in USD.
233
410
  """
234
- return job_results.compute_job_cost()
411
+ total_cost = 0
412
+ for result in job_results:
413
+ for key in result.raw_model_response:
414
+ if key.endswith("_cost"):
415
+ result_cost = result.raw_model_response[key]
416
+
417
+ question_name = key.removesuffix("_cost")
418
+ cache_used = result.cache_used_dict[question_name]
419
+
420
+ if isinstance(result_cost, (int, float)) and not cache_used:
421
+ total_cost += result_cost
422
+
423
+ return total_cost
235
424
 
236
425
  @staticmethod
237
426
  def _get_container_class(object):
@@ -315,12 +504,17 @@ class Jobs(Base):
315
504
 
316
505
  @staticmethod
317
506
  def _get_empty_container_object(object):
318
- from edsl.agents.AgentList import AgentList
319
- from edsl.scenarios.ScenarioList import ScenarioList
507
+ from edsl import AgentList
508
+ from edsl import Agent
509
+ from edsl import Scenario
510
+ from edsl import ScenarioList
320
511
 
321
- return {"Agent": AgentList([]), "Scenario": ScenarioList([])}.get(
322
- object.__class__.__name__, []
323
- )
512
+ if isinstance(object, Agent):
513
+ return AgentList([])
514
+ elif isinstance(object, Scenario):
515
+ return ScenarioList([])
516
+ else:
517
+ return []
324
518
 
325
519
  @staticmethod
326
520
  def _merge_objects(passed_objects, current_objects) -> list:
@@ -528,6 +722,110 @@ class Jobs(Base):
528
722
  return False
529
723
  return self._raise_validation_errors
530
724
 
725
+ def create_remote_inference_job(
726
+ self,
727
+ iterations: int = 1,
728
+ remote_inference_description: Optional[str] = None,
729
+ remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
+ verbose=False,
731
+ ):
732
+ """ """
733
+ from edsl.coop.coop import Coop
734
+
735
+ coop = Coop()
736
+ self._output("Remote inference activated. Sending job to server...")
737
+ remote_job_creation_data = coop.remote_inference_create(
738
+ self,
739
+ description=remote_inference_description,
740
+ status="queued",
741
+ iterations=iterations,
742
+ initial_results_visibility=remote_inference_results_visibility,
743
+ )
744
+ job_uuid = remote_job_creation_data.get("uuid")
745
+ if self.verbose:
746
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
747
+ return remote_job_creation_data
748
+
749
+ @staticmethod
750
+ def check_status(job_uuid):
751
+ from edsl.coop.coop import Coop
752
+
753
+ coop = Coop()
754
+ return coop.remote_inference_get(job_uuid)
755
+
756
+ def poll_remote_inference_job(
757
+ self, remote_job_creation_data: dict, verbose=False, poll_interval=5
758
+ ) -> Union[Results, None]:
759
+ from edsl.coop.coop import Coop
760
+ import time
761
+ from datetime import datetime
762
+ from edsl.config import CONFIG
763
+
764
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
765
+
766
+ job_uuid = remote_job_creation_data.get("uuid")
767
+
768
+ coop = Coop()
769
+ job_in_queue = True
770
+ while job_in_queue:
771
+ remote_job_data = coop.remote_inference_get(job_uuid)
772
+ status = remote_job_data.get("status")
773
+ if status == "cancelled":
774
+ if self.verbose:
775
+ print("\r" + " " * 80 + "\r", end="")
776
+ print("Job cancelled by the user.")
777
+ print(
778
+ f"See {expected_parrot_url}/home/remote-inference for more details."
779
+ )
780
+ return None
781
+ elif status == "failed":
782
+ if self.verbose:
783
+ print("\r" + " " * 80 + "\r", end="")
784
+ print("Job failed.")
785
+ print(
786
+ f"See {expected_parrot_url}/home/remote-inference for more details."
787
+ )
788
+ return None
789
+ elif status == "completed":
790
+ results_uuid = remote_job_data.get("results_uuid")
791
+ results = coop.get(results_uuid, expected_object_type="results")
792
+ if self.verbose:
793
+ print("\r" + " " * 80 + "\r", end="")
794
+ url = f"{expected_parrot_url}/content/{results_uuid}"
795
+ print(f"Job completed and Results stored on Coop: {url}.")
796
+ return results
797
+ else:
798
+ duration = poll_interval
799
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
800
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
801
+ start_time = time.time()
802
+ i = 0
803
+ while time.time() - start_time < duration:
804
+ if self.verbose:
805
+ print(
806
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
+ end="",
808
+ flush=True,
809
+ )
810
+ time.sleep(0.1)
811
+ i += 1
812
+
813
+ def use_remote_inference(self, disable_remote_inference: bool) -> bool:
814
+ if disable_remote_inference:
815
+ return False
816
+ if not disable_remote_inference:
817
+ try:
818
+ from edsl import Coop
819
+
820
+ user_edsl_settings = Coop().edsl_settings
821
+ return user_edsl_settings.get("remote_inference", False)
822
+ except requests.ConnectionError:
823
+ pass
824
+ except CoopServerResponseError as e:
825
+ pass
826
+
827
+ return False
828
+
531
829
  def use_remote_cache(self, disable_remote_cache: bool) -> bool:
532
830
  if disable_remote_cache:
533
831
  return False
@@ -544,6 +842,96 @@ class Jobs(Base):
544
842
 
545
843
  return False
546
844
 
845
+ def check_api_keys(self) -> None:
846
+ from edsl import Model
847
+
848
+ for model in self.models + [Model()]:
849
+ if not model.has_valid_api_key():
850
+ raise MissingAPIKeyError(
851
+ model_name=str(model.model),
852
+ inference_service=model._inference_service_,
853
+ )
854
+
855
+ def get_missing_api_keys(self) -> set:
856
+ """
857
+ Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
+ """
859
+
860
+ missing_api_keys = set()
861
+
862
+ from edsl import Model
863
+ from edsl.enums import service_to_api_keyname
864
+
865
+ for model in self.models + [Model()]:
866
+ if not model.has_valid_api_key():
867
+ key_name = service_to_api_keyname.get(
868
+ model._inference_service_, "NOT FOUND"
869
+ )
870
+ missing_api_keys.add(key_name)
871
+
872
+ return missing_api_keys
873
+
874
+ def user_has_all_model_keys(self):
875
+ """
876
+ Returns True if the user has all model keys required to run their job.
877
+
878
+ Otherwise, returns False.
879
+ """
880
+
881
+ try:
882
+ self.check_api_keys()
883
+ return True
884
+ except MissingAPIKeyError:
885
+ return False
886
+ except Exception:
887
+ raise
888
+
889
+ def user_has_ep_api_key(self) -> bool:
890
+ """
891
+ Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
+
893
+ Otherwise, returns False.
894
+ """
895
+
896
+ import os
897
+
898
+ coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
+
900
+ if coop_api_key is not None:
901
+ return True
902
+ else:
903
+ return False
904
+
905
+ def needs_external_llms(self) -> bool:
906
+ """
907
+ Returns True if the job needs external LLMs to run.
908
+
909
+ Otherwise, returns False.
910
+ """
911
+ # These cases are necessary to skip the API key check during doctests
912
+
913
+ # Accounts for Results.example()
914
+ all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
+ [hasattr(a, "answer_question_directly") for a in self.agents]
916
+ )
917
+
918
+ # Accounts for InterviewExceptionEntry.example()
919
+ only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
+
921
+ # Accounts for Survey.__call__
922
+ all_questions_are_functional = set(
923
+ [q.question_type for q in self.survey.questions]
924
+ ) == set(["functional"])
925
+
926
+ if (
927
+ all_agents_answer_questions_directly
928
+ or only_model_is_test
929
+ or all_questions_are_functional
930
+ ):
931
+ return False
932
+ else:
933
+ return True
934
+
547
935
  def run(
548
936
  self,
549
937
  n: int = 1,
@@ -552,7 +940,7 @@ class Jobs(Base):
552
940
  cache: Union[Cache, bool] = None,
553
941
  check_api_keys: bool = False,
554
942
  sidecar_model: Optional[LanguageModel] = None,
555
- verbose: bool = True,
943
+ verbose: bool = False,
556
944
  print_exceptions=True,
557
945
  remote_cache_description: Optional[str] = None,
558
946
  remote_inference_description: Optional[str] = None,
@@ -587,28 +975,62 @@ class Jobs(Base):
587
975
 
588
976
  self.verbose = verbose
589
977
 
590
- from edsl.jobs.JobsChecks import JobsChecks
978
+ if (
979
+ not self.user_has_all_model_keys()
980
+ and not self.user_has_ep_api_key()
981
+ and self.needs_external_llms()
982
+ ):
983
+ import secrets
984
+ from dotenv import load_dotenv
985
+ from edsl import CONFIG
986
+ from edsl.coop.coop import Coop
987
+ from edsl.utilities.utilities import write_api_key_to_env
988
+
989
+ missing_api_keys = self.get_missing_api_keys()
591
990
 
592
- jc = JobsChecks(self)
991
+ edsl_auth_token = secrets.token_urlsafe(16)
593
992
 
594
- # check if the user has all the keys they need
595
- if jc.needs_key_process():
596
- jc.key_process()
993
+ print("You're missing some of the API keys needed to run this job:")
994
+ for api_key in missing_api_keys:
995
+ print(f" 🔑 {api_key}")
996
+ print(
997
+ "\nYou can either add the missing keys to your .env file, or use remote inference."
998
+ )
999
+ print("Remote inference allows you to run jobs on our server.")
1000
+ print("\n🚀 To use remote inference, sign up at the following link:")
1001
+
1002
+ coop = Coop()
1003
+ coop._display_login_url(edsl_auth_token=edsl_auth_token)
597
1004
 
598
- from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1005
+ print(
1006
+ "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
+ )
599
1008
 
600
- jh = JobsRemoteInferenceHandler(self, verbose=verbose)
601
- if jh.use_remote_inference(disable_remote_inference):
602
- jh.create_remote_inference_job(
1009
+ api_key = coop._poll_for_api_key(edsl_auth_token)
1010
+
1011
+ if api_key is None:
1012
+ print("\nTimed out waiting for login. Please try again.")
1013
+ return
1014
+
1015
+ write_api_key_to_env(api_key)
1016
+ print("✨ API key retrieved and written to .env file.\n")
1017
+
1018
+ # Retrieve API key so we can continue running the job
1019
+ load_dotenv()
1020
+
1021
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1022
+ remote_job_creation_data = self.create_remote_inference_job(
603
1023
  iterations=n,
604
1024
  remote_inference_description=remote_inference_description,
605
1025
  remote_inference_results_visibility=remote_inference_results_visibility,
606
1026
  )
607
- results = jh.poll_remote_inference_job()
1027
+ results = self.poll_remote_inference_job(remote_job_creation_data)
1028
+ if results is None:
1029
+ self._output("Job failed.")
608
1030
  return results
609
1031
 
610
1032
  if check_api_keys:
611
- jc.check_api_keys()
1033
+ self.check_api_keys()
612
1034
 
613
1035
  # handle cache
614
1036
  if cache is None or cache is True:
@@ -638,9 +1060,46 @@ class Jobs(Base):
638
1060
  raise_validation_errors=raise_validation_errors,
639
1061
  )
640
1062
 
641
- # results.cache = cache.new_entries_cache()
1063
+ results.cache = cache.new_entries_cache()
642
1064
  return results
643
1065
 
1066
+ async def create_and_poll_remote_job(
1067
+ self,
1068
+ iterations: int = 1,
1069
+ remote_inference_description: Optional[str] = None,
1070
+ remote_inference_results_visibility: Optional[
1071
+ Literal["private", "public", "unlisted"]
1072
+ ] = "unlisted",
1073
+ ) -> Union[Results, None]:
1074
+ """
1075
+ Creates and polls a remote inference job asynchronously.
1076
+ Reuses existing synchronous methods but runs them in an async context.
1077
+
1078
+ :param iterations: Number of times to run each interview
1079
+ :param remote_inference_description: Optional description for the remote job
1080
+ :param remote_inference_results_visibility: Visibility setting for results
1081
+ :return: Results object if successful, None if job fails or is cancelled
1082
+ """
1083
+ import asyncio
1084
+ from functools import partial
1085
+
1086
+ # Create job using existing method
1087
+ loop = asyncio.get_event_loop()
1088
+ remote_job_creation_data = await loop.run_in_executor(
1089
+ None,
1090
+ partial(
1091
+ self.create_remote_inference_job,
1092
+ iterations=iterations,
1093
+ remote_inference_description=remote_inference_description,
1094
+ remote_inference_results_visibility=remote_inference_results_visibility,
1095
+ ),
1096
+ )
1097
+
1098
+ # Poll using existing method but with async sleep
1099
+ return await loop.run_in_executor(
1100
+ None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
+ )
1102
+
644
1103
  async def run_async(
645
1104
  self,
646
1105
  cache=None,
@@ -663,15 +1122,14 @@ class Jobs(Base):
663
1122
  :return: Results object
664
1123
  """
665
1124
  # Check if we should use remote inference
666
- from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
667
-
668
- jh = JobsRemoteInferenceHandler(self, verbose=False)
669
- if jh.use_remote_inference(disable_remote_inference):
670
- results = await jh.create_and_poll_remote_job(
1125
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
+ results = await self.create_and_poll_remote_job(
671
1127
  iterations=n,
672
1128
  remote_inference_description=remote_inference_description,
673
1129
  remote_inference_results_visibility=remote_inference_results_visibility,
674
1130
  )
1131
+ if results is None:
1132
+ self._output("Job failed.")
675
1133
  return results
676
1134
 
677
1135
  # If not using remote inference, run locally with async
@@ -691,22 +1149,24 @@ class Jobs(Base):
691
1149
  """
692
1150
  return set.union(*[question.parameters for question in self.survey.questions])
693
1151
 
1152
+ #######################
1153
+ # Dunder methods
1154
+ #######################
1155
+ def print(self):
1156
+ from rich import print_json
1157
+ import json
1158
+
1159
+ print_json(json.dumps(self.to_dict()))
1160
+
694
1161
  def __repr__(self) -> str:
695
1162
  """Return an eval-able string representation of the Jobs instance."""
696
1163
  return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
697
1164
 
698
- def _summary(self):
699
- return {
700
- "EDSL Class": "Jobs",
701
- "Number of questions": len(self.survey),
702
- "Number of agents": len(self.agents),
703
- "Number of models": len(self.models),
704
- "Number of scenarios": len(self.scenarios),
705
- }
706
-
707
1165
  def _repr_html_(self) -> str:
708
- footer = f"<a href={self.__documentation__}>(docs)</a>"
709
- return str(self.summary(format="html")) + footer
1166
+ from rich import print_json
1167
+ import json
1168
+
1169
+ print_json(json.dumps(self.to_dict()))
710
1170
 
711
1171
  def __len__(self) -> int:
712
1172
  """Return the maximum number of questions that will be asked while running this job.
@@ -776,7 +1236,7 @@ class Jobs(Base):
776
1236
  True
777
1237
 
778
1238
  """
779
- return hash(self) == hash(other)
1239
+ return self.to_dict() == other.to_dict()
780
1240
 
781
1241
  #######################
782
1242
  # Example methods