edsl 0.1.54__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/buckets/__init__.py +8 -3
  8. edsl/buckets/bucket_collection.py +9 -3
  9. edsl/buckets/model_buckets.py +4 -2
  10. edsl/buckets/token_bucket.py +2 -2
  11. edsl/buckets/token_bucket_client.py +5 -3
  12. edsl/caching/cache.py +131 -62
  13. edsl/caching/cache_entry.py +70 -58
  14. edsl/caching/sql_dict.py +17 -0
  15. edsl/cli.py +99 -0
  16. edsl/config/config_class.py +16 -0
  17. edsl/conversation/__init__.py +31 -0
  18. edsl/coop/coop.py +276 -242
  19. edsl/coop/coop_jobs_objects.py +59 -0
  20. edsl/coop/coop_objects.py +29 -0
  21. edsl/coop/coop_regular_objects.py +26 -0
  22. edsl/coop/utils.py +24 -19
  23. edsl/dataset/dataset.py +338 -101
  24. edsl/db_list/sqlite_list.py +349 -0
  25. edsl/inference_services/__init__.py +40 -5
  26. edsl/inference_services/exceptions.py +11 -0
  27. edsl/inference_services/services/anthropic_service.py +5 -2
  28. edsl/inference_services/services/aws_bedrock.py +6 -2
  29. edsl/inference_services/services/azure_ai.py +6 -2
  30. edsl/inference_services/services/google_service.py +3 -2
  31. edsl/inference_services/services/mistral_ai_service.py +6 -2
  32. edsl/inference_services/services/open_ai_service.py +6 -2
  33. edsl/inference_services/services/perplexity_service.py +6 -2
  34. edsl/inference_services/services/test_service.py +94 -5
  35. edsl/interviews/answering_function.py +167 -59
  36. edsl/interviews/interview.py +124 -72
  37. edsl/interviews/interview_task_manager.py +10 -0
  38. edsl/invigilators/invigilators.py +9 -0
  39. edsl/jobs/async_interview_runner.py +146 -104
  40. edsl/jobs/data_structures.py +6 -4
  41. edsl/jobs/decorators.py +61 -0
  42. edsl/jobs/fetch_invigilator.py +61 -18
  43. edsl/jobs/html_table_job_logger.py +14 -2
  44. edsl/jobs/jobs.py +180 -104
  45. edsl/jobs/jobs_component_constructor.py +2 -2
  46. edsl/jobs/jobs_interview_constructor.py +2 -0
  47. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  48. edsl/jobs/jobs_runner_status.py +30 -25
  49. edsl/jobs/progress_bar_manager.py +79 -0
  50. edsl/jobs/remote_inference.py +35 -1
  51. edsl/key_management/key_lookup_builder.py +6 -1
  52. edsl/language_models/language_model.py +86 -6
  53. edsl/language_models/model.py +10 -3
  54. edsl/language_models/price_manager.py +45 -75
  55. edsl/language_models/registry.py +5 -0
  56. edsl/notebooks/notebook.py +77 -10
  57. edsl/questions/VALIDATION_README.md +134 -0
  58. edsl/questions/__init__.py +24 -1
  59. edsl/questions/exceptions.py +21 -0
  60. edsl/questions/question_dict.py +201 -16
  61. edsl/questions/question_multiple_choice_with_other.py +624 -0
  62. edsl/questions/question_registry.py +2 -1
  63. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  64. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  65. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  66. edsl/questions/validation_analysis.py +185 -0
  67. edsl/questions/validation_cli.py +131 -0
  68. edsl/questions/validation_html_report.py +404 -0
  69. edsl/questions/validation_logger.py +136 -0
  70. edsl/results/result.py +63 -16
  71. edsl/results/results.py +702 -171
  72. edsl/scenarios/construct_download_link.py +16 -3
  73. edsl/scenarios/directory_scanner.py +226 -226
  74. edsl/scenarios/file_methods.py +5 -0
  75. edsl/scenarios/file_store.py +117 -6
  76. edsl/scenarios/handlers/__init__.py +5 -1
  77. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  78. edsl/scenarios/handlers/webm_file_store.py +104 -0
  79. edsl/scenarios/scenario.py +120 -101
  80. edsl/scenarios/scenario_list.py +800 -727
  81. edsl/scenarios/scenario_list_gc_test.py +146 -0
  82. edsl/scenarios/scenario_list_memory_test.py +214 -0
  83. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  84. edsl/scenarios/scenario_selector.py +5 -4
  85. edsl/scenarios/scenario_source.py +1990 -0
  86. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  87. edsl/surveys/survey.py +22 -0
  88. edsl/tasks/__init__.py +4 -2
  89. edsl/tasks/task_history.py +198 -36
  90. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  91. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  92. edsl/utilities/__init__.py +2 -1
  93. edsl/utilities/decorators.py +121 -0
  94. edsl/utilities/memory_debugger.py +1010 -0
  95. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/METADATA +51 -76
  96. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/RECORD +99 -75
  97. edsl/jobs/jobs_runner_asyncio.py +0 -281
  98. edsl/language_models/unused/fake_openai_service.py +0 -60
  99. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
  100. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
  101. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
edsl/coop/coop.py CHANGED
@@ -3,7 +3,7 @@ import base64
3
3
  import json
4
4
  import requests
5
5
 
6
- from typing import Any, Optional, Union, Literal, TypedDict, TYPE_CHECKING
6
+ from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
7
7
  from uuid import UUID
8
8
 
9
9
  from .. import __version__
@@ -19,6 +19,7 @@ from .exceptions import (
19
19
  CoopInvalidURLError,
20
20
  CoopNoUUIDError,
21
21
  CoopServerResponseError,
22
+ CoopValueError,
22
23
  )
23
24
  from .utils import (
24
25
  EDSLObject,
@@ -29,6 +30,8 @@ from .utils import (
29
30
  )
30
31
 
31
32
  from .coop_functions import CoopFunctionsMixin
33
+ from .coop_regular_objects import CoopRegularObjects
34
+ from .coop_jobs_objects import CoopJobsObjects
32
35
  from .ep_key_handling import ExpectedParrotKeyHandler
33
36
 
34
37
  from ..inference_services.data_structures import ServiceToModelsMapping
@@ -44,6 +47,7 @@ class RemoteInferenceResponse(TypedDict):
44
47
  reason: str
45
48
  credits_consumed: float
46
49
  version: str
50
+ job_json_string: Optional[str]
47
51
 
48
52
 
49
53
  class RemoteInferenceCreationInfo(TypedDict):
@@ -243,7 +247,7 @@ class Coop(CoopFunctionsMixin):
243
247
 
244
248
  if response.status_code >= 400:
245
249
  try:
246
- message = response.json().get("detail")
250
+ message = str(response.json().get("detail"))
247
251
  except json.JSONDecodeError:
248
252
  raise CoopServerResponseError(
249
253
  f"Server returned status code {response.status_code}."
@@ -499,6 +503,10 @@ class Coop(CoopFunctionsMixin):
499
503
  """
500
504
  object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
501
505
  object_dict = object.to_dict()
506
+
507
+ # Get the object hash
508
+ object_hash = object.get_hash() if hasattr(object, "get_hash") else None
509
+
502
510
  if object_type == "scenario" and self._scenario_is_file_store(object_dict):
503
511
  file_store_metadata = {
504
512
  "suffix": object_dict["suffix"],
@@ -524,6 +532,7 @@ class Coop(CoopFunctionsMixin):
524
532
  "file_store_metadata": file_store_metadata,
525
533
  "visibility": visibility,
526
534
  "version": self._edsl_version,
535
+ "object_hash": object_hash, # Include the object hash in the payload
527
536
  },
528
537
  )
529
538
  self._resolve_server_response(response)
@@ -670,42 +679,138 @@ class Coop(CoopFunctionsMixin):
670
679
  object.initialize_cache_from_results()
671
680
  return object
672
681
 
673
- def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
682
+ def _validate_object_types(
683
+ self, object_type: Union[ObjectType, List[ObjectType]]
684
+ ) -> List[ObjectType]:
674
685
  """
675
- Retrieve all objects of a certain type associated with the user.
686
+ Validate object types and return a list of valid types.
687
+
688
+ Args:
689
+ object_type: Single object type or list of object types to validate
690
+
691
+ Returns:
692
+ List of validated object types
693
+
694
+ Raises:
695
+ CoopValueError: If any object type is invalid
696
+ """
697
+ valid_object_types = ObjectRegistry.object_type_to_edsl_class.keys()
698
+ if isinstance(object_type, list):
699
+ invalid_types = [t for t in object_type if t not in valid_object_types]
700
+ if invalid_types:
701
+ raise CoopValueError(
702
+ f"Invalid object type(s): {invalid_types}. "
703
+ f"Valid types are: {list(valid_object_types)}"
704
+ )
705
+ return object_type
706
+ else:
707
+ if object_type not in valid_object_types:
708
+ raise CoopValueError(
709
+ f"Invalid object type: {object_type}. "
710
+ f"Valid types are: {list(valid_object_types)}"
711
+ )
712
+ return [object_type]
713
+
714
+ def _validate_visibility_types(
715
+ self, visibility: Union[VisibilityType, List[VisibilityType]]
716
+ ) -> List[VisibilityType]:
676
717
  """
677
- edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
718
+ Validate visibility types and return a list of valid types.
719
+
720
+ Args:
721
+ visibility: Single visibility type or list of visibility types to validate
722
+
723
+ Returns:
724
+ List of validated visibility types
725
+
726
+ Raises:
727
+ CoopValueError: If any visibility type is invalid
728
+ """
729
+ valid_visibility_types = ["private", "public", "unlisted"]
730
+ if isinstance(visibility, list):
731
+ invalid_visibilities = [
732
+ v for v in visibility if v not in valid_visibility_types
733
+ ]
734
+ if invalid_visibilities:
735
+ raise CoopValueError(
736
+ f"Invalid visibility type(s): {invalid_visibilities}. "
737
+ f"Valid types are: {valid_visibility_types}"
738
+ )
739
+ return visibility
740
+ else:
741
+ if visibility not in valid_visibility_types:
742
+ raise CoopValueError(
743
+ f"Invalid visibility type: {visibility}. "
744
+ f"Valid types are: {valid_visibility_types}"
745
+ )
746
+ return [visibility]
747
+
748
+ def list(
749
+ self,
750
+ object_type: Union[ObjectType, List[ObjectType], None] = None,
751
+ visibility: Union[VisibilityType, List[VisibilityType], None] = None,
752
+ search_query: Union[str, None] = None,
753
+ page: int = 1,
754
+ page_size: int = 10,
755
+ sort_ascending: bool = False,
756
+ ) -> "CoopRegularObjects":
757
+ """
758
+ Retrieve objects either owned by the user or shared with them.
759
+
760
+ Notes:
761
+ - search_query only works with the description field.
762
+ - If sort_ascending is False, then the most recently created objects are returned first.
763
+ """
764
+ from ..scenarios import Scenario
765
+
766
+ if page < 1:
767
+ raise CoopValueError("The page must be greater than or equal to 1.")
768
+ if page_size < 1:
769
+ raise CoopValueError("The page size must be greater than or equal to 1.")
770
+ if page_size > 100:
771
+ raise CoopValueError("The page size must be less than or equal to 100.")
772
+
773
+ params = {
774
+ "page": page,
775
+ "page_size": page_size,
776
+ "sort_ascending": sort_ascending,
777
+ }
778
+ if object_type:
779
+ params["type"] = self._validate_object_types(object_type)
780
+ if visibility:
781
+ params["visibility"] = self._validate_visibility_types(visibility)
782
+ if search_query:
783
+ params["search_query"] = search_query
784
+
678
785
  response = self._send_server_request(
679
- uri="api/v0/objects",
786
+ uri="api/v0/object/list",
680
787
  method="GET",
681
- params={"type": object_type},
788
+ params=params,
682
789
  )
683
790
  self._resolve_server_response(response)
791
+ content = response.json()
684
792
  objects = []
685
- for o in response.json():
686
- json_string = o.get("json_string")
687
- ## check if load from bucket needed.
688
- if "load_from:" in json_string[0:12]:
689
- load_link = json_string.split("load_from:")[1]
690
- object_data = requests.get(load_link)
691
- self._resolve_gcs_response(object_data)
692
- json_string = object_data.text
693
-
694
- json_string = json.loads(json_string)
695
- object = {
696
- "object": edsl_class.from_dict(json_string),
697
- "uuid": o.get("uuid"),
698
- "version": o.get("version"),
699
- "description": o.get("description"),
700
- "visibility": o.get("visibility"),
701
- "url": f"{self.url}/content/{o.get('uuid')}",
702
- "alias_url": self._get_alias_url(
703
- o.get("owner_username"), o.get("alias")
704
- ),
705
- }
793
+ for o in content:
794
+ object = Scenario(
795
+ {
796
+ "uuid": o.get("uuid"),
797
+ "object_type": o.get("object_type"),
798
+ "alias": o.get("alias"),
799
+ "owner_username": o.get("owner_username"),
800
+ "description": o.get("description"),
801
+ "visibility": o.get("visibility"),
802
+ "version": o.get("version"),
803
+ "url": f"{self.url}/content/{o.get('uuid')}",
804
+ "alias_url": self._get_alias_url(
805
+ o.get("owner_username"), o.get("alias")
806
+ ),
807
+ "last_updated_ts": o.get("last_updated_ts"),
808
+ "created_ts": o.get("created_ts"),
809
+ }
810
+ )
706
811
  objects.append(object)
707
812
 
708
- return objects
813
+ return CoopRegularObjects(objects)
709
814
 
710
815
  def delete(self, url_or_uuid: Union[str, UUID]) -> dict:
711
816
  """
@@ -793,93 +898,10 @@ class Coop(CoopFunctionsMixin):
793
898
  ################
794
899
  # Remote Cache
795
900
  ################
796
- # def remote_cache_create(
797
- # self,
798
- # cache_entry: CacheEntry,
799
- # visibility: VisibilityType = "private",
800
- # description: Optional[str] = None,
801
- # ) -> dict:
802
- # """
803
- # Create a single remote cache entry.
804
- # If an entry with the same key already exists in the database, update it instead.
805
-
806
- # :param cache_entry: The cache entry to send to the server.
807
- # :param visibility: The visibility of the cache entry.
808
- # :param optional description: A description for this entry in the remote cache.
809
-
810
- # >>> entry = CacheEntry.example()
811
- # >>> coop.remote_cache_create(cache_entry=entry)
812
- # {'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
813
- # """
814
- # response = self._send_server_request(
815
- # uri="api/v0/remote-cache",
816
- # method="POST",
817
- # payload={
818
- # "json_string": json.dumps(cache_entry.to_dict()),
819
- # "version": self._edsl_version,
820
- # "visibility": visibility,
821
- # "description": description,
822
- # },
823
- # )
824
- # self._resolve_server_response(response)
825
- # response_json = response.json()
826
- # created_entry_count = response_json.get("created_entry_count", 0)
827
- # if created_entry_count > 0:
828
- # self.remote_cache_create_log(
829
- # response,
830
- # description="Upload new cache entries to server",
831
- # cache_entry_count=created_entry_count,
832
- # )
833
- # return response.json()
834
-
835
- # def remote_cache_create_many(
836
- # self,
837
- # cache_entries: list[CacheEntry],
838
- # visibility: VisibilityType = "private",
839
- # description: Optional[str] = None,
840
- # ) -> dict:
841
- # """
842
- # Create many remote cache entries.
843
- # If an entry with the same key already exists in the database, update it instead.
844
-
845
- # :param cache_entries: The list of cache entries to send to the server.
846
- # :param visibility: The visibility of the cache entries.
847
- # :param optional description: A description for these entries in the remote cache.
848
-
849
- # >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
850
- # >>> coop.remote_cache_create_many(cache_entries=entries)
851
- # {'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
852
- # """
853
- # payload = [
854
- # {
855
- # "json_string": json.dumps(c.to_dict()),
856
- # "version": self._edsl_version,
857
- # "visibility": visibility,
858
- # "description": description,
859
- # }
860
- # for c in cache_entries
861
- # ]
862
- # response = self._send_server_request(
863
- # uri="api/v0/remote-cache/many",
864
- # method="POST",
865
- # payload=payload,
866
- # timeout=40,
867
- # )
868
- # self._resolve_server_response(response)
869
- # response_json = response.json()
870
- # created_entry_count = response_json.get("created_entry_count", 0)
871
- # if created_entry_count > 0:
872
- # self.remote_cache_create_log(
873
- # response,
874
- # description="Upload new cache entries to server",
875
- # cache_entry_count=created_entry_count,
876
- # )
877
- # return response.json()
878
-
879
901
  def remote_cache_get(
880
902
  self,
881
903
  job_uuid: Optional[Union[str, UUID]] = None,
882
- ) -> list[CacheEntry]:
904
+ ) -> List[CacheEntry]:
883
905
  """
884
906
  Get all remote cache entries.
885
907
 
@@ -908,8 +930,8 @@ class Coop(CoopFunctionsMixin):
908
930
 
909
931
  def remote_cache_get_by_key(
910
932
  self,
911
- select_keys: Optional[list[str]] = None,
912
- ) -> list[CacheEntry]:
933
+ select_keys: Optional[List[str]] = None,
934
+ ) -> List[CacheEntry]:
913
935
  """
914
936
  Get all remote cache entries.
915
937
 
@@ -936,126 +958,6 @@ class Coop(CoopFunctionsMixin):
936
958
  for v in response.json()
937
959
  ]
938
960
 
939
- def legacy_remote_cache_get(
940
- self,
941
- exclude_keys: Optional[list[str]] = None,
942
- select_keys: Optional[list[str]] = None,
943
- ) -> list[CacheEntry]:
944
- """
945
- Get all remote cache entries.
946
-
947
- :param optional select_keys: Only return CacheEntry objects with these keys.
948
- :param optional exclude_keys: Exclude CacheEntry objects with these keys.
949
-
950
- >>> coop.legacy_remote_cache_get()
951
- [CacheEntry(...), CacheEntry(...), ...]
952
- """
953
- if exclude_keys is None:
954
- exclude_keys = []
955
- if select_keys is None:
956
- select_keys = []
957
- response = self._send_server_request(
958
- uri="api/v0/remote-cache/legacy/get-many",
959
- method="POST",
960
- payload={"exclude_keys": exclude_keys, "selected_keys": select_keys},
961
- timeout=40,
962
- )
963
- self._resolve_server_response(response)
964
- return [
965
- CacheEntry.from_dict(json.loads(v.get("json_string")))
966
- for v in response.json()
967
- ]
968
-
969
- def legacy_remote_cache_get_diff(
970
- self,
971
- client_cacheentry_keys: list[str],
972
- ) -> dict:
973
- """
974
- Get the difference between local and remote cache entries for a user.
975
- """
976
- response = self._send_server_request(
977
- uri="api/v0/remote-cache/legacy/get-diff",
978
- method="POST",
979
- payload={"keys": client_cacheentry_keys},
980
- timeout=40,
981
- )
982
- self._resolve_server_response(response)
983
- response_json = response.json()
984
- response_dict = {
985
- "client_missing_cacheentries": [
986
- CacheEntry.from_dict(json.loads(c.get("json_string")))
987
- for c in response_json.get("client_missing_cacheentries", [])
988
- ],
989
- "server_missing_cacheentry_keys": response_json.get(
990
- "server_missing_cacheentry_keys", []
991
- ),
992
- }
993
- downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
994
- if downloaded_entry_count > 0:
995
- self.legacy_remote_cache_create_log(
996
- response,
997
- description="Download missing cache entries to client",
998
- cache_entry_count=downloaded_entry_count,
999
- )
1000
- return response_dict
1001
-
1002
- def legacy_remote_cache_clear(self) -> dict:
1003
- """
1004
- Clear all remote cache entries.
1005
-
1006
- >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
1007
- >>> coop.legacy_remote_cache_create_many(cache_entries=entries)
1008
- >>> coop.legacy_remote_cache_clear()
1009
- {'status': 'success', 'deleted_entry_count': 10}
1010
- """
1011
- response = self._send_server_request(
1012
- uri="api/v0/remote-cache/legacy/delete-all",
1013
- method="DELETE",
1014
- )
1015
- self._resolve_server_response(response)
1016
- response_json = response.json()
1017
- deleted_entry_count = response_json.get("deleted_entry_count", 0)
1018
- if deleted_entry_count > 0:
1019
- self.legacy_remote_cache_create_log(
1020
- response,
1021
- description="Clear cache entries",
1022
- cache_entry_count=deleted_entry_count,
1023
- )
1024
- return response.json()
1025
-
1026
- def legacy_remote_cache_create_log(
1027
- self, response: requests.Response, description: str, cache_entry_count: int
1028
- ) -> Union[dict, None]:
1029
- """
1030
- If a remote cache action has been completed successfully,
1031
- log the action.
1032
- """
1033
- if 200 <= response.status_code < 300:
1034
- log_response = self._send_server_request(
1035
- uri="api/v0/remote-cache-log/legacy",
1036
- method="POST",
1037
- payload={
1038
- "description": description,
1039
- "cache_entry_count": cache_entry_count,
1040
- },
1041
- )
1042
- self._resolve_server_response(log_response)
1043
- return response.json()
1044
-
1045
- def legacy_remote_cache_clear_log(self) -> dict:
1046
- """
1047
- Clear all remote cache log entries.
1048
-
1049
- >>> coop.legacy_remote_cache_clear_log()
1050
- {'status': 'success'}
1051
- """
1052
- response = self._send_server_request(
1053
- uri="api/v0/remote-cache-log/legacy/delete-all",
1054
- method="DELETE",
1055
- )
1056
- self._resolve_server_response(response)
1057
- return response.json()
1058
-
1059
961
  def remote_inference_create(
1060
962
  self,
1061
963
  job: "Jobs",
@@ -1142,7 +1044,10 @@ class Coop(CoopFunctionsMixin):
1142
1044
  )
1143
1045
 
1144
1046
  def remote_inference_get(
1145
- self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
1047
+ self,
1048
+ job_uuid: Optional[str] = None,
1049
+ results_uuid: Optional[str] = None,
1050
+ include_json_string: Optional[bool] = False,
1146
1051
  ) -> RemoteInferenceResponse:
1147
1052
  """
1148
1053
  Get the status and details of a remote inference job.
@@ -1154,6 +1059,7 @@ class Coop(CoopFunctionsMixin):
1154
1059
  job_uuid (str, optional): The UUID of the remote job to check
1155
1060
  results_uuid (str, optional): The UUID of the results associated with the job
1156
1061
  (can be used if you only have the results UUID)
1062
+ include_json_string (bool, optional): If True, include the json string for the job in the response
1157
1063
 
1158
1064
  Returns:
1159
1065
  RemoteInferenceResponse: Information about the job including:
@@ -1166,6 +1072,7 @@ class Coop(CoopFunctionsMixin):
1166
1072
  - reason: Reason for failure (if applicable)
1167
1073
  - credits_consumed: Credits used for the job execution
1168
1074
  - version: EDSL version used for the job
1075
+ - job_json_string: The json string for the job (if include_json_string is True)
1169
1076
 
1170
1077
  Raises:
1171
1078
  ValueError: If neither job_uuid nor results_uuid is provided
@@ -1222,14 +1129,119 @@ class Coop(CoopFunctionsMixin):
1222
1129
  "results_url": results_url,
1223
1130
  "latest_error_report_uuid": latest_error_report_uuid,
1224
1131
  "latest_error_report_url": latest_error_report_url,
1132
+ "latest_failure_description": data.get("latest_failure_details"),
1225
1133
  "status": data.get("status"),
1226
1134
  "reason": data.get("latest_failure_reason"),
1227
1135
  "credits_consumed": data.get("price"),
1228
1136
  "version": data.get("version"),
1137
+ "job_json_string": (
1138
+ data.get("job_json_string") if include_json_string else None
1139
+ ),
1229
1140
  }
1230
1141
  )
1231
1142
 
1232
- def get_running_jobs(self) -> list[str]:
1143
+ def _validate_remote_job_status_types(
1144
+ self, status: Union[RemoteJobStatus, List[RemoteJobStatus]]
1145
+ ) -> List[RemoteJobStatus]:
1146
+ """
1147
+ Validate visibility types and return a list of valid types.
1148
+
1149
+ Args:
1150
+ visibility: Single visibility type or list of visibility types to validate
1151
+
1152
+ Returns:
1153
+ List of validated visibility types
1154
+
1155
+ Raises:
1156
+ CoopValueError: If any visibility type is invalid
1157
+ """
1158
+ valid_status_types = [
1159
+ "queued",
1160
+ "running",
1161
+ "completed",
1162
+ "failed",
1163
+ "cancelled",
1164
+ "cancelling",
1165
+ "partial_failed",
1166
+ ]
1167
+ if isinstance(status, list):
1168
+ invalid_statuses = [s for s in status if s not in valid_status_types]
1169
+ if invalid_statuses:
1170
+ raise CoopValueError(
1171
+ f"Invalid status type(s): {invalid_statuses}. "
1172
+ f"Valid types are: {valid_status_types}"
1173
+ )
1174
+ return status
1175
+ else:
1176
+ if status not in valid_status_types:
1177
+ raise CoopValueError(
1178
+ f"Invalid status type: {status}. "
1179
+ f"Valid types are: {valid_status_types}"
1180
+ )
1181
+ return [status]
1182
+
1183
+ def remote_inference_list(
1184
+ self,
1185
+ status: Union[RemoteJobStatus, List[RemoteJobStatus], None] = None,
1186
+ search_query: Union[str, None] = None,
1187
+ page: int = 1,
1188
+ page_size: int = 10,
1189
+ sort_ascending: bool = False,
1190
+ ) -> "CoopJobsObjects":
1191
+ """
1192
+ Retrieve jobs owned by the user.
1193
+
1194
+ Notes:
1195
+ - search_query only works with the description field.
1196
+ - If sort_ascending is False, then the most recently created jobs are returned first.
1197
+ """
1198
+ from ..scenarios import Scenario
1199
+
1200
+ if page < 1:
1201
+ raise CoopValueError("The page must be greater than or equal to 1.")
1202
+ if page_size < 1:
1203
+ raise CoopValueError("The page size must be greater than or equal to 1.")
1204
+ if page_size > 100:
1205
+ raise CoopValueError("The page size must be less than or equal to 100.")
1206
+
1207
+ params = {
1208
+ "page": page,
1209
+ "page_size": page_size,
1210
+ "sort_ascending": sort_ascending,
1211
+ }
1212
+ if status:
1213
+ params["status"] = self._validate_remote_job_status_types(status)
1214
+ if search_query:
1215
+ params["search_query"] = search_query
1216
+
1217
+ response = self._send_server_request(
1218
+ uri="api/v0/remote-inference/list",
1219
+ method="GET",
1220
+ params=params,
1221
+ )
1222
+ self._resolve_server_response(response)
1223
+ content = response.json()
1224
+ jobs = []
1225
+ for o in content:
1226
+ job = Scenario(
1227
+ {
1228
+ "uuid": o.get("uuid"),
1229
+ "description": o.get("description"),
1230
+ "status": o.get("status"),
1231
+ "cost_credits": o.get("cost_credits"),
1232
+ "iterations": o.get("iterations"),
1233
+ "results_uuid": o.get("results_uuid"),
1234
+ "latest_error_report_uuid": o.get("latest_error_report_uuid"),
1235
+ "latest_failure_reason": o.get("latest_failure_reason"),
1236
+ "version": o.get("version"),
1237
+ "created_ts": o.get("created_ts"),
1238
+ }
1239
+ )
1240
+ jobs.append(job)
1241
+
1242
+ return CoopJobsObjects(jobs)
1243
+
1244
+ def get_running_jobs(self) -> List[str]:
1233
1245
  """
1234
1246
  Get a list of currently running job IDs.
1235
1247
 
@@ -1442,7 +1454,7 @@ class Coop(CoopFunctionsMixin):
1442
1454
  data = response.json()
1443
1455
  return ServiceToModelsMapping(data)
1444
1456
 
1445
- def fetch_working_models(self) -> list[dict]:
1457
+ def fetch_working_models(self) -> List[dict]:
1446
1458
  """
1447
1459
  Fetch a list of working models from Coop.
1448
1460
 
@@ -1488,6 +1500,28 @@ class Coop(CoopFunctionsMixin):
1488
1500
  data = response.json()
1489
1501
  return data
1490
1502
 
1503
+ def get_uuid_from_hash(self, hash_value: str) -> str:
1504
+ """
1505
+ Retrieve the UUID for an object based on its hash.
1506
+
1507
+ This method calls the remote endpoint to get the UUID associated with an object hash.
1508
+
1509
+ Args:
1510
+ hash_value (str): The hash value of the object to look up
1511
+
1512
+ Returns:
1513
+ str: The UUID of the object if found
1514
+
1515
+ Raises:
1516
+ CoopServerResponseError: If the object is not found or there's an error
1517
+ communicating with the server
1518
+ """
1519
+ response = self._send_server_request(
1520
+ uri=f"api/v0/object/hash/{hash_value}", method="GET"
1521
+ )
1522
+ self._resolve_server_response(response)
1523
+ return response.json().get("uuid")
1524
+
1491
1525
  def _display_login_url(
1492
1526
  self, edsl_auth_token: str, link_description: Optional[str] = None
1493
1527
  ):
@@ -1603,7 +1637,7 @@ def main():
1603
1637
  coop.get(response.get("uuid"), expected_object_type="question")
1604
1638
  coop.get(response.get("url"))
1605
1639
  coop.create(QuestionMultipleChoice.example())
1606
- coop.get_all("question")
1640
+ coop.list("question")
1607
1641
  coop.patch(response.get("uuid"), visibility="private")
1608
1642
  coop.patch(response.get("uuid"), description="hey")
1609
1643
  coop.patch(response.get("uuid"), value=QuestionFreeText.example())
@@ -1638,7 +1672,7 @@ def main():
1638
1672
  for object_type, cls in OBJECTS:
1639
1673
  print(f"Testing {object_type} objects")
1640
1674
  # 1. Delete existing objects
1641
- existing_objects = coop.get_all(object_type)
1675
+ existing_objects = coop.list(object_type)
1642
1676
  for item in existing_objects:
1643
1677
  coop.delete(item.get("uuid"))
1644
1678
  # 2. Create new objects
@@ -1650,7 +1684,7 @@ def main():
1650
1684
  cls.example(), visibility="unlisted", description="hey"
1651
1685
  )
1652
1686
  # 3. Retrieve all objects
1653
- objects = coop.get_all(object_type)
1687
+ objects = coop.list(object_type)
1654
1688
  assert len(objects) == 4
1655
1689
  # 4. Try to retrieve an item that does not exist
1656
1690
  try:
@@ -1669,7 +1703,7 @@ def main():
1669
1703
  # 7. Delete all objects
1670
1704
  for item in objects:
1671
1705
  coop.delete(item.get("uuid"))
1672
- assert len(coop.get_all(object_type)) == 0
1706
+ assert len(coop.list(object_type)) == 0
1673
1707
 
1674
1708
  ##############
1675
1709
  # C. Remote Cache