edsl 0.1.38__py3-none-any.whl → 0.1.38.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +34 -63
  2. edsl/BaseDiff.py +7 -7
  3. edsl/__init__.py +1 -2
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +11 -23
  6. edsl/agents/AgentList.py +23 -86
  7. edsl/agents/Invigilator.py +7 -18
  8. edsl/agents/InvigilatorBase.py +19 -0
  9. edsl/agents/PromptConstructor.py +4 -5
  10. edsl/auto/SurveyCreatorPipeline.py +1 -1
  11. edsl/auto/utilities.py +1 -1
  12. edsl/base/Base.py +13 -3
  13. edsl/config.py +0 -8
  14. edsl/conjure/AgentConstructionMixin.py +160 -0
  15. edsl/conjure/Conjure.py +62 -0
  16. edsl/conjure/InputData.py +659 -0
  17. edsl/conjure/InputDataCSV.py +48 -0
  18. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  19. edsl/conjure/InputDataPyRead.py +91 -0
  20. edsl/conjure/InputDataSPSS.py +8 -0
  21. edsl/conjure/InputDataStata.py +8 -0
  22. edsl/conjure/QuestionOptionMixin.py +76 -0
  23. edsl/conjure/QuestionTypeMixin.py +23 -0
  24. edsl/conjure/RawQuestion.py +65 -0
  25. edsl/conjure/SurveyResponses.py +7 -0
  26. edsl/conjure/__init__.py +9 -0
  27. edsl/conjure/examples/placeholder.txt +0 -0
  28. edsl/{utilities → conjure}/naming_utilities.py +1 -1
  29. edsl/conjure/utilities.py +201 -0
  30. edsl/coop/coop.py +7 -77
  31. edsl/data/Cache.py +17 -45
  32. edsl/data/CacheEntry.py +3 -8
  33. edsl/data/RemoteCacheSync.py +19 -0
  34. edsl/enums.py +0 -2
  35. edsl/exceptions/agents.py +0 -4
  36. edsl/inference_services/GoogleService.py +15 -7
  37. edsl/inference_services/registry.py +0 -2
  38. edsl/jobs/Jobs.py +559 -110
  39. edsl/jobs/buckets/TokenBucket.py +0 -3
  40. edsl/jobs/interviews/Interview.py +7 -7
  41. edsl/jobs/runners/JobsRunnerAsyncio.py +28 -156
  42. edsl/jobs/runners/JobsRunnerStatus.py +196 -194
  43. edsl/jobs/tasks/TaskHistory.py +19 -27
  44. edsl/language_models/LanguageModel.py +90 -52
  45. edsl/language_models/ModelList.py +14 -67
  46. edsl/language_models/registry.py +4 -57
  47. edsl/notebooks/Notebook.py +8 -7
  48. edsl/prompts/Prompt.py +3 -8
  49. edsl/questions/QuestionBase.py +30 -38
  50. edsl/questions/QuestionBaseGenMixin.py +1 -1
  51. edsl/questions/QuestionBasePromptsMixin.py +17 -0
  52. edsl/questions/QuestionExtract.py +4 -3
  53. edsl/questions/QuestionFunctional.py +3 -10
  54. edsl/questions/derived/QuestionTopK.py +0 -2
  55. edsl/questions/question_registry.py +6 -36
  56. edsl/results/Dataset.py +15 -146
  57. edsl/results/DatasetExportMixin.py +217 -231
  58. edsl/results/DatasetTree.py +4 -134
  59. edsl/results/Result.py +16 -31
  60. edsl/results/Results.py +65 -159
  61. edsl/scenarios/FileStore.py +13 -187
  62. edsl/scenarios/Scenario.py +18 -73
  63. edsl/scenarios/ScenarioList.py +76 -251
  64. edsl/surveys/MemoryPlan.py +1 -1
  65. edsl/surveys/Rule.py +5 -1
  66. edsl/surveys/RuleCollection.py +1 -1
  67. edsl/surveys/Survey.py +19 -25
  68. edsl/surveys/SurveyFlowVisualizationMixin.py +9 -67
  69. edsl/surveys/instructions/ChangeInstruction.py +7 -9
  70. edsl/surveys/instructions/Instruction.py +7 -21
  71. edsl/templates/error_reporting/interview_details.html +3 -3
  72. edsl/templates/error_reporting/interviews.html +9 -18
  73. edsl/utilities/utilities.py +0 -15
  74. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/METADATA +1 -2
  75. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/RECORD +77 -71
  76. edsl/exceptions/cache.py +0 -5
  77. edsl/inference_services/PerplexityService.py +0 -163
  78. edsl/jobs/JobsChecks.py +0 -147
  79. edsl/jobs/JobsPrompts.py +0 -268
  80. edsl/jobs/JobsRemoteInferenceHandler.py +0 -239
  81. edsl/results/CSSParameterizer.py +0 -108
  82. edsl/results/TableDisplay.py +0 -198
  83. edsl/results/table_display.css +0 -78
  84. edsl/scenarios/ScenarioJoin.py +0 -127
  85. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import warnings
4
4
  import requests
5
5
  from itertools import product
6
- from typing import Literal, Optional, Union, Sequence, Generator, TYPE_CHECKING
6
+ from typing import Literal, Optional, Union, Sequence, Generator
7
7
 
8
8
  from edsl.Base import Base
9
9
 
@@ -11,20 +11,11 @@ from edsl.exceptions import MissingAPIKeyError
11
11
  from edsl.jobs.buckets.BucketCollection import BucketCollection
12
12
  from edsl.jobs.interviews.Interview import Interview
13
13
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
14
- from edsl.utilities.decorators import remove_edsl_version
14
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
15
15
 
16
16
  from edsl.data.RemoteCacheSync import RemoteCacheSync
17
17
  from edsl.exceptions.coop import CoopServerResponseError
18
18
 
19
- if TYPE_CHECKING:
20
- from edsl.agents.Agent import Agent
21
- from edsl.agents.AgentList import AgentList
22
- from edsl.language_models.LanguageModel import LanguageModel
23
- from edsl.scenarios.Scenario import Scenario
24
- from edsl.surveys.Survey import Survey
25
- from edsl.results.Results import Results
26
- from edsl.results.Dataset import Dataset
27
-
28
19
 
29
20
  class Jobs(Base):
30
21
  """
@@ -33,8 +24,6 @@ class Jobs(Base):
33
24
  The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
34
25
  """
35
26
 
36
- __documentation__ = "https://docs.expectedparrot.com/en/latest/jobs.html"
37
-
38
27
  def __init__(
39
28
  self,
40
29
  survey: "Survey",
@@ -97,14 +86,8 @@ class Jobs(Base):
97
86
  @scenarios.setter
98
87
  def scenarios(self, value):
99
88
  from edsl import ScenarioList
100
- from edsl.results.Dataset import Dataset
101
89
 
102
90
  if value:
103
- if isinstance(
104
- value, Dataset
105
- ): # if the user passes in a Dataset, convert it to a ScenarioList
106
- value = value.to_scenario_list()
107
-
108
91
  if not isinstance(value, ScenarioList):
109
92
  self._scenarios = ScenarioList(value)
110
93
  else:
@@ -144,13 +127,6 @@ class Jobs(Base):
144
127
  - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
145
128
  - models: new models overwrite old models.
146
129
  """
147
- from edsl.results.Dataset import Dataset
148
-
149
- if isinstance(
150
- args[0], Dataset
151
- ): # let the user user a Dataset as if it were a ScenarioList
152
- args = args[0].to_scenario_list()
153
-
154
130
  passed_objects = self._turn_args_to_list(
155
131
  args
156
132
  ) # objects can also be passed comma-separated
@@ -175,19 +151,73 @@ class Jobs(Base):
175
151
  >>> Jobs.example().prompts()
176
152
  Dataset(...)
177
153
  """
178
- from edsl.jobs.JobsPrompts import JobsPrompts
154
+ from edsl import Coop
155
+
156
+ c = Coop()
157
+ price_lookup = c.fetch_prices()
158
+
159
+ interviews = self.interviews()
160
+ # data = []
161
+ interview_indices = []
162
+ question_names = []
163
+ user_prompts = []
164
+ system_prompts = []
165
+ scenario_indices = []
166
+ agent_indices = []
167
+ models = []
168
+ costs = []
169
+ from edsl.results.Dataset import Dataset
179
170
 
180
- j = JobsPrompts(self)
181
- return j.prompts()
171
+ for interview_index, interview in enumerate(interviews):
172
+ invigilators = [
173
+ interview._get_invigilator(question)
174
+ for question in self.survey.questions
175
+ ]
176
+ for _, invigilator in enumerate(invigilators):
177
+ prompts = invigilator.get_prompts()
178
+ user_prompt = prompts["user_prompt"]
179
+ system_prompt = prompts["system_prompt"]
180
+ user_prompts.append(user_prompt)
181
+ system_prompts.append(system_prompt)
182
+ agent_index = self.agents.index(invigilator.agent)
183
+ agent_indices.append(agent_index)
184
+ interview_indices.append(interview_index)
185
+ scenario_index = self.scenarios.index(invigilator.scenario)
186
+ scenario_indices.append(scenario_index)
187
+ models.append(invigilator.model.model)
188
+ question_names.append(invigilator.question.question_name)
189
+
190
+ prompt_cost = self.estimate_prompt_cost(
191
+ system_prompt=system_prompt,
192
+ user_prompt=user_prompt,
193
+ price_lookup=price_lookup,
194
+ inference_service=invigilator.model._inference_service_,
195
+ model=invigilator.model.model,
196
+ )
197
+ costs.append(prompt_cost["cost_usd"])
182
198
 
183
- def show_prompts(self, all=False) -> None:
199
+ d = Dataset(
200
+ [
201
+ {"user_prompt": user_prompts},
202
+ {"system_prompt": system_prompts},
203
+ {"interview_index": interview_indices},
204
+ {"question_name": question_names},
205
+ {"scenario_index": scenario_indices},
206
+ {"agent_index": agent_indices},
207
+ {"model": models},
208
+ {"estimated_cost": costs},
209
+ ]
210
+ )
211
+ return d
212
+
213
+ def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
184
214
  """Print the prompts."""
185
215
  if all:
186
- return self.prompts().to_scenario_list().table()
216
+ self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
187
217
  else:
188
- return (
189
- self.prompts().to_scenario_list().table("user_prompt", "system_prompt")
190
- )
218
+ self.prompts().select(
219
+ "user_prompt", "system_prompt"
220
+ ).to_scenario_list().print(format="rich", max_rows=max_rows)
191
221
 
192
222
  @staticmethod
193
223
  def estimate_prompt_cost(
@@ -196,42 +226,201 @@ class Jobs(Base):
196
226
  price_lookup: dict,
197
227
  inference_service: str,
198
228
  model: str,
229
+ ) -> dict:
230
+ """Estimates the cost of a prompt. Takes piping into account."""
231
+ import math
232
+
233
+ def get_piping_multiplier(prompt: str):
234
+ """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
235
+
236
+ if "{{" in prompt and "}}" in prompt:
237
+ return 2
238
+ return 1
239
+
240
+ # Look up prices per token
241
+ key = (inference_service, model)
242
+
243
+ try:
244
+ relevant_prices = price_lookup[key]
245
+
246
+ service_input_token_price = float(
247
+ relevant_prices["input"]["service_stated_token_price"]
248
+ )
249
+ service_input_token_qty = float(
250
+ relevant_prices["input"]["service_stated_token_qty"]
251
+ )
252
+ input_price_per_token = service_input_token_price / service_input_token_qty
253
+
254
+ service_output_token_price = float(
255
+ relevant_prices["output"]["service_stated_token_price"]
256
+ )
257
+ service_output_token_qty = float(
258
+ relevant_prices["output"]["service_stated_token_qty"]
259
+ )
260
+ output_price_per_token = (
261
+ service_output_token_price / service_output_token_qty
262
+ )
263
+
264
+ except KeyError:
265
+ # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
266
+ # Use a sensible default
267
+
268
+ import warnings
269
+
270
+ warnings.warn(
271
+ "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
272
+ )
273
+ input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
+ output_price_per_token = 0.00000060 # $0.60 / 1M tokens
275
+
276
+ # Compute the number of characters (double if the question involves piping)
277
+ user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
278
+ str(user_prompt)
279
+ )
280
+ system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
281
+ str(system_prompt)
282
+ )
283
+
284
+ # Convert into tokens (1 token approx. equals 4 characters)
285
+ input_tokens = (user_prompt_chars + system_prompt_chars) // 4
286
+
287
+ output_tokens = math.ceil(0.75 * input_tokens)
288
+
289
+ cost = (
290
+ input_tokens * input_price_per_token
291
+ + output_tokens * output_price_per_token
292
+ )
293
+
294
+ return {
295
+ "input_tokens": input_tokens,
296
+ "output_tokens": output_tokens,
297
+ "cost_usd": cost,
298
+ }
299
+
300
+ def estimate_job_cost_from_external_prices(
301
+ self, price_lookup: dict, iterations: int = 1
199
302
  ) -> dict:
200
303
  """
201
- Estimate the cost of running the prompts.
202
- :param iterations: the number of iterations to run
304
+ Estimates the cost of a job according to the following assumptions:
305
+
306
+ - 1 token = 4 characters.
307
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
308
+
309
+ price_lookup is an external pricing dictionary.
203
310
  """
204
- from edsl.jobs.JobsPrompts import JobsPrompts
205
311
 
206
- return JobsPrompts.estimate_prompt_cost(
207
- system_prompt, user_prompt, price_lookup, inference_service, model
312
+ import pandas as pd
313
+
314
+ interviews = self.interviews()
315
+ data = []
316
+ for interview in interviews:
317
+ invigilators = [
318
+ interview._get_invigilator(question)
319
+ for question in self.survey.questions
320
+ ]
321
+ for invigilator in invigilators:
322
+ prompts = invigilator.get_prompts()
323
+
324
+ # By this point, agent and scenario data has already been added to the prompts
325
+ user_prompt = prompts["user_prompt"]
326
+ system_prompt = prompts["system_prompt"]
327
+ inference_service = invigilator.model._inference_service_
328
+ model = invigilator.model.model
329
+
330
+ prompt_cost = self.estimate_prompt_cost(
331
+ system_prompt=system_prompt,
332
+ user_prompt=user_prompt,
333
+ price_lookup=price_lookup,
334
+ inference_service=inference_service,
335
+ model=model,
336
+ )
337
+
338
+ data.append(
339
+ {
340
+ "user_prompt": user_prompt,
341
+ "system_prompt": system_prompt,
342
+ "estimated_input_tokens": prompt_cost["input_tokens"],
343
+ "estimated_output_tokens": prompt_cost["output_tokens"],
344
+ "estimated_cost_usd": prompt_cost["cost_usd"],
345
+ "inference_service": inference_service,
346
+ "model": model,
347
+ }
348
+ )
349
+
350
+ df = pd.DataFrame.from_records(data)
351
+
352
+ df = (
353
+ df.groupby(["inference_service", "model"])
354
+ .agg(
355
+ {
356
+ "estimated_cost_usd": "sum",
357
+ "estimated_input_tokens": "sum",
358
+ "estimated_output_tokens": "sum",
359
+ }
360
+ )
361
+ .reset_index()
208
362
  )
363
+ df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
+ df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
+ df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
366
+
367
+ estimated_costs_by_model = df.to_dict("records")
368
+
369
+ estimated_total_cost = sum(
370
+ model["estimated_cost_usd"] for model in estimated_costs_by_model
371
+ )
372
+ estimated_total_input_tokens = sum(
373
+ model["estimated_input_tokens"] for model in estimated_costs_by_model
374
+ )
375
+ estimated_total_output_tokens = sum(
376
+ model["estimated_output_tokens"] for model in estimated_costs_by_model
377
+ )
378
+
379
+ output = {
380
+ "estimated_total_cost_usd": estimated_total_cost,
381
+ "estimated_total_input_tokens": estimated_total_input_tokens,
382
+ "estimated_total_output_tokens": estimated_total_output_tokens,
383
+ "model_costs": estimated_costs_by_model,
384
+ }
385
+
386
+ return output
209
387
 
210
388
  def estimate_job_cost(self, iterations: int = 1) -> dict:
211
389
  """
212
- Estimate the cost of running the job.
390
+ Estimates the cost of a job according to the following assumptions:
391
+
392
+ - 1 token = 4 characters.
393
+ - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
213
394
 
214
- :param iterations: the number of iterations to run
395
+ Fetches prices from Coop.
215
396
  """
216
- from edsl.jobs.JobsPrompts import JobsPrompts
397
+ from edsl import Coop
217
398
 
218
- j = JobsPrompts(self)
219
- return j.estimate_job_cost(iterations)
399
+ c = Coop()
400
+ price_lookup = c.fetch_prices()
220
401
 
221
- def estimate_job_cost_from_external_prices(
222
- self, price_lookup: dict, iterations: int = 1
223
- ) -> dict:
224
- from edsl.jobs.JobsPrompts import JobsPrompts
225
-
226
- j = JobsPrompts(self)
227
- return j.estimate_job_cost_from_external_prices(price_lookup, iterations)
402
+ return self.estimate_job_cost_from_external_prices(
403
+ price_lookup=price_lookup, iterations=iterations
404
+ )
228
405
 
229
406
  @staticmethod
230
- def compute_job_cost(job_results: Results) -> float:
407
+ def compute_job_cost(job_results: "Results") -> float:
231
408
  """
232
409
  Computes the cost of a completed job in USD.
233
410
  """
234
- return job_results.compute_job_cost()
411
+ total_cost = 0
412
+ for result in job_results:
413
+ for key in result.raw_model_response:
414
+ if key.endswith("_cost"):
415
+ result_cost = result.raw_model_response[key]
416
+
417
+ question_name = key.removesuffix("_cost")
418
+ cache_used = result.cache_used_dict[question_name]
419
+
420
+ if isinstance(result_cost, (int, float)) and not cache_used:
421
+ total_cost += result_cost
422
+
423
+ return total_cost
235
424
 
236
425
  @staticmethod
237
426
  def _get_container_class(object):
@@ -315,12 +504,17 @@ class Jobs(Base):
315
504
 
316
505
  @staticmethod
317
506
  def _get_empty_container_object(object):
318
- from edsl.agents.AgentList import AgentList
319
- from edsl.scenarios.ScenarioList import ScenarioList
507
+ from edsl import AgentList
508
+ from edsl import Agent
509
+ from edsl import Scenario
510
+ from edsl import ScenarioList
320
511
 
321
- return {"Agent": AgentList([]), "Scenario": ScenarioList([])}.get(
322
- object.__class__.__name__, []
323
- )
512
+ if isinstance(object, Agent):
513
+ return AgentList([])
514
+ elif isinstance(object, Scenario):
515
+ return ScenarioList([])
516
+ else:
517
+ return []
324
518
 
325
519
  @staticmethod
326
520
  def _merge_objects(passed_objects, current_objects) -> list:
@@ -447,7 +641,7 @@ class Jobs(Base):
447
641
  """
448
642
  from edsl.utilities.utilities import dict_hash
449
643
 
450
- return dict_hash(self.to_dict(add_edsl_version=False))
644
+ return dict_hash(self._to_dict())
451
645
 
452
646
  def _output(self, message) -> None:
453
647
  """Check if a Job is verbose. If so, print the message."""
@@ -528,6 +722,110 @@ class Jobs(Base):
528
722
  return False
529
723
  return self._raise_validation_errors
530
724
 
725
+ def create_remote_inference_job(
726
+ self,
727
+ iterations: int = 1,
728
+ remote_inference_description: Optional[str] = None,
729
+ remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
+ verbose=False,
731
+ ):
732
+ """ """
733
+ from edsl.coop.coop import Coop
734
+
735
+ coop = Coop()
736
+ self._output("Remote inference activated. Sending job to server...")
737
+ remote_job_creation_data = coop.remote_inference_create(
738
+ self,
739
+ description=remote_inference_description,
740
+ status="queued",
741
+ iterations=iterations,
742
+ initial_results_visibility=remote_inference_results_visibility,
743
+ )
744
+ job_uuid = remote_job_creation_data.get("uuid")
745
+ if self.verbose:
746
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
747
+ return remote_job_creation_data
748
+
749
+ @staticmethod
750
+ def check_status(job_uuid):
751
+ from edsl.coop.coop import Coop
752
+
753
+ coop = Coop()
754
+ return coop.remote_inference_get(job_uuid)
755
+
756
+ def poll_remote_inference_job(
757
+ self, remote_job_creation_data: dict, verbose=False, poll_interval=5
758
+ ) -> Union[Results, None]:
759
+ from edsl.coop.coop import Coop
760
+ import time
761
+ from datetime import datetime
762
+ from edsl.config import CONFIG
763
+
764
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
765
+
766
+ job_uuid = remote_job_creation_data.get("uuid")
767
+
768
+ coop = Coop()
769
+ job_in_queue = True
770
+ while job_in_queue:
771
+ remote_job_data = coop.remote_inference_get(job_uuid)
772
+ status = remote_job_data.get("status")
773
+ if status == "cancelled":
774
+ if self.verbose:
775
+ print("\r" + " " * 80 + "\r", end="")
776
+ print("Job cancelled by the user.")
777
+ print(
778
+ f"See {expected_parrot_url}/home/remote-inference for more details."
779
+ )
780
+ return None
781
+ elif status == "failed":
782
+ if self.verbose:
783
+ print("\r" + " " * 80 + "\r", end="")
784
+ print("Job failed.")
785
+ print(
786
+ f"See {expected_parrot_url}/home/remote-inference for more details."
787
+ )
788
+ return None
789
+ elif status == "completed":
790
+ results_uuid = remote_job_data.get("results_uuid")
791
+ results = coop.get(results_uuid, expected_object_type="results")
792
+ if self.verbose:
793
+ print("\r" + " " * 80 + "\r", end="")
794
+ url = f"{expected_parrot_url}/content/{results_uuid}"
795
+ print(f"Job completed and Results stored on Coop: {url}.")
796
+ return results
797
+ else:
798
+ duration = poll_interval
799
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
800
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
801
+ start_time = time.time()
802
+ i = 0
803
+ while time.time() - start_time < duration:
804
+ if self.verbose:
805
+ print(
806
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
+ end="",
808
+ flush=True,
809
+ )
810
+ time.sleep(0.1)
811
+ i += 1
812
+
813
+ def use_remote_inference(self, disable_remote_inference: bool) -> bool:
814
+ if disable_remote_inference:
815
+ return False
816
+ if not disable_remote_inference:
817
+ try:
818
+ from edsl import Coop
819
+
820
+ user_edsl_settings = Coop().edsl_settings
821
+ return user_edsl_settings.get("remote_inference", False)
822
+ except requests.ConnectionError:
823
+ pass
824
+ except CoopServerResponseError as e:
825
+ pass
826
+
827
+ return False
828
+
531
829
  def use_remote_cache(self, disable_remote_cache: bool) -> bool:
532
830
  if disable_remote_cache:
533
831
  return False
@@ -544,6 +842,96 @@ class Jobs(Base):
544
842
 
545
843
  return False
546
844
 
845
+ def check_api_keys(self) -> None:
846
+ from edsl import Model
847
+
848
+ for model in self.models + [Model()]:
849
+ if not model.has_valid_api_key():
850
+ raise MissingAPIKeyError(
851
+ model_name=str(model.model),
852
+ inference_service=model._inference_service_,
853
+ )
854
+
855
+ def get_missing_api_keys(self) -> set:
856
+ """
857
+ Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
+ """
859
+
860
+ missing_api_keys = set()
861
+
862
+ from edsl import Model
863
+ from edsl.enums import service_to_api_keyname
864
+
865
+ for model in self.models + [Model()]:
866
+ if not model.has_valid_api_key():
867
+ key_name = service_to_api_keyname.get(
868
+ model._inference_service_, "NOT FOUND"
869
+ )
870
+ missing_api_keys.add(key_name)
871
+
872
+ return missing_api_keys
873
+
874
+ def user_has_all_model_keys(self):
875
+ """
876
+ Returns True if the user has all model keys required to run their job.
877
+
878
+ Otherwise, returns False.
879
+ """
880
+
881
+ try:
882
+ self.check_api_keys()
883
+ return True
884
+ except MissingAPIKeyError:
885
+ return False
886
+ except Exception:
887
+ raise
888
+
889
+ def user_has_ep_api_key(self) -> bool:
890
+ """
891
+ Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
+
893
+ Otherwise, returns False.
894
+ """
895
+
896
+ import os
897
+
898
+ coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
+
900
+ if coop_api_key is not None:
901
+ return True
902
+ else:
903
+ return False
904
+
905
+ def needs_external_llms(self) -> bool:
906
+ """
907
+ Returns True if the job needs external LLMs to run.
908
+
909
+ Otherwise, returns False.
910
+ """
911
+ # These cases are necessary to skip the API key check during doctests
912
+
913
+ # Accounts for Results.example()
914
+ all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
+ [hasattr(a, "answer_question_directly") for a in self.agents]
916
+ )
917
+
918
+ # Accounts for InterviewExceptionEntry.example()
919
+ only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
+
921
+ # Accounts for Survey.__call__
922
+ all_questions_are_functional = set(
923
+ [q.question_type for q in self.survey.questions]
924
+ ) == set(["functional"])
925
+
926
+ if (
927
+ all_agents_answer_questions_directly
928
+ or only_model_is_test
929
+ or all_questions_are_functional
930
+ ):
931
+ return False
932
+ else:
933
+ return True
934
+
547
935
  def run(
548
936
  self,
549
937
  n: int = 1,
@@ -552,7 +940,7 @@ class Jobs(Base):
552
940
  cache: Union[Cache, bool] = None,
553
941
  check_api_keys: bool = False,
554
942
  sidecar_model: Optional[LanguageModel] = None,
555
- verbose: bool = True,
943
+ verbose: bool = False,
556
944
  print_exceptions=True,
557
945
  remote_cache_description: Optional[str] = None,
558
946
  remote_inference_description: Optional[str] = None,
@@ -587,28 +975,62 @@ class Jobs(Base):
587
975
 
588
976
  self.verbose = verbose
589
977
 
590
- from edsl.jobs.JobsChecks import JobsChecks
978
+ if (
979
+ not self.user_has_all_model_keys()
980
+ and not self.user_has_ep_api_key()
981
+ and self.needs_external_llms()
982
+ ):
983
+ import secrets
984
+ from dotenv import load_dotenv
985
+ from edsl import CONFIG
986
+ from edsl.coop.coop import Coop
987
+ from edsl.utilities.utilities import write_api_key_to_env
988
+
989
+ missing_api_keys = self.get_missing_api_keys()
591
990
 
592
- jc = JobsChecks(self)
991
+ edsl_auth_token = secrets.token_urlsafe(16)
593
992
 
594
- # check if the user has all the keys they need
595
- if jc.needs_key_process():
596
- jc.key_process()
993
+ print("You're missing some of the API keys needed to run this job:")
994
+ for api_key in missing_api_keys:
995
+ print(f" 🔑 {api_key}")
996
+ print(
997
+ "\nYou can either add the missing keys to your .env file, or use remote inference."
998
+ )
999
+ print("Remote inference allows you to run jobs on our server.")
1000
+ print("\n🚀 To use remote inference, sign up at the following link:")
1001
+
1002
+ coop = Coop()
1003
+ coop._display_login_url(edsl_auth_token=edsl_auth_token)
1004
+
1005
+ print(
1006
+ "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
+ )
1008
+
1009
+ api_key = coop._poll_for_api_key(edsl_auth_token)
1010
+
1011
+ if api_key is None:
1012
+ print("\nTimed out waiting for login. Please try again.")
1013
+ return
597
1014
 
598
- from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1015
+ write_api_key_to_env(api_key)
1016
+ print("✨ API key retrieved and written to .env file.\n")
599
1017
 
600
- jh = JobsRemoteInferenceHandler(self, verbose=verbose)
601
- if jh.use_remote_inference(disable_remote_inference):
602
- jh.create_remote_inference_job(
1018
+ # Retrieve API key so we can continue running the job
1019
+ load_dotenv()
1020
+
1021
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1022
+ remote_job_creation_data = self.create_remote_inference_job(
603
1023
  iterations=n,
604
1024
  remote_inference_description=remote_inference_description,
605
1025
  remote_inference_results_visibility=remote_inference_results_visibility,
606
1026
  )
607
- results = jh.poll_remote_inference_job()
1027
+ results = self.poll_remote_inference_job(remote_job_creation_data)
1028
+ if results is None:
1029
+ self._output("Job failed.")
608
1030
  return results
609
1031
 
610
1032
  if check_api_keys:
611
- jc.check_api_keys()
1033
+ self.check_api_keys()
612
1034
 
613
1035
  # handle cache
614
1036
  if cache is None or cache is True:
@@ -638,9 +1060,46 @@ class Jobs(Base):
638
1060
  raise_validation_errors=raise_validation_errors,
639
1061
  )
640
1062
 
641
- # results.cache = cache.new_entries_cache()
1063
+ results.cache = cache.new_entries_cache()
642
1064
  return results
643
1065
 
1066
+ async def create_and_poll_remote_job(
1067
+ self,
1068
+ iterations: int = 1,
1069
+ remote_inference_description: Optional[str] = None,
1070
+ remote_inference_results_visibility: Optional[
1071
+ Literal["private", "public", "unlisted"]
1072
+ ] = "unlisted",
1073
+ ) -> Union[Results, None]:
1074
+ """
1075
+ Creates and polls a remote inference job asynchronously.
1076
+ Reuses existing synchronous methods but runs them in an async context.
1077
+
1078
+ :param iterations: Number of times to run each interview
1079
+ :param remote_inference_description: Optional description for the remote job
1080
+ :param remote_inference_results_visibility: Visibility setting for results
1081
+ :return: Results object if successful, None if job fails or is cancelled
1082
+ """
1083
+ import asyncio
1084
+ from functools import partial
1085
+
1086
+ # Create job using existing method
1087
+ loop = asyncio.get_event_loop()
1088
+ remote_job_creation_data = await loop.run_in_executor(
1089
+ None,
1090
+ partial(
1091
+ self.create_remote_inference_job,
1092
+ iterations=iterations,
1093
+ remote_inference_description=remote_inference_description,
1094
+ remote_inference_results_visibility=remote_inference_results_visibility,
1095
+ ),
1096
+ )
1097
+
1098
+ # Poll using existing method but with async sleep
1099
+ return await loop.run_in_executor(
1100
+ None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
+ )
1102
+
644
1103
  async def run_async(
645
1104
  self,
646
1105
  cache=None,
@@ -663,15 +1122,14 @@ class Jobs(Base):
663
1122
  :return: Results object
664
1123
  """
665
1124
  # Check if we should use remote inference
666
- from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
667
-
668
- jh = JobsRemoteInferenceHandler(self, verbose=False)
669
- if jh.use_remote_inference(disable_remote_inference):
670
- results = await jh.create_and_poll_remote_job(
1125
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
+ results = await self.create_and_poll_remote_job(
671
1127
  iterations=n,
672
1128
  remote_inference_description=remote_inference_description,
673
1129
  remote_inference_results_visibility=remote_inference_results_visibility,
674
1130
  )
1131
+ if results is None:
1132
+ self._output("Job failed.")
675
1133
  return results
676
1134
 
677
1135
  # If not using remote inference, run locally with async
@@ -691,22 +1149,24 @@ class Jobs(Base):
691
1149
  """
692
1150
  return set.union(*[question.parameters for question in self.survey.questions])
693
1151
 
1152
+ #######################
1153
+ # Dunder methods
1154
+ #######################
1155
+ def print(self):
1156
+ from rich import print_json
1157
+ import json
1158
+
1159
+ print_json(json.dumps(self.to_dict()))
1160
+
694
1161
  def __repr__(self) -> str:
695
1162
  """Return an eval-able string representation of the Jobs instance."""
696
1163
  return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
697
1164
 
698
- def _summary(self):
699
- return {
700
- "EDSL Class": "Jobs",
701
- "Number of questions": len(self.survey),
702
- "Number of agents": len(self.agents),
703
- "Number of models": len(self.models),
704
- "Number of scenarios": len(self.scenarios),
705
- }
706
-
707
1165
  def _repr_html_(self) -> str:
708
- footer = f"<a href={self.__documentation__}>(docs)</a>"
709
- return str(self.summary(format="html")) + footer
1166
+ from rich import print_json
1167
+ import json
1168
+
1169
+ print_json(json.dumps(self.to_dict()))
710
1170
 
711
1171
  def __len__(self) -> int:
712
1172
  """Return the maximum number of questions that will be asked while running this job.
@@ -728,29 +1188,18 @@ class Jobs(Base):
728
1188
  # Serialization methods
729
1189
  #######################
730
1190
 
731
- def to_dict(self, add_edsl_version=True):
732
- d = {
733
- "survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
734
- "agents": [
735
- agent.to_dict(add_edsl_version=add_edsl_version)
736
- for agent in self.agents
737
- ],
738
- "models": [
739
- model.to_dict(add_edsl_version=add_edsl_version)
740
- for model in self.models
741
- ],
742
- "scenarios": [
743
- scenario.to_dict(add_edsl_version=add_edsl_version)
744
- for scenario in self.scenarios
745
- ],
1191
+ def _to_dict(self):
1192
+ return {
1193
+ "survey": self.survey._to_dict(),
1194
+ "agents": [agent._to_dict() for agent in self.agents],
1195
+ "models": [model._to_dict() for model in self.models],
1196
+ "scenarios": [scenario._to_dict() for scenario in self.scenarios],
746
1197
  }
747
- if add_edsl_version:
748
- from edsl import __version__
749
1198
 
750
- d["edsl_version"] = __version__
751
- d["edsl_class_name"] = "Jobs"
752
-
753
- return d
1199
+ @add_edsl_version
1200
+ def to_dict(self) -> dict:
1201
+ """Convert the Jobs instance to a dictionary."""
1202
+ return self._to_dict()
754
1203
 
755
1204
  @classmethod
756
1205
  @remove_edsl_version
@@ -776,7 +1225,7 @@ class Jobs(Base):
776
1225
  True
777
1226
 
778
1227
  """
779
- return hash(self) == hash(other)
1228
+ return self.to_dict() == other.to_dict()
780
1229
 
781
1230
  #######################
782
1231
  # Example methods