edsl 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. edsl/Base.py +5 -0
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +37 -9
  5. edsl/agents/Invigilator.py +2 -1
  6. edsl/agents/InvigilatorBase.py +5 -1
  7. edsl/agents/PromptConstructor.py +31 -67
  8. edsl/conversation/Conversation.py +1 -1
  9. edsl/coop/PriceFetcher.py +14 -18
  10. edsl/coop/coop.py +42 -8
  11. edsl/data/RemoteCacheSync.py +97 -0
  12. edsl/exceptions/coop.py +8 -0
  13. edsl/inference_services/InferenceServiceABC.py +28 -0
  14. edsl/inference_services/InferenceServicesCollection.py +10 -4
  15. edsl/inference_services/models_available_cache.py +25 -1
  16. edsl/inference_services/registry.py +24 -16
  17. edsl/jobs/Jobs.py +327 -206
  18. edsl/jobs/interviews/Interview.py +65 -10
  19. edsl/jobs/interviews/InterviewExceptionCollection.py +9 -0
  20. edsl/jobs/interviews/InterviewExceptionEntry.py +31 -9
  21. edsl/jobs/runners/JobsRunnerAsyncio.py +8 -13
  22. edsl/jobs/tasks/QuestionTaskCreator.py +1 -5
  23. edsl/jobs/tasks/TaskHistory.py +23 -7
  24. edsl/language_models/LanguageModel.py +3 -0
  25. edsl/prompts/Prompt.py +24 -38
  26. edsl/prompts/__init__.py +1 -1
  27. edsl/questions/QuestionBasePromptsMixin.py +18 -18
  28. edsl/questions/QuestionFunctional.py +7 -3
  29. edsl/questions/descriptors.py +24 -24
  30. edsl/results/Dataset.py +12 -0
  31. edsl/results/Result.py +2 -0
  32. edsl/results/Results.py +13 -1
  33. edsl/scenarios/FileStore.py +20 -5
  34. edsl/scenarios/Scenario.py +15 -1
  35. edsl/scenarios/__init__.py +2 -0
  36. edsl/surveys/Survey.py +3 -0
  37. edsl/surveys/instructions/Instruction.py +20 -3
  38. {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/METADATA +1 -1
  39. {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/RECORD +41 -57
  40. edsl/jobs/FailedQuestion.py +0 -78
  41. edsl/jobs/interviews/InterviewStatusMixin.py +0 -33
  42. edsl/jobs/tasks/task_management.py +0 -13
  43. edsl/prompts/QuestionInstructionsBase.py +0 -10
  44. edsl/prompts/library/agent_instructions.py +0 -38
  45. edsl/prompts/library/agent_persona.py +0 -21
  46. edsl/prompts/library/question_budget.py +0 -30
  47. edsl/prompts/library/question_checkbox.py +0 -38
  48. edsl/prompts/library/question_extract.py +0 -23
  49. edsl/prompts/library/question_freetext.py +0 -18
  50. edsl/prompts/library/question_linear_scale.py +0 -24
  51. edsl/prompts/library/question_list.py +0 -26
  52. edsl/prompts/library/question_multiple_choice.py +0 -54
  53. edsl/prompts/library/question_numerical.py +0 -35
  54. edsl/prompts/library/question_rank.py +0 -25
  55. edsl/prompts/prompt_config.py +0 -37
  56. edsl/prompts/registry.py +0 -202
  57. {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/LICENSE +0 -0
  58. {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/WHEEL +0 -0
@@ -65,7 +65,31 @@ models_available = {
65
65
  "meta-llama/Meta-Llama-3-70B-Instruct",
66
66
  "openchat/openchat_3.5",
67
67
  ],
68
- "google": ["gemini-pro"],
68
+ "google": [
69
+ "gemini-1.0-pro",
70
+ "gemini-1.0-pro-001",
71
+ "gemini-1.0-pro-latest",
72
+ "gemini-1.0-pro-vision-latest",
73
+ "gemini-1.5-flash",
74
+ "gemini-1.5-flash-001",
75
+ "gemini-1.5-flash-001-tuning",
76
+ "gemini-1.5-flash-002",
77
+ "gemini-1.5-flash-8b",
78
+ "gemini-1.5-flash-8b-001",
79
+ "gemini-1.5-flash-8b-exp-0827",
80
+ "gemini-1.5-flash-8b-exp-0924",
81
+ "gemini-1.5-flash-8b-latest",
82
+ "gemini-1.5-flash-exp-0827",
83
+ "gemini-1.5-flash-latest",
84
+ "gemini-1.5-pro",
85
+ "gemini-1.5-pro-001",
86
+ "gemini-1.5-pro-002",
87
+ "gemini-1.5-pro-exp-0801",
88
+ "gemini-1.5-pro-exp-0827",
89
+ "gemini-1.5-pro-latest",
90
+ "gemini-pro",
91
+ "gemini-pro-vision",
92
+ ],
69
93
  "bedrock": [
70
94
  "amazon.titan-tg1-large",
71
95
  "amazon.titan-text-lite-v1",
@@ -11,21 +11,29 @@ from edsl.inference_services.AwsBedrock import AwsBedrockService
11
11
  from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
13
  from edsl.inference_services.TestService import TestService
14
- from edsl.inference_services.MistralAIService import MistralAIService
15
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
16
15
 
17
- default = InferenceServicesCollection(
18
- [
19
- OpenAIService,
20
- AnthropicService,
21
- DeepInfraService,
22
- GoogleService,
23
- GroqService,
24
- AwsBedrockService,
25
- AzureAIService,
26
- OllamaService,
27
- TestService,
28
- MistralAIService,
29
- TogetherAIService,
30
- ]
31
- )
16
+ try:
17
+ from edsl.inference_services.MistralAIService import MistralAIService
18
+
19
+ mistral_available = True
20
+ except Exception as e:
21
+ mistral_available = False
22
+
23
+ services = [
24
+ OpenAIService,
25
+ AnthropicService,
26
+ DeepInfraService,
27
+ GoogleService,
28
+ GroqService,
29
+ AwsBedrockService,
30
+ AzureAIService,
31
+ OllamaService,
32
+ TestService,
33
+ TogetherAIService,
34
+ ]
35
+
36
+ if mistral_available:
37
+ services.append(MistralAIService)
38
+
39
+ default = InferenceServicesCollection(services)
edsl/jobs/Jobs.py CHANGED
@@ -1,8 +1,10 @@
1
1
  # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
2
  from __future__ import annotations
3
3
  import warnings
4
+ import requests
4
5
  from itertools import product
5
6
  from typing import Optional, Union, Sequence, Generator
7
+
6
8
  from edsl.Base import Base
7
9
  from edsl.exceptions import MissingAPIKeyError
8
10
  from edsl.jobs.buckets.BucketCollection import BucketCollection
@@ -10,6 +12,9 @@ from edsl.jobs.interviews.Interview import Interview
10
12
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
11
13
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
12
14
 
15
+ from edsl.data.RemoteCacheSync import RemoteCacheSync
16
+ from edsl.exceptions.coop import CoopServerResponseError
17
+
13
18
 
14
19
  class Jobs(Base):
15
20
  """
@@ -180,17 +185,15 @@ class Jobs(Base):
180
185
  scenario_indices.append(scenario_index)
181
186
  models.append(invigilator.model.model)
182
187
  question_names.append(invigilator.question.question_name)
183
- # cost calculation
184
- key = (invigilator.model._inference_service_, invigilator.model.model)
185
- relevant_prices = price_lookup[key]
186
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
187
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
188
- input_tokens = len(str(user_prompt) + str(system_prompt)) // 4
189
- output_tokens = len(str(user_prompt) + str(system_prompt)) // 4
190
- cost = input_tokens / float(
191
- inverse_input_price
192
- ) + output_tokens / float(inverse_output_price)
193
- costs.append(cost)
188
+
189
+ prompt_cost = self.estimate_prompt_cost(
190
+ system_prompt=system_prompt,
191
+ user_prompt=user_prompt,
192
+ price_lookup=price_lookup,
193
+ inference_service=invigilator.model._inference_service_,
194
+ model=invigilator.model.model,
195
+ )
196
+ costs.append(prompt_cost["cost"])
194
197
 
195
198
  d = Dataset(
196
199
  [
@@ -205,59 +208,195 @@ class Jobs(Base):
205
208
  ]
206
209
  )
207
210
  return d
208
- # if table:
209
- # d.to_scenario_list().print(format="rich")
210
- # else:
211
- # return d
212
211
 
213
- def show_prompts(self) -> None:
212
+ def show_prompts(self, all=False) -> None:
214
213
  """Print the prompts."""
215
- self.prompts().to_scenario_list().print(format="rich")
214
+ if all:
215
+ self.prompts().to_scenario_list().print(format="rich")
216
+ else:
217
+ self.prompts().select(
218
+ "user_prompt", "system_prompt"
219
+ ).to_scenario_list().print(format="rich")
220
+
221
+ @staticmethod
222
+ def estimate_prompt_cost(
223
+ system_prompt: str,
224
+ user_prompt: str,
225
+ price_lookup: dict,
226
+ inference_service: str,
227
+ model: str,
228
+ ) -> dict:
229
+ """Estimates the cost of a prompt. Takes piping into account."""
230
+
231
+ def get_piping_multiplier(prompt: str):
232
+ """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
233
+
234
+ if "{{" in prompt and "}}" in prompt:
235
+ return 2
236
+ return 1
237
+
238
+ # Look up prices per token
239
+ key = (inference_service, model)
240
+
241
+ try:
242
+ relevant_prices = price_lookup[key]
243
+ output_price_per_token = 1 / float(
244
+ relevant_prices["output"]["one_usd_buys"]
245
+ )
246
+ input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
247
+ except KeyError:
248
+ # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
249
+ # Use a sensible default
250
+
251
+ import warnings
252
+
253
+ warnings.warn(
254
+ "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
255
+ )
256
+
257
+ output_price_per_token = 0.00000015 # $0.15 / 1M tokens
258
+ input_price_per_token = 0.00000060 # $0.60 / 1M tokens
259
+
260
+ # Compute the number of characters (double if the question involves piping)
261
+ user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
262
+ str(user_prompt)
263
+ )
264
+ system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
265
+ str(system_prompt)
266
+ )
267
+
268
+ # Convert into tokens (1 token approx. equals 4 characters)
269
+ input_tokens = (user_prompt_chars + system_prompt_chars) // 4
270
+ output_tokens = input_tokens
271
+
272
+ cost = (
273
+ input_tokens * input_price_per_token
274
+ + output_tokens * output_price_per_token
275
+ )
276
+
277
+ return {
278
+ "input_tokens": input_tokens,
279
+ "output_tokens": output_tokens,
280
+ "cost": cost,
281
+ }
282
+
283
+ def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
284
+ """
285
+ Estimates the cost of a job according to the following assumptions:
286
+
287
+ - 1 token = 4 characters.
288
+ - Input tokens = output tokens.
289
+
290
+ price_lookup is an external pricing dictionary.
291
+ """
292
+
293
+ import pandas as pd
294
+
295
+ interviews = self.interviews()
296
+ data = []
297
+ for interview in interviews:
298
+ invigilators = [
299
+ interview._get_invigilator(question)
300
+ for question in self.survey.questions
301
+ ]
302
+ for invigilator in invigilators:
303
+ prompts = invigilator.get_prompts()
216
304
 
217
- def estimate_job_cost(self):
305
+ # By this point, agent and scenario data has already been added to the prompts
306
+ user_prompt = prompts["user_prompt"]
307
+ system_prompt = prompts["system_prompt"]
308
+ inference_service = invigilator.model._inference_service_
309
+ model = invigilator.model.model
310
+
311
+ prompt_cost = self.estimate_prompt_cost(
312
+ system_prompt=system_prompt,
313
+ user_prompt=user_prompt,
314
+ price_lookup=price_lookup,
315
+ inference_service=inference_service,
316
+ model=model,
317
+ )
318
+
319
+ data.append(
320
+ {
321
+ "user_prompt": user_prompt,
322
+ "system_prompt": system_prompt,
323
+ "estimated_input_tokens": prompt_cost["input_tokens"],
324
+ "estimated_output_tokens": prompt_cost["output_tokens"],
325
+ "estimated_cost": prompt_cost["cost"],
326
+ "inference_service": inference_service,
327
+ "model": model,
328
+ }
329
+ )
330
+
331
+ df = pd.DataFrame.from_records(data)
332
+
333
+ df = (
334
+ df.groupby(["inference_service", "model"])
335
+ .agg(
336
+ {
337
+ "estimated_cost": "sum",
338
+ "estimated_input_tokens": "sum",
339
+ "estimated_output_tokens": "sum",
340
+ }
341
+ )
342
+ .reset_index()
343
+ )
344
+
345
+ estimated_costs_by_model = df.to_dict("records")
346
+
347
+ estimated_total_cost = sum(
348
+ model["estimated_cost"] for model in estimated_costs_by_model
349
+ )
350
+ estimated_total_input_tokens = sum(
351
+ model["estimated_input_tokens"] for model in estimated_costs_by_model
352
+ )
353
+ estimated_total_output_tokens = sum(
354
+ model["estimated_output_tokens"] for model in estimated_costs_by_model
355
+ )
356
+
357
+ output = {
358
+ "estimated_total_cost": estimated_total_cost,
359
+ "estimated_total_input_tokens": estimated_total_input_tokens,
360
+ "estimated_total_output_tokens": estimated_total_output_tokens,
361
+ "model_costs": estimated_costs_by_model,
362
+ }
363
+
364
+ return output
365
+
366
+ def estimate_job_cost(self) -> dict:
367
+ """
368
+ Estimates the cost of a job according to the following assumptions:
369
+
370
+ - 1 token = 4 characters.
371
+ - Input tokens = output tokens.
372
+
373
+ Fetches prices from Coop.
374
+ """
218
375
  from edsl import Coop
219
376
 
220
377
  c = Coop()
221
378
  price_lookup = c.fetch_prices()
222
379
 
223
- prompts = self.prompts()
380
+ return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
224
381
 
225
- text_len = 0
226
- for prompt in prompts:
227
- text_len += len(str(prompt))
382
+ @staticmethod
383
+ def compute_job_cost(job_results: "Results") -> float:
384
+ """
385
+ Computes the cost of a completed job in USD.
386
+ """
387
+ total_cost = 0
388
+ for result in job_results:
389
+ for key in result.raw_model_response:
390
+ if key.endswith("_cost"):
391
+ result_cost = result.raw_model_response[key]
228
392
 
229
- input_token_aproximations = text_len // 4
393
+ question_name = key.removesuffix("_cost")
394
+ cache_used = result.cache_used_dict[question_name]
230
395
 
231
- aproximation_cost = {}
232
- total_cost = 0
233
- for model in self.models:
234
- key = (model._inference_service_, model.model)
235
- relevant_prices = price_lookup[key]
236
- inverse_output_price = relevant_prices["output"]["one_usd_buys"]
237
- inverse_input_price = relevant_prices["input"]["one_usd_buys"]
238
-
239
- aproximation_cost[key] = {
240
- "input": input_token_aproximations / float(inverse_input_price),
241
- "output": input_token_aproximations / float(inverse_output_price),
242
- }
243
- ##TODO curenlty we approximate the number of output tokens with the number
244
- # of input tokens. A better solution will be to compute the quesiton answer options length and sum them
245
- # to compute the output tokens
246
-
247
- total_cost += input_token_aproximations / float(inverse_input_price)
248
- total_cost += input_token_aproximations / float(inverse_output_price)
249
-
250
- # multiply_factor = len(self.agents or [1]) * len(self.scenarios or [1])
251
- multiply_factor = 1
252
- out = {
253
- "input_token_aproximations": input_token_aproximations,
254
- "models_costs": aproximation_cost,
255
- "estimated_total_cost": total_cost * multiply_factor,
256
- "multiply_factor": multiply_factor,
257
- "single_config_cost": total_cost,
258
- }
396
+ if isinstance(result_cost, (int, float)) and not cache_used:
397
+ total_cost += result_cost
259
398
 
260
- return out
399
+ return total_cost
261
400
 
262
401
  @staticmethod
263
402
  def _get_container_class(object):
@@ -482,7 +621,7 @@ class Jobs(Base):
482
621
 
483
622
  def _output(self, message) -> None:
484
623
  """Check if a Job is verbose. If so, print the message."""
485
- if self.verbose:
624
+ if hasattr(self, "verbose") and self.verbose:
486
625
  print(message)
487
626
 
488
627
  def _check_parameters(self, strict=False, warn=False) -> None:
@@ -559,6 +698,123 @@ class Jobs(Base):
559
698
  return False
560
699
  return self._raise_validation_errors
561
700
 
701
+ def create_remote_inference_job(
702
+ self, iterations: int = 1, remote_inference_description: Optional[str] = None
703
+ ):
704
+ """ """
705
+ from edsl.coop.coop import Coop
706
+
707
+ coop = Coop()
708
+ self._output("Remote inference activated. Sending job to server...")
709
+ remote_job_creation_data = coop.remote_inference_create(
710
+ self,
711
+ description=remote_inference_description,
712
+ status="queued",
713
+ iterations=iterations,
714
+ )
715
+ job_uuid = remote_job_creation_data.get("uuid")
716
+ print(f"Job sent to server. (Job uuid={job_uuid}).")
717
+ return remote_job_creation_data
718
+
719
+ @staticmethod
720
+ def check_status(job_uuid):
721
+ from edsl.coop.coop import Coop
722
+
723
+ coop = Coop()
724
+ return coop.remote_inference_get(job_uuid)
725
+
726
+ def poll_remote_inference_job(
727
+ self, remote_job_creation_data: dict
728
+ ) -> Union[Results, None]:
729
+ from edsl.coop.coop import Coop
730
+ import time
731
+ from datetime import datetime
732
+ from edsl.config import CONFIG
733
+
734
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
735
+
736
+ job_uuid = remote_job_creation_data.get("uuid")
737
+
738
+ coop = Coop()
739
+ job_in_queue = True
740
+ while job_in_queue:
741
+ remote_job_data = coop.remote_inference_get(job_uuid)
742
+ status = remote_job_data.get("status")
743
+ if status == "cancelled":
744
+ print("\r" + " " * 80 + "\r", end="")
745
+ print("Job cancelled by the user.")
746
+ print(
747
+ f"See {expected_parrot_url}/home/remote-inference for more details."
748
+ )
749
+ return None
750
+ elif status == "failed":
751
+ print("\r" + " " * 80 + "\r", end="")
752
+ print("Job failed.")
753
+ print(
754
+ f"See {expected_parrot_url}/home/remote-inference for more details."
755
+ )
756
+ return None
757
+ elif status == "completed":
758
+ results_uuid = remote_job_data.get("results_uuid")
759
+ results = coop.get(results_uuid, expected_object_type="results")
760
+ print("\r" + " " * 80 + "\r", end="")
761
+ url = f"{expected_parrot_url}/content/{results_uuid}"
762
+ print(f"Job completed and Results stored on Coop: {url}.")
763
+ return results
764
+ else:
765
+ duration = 5
766
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
767
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
768
+ start_time = time.time()
769
+ i = 0
770
+ while time.time() - start_time < duration:
771
+ print(
772
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
773
+ end="",
774
+ flush=True,
775
+ )
776
+ time.sleep(0.1)
777
+ i += 1
778
+
779
+ def use_remote_inference(self, disable_remote_inference: bool):
780
+ if disable_remote_inference:
781
+ return False
782
+ if not disable_remote_inference:
783
+ try:
784
+ from edsl import Coop
785
+
786
+ user_edsl_settings = Coop().edsl_settings
787
+ return user_edsl_settings.get("remote_inference", False)
788
+ except requests.ConnectionError:
789
+ pass
790
+ except CoopServerResponseError as e:
791
+ pass
792
+
793
+ return False
794
+
795
+ def use_remote_cache(self):
796
+ try:
797
+ from edsl import Coop
798
+
799
+ user_edsl_settings = Coop().edsl_settings
800
+ return user_edsl_settings.get("remote_caching", False)
801
+ except requests.ConnectionError:
802
+ pass
803
+ except CoopServerResponseError as e:
804
+ pass
805
+
806
+ return False
807
+
808
+ def check_api_keys(self):
809
+ from edsl import Model
810
+
811
+ for model in self.models + [Model()]:
812
+ if not model.has_valid_api_key():
813
+ raise MissingAPIKeyError(
814
+ model_name=str(model.model),
815
+ inference_service=model._inference_service_,
816
+ )
817
+
562
818
  def run(
563
819
  self,
564
820
  n: int = 1,
@@ -596,91 +852,17 @@ class Jobs(Base):
596
852
 
597
853
  self.verbose = verbose
598
854
 
599
- remote_cache = False
600
- remote_inference = False
601
-
602
- if not disable_remote_inference:
603
- try:
604
- coop = Coop()
605
- user_edsl_settings = Coop().edsl_settings
606
- remote_cache = user_edsl_settings.get("remote_caching", False)
607
- remote_inference = user_edsl_settings.get("remote_inference", False)
608
- except Exception:
609
- pass
610
-
611
- if remote_inference:
612
- import time
613
- from datetime import datetime
614
- from edsl.config import CONFIG
615
-
616
- expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
617
-
618
- self._output("Remote inference activated. Sending job to server...")
619
- if remote_cache:
620
- self._output(
621
- "Remote caching activated. The remote cache will be used for this job."
622
- )
623
-
624
- remote_job_creation_data = coop.remote_inference_create(
625
- self,
626
- description=remote_inference_description,
627
- status="queued",
628
- iterations=n,
855
+ if remote_inference := self.use_remote_inference(disable_remote_inference):
856
+ remote_job_creation_data = self.create_remote_inference_job(
857
+ iterations=n, remote_inference_description=remote_inference_description
629
858
  )
630
- time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
631
- job_uuid = remote_job_creation_data.get("uuid")
632
- print(f"Remote inference started (Job uuid={job_uuid}).")
633
- # print(f"Job queued at {time_queued}.")
634
- job_in_queue = True
635
- while job_in_queue:
636
- remote_job_data = coop.remote_inference_get(job_uuid)
637
- status = remote_job_data.get("status")
638
- if status == "cancelled":
639
- print("\r" + " " * 80 + "\r", end="")
640
- print("Job cancelled by the user.")
641
- print(
642
- f"See {expected_parrot_url}/home/remote-inference for more details."
643
- )
644
- return None
645
- elif status == "failed":
646
- print("\r" + " " * 80 + "\r", end="")
647
- print("Job failed.")
648
- print(
649
- f"See {expected_parrot_url}/home/remote-inference for more details."
650
- )
651
- return None
652
- elif status == "completed":
653
- results_uuid = remote_job_data.get("results_uuid")
654
- results = coop.get(results_uuid, expected_object_type="results")
655
- print("\r" + " " * 80 + "\r", end="")
656
- print(
657
- f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
658
- )
659
- return results
660
- else:
661
- duration = 5
662
- time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
663
- frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
664
- start_time = time.time()
665
- i = 0
666
- while time.time() - start_time < duration:
667
- print(
668
- f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
669
- end="",
670
- flush=True,
671
- )
672
- time.sleep(0.1)
673
- i += 1
674
- else:
675
- if check_api_keys:
676
- from edsl import Model
859
+ results = self.poll_remote_inference_job(remote_job_creation_data)
860
+ if results is None:
861
+ self._output("Job failed.")
862
+ return results
677
863
 
678
- for model in self.models + [Model()]:
679
- if not model.has_valid_api_key():
680
- raise MissingAPIKeyError(
681
- model_name=str(model.model),
682
- inference_service=model._inference_service_,
683
- )
864
+ if check_api_keys:
865
+ self.check_api_keys()
684
866
 
685
867
  # handle cache
686
868
  if cache is None or cache is True:
@@ -692,51 +874,14 @@ class Jobs(Base):
692
874
 
693
875
  cache = Cache()
694
876
 
695
- if not remote_cache:
696
- results = self._run_local(
697
- n=n,
698
- progress_bar=progress_bar,
699
- cache=cache,
700
- stop_on_exception=stop_on_exception,
701
- sidecar_model=sidecar_model,
702
- print_exceptions=print_exceptions,
703
- raise_validation_errors=raise_validation_errors,
704
- )
705
-
706
- results.cache = cache.new_entries_cache()
707
-
708
- self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
709
- else:
710
- cache_difference = coop.remote_cache_get_diff(cache.keys())
711
-
712
- client_missing_cacheentries = cache_difference.get(
713
- "client_missing_cacheentries", []
714
- )
715
-
716
- missing_entry_count = len(client_missing_cacheentries)
717
- if missing_entry_count > 0:
718
- self._output(
719
- f"Updating local cache with {missing_entry_count:,} new "
720
- f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
721
- )
722
- cache.add_from_dict(
723
- {entry.key: entry for entry in client_missing_cacheentries}
724
- )
725
- self._output("Local cache updated!")
726
- else:
727
- self._output("No new entries to add to local cache.")
728
-
729
- server_missing_cacheentry_keys = cache_difference.get(
730
- "server_missing_cacheentry_keys", []
731
- )
732
- server_missing_cacheentries = [
733
- entry
734
- for key in server_missing_cacheentry_keys
735
- if (entry := cache.data.get(key)) is not None
736
- ]
737
- old_entry_keys = [key for key in cache.keys()]
738
-
739
- self._output("Running job...")
877
+ remote_cache = self.use_remote_cache()
878
+ with RemoteCacheSync(
879
+ coop=Coop(),
880
+ cache=cache,
881
+ output_func=self._output,
882
+ remote_cache=remote_cache,
883
+ remote_cache_description=remote_cache_description,
884
+ ) as r:
740
885
  results = self._run_local(
741
886
  n=n,
742
887
  progress_bar=progress_bar,
@@ -746,32 +891,8 @@ class Jobs(Base):
746
891
  print_exceptions=print_exceptions,
747
892
  raise_validation_errors=raise_validation_errors,
748
893
  )
749
- self._output("Job completed!")
750
-
751
- new_cache_entries = list(
752
- [entry for entry in cache.values() if entry.key not in old_entry_keys]
753
- )
754
- server_missing_cacheentries.extend(new_cache_entries)
755
-
756
- new_entry_count = len(server_missing_cacheentries)
757
- if new_entry_count > 0:
758
- self._output(
759
- f"Updating remote cache with {new_entry_count:,} new "
760
- f"{'entry' if new_entry_count == 1 else 'entries'}..."
761
- )
762
- coop.remote_cache_create_many(
763
- server_missing_cacheentries,
764
- visibility="private",
765
- description=remote_cache_description,
766
- )
767
- self._output("Remote cache updated!")
768
- else:
769
- self._output("No new entries to add to remote cache.")
770
-
771
- results.cache = cache.new_entries_cache()
772
-
773
- self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
774
894
 
895
+ results.cache = cache.new_entries_cache()
775
896
  return results
776
897
 
777
898
  def _run_local(self, *args, **kwargs):