edsl 0.1.45__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +7 -3
- edsl/__version__.py +1 -1
- edsl/agents/PromptConstructor.py +26 -79
- edsl/agents/QuestionInstructionPromptBuilder.py +70 -32
- edsl/agents/QuestionTemplateReplacementsBuilder.py +12 -2
- edsl/coop/coop.py +155 -94
- edsl/data/RemoteCacheSync.py +10 -9
- edsl/inference_services/AvailableModelFetcher.py +1 -1
- edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
- edsl/jobs/Jobs.py +15 -17
- edsl/jobs/JobsPrompts.py +49 -26
- edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/interviews/Interview.py +6 -3
- edsl/language_models/LanguageModel.py +6 -0
- edsl/questions/question_base_gen_mixin.py +2 -0
- edsl/results/DatasetExportMixin.py +25 -4
- edsl/scenarios/ScenarioList.py +153 -21
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/METADATA +2 -2
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/RECORD +22 -22
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
- {edsl-0.1.45.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -504,90 +504,146 @@ class Coop(CoopFunctionsMixin):
|
|
504
504
|
################
|
505
505
|
# Remote Cache
|
506
506
|
################
|
507
|
-
def remote_cache_create(
|
507
|
+
# def remote_cache_create(
|
508
|
+
# self,
|
509
|
+
# cache_entry: CacheEntry,
|
510
|
+
# visibility: VisibilityType = "private",
|
511
|
+
# description: Optional[str] = None,
|
512
|
+
# ) -> dict:
|
513
|
+
# """
|
514
|
+
# Create a single remote cache entry.
|
515
|
+
# If an entry with the same key already exists in the database, update it instead.
|
516
|
+
|
517
|
+
# :param cache_entry: The cache entry to send to the server.
|
518
|
+
# :param visibility: The visibility of the cache entry.
|
519
|
+
# :param optional description: A description for this entry in the remote cache.
|
520
|
+
|
521
|
+
# >>> entry = CacheEntry.example()
|
522
|
+
# >>> coop.remote_cache_create(cache_entry=entry)
|
523
|
+
# {'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
|
524
|
+
# """
|
525
|
+
# response = self._send_server_request(
|
526
|
+
# uri="api/v0/remote-cache",
|
527
|
+
# method="POST",
|
528
|
+
# payload={
|
529
|
+
# "json_string": json.dumps(cache_entry.to_dict()),
|
530
|
+
# "version": self._edsl_version,
|
531
|
+
# "visibility": visibility,
|
532
|
+
# "description": description,
|
533
|
+
# },
|
534
|
+
# )
|
535
|
+
# self._resolve_server_response(response)
|
536
|
+
# response_json = response.json()
|
537
|
+
# created_entry_count = response_json.get("created_entry_count", 0)
|
538
|
+
# if created_entry_count > 0:
|
539
|
+
# self.remote_cache_create_log(
|
540
|
+
# response,
|
541
|
+
# description="Upload new cache entries to server",
|
542
|
+
# cache_entry_count=created_entry_count,
|
543
|
+
# )
|
544
|
+
# return response.json()
|
545
|
+
|
546
|
+
# def remote_cache_create_many(
|
547
|
+
# self,
|
548
|
+
# cache_entries: list[CacheEntry],
|
549
|
+
# visibility: VisibilityType = "private",
|
550
|
+
# description: Optional[str] = None,
|
551
|
+
# ) -> dict:
|
552
|
+
# """
|
553
|
+
# Create many remote cache entries.
|
554
|
+
# If an entry with the same key already exists in the database, update it instead.
|
555
|
+
|
556
|
+
# :param cache_entries: The list of cache entries to send to the server.
|
557
|
+
# :param visibility: The visibility of the cache entries.
|
558
|
+
# :param optional description: A description for these entries in the remote cache.
|
559
|
+
|
560
|
+
# >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
561
|
+
# >>> coop.remote_cache_create_many(cache_entries=entries)
|
562
|
+
# {'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
|
563
|
+
# """
|
564
|
+
# payload = [
|
565
|
+
# {
|
566
|
+
# "json_string": json.dumps(c.to_dict()),
|
567
|
+
# "version": self._edsl_version,
|
568
|
+
# "visibility": visibility,
|
569
|
+
# "description": description,
|
570
|
+
# }
|
571
|
+
# for c in cache_entries
|
572
|
+
# ]
|
573
|
+
# response = self._send_server_request(
|
574
|
+
# uri="api/v0/remote-cache/many",
|
575
|
+
# method="POST",
|
576
|
+
# payload=payload,
|
577
|
+
# timeout=40,
|
578
|
+
# )
|
579
|
+
# self._resolve_server_response(response)
|
580
|
+
# response_json = response.json()
|
581
|
+
# created_entry_count = response_json.get("created_entry_count", 0)
|
582
|
+
# if created_entry_count > 0:
|
583
|
+
# self.remote_cache_create_log(
|
584
|
+
# response,
|
585
|
+
# description="Upload new cache entries to server",
|
586
|
+
# cache_entry_count=created_entry_count,
|
587
|
+
# )
|
588
|
+
# return response.json()
|
589
|
+
|
590
|
+
def remote_cache_get(
|
508
591
|
self,
|
509
|
-
|
510
|
-
|
511
|
-
description: Optional[str] = None,
|
512
|
-
) -> dict:
|
592
|
+
job_uuid: Optional[Union[str, UUID]] = None,
|
593
|
+
) -> list[CacheEntry]:
|
513
594
|
"""
|
514
|
-
|
515
|
-
If an entry with the same key already exists in the database, update it instead.
|
595
|
+
Get all remote cache entries.
|
516
596
|
|
517
|
-
:param
|
518
|
-
:param visibility: The visibility of the cache entry.
|
519
|
-
:param optional description: A description for this entry in the remote cache.
|
597
|
+
:param optional select_keys: Only return CacheEntry objects with these keys.
|
520
598
|
|
521
|
-
>>>
|
522
|
-
|
523
|
-
{'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
|
599
|
+
>>> coop.remote_cache_get(job_uuid="...")
|
600
|
+
[CacheEntry(...), CacheEntry(...), ...]
|
524
601
|
"""
|
602
|
+
if job_uuid is None:
|
603
|
+
raise ValueError("Must provide a job_uuid.")
|
525
604
|
response = self._send_server_request(
|
526
|
-
uri="api/v0/remote-cache",
|
605
|
+
uri="api/v0/remote-cache/get-many-by-job",
|
527
606
|
method="POST",
|
528
607
|
payload={
|
529
|
-
"
|
530
|
-
"version": self._edsl_version,
|
531
|
-
"visibility": visibility,
|
532
|
-
"description": description,
|
608
|
+
"job_uuid": str(job_uuid),
|
533
609
|
},
|
610
|
+
timeout=40,
|
534
611
|
)
|
535
612
|
self._resolve_server_response(response)
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
response,
|
541
|
-
description="Upload new cache entries to server",
|
542
|
-
cache_entry_count=created_entry_count,
|
543
|
-
)
|
544
|
-
return response.json()
|
613
|
+
return [
|
614
|
+
CacheEntry.from_dict(json.loads(v.get("json_string")))
|
615
|
+
for v in response.json()
|
616
|
+
]
|
545
617
|
|
546
|
-
def
|
618
|
+
def remote_cache_get_by_key(
|
547
619
|
self,
|
548
|
-
|
549
|
-
|
550
|
-
description: Optional[str] = None,
|
551
|
-
) -> dict:
|
620
|
+
select_keys: Optional[list[str]] = None,
|
621
|
+
) -> list[CacheEntry]:
|
552
622
|
"""
|
553
|
-
|
554
|
-
If an entry with the same key already exists in the database, update it instead.
|
623
|
+
Get all remote cache entries.
|
555
624
|
|
556
|
-
:param
|
557
|
-
:param visibility: The visibility of the cache entries.
|
558
|
-
:param optional description: A description for these entries in the remote cache.
|
625
|
+
:param optional select_keys: Only return CacheEntry objects with these keys.
|
559
626
|
|
560
|
-
>>>
|
561
|
-
|
562
|
-
{'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
|
627
|
+
>>> coop.remote_cache_get_by_key(selected_keys=["..."])
|
628
|
+
[CacheEntry(...), CacheEntry(...), ...]
|
563
629
|
"""
|
564
|
-
|
565
|
-
|
566
|
-
"json_string": json.dumps(c.to_dict()),
|
567
|
-
"version": self._edsl_version,
|
568
|
-
"visibility": visibility,
|
569
|
-
"description": description,
|
570
|
-
}
|
571
|
-
for c in cache_entries
|
572
|
-
]
|
630
|
+
if select_keys is None or len(select_keys) == 0:
|
631
|
+
raise ValueError("Must provide a non-empty list of select_keys.")
|
573
632
|
response = self._send_server_request(
|
574
|
-
uri="api/v0/remote-cache/many",
|
633
|
+
uri="api/v0/remote-cache/get-many-by-key",
|
575
634
|
method="POST",
|
576
|
-
payload=
|
635
|
+
payload={
|
636
|
+
"selected_keys": select_keys,
|
637
|
+
},
|
577
638
|
timeout=40,
|
578
639
|
)
|
579
640
|
self._resolve_server_response(response)
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
response,
|
585
|
-
description="Upload new cache entries to server",
|
586
|
-
cache_entry_count=created_entry_count,
|
587
|
-
)
|
588
|
-
return response.json()
|
641
|
+
return [
|
642
|
+
CacheEntry.from_dict(json.loads(v.get("json_string")))
|
643
|
+
for v in response.json()
|
644
|
+
]
|
589
645
|
|
590
|
-
def
|
646
|
+
def legacy_remote_cache_get(
|
591
647
|
self,
|
592
648
|
exclude_keys: Optional[list[str]] = None,
|
593
649
|
select_keys: Optional[list[str]] = None,
|
@@ -595,9 +651,10 @@ class Coop(CoopFunctionsMixin):
|
|
595
651
|
"""
|
596
652
|
Get all remote cache entries.
|
597
653
|
|
654
|
+
:param optional select_keys: Only return CacheEntry objects with these keys.
|
598
655
|
:param optional exclude_keys: Exclude CacheEntry objects with these keys.
|
599
656
|
|
600
|
-
>>> coop.
|
657
|
+
>>> coop.legacy_remote_cache_get()
|
601
658
|
[CacheEntry(...), CacheEntry(...), ...]
|
602
659
|
"""
|
603
660
|
if exclude_keys is None:
|
@@ -605,9 +662,9 @@ class Coop(CoopFunctionsMixin):
|
|
605
662
|
if select_keys is None:
|
606
663
|
select_keys = []
|
607
664
|
response = self._send_server_request(
|
608
|
-
uri="api/v0/remote-cache/get-many",
|
665
|
+
uri="api/v0/remote-cache/legacy/get-many",
|
609
666
|
method="POST",
|
610
|
-
payload={"
|
667
|
+
payload={"exclude_keys": exclude_keys, "selected_keys": select_keys},
|
611
668
|
timeout=40,
|
612
669
|
)
|
613
670
|
self._resolve_server_response(response)
|
@@ -616,7 +673,7 @@ class Coop(CoopFunctionsMixin):
|
|
616
673
|
for v in response.json()
|
617
674
|
]
|
618
675
|
|
619
|
-
def
|
676
|
+
def legacy_remote_cache_get_diff(
|
620
677
|
self,
|
621
678
|
client_cacheentry_keys: list[str],
|
622
679
|
) -> dict:
|
@@ -624,7 +681,7 @@ class Coop(CoopFunctionsMixin):
|
|
624
681
|
Get the difference between local and remote cache entries for a user.
|
625
682
|
"""
|
626
683
|
response = self._send_server_request(
|
627
|
-
uri="api/v0/remote-cache/get-diff",
|
684
|
+
uri="api/v0/remote-cache/legacy/get-diff",
|
628
685
|
method="POST",
|
629
686
|
payload={"keys": client_cacheentry_keys},
|
630
687
|
timeout=40,
|
@@ -642,38 +699,38 @@ class Coop(CoopFunctionsMixin):
|
|
642
699
|
}
|
643
700
|
downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
|
644
701
|
if downloaded_entry_count > 0:
|
645
|
-
self.
|
702
|
+
self.legacy_remote_cache_create_log(
|
646
703
|
response,
|
647
704
|
description="Download missing cache entries to client",
|
648
705
|
cache_entry_count=downloaded_entry_count,
|
649
706
|
)
|
650
707
|
return response_dict
|
651
708
|
|
652
|
-
def
|
709
|
+
def legacy_remote_cache_clear(self) -> dict:
|
653
710
|
"""
|
654
711
|
Clear all remote cache entries.
|
655
712
|
|
656
713
|
>>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
657
|
-
>>> coop.
|
658
|
-
>>> coop.
|
714
|
+
>>> coop.legacy_remote_cache_create_many(cache_entries=entries)
|
715
|
+
>>> coop.legacy_remote_cache_clear()
|
659
716
|
{'status': 'success', 'deleted_entry_count': 10}
|
660
717
|
"""
|
661
718
|
response = self._send_server_request(
|
662
|
-
uri="api/v0/remote-cache/delete-all",
|
719
|
+
uri="api/v0/remote-cache/legacy/delete-all",
|
663
720
|
method="DELETE",
|
664
721
|
)
|
665
722
|
self._resolve_server_response(response)
|
666
723
|
response_json = response.json()
|
667
724
|
deleted_entry_count = response_json.get("deleted_entry_count", 0)
|
668
725
|
if deleted_entry_count > 0:
|
669
|
-
self.
|
726
|
+
self.legacy_remote_cache_create_log(
|
670
727
|
response,
|
671
728
|
description="Clear cache entries",
|
672
729
|
cache_entry_count=deleted_entry_count,
|
673
730
|
)
|
674
731
|
return response.json()
|
675
732
|
|
676
|
-
def
|
733
|
+
def legacy_remote_cache_create_log(
|
677
734
|
self, response: requests.Response, description: str, cache_entry_count: int
|
678
735
|
) -> Union[dict, None]:
|
679
736
|
"""
|
@@ -682,7 +739,7 @@ class Coop(CoopFunctionsMixin):
|
|
682
739
|
"""
|
683
740
|
if 200 <= response.status_code < 300:
|
684
741
|
log_response = self._send_server_request(
|
685
|
-
uri="api/v0/remote-cache-log",
|
742
|
+
uri="api/v0/remote-cache-log/legacy",
|
686
743
|
method="POST",
|
687
744
|
payload={
|
688
745
|
"description": description,
|
@@ -692,15 +749,15 @@ class Coop(CoopFunctionsMixin):
|
|
692
749
|
self._resolve_server_response(log_response)
|
693
750
|
return response.json()
|
694
751
|
|
695
|
-
def
|
752
|
+
def legacy_remote_cache_clear_log(self) -> dict:
|
696
753
|
"""
|
697
754
|
Clear all remote cache log entries.
|
698
755
|
|
699
|
-
>>> coop.
|
756
|
+
>>> coop.legacy_remote_cache_clear_log()
|
700
757
|
{'status': 'success'}
|
701
758
|
"""
|
702
759
|
response = self._send_server_request(
|
703
|
-
uri="api/v0/remote-cache-log/delete-all",
|
760
|
+
uri="api/v0/remote-cache-log/legacy/delete-all",
|
704
761
|
method="DELETE",
|
705
762
|
)
|
706
763
|
self._resolve_server_response(response)
|
@@ -714,6 +771,7 @@ class Coop(CoopFunctionsMixin):
|
|
714
771
|
visibility: Optional[VisibilityType] = "unlisted",
|
715
772
|
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
716
773
|
iterations: Optional[int] = 1,
|
774
|
+
fresh: Optional[bool] = False,
|
717
775
|
) -> RemoteInferenceCreationInfo:
|
718
776
|
"""
|
719
777
|
Send a remote inference job to the server.
|
@@ -742,6 +800,7 @@ class Coop(CoopFunctionsMixin):
|
|
742
800
|
"visibility": visibility,
|
743
801
|
"version": self._edsl_version,
|
744
802
|
"initial_results_visibility": initial_results_visibility,
|
803
|
+
"fresh": fresh,
|
745
804
|
},
|
746
805
|
)
|
747
806
|
self._resolve_server_response(response)
|
@@ -1037,19 +1096,21 @@ class Coop(CoopFunctionsMixin):
|
|
1037
1096
|
if console.is_terminal:
|
1038
1097
|
# Running in a standard terminal, show the full URL
|
1039
1098
|
if link_description:
|
1040
|
-
rich_print(
|
1099
|
+
rich_print(
|
1100
|
+
"{link_description}\n[#38bdf8][link={url}]{url}[/link][/#38bdf8]"
|
1101
|
+
)
|
1041
1102
|
else:
|
1042
1103
|
rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
1043
1104
|
else:
|
1044
1105
|
# Running in an interactive environment (e.g., Jupyter Notebook), hide the URL
|
1045
1106
|
if link_description:
|
1046
|
-
rich_print(
|
1107
|
+
rich_print(
|
1108
|
+
f"{link_description}\n[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]"
|
1109
|
+
)
|
1047
1110
|
else:
|
1048
|
-
rich_print(
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1111
|
+
rich_print(
|
1112
|
+
f"[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]"
|
1113
|
+
)
|
1053
1114
|
|
1054
1115
|
def _get_api_key(self, edsl_auth_token: str):
|
1055
1116
|
"""
|
@@ -1204,24 +1265,24 @@ def main():
|
|
1204
1265
|
# C. Remote Cache
|
1205
1266
|
##############
|
1206
1267
|
# clear
|
1207
|
-
coop.
|
1208
|
-
assert coop.
|
1268
|
+
coop.legacy_remote_cache_clear()
|
1269
|
+
assert coop.legacy_remote_cache_get() == []
|
1209
1270
|
# create one remote cache entry
|
1210
1271
|
cache_entry = CacheEntry.example()
|
1211
1272
|
cache_entry.to_dict()
|
1212
|
-
coop.remote_cache_create(cache_entry)
|
1273
|
+
# coop.remote_cache_create(cache_entry)
|
1213
1274
|
# create many remote cache entries
|
1214
1275
|
cache_entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
1215
|
-
coop.remote_cache_create_many(cache_entries)
|
1276
|
+
# coop.remote_cache_create_many(cache_entries)
|
1216
1277
|
# get all remote cache entries
|
1217
|
-
coop.
|
1218
|
-
coop.
|
1219
|
-
coop.
|
1278
|
+
coop.legacy_remote_cache_get()
|
1279
|
+
coop.legacy_remote_cache_get(exclude_keys=[])
|
1280
|
+
coop.legacy_remote_cache_get(exclude_keys=["a"])
|
1220
1281
|
exclude_keys = [cache_entry.key for cache_entry in cache_entries]
|
1221
|
-
coop.
|
1282
|
+
coop.legacy_remote_cache_get(exclude_keys)
|
1222
1283
|
# clear
|
1223
|
-
coop.
|
1224
|
-
coop.
|
1284
|
+
coop.legacy_remote_cache_clear()
|
1285
|
+
coop.legacy_remote_cache_get()
|
1225
1286
|
|
1226
1287
|
##############
|
1227
1288
|
# D. Remote Inference
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -100,7 +100,7 @@ class RemoteCacheSync(AbstractContextManager):
|
|
100
100
|
|
101
101
|
def _get_cache_difference(self) -> CacheDifference:
|
102
102
|
"""Retrieves differences between local and remote caches."""
|
103
|
-
diff = self.coop.
|
103
|
+
diff = self.coop.legacy_remote_cache_get_diff(self.cache.keys())
|
104
104
|
return CacheDifference(
|
105
105
|
client_missing_entries=diff.get("client_missing_cacheentries", []),
|
106
106
|
server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
|
@@ -112,7 +112,7 @@ class RemoteCacheSync(AbstractContextManager):
|
|
112
112
|
missing_count = len(diff.client_missing_entries)
|
113
113
|
|
114
114
|
if missing_count == 0:
|
115
|
-
|
115
|
+
# self._output("No new entries to add to local cache.")
|
116
116
|
return
|
117
117
|
|
118
118
|
# self._output(
|
@@ -154,22 +154,23 @@ class RemoteCacheSync(AbstractContextManager):
|
|
154
154
|
upload_count = len(entries_to_upload)
|
155
155
|
|
156
156
|
if upload_count > 0:
|
157
|
+
pass
|
157
158
|
# self._output(
|
158
159
|
# f"Updating remote cache with {upload_count:,} new "
|
159
160
|
# f"{'entry' if upload_count == 1 else 'entries'}..."
|
160
161
|
# )
|
161
162
|
|
162
|
-
self.coop.remote_cache_create_many(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
)
|
163
|
+
# self.coop.remote_cache_create_many(
|
164
|
+
# entries_to_upload,
|
165
|
+
# visibility="private",
|
166
|
+
# description=self.remote_cache_description,
|
167
|
+
# )
|
167
168
|
# self._output("Remote cache updated!")
|
168
169
|
# else:
|
169
|
-
|
170
|
+
# self._output("No new entries to add to remote cache.")
|
170
171
|
|
171
172
|
# self._output(
|
172
|
-
|
173
|
+
# f"There are {len(self.cache.keys()):,} entries in the local cache."
|
173
174
|
# )
|
174
175
|
|
175
176
|
|
@@ -69,7 +69,7 @@ class AvailableModelFetcher:
|
|
69
69
|
|
70
70
|
Returns a list of [model, service_name, index] entries.
|
71
71
|
"""
|
72
|
-
if service == "azure":
|
72
|
+
if service == "azure" or service == "bedrock":
|
73
73
|
force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
|
74
74
|
|
75
75
|
if service: # they passed a specific service
|
@@ -66,10 +66,14 @@ class SkipHandler:
|
|
66
66
|
)
|
67
67
|
)
|
68
68
|
|
69
|
+
|
69
70
|
def cancel_between(start, end):
|
70
71
|
"""Cancel the tasks for questions between the start and end indices."""
|
71
72
|
for i in range(start, end):
|
72
|
-
|
73
|
+
#print(f"Cancelling task {i}")
|
74
|
+
#self.interview.tasks[i].cancel()
|
75
|
+
#self.interview.tasks[i].set_result("skipped")
|
76
|
+
self.interview.skip_flags[self.interview.survey.questions[i].question_name] = True
|
73
77
|
|
74
78
|
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
75
79
|
cancel_between(
|
@@ -80,6 +84,8 @@ class SkipHandler:
|
|
80
84
|
if next_question_index > (current_question_index + 1):
|
81
85
|
cancel_between(current_question_index + 1, next_question_index)
|
82
86
|
|
87
|
+
|
88
|
+
|
83
89
|
|
84
90
|
class AnswerQuestionFunctionConstructor:
|
85
91
|
"""Constructs a function that answers a question and records the answer."""
|
@@ -161,6 +167,11 @@ class AnswerQuestionFunctionConstructor:
|
|
161
167
|
async def attempt_answer():
|
162
168
|
invigilator = self.invigilator_fetcher(question)
|
163
169
|
|
170
|
+
if self.interview.skip_flags.get(question.question_name, False):
|
171
|
+
return invigilator.get_failed_task_result(
|
172
|
+
failure_reason="Question skipped."
|
173
|
+
)
|
174
|
+
|
164
175
|
if self.skip_handler.should_skip(question):
|
165
176
|
return invigilator.get_failed_task_result(
|
166
177
|
failure_reason="Question skipped."
|
edsl/jobs/Jobs.py
CHANGED
@@ -277,7 +277,7 @@ class Jobs(Base):
|
|
277
277
|
|
278
278
|
return JobsComponentConstructor(self).by(*args)
|
279
279
|
|
280
|
-
def prompts(self) -> "Dataset":
|
280
|
+
def prompts(self, iterations=1) -> "Dataset":
|
281
281
|
"""Return a Dataset of prompts that will be used.
|
282
282
|
|
283
283
|
|
@@ -285,7 +285,7 @@ class Jobs(Base):
|
|
285
285
|
>>> Jobs.example().prompts()
|
286
286
|
Dataset(...)
|
287
287
|
"""
|
288
|
-
return JobsPrompts(self).prompts()
|
288
|
+
return JobsPrompts(self).prompts(iterations=iterations)
|
289
289
|
|
290
290
|
def show_prompts(self, all: bool = False) -> None:
|
291
291
|
"""Print the prompts."""
|
@@ -418,11 +418,9 @@ class Jobs(Base):
|
|
418
418
|
BucketCollection(...)
|
419
419
|
"""
|
420
420
|
bc = BucketCollection.from_models(self.models)
|
421
|
-
|
421
|
+
|
422
422
|
if self.run_config.environment.key_lookup is not None:
|
423
|
-
bc.update_from_key_lookup(
|
424
|
-
self.run_config.environment.key_lookup
|
425
|
-
)
|
423
|
+
bc.update_from_key_lookup(self.run_config.environment.key_lookup)
|
426
424
|
return bc
|
427
425
|
|
428
426
|
def html(self):
|
@@ -484,25 +482,24 @@ class Jobs(Base):
|
|
484
482
|
def _start_remote_inference_job(
|
485
483
|
self, job_handler: Optional[JobsRemoteInferenceHandler] = None
|
486
484
|
) -> Union["Results", None]:
|
487
|
-
|
488
485
|
if job_handler is None:
|
489
486
|
job_handler = self._create_remote_inference_handler()
|
490
|
-
|
487
|
+
|
491
488
|
job_info = job_handler.create_remote_inference_job(
|
492
|
-
|
493
|
-
|
494
|
-
|
489
|
+
iterations=self.run_config.parameters.n,
|
490
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
491
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
492
|
+
fresh=self.run_config.parameters.fresh,
|
495
493
|
)
|
496
494
|
return job_info
|
497
|
-
|
498
|
-
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
499
495
|
|
496
|
+
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
500
497
|
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
501
|
-
|
498
|
+
|
502
499
|
return JobsRemoteInferenceHandler(
|
503
500
|
self, verbose=self.run_config.parameters.verbose
|
504
501
|
)
|
505
|
-
|
502
|
+
|
506
503
|
def _remote_results(
|
507
504
|
self,
|
508
505
|
config: RunConfig,
|
@@ -516,7 +513,8 @@ class Jobs(Base):
|
|
516
513
|
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
517
514
|
job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
|
518
515
|
if background:
|
519
|
-
from edsl.results.Results import Results
|
516
|
+
from edsl.results.Results import Results
|
517
|
+
|
520
518
|
results = Results.from_job_info(job_info)
|
521
519
|
return results
|
522
520
|
else:
|
@@ -603,7 +601,7 @@ class Jobs(Base):
|
|
603
601
|
# first try to run the job remotely
|
604
602
|
if (results := self._remote_results(config)) is not None:
|
605
603
|
return results
|
606
|
-
|
604
|
+
|
607
605
|
self._check_if_local_keys_ok()
|
608
606
|
|
609
607
|
if config.environment.bucket_collection is None:
|