edsl 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +3 -2
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +105 -7
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/invigilators/invigilators.py +10 -1
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +127 -46
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +102 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +45 -75
- edsl/language_models/registry.py +5 -0
- edsl/language_models/utilities.py +2 -1
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_check_box.py +171 -149
- edsl/questions/question_dict.py +243 -51
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +63 -16
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +117 -6
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/METADATA +52 -76
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/RECORD +102 -78
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
edsl/coop/coop.py
CHANGED
@@ -3,7 +3,7 @@ import base64
|
|
3
3
|
import json
|
4
4
|
import requests
|
5
5
|
|
6
|
-
from typing import Any, Optional, Union, Literal, TypedDict, TYPE_CHECKING
|
6
|
+
from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
|
7
7
|
from uuid import UUID
|
8
8
|
|
9
9
|
from .. import __version__
|
@@ -19,6 +19,7 @@ from .exceptions import (
|
|
19
19
|
CoopInvalidURLError,
|
20
20
|
CoopNoUUIDError,
|
21
21
|
CoopServerResponseError,
|
22
|
+
CoopValueError,
|
22
23
|
)
|
23
24
|
from .utils import (
|
24
25
|
EDSLObject,
|
@@ -29,6 +30,8 @@ from .utils import (
|
|
29
30
|
)
|
30
31
|
|
31
32
|
from .coop_functions import CoopFunctionsMixin
|
33
|
+
from .coop_regular_objects import CoopRegularObjects
|
34
|
+
from .coop_jobs_objects import CoopJobsObjects
|
32
35
|
from .ep_key_handling import ExpectedParrotKeyHandler
|
33
36
|
|
34
37
|
from ..inference_services.data_structures import ServiceToModelsMapping
|
@@ -44,6 +47,7 @@ class RemoteInferenceResponse(TypedDict):
|
|
44
47
|
reason: str
|
45
48
|
credits_consumed: float
|
46
49
|
version: str
|
50
|
+
job_json_string: Optional[str]
|
47
51
|
|
48
52
|
|
49
53
|
class RemoteInferenceCreationInfo(TypedDict):
|
@@ -243,7 +247,7 @@ class Coop(CoopFunctionsMixin):
|
|
243
247
|
|
244
248
|
if response.status_code >= 400:
|
245
249
|
try:
|
246
|
-
message = response.json().get("detail")
|
250
|
+
message = str(response.json().get("detail"))
|
247
251
|
except json.JSONDecodeError:
|
248
252
|
raise CoopServerResponseError(
|
249
253
|
f"Server returned status code {response.status_code}."
|
@@ -499,6 +503,10 @@ class Coop(CoopFunctionsMixin):
|
|
499
503
|
"""
|
500
504
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
|
501
505
|
object_dict = object.to_dict()
|
506
|
+
|
507
|
+
# Get the object hash
|
508
|
+
object_hash = object.get_hash() if hasattr(object, "get_hash") else None
|
509
|
+
|
502
510
|
if object_type == "scenario" and self._scenario_is_file_store(object_dict):
|
503
511
|
file_store_metadata = {
|
504
512
|
"suffix": object_dict["suffix"],
|
@@ -524,6 +532,7 @@ class Coop(CoopFunctionsMixin):
|
|
524
532
|
"file_store_metadata": file_store_metadata,
|
525
533
|
"visibility": visibility,
|
526
534
|
"version": self._edsl_version,
|
535
|
+
"object_hash": object_hash, # Include the object hash in the payload
|
527
536
|
},
|
528
537
|
)
|
529
538
|
self._resolve_server_response(response)
|
@@ -670,42 +679,138 @@ class Coop(CoopFunctionsMixin):
|
|
670
679
|
object.initialize_cache_from_results()
|
671
680
|
return object
|
672
681
|
|
673
|
-
def
|
682
|
+
def _validate_object_types(
|
683
|
+
self, object_type: Union[ObjectType, List[ObjectType]]
|
684
|
+
) -> List[ObjectType]:
|
674
685
|
"""
|
675
|
-
|
686
|
+
Validate object types and return a list of valid types.
|
687
|
+
|
688
|
+
Args:
|
689
|
+
object_type: Single object type or list of object types to validate
|
690
|
+
|
691
|
+
Returns:
|
692
|
+
List of validated object types
|
693
|
+
|
694
|
+
Raises:
|
695
|
+
CoopValueError: If any object type is invalid
|
696
|
+
"""
|
697
|
+
valid_object_types = ObjectRegistry.object_type_to_edsl_class.keys()
|
698
|
+
if isinstance(object_type, list):
|
699
|
+
invalid_types = [t for t in object_type if t not in valid_object_types]
|
700
|
+
if invalid_types:
|
701
|
+
raise CoopValueError(
|
702
|
+
f"Invalid object type(s): {invalid_types}. "
|
703
|
+
f"Valid types are: {list(valid_object_types)}"
|
704
|
+
)
|
705
|
+
return object_type
|
706
|
+
else:
|
707
|
+
if object_type not in valid_object_types:
|
708
|
+
raise CoopValueError(
|
709
|
+
f"Invalid object type: {object_type}. "
|
710
|
+
f"Valid types are: {list(valid_object_types)}"
|
711
|
+
)
|
712
|
+
return [object_type]
|
713
|
+
|
714
|
+
def _validate_visibility_types(
|
715
|
+
self, visibility: Union[VisibilityType, List[VisibilityType]]
|
716
|
+
) -> List[VisibilityType]:
|
676
717
|
"""
|
677
|
-
|
718
|
+
Validate visibility types and return a list of valid types.
|
719
|
+
|
720
|
+
Args:
|
721
|
+
visibility: Single visibility type or list of visibility types to validate
|
722
|
+
|
723
|
+
Returns:
|
724
|
+
List of validated visibility types
|
725
|
+
|
726
|
+
Raises:
|
727
|
+
CoopValueError: If any visibility type is invalid
|
728
|
+
"""
|
729
|
+
valid_visibility_types = ["private", "public", "unlisted"]
|
730
|
+
if isinstance(visibility, list):
|
731
|
+
invalid_visibilities = [
|
732
|
+
v for v in visibility if v not in valid_visibility_types
|
733
|
+
]
|
734
|
+
if invalid_visibilities:
|
735
|
+
raise CoopValueError(
|
736
|
+
f"Invalid visibility type(s): {invalid_visibilities}. "
|
737
|
+
f"Valid types are: {valid_visibility_types}"
|
738
|
+
)
|
739
|
+
return visibility
|
740
|
+
else:
|
741
|
+
if visibility not in valid_visibility_types:
|
742
|
+
raise CoopValueError(
|
743
|
+
f"Invalid visibility type: {visibility}. "
|
744
|
+
f"Valid types are: {valid_visibility_types}"
|
745
|
+
)
|
746
|
+
return [visibility]
|
747
|
+
|
748
|
+
def list(
|
749
|
+
self,
|
750
|
+
object_type: Union[ObjectType, List[ObjectType], None] = None,
|
751
|
+
visibility: Union[VisibilityType, List[VisibilityType], None] = None,
|
752
|
+
search_query: Union[str, None] = None,
|
753
|
+
page: int = 1,
|
754
|
+
page_size: int = 10,
|
755
|
+
sort_ascending: bool = False,
|
756
|
+
) -> "CoopRegularObjects":
|
757
|
+
"""
|
758
|
+
Retrieve objects either owned by the user or shared with them.
|
759
|
+
|
760
|
+
Notes:
|
761
|
+
- search_query only works with the description field.
|
762
|
+
- If sort_ascending is False, then the most recently created objects are returned first.
|
763
|
+
"""
|
764
|
+
from ..scenarios import Scenario
|
765
|
+
|
766
|
+
if page < 1:
|
767
|
+
raise CoopValueError("The page must be greater than or equal to 1.")
|
768
|
+
if page_size < 1:
|
769
|
+
raise CoopValueError("The page size must be greater than or equal to 1.")
|
770
|
+
if page_size > 100:
|
771
|
+
raise CoopValueError("The page size must be less than or equal to 100.")
|
772
|
+
|
773
|
+
params = {
|
774
|
+
"page": page,
|
775
|
+
"page_size": page_size,
|
776
|
+
"sort_ascending": sort_ascending,
|
777
|
+
}
|
778
|
+
if object_type:
|
779
|
+
params["type"] = self._validate_object_types(object_type)
|
780
|
+
if visibility:
|
781
|
+
params["visibility"] = self._validate_visibility_types(visibility)
|
782
|
+
if search_query:
|
783
|
+
params["search_query"] = search_query
|
784
|
+
|
678
785
|
response = self._send_server_request(
|
679
|
-
uri="api/v0/
|
786
|
+
uri="api/v0/object/list",
|
680
787
|
method="GET",
|
681
|
-
params=
|
788
|
+
params=params,
|
682
789
|
)
|
683
790
|
self._resolve_server_response(response)
|
791
|
+
content = response.json()
|
684
792
|
objects = []
|
685
|
-
for o in
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
o.get("owner_username"), o.get("alias")
|
704
|
-
),
|
705
|
-
}
|
793
|
+
for o in content:
|
794
|
+
object = Scenario(
|
795
|
+
{
|
796
|
+
"uuid": o.get("uuid"),
|
797
|
+
"object_type": o.get("object_type"),
|
798
|
+
"alias": o.get("alias"),
|
799
|
+
"owner_username": o.get("owner_username"),
|
800
|
+
"description": o.get("description"),
|
801
|
+
"visibility": o.get("visibility"),
|
802
|
+
"version": o.get("version"),
|
803
|
+
"url": f"{self.url}/content/{o.get('uuid')}",
|
804
|
+
"alias_url": self._get_alias_url(
|
805
|
+
o.get("owner_username"), o.get("alias")
|
806
|
+
),
|
807
|
+
"last_updated_ts": o.get("last_updated_ts"),
|
808
|
+
"created_ts": o.get("created_ts"),
|
809
|
+
}
|
810
|
+
)
|
706
811
|
objects.append(object)
|
707
812
|
|
708
|
-
return objects
|
813
|
+
return CoopRegularObjects(objects)
|
709
814
|
|
710
815
|
def delete(self, url_or_uuid: Union[str, UUID]) -> dict:
|
711
816
|
"""
|
@@ -793,93 +898,10 @@ class Coop(CoopFunctionsMixin):
|
|
793
898
|
################
|
794
899
|
# Remote Cache
|
795
900
|
################
|
796
|
-
# def remote_cache_create(
|
797
|
-
# self,
|
798
|
-
# cache_entry: CacheEntry,
|
799
|
-
# visibility: VisibilityType = "private",
|
800
|
-
# description: Optional[str] = None,
|
801
|
-
# ) -> dict:
|
802
|
-
# """
|
803
|
-
# Create a single remote cache entry.
|
804
|
-
# If an entry with the same key already exists in the database, update it instead.
|
805
|
-
|
806
|
-
# :param cache_entry: The cache entry to send to the server.
|
807
|
-
# :param visibility: The visibility of the cache entry.
|
808
|
-
# :param optional description: A description for this entry in the remote cache.
|
809
|
-
|
810
|
-
# >>> entry = CacheEntry.example()
|
811
|
-
# >>> coop.remote_cache_create(cache_entry=entry)
|
812
|
-
# {'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
|
813
|
-
# """
|
814
|
-
# response = self._send_server_request(
|
815
|
-
# uri="api/v0/remote-cache",
|
816
|
-
# method="POST",
|
817
|
-
# payload={
|
818
|
-
# "json_string": json.dumps(cache_entry.to_dict()),
|
819
|
-
# "version": self._edsl_version,
|
820
|
-
# "visibility": visibility,
|
821
|
-
# "description": description,
|
822
|
-
# },
|
823
|
-
# )
|
824
|
-
# self._resolve_server_response(response)
|
825
|
-
# response_json = response.json()
|
826
|
-
# created_entry_count = response_json.get("created_entry_count", 0)
|
827
|
-
# if created_entry_count > 0:
|
828
|
-
# self.remote_cache_create_log(
|
829
|
-
# response,
|
830
|
-
# description="Upload new cache entries to server",
|
831
|
-
# cache_entry_count=created_entry_count,
|
832
|
-
# )
|
833
|
-
# return response.json()
|
834
|
-
|
835
|
-
# def remote_cache_create_many(
|
836
|
-
# self,
|
837
|
-
# cache_entries: list[CacheEntry],
|
838
|
-
# visibility: VisibilityType = "private",
|
839
|
-
# description: Optional[str] = None,
|
840
|
-
# ) -> dict:
|
841
|
-
# """
|
842
|
-
# Create many remote cache entries.
|
843
|
-
# If an entry with the same key already exists in the database, update it instead.
|
844
|
-
|
845
|
-
# :param cache_entries: The list of cache entries to send to the server.
|
846
|
-
# :param visibility: The visibility of the cache entries.
|
847
|
-
# :param optional description: A description for these entries in the remote cache.
|
848
|
-
|
849
|
-
# >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
850
|
-
# >>> coop.remote_cache_create_many(cache_entries=entries)
|
851
|
-
# {'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
|
852
|
-
# """
|
853
|
-
# payload = [
|
854
|
-
# {
|
855
|
-
# "json_string": json.dumps(c.to_dict()),
|
856
|
-
# "version": self._edsl_version,
|
857
|
-
# "visibility": visibility,
|
858
|
-
# "description": description,
|
859
|
-
# }
|
860
|
-
# for c in cache_entries
|
861
|
-
# ]
|
862
|
-
# response = self._send_server_request(
|
863
|
-
# uri="api/v0/remote-cache/many",
|
864
|
-
# method="POST",
|
865
|
-
# payload=payload,
|
866
|
-
# timeout=40,
|
867
|
-
# )
|
868
|
-
# self._resolve_server_response(response)
|
869
|
-
# response_json = response.json()
|
870
|
-
# created_entry_count = response_json.get("created_entry_count", 0)
|
871
|
-
# if created_entry_count > 0:
|
872
|
-
# self.remote_cache_create_log(
|
873
|
-
# response,
|
874
|
-
# description="Upload new cache entries to server",
|
875
|
-
# cache_entry_count=created_entry_count,
|
876
|
-
# )
|
877
|
-
# return response.json()
|
878
|
-
|
879
901
|
def remote_cache_get(
|
880
902
|
self,
|
881
903
|
job_uuid: Optional[Union[str, UUID]] = None,
|
882
|
-
) ->
|
904
|
+
) -> List[CacheEntry]:
|
883
905
|
"""
|
884
906
|
Get all remote cache entries.
|
885
907
|
|
@@ -908,8 +930,8 @@ class Coop(CoopFunctionsMixin):
|
|
908
930
|
|
909
931
|
def remote_cache_get_by_key(
|
910
932
|
self,
|
911
|
-
select_keys: Optional[
|
912
|
-
) ->
|
933
|
+
select_keys: Optional[List[str]] = None,
|
934
|
+
) -> List[CacheEntry]:
|
913
935
|
"""
|
914
936
|
Get all remote cache entries.
|
915
937
|
|
@@ -936,126 +958,6 @@ class Coop(CoopFunctionsMixin):
|
|
936
958
|
for v in response.json()
|
937
959
|
]
|
938
960
|
|
939
|
-
def legacy_remote_cache_get(
|
940
|
-
self,
|
941
|
-
exclude_keys: Optional[list[str]] = None,
|
942
|
-
select_keys: Optional[list[str]] = None,
|
943
|
-
) -> list[CacheEntry]:
|
944
|
-
"""
|
945
|
-
Get all remote cache entries.
|
946
|
-
|
947
|
-
:param optional select_keys: Only return CacheEntry objects with these keys.
|
948
|
-
:param optional exclude_keys: Exclude CacheEntry objects with these keys.
|
949
|
-
|
950
|
-
>>> coop.legacy_remote_cache_get()
|
951
|
-
[CacheEntry(...), CacheEntry(...), ...]
|
952
|
-
"""
|
953
|
-
if exclude_keys is None:
|
954
|
-
exclude_keys = []
|
955
|
-
if select_keys is None:
|
956
|
-
select_keys = []
|
957
|
-
response = self._send_server_request(
|
958
|
-
uri="api/v0/remote-cache/legacy/get-many",
|
959
|
-
method="POST",
|
960
|
-
payload={"exclude_keys": exclude_keys, "selected_keys": select_keys},
|
961
|
-
timeout=40,
|
962
|
-
)
|
963
|
-
self._resolve_server_response(response)
|
964
|
-
return [
|
965
|
-
CacheEntry.from_dict(json.loads(v.get("json_string")))
|
966
|
-
for v in response.json()
|
967
|
-
]
|
968
|
-
|
969
|
-
def legacy_remote_cache_get_diff(
|
970
|
-
self,
|
971
|
-
client_cacheentry_keys: list[str],
|
972
|
-
) -> dict:
|
973
|
-
"""
|
974
|
-
Get the difference between local and remote cache entries for a user.
|
975
|
-
"""
|
976
|
-
response = self._send_server_request(
|
977
|
-
uri="api/v0/remote-cache/legacy/get-diff",
|
978
|
-
method="POST",
|
979
|
-
payload={"keys": client_cacheentry_keys},
|
980
|
-
timeout=40,
|
981
|
-
)
|
982
|
-
self._resolve_server_response(response)
|
983
|
-
response_json = response.json()
|
984
|
-
response_dict = {
|
985
|
-
"client_missing_cacheentries": [
|
986
|
-
CacheEntry.from_dict(json.loads(c.get("json_string")))
|
987
|
-
for c in response_json.get("client_missing_cacheentries", [])
|
988
|
-
],
|
989
|
-
"server_missing_cacheentry_keys": response_json.get(
|
990
|
-
"server_missing_cacheentry_keys", []
|
991
|
-
),
|
992
|
-
}
|
993
|
-
downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
|
994
|
-
if downloaded_entry_count > 0:
|
995
|
-
self.legacy_remote_cache_create_log(
|
996
|
-
response,
|
997
|
-
description="Download missing cache entries to client",
|
998
|
-
cache_entry_count=downloaded_entry_count,
|
999
|
-
)
|
1000
|
-
return response_dict
|
1001
|
-
|
1002
|
-
def legacy_remote_cache_clear(self) -> dict:
|
1003
|
-
"""
|
1004
|
-
Clear all remote cache entries.
|
1005
|
-
|
1006
|
-
>>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
1007
|
-
>>> coop.legacy_remote_cache_create_many(cache_entries=entries)
|
1008
|
-
>>> coop.legacy_remote_cache_clear()
|
1009
|
-
{'status': 'success', 'deleted_entry_count': 10}
|
1010
|
-
"""
|
1011
|
-
response = self._send_server_request(
|
1012
|
-
uri="api/v0/remote-cache/legacy/delete-all",
|
1013
|
-
method="DELETE",
|
1014
|
-
)
|
1015
|
-
self._resolve_server_response(response)
|
1016
|
-
response_json = response.json()
|
1017
|
-
deleted_entry_count = response_json.get("deleted_entry_count", 0)
|
1018
|
-
if deleted_entry_count > 0:
|
1019
|
-
self.legacy_remote_cache_create_log(
|
1020
|
-
response,
|
1021
|
-
description="Clear cache entries",
|
1022
|
-
cache_entry_count=deleted_entry_count,
|
1023
|
-
)
|
1024
|
-
return response.json()
|
1025
|
-
|
1026
|
-
def legacy_remote_cache_create_log(
|
1027
|
-
self, response: requests.Response, description: str, cache_entry_count: int
|
1028
|
-
) -> Union[dict, None]:
|
1029
|
-
"""
|
1030
|
-
If a remote cache action has been completed successfully,
|
1031
|
-
log the action.
|
1032
|
-
"""
|
1033
|
-
if 200 <= response.status_code < 300:
|
1034
|
-
log_response = self._send_server_request(
|
1035
|
-
uri="api/v0/remote-cache-log/legacy",
|
1036
|
-
method="POST",
|
1037
|
-
payload={
|
1038
|
-
"description": description,
|
1039
|
-
"cache_entry_count": cache_entry_count,
|
1040
|
-
},
|
1041
|
-
)
|
1042
|
-
self._resolve_server_response(log_response)
|
1043
|
-
return response.json()
|
1044
|
-
|
1045
|
-
def legacy_remote_cache_clear_log(self) -> dict:
|
1046
|
-
"""
|
1047
|
-
Clear all remote cache log entries.
|
1048
|
-
|
1049
|
-
>>> coop.legacy_remote_cache_clear_log()
|
1050
|
-
{'status': 'success'}
|
1051
|
-
"""
|
1052
|
-
response = self._send_server_request(
|
1053
|
-
uri="api/v0/remote-cache-log/legacy/delete-all",
|
1054
|
-
method="DELETE",
|
1055
|
-
)
|
1056
|
-
self._resolve_server_response(response)
|
1057
|
-
return response.json()
|
1058
|
-
|
1059
961
|
def remote_inference_create(
|
1060
962
|
self,
|
1061
963
|
job: "Jobs",
|
@@ -1142,7 +1044,10 @@ class Coop(CoopFunctionsMixin):
|
|
1142
1044
|
)
|
1143
1045
|
|
1144
1046
|
def remote_inference_get(
|
1145
|
-
self,
|
1047
|
+
self,
|
1048
|
+
job_uuid: Optional[str] = None,
|
1049
|
+
results_uuid: Optional[str] = None,
|
1050
|
+
include_json_string: Optional[bool] = False,
|
1146
1051
|
) -> RemoteInferenceResponse:
|
1147
1052
|
"""
|
1148
1053
|
Get the status and details of a remote inference job.
|
@@ -1154,6 +1059,7 @@ class Coop(CoopFunctionsMixin):
|
|
1154
1059
|
job_uuid (str, optional): The UUID of the remote job to check
|
1155
1060
|
results_uuid (str, optional): The UUID of the results associated with the job
|
1156
1061
|
(can be used if you only have the results UUID)
|
1062
|
+
include_json_string (bool, optional): If True, include the json string for the job in the response
|
1157
1063
|
|
1158
1064
|
Returns:
|
1159
1065
|
RemoteInferenceResponse: Information about the job including:
|
@@ -1166,6 +1072,7 @@ class Coop(CoopFunctionsMixin):
|
|
1166
1072
|
- reason: Reason for failure (if applicable)
|
1167
1073
|
- credits_consumed: Credits used for the job execution
|
1168
1074
|
- version: EDSL version used for the job
|
1075
|
+
- job_json_string: The json string for the job (if include_json_string is True)
|
1169
1076
|
|
1170
1077
|
Raises:
|
1171
1078
|
ValueError: If neither job_uuid nor results_uuid is provided
|
@@ -1222,14 +1129,119 @@ class Coop(CoopFunctionsMixin):
|
|
1222
1129
|
"results_url": results_url,
|
1223
1130
|
"latest_error_report_uuid": latest_error_report_uuid,
|
1224
1131
|
"latest_error_report_url": latest_error_report_url,
|
1132
|
+
"latest_failure_description": data.get("latest_failure_details"),
|
1225
1133
|
"status": data.get("status"),
|
1226
1134
|
"reason": data.get("latest_failure_reason"),
|
1227
1135
|
"credits_consumed": data.get("price"),
|
1228
1136
|
"version": data.get("version"),
|
1137
|
+
"job_json_string": (
|
1138
|
+
data.get("job_json_string") if include_json_string else None
|
1139
|
+
),
|
1229
1140
|
}
|
1230
1141
|
)
|
1231
1142
|
|
1232
|
-
def
|
1143
|
+
def _validate_remote_job_status_types(
|
1144
|
+
self, status: Union[RemoteJobStatus, List[RemoteJobStatus]]
|
1145
|
+
) -> List[RemoteJobStatus]:
|
1146
|
+
"""
|
1147
|
+
Validate visibility types and return a list of valid types.
|
1148
|
+
|
1149
|
+
Args:
|
1150
|
+
visibility: Single visibility type or list of visibility types to validate
|
1151
|
+
|
1152
|
+
Returns:
|
1153
|
+
List of validated visibility types
|
1154
|
+
|
1155
|
+
Raises:
|
1156
|
+
CoopValueError: If any visibility type is invalid
|
1157
|
+
"""
|
1158
|
+
valid_status_types = [
|
1159
|
+
"queued",
|
1160
|
+
"running",
|
1161
|
+
"completed",
|
1162
|
+
"failed",
|
1163
|
+
"cancelled",
|
1164
|
+
"cancelling",
|
1165
|
+
"partial_failed",
|
1166
|
+
]
|
1167
|
+
if isinstance(status, list):
|
1168
|
+
invalid_statuses = [s for s in status if s not in valid_status_types]
|
1169
|
+
if invalid_statuses:
|
1170
|
+
raise CoopValueError(
|
1171
|
+
f"Invalid status type(s): {invalid_statuses}. "
|
1172
|
+
f"Valid types are: {valid_status_types}"
|
1173
|
+
)
|
1174
|
+
return status
|
1175
|
+
else:
|
1176
|
+
if status not in valid_status_types:
|
1177
|
+
raise CoopValueError(
|
1178
|
+
f"Invalid status type: {status}. "
|
1179
|
+
f"Valid types are: {valid_status_types}"
|
1180
|
+
)
|
1181
|
+
return [status]
|
1182
|
+
|
1183
|
+
def remote_inference_list(
|
1184
|
+
self,
|
1185
|
+
status: Union[RemoteJobStatus, List[RemoteJobStatus], None] = None,
|
1186
|
+
search_query: Union[str, None] = None,
|
1187
|
+
page: int = 1,
|
1188
|
+
page_size: int = 10,
|
1189
|
+
sort_ascending: bool = False,
|
1190
|
+
) -> "CoopJobsObjects":
|
1191
|
+
"""
|
1192
|
+
Retrieve jobs owned by the user.
|
1193
|
+
|
1194
|
+
Notes:
|
1195
|
+
- search_query only works with the description field.
|
1196
|
+
- If sort_ascending is False, then the most recently created jobs are returned first.
|
1197
|
+
"""
|
1198
|
+
from ..scenarios import Scenario
|
1199
|
+
|
1200
|
+
if page < 1:
|
1201
|
+
raise CoopValueError("The page must be greater than or equal to 1.")
|
1202
|
+
if page_size < 1:
|
1203
|
+
raise CoopValueError("The page size must be greater than or equal to 1.")
|
1204
|
+
if page_size > 100:
|
1205
|
+
raise CoopValueError("The page size must be less than or equal to 100.")
|
1206
|
+
|
1207
|
+
params = {
|
1208
|
+
"page": page,
|
1209
|
+
"page_size": page_size,
|
1210
|
+
"sort_ascending": sort_ascending,
|
1211
|
+
}
|
1212
|
+
if status:
|
1213
|
+
params["status"] = self._validate_remote_job_status_types(status)
|
1214
|
+
if search_query:
|
1215
|
+
params["search_query"] = search_query
|
1216
|
+
|
1217
|
+
response = self._send_server_request(
|
1218
|
+
uri="api/v0/remote-inference/list",
|
1219
|
+
method="GET",
|
1220
|
+
params=params,
|
1221
|
+
)
|
1222
|
+
self._resolve_server_response(response)
|
1223
|
+
content = response.json()
|
1224
|
+
jobs = []
|
1225
|
+
for o in content:
|
1226
|
+
job = Scenario(
|
1227
|
+
{
|
1228
|
+
"uuid": o.get("uuid"),
|
1229
|
+
"description": o.get("description"),
|
1230
|
+
"status": o.get("status"),
|
1231
|
+
"cost_credits": o.get("cost_credits"),
|
1232
|
+
"iterations": o.get("iterations"),
|
1233
|
+
"results_uuid": o.get("results_uuid"),
|
1234
|
+
"latest_error_report_uuid": o.get("latest_error_report_uuid"),
|
1235
|
+
"latest_failure_reason": o.get("latest_failure_reason"),
|
1236
|
+
"version": o.get("version"),
|
1237
|
+
"created_ts": o.get("created_ts"),
|
1238
|
+
}
|
1239
|
+
)
|
1240
|
+
jobs.append(job)
|
1241
|
+
|
1242
|
+
return CoopJobsObjects(jobs)
|
1243
|
+
|
1244
|
+
def get_running_jobs(self) -> List[str]:
|
1233
1245
|
"""
|
1234
1246
|
Get a list of currently running job IDs.
|
1235
1247
|
|
@@ -1442,7 +1454,7 @@ class Coop(CoopFunctionsMixin):
|
|
1442
1454
|
data = response.json()
|
1443
1455
|
return ServiceToModelsMapping(data)
|
1444
1456
|
|
1445
|
-
def fetch_working_models(self) ->
|
1457
|
+
def fetch_working_models(self) -> List[dict]:
|
1446
1458
|
"""
|
1447
1459
|
Fetch a list of working models from Coop.
|
1448
1460
|
|
@@ -1488,6 +1500,28 @@ class Coop(CoopFunctionsMixin):
|
|
1488
1500
|
data = response.json()
|
1489
1501
|
return data
|
1490
1502
|
|
1503
|
+
def get_uuid_from_hash(self, hash_value: str) -> str:
|
1504
|
+
"""
|
1505
|
+
Retrieve the UUID for an object based on its hash.
|
1506
|
+
|
1507
|
+
This method calls the remote endpoint to get the UUID associated with an object hash.
|
1508
|
+
|
1509
|
+
Args:
|
1510
|
+
hash_value (str): The hash value of the object to look up
|
1511
|
+
|
1512
|
+
Returns:
|
1513
|
+
str: The UUID of the object if found
|
1514
|
+
|
1515
|
+
Raises:
|
1516
|
+
CoopServerResponseError: If the object is not found or there's an error
|
1517
|
+
communicating with the server
|
1518
|
+
"""
|
1519
|
+
response = self._send_server_request(
|
1520
|
+
uri=f"api/v0/object/hash/{hash_value}", method="GET"
|
1521
|
+
)
|
1522
|
+
self._resolve_server_response(response)
|
1523
|
+
return response.json().get("uuid")
|
1524
|
+
|
1491
1525
|
def _display_login_url(
|
1492
1526
|
self, edsl_auth_token: str, link_description: Optional[str] = None
|
1493
1527
|
):
|
@@ -1603,7 +1637,7 @@ def main():
|
|
1603
1637
|
coop.get(response.get("uuid"), expected_object_type="question")
|
1604
1638
|
coop.get(response.get("url"))
|
1605
1639
|
coop.create(QuestionMultipleChoice.example())
|
1606
|
-
coop.
|
1640
|
+
coop.list("question")
|
1607
1641
|
coop.patch(response.get("uuid"), visibility="private")
|
1608
1642
|
coop.patch(response.get("uuid"), description="hey")
|
1609
1643
|
coop.patch(response.get("uuid"), value=QuestionFreeText.example())
|
@@ -1638,7 +1672,7 @@ def main():
|
|
1638
1672
|
for object_type, cls in OBJECTS:
|
1639
1673
|
print(f"Testing {object_type} objects")
|
1640
1674
|
# 1. Delete existing objects
|
1641
|
-
existing_objects = coop.
|
1675
|
+
existing_objects = coop.list(object_type)
|
1642
1676
|
for item in existing_objects:
|
1643
1677
|
coop.delete(item.get("uuid"))
|
1644
1678
|
# 2. Create new objects
|
@@ -1650,7 +1684,7 @@ def main():
|
|
1650
1684
|
cls.example(), visibility="unlisted", description="hey"
|
1651
1685
|
)
|
1652
1686
|
# 3. Retrieve all objects
|
1653
|
-
objects = coop.
|
1687
|
+
objects = coop.list(object_type)
|
1654
1688
|
assert len(objects) == 4
|
1655
1689
|
# 4. Try to retrieve an item that does not exist
|
1656
1690
|
try:
|
@@ -1669,7 +1703,7 @@ def main():
|
|
1669
1703
|
# 7. Delete all objects
|
1670
1704
|
for item in objects:
|
1671
1705
|
coop.delete(item.get("uuid"))
|
1672
|
-
assert len(coop.
|
1706
|
+
assert len(coop.list(object_type)) == 0
|
1673
1707
|
|
1674
1708
|
##############
|
1675
1709
|
# C. Remote Cache
|