edsl 0.1.36.dev1__py3-none-any.whl → 0.1.36.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +5 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +5 -1
- edsl/agents/PromptConstructor.py +4 -3
- edsl/coop/PriceFetcher.py +14 -18
- edsl/coop/coop.py +42 -8
- edsl/data/RemoteCacheSync.py +84 -0
- edsl/exceptions/coop.py +8 -0
- edsl/inference_services/InferenceServiceABC.py +28 -0
- edsl/inference_services/registry.py +24 -16
- edsl/jobs/Jobs.py +190 -167
- edsl/jobs/interviews/Interview.py +21 -3
- edsl/jobs/interviews/InterviewExceptionCollection.py +9 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +24 -6
- edsl/jobs/runners/JobsRunnerAsyncio.py +17 -23
- edsl/jobs/tasks/TaskHistory.py +23 -7
- edsl/questions/QuestionFunctional.py +7 -3
- edsl/results/Dataset.py +12 -0
- edsl/results/Result.py +11 -9
- edsl/results/Results.py +13 -1
- edsl/scenarios/Scenario.py +12 -1
- edsl/surveys/Survey.py +3 -0
- edsl/surveys/instructions/Instruction.py +20 -3
- {edsl-0.1.36.dev1.dist-info → edsl-0.1.36.dev5.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dev1.dist-info → edsl-0.1.36.dev5.dist-info}/RECORD +27 -26
- {edsl-0.1.36.dev1.dist-info → edsl-0.1.36.dev5.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dev1.dist-info → edsl-0.1.36.dev5.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
3
|
import warnings
|
4
|
+
import requests
|
4
5
|
from itertools import product
|
5
6
|
from typing import Optional, Union, Sequence, Generator
|
7
|
+
|
6
8
|
from edsl.Base import Base
|
7
9
|
from edsl.exceptions import MissingAPIKeyError
|
8
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
@@ -10,6 +12,9 @@ from edsl.jobs.interviews.Interview import Interview
|
|
10
12
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
11
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
12
14
|
|
15
|
+
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
16
|
+
from edsl.exceptions.coop import CoopServerResponseError
|
17
|
+
|
13
18
|
|
14
19
|
class Jobs(Base):
|
15
20
|
"""
|
@@ -203,14 +208,15 @@ class Jobs(Base):
|
|
203
208
|
]
|
204
209
|
)
|
205
210
|
return d
|
206
|
-
# if table:
|
207
|
-
# d.to_scenario_list().print(format="rich")
|
208
|
-
# else:
|
209
|
-
# return d
|
210
211
|
|
211
|
-
def show_prompts(self) -> None:
|
212
|
+
def show_prompts(self, all=False) -> None:
|
212
213
|
"""Print the prompts."""
|
213
|
-
|
214
|
+
if all:
|
215
|
+
self.prompts().to_scenario_list().print(format="rich")
|
216
|
+
else:
|
217
|
+
self.prompts().select(
|
218
|
+
"user_prompt", "system_prompt"
|
219
|
+
).to_scenario_list().print(format="rich")
|
214
220
|
|
215
221
|
@staticmethod
|
216
222
|
def estimate_prompt_cost(
|
@@ -219,11 +225,11 @@ class Jobs(Base):
|
|
219
225
|
price_lookup: dict,
|
220
226
|
inference_service: str,
|
221
227
|
model: str,
|
222
|
-
):
|
228
|
+
) -> dict:
|
223
229
|
"""Estimates the cost of a prompt. Takes piping into account."""
|
224
230
|
|
225
231
|
def get_piping_multiplier(prompt: str):
|
226
|
-
"""Returns 2 if a prompt includes Jinja
|
232
|
+
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
227
233
|
|
228
234
|
if "{{" in prompt and "}}" in prompt:
|
229
235
|
return 2
|
@@ -231,9 +237,25 @@ class Jobs(Base):
|
|
231
237
|
|
232
238
|
# Look up prices per token
|
233
239
|
key = (inference_service, model)
|
234
|
-
|
235
|
-
|
236
|
-
|
240
|
+
|
241
|
+
try:
|
242
|
+
relevant_prices = price_lookup[key]
|
243
|
+
output_price_per_token = 1 / float(
|
244
|
+
relevant_prices["output"]["one_usd_buys"]
|
245
|
+
)
|
246
|
+
input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
|
247
|
+
except KeyError:
|
248
|
+
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
249
|
+
# Use a sensible default
|
250
|
+
|
251
|
+
import warnings
|
252
|
+
|
253
|
+
warnings.warn(
|
254
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
255
|
+
)
|
256
|
+
|
257
|
+
output_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
258
|
+
input_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
237
259
|
|
238
260
|
# Compute the number of characters (double if the question involves piping)
|
239
261
|
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
@@ -258,7 +280,7 @@ class Jobs(Base):
|
|
258
280
|
"cost": cost,
|
259
281
|
}
|
260
282
|
|
261
|
-
def estimate_job_cost_from_external_prices(self, price_lookup: dict):
|
283
|
+
def estimate_job_cost_from_external_prices(self, price_lookup: dict) -> dict:
|
262
284
|
"""
|
263
285
|
Estimates the cost of a job according to the following assumptions:
|
264
286
|
|
@@ -341,7 +363,7 @@ class Jobs(Base):
|
|
341
363
|
|
342
364
|
return output
|
343
365
|
|
344
|
-
def estimate_job_cost(self):
|
366
|
+
def estimate_job_cost(self) -> dict:
|
345
367
|
"""
|
346
368
|
Estimates the cost of a job according to the following assumptions:
|
347
369
|
|
@@ -357,6 +379,25 @@ class Jobs(Base):
|
|
357
379
|
|
358
380
|
return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
|
359
381
|
|
382
|
+
@staticmethod
|
383
|
+
def compute_job_cost(job_results: "Results") -> float:
|
384
|
+
"""
|
385
|
+
Computes the cost of a completed job in USD.
|
386
|
+
"""
|
387
|
+
total_cost = 0
|
388
|
+
for result in job_results:
|
389
|
+
for key in result.raw_model_response:
|
390
|
+
if key.endswith("_cost"):
|
391
|
+
result_cost = result.raw_model_response[key]
|
392
|
+
|
393
|
+
question_name = key.removesuffix("_cost")
|
394
|
+
cache_used = result.cache_used_dict[question_name]
|
395
|
+
|
396
|
+
if isinstance(result_cost, (int, float)) and not cache_used:
|
397
|
+
total_cost += result_cost
|
398
|
+
|
399
|
+
return total_cost
|
400
|
+
|
360
401
|
@staticmethod
|
361
402
|
def _get_container_class(object):
|
362
403
|
from edsl.agents.AgentList import AgentList
|
@@ -580,7 +621,7 @@ class Jobs(Base):
|
|
580
621
|
|
581
622
|
def _output(self, message) -> None:
|
582
623
|
"""Check if a Job is verbose. If so, print the message."""
|
583
|
-
if self.verbose:
|
624
|
+
if hasattr(self, "verbose") and self.verbose:
|
584
625
|
print(message)
|
585
626
|
|
586
627
|
def _check_parameters(self, strict=False, warn=False) -> None:
|
@@ -657,6 +698,123 @@ class Jobs(Base):
|
|
657
698
|
return False
|
658
699
|
return self._raise_validation_errors
|
659
700
|
|
701
|
+
def create_remote_inference_job(
|
702
|
+
self, iterations: int = 1, remote_inference_description: Optional[str] = None
|
703
|
+
):
|
704
|
+
""" """
|
705
|
+
from edsl.coop.coop import Coop
|
706
|
+
|
707
|
+
coop = Coop()
|
708
|
+
self._output("Remote inference activated. Sending job to server...")
|
709
|
+
remote_job_creation_data = coop.remote_inference_create(
|
710
|
+
self,
|
711
|
+
description=remote_inference_description,
|
712
|
+
status="queued",
|
713
|
+
iterations=iterations,
|
714
|
+
)
|
715
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
716
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
717
|
+
return remote_job_creation_data
|
718
|
+
|
719
|
+
@staticmethod
|
720
|
+
def check_status(job_uuid):
|
721
|
+
from edsl.coop.coop import Coop
|
722
|
+
|
723
|
+
coop = Coop()
|
724
|
+
return coop.remote_inference_get(job_uuid)
|
725
|
+
|
726
|
+
def poll_remote_inference_job(
|
727
|
+
self, remote_job_creation_data: dict
|
728
|
+
) -> Union[Results, None]:
|
729
|
+
from edsl.coop.coop import Coop
|
730
|
+
import time
|
731
|
+
from datetime import datetime
|
732
|
+
from edsl.config import CONFIG
|
733
|
+
|
734
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
735
|
+
|
736
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
737
|
+
|
738
|
+
coop = Coop()
|
739
|
+
job_in_queue = True
|
740
|
+
while job_in_queue:
|
741
|
+
remote_job_data = coop.remote_inference_get(job_uuid)
|
742
|
+
status = remote_job_data.get("status")
|
743
|
+
if status == "cancelled":
|
744
|
+
print("\r" + " " * 80 + "\r", end="")
|
745
|
+
print("Job cancelled by the user.")
|
746
|
+
print(
|
747
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
748
|
+
)
|
749
|
+
return None
|
750
|
+
elif status == "failed":
|
751
|
+
print("\r" + " " * 80 + "\r", end="")
|
752
|
+
print("Job failed.")
|
753
|
+
print(
|
754
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
755
|
+
)
|
756
|
+
return None
|
757
|
+
elif status == "completed":
|
758
|
+
results_uuid = remote_job_data.get("results_uuid")
|
759
|
+
results = coop.get(results_uuid, expected_object_type="results")
|
760
|
+
print("\r" + " " * 80 + "\r", end="")
|
761
|
+
url = f"{expected_parrot_url}/content/{results_uuid}"
|
762
|
+
print(f"Job completed and Results stored on Coop: {url}.")
|
763
|
+
return results
|
764
|
+
else:
|
765
|
+
duration = 5
|
766
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
767
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
768
|
+
start_time = time.time()
|
769
|
+
i = 0
|
770
|
+
while time.time() - start_time < duration:
|
771
|
+
print(
|
772
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
773
|
+
end="",
|
774
|
+
flush=True,
|
775
|
+
)
|
776
|
+
time.sleep(0.1)
|
777
|
+
i += 1
|
778
|
+
|
779
|
+
def use_remote_inference(self, disable_remote_inference: bool):
|
780
|
+
if disable_remote_inference:
|
781
|
+
return False
|
782
|
+
if not disable_remote_inference:
|
783
|
+
try:
|
784
|
+
from edsl import Coop
|
785
|
+
|
786
|
+
user_edsl_settings = Coop().edsl_settings
|
787
|
+
return user_edsl_settings.get("remote_inference", False)
|
788
|
+
except requests.ConnectionError:
|
789
|
+
pass
|
790
|
+
except CoopServerResponseError as e:
|
791
|
+
pass
|
792
|
+
|
793
|
+
return False
|
794
|
+
|
795
|
+
def use_remote_cache(self):
|
796
|
+
try:
|
797
|
+
from edsl import Coop
|
798
|
+
|
799
|
+
user_edsl_settings = Coop().edsl_settings
|
800
|
+
return user_edsl_settings.get("remote_caching", False)
|
801
|
+
except requests.ConnectionError:
|
802
|
+
pass
|
803
|
+
except CoopServerResponseError as e:
|
804
|
+
pass
|
805
|
+
|
806
|
+
return False
|
807
|
+
|
808
|
+
def check_api_keys(self):
|
809
|
+
from edsl import Model
|
810
|
+
|
811
|
+
for model in self.models + [Model()]:
|
812
|
+
if not model.has_valid_api_key():
|
813
|
+
raise MissingAPIKeyError(
|
814
|
+
model_name=str(model.model),
|
815
|
+
inference_service=model._inference_service_,
|
816
|
+
)
|
817
|
+
|
660
818
|
def run(
|
661
819
|
self,
|
662
820
|
n: int = 1,
|
@@ -694,91 +852,17 @@ class Jobs(Base):
|
|
694
852
|
|
695
853
|
self.verbose = verbose
|
696
854
|
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
if not disable_remote_inference:
|
701
|
-
try:
|
702
|
-
coop = Coop()
|
703
|
-
user_edsl_settings = Coop().edsl_settings
|
704
|
-
remote_cache = user_edsl_settings.get("remote_caching", False)
|
705
|
-
remote_inference = user_edsl_settings.get("remote_inference", False)
|
706
|
-
except Exception:
|
707
|
-
pass
|
708
|
-
|
709
|
-
if remote_inference:
|
710
|
-
import time
|
711
|
-
from datetime import datetime
|
712
|
-
from edsl.config import CONFIG
|
713
|
-
|
714
|
-
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
715
|
-
|
716
|
-
self._output("Remote inference activated. Sending job to server...")
|
717
|
-
if remote_cache:
|
718
|
-
self._output(
|
719
|
-
"Remote caching activated. The remote cache will be used for this job."
|
720
|
-
)
|
721
|
-
|
722
|
-
remote_job_creation_data = coop.remote_inference_create(
|
723
|
-
self,
|
724
|
-
description=remote_inference_description,
|
725
|
-
status="queued",
|
726
|
-
iterations=n,
|
855
|
+
if remote_inference := self.use_remote_inference(disable_remote_inference):
|
856
|
+
remote_job_creation_data = self.create_remote_inference_job(
|
857
|
+
iterations=n, remote_inference_description=remote_inference_description
|
727
858
|
)
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
job_in_queue = True
|
733
|
-
while job_in_queue:
|
734
|
-
remote_job_data = coop.remote_inference_get(job_uuid)
|
735
|
-
status = remote_job_data.get("status")
|
736
|
-
if status == "cancelled":
|
737
|
-
print("\r" + " " * 80 + "\r", end="")
|
738
|
-
print("Job cancelled by the user.")
|
739
|
-
print(
|
740
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
741
|
-
)
|
742
|
-
return None
|
743
|
-
elif status == "failed":
|
744
|
-
print("\r" + " " * 80 + "\r", end="")
|
745
|
-
print("Job failed.")
|
746
|
-
print(
|
747
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
748
|
-
)
|
749
|
-
return None
|
750
|
-
elif status == "completed":
|
751
|
-
results_uuid = remote_job_data.get("results_uuid")
|
752
|
-
results = coop.get(results_uuid, expected_object_type="results")
|
753
|
-
print("\r" + " " * 80 + "\r", end="")
|
754
|
-
print(
|
755
|
-
f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
|
756
|
-
)
|
757
|
-
return results
|
758
|
-
else:
|
759
|
-
duration = 5
|
760
|
-
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
761
|
-
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
762
|
-
start_time = time.time()
|
763
|
-
i = 0
|
764
|
-
while time.time() - start_time < duration:
|
765
|
-
print(
|
766
|
-
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
767
|
-
end="",
|
768
|
-
flush=True,
|
769
|
-
)
|
770
|
-
time.sleep(0.1)
|
771
|
-
i += 1
|
772
|
-
else:
|
773
|
-
if check_api_keys:
|
774
|
-
from edsl import Model
|
859
|
+
results = self.poll_remote_inference_job(remote_job_creation_data)
|
860
|
+
if results is None:
|
861
|
+
self._output("Job failed.")
|
862
|
+
return results
|
775
863
|
|
776
|
-
|
777
|
-
|
778
|
-
raise MissingAPIKeyError(
|
779
|
-
model_name=str(model.model),
|
780
|
-
inference_service=model._inference_service_,
|
781
|
-
)
|
864
|
+
if check_api_keys:
|
865
|
+
self.check_api_keys()
|
782
866
|
|
783
867
|
# handle cache
|
784
868
|
if cache is None or cache is True:
|
@@ -790,51 +874,14 @@ class Jobs(Base):
|
|
790
874
|
|
791
875
|
cache = Cache()
|
792
876
|
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
raise_validation_errors=raise_validation_errors,
|
802
|
-
)
|
803
|
-
|
804
|
-
results.cache = cache.new_entries_cache()
|
805
|
-
|
806
|
-
self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
|
807
|
-
else:
|
808
|
-
cache_difference = coop.remote_cache_get_diff(cache.keys())
|
809
|
-
|
810
|
-
client_missing_cacheentries = cache_difference.get(
|
811
|
-
"client_missing_cacheentries", []
|
812
|
-
)
|
813
|
-
|
814
|
-
missing_entry_count = len(client_missing_cacheentries)
|
815
|
-
if missing_entry_count > 0:
|
816
|
-
self._output(
|
817
|
-
f"Updating local cache with {missing_entry_count:,} new "
|
818
|
-
f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
|
819
|
-
)
|
820
|
-
cache.add_from_dict(
|
821
|
-
{entry.key: entry for entry in client_missing_cacheentries}
|
822
|
-
)
|
823
|
-
self._output("Local cache updated!")
|
824
|
-
else:
|
825
|
-
self._output("No new entries to add to local cache.")
|
826
|
-
|
827
|
-
server_missing_cacheentry_keys = cache_difference.get(
|
828
|
-
"server_missing_cacheentry_keys", []
|
829
|
-
)
|
830
|
-
server_missing_cacheentries = [
|
831
|
-
entry
|
832
|
-
for key in server_missing_cacheentry_keys
|
833
|
-
if (entry := cache.data.get(key)) is not None
|
834
|
-
]
|
835
|
-
old_entry_keys = [key for key in cache.keys()]
|
836
|
-
|
837
|
-
self._output("Running job...")
|
877
|
+
remote_cache = self.use_remote_cache()
|
878
|
+
with RemoteCacheSync(
|
879
|
+
coop=Coop(),
|
880
|
+
cache=cache,
|
881
|
+
output_func=self._output,
|
882
|
+
remote_cache=remote_cache,
|
883
|
+
remote_cache_description=remote_cache_description,
|
884
|
+
) as r:
|
838
885
|
results = self._run_local(
|
839
886
|
n=n,
|
840
887
|
progress_bar=progress_bar,
|
@@ -844,32 +891,8 @@ class Jobs(Base):
|
|
844
891
|
print_exceptions=print_exceptions,
|
845
892
|
raise_validation_errors=raise_validation_errors,
|
846
893
|
)
|
847
|
-
self._output("Job completed!")
|
848
|
-
|
849
|
-
new_cache_entries = list(
|
850
|
-
[entry for entry in cache.values() if entry.key not in old_entry_keys]
|
851
|
-
)
|
852
|
-
server_missing_cacheentries.extend(new_cache_entries)
|
853
|
-
|
854
|
-
new_entry_count = len(server_missing_cacheentries)
|
855
|
-
if new_entry_count > 0:
|
856
|
-
self._output(
|
857
|
-
f"Updating remote cache with {new_entry_count:,} new "
|
858
|
-
f"{'entry' if new_entry_count == 1 else 'entries'}..."
|
859
|
-
)
|
860
|
-
coop.remote_cache_create_many(
|
861
|
-
server_missing_cacheentries,
|
862
|
-
visibility="private",
|
863
|
-
description=remote_cache_description,
|
864
|
-
)
|
865
|
-
self._output("Remote cache updated!")
|
866
|
-
else:
|
867
|
-
self._output("No new entries to add to remote cache.")
|
868
|
-
|
869
|
-
results.cache = cache.new_entries_cache()
|
870
|
-
|
871
|
-
self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
|
872
894
|
|
895
|
+
results.cache = cache.new_entries_cache()
|
873
896
|
return results
|
874
897
|
|
875
898
|
def _run_local(self, *args, **kwargs):
|
@@ -159,13 +159,13 @@ class Interview:
|
|
159
159
|
return self.task_creators.interview_status
|
160
160
|
|
161
161
|
# region: Serialization
|
162
|
-
def _to_dict(self, include_exceptions=
|
162
|
+
def _to_dict(self, include_exceptions=True) -> dict[str, Any]:
|
163
163
|
"""Return a dictionary representation of the Interview instance.
|
164
164
|
This is just for hashing purposes.
|
165
165
|
|
166
166
|
>>> i = Interview.example()
|
167
167
|
>>> hash(i)
|
168
|
-
|
168
|
+
1217840301076717434
|
169
169
|
"""
|
170
170
|
d = {
|
171
171
|
"agent": self.agent._to_dict(),
|
@@ -177,11 +177,29 @@ class Interview:
|
|
177
177
|
}
|
178
178
|
if include_exceptions:
|
179
179
|
d["exceptions"] = self.exceptions.to_dict()
|
180
|
+
return d
|
181
|
+
|
182
|
+
@classmethod
|
183
|
+
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
184
|
+
"""Return an Interview instance from a dictionary."""
|
185
|
+
agent = Agent.from_dict(d["agent"])
|
186
|
+
survey = Survey.from_dict(d["survey"])
|
187
|
+
scenario = Scenario.from_dict(d["scenario"])
|
188
|
+
model = LanguageModel.from_dict(d["model"])
|
189
|
+
iteration = d["iteration"]
|
190
|
+
return cls(agent=agent, survey=survey, scenario=scenario, model=model, iteration=iteration)
|
180
191
|
|
181
192
|
def __hash__(self) -> int:
|
182
193
|
from edsl.utilities.utilities import dict_hash
|
183
194
|
|
184
|
-
return dict_hash(self._to_dict())
|
195
|
+
return dict_hash(self._to_dict(include_exceptions=False))
|
196
|
+
|
197
|
+
def __eq__(self, other: "Interview") -> bool:
|
198
|
+
"""
|
199
|
+
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
|
200
|
+
True
|
201
|
+
"""
|
202
|
+
return hash(self) == hash(other)
|
185
203
|
|
186
204
|
# endregion
|
187
205
|
|
@@ -33,6 +33,15 @@ class InterviewExceptionCollection(UserDict):
|
|
33
33
|
"""Return the collection of exceptions as a dictionary."""
|
34
34
|
newdata = {k: [e.to_dict() for e in v] for k, v in self.data.items()}
|
35
35
|
return newdata
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def from_dict(cls, data: dict) -> "InterviewExceptionCollection":
|
39
|
+
"""Create an InterviewExceptionCollection from a dictionary."""
|
40
|
+
collection = cls()
|
41
|
+
for question_name, entries in data.items():
|
42
|
+
for entry in entries:
|
43
|
+
collection.add(question_name, InterviewExceptionEntry.from_dict(entry))
|
44
|
+
return collection
|
36
45
|
|
37
46
|
def _repr_html_(self) -> str:
|
38
47
|
from edsl.utilities.utilities import data_to_html
|
@@ -9,7 +9,6 @@ class InterviewExceptionEntry:
|
|
9
9
|
self,
|
10
10
|
*,
|
11
11
|
exception: Exception,
|
12
|
-
# failed_question: FailedQuestion,
|
13
12
|
invigilator: "Invigilator",
|
14
13
|
traceback_format="text",
|
15
14
|
answers=None,
|
@@ -133,22 +132,41 @@ class InterviewExceptionEntry:
|
|
133
132
|
)
|
134
133
|
console.print(tb)
|
135
134
|
return html_output.getvalue()
|
135
|
+
|
136
|
+
@staticmethod
|
137
|
+
def serialize_exception(exception: Exception) -> dict:
|
138
|
+
return {
|
139
|
+
"type": type(exception).__name__,
|
140
|
+
"message": str(exception),
|
141
|
+
"traceback": "".join(traceback.format_exception(type(exception), exception, exception.__traceback__)),
|
142
|
+
}
|
143
|
+
|
144
|
+
@staticmethod
|
145
|
+
def deserialize_exception(data: dict) -> Exception:
|
146
|
+
exception_class = globals()[data["type"]]
|
147
|
+
return exception_class(data["message"])
|
136
148
|
|
137
149
|
def to_dict(self) -> dict:
|
138
150
|
"""Return the exception as a dictionary.
|
139
151
|
|
140
152
|
>>> entry = InterviewExceptionEntry.example()
|
141
|
-
>>> entry.to_dict()
|
142
|
-
ValueError()
|
143
|
-
|
153
|
+
>>> _ = entry.to_dict()
|
144
154
|
"""
|
145
155
|
return {
|
146
|
-
"exception": self.exception,
|
156
|
+
"exception": self.serialize_exception(self.exception),
|
147
157
|
"time": self.time,
|
148
158
|
"traceback": self.traceback,
|
149
|
-
# "failed_question": self.failed_question.to_dict(),
|
150
159
|
"invigilator": self.invigilator.to_dict(),
|
151
160
|
}
|
161
|
+
|
162
|
+
@classmethod
|
163
|
+
def from_dict(cls, data: dict) -> "InterviewExceptionEntry":
|
164
|
+
"""Create an InterviewExceptionEntry from a dictionary."""
|
165
|
+
from edsl.agents.Invigilator import InvigilatorAI
|
166
|
+
|
167
|
+
exception = cls.deserialize_exception(data["exception"])
|
168
|
+
invigilator = InvigilatorAI.from_dict(data["invigilator"])
|
169
|
+
return cls(exception=exception, invigilator=invigilator)
|
152
170
|
|
153
171
|
def push(self):
|
154
172
|
from edsl import Coop
|
@@ -1,18 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import time
|
3
|
-
import math
|
4
3
|
import asyncio
|
5
|
-
import functools
|
6
4
|
import threading
|
7
5
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
8
6
|
from contextlib import contextmanager
|
9
7
|
from collections import UserList
|
10
8
|
|
11
|
-
from rich.live import Live
|
12
|
-
from rich.console import Console
|
13
|
-
|
14
9
|
from edsl.results.Results import Results
|
15
|
-
from edsl import shared_globals
|
16
10
|
from edsl.jobs.interviews.Interview import Interview
|
17
11
|
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
18
12
|
|
@@ -25,7 +19,6 @@ from edsl.results.Results import Results
|
|
25
19
|
from edsl.language_models.LanguageModel import LanguageModel
|
26
20
|
from edsl.data.Cache import Cache
|
27
21
|
|
28
|
-
|
29
22
|
class StatusTracker(UserList):
|
30
23
|
def __init__(self, total_tasks: int):
|
31
24
|
self.total_tasks = total_tasks
|
@@ -48,8 +41,6 @@ class JobsRunnerAsyncio:
|
|
48
41
|
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
49
42
|
self.total_interviews: List["Interview"] = []
|
50
43
|
|
51
|
-
# self.jobs_runner_status = JobsRunnerStatus(self, n=1)
|
52
|
-
|
53
44
|
async def run_async_generator(
|
54
45
|
self,
|
55
46
|
cache: Cache,
|
@@ -173,19 +164,20 @@ class JobsRunnerAsyncio:
|
|
173
164
|
|
174
165
|
prompt_dictionary = {}
|
175
166
|
for answer_key_name in answer_key_names:
|
176
|
-
prompt_dictionary[
|
177
|
-
answer_key_name
|
178
|
-
|
179
|
-
prompt_dictionary[
|
180
|
-
answer_key_name
|
181
|
-
|
167
|
+
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
168
|
+
question_name_to_prompts[answer_key_name]["user_prompt"]
|
169
|
+
)
|
170
|
+
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
171
|
+
question_name_to_prompts[answer_key_name]["system_prompt"]
|
172
|
+
)
|
182
173
|
|
183
174
|
raw_model_results_dictionary = {}
|
175
|
+
cache_used_dictionary = {}
|
184
176
|
for result in valid_results:
|
185
177
|
question_name = result.question_name
|
186
|
-
raw_model_results_dictionary[
|
187
|
-
|
188
|
-
|
178
|
+
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
179
|
+
result.raw_model_response
|
180
|
+
)
|
189
181
|
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
190
182
|
one_use_buys = (
|
191
183
|
"NA"
|
@@ -195,6 +187,7 @@ class JobsRunnerAsyncio:
|
|
195
187
|
else 1.0 / result.cost
|
196
188
|
)
|
197
189
|
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
190
|
+
cache_used_dictionary[question_name] = result.cache_used
|
198
191
|
|
199
192
|
result = Result(
|
200
193
|
agent=interview.agent,
|
@@ -207,6 +200,7 @@ class JobsRunnerAsyncio:
|
|
207
200
|
survey=interview.survey,
|
208
201
|
generated_tokens=generated_tokens_dict,
|
209
202
|
comments_dict=comments_dict,
|
203
|
+
cache_used_dict=cache_used_dictionary,
|
210
204
|
)
|
211
205
|
result.interview_hash = hash(interview)
|
212
206
|
|
@@ -225,17 +219,16 @@ class JobsRunnerAsyncio:
|
|
225
219
|
}
|
226
220
|
interview_hashes = list(interview_lookup.keys())
|
227
221
|
|
222
|
+
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
223
|
+
|
228
224
|
results = Results(
|
229
225
|
survey=self.jobs.survey,
|
230
226
|
data=sorted(
|
231
227
|
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
232
228
|
),
|
229
|
+
task_history=task_history,
|
230
|
+
cache=cache,
|
233
231
|
)
|
234
|
-
results.cache = cache
|
235
|
-
results.task_history = TaskHistory(
|
236
|
-
self.total_interviews, include_traceback=False
|
237
|
-
)
|
238
|
-
results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
|
239
232
|
results.bucket_collection = self.bucket_collection
|
240
233
|
|
241
234
|
if results.has_unfixed_exceptions and print_exceptions:
|
@@ -263,6 +256,7 @@ class JobsRunnerAsyncio:
|
|
263
256
|
except Exception as e:
|
264
257
|
print(e)
|
265
258
|
remote_logging = False
|
259
|
+
|
266
260
|
if remote_logging:
|
267
261
|
filestore = HTMLFileStore(filepath)
|
268
262
|
coop_details = filestore.push(description="Error report")
|