edsl 0.1.37.dev5__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +63 -34
  2. edsl/BaseDiff.py +7 -7
  3. edsl/__init__.py +2 -1
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +23 -11
  6. edsl/agents/AgentList.py +86 -23
  7. edsl/agents/Invigilator.py +18 -7
  8. edsl/agents/InvigilatorBase.py +0 -19
  9. edsl/agents/PromptConstructor.py +5 -4
  10. edsl/auto/SurveyCreatorPipeline.py +1 -1
  11. edsl/auto/utilities.py +1 -1
  12. edsl/base/Base.py +3 -13
  13. edsl/config.py +8 -0
  14. edsl/coop/coop.py +89 -19
  15. edsl/data/Cache.py +45 -17
  16. edsl/data/CacheEntry.py +8 -3
  17. edsl/data/RemoteCacheSync.py +0 -19
  18. edsl/enums.py +2 -0
  19. edsl/exceptions/agents.py +4 -0
  20. edsl/exceptions/cache.py +5 -0
  21. edsl/inference_services/GoogleService.py +7 -15
  22. edsl/inference_services/PerplexityService.py +163 -0
  23. edsl/inference_services/registry.py +2 -0
  24. edsl/jobs/Jobs.py +110 -559
  25. edsl/jobs/JobsChecks.py +147 -0
  26. edsl/jobs/JobsPrompts.py +268 -0
  27. edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
  28. edsl/jobs/buckets/TokenBucket.py +3 -0
  29. edsl/jobs/interviews/Interview.py +7 -7
  30. edsl/jobs/runners/JobsRunnerAsyncio.py +156 -28
  31. edsl/jobs/runners/JobsRunnerStatus.py +194 -196
  32. edsl/jobs/tasks/TaskHistory.py +27 -19
  33. edsl/language_models/LanguageModel.py +52 -90
  34. edsl/language_models/ModelList.py +67 -14
  35. edsl/language_models/registry.py +57 -4
  36. edsl/notebooks/Notebook.py +7 -8
  37. edsl/prompts/Prompt.py +8 -3
  38. edsl/questions/QuestionBase.py +38 -30
  39. edsl/questions/QuestionBaseGenMixin.py +1 -1
  40. edsl/questions/QuestionBasePromptsMixin.py +0 -17
  41. edsl/questions/QuestionExtract.py +3 -4
  42. edsl/questions/QuestionFunctional.py +10 -3
  43. edsl/questions/derived/QuestionTopK.py +2 -0
  44. edsl/questions/question_registry.py +36 -6
  45. edsl/results/CSSParameterizer.py +108 -0
  46. edsl/results/Dataset.py +146 -15
  47. edsl/results/DatasetExportMixin.py +231 -217
  48. edsl/results/DatasetTree.py +134 -4
  49. edsl/results/Result.py +31 -16
  50. edsl/results/Results.py +159 -65
  51. edsl/results/TableDisplay.py +198 -0
  52. edsl/results/table_display.css +78 -0
  53. edsl/scenarios/FileStore.py +187 -13
  54. edsl/scenarios/Scenario.py +73 -18
  55. edsl/scenarios/ScenarioJoin.py +127 -0
  56. edsl/scenarios/ScenarioList.py +251 -76
  57. edsl/surveys/MemoryPlan.py +1 -1
  58. edsl/surveys/Rule.py +1 -5
  59. edsl/surveys/RuleCollection.py +1 -1
  60. edsl/surveys/Survey.py +25 -19
  61. edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
  62. edsl/surveys/instructions/ChangeInstruction.py +9 -7
  63. edsl/surveys/instructions/Instruction.py +21 -7
  64. edsl/templates/error_reporting/interview_details.html +3 -3
  65. edsl/templates/error_reporting/interviews.html +18 -9
  66. edsl/{conjure → utilities}/naming_utilities.py +1 -1
  67. edsl/utilities/utilities.py +15 -0
  68. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/METADATA +2 -1
  69. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/RECORD +71 -77
  70. edsl/conjure/AgentConstructionMixin.py +0 -160
  71. edsl/conjure/Conjure.py +0 -62
  72. edsl/conjure/InputData.py +0 -659
  73. edsl/conjure/InputDataCSV.py +0 -48
  74. edsl/conjure/InputDataMixinQuestionStats.py +0 -182
  75. edsl/conjure/InputDataPyRead.py +0 -91
  76. edsl/conjure/InputDataSPSS.py +0 -8
  77. edsl/conjure/InputDataStata.py +0 -8
  78. edsl/conjure/QuestionOptionMixin.py +0 -76
  79. edsl/conjure/QuestionTypeMixin.py +0 -23
  80. edsl/conjure/RawQuestion.py +0 -65
  81. edsl/conjure/SurveyResponses.py +0 -7
  82. edsl/conjure/__init__.py +0 -9
  83. edsl/conjure/examples/placeholder.txt +0 -0
  84. edsl/conjure/utilities.py +0 -201
  85. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import warnings
4
4
  import requests
5
5
  from itertools import product
6
- from typing import Literal, Optional, Union, Sequence, Generator
6
+ from typing import Literal, Optional, Union, Sequence, Generator, TYPE_CHECKING
7
7
 
8
8
  from edsl.Base import Base
9
9
 
@@ -11,11 +11,20 @@ from edsl.exceptions import MissingAPIKeyError
11
11
  from edsl.jobs.buckets.BucketCollection import BucketCollection
12
12
  from edsl.jobs.interviews.Interview import Interview
13
13
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
14
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
+ from edsl.utilities.decorators import remove_edsl_version
15
15
 
16
16
  from edsl.data.RemoteCacheSync import RemoteCacheSync
17
17
  from edsl.exceptions.coop import CoopServerResponseError
18
18
 
19
+ if TYPE_CHECKING:
20
+ from edsl.agents.Agent import Agent
21
+ from edsl.agents.AgentList import AgentList
22
+ from edsl.language_models.LanguageModel import LanguageModel
23
+ from edsl.scenarios.Scenario import Scenario
24
+ from edsl.surveys.Survey import Survey
25
+ from edsl.results.Results import Results
26
+ from edsl.results.Dataset import Dataset
27
+
19
28
 
20
29
  class Jobs(Base):
21
30
  """
@@ -24,6 +33,8 @@ class Jobs(Base):
24
33
  The `JobsRunner` is chosen by the user, and is stored in the `jobs_runner_name` attribute.
25
34
  """
26
35
 
36
+ __documentation__ = "https://docs.expectedparrot.com/en/latest/jobs.html"
37
+
27
38
  def __init__(
28
39
  self,
29
40
  survey: "Survey",
@@ -86,8 +97,14 @@ class Jobs(Base):
86
97
  @scenarios.setter
87
98
  def scenarios(self, value):
88
99
  from edsl import ScenarioList
100
+ from edsl.results.Dataset import Dataset
89
101
 
90
102
  if value:
103
+ if isinstance(
104
+ value, Dataset
105
+ ): # if the user passes in a Dataset, convert it to a ScenarioList
106
+ value = value.to_scenario_list()
107
+
91
108
  if not isinstance(value, ScenarioList):
92
109
  self._scenarios = ScenarioList(value)
93
110
  else:
@@ -127,6 +144,13 @@ class Jobs(Base):
127
144
  - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
128
145
  - models: new models overwrite old models.
129
146
  """
147
+ from edsl.results.Dataset import Dataset
148
+
149
+ if isinstance(
150
+ args[0], Dataset
151
+ ): # let the user user a Dataset as if it were a ScenarioList
152
+ args = args[0].to_scenario_list()
153
+
130
154
  passed_objects = self._turn_args_to_list(
131
155
  args
132
156
  ) # objects can also be passed comma-separated
@@ -151,73 +175,19 @@ class Jobs(Base):
151
175
  >>> Jobs.example().prompts()
152
176
  Dataset(...)
153
177
  """
154
- from edsl import Coop
155
-
156
- c = Coop()
157
- price_lookup = c.fetch_prices()
158
-
159
- interviews = self.interviews()
160
- # data = []
161
- interview_indices = []
162
- question_names = []
163
- user_prompts = []
164
- system_prompts = []
165
- scenario_indices = []
166
- agent_indices = []
167
- models = []
168
- costs = []
169
- from edsl.results.Dataset import Dataset
178
+ from edsl.jobs.JobsPrompts import JobsPrompts
170
179
 
171
- for interview_index, interview in enumerate(interviews):
172
- invigilators = [
173
- interview._get_invigilator(question)
174
- for question in self.survey.questions
175
- ]
176
- for _, invigilator in enumerate(invigilators):
177
- prompts = invigilator.get_prompts()
178
- user_prompt = prompts["user_prompt"]
179
- system_prompt = prompts["system_prompt"]
180
- user_prompts.append(user_prompt)
181
- system_prompts.append(system_prompt)
182
- agent_index = self.agents.index(invigilator.agent)
183
- agent_indices.append(agent_index)
184
- interview_indices.append(interview_index)
185
- scenario_index = self.scenarios.index(invigilator.scenario)
186
- scenario_indices.append(scenario_index)
187
- models.append(invigilator.model.model)
188
- question_names.append(invigilator.question.question_name)
189
-
190
- prompt_cost = self.estimate_prompt_cost(
191
- system_prompt=system_prompt,
192
- user_prompt=user_prompt,
193
- price_lookup=price_lookup,
194
- inference_service=invigilator.model._inference_service_,
195
- model=invigilator.model.model,
196
- )
197
- costs.append(prompt_cost["cost_usd"])
180
+ j = JobsPrompts(self)
181
+ return j.prompts()
198
182
 
199
- d = Dataset(
200
- [
201
- {"user_prompt": user_prompts},
202
- {"system_prompt": system_prompts},
203
- {"interview_index": interview_indices},
204
- {"question_name": question_names},
205
- {"scenario_index": scenario_indices},
206
- {"agent_index": agent_indices},
207
- {"model": models},
208
- {"estimated_cost": costs},
209
- ]
210
- )
211
- return d
212
-
213
- def show_prompts(self, all=False, max_rows: Optional[int] = None) -> None:
183
+ def show_prompts(self, all=False) -> None:
214
184
  """Print the prompts."""
215
185
  if all:
216
- self.prompts().to_scenario_list().print(format="rich", max_rows=max_rows)
186
+ return self.prompts().to_scenario_list().table()
217
187
  else:
218
- self.prompts().select(
219
- "user_prompt", "system_prompt"
220
- ).to_scenario_list().print(format="rich", max_rows=max_rows)
188
+ return (
189
+ self.prompts().to_scenario_list().table("user_prompt", "system_prompt")
190
+ )
221
191
 
222
192
  @staticmethod
223
193
  def estimate_prompt_cost(
@@ -226,201 +196,42 @@ class Jobs(Base):
226
196
  price_lookup: dict,
227
197
  inference_service: str,
228
198
  model: str,
229
- ) -> dict:
230
- """Estimates the cost of a prompt. Takes piping into account."""
231
- import math
232
-
233
- def get_piping_multiplier(prompt: str):
234
- """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
235
-
236
- if "{{" in prompt and "}}" in prompt:
237
- return 2
238
- return 1
239
-
240
- # Look up prices per token
241
- key = (inference_service, model)
242
-
243
- try:
244
- relevant_prices = price_lookup[key]
245
-
246
- service_input_token_price = float(
247
- relevant_prices["input"]["service_stated_token_price"]
248
- )
249
- service_input_token_qty = float(
250
- relevant_prices["input"]["service_stated_token_qty"]
251
- )
252
- input_price_per_token = service_input_token_price / service_input_token_qty
253
-
254
- service_output_token_price = float(
255
- relevant_prices["output"]["service_stated_token_price"]
256
- )
257
- service_output_token_qty = float(
258
- relevant_prices["output"]["service_stated_token_qty"]
259
- )
260
- output_price_per_token = (
261
- service_output_token_price / service_output_token_qty
262
- )
263
-
264
- except KeyError:
265
- # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
266
- # Use a sensible default
267
-
268
- import warnings
269
-
270
- warnings.warn(
271
- "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
272
- )
273
- input_price_per_token = 0.00000015 # $0.15 / 1M tokens
274
- output_price_per_token = 0.00000060 # $0.60 / 1M tokens
275
-
276
- # Compute the number of characters (double if the question involves piping)
277
- user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
278
- str(user_prompt)
279
- )
280
- system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
281
- str(system_prompt)
282
- )
283
-
284
- # Convert into tokens (1 token approx. equals 4 characters)
285
- input_tokens = (user_prompt_chars + system_prompt_chars) // 4
286
-
287
- output_tokens = math.ceil(0.75 * input_tokens)
288
-
289
- cost = (
290
- input_tokens * input_price_per_token
291
- + output_tokens * output_price_per_token
292
- )
293
-
294
- return {
295
- "input_tokens": input_tokens,
296
- "output_tokens": output_tokens,
297
- "cost_usd": cost,
298
- }
299
-
300
- def estimate_job_cost_from_external_prices(
301
- self, price_lookup: dict, iterations: int = 1
302
199
  ) -> dict:
303
200
  """
304
- Estimates the cost of a job according to the following assumptions:
305
-
306
- - 1 token = 4 characters.
307
- - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
308
-
309
- price_lookup is an external pricing dictionary.
201
+ Estimate the cost of running the prompts.
202
+ :param iterations: the number of iterations to run
310
203
  """
204
+ from edsl.jobs.JobsPrompts import JobsPrompts
311
205
 
312
- import pandas as pd
313
-
314
- interviews = self.interviews()
315
- data = []
316
- for interview in interviews:
317
- invigilators = [
318
- interview._get_invigilator(question)
319
- for question in self.survey.questions
320
- ]
321
- for invigilator in invigilators:
322
- prompts = invigilator.get_prompts()
323
-
324
- # By this point, agent and scenario data has already been added to the prompts
325
- user_prompt = prompts["user_prompt"]
326
- system_prompt = prompts["system_prompt"]
327
- inference_service = invigilator.model._inference_service_
328
- model = invigilator.model.model
329
-
330
- prompt_cost = self.estimate_prompt_cost(
331
- system_prompt=system_prompt,
332
- user_prompt=user_prompt,
333
- price_lookup=price_lookup,
334
- inference_service=inference_service,
335
- model=model,
336
- )
337
-
338
- data.append(
339
- {
340
- "user_prompt": user_prompt,
341
- "system_prompt": system_prompt,
342
- "estimated_input_tokens": prompt_cost["input_tokens"],
343
- "estimated_output_tokens": prompt_cost["output_tokens"],
344
- "estimated_cost_usd": prompt_cost["cost_usd"],
345
- "inference_service": inference_service,
346
- "model": model,
347
- }
348
- )
349
-
350
- df = pd.DataFrame.from_records(data)
351
-
352
- df = (
353
- df.groupby(["inference_service", "model"])
354
- .agg(
355
- {
356
- "estimated_cost_usd": "sum",
357
- "estimated_input_tokens": "sum",
358
- "estimated_output_tokens": "sum",
359
- }
360
- )
361
- .reset_index()
206
+ return JobsPrompts.estimate_prompt_cost(
207
+ system_prompt, user_prompt, price_lookup, inference_service, model
362
208
  )
363
- df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
364
- df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
365
- df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
366
-
367
- estimated_costs_by_model = df.to_dict("records")
368
-
369
- estimated_total_cost = sum(
370
- model["estimated_cost_usd"] for model in estimated_costs_by_model
371
- )
372
- estimated_total_input_tokens = sum(
373
- model["estimated_input_tokens"] for model in estimated_costs_by_model
374
- )
375
- estimated_total_output_tokens = sum(
376
- model["estimated_output_tokens"] for model in estimated_costs_by_model
377
- )
378
-
379
- output = {
380
- "estimated_total_cost_usd": estimated_total_cost,
381
- "estimated_total_input_tokens": estimated_total_input_tokens,
382
- "estimated_total_output_tokens": estimated_total_output_tokens,
383
- "model_costs": estimated_costs_by_model,
384
- }
385
-
386
- return output
387
209
 
388
210
  def estimate_job_cost(self, iterations: int = 1) -> dict:
389
211
  """
390
- Estimates the cost of a job according to the following assumptions:
391
-
392
- - 1 token = 4 characters.
393
- - For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
212
+ Estimate the cost of running the job.
394
213
 
395
- Fetches prices from Coop.
214
+ :param iterations: the number of iterations to run
396
215
  """
397
- from edsl import Coop
216
+ from edsl.jobs.JobsPrompts import JobsPrompts
398
217
 
399
- c = Coop()
400
- price_lookup = c.fetch_prices()
218
+ j = JobsPrompts(self)
219
+ return j.estimate_job_cost(iterations)
401
220
 
402
- return self.estimate_job_cost_from_external_prices(
403
- price_lookup=price_lookup, iterations=iterations
404
- )
221
+ def estimate_job_cost_from_external_prices(
222
+ self, price_lookup: dict, iterations: int = 1
223
+ ) -> dict:
224
+ from edsl.jobs.JobsPrompts import JobsPrompts
225
+
226
+ j = JobsPrompts(self)
227
+ return j.estimate_job_cost_from_external_prices(price_lookup, iterations)
405
228
 
406
229
  @staticmethod
407
- def compute_job_cost(job_results: "Results") -> float:
230
+ def compute_job_cost(job_results: Results) -> float:
408
231
  """
409
232
  Computes the cost of a completed job in USD.
410
233
  """
411
- total_cost = 0
412
- for result in job_results:
413
- for key in result.raw_model_response:
414
- if key.endswith("_cost"):
415
- result_cost = result.raw_model_response[key]
416
-
417
- question_name = key.removesuffix("_cost")
418
- cache_used = result.cache_used_dict[question_name]
419
-
420
- if isinstance(result_cost, (int, float)) and not cache_used:
421
- total_cost += result_cost
422
-
423
- return total_cost
234
+ return job_results.compute_job_cost()
424
235
 
425
236
  @staticmethod
426
237
  def _get_container_class(object):
@@ -504,17 +315,12 @@ class Jobs(Base):
504
315
 
505
316
  @staticmethod
506
317
  def _get_empty_container_object(object):
507
- from edsl import AgentList
508
- from edsl import Agent
509
- from edsl import Scenario
510
- from edsl import ScenarioList
318
+ from edsl.agents.AgentList import AgentList
319
+ from edsl.scenarios.ScenarioList import ScenarioList
511
320
 
512
- if isinstance(object, Agent):
513
- return AgentList([])
514
- elif isinstance(object, Scenario):
515
- return ScenarioList([])
516
- else:
517
- return []
321
+ return {"Agent": AgentList([]), "Scenario": ScenarioList([])}.get(
322
+ object.__class__.__name__, []
323
+ )
518
324
 
519
325
  @staticmethod
520
326
  def _merge_objects(passed_objects, current_objects) -> list:
@@ -641,7 +447,7 @@ class Jobs(Base):
641
447
  """
642
448
  from edsl.utilities.utilities import dict_hash
643
449
 
644
- return dict_hash(self._to_dict())
450
+ return dict_hash(self.to_dict(add_edsl_version=False))
645
451
 
646
452
  def _output(self, message) -> None:
647
453
  """Check if a Job is verbose. If so, print the message."""
@@ -722,110 +528,6 @@ class Jobs(Base):
722
528
  return False
723
529
  return self._raise_validation_errors
724
530
 
725
- def create_remote_inference_job(
726
- self,
727
- iterations: int = 1,
728
- remote_inference_description: Optional[str] = None,
729
- remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
730
- verbose=False,
731
- ):
732
- """ """
733
- from edsl.coop.coop import Coop
734
-
735
- coop = Coop()
736
- self._output("Remote inference activated. Sending job to server...")
737
- remote_job_creation_data = coop.remote_inference_create(
738
- self,
739
- description=remote_inference_description,
740
- status="queued",
741
- iterations=iterations,
742
- initial_results_visibility=remote_inference_results_visibility,
743
- )
744
- job_uuid = remote_job_creation_data.get("uuid")
745
- if self.verbose:
746
- print(f"Job sent to server. (Job uuid={job_uuid}).")
747
- return remote_job_creation_data
748
-
749
- @staticmethod
750
- def check_status(job_uuid):
751
- from edsl.coop.coop import Coop
752
-
753
- coop = Coop()
754
- return coop.remote_inference_get(job_uuid)
755
-
756
- def poll_remote_inference_job(
757
- self, remote_job_creation_data: dict, verbose=False, poll_interval=5
758
- ) -> Union[Results, None]:
759
- from edsl.coop.coop import Coop
760
- import time
761
- from datetime import datetime
762
- from edsl.config import CONFIG
763
-
764
- expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
765
-
766
- job_uuid = remote_job_creation_data.get("uuid")
767
-
768
- coop = Coop()
769
- job_in_queue = True
770
- while job_in_queue:
771
- remote_job_data = coop.remote_inference_get(job_uuid)
772
- status = remote_job_data.get("status")
773
- if status == "cancelled":
774
- if self.verbose:
775
- print("\r" + " " * 80 + "\r", end="")
776
- print("Job cancelled by the user.")
777
- print(
778
- f"See {expected_parrot_url}/home/remote-inference for more details."
779
- )
780
- return None
781
- elif status == "failed":
782
- if self.verbose:
783
- print("\r" + " " * 80 + "\r", end="")
784
- print("Job failed.")
785
- print(
786
- f"See {expected_parrot_url}/home/remote-inference for more details."
787
- )
788
- return None
789
- elif status == "completed":
790
- results_uuid = remote_job_data.get("results_uuid")
791
- results = coop.get(results_uuid, expected_object_type="results")
792
- if self.verbose:
793
- print("\r" + " " * 80 + "\r", end="")
794
- url = f"{expected_parrot_url}/content/{results_uuid}"
795
- print(f"Job completed and Results stored on Coop: {url}.")
796
- return results
797
- else:
798
- duration = poll_interval
799
- time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
800
- frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
801
- start_time = time.time()
802
- i = 0
803
- while time.time() - start_time < duration:
804
- if self.verbose:
805
- print(
806
- f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
807
- end="",
808
- flush=True,
809
- )
810
- time.sleep(0.1)
811
- i += 1
812
-
813
- def use_remote_inference(self, disable_remote_inference: bool) -> bool:
814
- if disable_remote_inference:
815
- return False
816
- if not disable_remote_inference:
817
- try:
818
- from edsl import Coop
819
-
820
- user_edsl_settings = Coop().edsl_settings
821
- return user_edsl_settings.get("remote_inference", False)
822
- except requests.ConnectionError:
823
- pass
824
- except CoopServerResponseError as e:
825
- pass
826
-
827
- return False
828
-
829
531
  def use_remote_cache(self, disable_remote_cache: bool) -> bool:
830
532
  if disable_remote_cache:
831
533
  return False
@@ -842,96 +544,6 @@ class Jobs(Base):
842
544
 
843
545
  return False
844
546
 
845
- def check_api_keys(self) -> None:
846
- from edsl import Model
847
-
848
- for model in self.models + [Model()]:
849
- if not model.has_valid_api_key():
850
- raise MissingAPIKeyError(
851
- model_name=str(model.model),
852
- inference_service=model._inference_service_,
853
- )
854
-
855
- def get_missing_api_keys(self) -> set:
856
- """
857
- Returns a list of the api keys that a user needs to run this job, but does not currently have in their .env file.
858
- """
859
-
860
- missing_api_keys = set()
861
-
862
- from edsl import Model
863
- from edsl.enums import service_to_api_keyname
864
-
865
- for model in self.models + [Model()]:
866
- if not model.has_valid_api_key():
867
- key_name = service_to_api_keyname.get(
868
- model._inference_service_, "NOT FOUND"
869
- )
870
- missing_api_keys.add(key_name)
871
-
872
- return missing_api_keys
873
-
874
- def user_has_all_model_keys(self):
875
- """
876
- Returns True if the user has all model keys required to run their job.
877
-
878
- Otherwise, returns False.
879
- """
880
-
881
- try:
882
- self.check_api_keys()
883
- return True
884
- except MissingAPIKeyError:
885
- return False
886
- except Exception:
887
- raise
888
-
889
- def user_has_ep_api_key(self) -> bool:
890
- """
891
- Returns True if the user has an EXPECTED_PARROT_API_KEY in their env.
892
-
893
- Otherwise, returns False.
894
- """
895
-
896
- import os
897
-
898
- coop_api_key = os.getenv("EXPECTED_PARROT_API_KEY")
899
-
900
- if coop_api_key is not None:
901
- return True
902
- else:
903
- return False
904
-
905
- def needs_external_llms(self) -> bool:
906
- """
907
- Returns True if the job needs external LLMs to run.
908
-
909
- Otherwise, returns False.
910
- """
911
- # These cases are necessary to skip the API key check during doctests
912
-
913
- # Accounts for Results.example()
914
- all_agents_answer_questions_directly = len(self.agents) > 0 and all(
915
- [hasattr(a, "answer_question_directly") for a in self.agents]
916
- )
917
-
918
- # Accounts for InterviewExceptionEntry.example()
919
- only_model_is_test = set([m.model for m in self.models]) == set(["test"])
920
-
921
- # Accounts for Survey.__call__
922
- all_questions_are_functional = set(
923
- [q.question_type for q in self.survey.questions]
924
- ) == set(["functional"])
925
-
926
- if (
927
- all_agents_answer_questions_directly
928
- or only_model_is_test
929
- or all_questions_are_functional
930
- ):
931
- return False
932
- else:
933
- return True
934
-
935
547
  def run(
936
548
  self,
937
549
  n: int = 1,
@@ -940,7 +552,7 @@ class Jobs(Base):
940
552
  cache: Union[Cache, bool] = None,
941
553
  check_api_keys: bool = False,
942
554
  sidecar_model: Optional[LanguageModel] = None,
943
- verbose: bool = False,
555
+ verbose: bool = True,
944
556
  print_exceptions=True,
945
557
  remote_cache_description: Optional[str] = None,
946
558
  remote_inference_description: Optional[str] = None,
@@ -975,62 +587,28 @@ class Jobs(Base):
975
587
 
976
588
  self.verbose = verbose
977
589
 
978
- if (
979
- not self.user_has_all_model_keys()
980
- and not self.user_has_ep_api_key()
981
- and self.needs_external_llms()
982
- ):
983
- import secrets
984
- from dotenv import load_dotenv
985
- from edsl import CONFIG
986
- from edsl.coop.coop import Coop
987
- from edsl.utilities.utilities import write_api_key_to_env
988
-
989
- missing_api_keys = self.get_missing_api_keys()
590
+ from edsl.jobs.JobsChecks import JobsChecks
990
591
 
991
- edsl_auth_token = secrets.token_urlsafe(16)
592
+ jc = JobsChecks(self)
992
593
 
993
- print("You're missing some of the API keys needed to run this job:")
994
- for api_key in missing_api_keys:
995
- print(f" 🔑 {api_key}")
996
- print(
997
- "\nYou can either add the missing keys to your .env file, or use remote inference."
998
- )
999
- print("Remote inference allows you to run jobs on our server.")
1000
- print("\n🚀 To use remote inference, sign up at the following link:")
1001
-
1002
- coop = Coop()
1003
- coop._display_login_url(edsl_auth_token=edsl_auth_token)
1004
-
1005
- print(
1006
- "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
1007
- )
1008
-
1009
- api_key = coop._poll_for_api_key(edsl_auth_token)
1010
-
1011
- if api_key is None:
1012
- print("\nTimed out waiting for login. Please try again.")
1013
- return
594
+ # check if the user has all the keys they need
595
+ if jc.needs_key_process():
596
+ jc.key_process()
1014
597
 
1015
- write_api_key_to_env(api_key)
1016
- print("✨ API key retrieved and written to .env file.\n")
598
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1017
599
 
1018
- # Retrieve API key so we can continue running the job
1019
- load_dotenv()
1020
-
1021
- if remote_inference := self.use_remote_inference(disable_remote_inference):
1022
- remote_job_creation_data = self.create_remote_inference_job(
600
+ jh = JobsRemoteInferenceHandler(self, verbose=verbose)
601
+ if jh.use_remote_inference(disable_remote_inference):
602
+ jh.create_remote_inference_job(
1023
603
  iterations=n,
1024
604
  remote_inference_description=remote_inference_description,
1025
605
  remote_inference_results_visibility=remote_inference_results_visibility,
1026
606
  )
1027
- results = self.poll_remote_inference_job(remote_job_creation_data)
1028
- if results is None:
1029
- self._output("Job failed.")
607
+ results = jh.poll_remote_inference_job()
1030
608
  return results
1031
609
 
1032
610
  if check_api_keys:
1033
- self.check_api_keys()
611
+ jc.check_api_keys()
1034
612
 
1035
613
  # handle cache
1036
614
  if cache is None or cache is True:
@@ -1060,46 +638,9 @@ class Jobs(Base):
1060
638
  raise_validation_errors=raise_validation_errors,
1061
639
  )
1062
640
 
1063
- results.cache = cache.new_entries_cache()
641
+ # results.cache = cache.new_entries_cache()
1064
642
  return results
1065
643
 
1066
- async def create_and_poll_remote_job(
1067
- self,
1068
- iterations: int = 1,
1069
- remote_inference_description: Optional[str] = None,
1070
- remote_inference_results_visibility: Optional[
1071
- Literal["private", "public", "unlisted"]
1072
- ] = "unlisted",
1073
- ) -> Union[Results, None]:
1074
- """
1075
- Creates and polls a remote inference job asynchronously.
1076
- Reuses existing synchronous methods but runs them in an async context.
1077
-
1078
- :param iterations: Number of times to run each interview
1079
- :param remote_inference_description: Optional description for the remote job
1080
- :param remote_inference_results_visibility: Visibility setting for results
1081
- :return: Results object if successful, None if job fails or is cancelled
1082
- """
1083
- import asyncio
1084
- from functools import partial
1085
-
1086
- # Create job using existing method
1087
- loop = asyncio.get_event_loop()
1088
- remote_job_creation_data = await loop.run_in_executor(
1089
- None,
1090
- partial(
1091
- self.create_remote_inference_job,
1092
- iterations=iterations,
1093
- remote_inference_description=remote_inference_description,
1094
- remote_inference_results_visibility=remote_inference_results_visibility,
1095
- ),
1096
- )
1097
-
1098
- # Poll using existing method but with async sleep
1099
- return await loop.run_in_executor(
1100
- None, partial(self.poll_remote_inference_job, remote_job_creation_data)
1101
- )
1102
-
1103
644
  async def run_async(
1104
645
  self,
1105
646
  cache=None,
@@ -1122,14 +663,15 @@ class Jobs(Base):
1122
663
  :return: Results object
1123
664
  """
1124
665
  # Check if we should use remote inference
1125
- if remote_inference := self.use_remote_inference(disable_remote_inference):
1126
- results = await self.create_and_poll_remote_job(
666
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
667
+
668
+ jh = JobsRemoteInferenceHandler(self, verbose=False)
669
+ if jh.use_remote_inference(disable_remote_inference):
670
+ results = await jh.create_and_poll_remote_job(
1127
671
  iterations=n,
1128
672
  remote_inference_description=remote_inference_description,
1129
673
  remote_inference_results_visibility=remote_inference_results_visibility,
1130
674
  )
1131
- if results is None:
1132
- self._output("Job failed.")
1133
675
  return results
1134
676
 
1135
677
  # If not using remote inference, run locally with async
@@ -1149,24 +691,22 @@ class Jobs(Base):
1149
691
  """
1150
692
  return set.union(*[question.parameters for question in self.survey.questions])
1151
693
 
1152
- #######################
1153
- # Dunder methods
1154
- #######################
1155
- def print(self):
1156
- from rich import print_json
1157
- import json
1158
-
1159
- print_json(json.dumps(self.to_dict()))
1160
-
1161
694
  def __repr__(self) -> str:
1162
695
  """Return an eval-able string representation of the Jobs instance."""
1163
696
  return f"Jobs(survey={repr(self.survey)}, agents={repr(self.agents)}, models={repr(self.models)}, scenarios={repr(self.scenarios)})"
1164
697
 
1165
- def _repr_html_(self) -> str:
1166
- from rich import print_json
1167
- import json
698
+ def _summary(self):
699
+ return {
700
+ "EDSL Class": "Jobs",
701
+ "Number of questions": len(self.survey),
702
+ "Number of agents": len(self.agents),
703
+ "Number of models": len(self.models),
704
+ "Number of scenarios": len(self.scenarios),
705
+ }
1168
706
 
1169
- print_json(json.dumps(self.to_dict()))
707
+ def _repr_html_(self) -> str:
708
+ footer = f"<a href={self.__documentation__}>(docs)</a>"
709
+ return str(self.summary(format="html")) + footer
1170
710
 
1171
711
  def __len__(self) -> int:
1172
712
  """Return the maximum number of questions that will be asked while running this job.
@@ -1188,18 +728,29 @@ class Jobs(Base):
1188
728
  # Serialization methods
1189
729
  #######################
1190
730
 
1191
- def _to_dict(self):
1192
- return {
1193
- "survey": self.survey._to_dict(),
1194
- "agents": [agent._to_dict() for agent in self.agents],
1195
- "models": [model._to_dict() for model in self.models],
1196
- "scenarios": [scenario._to_dict() for scenario in self.scenarios],
731
+ def to_dict(self, add_edsl_version=True):
732
+ d = {
733
+ "survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
734
+ "agents": [
735
+ agent.to_dict(add_edsl_version=add_edsl_version)
736
+ for agent in self.agents
737
+ ],
738
+ "models": [
739
+ model.to_dict(add_edsl_version=add_edsl_version)
740
+ for model in self.models
741
+ ],
742
+ "scenarios": [
743
+ scenario.to_dict(add_edsl_version=add_edsl_version)
744
+ for scenario in self.scenarios
745
+ ],
1197
746
  }
747
+ if add_edsl_version:
748
+ from edsl import __version__
1198
749
 
1199
- @add_edsl_version
1200
- def to_dict(self) -> dict:
1201
- """Convert the Jobs instance to a dictionary."""
1202
- return self._to_dict()
750
+ d["edsl_version"] = __version__
751
+ d["edsl_class_name"] = "Jobs"
752
+
753
+ return d
1203
754
 
1204
755
  @classmethod
1205
756
  @remove_edsl_version
@@ -1225,7 +776,7 @@ class Jobs(Base):
1225
776
  True
1226
777
 
1227
778
  """
1228
- return self.to_dict() == other.to_dict()
779
+ return hash(self) == hash(other)
1229
780
 
1230
781
  #######################
1231
782
  # Example methods