edsl 0.1.36__py3-none-any.whl → 0.1.36.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -5
- edsl/__init__.py +0 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +7 -11
- edsl/agents/InvigilatorBase.py +1 -5
- edsl/agents/PromptConstructor.py +18 -27
- edsl/conversation/Conversation.py +1 -1
- edsl/coop/PriceFetcher.py +18 -14
- edsl/coop/coop.py +8 -42
- edsl/exceptions/coop.py +0 -8
- edsl/inference_services/InferenceServiceABC.py +0 -28
- edsl/inference_services/InferenceServicesCollection.py +4 -10
- edsl/inference_services/models_available_cache.py +1 -25
- edsl/jobs/Jobs.py +167 -190
- edsl/jobs/interviews/Interview.py +14 -42
- edsl/jobs/interviews/InterviewExceptionCollection.py +0 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +6 -31
- edsl/jobs/runners/JobsRunnerAsyncio.py +13 -8
- edsl/jobs/tasks/TaskHistory.py +7 -23
- edsl/questions/QuestionFunctional.py +3 -7
- edsl/results/Dataset.py +0 -12
- edsl/results/Result.py +0 -2
- edsl/results/Results.py +1 -13
- edsl/scenarios/FileStore.py +5 -20
- edsl/scenarios/Scenario.py +1 -15
- edsl/scenarios/__init__.py +0 -2
- edsl/surveys/Survey.py +0 -3
- edsl/surveys/instructions/Instruction.py +3 -20
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/RECORD +32 -33
- edsl/data/RemoteCacheSync.py +0 -97
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
3
|
import warnings
|
4
|
-
import requests
|
5
4
|
from itertools import product
|
6
5
|
from typing import Optional, Union, Sequence, Generator
|
7
|
-
|
8
6
|
from edsl.Base import Base
|
9
7
|
from edsl.exceptions import MissingAPIKeyError
|
10
8
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
@@ -12,9 +10,6 @@ from edsl.jobs.interviews.Interview import Interview
|
|
12
10
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
13
11
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
14
12
|
|
15
|
-
from edsl.data.RemoteCacheSync import RemoteCacheSync
|
16
|
-
from edsl.exceptions.coop import CoopServerResponseError
|
17
|
-
|
18
13
|
|
19
14
|
class Jobs(Base):
|
20
15
|
"""
|
@@ -208,15 +203,14 @@ class Jobs(Base):
|
|
208
203
|
]
|
209
204
|
)
|
210
205
|
return d
|
206
|
+
# if table:
|
207
|
+
# d.to_scenario_list().print(format="rich")
|
208
|
+
# else:
|
209
|
+
# return d
|
211
210
|
|
212
|
-
def show_prompts(self
|
211
|
+
def show_prompts(self) -> None:
|
213
212
|
"""Print the prompts."""
|
214
|
-
|
215
|
-
self.prompts().to_scenario_list().print(format="rich")
|
216
|
-
else:
|
217
|
-
self.prompts().select(
|
218
|
-
"user_prompt", "system_prompt"
|
219
|
-
).to_scenario_list().print(format="rich")
|
213
|
+
self.prompts().to_scenario_list().print(format="rich")
|
220
214
|
|
221
215
|
@staticmethod
|
222
216
|
def estimate_prompt_cost(
|
@@ -225,11 +219,11 @@ class Jobs(Base):
|
|
225
219
|
price_lookup: dict,
|
226
220
|
inference_service: str,
|
227
221
|
model: str,
|
228
|
-
)
|
222
|
+
):
|
229
223
|
"""Estimates the cost of a prompt. Takes piping into account."""
|
230
224
|
|
231
225
|
def get_piping_multiplier(prompt: str):
|
232
|
-
"""Returns 2 if a prompt includes Jinja
|
226
|
+
"""Returns 2 if a prompt includes Jinja brances, and 1 otherwise."""
|
233
227
|
|
234
228
|
if "{{" in prompt and "}}" in prompt:
|
235
229
|
return 2
|
@@ -237,25 +231,9 @@ class Jobs(Base):
|
|
237
231
|
|
238
232
|
# Look up prices per token
|
239
233
|
key = (inference_service, model)
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
output_price_per_token = 1 / float(
|
244
|
-
relevant_prices["output"]["one_usd_buys"]
|
245
|
-
)
|
246
|
-
input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
|
247
|
-
except KeyError:
|
248
|
-
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
249
|
-
# Use a sensible default
|
250
|
-
|
251
|
-
import warnings
|
252
|
-
|
253
|
-
warnings.warn(
|
254
|
-
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
255
|
-
)
|
256
|
-
|
257
|
-
output_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
258
|
-
input_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
234
|
+
relevant_prices = price_lookup[key]
|
235
|
+
output_price_per_token = 1 / float(relevant_prices["output"]["one_usd_buys"])
|
236
|
+
input_price_per_token = 1 / float(relevant_prices["input"]["one_usd_buys"])
|
259
237
|
|
260
238
|
# Compute the number of characters (double if the question involves piping)
|
261
239
|
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
@@ -280,7 +258,7 @@ class Jobs(Base):
|
|
280
258
|
"cost": cost,
|
281
259
|
}
|
282
260
|
|
283
|
-
def estimate_job_cost_from_external_prices(self, price_lookup: dict)
|
261
|
+
def estimate_job_cost_from_external_prices(self, price_lookup: dict):
|
284
262
|
"""
|
285
263
|
Estimates the cost of a job according to the following assumptions:
|
286
264
|
|
@@ -363,7 +341,7 @@ class Jobs(Base):
|
|
363
341
|
|
364
342
|
return output
|
365
343
|
|
366
|
-
def estimate_job_cost(self)
|
344
|
+
def estimate_job_cost(self):
|
367
345
|
"""
|
368
346
|
Estimates the cost of a job according to the following assumptions:
|
369
347
|
|
@@ -379,25 +357,6 @@ class Jobs(Base):
|
|
379
357
|
|
380
358
|
return self.estimate_job_cost_from_external_prices(price_lookup=price_lookup)
|
381
359
|
|
382
|
-
@staticmethod
|
383
|
-
def compute_job_cost(job_results: "Results") -> float:
|
384
|
-
"""
|
385
|
-
Computes the cost of a completed job in USD.
|
386
|
-
"""
|
387
|
-
total_cost = 0
|
388
|
-
for result in job_results:
|
389
|
-
for key in result.raw_model_response:
|
390
|
-
if key.endswith("_cost"):
|
391
|
-
result_cost = result.raw_model_response[key]
|
392
|
-
|
393
|
-
question_name = key.removesuffix("_cost")
|
394
|
-
cache_used = result.cache_used_dict[question_name]
|
395
|
-
|
396
|
-
if isinstance(result_cost, (int, float)) and not cache_used:
|
397
|
-
total_cost += result_cost
|
398
|
-
|
399
|
-
return total_cost
|
400
|
-
|
401
360
|
@staticmethod
|
402
361
|
def _get_container_class(object):
|
403
362
|
from edsl.agents.AgentList import AgentList
|
@@ -621,7 +580,7 @@ class Jobs(Base):
|
|
621
580
|
|
622
581
|
def _output(self, message) -> None:
|
623
582
|
"""Check if a Job is verbose. If so, print the message."""
|
624
|
-
if
|
583
|
+
if self.verbose:
|
625
584
|
print(message)
|
626
585
|
|
627
586
|
def _check_parameters(self, strict=False, warn=False) -> None:
|
@@ -698,123 +657,6 @@ class Jobs(Base):
|
|
698
657
|
return False
|
699
658
|
return self._raise_validation_errors
|
700
659
|
|
701
|
-
def create_remote_inference_job(
|
702
|
-
self, iterations: int = 1, remote_inference_description: Optional[str] = None
|
703
|
-
):
|
704
|
-
""" """
|
705
|
-
from edsl.coop.coop import Coop
|
706
|
-
|
707
|
-
coop = Coop()
|
708
|
-
self._output("Remote inference activated. Sending job to server...")
|
709
|
-
remote_job_creation_data = coop.remote_inference_create(
|
710
|
-
self,
|
711
|
-
description=remote_inference_description,
|
712
|
-
status="queued",
|
713
|
-
iterations=iterations,
|
714
|
-
)
|
715
|
-
job_uuid = remote_job_creation_data.get("uuid")
|
716
|
-
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
717
|
-
return remote_job_creation_data
|
718
|
-
|
719
|
-
@staticmethod
|
720
|
-
def check_status(job_uuid):
|
721
|
-
from edsl.coop.coop import Coop
|
722
|
-
|
723
|
-
coop = Coop()
|
724
|
-
return coop.remote_inference_get(job_uuid)
|
725
|
-
|
726
|
-
def poll_remote_inference_job(
|
727
|
-
self, remote_job_creation_data: dict
|
728
|
-
) -> Union[Results, None]:
|
729
|
-
from edsl.coop.coop import Coop
|
730
|
-
import time
|
731
|
-
from datetime import datetime
|
732
|
-
from edsl.config import CONFIG
|
733
|
-
|
734
|
-
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
735
|
-
|
736
|
-
job_uuid = remote_job_creation_data.get("uuid")
|
737
|
-
|
738
|
-
coop = Coop()
|
739
|
-
job_in_queue = True
|
740
|
-
while job_in_queue:
|
741
|
-
remote_job_data = coop.remote_inference_get(job_uuid)
|
742
|
-
status = remote_job_data.get("status")
|
743
|
-
if status == "cancelled":
|
744
|
-
print("\r" + " " * 80 + "\r", end="")
|
745
|
-
print("Job cancelled by the user.")
|
746
|
-
print(
|
747
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
748
|
-
)
|
749
|
-
return None
|
750
|
-
elif status == "failed":
|
751
|
-
print("\r" + " " * 80 + "\r", end="")
|
752
|
-
print("Job failed.")
|
753
|
-
print(
|
754
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
755
|
-
)
|
756
|
-
return None
|
757
|
-
elif status == "completed":
|
758
|
-
results_uuid = remote_job_data.get("results_uuid")
|
759
|
-
results = coop.get(results_uuid, expected_object_type="results")
|
760
|
-
print("\r" + " " * 80 + "\r", end="")
|
761
|
-
url = f"{expected_parrot_url}/content/{results_uuid}"
|
762
|
-
print(f"Job completed and Results stored on Coop: {url}.")
|
763
|
-
return results
|
764
|
-
else:
|
765
|
-
duration = 5
|
766
|
-
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
767
|
-
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
768
|
-
start_time = time.time()
|
769
|
-
i = 0
|
770
|
-
while time.time() - start_time < duration:
|
771
|
-
print(
|
772
|
-
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
773
|
-
end="",
|
774
|
-
flush=True,
|
775
|
-
)
|
776
|
-
time.sleep(0.1)
|
777
|
-
i += 1
|
778
|
-
|
779
|
-
def use_remote_inference(self, disable_remote_inference: bool):
|
780
|
-
if disable_remote_inference:
|
781
|
-
return False
|
782
|
-
if not disable_remote_inference:
|
783
|
-
try:
|
784
|
-
from edsl import Coop
|
785
|
-
|
786
|
-
user_edsl_settings = Coop().edsl_settings
|
787
|
-
return user_edsl_settings.get("remote_inference", False)
|
788
|
-
except requests.ConnectionError:
|
789
|
-
pass
|
790
|
-
except CoopServerResponseError as e:
|
791
|
-
pass
|
792
|
-
|
793
|
-
return False
|
794
|
-
|
795
|
-
def use_remote_cache(self):
|
796
|
-
try:
|
797
|
-
from edsl import Coop
|
798
|
-
|
799
|
-
user_edsl_settings = Coop().edsl_settings
|
800
|
-
return user_edsl_settings.get("remote_caching", False)
|
801
|
-
except requests.ConnectionError:
|
802
|
-
pass
|
803
|
-
except CoopServerResponseError as e:
|
804
|
-
pass
|
805
|
-
|
806
|
-
return False
|
807
|
-
|
808
|
-
def check_api_keys(self):
|
809
|
-
from edsl import Model
|
810
|
-
|
811
|
-
for model in self.models + [Model()]:
|
812
|
-
if not model.has_valid_api_key():
|
813
|
-
raise MissingAPIKeyError(
|
814
|
-
model_name=str(model.model),
|
815
|
-
inference_service=model._inference_service_,
|
816
|
-
)
|
817
|
-
|
818
660
|
def run(
|
819
661
|
self,
|
820
662
|
n: int = 1,
|
@@ -852,17 +694,91 @@ class Jobs(Base):
|
|
852
694
|
|
853
695
|
self.verbose = verbose
|
854
696
|
|
855
|
-
|
856
|
-
|
857
|
-
|
697
|
+
remote_cache = False
|
698
|
+
remote_inference = False
|
699
|
+
|
700
|
+
if not disable_remote_inference:
|
701
|
+
try:
|
702
|
+
coop = Coop()
|
703
|
+
user_edsl_settings = Coop().edsl_settings
|
704
|
+
remote_cache = user_edsl_settings.get("remote_caching", False)
|
705
|
+
remote_inference = user_edsl_settings.get("remote_inference", False)
|
706
|
+
except Exception:
|
707
|
+
pass
|
708
|
+
|
709
|
+
if remote_inference:
|
710
|
+
import time
|
711
|
+
from datetime import datetime
|
712
|
+
from edsl.config import CONFIG
|
713
|
+
|
714
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
715
|
+
|
716
|
+
self._output("Remote inference activated. Sending job to server...")
|
717
|
+
if remote_cache:
|
718
|
+
self._output(
|
719
|
+
"Remote caching activated. The remote cache will be used for this job."
|
720
|
+
)
|
721
|
+
|
722
|
+
remote_job_creation_data = coop.remote_inference_create(
|
723
|
+
self,
|
724
|
+
description=remote_inference_description,
|
725
|
+
status="queued",
|
726
|
+
iterations=n,
|
858
727
|
)
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
728
|
+
time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
|
729
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
730
|
+
print(f"Remote inference started (Job uuid={job_uuid}).")
|
731
|
+
# print(f"Job queued at {time_queued}.")
|
732
|
+
job_in_queue = True
|
733
|
+
while job_in_queue:
|
734
|
+
remote_job_data = coop.remote_inference_get(job_uuid)
|
735
|
+
status = remote_job_data.get("status")
|
736
|
+
if status == "cancelled":
|
737
|
+
print("\r" + " " * 80 + "\r", end="")
|
738
|
+
print("Job cancelled by the user.")
|
739
|
+
print(
|
740
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
741
|
+
)
|
742
|
+
return None
|
743
|
+
elif status == "failed":
|
744
|
+
print("\r" + " " * 80 + "\r", end="")
|
745
|
+
print("Job failed.")
|
746
|
+
print(
|
747
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
748
|
+
)
|
749
|
+
return None
|
750
|
+
elif status == "completed":
|
751
|
+
results_uuid = remote_job_data.get("results_uuid")
|
752
|
+
results = coop.get(results_uuid, expected_object_type="results")
|
753
|
+
print("\r" + " " * 80 + "\r", end="")
|
754
|
+
print(
|
755
|
+
f"Job completed and Results stored on Coop (Results uuid={results_uuid})."
|
756
|
+
)
|
757
|
+
return results
|
758
|
+
else:
|
759
|
+
duration = 5
|
760
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
761
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
762
|
+
start_time = time.time()
|
763
|
+
i = 0
|
764
|
+
while time.time() - start_time < duration:
|
765
|
+
print(
|
766
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
767
|
+
end="",
|
768
|
+
flush=True,
|
769
|
+
)
|
770
|
+
time.sleep(0.1)
|
771
|
+
i += 1
|
772
|
+
else:
|
773
|
+
if check_api_keys:
|
774
|
+
from edsl import Model
|
863
775
|
|
864
|
-
|
865
|
-
|
776
|
+
for model in self.models + [Model()]:
|
777
|
+
if not model.has_valid_api_key():
|
778
|
+
raise MissingAPIKeyError(
|
779
|
+
model_name=str(model.model),
|
780
|
+
inference_service=model._inference_service_,
|
781
|
+
)
|
866
782
|
|
867
783
|
# handle cache
|
868
784
|
if cache is None or cache is True:
|
@@ -874,14 +790,51 @@ class Jobs(Base):
|
|
874
790
|
|
875
791
|
cache = Cache()
|
876
792
|
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
793
|
+
if not remote_cache:
|
794
|
+
results = self._run_local(
|
795
|
+
n=n,
|
796
|
+
progress_bar=progress_bar,
|
797
|
+
cache=cache,
|
798
|
+
stop_on_exception=stop_on_exception,
|
799
|
+
sidecar_model=sidecar_model,
|
800
|
+
print_exceptions=print_exceptions,
|
801
|
+
raise_validation_errors=raise_validation_errors,
|
802
|
+
)
|
803
|
+
|
804
|
+
results.cache = cache.new_entries_cache()
|
805
|
+
|
806
|
+
self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
|
807
|
+
else:
|
808
|
+
cache_difference = coop.remote_cache_get_diff(cache.keys())
|
809
|
+
|
810
|
+
client_missing_cacheentries = cache_difference.get(
|
811
|
+
"client_missing_cacheentries", []
|
812
|
+
)
|
813
|
+
|
814
|
+
missing_entry_count = len(client_missing_cacheentries)
|
815
|
+
if missing_entry_count > 0:
|
816
|
+
self._output(
|
817
|
+
f"Updating local cache with {missing_entry_count:,} new "
|
818
|
+
f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
|
819
|
+
)
|
820
|
+
cache.add_from_dict(
|
821
|
+
{entry.key: entry for entry in client_missing_cacheentries}
|
822
|
+
)
|
823
|
+
self._output("Local cache updated!")
|
824
|
+
else:
|
825
|
+
self._output("No new entries to add to local cache.")
|
826
|
+
|
827
|
+
server_missing_cacheentry_keys = cache_difference.get(
|
828
|
+
"server_missing_cacheentry_keys", []
|
829
|
+
)
|
830
|
+
server_missing_cacheentries = [
|
831
|
+
entry
|
832
|
+
for key in server_missing_cacheentry_keys
|
833
|
+
if (entry := cache.data.get(key)) is not None
|
834
|
+
]
|
835
|
+
old_entry_keys = [key for key in cache.keys()]
|
836
|
+
|
837
|
+
self._output("Running job...")
|
885
838
|
results = self._run_local(
|
886
839
|
n=n,
|
887
840
|
progress_bar=progress_bar,
|
@@ -891,8 +844,32 @@ class Jobs(Base):
|
|
891
844
|
print_exceptions=print_exceptions,
|
892
845
|
raise_validation_errors=raise_validation_errors,
|
893
846
|
)
|
847
|
+
self._output("Job completed!")
|
848
|
+
|
849
|
+
new_cache_entries = list(
|
850
|
+
[entry for entry in cache.values() if entry.key not in old_entry_keys]
|
851
|
+
)
|
852
|
+
server_missing_cacheentries.extend(new_cache_entries)
|
853
|
+
|
854
|
+
new_entry_count = len(server_missing_cacheentries)
|
855
|
+
if new_entry_count > 0:
|
856
|
+
self._output(
|
857
|
+
f"Updating remote cache with {new_entry_count:,} new "
|
858
|
+
f"{'entry' if new_entry_count == 1 else 'entries'}..."
|
859
|
+
)
|
860
|
+
coop.remote_cache_create_many(
|
861
|
+
server_missing_cacheentries,
|
862
|
+
visibility="private",
|
863
|
+
description=remote_cache_description,
|
864
|
+
)
|
865
|
+
self._output("Remote cache updated!")
|
866
|
+
else:
|
867
|
+
self._output("No new entries to add to remote cache.")
|
868
|
+
|
869
|
+
results.cache = cache.new_entries_cache()
|
870
|
+
|
871
|
+
self._output(f"There are {len(cache.keys()):,} entries in the local cache.")
|
894
872
|
|
895
|
-
results.cache = cache.new_entries_cache()
|
896
873
|
return results
|
897
874
|
|
898
875
|
def _run_local(self, *args, **kwargs):
|
@@ -110,9 +110,9 @@ class Interview:
|
|
110
110
|
self.debug = debug
|
111
111
|
self.iteration = iteration
|
112
112
|
self.cache = cache
|
113
|
-
self.answers: dict[
|
114
|
-
|
115
|
-
|
113
|
+
self.answers: dict[str, str] = (
|
114
|
+
Answers()
|
115
|
+
) # will get filled in as interview progresses
|
116
116
|
self.sidecar_model = sidecar_model
|
117
117
|
|
118
118
|
# Trackers
|
@@ -143,9 +143,9 @@ class Interview:
|
|
143
143
|
The keys are the question names; the values are the lists of status log changes for each task.
|
144
144
|
"""
|
145
145
|
for task_creator in self.task_creators.values():
|
146
|
-
self._task_status_log_dict[
|
147
|
-
task_creator.
|
148
|
-
|
146
|
+
self._task_status_log_dict[task_creator.question.question_name] = (
|
147
|
+
task_creator.status_log
|
148
|
+
)
|
149
149
|
return self._task_status_log_dict
|
150
150
|
|
151
151
|
@property
|
@@ -159,13 +159,13 @@ class Interview:
|
|
159
159
|
return self.task_creators.interview_status
|
160
160
|
|
161
161
|
# region: Serialization
|
162
|
-
def _to_dict(self, include_exceptions=
|
162
|
+
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
163
163
|
"""Return a dictionary representation of the Interview instance.
|
164
164
|
This is just for hashing purposes.
|
165
165
|
|
166
166
|
>>> i = Interview.example()
|
167
167
|
>>> hash(i)
|
168
|
-
|
168
|
+
1646262796627658719
|
169
169
|
"""
|
170
170
|
d = {
|
171
171
|
"agent": self.agent._to_dict(),
|
@@ -177,39 +177,11 @@ class Interview:
|
|
177
177
|
}
|
178
178
|
if include_exceptions:
|
179
179
|
d["exceptions"] = self.exceptions.to_dict()
|
180
|
-
return d
|
181
|
-
|
182
|
-
@classmethod
|
183
|
-
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
184
|
-
"""Return an Interview instance from a dictionary."""
|
185
|
-
agent = Agent.from_dict(d["agent"])
|
186
|
-
survey = Survey.from_dict(d["survey"])
|
187
|
-
scenario = Scenario.from_dict(d["scenario"])
|
188
|
-
model = LanguageModel.from_dict(d["model"])
|
189
|
-
iteration = d["iteration"]
|
190
|
-
interview = cls(
|
191
|
-
agent=agent,
|
192
|
-
survey=survey,
|
193
|
-
scenario=scenario,
|
194
|
-
model=model,
|
195
|
-
iteration=iteration,
|
196
|
-
)
|
197
|
-
if "exceptions" in d:
|
198
|
-
exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
|
199
|
-
interview.exceptions = exceptions
|
200
|
-
return interview
|
201
180
|
|
202
181
|
def __hash__(self) -> int:
|
203
182
|
from edsl.utilities.utilities import dict_hash
|
204
183
|
|
205
|
-
return dict_hash(self._to_dict(
|
206
|
-
|
207
|
-
def __eq__(self, other: "Interview") -> bool:
|
208
|
-
"""
|
209
|
-
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
|
210
|
-
True
|
211
|
-
"""
|
212
|
-
return hash(self) == hash(other)
|
184
|
+
return dict_hash(self._to_dict())
|
213
185
|
|
214
186
|
# endregion
|
215
187
|
|
@@ -486,11 +458,11 @@ class Interview:
|
|
486
458
|
"""
|
487
459
|
current_question_index: int = self.to_index[current_question.question_name]
|
488
460
|
|
489
|
-
next_question: Union[
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
461
|
+
next_question: Union[int, EndOfSurvey] = (
|
462
|
+
self.survey.rule_collection.next_question(
|
463
|
+
q_now=current_question_index,
|
464
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
465
|
+
)
|
494
466
|
)
|
495
467
|
|
496
468
|
next_question_index = next_question.next_q
|
@@ -34,15 +34,6 @@ class InterviewExceptionCollection(UserDict):
|
|
34
34
|
newdata = {k: [e.to_dict() for e in v] for k, v in self.data.items()}
|
35
35
|
return newdata
|
36
36
|
|
37
|
-
@classmethod
|
38
|
-
def from_dict(cls, data: dict) -> "InterviewExceptionCollection":
|
39
|
-
"""Create an InterviewExceptionCollection from a dictionary."""
|
40
|
-
collection = cls()
|
41
|
-
for question_name, entries in data.items():
|
42
|
-
for entry in entries:
|
43
|
-
collection.add(question_name, InterviewExceptionEntry.from_dict(entry))
|
44
|
-
return collection
|
45
|
-
|
46
37
|
def _repr_html_(self) -> str:
|
47
38
|
from edsl.utilities.utilities import data_to_html
|
48
39
|
|
@@ -9,6 +9,7 @@ class InterviewExceptionEntry:
|
|
9
9
|
self,
|
10
10
|
*,
|
11
11
|
exception: Exception,
|
12
|
+
# failed_question: FailedQuestion,
|
12
13
|
invigilator: "Invigilator",
|
13
14
|
traceback_format="text",
|
14
15
|
answers=None,
|
@@ -133,48 +134,22 @@ class InterviewExceptionEntry:
|
|
133
134
|
console.print(tb)
|
134
135
|
return html_output.getvalue()
|
135
136
|
|
136
|
-
@staticmethod
|
137
|
-
def serialize_exception(exception: Exception) -> dict:
|
138
|
-
return {
|
139
|
-
"type": type(exception).__name__,
|
140
|
-
"message": str(exception),
|
141
|
-
"traceback": "".join(
|
142
|
-
traceback.format_exception(
|
143
|
-
type(exception), exception, exception.__traceback__
|
144
|
-
)
|
145
|
-
),
|
146
|
-
}
|
147
|
-
|
148
|
-
@staticmethod
|
149
|
-
def deserialize_exception(data: dict) -> Exception:
|
150
|
-
try:
|
151
|
-
exception_class = globals()[data["type"]]
|
152
|
-
except KeyError:
|
153
|
-
exception_class = Exception
|
154
|
-
return exception_class(data["message"])
|
155
|
-
|
156
137
|
def to_dict(self) -> dict:
|
157
138
|
"""Return the exception as a dictionary.
|
158
139
|
|
159
140
|
>>> entry = InterviewExceptionEntry.example()
|
160
|
-
>>>
|
141
|
+
>>> entry.to_dict()['exception']
|
142
|
+
ValueError()
|
143
|
+
|
161
144
|
"""
|
162
145
|
return {
|
163
|
-
"exception": self.
|
146
|
+
"exception": self.exception,
|
164
147
|
"time": self.time,
|
165
148
|
"traceback": self.traceback,
|
149
|
+
# "failed_question": self.failed_question.to_dict(),
|
166
150
|
"invigilator": self.invigilator.to_dict(),
|
167
151
|
}
|
168
152
|
|
169
|
-
@classmethod
|
170
|
-
def from_dict(cls, data: dict) -> "InterviewExceptionEntry":
|
171
|
-
"""Create an InterviewExceptionEntry from a dictionary."""
|
172
|
-
from edsl.agents.Invigilator import InvigilatorAI
|
173
|
-
|
174
|
-
exception = cls.deserialize_exception(data["exception"])
|
175
|
-
invigilator = InvigilatorAI.from_dict(data["invigilator"])
|
176
|
-
return cls(exception=exception, invigilator=invigilator)
|
177
|
-
|
178
153
|
def push(self):
|
179
154
|
from edsl import Coop
|
180
155
|
|