edsl 0.1.36__py3-none-any.whl → 0.1.36.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/jobs/Jobs.py CHANGED
@@ -1,10 +1,8 @@
1
1
  # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
2
  from __future__ import annotations
3
3
  import warnings
4
- import requests
5
4
  from itertools import product
6
5
  from typing import Optional, Union, Sequence, Generator
7
-
8
6
  from edsl.Base import Base
9
7
  from edsl.exceptions import MissingAPIKeyError
10
8
  from edsl.jobs.buckets.BucketCollection import BucketCollection
@@ -12,9 +10,6 @@ from edsl.jobs.interviews.Interview import Interview
12
10
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
13
11
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
12
 
15
- from edsl.data.RemoteCacheSync import RemoteCacheSync
16
- from edsl.exceptions.coop import CoopServerResponseError
17
-
18
13
 
19
14
  class Jobs(Base):
20
15
  """
@@ -208,15 +203,14 @@ class Jobs(Base):
208
203
  ]
209
204
  )
210
205
  return d
206
+ # if table:
207
+ # d.to_scenario_list().print(format="rich")
208
+ # else:
209
+ # return d
211
210
 
212
- def show_prompts(self, all=False) -> None:
211
+ def show_prompts(self) -> None:
213
212
  """Print the prompts."""
214
- if all:
215
- self.prompts().to_scenario_list().print(format="rich")
216
- else:
217
- self.prompts().select(
218
- "user_prompt", "system_prompt"
219
- ).to_scenario_list().print(format="rich")
213
+ self.prompts().to_scenario_list().print(format="rich")
220
214
 
221
215
  @staticmethod
222
216
  def estimate_prompt_cost(
@@ -225,11 +219,11 @@ class Jobs(Base):
225
219
  price_lookup: dict,
226
220
  inference_service: str,
227
221
  model: str,
228
- ) -> dict:
222
+ ):
229
223
  """Estimates the cost of a prompt. Takes piping into account."""
230
224
 
231
225
  def get_piping_multiplier(prompt: str):
232
- """Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
226
+ """Returns 2 if a prompt includes Jinja brances, and 1 otherwise."""
233
227
 
234
228
  if "{{" in prompt and "}}" in prompt:
235
229
  return 2
@@ -237,25 +231,9 @@ class Jobs(Base):
237
231
 
238
232
  # Look up prices per token
239
233
  key = (inference_service, model)
240
-
241
- try:
242
- relevant_prices = price_lookup[key]
243
- output_price_per_token = 1 / float(
244
- relevant_prices["output"]["one_usd_buys"]
245
- )
246
- input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
247
- except KeyError:
248
- # A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
249
- # Use a sensible default
250
-
251
- import warnings
252
-
253
- warnings.warn(
254
- "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
255
- )
256
-
257
- output_price_per_token = 0.00000015 # $0.15 / 1M tokens
258
- input_price_per_token = 0.00000060 # $0.60 / 1M tokens
234
+ relevant_prices = price_lookup[key]
235
+ output_price_per_token = 1 / float(relevant_prices["output"]["one_usd_buys"])
236
+ input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
259
237
 
260
238
  # Compute the number of characters (double if the question involves piping)
261
239
  user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
@@ -280,7 +258,7 @@ class Jobs(Base):
280
258
  "cost": cost,
281
259
  }
282
260
 
283
- def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
261
+ def estimate_job_cost_from_external_prices(self, price_lookup: dict):
284
262
  """
285
263
  Estimates the cost of a job according to the following assumptions:
286
264
 
@@ -363,7 +341,7 @@ class Jobs(Base):
363
341
 
364
342
  return output
365
343
 
366
- def estimate_job_cost(self) -> dict:
344
+ def estimate_job_cost(self):
367
345
  """
368
346
  Estimates the cost of a job according to the following assumptions:
369
347
 
@@ -379,25 +357,6 @@ class Jobs(Base):
379
357
 
380
358
  return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
381
359
 
382
- @staticmethod
383
- def compute_job_cost(job_results: "Results") -> float:
384
- """
385
- Computes the cost of a completed job in USD.
386
- """
387
- total_cost = 0
388
- for result in job_results:
389
- for key in result.raw_model_response:
390
- if key.endswith("_cost"):
391
- result_cost = result.raw_model_response[key]
392
-
393
- question_name = key.removesuffix("_cost")
394
- cache_used = result.cache_used_dict[question_name]
395
-
396
- if isinstance(result_cost, (int, float)) and not cache_used:
397
- total_cost += result_cost
398
-
399
- return total_cost
400
-
401
360
  @staticmethod
402
361
  def _get_container_class(object):
403
362
  from edsl.agents.AgentList import AgentList
@@ -621,7 +580,7 @@ class Jobs(Base):
621
580
 
622
581
  def _output(self, message) -> None:
623
582
  """Check if a Job is verbose. If so, print the message."""
624
- if hasattr(self, "verbose") and self.verbose:
583
+ if self.verbose:
625
584
  print(message)
626
585
 
627
586
  def _check_parameters(self, strict=False, warn=False) -> None:
@@ -698,123 +657,6 @@ class Jobs(Base):
698
657
  return False
699
658
  return self._raise_validation_errors
700
659
 
701
- def create_remote_inference_job(
702
- self, iterations: int = 1, remote_inference_description: Optional[str] = None
703
- ):
704
- """ """
705
- from edsl.coop.coop import Coop
706
-
707
- coop = Coop()
708
- self._output("Remote inference activated. Sending job to server...")
709
- remote_job_creation_data = coop.remote_inference_create(
710
- self,
711
- description=remote_inference_description,
712
- status="queued",
713
- iterations=iterations,
714
- )
715
- job_uuid = remote_job_creation_data.get("uuid")
716
- print(f"Job sent to server. (Job uuid={job_uuid}).")
717
- return remote_job_creation_data
718
-
719
- @staticmethod
720
- def check_status(job_uuid):
721
- from edsl.coop.coop import Coop
722
-
723
- coop = Coop()
724
- return coop.remote_inference_get(job_uuid)
725
-
726
- def poll_remote_inference_job(
727
- self, remote_job_creation_data: dict
728
- ) -> Union[Results, None]:
729
- from edsl.coop.coop import Coop
730
- import time
731
- from datetime import datetime
732
- from edsl.config import CONFIG
733
-
734
- expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
735
-
736
- job_uuid = remote_job_creation_data.get("uuid")
737
-
738
- coop = Coop()
739
- job_in_queue = True
740
- while job_in_queue:
741
- remote_job_data = coop.remote_inference_get(job_uuid)
742
- status = remote_job_data.get("status")
743
- if status == "cancelled":
744
- print("\r" + " " * 80 + "\r", end="")
745
- print("Job cancelled by the user.")
746
- print(
747
- f"See {expected_parrot_url}/home/remote-inference for more details."
748
- )
749
- return None
750
- elif status == "failed":
751
- print("\r" + " " * 80 + "\r", end="")
752
- print("Job failed.")
753
- print(
754
- f"See {expected_parrot_url}/home/remote-inference for more details."
755
- )
756
- return None
757
- elif status == "completed":
758
- results_uuid = remote_job_data.get("results_uuid")
759
- results = coop.get(results_uuid, expected_object_type="results")
760
- print("\r" + " " * 80 + "\r", end="")
761
- url = f"{expected_parrot_url}/content/{results_uuid}"
762
- print(f"Job completed and Results stored on Coop: {url}.")
763
- return results
764
- else:
765
- duration = 5
766
- time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
767
- frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
768
- start_time = time.time()
769
- i = 0
770
- while time.time() - start_time < duration:
771
- print(
772
- f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
773
- end="",
774
- flush=True,
775
- )
776
- time.sleep(0.1)
777
- i += 1
778
-
779
- def use_remote_inference(self, disable_remote_inference: bool):
780
- if disable_remote_inference:
781
- return False
782
- if not disable_remote_inference:
783
- try:
784
- from edsl import Coop
785
-
786
- user_edsl_settings = Coop().edsl_settings
787
- return user_edsl_settings.get("remote_inference", False)
788
- except requests.ConnectionError:
789
- pass
790
- except CoopServerResponseError as e:
791
- pass
792
-
793
- return False
794
-
795
- def use_remote_cache(self):
796
- try:
797
- from edsl import Coop
798
-
799
- user_edsl_settings = Coop().edsl_settings
800
- return user_edsl_settings.get("remote_caching", False)
801
- except requests.ConnectionError:
802
- pass
803
- except CoopServerResponseError as e:
804
- pass
805
-
806
- return False
807
-
808
- def check_api_keys(self):
809
- from edsl import Model
810
-
811
- for model in self.models + [Model()]:
812
- if not model.has_valid_api_key():
813
- raise MissingAPIKeyError(
814
- model_name=str(model.model),
815
- inference_service=model._inference_service_,
816
- )
817
-
818
660
  def run(
819
661
  self,
820
662
  n: int = 1,
@@ -852,17 +694,91 @@ class Jobs(Base):
852
694
 
853
695
  self.verbose = verbose
854
696
 
855
- if remote_inference := self.use_remote_inference(disable_remote_inference):
856
- remote_job_creation_data = self.create_remote_inference_job(
857
- iterations=n, remote_inference_description=remote_inference_description
697
+ remote_cache = False
698
+ remote_inference = False
699
+
700
+ if not disable_remote_inference:
701
+ try:
702
+ coop = Coop()
703
+ user_edsl_settings = Coop().edsl_settings
704
+ remote_cache = user_edsl_settings.get("remote_caching", False)
705
+ remote_inference = user_edsl_settings.get("remote_inference", False)
706
+ except Exception:
707
+ pass
708
+
709
+ if remote_inference:
710
+ import time
711
+ from datetime import datetime
712
+ from edsl.config import CONFIG
713
+
714
+ expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
715
+
716
+ self._output("Remote inference activated. Sending job to server...")
717
+ if remote_cache:
718
+ self._output(
719
+ "Remote caching activated. The remote cache will be used for this job."
720
+ )
721
+
722
+ remote_job_creation_data = coop.remote_inference_create(
723
+ self,
724
+ description=remote_inference_description,
725
+ status="queued",
726
+ iterations=n,
858
727
  )
859
- results = self.poll_remote_inference_job(remote_job_creation_data)
860
- if results is None:
861
- self._output("Job failed.")
862
- return results
728
+ time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
729
+ job_uuid = remote_job_creation_data.get("uuid")
730
+ print(f"Remote inference started (Job uuid={job_uuid}).")
731
+ # print(f"Job queued at {time_queued}.")
732
+ job_in_queue = True
733
+ while job_in_queue:
734
+ remote_job_data = coop.remote_inference_get(job_uuid)
735
+ status = remote_job_data.get("status")
736
+ if status == "cancelled":
737
+ print("\r" + " " * 80 + "\r", end="")
738
+ print("Job cancelled by the user.")
739
+ print(
740
+ f"See {expected_parrot_url}/home/remote-inference for more details."
741
+ )
742
+ return None
743
+ elif status == "failed":
744
+ print("\r" + " " * 80 + "\r", end="")
745
+ print("Job failed.")
746
+ print(
747
+ f"See {expected_parrot_url}/home/remote-inference for more details."
748
+ )
749
+ return None
750
+ elif status == "completed":
751
+ results_uuid = remote_job_data.get("results_uuid")
752
+ results = coop.get(results_uuid, expected_object_type="results")
753
+ print("\r" + " " * 80 + "\r", end="")
754
+ print(
755
+ f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
756
+ )
757
+ return results
758
+ else:
759
+ duration = 5
760
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
761
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
762
+ start_time = time.time()
763
+ i = 0
764
+ while time.time() - start_time < duration:
765
+ print(
766
+ f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
767
+ end="",
768
+ flush=True,
769
+ )
770
+ time.sleep(0.1)
771
+ i += 1
772
+ else:
773
+ if check_api_keys:
774
+ from edsl import Model
863
775
 
864
- if check_api_keys:
865
- self.check_api_keys()
776
+ for model in self.models + [Model()]:
777
+ if not model.has_valid_api_key():
778
+ raise MissingAPIKeyError(
779
+ model_name=str(model.model),
780
+ inference_service=model._inference_service_,
781
+ )
866
782
 
867
783
  # handle cache
868
784
  if cache is None or cache is True:
@@ -874,14 +790,51 @@ class Jobs(Base):
874
790
 
875
791
  cache = Cache()
876
792
 
877
- remote_cache = self.use_remote_cache()
878
- with RemoteCacheSync(
879
- coop=Coop(),
880
- cache=cache,
881
- output_func=self._output,
882
- remote_cache=remote_cache,
883
- remote_cache_description=remote_cache_description,
884
- ) as r:
793
+ if not remote_cache:
794
+ results = self._run_local(
795
+ n=n,
796
+ progress_bar=progress_bar,
797
+ cache=cache,
798
+ stop_on_exception=stop_on_exception,
799
+ sidecar_model=sidecar_model,
800
+ print_exceptions=print_exceptions,
801
+ raise_validation_errors=raise_validation_errors,
802
+ )
803
+
804
+ results.cache = cache.new_entries_cache()
805
+
806
+ self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
807
+ else:
808
+ cache_difference = coop.remote_cache_get_diff(cache.keys())
809
+
810
+ client_missing_cacheentries = cache_difference.get(
811
+ "client_missing_cacheentries", []
812
+ )
813
+
814
+ missing_entry_count = len(client_missing_cacheentries)
815
+ if missing_entry_count > 0:
816
+ self._output(
817
+ f"Updating local cache with {missing_entry_count:,} new "
818
+ f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
819
+ )
820
+ cache.add_from_dict(
821
+ {entry.key: entry for entry in client_missing_cacheentries}
822
+ )
823
+ self._output("Local cache updated!")
824
+ else:
825
+ self._output("No new entries to add to local cache.")
826
+
827
+ server_missing_cacheentry_keys = cache_difference.get(
828
+ "server_missing_cacheentry_keys", []
829
+ )
830
+ server_missing_cacheentries = [
831
+ entry
832
+ for key in server_missing_cacheentry_keys
833
+ if (entry := cache.data.get(key)) is not None
834
+ ]
835
+ old_entry_keys = [key for key in cache.keys()]
836
+
837
+ self._output("Running job...")
885
838
  results = self._run_local(
886
839
  n=n,
887
840
  progress_bar=progress_bar,
@@ -891,8 +844,32 @@ class Jobs(Base):
891
844
  print_exceptions=print_exceptions,
892
845
  raise_validation_errors=raise_validation_errors,
893
846
  )
847
+ self._output("Job completed!")
848
+
849
+ new_cache_entries = list(
850
+ [entry for entry in cache.values() if entry.key not in old_entry_keys]
851
+ )
852
+ server_missing_cacheentries.extend(new_cache_entries)
853
+
854
+ new_entry_count = len(server_missing_cacheentries)
855
+ if new_entry_count > 0:
856
+ self._output(
857
+ f"Updating remote cache with {new_entry_count:,} new "
858
+ f"{'entry' if new_entry_count == 1 else 'entries'}..."
859
+ )
860
+ coop.remote_cache_create_many(
861
+ server_missing_cacheentries,
862
+ visibility="private",
863
+ description=remote_cache_description,
864
+ )
865
+ self._output("Remote cache updated!")
866
+ else:
867
+ self._output("No new entries to add to remote cache.")
868
+
869
+ results.cache = cache.new_entries_cache()
870
+
871
+ self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
894
872
 
895
- results.cache = cache.new_entries_cache()
896
873
  return results
897
874
 
898
875
  def _run_local(self, *args, **kwargs):
@@ -110,9 +110,9 @@ class Interview:
110
110
  self.debug = debug
111
111
  self.iteration = iteration
112
112
  self.cache = cache
113
- self.answers: dict[
114
- str, str
115
- ] = Answers() # will get filled in as interview progresses
113
+ self.answers: dict[str, str] = (
114
+ Answers()
115
+ ) # will get filled in as interview progresses
116
116
  self.sidecar_model = sidecar_model
117
117
 
118
118
  # Trackers
@@ -143,9 +143,9 @@ class Interview:
143
143
  The keys are the question names; the values are the lists of status log changes for each task.
144
144
  """
145
145
  for task_creator in self.task_creators.values():
146
- self._task_status_log_dict[
147
- task_creator.question.question_name
148
- ] = task_creator.status_log
146
+ self._task_status_log_dict[task_creator.question.question_name] = (
147
+ task_creator.status_log
148
+ )
149
149
  return self._task_status_log_dict
150
150
 
151
151
  @property
@@ -159,13 +159,13 @@ class Interview:
159
159
  return self.task_creators.interview_status
160
160
 
161
161
  # region: Serialization
162
- def _to_dict(self, include_exceptions=True) -> dict[str, Any]:
162
+ def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
163
163
  """Return a dictionary representation of the Interview instance.
164
164
  This is just for hashing purposes.
165
165
 
166
166
  >>> i = Interview.example()
167
167
  >>> hash(i)
168
- 1217840301076717434
168
+ 1646262796627658719
169
169
  """
170
170
  d = {
171
171
  "agent": self.agent._to_dict(),
@@ -177,39 +177,11 @@ class Interview:
177
177
  }
178
178
  if include_exceptions:
179
179
  d["exceptions"] = self.exceptions.to_dict()
180
- return d
181
-
182
- @classmethod
183
- def from_dict(cls, d: dict[str, Any]) -> "Interview":
184
- """Return an Interview instance from a dictionary."""
185
- agent = Agent.from_dict(d["agent"])
186
- survey = Survey.from_dict(d["survey"])
187
- scenario = Scenario.from_dict(d["scenario"])
188
- model = LanguageModel.from_dict(d["model"])
189
- iteration = d["iteration"]
190
- interview = cls(
191
- agent=agent,
192
- survey=survey,
193
- scenario=scenario,
194
- model=model,
195
- iteration=iteration,
196
- )
197
- if "exceptions" in d:
198
- exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
199
- interview.exceptions = exceptions
200
- return interview
201
180
 
202
181
  def __hash__(self) -> int:
203
182
  from edsl.utilities.utilities import dict_hash
204
183
 
205
- return dict_hash(self._to_dict(include_exceptions=False))
206
-
207
- def __eq__(self, other: "Interview") -> bool:
208
- """
209
- >>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
210
- True
211
- """
212
- return hash(self) == hash(other)
184
+ return dict_hash(self._to_dict())
213
185
 
214
186
  # endregion
215
187
 
@@ -486,11 +458,11 @@ class Interview:
486
458
  """
487
459
  current_question_index: int = self.to_index[current_question.question_name]
488
460
 
489
- next_question: Union[
490
- int, EndOfSurvey
491
- ] = self.survey.rule_collection.next_question(
492
- q_now=current_question_index,
493
- answers=self.answers | self.scenario | self.agent["traits"],
461
+ next_question: Union[int, EndOfSurvey] = (
462
+ self.survey.rule_collection.next_question(
463
+ q_now=current_question_index,
464
+ answers=self.answers | self.scenario | self.agent["traits"],
465
+ )
494
466
  )
495
467
 
496
468
  next_question_index = next_question.next_q
@@ -34,15 +34,6 @@ class InterviewExceptionCollection(UserDict):
34
34
  newdata = {k: [e.to_dict() for e in v] for k, v in self.data.items()}
35
35
  return newdata
36
36
 
37
- @classmethod
38
- def from_dict(cls, data: dict) -> "InterviewExceptionCollection":
39
- """Create an InterviewExceptionCollection from a dictionary."""
40
- collection = cls()
41
- for question_name, entries in data.items():
42
- for entry in entries:
43
- collection.add(question_name, InterviewExceptionEntry.from_dict(entry))
44
- return collection
45
-
46
37
  def _repr_html_(self) -> str:
47
38
  from edsl.utilities.utilities import data_to_html
48
39
 
@@ -9,6 +9,7 @@ class InterviewExceptionEntry:
9
9
  self,
10
10
  *,
11
11
  exception: Exception,
12
+ # failed_question: FailedQuestion,
12
13
  invigilator: "Invigilator",
13
14
  traceback_format="text",
14
15
  answers=None,
@@ -133,48 +134,22 @@ class InterviewExceptionEntry:
133
134
  console.print(tb)
134
135
  return html_output.getvalue()
135
136
 
136
- @staticmethod
137
- def serialize_exception(exception: Exception) -> dict:
138
- return {
139
- "type": type(exception).__name__,
140
- "message": str(exception),
141
- "traceback": "".join(
142
- traceback.format_exception(
143
- type(exception), exception, exception.__traceback__
144
- )
145
- ),
146
- }
147
-
148
- @staticmethod
149
- def deserialize_exception(data: dict) -> Exception:
150
- try:
151
- exception_class = globals()[data["type"]]
152
- except KeyError:
153
- exception_class = Exception
154
- return exception_class(data["message"])
155
-
156
137
  def to_dict(self) -> dict:
157
138
  """Return the exception as a dictionary.
158
139
 
159
140
  >>> entry = InterviewExceptionEntry.example()
160
- >>> _ = entry.to_dict()
141
+ >>> entry.to_dict()['exception']
142
+ ValueError()
143
+
161
144
  """
162
145
  return {
163
- "exception": self.serialize_exception(self.exception),
146
+ "exception": self.exception,
164
147
  "time": self.time,
165
148
  "traceback": self.traceback,
149
+ # "failed_question": self.failed_question.to_dict(),
166
150
  "invigilator": self.invigilator.to_dict(),
167
151
  }
168
152
 
169
- @classmethod
170
- def from_dict(cls, data: dict) -> "InterviewExceptionEntry":
171
- """Create an InterviewExceptionEntry from a dictionary."""
172
- from edsl.agents.Invigilator import InvigilatorAI
173
-
174
- exception = cls.deserialize_exception(data["exception"])
175
- invigilator = InvigilatorAI.from_dict(data["invigilator"])
176
- return cls(exception=exception, invigilator=invigilator)
177
-
178
153
  def push(self):
179
154
  from edsl import Coop
180
155