edsl 0.1.60__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/coop/coop.py CHANGED
@@ -3,8 +3,9 @@ import base64
3
3
  import json
4
4
  import requests
5
5
  import time
6
+ import os
6
7
 
7
- from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
8
+ from typing import Any, Dict, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
8
9
  from uuid import UUID
9
10
 
10
11
  from .. import __version__
@@ -35,6 +36,7 @@ from .utils import (
35
36
  from .coop_functions import CoopFunctionsMixin
36
37
  from .coop_regular_objects import CoopRegularObjects
37
38
  from .coop_jobs_objects import CoopJobsObjects
39
+ from .coop_prolific_filters import CoopProlificFilters
38
40
  from .ep_key_handling import ExpectedParrotKeyHandler
39
41
 
40
42
  from ..inference_services.data_structures import ServiceToModelsMapping
@@ -66,7 +68,6 @@ class JobRunInterviewDetails(TypedDict):
66
68
 
67
69
 
68
70
  class LatestJobRunDetails(TypedDict):
69
-
70
71
  # For running, completed, and partially failed jobs
71
72
  interview_details: Optional[JobRunInterviewDetails] = None
72
73
 
@@ -296,9 +297,9 @@ class Coop(CoopFunctionsMixin):
296
297
  message = str(response.json().get("detail"))
297
298
  except json.JSONDecodeError:
298
299
  raise CoopServerResponseError(
299
- f"Server returned status code {response.status_code}."
300
- "JSON response could not be decoded.",
301
- "The server response was: " + response.text,
300
+ f"Server returned status code {response.status_code}. "
301
+ f"JSON response could not be decoded. "
302
+ f"The server response was: {response.text}"
302
303
  )
303
304
  # print(response.text)
304
305
  if "The API key you provided is invalid" in message and check_api_key:
@@ -597,7 +598,7 @@ class Coop(CoopFunctionsMixin):
597
598
  else:
598
599
  from .exceptions import CoopResponseError
599
600
 
600
- raise CoopResponseError("No signed url was provided received")
601
+ raise CoopResponseError("No signed url was provided.")
601
602
 
602
603
  response = requests.put(
603
604
  signed_url, data=json_data.encode(), headers=headers
@@ -694,6 +695,35 @@ class Coop(CoopFunctionsMixin):
694
695
  """
695
696
  obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
696
697
 
698
+ # Handle alias-based retrieval with new/old format detection
699
+ if not obj_uuid and owner_username and alias:
700
+ # First, get object info to determine format and UUID
701
+ info_response = self._send_server_request(
702
+ uri="api/v0/object/alias/info",
703
+ method="GET",
704
+ params={"owner_username": owner_username, "alias": alias},
705
+ )
706
+ self._resolve_server_response(info_response)
707
+ info_data = info_response.json()
708
+
709
+ obj_uuid = info_data.get("uuid")
710
+ is_new_format = info_data.get("is_new_format", False)
711
+
712
+ # Validate object type if expected
713
+ if expected_object_type:
714
+ actual_object_type = info_data.get("object_type")
715
+ if actual_object_type != expected_object_type:
716
+ from .exceptions import CoopObjectTypeError
717
+
718
+ raise CoopObjectTypeError(
719
+ f"Expected {expected_object_type=} but got {actual_object_type=}"
720
+ )
721
+
722
+ # Use pull method for new format objects
723
+ if is_new_format:
724
+ return self.pull(obj_uuid, expected_object_type)
725
+
726
+ # Handle UUID-based retrieval or legacy alias objects
697
727
  if obj_uuid:
698
728
  response = self._send_server_request(
699
729
  uri="api/v0/object",
@@ -915,6 +945,39 @@ class Coop(CoopFunctionsMixin):
915
945
 
916
946
  obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
917
947
 
948
+ # If we're updating the value, we need to check the storage format
949
+ if value:
950
+ # If we don't have a UUID but have an alias, get the UUID and format info first
951
+ if not obj_uuid and owner_username and obj_alias:
952
+ # Get object info including UUID and format
953
+ info_response = self._send_server_request(
954
+ uri="api/v0/object/alias/info",
955
+ method="GET",
956
+ params={"owner_username": owner_username, "alias": obj_alias},
957
+ )
958
+ self._resolve_server_response(info_response)
959
+ info_data = info_response.json()
960
+
961
+ obj_uuid = info_data.get("uuid")
962
+ is_new_format = info_data.get("is_new_format", False)
963
+ else:
964
+ # We have a UUID, check the format
965
+ format_check_response = self._send_server_request(
966
+ uri="api/v0/object/check-format",
967
+ method="POST",
968
+ payload={"object_uuid": str(obj_uuid)},
969
+ )
970
+ self._resolve_server_response(format_check_response)
971
+ format_data = format_check_response.json()
972
+ is_new_format = format_data.get("is_new_format", False)
973
+
974
+ if is_new_format:
975
+ # Handle new format objects: update metadata first, then upload content
976
+ return self._patch_new_format_object(
977
+ obj_uuid, description, alias, value, visibility
978
+ )
979
+
980
+ # Handle traditional format objects or metadata-only updates
918
981
  if obj_uuid:
919
982
  uri = "api/v0/object"
920
983
  params = {"uuid": obj_uuid}
@@ -944,6 +1007,80 @@ class Coop(CoopFunctionsMixin):
944
1007
  self._resolve_server_response(response)
945
1008
  return response.json()
946
1009
 
1010
+ def _patch_new_format_object(
1011
+ self,
1012
+ obj_uuid: UUID,
1013
+ description: Optional[str],
1014
+ alias: Optional[str],
1015
+ value: EDSLObject,
1016
+ visibility: Optional[VisibilityType],
1017
+ ) -> dict:
1018
+ """
1019
+ Handle patching of objects stored in the new format (GCS).
1020
+ """
1021
+ # Step 1: Update metadata only (no json_string)
1022
+ if description is not None or alias is not None or visibility is not None:
1023
+ metadata_response = self._send_server_request(
1024
+ uri="api/v0/object",
1025
+ method="PATCH",
1026
+ params={"uuid": obj_uuid},
1027
+ payload={
1028
+ "description": description,
1029
+ "alias": alias,
1030
+ "json_string": None, # Don't send content to traditional endpoint
1031
+ "visibility": visibility,
1032
+ },
1033
+ )
1034
+ self._resolve_server_response(metadata_response)
1035
+
1036
+ # Step 2: Get signed upload URL for content update
1037
+ upload_url_response = self._send_server_request(
1038
+ uri="api/v0/object/upload-url",
1039
+ method="POST",
1040
+ payload={"object_uuid": str(obj_uuid)},
1041
+ )
1042
+ self._resolve_server_response(upload_url_response)
1043
+ upload_data = upload_url_response.json()
1044
+
1045
+ # Step 3: Upload the object content to GCS
1046
+ signed_url = upload_data.get("signed_url")
1047
+ if not signed_url:
1048
+ raise CoopServerResponseError("Failed to get signed upload URL")
1049
+
1050
+ json_content = json.dumps(
1051
+ value.to_dict(),
1052
+ default=self._json_handle_none,
1053
+ allow_nan=False,
1054
+ )
1055
+
1056
+ # Upload to GCS using signed URL
1057
+ gcs_response = requests.put(
1058
+ signed_url,
1059
+ data=json_content,
1060
+ headers={"Content-Type": "application/json"},
1061
+ )
1062
+
1063
+ if gcs_response.status_code != 200:
1064
+ raise CoopServerResponseError(
1065
+ f"Failed to upload object to GCS: {gcs_response.status_code}"
1066
+ )
1067
+
1068
+ # Step 4: Confirm upload and trigger queue worker processing
1069
+ confirm_response = self._send_server_request(
1070
+ uri="api/v0/object/confirm-upload",
1071
+ method="POST",
1072
+ payload={"object_uuid": str(obj_uuid)},
1073
+ )
1074
+ self._resolve_server_response(confirm_response)
1075
+ confirm_data = confirm_response.json()
1076
+
1077
+ return {
1078
+ "status": "success",
1079
+ "message": "Object updated successfully (new format - uploaded to GCS and processing triggered)",
1080
+ "object_uuid": str(obj_uuid),
1081
+ "processing_started": confirm_data.get("processing_started", False),
1082
+ }
1083
+
947
1084
  ################
948
1085
  # Remote Cache
949
1086
  ################
@@ -1025,6 +1162,115 @@ class Coop(CoopFunctionsMixin):
1025
1162
  is handled by Expected Parrot's infrastructure, and you can check the status
1026
1163
  and retrieve results later.
1027
1164
 
1165
+ Parameters:
1166
+ job (Jobs): The EDSL job to run in the cloud
1167
+ description (str, optional): A human-readable description of the job
1168
+ status (RemoteJobStatus): Initial status, should be "queued" for normal use
1169
+ Possible values: "queued", "running", "completed", "failed"
1170
+ visibility (VisibilityType): Access level for the job information. One of:
1171
+ - "private": Only accessible by the owner
1172
+ - "public": Accessible by anyone
1173
+ - "unlisted": Accessible with the link, but not listed publicly
1174
+ initial_results_visibility (VisibilityType): Access level for the job results
1175
+ iterations (int): Number of times to run each interview (default: 1)
1176
+ fresh (bool): If True, ignore existing cache entries and generate new results
1177
+
1178
+ Returns:
1179
+ RemoteInferenceCreationInfo: Information about the created job including:
1180
+ - uuid: The unique identifier for the job
1181
+ - description: The job description
1182
+ - status: Current status of the job
1183
+ - iterations: Number of iterations for each interview
1184
+ - visibility: Access level for the job
1185
+ - version: EDSL version used to create the job
1186
+
1187
+ Raises:
1188
+ CoopServerResponseError: If there's an error communicating with the server
1189
+
1190
+ Notes:
1191
+ - Remote jobs run asynchronously and may take time to complete
1192
+ - Use remote_inference_get() with the returned UUID to check status
1193
+ - Credits are consumed based on the complexity of the job
1194
+
1195
+ Example:
1196
+ >>> from edsl.jobs import Jobs
1197
+ >>> job = Jobs.example()
1198
+ >>> job_info = coop.remote_inference_create(job=job, description="My job")
1199
+ >>> print(f"Job created with UUID: {job_info['uuid']}")
1200
+ """
1201
+ response = self._send_server_request(
1202
+ uri="api/v0/new-remote-inference",
1203
+ method="POST",
1204
+ payload={
1205
+ "json_string": "offloaded",
1206
+ "description": description,
1207
+ "status": status,
1208
+ "iterations": iterations,
1209
+ "visibility": visibility,
1210
+ "version": self._edsl_version,
1211
+ "initial_results_visibility": initial_results_visibility,
1212
+ "fresh": fresh,
1213
+ },
1214
+ )
1215
+ self._resolve_server_response(response)
1216
+ response_json = response.json()
1217
+ upload_signed_url = response_json.get("upload_signed_url")
1218
+ if not upload_signed_url:
1219
+ from .exceptions import CoopResponseError
1220
+
1221
+ raise CoopResponseError("No signed url was provided.")
1222
+
1223
+ response = requests.put(
1224
+ upload_signed_url,
1225
+ data=json.dumps(
1226
+ job.to_dict(),
1227
+ default=self._json_handle_none,
1228
+ ).encode(),
1229
+ headers={"Content-Type": "application/json"},
1230
+ )
1231
+ self._resolve_gcs_response(response)
1232
+
1233
+ job_uuid = response_json.get("job_uuid")
1234
+
1235
+ response = self._send_server_request(
1236
+ uri="api/v0/new-remote-inference/uploaded",
1237
+ method="POST",
1238
+ payload={
1239
+ "job_uuid": job_uuid,
1240
+ "message": "Job uploaded successfully",
1241
+ },
1242
+ )
1243
+ response_json = response.json()
1244
+
1245
+ return RemoteInferenceCreationInfo(
1246
+ **{
1247
+ "uuid": response_json.get("job_uuid"),
1248
+ "description": response_json.get("description", ""),
1249
+ "status": response_json.get("status"),
1250
+ "iterations": response_json.get("iterations", ""),
1251
+ "visibility": response_json.get("visibility", ""),
1252
+ "version": self._edsl_version,
1253
+ }
1254
+ )
1255
+
1256
+ def old_remote_inference_create(
1257
+ self,
1258
+ job: "Jobs",
1259
+ description: Optional[str] = None,
1260
+ status: RemoteJobStatus = "queued",
1261
+ visibility: Optional[VisibilityType] = "unlisted",
1262
+ initial_results_visibility: Optional[VisibilityType] = "unlisted",
1263
+ iterations: Optional[int] = 1,
1264
+ fresh: Optional[bool] = False,
1265
+ ) -> RemoteInferenceCreationInfo:
1266
+ """
1267
+ Create a remote inference job for execution in the Expected Parrot cloud.
1268
+
1269
+ This method sends a job to be executed in the cloud, which can be more efficient
1270
+ for large jobs or when you want to run jobs in the background. The job execution
1271
+ is handled by Expected Parrot's infrastructure, and you can check the status
1272
+ and retrieve results later.
1273
+
1028
1274
  Parameters:
1029
1275
  job (Jobs): The EDSL job to run in the cloud
1030
1276
  description (str, optional): A human-readable description of the job
@@ -1208,79 +1454,233 @@ class Coop(CoopFunctionsMixin):
1208
1454
  }
1209
1455
  )
1210
1456
 
1211
- def _validate_remote_job_status_types(
1212
- self, status: Union[RemoteJobStatus, List[RemoteJobStatus]]
1213
- ) -> List[RemoteJobStatus]:
1457
+ def new_remote_inference_get(
1458
+ self,
1459
+ job_uuid: Optional[str] = None,
1460
+ results_uuid: Optional[str] = None,
1461
+ include_json_string: Optional[bool] = False,
1462
+ ) -> RemoteInferenceResponse:
1214
1463
  """
1215
- Validate visibility types and return a list of valid types.
1464
+ Get the status and details of a remote inference job.
1216
1465
 
1217
- Args:
1218
- visibility: Single visibility type or list of visibility types to validate
1466
+ This method retrieves the current status and information about a remote job,
1467
+ including links to results if the job has completed successfully.
1468
+
1469
+ Parameters:
1470
+ job_uuid (str, optional): The UUID of the remote job to check
1471
+ results_uuid (str, optional): The UUID of the results associated with the job
1472
+ (can be used if you only have the results UUID)
1473
+ include_json_string (bool, optional): If True, include the json string for the job in the response
1219
1474
 
1220
1475
  Returns:
1221
- List of validated visibility types
1476
+ RemoteInferenceResponse: Information about the job including:
1477
+ job_uuid: The unique identifier for the job
1478
+ results_uuid: The UUID of the results
1479
+ results_url: URL to access the results
1480
+ status: Current status ("queued", "running", "completed", "failed")
1481
+ version: EDSL version used for the job
1482
+ job_json_string: The json string for the job (if include_json_string is True)
1483
+ latest_job_run_details: Metadata about the job status
1484
+ interview_details: Metadata about the job interview status (for jobs that have reached running status)
1485
+ total_interviews: The total number of interviews in the job
1486
+ completed_interviews: The number of completed interviews
1487
+ interviews_with_exceptions: The number of completed interviews that have exceptions
1488
+ exception_counters: A list of exception counts for the job
1489
+ exception_type: The type of exception
1490
+ inference_service: The inference service
1491
+ model: The model
1492
+ question_name: The name of the question
1493
+ exception_count: The number of exceptions
1494
+ failure_reason: The reason the job failed (failed jobs only)
1495
+ failure_description: The description of the failure (failed jobs only)
1496
+ error_report_uuid: The UUID of the error report (partially failed jobs only)
1497
+ cost_credits: The cost of the job run in credits
1498
+ cost_usd: The cost of the job run in USD
1499
+ expenses: The expenses incurred by the job run
1500
+ service: The service
1501
+ model: The model
1502
+ token_type: The type of token (input or output)
1503
+ price_per_million_tokens: The price per million tokens
1504
+ tokens_count: The number of tokens consumed
1505
+ cost_credits: The cost of the service/model/token type combination in credits
1506
+ cost_usd: The cost of the service/model/token type combination in USD
1222
1507
 
1223
1508
  Raises:
1224
- CoopValueError: If any visibility type is invalid
1225
- """
1226
- valid_status_types = [
1227
- "queued",
1228
- "running",
1229
- "completed",
1230
- "failed",
1231
- "cancelled",
1232
- "cancelling",
1233
- "partial_failed",
1234
- ]
1235
- if isinstance(status, list):
1236
- invalid_statuses = [s for s in status if s not in valid_status_types]
1237
- if invalid_statuses:
1238
- raise CoopValueError(
1239
- f"Invalid status type(s): {invalid_statuses}. "
1240
- f"Valid types are: {valid_status_types}"
1241
- )
1242
- return status
1243
- else:
1244
- if status not in valid_status_types:
1245
- raise CoopValueError(
1246
- f"Invalid status type: {status}. "
1247
- f"Valid types are: {valid_status_types}"
1248
- )
1249
- return [status]
1250
-
1251
- def remote_inference_list(
1252
- self,
1253
- status: Union[RemoteJobStatus, List[RemoteJobStatus], None] = None,
1254
- search_query: Union[str, None] = None,
1255
- page: int = 1,
1256
- page_size: int = 10,
1257
- sort_ascending: bool = False,
1258
- ) -> "CoopJobsObjects":
1259
- """
1260
- Retrieve jobs owned by the user.
1509
+ ValueError: If neither job_uuid nor results_uuid is provided
1510
+ CoopServerResponseError: If there's an error communicating with the server
1261
1511
 
1262
1512
  Notes:
1263
- - search_query only works with the description field.
1264
- - If sort_ascending is False, then the most recently created jobs are returned first.
1513
+ - Either job_uuid or results_uuid must be provided
1514
+ - If both are provided, job_uuid takes precedence
1515
+ - For completed jobs, you can use the results_url to view or download results
1516
+ - For failed jobs, check the latest_error_report_url for debugging information
1517
+
1518
+ Example:
1519
+ >>> job_status = coop.new_remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
1520
+ >>> print(f"Job status: {job_status['status']}")
1521
+ >>> if job_status['status'] == 'completed':
1522
+ ... print(f"Results available at: {job_status['results_url']}")
1265
1523
  """
1266
- from ..scenarios import Scenario
1524
+ if job_uuid is None and results_uuid is None:
1525
+ from .exceptions import CoopValueError
1267
1526
 
1268
- if page < 1:
1269
- raise CoopValueError("The page must be greater than or equal to 1.")
1270
- if page_size < 1:
1271
- raise CoopValueError("The page size must be greater than or equal to 1.")
1272
- if page_size > 100:
1273
- raise CoopValueError("The page size must be less than or equal to 100.")
1527
+ raise CoopValueError("Either job_uuid or results_uuid must be provided.")
1528
+ elif job_uuid is not None:
1529
+ params = {"job_uuid": job_uuid}
1530
+ else:
1531
+ params = {"results_uuid": results_uuid}
1532
+ if include_json_string:
1533
+ params["include_json_string"] = include_json_string
1274
1534
 
1275
- params = {
1276
- "page": page,
1277
- "page_size": page_size,
1278
- "sort_ascending": sort_ascending,
1279
- }
1280
- if status:
1281
- params["status"] = self._validate_remote_job_status_types(status)
1282
- if search_query:
1283
- params["search_query"] = search_query
1535
+ response = self._send_server_request(
1536
+ uri="api/v0/remote-inference",
1537
+ method="GET",
1538
+ params=params,
1539
+ )
1540
+ self._resolve_server_response(response)
1541
+ data = response.json()
1542
+
1543
+ results_uuid = data.get("results_uuid")
1544
+
1545
+ if results_uuid is None:
1546
+ results_url = None
1547
+ else:
1548
+ results_url = f"{self.url}/content/{results_uuid}"
1549
+
1550
+ latest_job_run_details = data.get("latest_job_run_details", {})
1551
+ if data.get("status") == "partial_failed":
1552
+ latest_error_report_uuid = latest_job_run_details.get("error_report_uuid")
1553
+ if latest_error_report_uuid is None:
1554
+ latest_job_run_details["error_report_url"] = None
1555
+ else:
1556
+ latest_error_report_url = (
1557
+ f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
1558
+ )
1559
+ latest_job_run_details["error_report_url"] = latest_error_report_url
1560
+
1561
+ json_string = data.get("job_json_string")
1562
+
1563
+ # The job has been offloaded to GCS
1564
+ if include_json_string and json_string == "offloaded":
1565
+
1566
+ # Attempt to fetch JSON string from GCS
1567
+ response = self._send_server_request(
1568
+ uri="api/v0/remote-inference/pull",
1569
+ method="POST",
1570
+ payload={"job_uuid": job_uuid},
1571
+ )
1572
+ # Handle any errors in the response
1573
+ self._resolve_server_response(response)
1574
+ if "signed_url" not in response.json():
1575
+ from .exceptions import CoopResponseError
1576
+
1577
+ raise CoopResponseError("No signed url was provided.")
1578
+ signed_url = response.json().get("signed_url")
1579
+
1580
+ if signed_url == "": # The job is in legacy format
1581
+ job_json = json_string
1582
+
1583
+ try:
1584
+ response = requests.get(signed_url)
1585
+ self._resolve_gcs_response(response)
1586
+ job_json = json.dumps(response.json())
1587
+ except Exception:
1588
+ job_json = json_string
1589
+
1590
+ # If the job is in legacy format, we should already have the JSON string
1591
+ # from the first API call
1592
+ elif include_json_string and not json_string == "offloaded":
1593
+ job_json = json_string
1594
+
1595
+ # If include_json_string is False, we don't need the JSON string at all
1596
+ else:
1597
+ job_json = None
1598
+
1599
+ return RemoteInferenceResponse(
1600
+ **{
1601
+ "job_uuid": data.get("job_uuid"),
1602
+ "results_uuid": results_uuid,
1603
+ "results_url": results_url,
1604
+ "status": data.get("status"),
1605
+ "version": data.get("version"),
1606
+ "job_json_string": job_json,
1607
+ "latest_job_run_details": latest_job_run_details,
1608
+ }
1609
+ )
1610
+
1611
+ def _validate_remote_job_status_types(
1612
+ self, status: Union[RemoteJobStatus, List[RemoteJobStatus]]
1613
+ ) -> List[RemoteJobStatus]:
1614
+ """
1615
+ Validate visibility types and return a list of valid types.
1616
+
1617
+ Args:
1618
+ visibility: Single visibility type or list of visibility types to validate
1619
+
1620
+ Returns:
1621
+ List of validated visibility types
1622
+
1623
+ Raises:
1624
+ CoopValueError: If any visibility type is invalid
1625
+ """
1626
+ valid_status_types = [
1627
+ "queued",
1628
+ "running",
1629
+ "completed",
1630
+ "failed",
1631
+ "cancelled",
1632
+ "cancelling",
1633
+ "partial_failed",
1634
+ ]
1635
+ if isinstance(status, list):
1636
+ invalid_statuses = [s for s in status if s not in valid_status_types]
1637
+ if invalid_statuses:
1638
+ raise CoopValueError(
1639
+ f"Invalid status type(s): {invalid_statuses}. "
1640
+ f"Valid types are: {valid_status_types}"
1641
+ )
1642
+ return status
1643
+ else:
1644
+ if status not in valid_status_types:
1645
+ raise CoopValueError(
1646
+ f"Invalid status type: {status}. "
1647
+ f"Valid types are: {valid_status_types}"
1648
+ )
1649
+ return [status]
1650
+
1651
+ def remote_inference_list(
1652
+ self,
1653
+ status: Union[RemoteJobStatus, List[RemoteJobStatus], None] = None,
1654
+ search_query: Union[str, None] = None,
1655
+ page: int = 1,
1656
+ page_size: int = 10,
1657
+ sort_ascending: bool = False,
1658
+ ) -> "CoopJobsObjects":
1659
+ """
1660
+ Retrieve jobs owned by the user.
1661
+
1662
+ Notes:
1663
+ - search_query only works with the description field.
1664
+ - If sort_ascending is False, then the most recently created jobs are returned first.
1665
+ """
1666
+ from ..scenarios import Scenario
1667
+
1668
+ if page < 1:
1669
+ raise CoopValueError("The page must be greater than or equal to 1.")
1670
+ if page_size < 1:
1671
+ raise CoopValueError("The page size must be greater than or equal to 1.")
1672
+ if page_size > 100:
1673
+ raise CoopValueError("The page size must be less than or equal to 100.")
1674
+
1675
+ params = {
1676
+ "page": page,
1677
+ "page_size": page_size,
1678
+ "sort_ascending": sort_ascending,
1679
+ }
1680
+ if status:
1681
+ params["status"] = self._validate_remote_job_status_types(status)
1682
+ if search_query:
1683
+ params["search_query"] = search_query
1284
1684
 
1285
1685
  response = self._send_server_request(
1286
1686
  uri="api/v0/remote-inference/list",
@@ -1368,25 +1768,57 @@ class Coop(CoopFunctionsMixin):
1368
1768
  def create_project(
1369
1769
  self,
1370
1770
  survey: "Survey",
1771
+ scenario_list: Optional["ScenarioList"] = None,
1772
+ scenario_list_method: Optional[
1773
+ Literal["randomize", "loop", "single_scenario"]
1774
+ ] = None,
1371
1775
  project_name: str = "Project",
1372
1776
  survey_description: Optional[str] = None,
1373
1777
  survey_alias: Optional[str] = None,
1374
1778
  survey_visibility: Optional[VisibilityType] = "unlisted",
1779
+ scenario_list_description: Optional[str] = None,
1780
+ scenario_list_alias: Optional[str] = None,
1781
+ scenario_list_visibility: Optional[VisibilityType] = "unlisted",
1375
1782
  ):
1376
1783
  """
1377
1784
  Create a survey object on Coop, then create a project from the survey.
1378
1785
  """
1379
- survey_details = self.create(
1786
+ if scenario_list is None and scenario_list_method is not None:
1787
+ raise CoopValueError(
1788
+ "You must specify both a scenario list and a scenario list method to use scenarios with your survey."
1789
+ )
1790
+ elif scenario_list is not None and scenario_list_method is None:
1791
+ raise CoopValueError(
1792
+ "You must specify both a scenario list and a scenario list method to use scenarios with your survey."
1793
+ )
1794
+ survey_details = self.push(
1380
1795
  object=survey,
1381
1796
  description=survey_description,
1382
1797
  alias=survey_alias,
1383
1798
  visibility=survey_visibility,
1384
1799
  )
1385
1800
  survey_uuid = survey_details.get("uuid")
1801
+ if scenario_list is not None:
1802
+ scenario_list_details = self.push(
1803
+ object=scenario_list,
1804
+ description=scenario_list_description,
1805
+ alias=scenario_list_alias,
1806
+ visibility=scenario_list_visibility,
1807
+ )
1808
+ scenario_list_uuid = scenario_list_details.get("uuid")
1809
+ else:
1810
+ scenario_list_uuid = None
1386
1811
  response = self._send_server_request(
1387
1812
  uri="api/v0/projects/create-from-survey",
1388
1813
  method="POST",
1389
- payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
1814
+ payload={
1815
+ "project_name": project_name,
1816
+ "survey_uuid": str(survey_uuid),
1817
+ "scenario_list_uuid": (
1818
+ str(scenario_list_uuid) if scenario_list_uuid is not None else None
1819
+ ),
1820
+ "scenario_list_method": scenario_list_method,
1821
+ },
1390
1822
  )
1391
1823
  self._resolve_server_response(response)
1392
1824
  response_json = response.json()
@@ -1413,14 +1845,26 @@ class Coop(CoopFunctionsMixin):
1413
1845
  return {
1414
1846
  "project_name": response_json.get("project_name"),
1415
1847
  "project_job_uuids": response_json.get("job_uuids"),
1848
+ "project_prolific_studies": [
1849
+ {
1850
+ "study_id": study.get("id"),
1851
+ "name": study.get("name"),
1852
+ "status": study.get("status"),
1853
+ "num_participants": study.get("total_available_places"),
1854
+ "places_taken": study.get("places_taken"),
1855
+ }
1856
+ for study in response_json.get("prolific_studies", [])
1857
+ ],
1416
1858
  }
1417
1859
 
1418
- def get_project_human_responses(
1860
+ def _turn_human_responses_into_results(
1419
1861
  self,
1420
- project_uuid: str,
1862
+ human_responses: List[dict],
1863
+ survey_json_string: str,
1864
+ scenario_list_json_string: Optional[str] = None,
1421
1865
  ) -> Union["Results", "ScenarioList"]:
1422
1866
  """
1423
- Return a Results object with the human responses for a project.
1867
+ Turn a list of human responses into a Results object.
1424
1868
 
1425
1869
  If generating the Results object fails, a ScenarioList will be returned instead.
1426
1870
  """
@@ -1430,16 +1874,19 @@ class Coop(CoopFunctionsMixin):
1430
1874
  from ..scenarios import Scenario, ScenarioList
1431
1875
  from ..surveys import Survey
1432
1876
 
1433
- response = self._send_server_request(
1434
- uri=f"api/v0/projects/{project_uuid}/human-responses",
1435
- method="GET",
1436
- )
1437
- self._resolve_server_response(response)
1438
- response_json = response.json()
1439
- human_responses = response_json.get("human_responses", [])
1440
-
1441
1877
  try:
1442
- agent_list = AgentList()
1878
+ survey = Survey.from_dict(json.loads(survey_json_string))
1879
+
1880
+ model = Model("test")
1881
+
1882
+ if scenario_list_json_string is not None:
1883
+ scenario_list = ScenarioList.from_dict(
1884
+ json.loads(scenario_list_json_string)
1885
+ )
1886
+ else:
1887
+ scenario_list = ScenarioList()
1888
+
1889
+ results = None
1443
1890
 
1444
1891
  for response in human_responses:
1445
1892
  response_uuid = response.get("response_uuid")
@@ -1449,8 +1896,14 @@ class Coop(CoopFunctionsMixin):
1449
1896
  )
1450
1897
 
1451
1898
  response_dict = json.loads(response.get("response_json_string"))
1899
+ agent_traits_json_string = response.get("agent_traits_json_string")
1900
+ scenario_uuid = response.get("scenario_uuid")
1901
+ if agent_traits_json_string is not None:
1902
+ agent_traits = json.loads(agent_traits_json_string)
1903
+ else:
1904
+ agent_traits = {}
1452
1905
 
1453
- a = Agent(name=response_uuid, instruction="")
1906
+ a = Agent(name=response_uuid, instruction="", traits=agent_traits)
1454
1907
 
1455
1908
  def create_answer_function(response_data):
1456
1909
  def f(self, question, scenario):
@@ -1458,27 +1911,38 @@ class Coop(CoopFunctionsMixin):
1458
1911
 
1459
1912
  return f
1460
1913
 
1914
+ scenario = None
1915
+ if scenario_uuid is not None:
1916
+ for s in scenario_list:
1917
+ if s.get("uuid") == scenario_uuid:
1918
+ scenario = s
1919
+ break
1920
+
1921
+ if scenario is None:
1922
+ raise RuntimeError("Scenario not found.")
1923
+
1461
1924
  a.add_direct_question_answering_method(
1462
1925
  create_answer_function(response_dict)
1463
1926
  )
1464
- agent_list.append(a)
1465
1927
 
1466
- survey_json_string = response_json.get("survey_json_string")
1467
- survey = Survey.from_dict(json.loads(survey_json_string))
1928
+ job = survey.by(a).by(model)
1468
1929
 
1469
- model = Model("test")
1470
- results = (
1471
- survey.by(agent_list)
1472
- .by(model)
1473
- .run(
1930
+ if scenario is not None:
1931
+ job = job.by(scenario)
1932
+
1933
+ question_results = job.run(
1474
1934
  cache=Cache(),
1475
1935
  disable_remote_cache=True,
1476
1936
  disable_remote_inference=True,
1477
1937
  print_exceptions=False,
1478
1938
  )
1479
- )
1939
+
1940
+ if results is None:
1941
+ results = question_results
1942
+ else:
1943
+ results = results + question_results
1480
1944
  return results
1481
- except Exception:
1945
+ except Exception as e:
1482
1946
  human_response_scenarios = []
1483
1947
  for response in human_responses:
1484
1948
  response_uuid = response.get("response_uuid")
@@ -1493,6 +1957,427 @@ class Coop(CoopFunctionsMixin):
1493
1957
  human_response_scenarios.append(scenario)
1494
1958
  return ScenarioList(human_response_scenarios)
1495
1959
 
1960
+ def get_project_human_responses(
1961
+ self,
1962
+ project_uuid: str,
1963
+ ) -> Union["Results", "ScenarioList"]:
1964
+ """
1965
+ Return a Results object with the human responses for a project.
1966
+
1967
+ If generating the Results object fails, a ScenarioList will be returned instead.
1968
+ """
1969
+ response = self._send_server_request(
1970
+ uri=f"api/v0/projects/{project_uuid}/human-responses",
1971
+ method="GET",
1972
+ )
1973
+ self._resolve_server_response(response)
1974
+ response_json = response.json()
1975
+ human_responses = response_json.get("human_responses", [])
1976
+ survey_json_string = response_json.get("survey_json_string")
1977
+ scenario_list_json_string = response_json.get("scenario_list_json_string")
1978
+
1979
+ return self._turn_human_responses_into_results(
1980
+ human_responses, survey_json_string, scenario_list_json_string
1981
+ )
1982
+
1983
+ def list_prolific_filters(self) -> "CoopProlificFilters":
1984
+ """
1985
+ Get a ScenarioList of supported Prolific filters. This list has several methods
1986
+ that you can use to create valid filter dicts for use with Coop.create_prolific_study().
1987
+
1988
+ Call find() to examine a specific filter by ID:
1989
+ >>> filters = coop.list_prolific_filters()
1990
+ >>> filters.find("age")
1991
+ Scenario(
1992
+ {
1993
+ "filter_id": "age",
1994
+ "type": "range",
1995
+ "range_filter_min": 18,
1996
+ "range_filter_max": 100,
1997
+ ...
1998
+ }
1999
+ )
2000
+
2001
+ Call create_study_filter() to create a valid filter dict:
2002
+ >>> filters.create_study_filter("age", min=30, max=40)
2003
+ {
2004
+ "filter_id": "age",
2005
+ "selected_range": {
2006
+ "lower": 30,
2007
+ "upper": 40,
2008
+ },
2009
+ }
2010
+ """
2011
+ from ..scenarios import Scenario
2012
+
2013
+ response = self._send_server_request(
2014
+ uri="api/v0/prolific-filters",
2015
+ method="GET",
2016
+ )
2017
+ self._resolve_server_response(response)
2018
+ response_json = response.json()
2019
+ filters = response_json.get("prolific_filters", [])
2020
+ filter_scenarios = []
2021
+ for filter in filters:
2022
+ filter_type = filter.get("type")
2023
+ question = filter.get("question")
2024
+ scenario = Scenario(
2025
+ {
2026
+ "filter_id": filter.get("filter_id"),
2027
+ "title": filter.get("title"),
2028
+ "question": (
2029
+ f"Participants were asked the following: {question}"
2030
+ if question
2031
+ else None
2032
+ ),
2033
+ "type": filter_type,
2034
+ "range_filter_min": (
2035
+ filter.get("min") if filter_type == "range" else None
2036
+ ),
2037
+ "range_filter_max": (
2038
+ filter.get("max") if filter_type == "range" else None
2039
+ ),
2040
+ "select_filter_num_options": (
2041
+ len(filter.get("choices", []))
2042
+ if filter_type == "select"
2043
+ else None
2044
+ ),
2045
+ "select_filter_options": (
2046
+ filter.get("choices") if filter_type == "select" else None
2047
+ ),
2048
+ }
2049
+ )
2050
+ filter_scenarios.append(scenario)
2051
+ return CoopProlificFilters(filter_scenarios)
2052
+
2053
+ @staticmethod
2054
+ def _validate_prolific_study_cost(
2055
+ estimated_completion_time_minutes: int, participant_payment_cents: int
2056
+ ) -> tuple[bool, float]:
2057
+ """
2058
+ If the cost of a Prolific study is below the threshold, return True.
2059
+ Otherwise, return False.
2060
+ The second value in the tuple is the cost of the study in USD per hour.
2061
+ """
2062
+ estimated_completion_time_hours = estimated_completion_time_minutes / 60
2063
+ participant_payment_usd = participant_payment_cents / 100
2064
+ cost_usd_per_hour = participant_payment_usd / estimated_completion_time_hours
2065
+
2066
+ # $8.00 USD per hour is the minimum amount for using Prolific
2067
+ if cost_usd_per_hour < 8:
2068
+ return True, cost_usd_per_hour
2069
+ else:
2070
+ return False, cost_usd_per_hour
2071
+
2072
+ def create_prolific_study(
2073
+ self,
2074
+ project_uuid: str,
2075
+ name: str,
2076
+ description: str,
2077
+ num_participants: int,
2078
+ estimated_completion_time_minutes: int,
2079
+ participant_payment_cents: int,
2080
+ device_compatibility: Optional[
2081
+ List[Literal["desktop", "tablet", "mobile"]]
2082
+ ] = None,
2083
+ peripheral_requirements: Optional[
2084
+ List[Literal["audio", "camera", "download", "microphone"]]
2085
+ ] = None,
2086
+ filters: Optional[List[Dict]] = None,
2087
+ ) -> dict:
2088
+ """
2089
+ Create a Prolific study for a project. Returns a dict with the study details.
2090
+
2091
+ To add filters to your study, you should first pull the list of supported
2092
+ filters using Coop.list_prolific_filters().
2093
+ Then, you can use the create_study_filter method of the returned
2094
+ CoopProlificFilters object to create a valid filter dict.
2095
+ """
2096
+ is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
2097
+ estimated_completion_time_minutes, participant_payment_cents
2098
+ )
2099
+ if is_underpayment:
2100
+ raise CoopValueError(
2101
+ f"The current participant payment of ${cost_usd_per_hour:.2f} USD per hour is below the minimum payment for using Prolific ($8.00 USD per hour)."
2102
+ )
2103
+
2104
+ response = self._send_server_request(
2105
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies",
2106
+ method="POST",
2107
+ payload={
2108
+ "name": name,
2109
+ "description": description,
2110
+ "total_available_places": num_participants,
2111
+ "estimated_completion_time": estimated_completion_time_minutes,
2112
+ "reward": participant_payment_cents,
2113
+ "device_compatibility": (
2114
+ ["desktop", "tablet", "mobile"]
2115
+ if device_compatibility is None
2116
+ else device_compatibility
2117
+ ),
2118
+ "peripheral_requirements": (
2119
+ [] if peripheral_requirements is None else peripheral_requirements
2120
+ ),
2121
+ "filters": [] if filters is None else filters,
2122
+ },
2123
+ )
2124
+ self._resolve_server_response(response)
2125
+ response_json = response.json()
2126
+ return {
2127
+ "study_id": response_json.get("study_id"),
2128
+ "status": response_json.get("status"),
2129
+ "admin_url": response_json.get("admin_url"),
2130
+ "respondent_url": response_json.get("respondent_url"),
2131
+ "name": response_json.get("name"),
2132
+ "description": response_json.get("description"),
2133
+ "num_participants": response_json.get("total_available_places"),
2134
+ "estimated_completion_time_minutes": response_json.get(
2135
+ "estimated_completion_time"
2136
+ ),
2137
+ "participant_payment_cents": response_json.get("reward"),
2138
+ "total_cost_cents": response_json.get("total_cost"),
2139
+ "device_compatibility": response_json.get("device_compatibility"),
2140
+ "peripheral_requirements": response_json.get("peripheral_requirements"),
2141
+ "filters": response_json.get("filters"),
2142
+ }
2143
+
2144
+ def update_prolific_study(
2145
+ self,
2146
+ project_uuid: str,
2147
+ study_id: str,
2148
+ name: Optional[str] = None,
2149
+ description: Optional[str] = None,
2150
+ num_participants: Optional[int] = None,
2151
+ estimated_completion_time_minutes: Optional[int] = None,
2152
+ participant_payment_cents: Optional[int] = None,
2153
+ device_compatibility: Optional[
2154
+ List[Literal["desktop", "tablet", "mobile"]]
2155
+ ] = None,
2156
+ peripheral_requirements: Optional[
2157
+ List[Literal["audio", "camera", "download", "microphone"]]
2158
+ ] = None,
2159
+ filters: Optional[List[Dict]] = None,
2160
+ ) -> dict:
2161
+ """
2162
+ Update a Prolific study. Returns a dict with the study details.
2163
+ """
2164
+ study = self.get_prolific_study(project_uuid, study_id)
2165
+
2166
+ current_completion_time = study.get("estimated_completion_time_minutes")
2167
+ current_payment = study.get("participant_payment_cents")
2168
+
2169
+ updated_completion_time = (
2170
+ estimated_completion_time_minutes or current_completion_time
2171
+ )
2172
+ updated_payment = participant_payment_cents or current_payment
2173
+
2174
+ is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
2175
+ updated_completion_time, updated_payment
2176
+ )
2177
+ if is_underpayment:
2178
+ raise CoopValueError(
2179
+ f"This update would result in a participant payment of ${cost_usd_per_hour:.2f} USD per hour, which is below the minimum payment for using Prolific ($8.00 USD per hour)."
2180
+ )
2181
+
2182
+ payload = {}
2183
+ if name is not None:
2184
+ payload["name"] = name
2185
+ if description is not None:
2186
+ payload["description"] = description
2187
+ if num_participants is not None:
2188
+ payload["total_available_places"] = num_participants
2189
+ if estimated_completion_time_minutes is not None:
2190
+ payload["estimated_completion_time"] = estimated_completion_time_minutes
2191
+ if participant_payment_cents is not None:
2192
+ payload["reward"] = participant_payment_cents
2193
+ if device_compatibility is not None:
2194
+ payload["device_compatibility"] = device_compatibility
2195
+ if peripheral_requirements is not None:
2196
+ payload["peripheral_requirements"] = peripheral_requirements
2197
+ if filters is not None:
2198
+ payload["filters"] = filters
2199
+
2200
+ response = self._send_server_request(
2201
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
2202
+ method="PATCH",
2203
+ payload=payload,
2204
+ )
2205
+ self._resolve_server_response(response)
2206
+ response_json = response.json()
2207
+ return {
2208
+ "study_id": response_json.get("study_id"),
2209
+ "status": response_json.get("status"),
2210
+ "admin_url": response_json.get("admin_url"),
2211
+ "respondent_url": response_json.get("respondent_url"),
2212
+ "name": response_json.get("name"),
2213
+ "description": response_json.get("description"),
2214
+ "num_participants": response_json.get("total_available_places"),
2215
+ "estimated_completion_time_minutes": response_json.get(
2216
+ "estimated_completion_time"
2217
+ ),
2218
+ "participant_payment_cents": response_json.get("reward"),
2219
+ "total_cost_cents": response_json.get("total_cost"),
2220
+ "device_compatibility": response_json.get("device_compatibility"),
2221
+ "peripheral_requirements": response_json.get("peripheral_requirements"),
2222
+ "filters": response_json.get("filters"),
2223
+ }
2224
+
2225
+ def publish_prolific_study(
2226
+ self,
2227
+ project_uuid: str,
2228
+ study_id: str,
2229
+ ) -> dict:
2230
+ """
2231
+ Publish a Prolific study.
2232
+ """
2233
+ response = self._send_server_request(
2234
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/publish",
2235
+ method="POST",
2236
+ )
2237
+ self._resolve_server_response(response)
2238
+ return response.json()
2239
+
2240
+ def get_prolific_study(self, project_uuid: str, study_id: str) -> dict:
2241
+ """
2242
+ Get a Prolific study. Returns a dict with the study details.
2243
+ """
2244
+ response = self._send_server_request(
2245
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
2246
+ method="GET",
2247
+ )
2248
+ self._resolve_server_response(response)
2249
+ response_json = response.json()
2250
+ return {
2251
+ "study_id": response_json.get("study_id"),
2252
+ "status": response_json.get("status"),
2253
+ "admin_url": response_json.get("admin_url"),
2254
+ "respondent_url": response_json.get("respondent_url"),
2255
+ "name": response_json.get("name"),
2256
+ "description": response_json.get("description"),
2257
+ "num_participants": response_json.get("total_available_places"),
2258
+ "estimated_completion_time_minutes": response_json.get(
2259
+ "estimated_completion_time"
2260
+ ),
2261
+ "participant_payment_cents": response_json.get("reward"),
2262
+ "total_cost_cents": response_json.get("total_cost"),
2263
+ "device_compatibility": response_json.get("device_compatibility"),
2264
+ "peripheral_requirements": response_json.get("peripheral_requirements"),
2265
+ "filters": response_json.get("filters"),
2266
+ }
2267
+
2268
+ def get_prolific_study_responses(
2269
+ self,
2270
+ project_uuid: str,
2271
+ study_id: str,
2272
+ ) -> Union["Results", "ScenarioList"]:
2273
+ """
2274
+ Return a Results object with the human responses for a project.
2275
+
2276
+ If generating the Results object fails, a ScenarioList will be returned instead.
2277
+ """
2278
+ response = self._send_server_request(
2279
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/responses",
2280
+ method="GET",
2281
+ )
2282
+ self._resolve_server_response(response)
2283
+ response_json = response.json()
2284
+ human_responses = response_json.get("human_responses", [])
2285
+ survey_json_string = response_json.get("survey_json_string")
2286
+
2287
+ return self._turn_human_responses_into_results(
2288
+ human_responses, survey_json_string
2289
+ )
2290
+
2291
+ def delete_prolific_study(
2292
+ self,
2293
+ project_uuid: str,
2294
+ study_id: str,
2295
+ ) -> dict:
2296
+ """
2297
+ Deletes a Prolific study.
2298
+
2299
+ Note: Only draft studies can be deleted. Once you publish a study, it cannot be deleted.
2300
+ """
2301
+ response = self._send_server_request(
2302
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
2303
+ method="DELETE",
2304
+ )
2305
+ self._resolve_server_response(response)
2306
+ return response.json()
2307
+
2308
+ def approve_prolific_study_submission(
2309
+ self,
2310
+ project_uuid: str,
2311
+ study_id: str,
2312
+ submission_id: str,
2313
+ ) -> dict:
2314
+ """
2315
+ Approve a Prolific study submission.
2316
+ """
2317
+ response = self._send_server_request(
2318
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/approve",
2319
+ method="POST",
2320
+ )
2321
+ self._resolve_server_response(response)
2322
+ return response.json()
2323
+
2324
+ def reject_prolific_study_submission(
2325
+ self,
2326
+ project_uuid: str,
2327
+ study_id: str,
2328
+ submission_id: str,
2329
+ reason: Literal[
2330
+ "TOO_QUICKLY",
2331
+ "TOO_SLOWLY",
2332
+ "FAILED_INSTRUCTIONS",
2333
+ "INCOMP_LONGITUDINAL",
2334
+ "FAILED_CHECK",
2335
+ "LOW_EFFORT",
2336
+ "MALINGERING",
2337
+ "NO_CODE",
2338
+ "BAD_CODE",
2339
+ "NO_DATA",
2340
+ "UNSUPP_DEVICE",
2341
+ "OTHER",
2342
+ ],
2343
+ explanation: str,
2344
+ ) -> dict:
2345
+ """
2346
+ Reject a Prolific study submission.
2347
+ """
2348
+ valid_rejection_reasons = [
2349
+ "TOO_QUICKLY",
2350
+ "TOO_SLOWLY",
2351
+ "FAILED_INSTRUCTIONS",
2352
+ "INCOMP_LONGITUDINAL",
2353
+ "FAILED_CHECK",
2354
+ "LOW_EFFORT",
2355
+ "MALINGERING",
2356
+ "NO_CODE",
2357
+ "BAD_CODE",
2358
+ "NO_DATA",
2359
+ "UNSUPP_DEVICE",
2360
+ "OTHER",
2361
+ ]
2362
+ if reason not in valid_rejection_reasons:
2363
+ raise CoopValueError(
2364
+ f"Invalid rejection reason. Please use one of the following: {valid_rejection_reasons}."
2365
+ )
2366
+ if len(explanation) < 100:
2367
+ raise CoopValueError(
2368
+ "Rejection explanation must be at least 100 characters."
2369
+ )
2370
+ response = self._send_server_request(
2371
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/reject",
2372
+ method="POST",
2373
+ payload={
2374
+ "reason": reason,
2375
+ "explanation": explanation,
2376
+ },
2377
+ )
2378
+ self._resolve_server_response(response)
2379
+ return response.json()
2380
+
1496
2381
  def __repr__(self):
1497
2382
  """Return a string representation of the client."""
1498
2383
  return f"Client(api_key='{self.api_key}', url='{self.url}')"
@@ -1686,6 +2571,235 @@ class Coop(CoopFunctionsMixin):
1686
2571
  self._resolve_server_response(response)
1687
2572
  return response.json().get("uuid")
1688
2573
 
2574
+ def pull(
2575
+ self,
2576
+ url_or_uuid: Optional[Union[str, UUID]] = None,
2577
+ expected_object_type: Optional[ObjectType] = None,
2578
+ ) -> dict:
2579
+ """
2580
+ Generate a signed URL for pulling an object directly from Google Cloud Storage.
2581
+
2582
+ This method gets a signed URL that allows direct download access to the object from
2583
+ Google Cloud Storage, which is more efficient for large files.
2584
+
2585
+ Parameters:
2586
+ url_or_uuid (Union[str, UUID], optional): Identifier for the object to retrieve.
2587
+ Can be one of:
2588
+ - UUID string (e.g., "123e4567-e89b-12d3-a456-426614174000")
2589
+ - Full URL (e.g., "https://expectedparrot.com/content/123e4567...")
2590
+ - Alias URL (e.g., "https://expectedparrot.com/content/username/my-survey")
2591
+ expected_object_type (ObjectType, optional): If provided, validates that the
2592
+ retrieved object is of the expected type (e.g., "survey", "agent")
2593
+
2594
+ Returns:
2595
+ dict: A response containing the signed_url for direct download
2596
+
2597
+ Raises:
2598
+ CoopNoUUIDError: If no UUID or URL is provided
2599
+ CoopInvalidURLError: If the URL format is invalid
2600
+ CoopServerResponseError: If there's an error communicating with the server
2601
+ HTTPException: If the object or object files are not found
2602
+
2603
+ Example:
2604
+ >>> response = coop.pull("123e4567-e89b-12d3-a456-426614174000")
2605
+ >>> response = coop.pull("https://expectedparrot.com/content/username/my-survey")
2606
+ >>> print(f"Download URL: {response['signed_url']}")
2607
+ >>> # Use the signed_url to download the object directly
2608
+ """
2609
+ obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
2610
+
2611
+ # Handle alias-based retrieval with new/old format detection
2612
+ if not obj_uuid and owner_username and alias:
2613
+ # First, get object info to determine format and UUID
2614
+ info_response = self._send_server_request(
2615
+ uri="api/v0/object/alias/info",
2616
+ method="GET",
2617
+ params={"owner_username": owner_username, "alias": alias},
2618
+ )
2619
+ self._resolve_server_response(info_response)
2620
+ info_data = info_response.json()
2621
+
2622
+ obj_uuid = info_data.get("uuid")
2623
+ is_new_format = info_data.get("is_new_format", False)
2624
+
2625
+ # Validate object type if expected
2626
+ if expected_object_type:
2627
+ actual_object_type = info_data.get("object_type")
2628
+ if actual_object_type != expected_object_type:
2629
+ from .exceptions import CoopObjectTypeError
2630
+
2631
+ raise CoopObjectTypeError(
2632
+ f"Expected {expected_object_type=} but got {actual_object_type=}"
2633
+ )
2634
+
2635
+ # Use get method for old format objects
2636
+ if not is_new_format:
2637
+ return self.get(url_or_uuid, expected_object_type)
2638
+
2639
+ # Send the request to the API endpoint with the resolved UUID
2640
+ response = self._send_server_request(
2641
+ uri="api/v0/object/pull",
2642
+ method="POST",
2643
+ payload={"object_uuid": obj_uuid},
2644
+ )
2645
+ # Handle any errors in the response
2646
+ self._resolve_server_response(response)
2647
+ if "signed_url" not in response.json():
2648
+ from .exceptions import CoopResponseError
2649
+
2650
+ raise CoopResponseError("No signed url was provided.")
2651
+ signed_url = response.json().get("signed_url")
2652
+
2653
+ if signed_url == "": # it is in old format
2654
+ return self.get(url_or_uuid, expected_object_type)
2655
+
2656
+ try:
2657
+ response = requests.get(signed_url)
2658
+
2659
+ self._resolve_gcs_response(response)
2660
+
2661
+ except Exception:
2662
+ return self.get(url_or_uuid, expected_object_type)
2663
+ object_dict = response.json()
2664
+ if expected_object_type is not None:
2665
+ edsl_class = ObjectRegistry.get_edsl_class_by_object_type(
2666
+ expected_object_type
2667
+ )
2668
+ edsl_object = edsl_class.from_dict(object_dict)
2669
+ # Return the response containing the signed URL
2670
+ return edsl_object
2671
+
2672
+ def get_upload_url(self, object_uuid: str) -> dict:
2673
+ """
2674
+ Get a signed upload URL for updating the content of an existing object.
2675
+
2676
+ This method gets a signed URL that allows direct upload to Google Cloud Storage
2677
+ for objects stored in the new format, while preserving the existing UUID.
2678
+
2679
+ Parameters:
2680
+ object_uuid (str): The UUID of the object to get an upload URL for
2681
+
2682
+ Returns:
2683
+ dict: A response containing:
2684
+ - signed_url: The signed URL for uploading new content
2685
+ - object_uuid: The UUID of the object
2686
+ - message: Success message
2687
+
2688
+ Raises:
2689
+ CoopServerResponseError: If there's an error communicating with the server
2690
+ HTTPException: If the object is not found, not owned by user, or not in new format
2691
+
2692
+ Notes:
2693
+ - Only works with objects stored in the new format (transition table)
2694
+ - User must be the owner of the object
2695
+ - The signed URL expires after 60 minutes
2696
+
2697
+ Example:
2698
+ >>> response = coop.get_upload_url("123e4567-e89b-12d3-a456-426614174000")
2699
+ >>> upload_url = response['signed_url']
2700
+ >>> # Use the upload_url to PUT new content directly to GCS
2701
+ """
2702
+ response = self._send_server_request(
2703
+ uri="api/v0/object/upload-url",
2704
+ method="POST",
2705
+ payload={"object_uuid": object_uuid},
2706
+ )
2707
+ self._resolve_server_response(response)
2708
+ return response.json()
2709
+
2710
+ def push(
2711
+ self,
2712
+ object: EDSLObject,
2713
+ description: Optional[str] = None,
2714
+ alias: Optional[str] = None,
2715
+ visibility: Optional[VisibilityType] = "unlisted",
2716
+ ) -> dict:
2717
+ """
2718
+ Generate a signed URL for pushing an object directly to Google Cloud Storage.
2719
+
2720
+ This method gets a signed URL that allows direct upload access to Google Cloud Storage,
2721
+ which is more efficient for large files.
2722
+
2723
+ Parameters:
2724
+ object_type (ObjectType): The type of object to be uploaded
2725
+
2726
+ Returns:
2727
+ dict: A response containing the signed_url for direct upload and optionally a job_id
2728
+
2729
+ Raises:
2730
+ CoopServerResponseError: If there's an error communicating with the server
2731
+
2732
+ Example:
2733
+ >>> response = coop.push("scenario")
2734
+ >>> print(f"Upload URL: {response['signed_url']}")
2735
+ >>> # Use the signed_url to upload the object directly
2736
+ """
2737
+
2738
+ object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
2739
+ object_dict = object.to_dict()
2740
+ object_hash = object.get_hash() if hasattr(object, "get_hash") else None
2741
+
2742
+ # Send the request to the API endpoint
2743
+ response = self._send_server_request(
2744
+ uri="api/v0/object/push",
2745
+ method="POST",
2746
+ payload={
2747
+ "object_type": object_type,
2748
+ "description": description,
2749
+ "alias": alias,
2750
+ "visibility": visibility,
2751
+ "object_hash": object_hash,
2752
+ "version": self._edsl_version,
2753
+ },
2754
+ )
2755
+ response_json = response.json()
2756
+ if response_json.get("signed_url") is not None:
2757
+ signed_url = response_json.get("signed_url")
2758
+ else:
2759
+ from .exceptions import CoopResponseError
2760
+
2761
+ raise CoopResponseError(response.text)
2762
+
2763
+ json_data = json.dumps(
2764
+ object_dict,
2765
+ default=self._json_handle_none,
2766
+ allow_nan=False,
2767
+ )
2768
+ response = requests.put(
2769
+ signed_url,
2770
+ data=json_data.encode(),
2771
+ headers={"Content-Type": "application/json"},
2772
+ )
2773
+ self._resolve_gcs_response(response)
2774
+
2775
+ # Send confirmation that upload was completed
2776
+ object_uuid = response_json.get("object_uuid", None)
2777
+ owner_username = response_json.get("owner_username", None)
2778
+ object_alias = response_json.get("alias", None)
2779
+
2780
+ if object_uuid is None:
2781
+ from .exceptions import CoopResponseError
2782
+
2783
+ raise CoopResponseError("No object uuid was provided received")
2784
+
2785
+ # Confirm the upload completion
2786
+ confirm_response = self._send_server_request(
2787
+ uri="api/v0/object/confirm-upload",
2788
+ method="POST",
2789
+ payload={"object_uuid": object_uuid},
2790
+ )
2791
+ self._resolve_server_response(confirm_response)
2792
+
2793
+ return {
2794
+ "description": response_json.get("description"),
2795
+ "object_type": object_type,
2796
+ "url": f"{self.url}/content/{object_uuid}",
2797
+ "alias_url": self._get_alias_url(owner_username, object_alias),
2798
+ "uuid": object_uuid,
2799
+ "version": self._edsl_version,
2800
+ "visibility": response_json.get("visibility"),
2801
+ }
2802
+
1689
2803
  def _display_login_url(
1690
2804
  self, edsl_auth_token: str, link_description: Optional[str] = None
1691
2805
  ):
@@ -1769,6 +2883,125 @@ class Coop(CoopFunctionsMixin):
1769
2883
  # Add API key to environment
1770
2884
  load_dotenv()
1771
2885
 
2886
+ def login_streamlit(self, timeout: int = 120):
2887
+ """
2888
+ Start the EDSL auth token login flow inside a Streamlit application.
2889
+
2890
+ This helper is functionally equivalent to ``Coop.login`` but renders the
2891
+ login link and status updates directly in the Streamlit UI. The method
2892
+ will automatically poll the Expected Parrot server for the API-key
2893
+ associated with the generated auth-token and, once received, store it
2894
+ via ``ExpectedParrotKeyHandler`` and write it to the local ``.env``
2895
+ file so subsequent sessions pick it up automatically.
2896
+
2897
+ Parameters
2898
+ ----------
2899
+ timeout : int, default 120
2900
+ How many seconds to wait for the user to complete the login before
2901
+ giving up and showing an error in the Streamlit app.
2902
+
2903
+ Returns
2904
+ -------
2905
+ str | None
2906
+ The API-key if the user logged-in successfully, otherwise ``None``.
2907
+ """
2908
+ try:
2909
+ import streamlit as st
2910
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
2911
+ except ModuleNotFoundError as exc:
2912
+ raise ImportError(
2913
+ "Streamlit is required for `login_streamlit`. Install it with `pip install streamlit`."
2914
+ ) from exc
2915
+
2916
+ # Ensure we are actually running inside a Streamlit script. If not, give a
2917
+ # clear error message instead of crashing when `st.experimental_rerun` is
2918
+ # invoked outside the Streamlit runtime.
2919
+ if get_script_run_ctx() is None:
2920
+ raise RuntimeError(
2921
+ "`login_streamlit` must be invoked from within a running Streamlit "
2922
+ "app (use `streamlit run your_script.py`). If you need to obtain an "
2923
+ "API-key in a regular Python script or notebook, use `Coop.login()` "
2924
+ "instead."
2925
+ )
2926
+
2927
+ import secrets
2928
+ import time
2929
+ import os
2930
+ from dotenv import load_dotenv
2931
+ from .ep_key_handling import ExpectedParrotKeyHandler
2932
+ from ..utilities.utilities import write_api_key_to_env
2933
+
2934
+ # ------------------------------------------------------------------
2935
+ # 1. Prepare auth-token and store state across reruns
2936
+ # ------------------------------------------------------------------
2937
+ if "edsl_auth_token" not in st.session_state:
2938
+ st.session_state.edsl_auth_token = secrets.token_urlsafe(16)
2939
+ st.session_state.login_start_time = time.time()
2940
+
2941
+ edsl_auth_token: str = st.session_state.edsl_auth_token
2942
+ login_url = (
2943
+ f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
2944
+ )
2945
+
2946
+ # ------------------------------------------------------------------
2947
+ # 2. Render clickable login link
2948
+ # ------------------------------------------------------------------
2949
+ st.markdown(
2950
+ f"🔗 **Log in to Expected Parrot** → [click here]({login_url})",
2951
+ unsafe_allow_html=True,
2952
+ )
2953
+
2954
+ # ------------------------------------------------------------------
2955
+ # 3. Poll server for API-key (runs once per Streamlit execution)
2956
+ # ------------------------------------------------------------------
2957
+ api_key = self._get_api_key(edsl_auth_token)
2958
+ if api_key is None:
2959
+ elapsed = time.time() - st.session_state.login_start_time
2960
+ if elapsed > timeout:
2961
+ st.error(
2962
+ "Timed-out waiting for login. Please rerun the app to try again."
2963
+ )
2964
+ return None
2965
+
2966
+ remaining = int(timeout - elapsed)
2967
+ st.info(f"Waiting for login… ({remaining}s left)")
2968
+ # Trigger a rerun after a short delay to continue polling
2969
+ time.sleep(1)
2970
+
2971
+ # Attempt a rerun in a version-agnostic way. Different Streamlit
2972
+ # releases expose the helper under different names.
2973
+ def _safe_rerun():
2974
+ if hasattr(st, "experimental_rerun"):
2975
+ st.experimental_rerun()
2976
+ elif hasattr(st, "rerun"):
2977
+ st.rerun() # introduced in newer versions
2978
+ else:
2979
+ # Fallback – advise the user to update Streamlit for automatic polling.
2980
+ st.warning(
2981
+ "Please refresh the page to continue the login flow. "
2982
+ "(Consider upgrading Streamlit to enable automatic refresh.)"
2983
+ )
2984
+
2985
+ try:
2986
+ _safe_rerun()
2987
+ except Exception:
2988
+ # The Streamlit runtime intercepts the rerun exception; any other
2989
+ # unexpected errors are ignored to avoid crashing the app.
2990
+ pass
2991
+
2992
+ # ------------------------------------------------------------------
2993
+ # 4. Key received – persist it and notify user
2994
+ # ------------------------------------------------------------------
2995
+ ExpectedParrotKeyHandler().store_ep_api_key(api_key)
2996
+ os.environ["EXPECTED_PARROT_API_KEY"] = api_key
2997
+ path_to_env = write_api_key_to_env(api_key)
2998
+ load_dotenv()
2999
+
3000
+ st.success("API-key retrieved and stored. You are now logged-in! 🎉")
3001
+ st.caption(f"Key saved to `{path_to_env}`.")
3002
+
3003
+ return api_key
3004
+
1772
3005
  def transfer_credits(
1773
3006
  self,
1774
3007
  credits_transferred: int,
@@ -1816,6 +3049,53 @@ class Coop(CoopFunctionsMixin):
1816
3049
  self._resolve_server_response(response)
1817
3050
  return response.json()
1818
3051
 
3052
+ def pay_for_service(
3053
+ self,
3054
+ credits_transferred: int,
3055
+ recipient_username: str,
3056
+ service_name: str,
3057
+ ) -> dict:
3058
+ """
3059
+ Pay for a service.
3060
+
3061
+ This method transfers a specified number of credits from the authenticated user's
3062
+ account to another user's account on the Expected Parrot platform.
3063
+
3064
+ Parameters:
3065
+ credits_transferred (int): The number of credits to transfer to the recipient
3066
+ recipient_username (str): The username of the recipient
3067
+ service_name (str): The name of the service to pay for
3068
+
3069
+ Returns:
3070
+ dict: Information about the transfer transaction, including:
3071
+ - success: Whether the transaction was successful
3072
+ - transaction_id: A unique identifier for the transaction
3073
+ - remaining_credits: The number of credits remaining in the sender's account
3074
+
3075
+ Raises:
3076
+ CoopServerResponseError: If there's an error communicating with the server
3077
+ or if the transfer criteria aren't met (e.g., insufficient credits)
3078
+
3079
+ Example:
3080
+ >>> result = coop.pay_for_service(
3081
+ ... credits_transferred=100,
3082
+ ... service_name="service_name",
3083
+ ... recipient_username="friend_username",
3084
+ ... )
3085
+ >>> print(f"Transfer successful! You have {result['remaining_credits']} credits left.")
3086
+ """
3087
+ response = self._send_server_request(
3088
+ uri="api/v0/users/pay-for-service",
3089
+ method="POST",
3090
+ payload={
3091
+ "cost_credits": credits_transferred,
3092
+ "service_name": service_name,
3093
+ "recipient_username": recipient_username,
3094
+ },
3095
+ )
3096
+ self._resolve_server_response(response)
3097
+ return response.json()
3098
+
1819
3099
  def get_balance(self) -> dict:
1820
3100
  """
1821
3101
  Get the current credit balance for the authenticated user.
@@ -1835,10 +3115,178 @@ class Coop(CoopFunctionsMixin):
1835
3115
  >>> balance = coop.get_balance()
1836
3116
  >>> print(f"You have {balance['credits']} credits available.")
1837
3117
  """
1838
- response = self._send_server_request(uri="api/users/get_balance", method="GET")
3118
+ response = self._send_server_request(
3119
+ uri="api/v0/users/get-balance", method="GET"
3120
+ )
3121
+ self._resolve_server_response(response)
3122
+ return response.json()
3123
+
3124
+ def get_profile(self) -> dict:
3125
+ """
3126
+ Get the current user's profile information.
3127
+
3128
+ This method retrieves the authenticated user's profile information from
3129
+ the Expected Parrot platform using their API key.
3130
+
3131
+ Returns:
3132
+ dict: User profile information including:
3133
+ - username: The user's username
3134
+ - email: The user's email address
3135
+
3136
+ Raises:
3137
+ CoopServerResponseError: If there's an error communicating with the server
3138
+
3139
+ Example:
3140
+ >>> profile = coop.get_profile()
3141
+ >>> print(f"Welcome, {profile['username']}!")
3142
+ """
3143
+ response = self._send_server_request(uri="api/v0/users/profile", method="GET")
1839
3144
  self._resolve_server_response(response)
1840
3145
  return response.json()
1841
3146
 
3147
+ def login_gradio(self, timeout: int = 120, launch: bool = True, **launch_kwargs):
3148
+ """
3149
+ Start the EDSL auth token login flow inside a **Gradio** application.
3150
+
3151
+ This helper mirrors the behaviour of :py:meth:`Coop.login_streamlit` but
3152
+ renders the login link and status updates inside a Gradio UI. It will
3153
+ poll the Expected Parrot server for the API-key associated with a newly
3154
+ generated auth-token and, once received, store it via
3155
+ :pyclass:`~edsl.coop.ep_key_handling.ExpectedParrotKeyHandler` as well as
3156
+ in the local ``.env`` file so subsequent sessions pick it up
3157
+ automatically.
3158
+
3159
+ Parameters
3160
+ ----------
3161
+ timeout : int, default 120
3162
+ How many seconds to wait for the user to complete the login before
3163
+ giving up.
3164
+ launch : bool, default True
3165
+ If ``True`` the Gradio app is immediately launched with
3166
+ ``demo.launch(**launch_kwargs)``. Set this to ``False`` if you want
3167
+ to embed the returned :class:`gradio.Blocks` object into an existing
3168
+ Gradio interface.
3169
+ **launch_kwargs
3170
+ Additional keyword-arguments forwarded to ``gr.Blocks.launch`` when
3171
+ *launch* is ``True``.
3172
+
3173
+ Returns
3174
+ -------
3175
+ str | gradio.Blocks | None
3176
+ • If the API-key is retrieved within *timeout* seconds while the
3177
+ function is executing (e.g. when *launch* is ``False`` and the
3178
+ caller integrates the Blocks into another app) the key is
3179
+ returned.
3180
+ • If *launch* is ``True`` the method returns ``None`` after the
3181
+ Gradio app has been launched.
3182
+ • If *launch* is ``False`` the constructed ``gr.Blocks`` is
3183
+ returned so the caller can compose it further.
3184
+ """
3185
+ try:
3186
+ import gradio as gr
3187
+ except ModuleNotFoundError as exc:
3188
+ raise ImportError(
3189
+ "Gradio is required for `login_gradio`. Install it with `pip install gradio`."
3190
+ ) from exc
3191
+
3192
+ import secrets
3193
+ import time
3194
+ import os
3195
+ from dotenv import load_dotenv
3196
+ from .ep_key_handling import ExpectedParrotKeyHandler
3197
+ from ..utilities.utilities import write_api_key_to_env
3198
+
3199
+ # ------------------------------------------------------------------
3200
+ # 1. Prepare auth-token
3201
+ # ------------------------------------------------------------------
3202
+ edsl_auth_token = secrets.token_urlsafe(16)
3203
+ login_url = (
3204
+ f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
3205
+ )
3206
+ start_time = time.time()
3207
+
3208
+ # ------------------------------------------------------------------
3209
+ # 2. Build Gradio interface
3210
+ # ------------------------------------------------------------------
3211
+ with gr.Blocks() as demo:
3212
+ gr.HTML(
3213
+ f'🔗 <b>Log in to Expected Parrot</b> → <a href="{login_url}" target="_blank">click here</a>'
3214
+ )
3215
+ status_md = gr.Markdown("Waiting for login…")
3216
+ refresh_btn = gr.Button(
3217
+ "I've logged in – click to continue", elem_id="refresh-btn"
3218
+ )
3219
+ key_state = gr.State(value=None)
3220
+
3221
+ # --------------------------------------------------------------
3222
+ # Polling callback
3223
+ # --------------------------------------------------------------
3224
+ def _refresh(current_key): # noqa: D401, pylint: disable=unused-argument
3225
+ """Poll server for API-key and update UI accordingly."""
3226
+
3227
+ # Fallback helper to generate a `update` object for the refresh button
3228
+ def _button_update(**kwargs):
3229
+ try:
3230
+ return gr.Button.update(**kwargs)
3231
+ except AttributeError:
3232
+ return gr.update(**kwargs)
3233
+
3234
+ api_key = self._get_api_key(edsl_auth_token)
3235
+ # Fall back to env var in case the key was obtained earlier in this session
3236
+ if not api_key:
3237
+ api_key = os.environ.get("EXPECTED_PARROT_API_KEY")
3238
+ elapsed = time.time() - start_time
3239
+ remaining = max(0, int(timeout - elapsed))
3240
+
3241
+ if api_key:
3242
+ # Persist and expose the key
3243
+ ExpectedParrotKeyHandler().store_ep_api_key(api_key)
3244
+ os.environ["EXPECTED_PARROT_API_KEY"] = api_key
3245
+ path_to_env = write_api_key_to_env(api_key)
3246
+ load_dotenv()
3247
+ success_msg = (
3248
+ "API-key retrieved and stored 🎉\n\n"
3249
+ f"Key saved to `{path_to_env}`."
3250
+ )
3251
+ return (
3252
+ success_msg,
3253
+ _button_update(interactive=False, visible=False),
3254
+ api_key,
3255
+ )
3256
+
3257
+ if elapsed > timeout:
3258
+ err_msg = (
3259
+ "Timed-out waiting for login. Please refresh the page "
3260
+ "or restart the app to try again."
3261
+ )
3262
+ return (
3263
+ err_msg,
3264
+ _button_update(),
3265
+ None,
3266
+ )
3267
+
3268
+ info_msg = f"Waiting for login… ({remaining}s left)"
3269
+ return (
3270
+ info_msg,
3271
+ _button_update(),
3272
+ None,
3273
+ )
3274
+
3275
+ # Initial status check when the interface loads
3276
+ demo.load(
3277
+ fn=_refresh,
3278
+ inputs=key_state,
3279
+ outputs=[status_md, refresh_btn, key_state],
3280
+ )
3281
+
3282
+ # ------------------------------------------------------------------
3283
+ # 3. Launch or return interface
3284
+ # ------------------------------------------------------------------
3285
+ if launch:
3286
+ demo.launch(**launch_kwargs)
3287
+ return None
3288
+ return demo
3289
+
1842
3290
 
1843
3291
  def main():
1844
3292
  """
@@ -1973,5 +3421,14 @@ def main():
1973
3421
  job = Jobs.example()
1974
3422
  coop.remote_inference_cost(job)
1975
3423
  job_coop_object = coop.remote_inference_create(job)
1976
- job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
3424
+ job_coop_results = coop.new_remote_inference_get(job_coop_object.get("uuid"))
1977
3425
  coop.get(job_coop_results.get("results_uuid"))
3426
+
3427
+ import streamlit as st
3428
+ from edsl.coop import Coop
3429
+
3430
+ coop = Coop() # no API-key required yet
3431
+ api_key = coop.login_streamlit() # renders link + handles polling & storage
3432
+
3433
+ if api_key:
3434
+ st.success("Ready to use EDSL with remote features!")