edsl 0.1.60__py3-none-any.whl → 0.1.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +65 -17
- edsl/agents/agent_list.py +117 -33
- edsl/base/base_class.py +88 -11
- edsl/config/config_class.py +7 -2
- edsl/coop/coop.py +1552 -95
- edsl/coop/coop_jobs_objects.py +2 -2
- edsl/coop/coop_prolific_filters.py +171 -0
- edsl/coop/coop_regular_objects.py +3 -1
- edsl/dataset/display/table_display.py +40 -7
- edsl/db_list/sqlite_list.py +102 -3
- edsl/jobs/data_structures.py +46 -31
- edsl/jobs/jobs.py +73 -2
- edsl/jobs/remote_inference.py +47 -13
- edsl/prompts/prompt.py +7 -2
- edsl/questions/loop_processor.py +289 -10
- edsl/questions/question_registry.py +4 -1
- edsl/questions/templates/dict/answering_instructions.jinja +0 -1
- edsl/scenarios/file_store.py +69 -0
- edsl/scenarios/scenario.py +233 -0
- edsl/scenarios/scenario_list.py +31 -1
- edsl/scenarios/scenario_source.py +605 -498
- edsl/surveys/survey.py +198 -163
- {edsl-0.1.60.dist-info → edsl-0.1.62.dist-info}/METADATA +3 -3
- {edsl-0.1.60.dist-info → edsl-0.1.62.dist-info}/RECORD +28 -27
- {edsl-0.1.60.dist-info → edsl-0.1.62.dist-info}/LICENSE +0 -0
- {edsl-0.1.60.dist-info → edsl-0.1.62.dist-info}/WHEEL +0 -0
- {edsl-0.1.60.dist-info → edsl-0.1.62.dist-info}/entry_points.txt +0 -0
edsl/coop/coop.py
CHANGED
@@ -3,8 +3,9 @@ import base64
|
|
3
3
|
import json
|
4
4
|
import requests
|
5
5
|
import time
|
6
|
+
import os
|
6
7
|
|
7
|
-
from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
|
8
|
+
from typing import Any, Dict, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
|
8
9
|
from uuid import UUID
|
9
10
|
|
10
11
|
from .. import __version__
|
@@ -35,6 +36,7 @@ from .utils import (
|
|
35
36
|
from .coop_functions import CoopFunctionsMixin
|
36
37
|
from .coop_regular_objects import CoopRegularObjects
|
37
38
|
from .coop_jobs_objects import CoopJobsObjects
|
39
|
+
from .coop_prolific_filters import CoopProlificFilters
|
38
40
|
from .ep_key_handling import ExpectedParrotKeyHandler
|
39
41
|
|
40
42
|
from ..inference_services.data_structures import ServiceToModelsMapping
|
@@ -66,7 +68,6 @@ class JobRunInterviewDetails(TypedDict):
|
|
66
68
|
|
67
69
|
|
68
70
|
class LatestJobRunDetails(TypedDict):
|
69
|
-
|
70
71
|
# For running, completed, and partially failed jobs
|
71
72
|
interview_details: Optional[JobRunInterviewDetails] = None
|
72
73
|
|
@@ -296,9 +297,9 @@ class Coop(CoopFunctionsMixin):
|
|
296
297
|
message = str(response.json().get("detail"))
|
297
298
|
except json.JSONDecodeError:
|
298
299
|
raise CoopServerResponseError(
|
299
|
-
f"Server returned status code {response.status_code}."
|
300
|
-
"JSON response could not be decoded."
|
301
|
-
"The server response was:
|
300
|
+
f"Server returned status code {response.status_code}. "
|
301
|
+
f"JSON response could not be decoded. "
|
302
|
+
f"The server response was: {response.text}"
|
302
303
|
)
|
303
304
|
# print(response.text)
|
304
305
|
if "The API key you provided is invalid" in message and check_api_key:
|
@@ -597,7 +598,7 @@ class Coop(CoopFunctionsMixin):
|
|
597
598
|
else:
|
598
599
|
from .exceptions import CoopResponseError
|
599
600
|
|
600
|
-
raise CoopResponseError("No signed url was provided
|
601
|
+
raise CoopResponseError("No signed url was provided.")
|
601
602
|
|
602
603
|
response = requests.put(
|
603
604
|
signed_url, data=json_data.encode(), headers=headers
|
@@ -694,6 +695,35 @@ class Coop(CoopFunctionsMixin):
|
|
694
695
|
"""
|
695
696
|
obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
|
696
697
|
|
698
|
+
# Handle alias-based retrieval with new/old format detection
|
699
|
+
if not obj_uuid and owner_username and alias:
|
700
|
+
# First, get object info to determine format and UUID
|
701
|
+
info_response = self._send_server_request(
|
702
|
+
uri="api/v0/object/alias/info",
|
703
|
+
method="GET",
|
704
|
+
params={"owner_username": owner_username, "alias": alias},
|
705
|
+
)
|
706
|
+
self._resolve_server_response(info_response)
|
707
|
+
info_data = info_response.json()
|
708
|
+
|
709
|
+
obj_uuid = info_data.get("uuid")
|
710
|
+
is_new_format = info_data.get("is_new_format", False)
|
711
|
+
|
712
|
+
# Validate object type if expected
|
713
|
+
if expected_object_type:
|
714
|
+
actual_object_type = info_data.get("object_type")
|
715
|
+
if actual_object_type != expected_object_type:
|
716
|
+
from .exceptions import CoopObjectTypeError
|
717
|
+
|
718
|
+
raise CoopObjectTypeError(
|
719
|
+
f"Expected {expected_object_type=} but got {actual_object_type=}"
|
720
|
+
)
|
721
|
+
|
722
|
+
# Use pull method for new format objects
|
723
|
+
if is_new_format:
|
724
|
+
return self.pull(obj_uuid, expected_object_type)
|
725
|
+
|
726
|
+
# Handle UUID-based retrieval or legacy alias objects
|
697
727
|
if obj_uuid:
|
698
728
|
response = self._send_server_request(
|
699
729
|
uri="api/v0/object",
|
@@ -915,6 +945,39 @@ class Coop(CoopFunctionsMixin):
|
|
915
945
|
|
916
946
|
obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
|
917
947
|
|
948
|
+
# If we're updating the value, we need to check the storage format
|
949
|
+
if value:
|
950
|
+
# If we don't have a UUID but have an alias, get the UUID and format info first
|
951
|
+
if not obj_uuid and owner_username and obj_alias:
|
952
|
+
# Get object info including UUID and format
|
953
|
+
info_response = self._send_server_request(
|
954
|
+
uri="api/v0/object/alias/info",
|
955
|
+
method="GET",
|
956
|
+
params={"owner_username": owner_username, "alias": obj_alias},
|
957
|
+
)
|
958
|
+
self._resolve_server_response(info_response)
|
959
|
+
info_data = info_response.json()
|
960
|
+
|
961
|
+
obj_uuid = info_data.get("uuid")
|
962
|
+
is_new_format = info_data.get("is_new_format", False)
|
963
|
+
else:
|
964
|
+
# We have a UUID, check the format
|
965
|
+
format_check_response = self._send_server_request(
|
966
|
+
uri="api/v0/object/check-format",
|
967
|
+
method="POST",
|
968
|
+
payload={"object_uuid": str(obj_uuid)},
|
969
|
+
)
|
970
|
+
self._resolve_server_response(format_check_response)
|
971
|
+
format_data = format_check_response.json()
|
972
|
+
is_new_format = format_data.get("is_new_format", False)
|
973
|
+
|
974
|
+
if is_new_format:
|
975
|
+
# Handle new format objects: update metadata first, then upload content
|
976
|
+
return self._patch_new_format_object(
|
977
|
+
obj_uuid, description, alias, value, visibility
|
978
|
+
)
|
979
|
+
|
980
|
+
# Handle traditional format objects or metadata-only updates
|
918
981
|
if obj_uuid:
|
919
982
|
uri = "api/v0/object"
|
920
983
|
params = {"uuid": obj_uuid}
|
@@ -944,6 +1007,80 @@ class Coop(CoopFunctionsMixin):
|
|
944
1007
|
self._resolve_server_response(response)
|
945
1008
|
return response.json()
|
946
1009
|
|
1010
|
+
def _patch_new_format_object(
|
1011
|
+
self,
|
1012
|
+
obj_uuid: UUID,
|
1013
|
+
description: Optional[str],
|
1014
|
+
alias: Optional[str],
|
1015
|
+
value: EDSLObject,
|
1016
|
+
visibility: Optional[VisibilityType],
|
1017
|
+
) -> dict:
|
1018
|
+
"""
|
1019
|
+
Handle patching of objects stored in the new format (GCS).
|
1020
|
+
"""
|
1021
|
+
# Step 1: Update metadata only (no json_string)
|
1022
|
+
if description is not None or alias is not None or visibility is not None:
|
1023
|
+
metadata_response = self._send_server_request(
|
1024
|
+
uri="api/v0/object",
|
1025
|
+
method="PATCH",
|
1026
|
+
params={"uuid": obj_uuid},
|
1027
|
+
payload={
|
1028
|
+
"description": description,
|
1029
|
+
"alias": alias,
|
1030
|
+
"json_string": None, # Don't send content to traditional endpoint
|
1031
|
+
"visibility": visibility,
|
1032
|
+
},
|
1033
|
+
)
|
1034
|
+
self._resolve_server_response(metadata_response)
|
1035
|
+
|
1036
|
+
# Step 2: Get signed upload URL for content update
|
1037
|
+
upload_url_response = self._send_server_request(
|
1038
|
+
uri="api/v0/object/upload-url",
|
1039
|
+
method="POST",
|
1040
|
+
payload={"object_uuid": str(obj_uuid)},
|
1041
|
+
)
|
1042
|
+
self._resolve_server_response(upload_url_response)
|
1043
|
+
upload_data = upload_url_response.json()
|
1044
|
+
|
1045
|
+
# Step 3: Upload the object content to GCS
|
1046
|
+
signed_url = upload_data.get("signed_url")
|
1047
|
+
if not signed_url:
|
1048
|
+
raise CoopServerResponseError("Failed to get signed upload URL")
|
1049
|
+
|
1050
|
+
json_content = json.dumps(
|
1051
|
+
value.to_dict(),
|
1052
|
+
default=self._json_handle_none,
|
1053
|
+
allow_nan=False,
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
# Upload to GCS using signed URL
|
1057
|
+
gcs_response = requests.put(
|
1058
|
+
signed_url,
|
1059
|
+
data=json_content,
|
1060
|
+
headers={"Content-Type": "application/json"},
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
if gcs_response.status_code != 200:
|
1064
|
+
raise CoopServerResponseError(
|
1065
|
+
f"Failed to upload object to GCS: {gcs_response.status_code}"
|
1066
|
+
)
|
1067
|
+
|
1068
|
+
# Step 4: Confirm upload and trigger queue worker processing
|
1069
|
+
confirm_response = self._send_server_request(
|
1070
|
+
uri="api/v0/object/confirm-upload",
|
1071
|
+
method="POST",
|
1072
|
+
payload={"object_uuid": str(obj_uuid)},
|
1073
|
+
)
|
1074
|
+
self._resolve_server_response(confirm_response)
|
1075
|
+
confirm_data = confirm_response.json()
|
1076
|
+
|
1077
|
+
return {
|
1078
|
+
"status": "success",
|
1079
|
+
"message": "Object updated successfully (new format - uploaded to GCS and processing triggered)",
|
1080
|
+
"object_uuid": str(obj_uuid),
|
1081
|
+
"processing_started": confirm_data.get("processing_started", False),
|
1082
|
+
}
|
1083
|
+
|
947
1084
|
################
|
948
1085
|
# Remote Cache
|
949
1086
|
################
|
@@ -1025,6 +1162,115 @@ class Coop(CoopFunctionsMixin):
|
|
1025
1162
|
is handled by Expected Parrot's infrastructure, and you can check the status
|
1026
1163
|
and retrieve results later.
|
1027
1164
|
|
1165
|
+
Parameters:
|
1166
|
+
job (Jobs): The EDSL job to run in the cloud
|
1167
|
+
description (str, optional): A human-readable description of the job
|
1168
|
+
status (RemoteJobStatus): Initial status, should be "queued" for normal use
|
1169
|
+
Possible values: "queued", "running", "completed", "failed"
|
1170
|
+
visibility (VisibilityType): Access level for the job information. One of:
|
1171
|
+
- "private": Only accessible by the owner
|
1172
|
+
- "public": Accessible by anyone
|
1173
|
+
- "unlisted": Accessible with the link, but not listed publicly
|
1174
|
+
initial_results_visibility (VisibilityType): Access level for the job results
|
1175
|
+
iterations (int): Number of times to run each interview (default: 1)
|
1176
|
+
fresh (bool): If True, ignore existing cache entries and generate new results
|
1177
|
+
|
1178
|
+
Returns:
|
1179
|
+
RemoteInferenceCreationInfo: Information about the created job including:
|
1180
|
+
- uuid: The unique identifier for the job
|
1181
|
+
- description: The job description
|
1182
|
+
- status: Current status of the job
|
1183
|
+
- iterations: Number of iterations for each interview
|
1184
|
+
- visibility: Access level for the job
|
1185
|
+
- version: EDSL version used to create the job
|
1186
|
+
|
1187
|
+
Raises:
|
1188
|
+
CoopServerResponseError: If there's an error communicating with the server
|
1189
|
+
|
1190
|
+
Notes:
|
1191
|
+
- Remote jobs run asynchronously and may take time to complete
|
1192
|
+
- Use remote_inference_get() with the returned UUID to check status
|
1193
|
+
- Credits are consumed based on the complexity of the job
|
1194
|
+
|
1195
|
+
Example:
|
1196
|
+
>>> from edsl.jobs import Jobs
|
1197
|
+
>>> job = Jobs.example()
|
1198
|
+
>>> job_info = coop.remote_inference_create(job=job, description="My job")
|
1199
|
+
>>> print(f"Job created with UUID: {job_info['uuid']}")
|
1200
|
+
"""
|
1201
|
+
response = self._send_server_request(
|
1202
|
+
uri="api/v0/new-remote-inference",
|
1203
|
+
method="POST",
|
1204
|
+
payload={
|
1205
|
+
"json_string": "offloaded",
|
1206
|
+
"description": description,
|
1207
|
+
"status": status,
|
1208
|
+
"iterations": iterations,
|
1209
|
+
"visibility": visibility,
|
1210
|
+
"version": self._edsl_version,
|
1211
|
+
"initial_results_visibility": initial_results_visibility,
|
1212
|
+
"fresh": fresh,
|
1213
|
+
},
|
1214
|
+
)
|
1215
|
+
self._resolve_server_response(response)
|
1216
|
+
response_json = response.json()
|
1217
|
+
upload_signed_url = response_json.get("upload_signed_url")
|
1218
|
+
if not upload_signed_url:
|
1219
|
+
from .exceptions import CoopResponseError
|
1220
|
+
|
1221
|
+
raise CoopResponseError("No signed url was provided.")
|
1222
|
+
|
1223
|
+
response = requests.put(
|
1224
|
+
upload_signed_url,
|
1225
|
+
data=json.dumps(
|
1226
|
+
job.to_dict(),
|
1227
|
+
default=self._json_handle_none,
|
1228
|
+
).encode(),
|
1229
|
+
headers={"Content-Type": "application/json"},
|
1230
|
+
)
|
1231
|
+
self._resolve_gcs_response(response)
|
1232
|
+
|
1233
|
+
job_uuid = response_json.get("job_uuid")
|
1234
|
+
|
1235
|
+
response = self._send_server_request(
|
1236
|
+
uri="api/v0/new-remote-inference/uploaded",
|
1237
|
+
method="POST",
|
1238
|
+
payload={
|
1239
|
+
"job_uuid": job_uuid,
|
1240
|
+
"message": "Job uploaded successfully",
|
1241
|
+
},
|
1242
|
+
)
|
1243
|
+
response_json = response.json()
|
1244
|
+
|
1245
|
+
return RemoteInferenceCreationInfo(
|
1246
|
+
**{
|
1247
|
+
"uuid": response_json.get("job_uuid"),
|
1248
|
+
"description": response_json.get("description", ""),
|
1249
|
+
"status": response_json.get("status"),
|
1250
|
+
"iterations": response_json.get("iterations", ""),
|
1251
|
+
"visibility": response_json.get("visibility", ""),
|
1252
|
+
"version": self._edsl_version,
|
1253
|
+
}
|
1254
|
+
)
|
1255
|
+
|
1256
|
+
def old_remote_inference_create(
|
1257
|
+
self,
|
1258
|
+
job: "Jobs",
|
1259
|
+
description: Optional[str] = None,
|
1260
|
+
status: RemoteJobStatus = "queued",
|
1261
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
1262
|
+
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
1263
|
+
iterations: Optional[int] = 1,
|
1264
|
+
fresh: Optional[bool] = False,
|
1265
|
+
) -> RemoteInferenceCreationInfo:
|
1266
|
+
"""
|
1267
|
+
Create a remote inference job for execution in the Expected Parrot cloud.
|
1268
|
+
|
1269
|
+
This method sends a job to be executed in the cloud, which can be more efficient
|
1270
|
+
for large jobs or when you want to run jobs in the background. The job execution
|
1271
|
+
is handled by Expected Parrot's infrastructure, and you can check the status
|
1272
|
+
and retrieve results later.
|
1273
|
+
|
1028
1274
|
Parameters:
|
1029
1275
|
job (Jobs): The EDSL job to run in the cloud
|
1030
1276
|
description (str, optional): A human-readable description of the job
|
@@ -1208,79 +1454,233 @@ class Coop(CoopFunctionsMixin):
|
|
1208
1454
|
}
|
1209
1455
|
)
|
1210
1456
|
|
1211
|
-
def
|
1212
|
-
self,
|
1213
|
-
|
1457
|
+
def new_remote_inference_get(
|
1458
|
+
self,
|
1459
|
+
job_uuid: Optional[str] = None,
|
1460
|
+
results_uuid: Optional[str] = None,
|
1461
|
+
include_json_string: Optional[bool] = False,
|
1462
|
+
) -> RemoteInferenceResponse:
|
1214
1463
|
"""
|
1215
|
-
|
1464
|
+
Get the status and details of a remote inference job.
|
1216
1465
|
|
1217
|
-
|
1218
|
-
|
1466
|
+
This method retrieves the current status and information about a remote job,
|
1467
|
+
including links to results if the job has completed successfully.
|
1468
|
+
|
1469
|
+
Parameters:
|
1470
|
+
job_uuid (str, optional): The UUID of the remote job to check
|
1471
|
+
results_uuid (str, optional): The UUID of the results associated with the job
|
1472
|
+
(can be used if you only have the results UUID)
|
1473
|
+
include_json_string (bool, optional): If True, include the json string for the job in the response
|
1219
1474
|
|
1220
1475
|
Returns:
|
1221
|
-
|
1476
|
+
RemoteInferenceResponse: Information about the job including:
|
1477
|
+
job_uuid: The unique identifier for the job
|
1478
|
+
results_uuid: The UUID of the results
|
1479
|
+
results_url: URL to access the results
|
1480
|
+
status: Current status ("queued", "running", "completed", "failed")
|
1481
|
+
version: EDSL version used for the job
|
1482
|
+
job_json_string: The json string for the job (if include_json_string is True)
|
1483
|
+
latest_job_run_details: Metadata about the job status
|
1484
|
+
interview_details: Metadata about the job interview status (for jobs that have reached running status)
|
1485
|
+
total_interviews: The total number of interviews in the job
|
1486
|
+
completed_interviews: The number of completed interviews
|
1487
|
+
interviews_with_exceptions: The number of completed interviews that have exceptions
|
1488
|
+
exception_counters: A list of exception counts for the job
|
1489
|
+
exception_type: The type of exception
|
1490
|
+
inference_service: The inference service
|
1491
|
+
model: The model
|
1492
|
+
question_name: The name of the question
|
1493
|
+
exception_count: The number of exceptions
|
1494
|
+
failure_reason: The reason the job failed (failed jobs only)
|
1495
|
+
failure_description: The description of the failure (failed jobs only)
|
1496
|
+
error_report_uuid: The UUID of the error report (partially failed jobs only)
|
1497
|
+
cost_credits: The cost of the job run in credits
|
1498
|
+
cost_usd: The cost of the job run in USD
|
1499
|
+
expenses: The expenses incurred by the job run
|
1500
|
+
service: The service
|
1501
|
+
model: The model
|
1502
|
+
token_type: The type of token (input or output)
|
1503
|
+
price_per_million_tokens: The price per million tokens
|
1504
|
+
tokens_count: The number of tokens consumed
|
1505
|
+
cost_credits: The cost of the service/model/token type combination in credits
|
1506
|
+
cost_usd: The cost of the service/model/token type combination in USD
|
1222
1507
|
|
1223
1508
|
Raises:
|
1224
|
-
|
1225
|
-
|
1226
|
-
valid_status_types = [
|
1227
|
-
"queued",
|
1228
|
-
"running",
|
1229
|
-
"completed",
|
1230
|
-
"failed",
|
1231
|
-
"cancelled",
|
1232
|
-
"cancelling",
|
1233
|
-
"partial_failed",
|
1234
|
-
]
|
1235
|
-
if isinstance(status, list):
|
1236
|
-
invalid_statuses = [s for s in status if s not in valid_status_types]
|
1237
|
-
if invalid_statuses:
|
1238
|
-
raise CoopValueError(
|
1239
|
-
f"Invalid status type(s): {invalid_statuses}. "
|
1240
|
-
f"Valid types are: {valid_status_types}"
|
1241
|
-
)
|
1242
|
-
return status
|
1243
|
-
else:
|
1244
|
-
if status not in valid_status_types:
|
1245
|
-
raise CoopValueError(
|
1246
|
-
f"Invalid status type: {status}. "
|
1247
|
-
f"Valid types are: {valid_status_types}"
|
1248
|
-
)
|
1249
|
-
return [status]
|
1250
|
-
|
1251
|
-
def remote_inference_list(
|
1252
|
-
self,
|
1253
|
-
status: Union[RemoteJobStatus, List[RemoteJobStatus], None] = None,
|
1254
|
-
search_query: Union[str, None] = None,
|
1255
|
-
page: int = 1,
|
1256
|
-
page_size: int = 10,
|
1257
|
-
sort_ascending: bool = False,
|
1258
|
-
) -> "CoopJobsObjects":
|
1259
|
-
"""
|
1260
|
-
Retrieve jobs owned by the user.
|
1509
|
+
ValueError: If neither job_uuid nor results_uuid is provided
|
1510
|
+
CoopServerResponseError: If there's an error communicating with the server
|
1261
1511
|
|
1262
1512
|
Notes:
|
1263
|
-
|
1264
|
-
|
1513
|
+
- Either job_uuid or results_uuid must be provided
|
1514
|
+
- If both are provided, job_uuid takes precedence
|
1515
|
+
- For completed jobs, you can use the results_url to view or download results
|
1516
|
+
- For failed jobs, check the latest_error_report_url for debugging information
|
1517
|
+
|
1518
|
+
Example:
|
1519
|
+
>>> job_status = coop.new_remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
1520
|
+
>>> print(f"Job status: {job_status['status']}")
|
1521
|
+
>>> if job_status['status'] == 'completed':
|
1522
|
+
... print(f"Results available at: {job_status['results_url']}")
|
1265
1523
|
"""
|
1266
|
-
|
1524
|
+
if job_uuid is None and results_uuid is None:
|
1525
|
+
from .exceptions import CoopValueError
|
1267
1526
|
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1527
|
+
raise CoopValueError("Either job_uuid or results_uuid must be provided.")
|
1528
|
+
elif job_uuid is not None:
|
1529
|
+
params = {"job_uuid": job_uuid}
|
1530
|
+
else:
|
1531
|
+
params = {"results_uuid": results_uuid}
|
1532
|
+
if include_json_string:
|
1533
|
+
params["include_json_string"] = include_json_string
|
1274
1534
|
|
1275
|
-
|
1276
|
-
"
|
1277
|
-
"
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1535
|
+
response = self._send_server_request(
|
1536
|
+
uri="api/v0/remote-inference",
|
1537
|
+
method="GET",
|
1538
|
+
params=params,
|
1539
|
+
)
|
1540
|
+
self._resolve_server_response(response)
|
1541
|
+
data = response.json()
|
1542
|
+
|
1543
|
+
results_uuid = data.get("results_uuid")
|
1544
|
+
|
1545
|
+
if results_uuid is None:
|
1546
|
+
results_url = None
|
1547
|
+
else:
|
1548
|
+
results_url = f"{self.url}/content/{results_uuid}"
|
1549
|
+
|
1550
|
+
latest_job_run_details = data.get("latest_job_run_details", {})
|
1551
|
+
if data.get("status") == "partial_failed":
|
1552
|
+
latest_error_report_uuid = latest_job_run_details.get("error_report_uuid")
|
1553
|
+
if latest_error_report_uuid is None:
|
1554
|
+
latest_job_run_details["error_report_url"] = None
|
1555
|
+
else:
|
1556
|
+
latest_error_report_url = (
|
1557
|
+
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
1558
|
+
)
|
1559
|
+
latest_job_run_details["error_report_url"] = latest_error_report_url
|
1560
|
+
|
1561
|
+
json_string = data.get("job_json_string")
|
1562
|
+
|
1563
|
+
# The job has been offloaded to GCS
|
1564
|
+
if include_json_string and json_string == "offloaded":
|
1565
|
+
|
1566
|
+
# Attempt to fetch JSON string from GCS
|
1567
|
+
response = self._send_server_request(
|
1568
|
+
uri="api/v0/remote-inference/pull",
|
1569
|
+
method="POST",
|
1570
|
+
payload={"job_uuid": job_uuid},
|
1571
|
+
)
|
1572
|
+
# Handle any errors in the response
|
1573
|
+
self._resolve_server_response(response)
|
1574
|
+
if "signed_url" not in response.json():
|
1575
|
+
from .exceptions import CoopResponseError
|
1576
|
+
|
1577
|
+
raise CoopResponseError("No signed url was provided.")
|
1578
|
+
signed_url = response.json().get("signed_url")
|
1579
|
+
|
1580
|
+
if signed_url == "": # The job is in legacy format
|
1581
|
+
job_json = json_string
|
1582
|
+
|
1583
|
+
try:
|
1584
|
+
response = requests.get(signed_url)
|
1585
|
+
self._resolve_gcs_response(response)
|
1586
|
+
job_json = json.dumps(response.json())
|
1587
|
+
except Exception:
|
1588
|
+
job_json = json_string
|
1589
|
+
|
1590
|
+
# If the job is in legacy format, we should already have the JSON string
|
1591
|
+
# from the first API call
|
1592
|
+
elif include_json_string and not json_string == "offloaded":
|
1593
|
+
job_json = json_string
|
1594
|
+
|
1595
|
+
# If include_json_string is False, we don't need the JSON string at all
|
1596
|
+
else:
|
1597
|
+
job_json = None
|
1598
|
+
|
1599
|
+
return RemoteInferenceResponse(
|
1600
|
+
**{
|
1601
|
+
"job_uuid": data.get("job_uuid"),
|
1602
|
+
"results_uuid": results_uuid,
|
1603
|
+
"results_url": results_url,
|
1604
|
+
"status": data.get("status"),
|
1605
|
+
"version": data.get("version"),
|
1606
|
+
"job_json_string": job_json,
|
1607
|
+
"latest_job_run_details": latest_job_run_details,
|
1608
|
+
}
|
1609
|
+
)
|
1610
|
+
|
1611
|
+
def _validate_remote_job_status_types(
|
1612
|
+
self, status: Union[RemoteJobStatus, List[RemoteJobStatus]]
|
1613
|
+
) -> List[RemoteJobStatus]:
|
1614
|
+
"""
|
1615
|
+
Validate visibility types and return a list of valid types.
|
1616
|
+
|
1617
|
+
Args:
|
1618
|
+
visibility: Single visibility type or list of visibility types to validate
|
1619
|
+
|
1620
|
+
Returns:
|
1621
|
+
List of validated visibility types
|
1622
|
+
|
1623
|
+
Raises:
|
1624
|
+
CoopValueError: If any visibility type is invalid
|
1625
|
+
"""
|
1626
|
+
valid_status_types = [
|
1627
|
+
"queued",
|
1628
|
+
"running",
|
1629
|
+
"completed",
|
1630
|
+
"failed",
|
1631
|
+
"cancelled",
|
1632
|
+
"cancelling",
|
1633
|
+
"partial_failed",
|
1634
|
+
]
|
1635
|
+
if isinstance(status, list):
|
1636
|
+
invalid_statuses = [s for s in status if s not in valid_status_types]
|
1637
|
+
if invalid_statuses:
|
1638
|
+
raise CoopValueError(
|
1639
|
+
f"Invalid status type(s): {invalid_statuses}. "
|
1640
|
+
f"Valid types are: {valid_status_types}"
|
1641
|
+
)
|
1642
|
+
return status
|
1643
|
+
else:
|
1644
|
+
if status not in valid_status_types:
|
1645
|
+
raise CoopValueError(
|
1646
|
+
f"Invalid status type: {status}. "
|
1647
|
+
f"Valid types are: {valid_status_types}"
|
1648
|
+
)
|
1649
|
+
return [status]
|
1650
|
+
|
1651
|
+
def remote_inference_list(
|
1652
|
+
self,
|
1653
|
+
status: Union[RemoteJobStatus, List[RemoteJobStatus], None] = None,
|
1654
|
+
search_query: Union[str, None] = None,
|
1655
|
+
page: int = 1,
|
1656
|
+
page_size: int = 10,
|
1657
|
+
sort_ascending: bool = False,
|
1658
|
+
) -> "CoopJobsObjects":
|
1659
|
+
"""
|
1660
|
+
Retrieve jobs owned by the user.
|
1661
|
+
|
1662
|
+
Notes:
|
1663
|
+
- search_query only works with the description field.
|
1664
|
+
- If sort_ascending is False, then the most recently created jobs are returned first.
|
1665
|
+
"""
|
1666
|
+
from ..scenarios import Scenario
|
1667
|
+
|
1668
|
+
if page < 1:
|
1669
|
+
raise CoopValueError("The page must be greater than or equal to 1.")
|
1670
|
+
if page_size < 1:
|
1671
|
+
raise CoopValueError("The page size must be greater than or equal to 1.")
|
1672
|
+
if page_size > 100:
|
1673
|
+
raise CoopValueError("The page size must be less than or equal to 100.")
|
1674
|
+
|
1675
|
+
params = {
|
1676
|
+
"page": page,
|
1677
|
+
"page_size": page_size,
|
1678
|
+
"sort_ascending": sort_ascending,
|
1679
|
+
}
|
1680
|
+
if status:
|
1681
|
+
params["status"] = self._validate_remote_job_status_types(status)
|
1682
|
+
if search_query:
|
1683
|
+
params["search_query"] = search_query
|
1284
1684
|
|
1285
1685
|
response = self._send_server_request(
|
1286
1686
|
uri="api/v0/remote-inference/list",
|
@@ -1368,25 +1768,57 @@ class Coop(CoopFunctionsMixin):
|
|
1368
1768
|
def create_project(
|
1369
1769
|
self,
|
1370
1770
|
survey: "Survey",
|
1771
|
+
scenario_list: Optional["ScenarioList"] = None,
|
1772
|
+
scenario_list_method: Optional[
|
1773
|
+
Literal["randomize", "loop", "single_scenario"]
|
1774
|
+
] = None,
|
1371
1775
|
project_name: str = "Project",
|
1372
1776
|
survey_description: Optional[str] = None,
|
1373
1777
|
survey_alias: Optional[str] = None,
|
1374
1778
|
survey_visibility: Optional[VisibilityType] = "unlisted",
|
1779
|
+
scenario_list_description: Optional[str] = None,
|
1780
|
+
scenario_list_alias: Optional[str] = None,
|
1781
|
+
scenario_list_visibility: Optional[VisibilityType] = "unlisted",
|
1375
1782
|
):
|
1376
1783
|
"""
|
1377
1784
|
Create a survey object on Coop, then create a project from the survey.
|
1378
1785
|
"""
|
1379
|
-
|
1786
|
+
if scenario_list is None and scenario_list_method is not None:
|
1787
|
+
raise CoopValueError(
|
1788
|
+
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1789
|
+
)
|
1790
|
+
elif scenario_list is not None and scenario_list_method is None:
|
1791
|
+
raise CoopValueError(
|
1792
|
+
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1793
|
+
)
|
1794
|
+
survey_details = self.push(
|
1380
1795
|
object=survey,
|
1381
1796
|
description=survey_description,
|
1382
1797
|
alias=survey_alias,
|
1383
1798
|
visibility=survey_visibility,
|
1384
1799
|
)
|
1385
1800
|
survey_uuid = survey_details.get("uuid")
|
1801
|
+
if scenario_list is not None:
|
1802
|
+
scenario_list_details = self.push(
|
1803
|
+
object=scenario_list,
|
1804
|
+
description=scenario_list_description,
|
1805
|
+
alias=scenario_list_alias,
|
1806
|
+
visibility=scenario_list_visibility,
|
1807
|
+
)
|
1808
|
+
scenario_list_uuid = scenario_list_details.get("uuid")
|
1809
|
+
else:
|
1810
|
+
scenario_list_uuid = None
|
1386
1811
|
response = self._send_server_request(
|
1387
1812
|
uri="api/v0/projects/create-from-survey",
|
1388
1813
|
method="POST",
|
1389
|
-
payload={
|
1814
|
+
payload={
|
1815
|
+
"project_name": project_name,
|
1816
|
+
"survey_uuid": str(survey_uuid),
|
1817
|
+
"scenario_list_uuid": (
|
1818
|
+
str(scenario_list_uuid) if scenario_list_uuid is not None else None
|
1819
|
+
),
|
1820
|
+
"scenario_list_method": scenario_list_method,
|
1821
|
+
},
|
1390
1822
|
)
|
1391
1823
|
self._resolve_server_response(response)
|
1392
1824
|
response_json = response.json()
|
@@ -1413,14 +1845,26 @@ class Coop(CoopFunctionsMixin):
|
|
1413
1845
|
return {
|
1414
1846
|
"project_name": response_json.get("project_name"),
|
1415
1847
|
"project_job_uuids": response_json.get("job_uuids"),
|
1848
|
+
"project_prolific_studies": [
|
1849
|
+
{
|
1850
|
+
"study_id": study.get("id"),
|
1851
|
+
"name": study.get("name"),
|
1852
|
+
"status": study.get("status"),
|
1853
|
+
"num_participants": study.get("total_available_places"),
|
1854
|
+
"places_taken": study.get("places_taken"),
|
1855
|
+
}
|
1856
|
+
for study in response_json.get("prolific_studies", [])
|
1857
|
+
],
|
1416
1858
|
}
|
1417
1859
|
|
1418
|
-
def
|
1860
|
+
def _turn_human_responses_into_results(
|
1419
1861
|
self,
|
1420
|
-
|
1862
|
+
human_responses: List[dict],
|
1863
|
+
survey_json_string: str,
|
1864
|
+
scenario_list_json_string: Optional[str] = None,
|
1421
1865
|
) -> Union["Results", "ScenarioList"]:
|
1422
1866
|
"""
|
1423
|
-
|
1867
|
+
Turn a list of human responses into a Results object.
|
1424
1868
|
|
1425
1869
|
If generating the Results object fails, a ScenarioList will be returned instead.
|
1426
1870
|
"""
|
@@ -1430,16 +1874,19 @@ class Coop(CoopFunctionsMixin):
|
|
1430
1874
|
from ..scenarios import Scenario, ScenarioList
|
1431
1875
|
from ..surveys import Survey
|
1432
1876
|
|
1433
|
-
response = self._send_server_request(
|
1434
|
-
uri=f"api/v0/projects/{project_uuid}/human-responses",
|
1435
|
-
method="GET",
|
1436
|
-
)
|
1437
|
-
self._resolve_server_response(response)
|
1438
|
-
response_json = response.json()
|
1439
|
-
human_responses = response_json.get("human_responses", [])
|
1440
|
-
|
1441
1877
|
try:
|
1442
|
-
|
1878
|
+
survey = Survey.from_dict(json.loads(survey_json_string))
|
1879
|
+
|
1880
|
+
model = Model("test")
|
1881
|
+
|
1882
|
+
if scenario_list_json_string is not None:
|
1883
|
+
scenario_list = ScenarioList.from_dict(
|
1884
|
+
json.loads(scenario_list_json_string)
|
1885
|
+
)
|
1886
|
+
else:
|
1887
|
+
scenario_list = ScenarioList()
|
1888
|
+
|
1889
|
+
results = None
|
1443
1890
|
|
1444
1891
|
for response in human_responses:
|
1445
1892
|
response_uuid = response.get("response_uuid")
|
@@ -1449,8 +1896,14 @@ class Coop(CoopFunctionsMixin):
|
|
1449
1896
|
)
|
1450
1897
|
|
1451
1898
|
response_dict = json.loads(response.get("response_json_string"))
|
1899
|
+
agent_traits_json_string = response.get("agent_traits_json_string")
|
1900
|
+
scenario_uuid = response.get("scenario_uuid")
|
1901
|
+
if agent_traits_json_string is not None:
|
1902
|
+
agent_traits = json.loads(agent_traits_json_string)
|
1903
|
+
else:
|
1904
|
+
agent_traits = {}
|
1452
1905
|
|
1453
|
-
a = Agent(name=response_uuid, instruction="")
|
1906
|
+
a = Agent(name=response_uuid, instruction="", traits=agent_traits)
|
1454
1907
|
|
1455
1908
|
def create_answer_function(response_data):
|
1456
1909
|
def f(self, question, scenario):
|
@@ -1458,27 +1911,38 @@ class Coop(CoopFunctionsMixin):
|
|
1458
1911
|
|
1459
1912
|
return f
|
1460
1913
|
|
1914
|
+
scenario = None
|
1915
|
+
if scenario_uuid is not None:
|
1916
|
+
for s in scenario_list:
|
1917
|
+
if s.get("uuid") == scenario_uuid:
|
1918
|
+
scenario = s
|
1919
|
+
break
|
1920
|
+
|
1921
|
+
if scenario is None:
|
1922
|
+
raise RuntimeError("Scenario not found.")
|
1923
|
+
|
1461
1924
|
a.add_direct_question_answering_method(
|
1462
1925
|
create_answer_function(response_dict)
|
1463
1926
|
)
|
1464
|
-
agent_list.append(a)
|
1465
1927
|
|
1466
|
-
|
1467
|
-
survey = Survey.from_dict(json.loads(survey_json_string))
|
1928
|
+
job = survey.by(a).by(model)
|
1468
1929
|
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
.
|
1473
|
-
.run(
|
1930
|
+
if scenario is not None:
|
1931
|
+
job = job.by(scenario)
|
1932
|
+
|
1933
|
+
question_results = job.run(
|
1474
1934
|
cache=Cache(),
|
1475
1935
|
disable_remote_cache=True,
|
1476
1936
|
disable_remote_inference=True,
|
1477
1937
|
print_exceptions=False,
|
1478
1938
|
)
|
1479
|
-
|
1939
|
+
|
1940
|
+
if results is None:
|
1941
|
+
results = question_results
|
1942
|
+
else:
|
1943
|
+
results = results + question_results
|
1480
1944
|
return results
|
1481
|
-
except Exception:
|
1945
|
+
except Exception as e:
|
1482
1946
|
human_response_scenarios = []
|
1483
1947
|
for response in human_responses:
|
1484
1948
|
response_uuid = response.get("response_uuid")
|
@@ -1493,6 +1957,427 @@ class Coop(CoopFunctionsMixin):
|
|
1493
1957
|
human_response_scenarios.append(scenario)
|
1494
1958
|
return ScenarioList(human_response_scenarios)
|
1495
1959
|
|
1960
|
+
def get_project_human_responses(
|
1961
|
+
self,
|
1962
|
+
project_uuid: str,
|
1963
|
+
) -> Union["Results", "ScenarioList"]:
|
1964
|
+
"""
|
1965
|
+
Return a Results object with the human responses for a project.
|
1966
|
+
|
1967
|
+
If generating the Results object fails, a ScenarioList will be returned instead.
|
1968
|
+
"""
|
1969
|
+
response = self._send_server_request(
|
1970
|
+
uri=f"api/v0/projects/{project_uuid}/human-responses",
|
1971
|
+
method="GET",
|
1972
|
+
)
|
1973
|
+
self._resolve_server_response(response)
|
1974
|
+
response_json = response.json()
|
1975
|
+
human_responses = response_json.get("human_responses", [])
|
1976
|
+
survey_json_string = response_json.get("survey_json_string")
|
1977
|
+
scenario_list_json_string = response_json.get("scenario_list_json_string")
|
1978
|
+
|
1979
|
+
return self._turn_human_responses_into_results(
|
1980
|
+
human_responses, survey_json_string, scenario_list_json_string
|
1981
|
+
)
|
1982
|
+
|
1983
|
+
def list_prolific_filters(self) -> "CoopProlificFilters":
|
1984
|
+
"""
|
1985
|
+
Get a ScenarioList of supported Prolific filters. This list has several methods
|
1986
|
+
that you can use to create valid filter dicts for use with Coop.create_prolific_study().
|
1987
|
+
|
1988
|
+
Call find() to examine a specific filter by ID:
|
1989
|
+
>>> filters = coop.list_prolific_filters()
|
1990
|
+
>>> filters.find("age")
|
1991
|
+
Scenario(
|
1992
|
+
{
|
1993
|
+
"filter_id": "age",
|
1994
|
+
"type": "range",
|
1995
|
+
"range_filter_min": 18,
|
1996
|
+
"range_filter_max": 100,
|
1997
|
+
...
|
1998
|
+
}
|
1999
|
+
)
|
2000
|
+
|
2001
|
+
Call create_study_filter() to create a valid filter dict:
|
2002
|
+
>>> filters.create_study_filter("age", min=30, max=40)
|
2003
|
+
{
|
2004
|
+
"filter_id": "age",
|
2005
|
+
"selected_range": {
|
2006
|
+
"lower": 30,
|
2007
|
+
"upper": 40,
|
2008
|
+
},
|
2009
|
+
}
|
2010
|
+
"""
|
2011
|
+
from ..scenarios import Scenario
|
2012
|
+
|
2013
|
+
response = self._send_server_request(
|
2014
|
+
uri="api/v0/prolific-filters",
|
2015
|
+
method="GET",
|
2016
|
+
)
|
2017
|
+
self._resolve_server_response(response)
|
2018
|
+
response_json = response.json()
|
2019
|
+
filters = response_json.get("prolific_filters", [])
|
2020
|
+
filter_scenarios = []
|
2021
|
+
for filter in filters:
|
2022
|
+
filter_type = filter.get("type")
|
2023
|
+
question = filter.get("question")
|
2024
|
+
scenario = Scenario(
|
2025
|
+
{
|
2026
|
+
"filter_id": filter.get("filter_id"),
|
2027
|
+
"title": filter.get("title"),
|
2028
|
+
"question": (
|
2029
|
+
f"Participants were asked the following: {question}"
|
2030
|
+
if question
|
2031
|
+
else None
|
2032
|
+
),
|
2033
|
+
"type": filter_type,
|
2034
|
+
"range_filter_min": (
|
2035
|
+
filter.get("min") if filter_type == "range" else None
|
2036
|
+
),
|
2037
|
+
"range_filter_max": (
|
2038
|
+
filter.get("max") if filter_type == "range" else None
|
2039
|
+
),
|
2040
|
+
"select_filter_num_options": (
|
2041
|
+
len(filter.get("choices", []))
|
2042
|
+
if filter_type == "select"
|
2043
|
+
else None
|
2044
|
+
),
|
2045
|
+
"select_filter_options": (
|
2046
|
+
filter.get("choices") if filter_type == "select" else None
|
2047
|
+
),
|
2048
|
+
}
|
2049
|
+
)
|
2050
|
+
filter_scenarios.append(scenario)
|
2051
|
+
return CoopProlificFilters(filter_scenarios)
|
2052
|
+
|
2053
|
+
@staticmethod
|
2054
|
+
def _validate_prolific_study_cost(
|
2055
|
+
estimated_completion_time_minutes: int, participant_payment_cents: int
|
2056
|
+
) -> tuple[bool, float]:
|
2057
|
+
"""
|
2058
|
+
If the cost of a Prolific study is below the threshold, return True.
|
2059
|
+
Otherwise, return False.
|
2060
|
+
The second value in the tuple is the cost of the study in USD per hour.
|
2061
|
+
"""
|
2062
|
+
estimated_completion_time_hours = estimated_completion_time_minutes / 60
|
2063
|
+
participant_payment_usd = participant_payment_cents / 100
|
2064
|
+
cost_usd_per_hour = participant_payment_usd / estimated_completion_time_hours
|
2065
|
+
|
2066
|
+
# $8.00 USD per hour is the minimum amount for using Prolific
|
2067
|
+
if cost_usd_per_hour < 8:
|
2068
|
+
return True, cost_usd_per_hour
|
2069
|
+
else:
|
2070
|
+
return False, cost_usd_per_hour
|
2071
|
+
|
2072
|
+
def create_prolific_study(
|
2073
|
+
self,
|
2074
|
+
project_uuid: str,
|
2075
|
+
name: str,
|
2076
|
+
description: str,
|
2077
|
+
num_participants: int,
|
2078
|
+
estimated_completion_time_minutes: int,
|
2079
|
+
participant_payment_cents: int,
|
2080
|
+
device_compatibility: Optional[
|
2081
|
+
List[Literal["desktop", "tablet", "mobile"]]
|
2082
|
+
] = None,
|
2083
|
+
peripheral_requirements: Optional[
|
2084
|
+
List[Literal["audio", "camera", "download", "microphone"]]
|
2085
|
+
] = None,
|
2086
|
+
filters: Optional[List[Dict]] = None,
|
2087
|
+
) -> dict:
|
2088
|
+
"""
|
2089
|
+
Create a Prolific study for a project. Returns a dict with the study details.
|
2090
|
+
|
2091
|
+
To add filters to your study, you should first pull the list of supported
|
2092
|
+
filters using Coop.list_prolific_filters().
|
2093
|
+
Then, you can use the create_study_filter method of the returned
|
2094
|
+
CoopProlificFilters object to create a valid filter dict.
|
2095
|
+
"""
|
2096
|
+
is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
|
2097
|
+
estimated_completion_time_minutes, participant_payment_cents
|
2098
|
+
)
|
2099
|
+
if is_underpayment:
|
2100
|
+
raise CoopValueError(
|
2101
|
+
f"The current participant payment of ${cost_usd_per_hour:.2f} USD per hour is below the minimum payment for using Prolific ($8.00 USD per hour)."
|
2102
|
+
)
|
2103
|
+
|
2104
|
+
response = self._send_server_request(
|
2105
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies",
|
2106
|
+
method="POST",
|
2107
|
+
payload={
|
2108
|
+
"name": name,
|
2109
|
+
"description": description,
|
2110
|
+
"total_available_places": num_participants,
|
2111
|
+
"estimated_completion_time": estimated_completion_time_minutes,
|
2112
|
+
"reward": participant_payment_cents,
|
2113
|
+
"device_compatibility": (
|
2114
|
+
["desktop", "tablet", "mobile"]
|
2115
|
+
if device_compatibility is None
|
2116
|
+
else device_compatibility
|
2117
|
+
),
|
2118
|
+
"peripheral_requirements": (
|
2119
|
+
[] if peripheral_requirements is None else peripheral_requirements
|
2120
|
+
),
|
2121
|
+
"filters": [] if filters is None else filters,
|
2122
|
+
},
|
2123
|
+
)
|
2124
|
+
self._resolve_server_response(response)
|
2125
|
+
response_json = response.json()
|
2126
|
+
return {
|
2127
|
+
"study_id": response_json.get("study_id"),
|
2128
|
+
"status": response_json.get("status"),
|
2129
|
+
"admin_url": response_json.get("admin_url"),
|
2130
|
+
"respondent_url": response_json.get("respondent_url"),
|
2131
|
+
"name": response_json.get("name"),
|
2132
|
+
"description": response_json.get("description"),
|
2133
|
+
"num_participants": response_json.get("total_available_places"),
|
2134
|
+
"estimated_completion_time_minutes": response_json.get(
|
2135
|
+
"estimated_completion_time"
|
2136
|
+
),
|
2137
|
+
"participant_payment_cents": response_json.get("reward"),
|
2138
|
+
"total_cost_cents": response_json.get("total_cost"),
|
2139
|
+
"device_compatibility": response_json.get("device_compatibility"),
|
2140
|
+
"peripheral_requirements": response_json.get("peripheral_requirements"),
|
2141
|
+
"filters": response_json.get("filters"),
|
2142
|
+
}
|
2143
|
+
|
2144
|
+
def update_prolific_study(
|
2145
|
+
self,
|
2146
|
+
project_uuid: str,
|
2147
|
+
study_id: str,
|
2148
|
+
name: Optional[str] = None,
|
2149
|
+
description: Optional[str] = None,
|
2150
|
+
num_participants: Optional[int] = None,
|
2151
|
+
estimated_completion_time_minutes: Optional[int] = None,
|
2152
|
+
participant_payment_cents: Optional[int] = None,
|
2153
|
+
device_compatibility: Optional[
|
2154
|
+
List[Literal["desktop", "tablet", "mobile"]]
|
2155
|
+
] = None,
|
2156
|
+
peripheral_requirements: Optional[
|
2157
|
+
List[Literal["audio", "camera", "download", "microphone"]]
|
2158
|
+
] = None,
|
2159
|
+
filters: Optional[List[Dict]] = None,
|
2160
|
+
) -> dict:
|
2161
|
+
"""
|
2162
|
+
Update a Prolific study. Returns a dict with the study details.
|
2163
|
+
"""
|
2164
|
+
study = self.get_prolific_study(project_uuid, study_id)
|
2165
|
+
|
2166
|
+
current_completion_time = study.get("estimated_completion_time_minutes")
|
2167
|
+
current_payment = study.get("participant_payment_cents")
|
2168
|
+
|
2169
|
+
updated_completion_time = (
|
2170
|
+
estimated_completion_time_minutes or current_completion_time
|
2171
|
+
)
|
2172
|
+
updated_payment = participant_payment_cents or current_payment
|
2173
|
+
|
2174
|
+
is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
|
2175
|
+
updated_completion_time, updated_payment
|
2176
|
+
)
|
2177
|
+
if is_underpayment:
|
2178
|
+
raise CoopValueError(
|
2179
|
+
f"This update would result in a participant payment of ${cost_usd_per_hour:.2f} USD per hour, which is below the minimum payment for using Prolific ($8.00 USD per hour)."
|
2180
|
+
)
|
2181
|
+
|
2182
|
+
payload = {}
|
2183
|
+
if name is not None:
|
2184
|
+
payload["name"] = name
|
2185
|
+
if description is not None:
|
2186
|
+
payload["description"] = description
|
2187
|
+
if num_participants is not None:
|
2188
|
+
payload["total_available_places"] = num_participants
|
2189
|
+
if estimated_completion_time_minutes is not None:
|
2190
|
+
payload["estimated_completion_time"] = estimated_completion_time_minutes
|
2191
|
+
if participant_payment_cents is not None:
|
2192
|
+
payload["reward"] = participant_payment_cents
|
2193
|
+
if device_compatibility is not None:
|
2194
|
+
payload["device_compatibility"] = device_compatibility
|
2195
|
+
if peripheral_requirements is not None:
|
2196
|
+
payload["peripheral_requirements"] = peripheral_requirements
|
2197
|
+
if filters is not None:
|
2198
|
+
payload["filters"] = filters
|
2199
|
+
|
2200
|
+
response = self._send_server_request(
|
2201
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
|
2202
|
+
method="PATCH",
|
2203
|
+
payload=payload,
|
2204
|
+
)
|
2205
|
+
self._resolve_server_response(response)
|
2206
|
+
response_json = response.json()
|
2207
|
+
return {
|
2208
|
+
"study_id": response_json.get("study_id"),
|
2209
|
+
"status": response_json.get("status"),
|
2210
|
+
"admin_url": response_json.get("admin_url"),
|
2211
|
+
"respondent_url": response_json.get("respondent_url"),
|
2212
|
+
"name": response_json.get("name"),
|
2213
|
+
"description": response_json.get("description"),
|
2214
|
+
"num_participants": response_json.get("total_available_places"),
|
2215
|
+
"estimated_completion_time_minutes": response_json.get(
|
2216
|
+
"estimated_completion_time"
|
2217
|
+
),
|
2218
|
+
"participant_payment_cents": response_json.get("reward"),
|
2219
|
+
"total_cost_cents": response_json.get("total_cost"),
|
2220
|
+
"device_compatibility": response_json.get("device_compatibility"),
|
2221
|
+
"peripheral_requirements": response_json.get("peripheral_requirements"),
|
2222
|
+
"filters": response_json.get("filters"),
|
2223
|
+
}
|
2224
|
+
|
2225
|
+
def publish_prolific_study(
|
2226
|
+
self,
|
2227
|
+
project_uuid: str,
|
2228
|
+
study_id: str,
|
2229
|
+
) -> dict:
|
2230
|
+
"""
|
2231
|
+
Publish a Prolific study.
|
2232
|
+
"""
|
2233
|
+
response = self._send_server_request(
|
2234
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/publish",
|
2235
|
+
method="POST",
|
2236
|
+
)
|
2237
|
+
self._resolve_server_response(response)
|
2238
|
+
return response.json()
|
2239
|
+
|
2240
|
+
def get_prolific_study(self, project_uuid: str, study_id: str) -> dict:
|
2241
|
+
"""
|
2242
|
+
Get a Prolific study. Returns a dict with the study details.
|
2243
|
+
"""
|
2244
|
+
response = self._send_server_request(
|
2245
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
|
2246
|
+
method="GET",
|
2247
|
+
)
|
2248
|
+
self._resolve_server_response(response)
|
2249
|
+
response_json = response.json()
|
2250
|
+
return {
|
2251
|
+
"study_id": response_json.get("study_id"),
|
2252
|
+
"status": response_json.get("status"),
|
2253
|
+
"admin_url": response_json.get("admin_url"),
|
2254
|
+
"respondent_url": response_json.get("respondent_url"),
|
2255
|
+
"name": response_json.get("name"),
|
2256
|
+
"description": response_json.get("description"),
|
2257
|
+
"num_participants": response_json.get("total_available_places"),
|
2258
|
+
"estimated_completion_time_minutes": response_json.get(
|
2259
|
+
"estimated_completion_time"
|
2260
|
+
),
|
2261
|
+
"participant_payment_cents": response_json.get("reward"),
|
2262
|
+
"total_cost_cents": response_json.get("total_cost"),
|
2263
|
+
"device_compatibility": response_json.get("device_compatibility"),
|
2264
|
+
"peripheral_requirements": response_json.get("peripheral_requirements"),
|
2265
|
+
"filters": response_json.get("filters"),
|
2266
|
+
}
|
2267
|
+
|
2268
|
+
def get_prolific_study_responses(
|
2269
|
+
self,
|
2270
|
+
project_uuid: str,
|
2271
|
+
study_id: str,
|
2272
|
+
) -> Union["Results", "ScenarioList"]:
|
2273
|
+
"""
|
2274
|
+
Return a Results object with the human responses for a project.
|
2275
|
+
|
2276
|
+
If generating the Results object fails, a ScenarioList will be returned instead.
|
2277
|
+
"""
|
2278
|
+
response = self._send_server_request(
|
2279
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/responses",
|
2280
|
+
method="GET",
|
2281
|
+
)
|
2282
|
+
self._resolve_server_response(response)
|
2283
|
+
response_json = response.json()
|
2284
|
+
human_responses = response_json.get("human_responses", [])
|
2285
|
+
survey_json_string = response_json.get("survey_json_string")
|
2286
|
+
|
2287
|
+
return self._turn_human_responses_into_results(
|
2288
|
+
human_responses, survey_json_string
|
2289
|
+
)
|
2290
|
+
|
2291
|
+
def delete_prolific_study(
|
2292
|
+
self,
|
2293
|
+
project_uuid: str,
|
2294
|
+
study_id: str,
|
2295
|
+
) -> dict:
|
2296
|
+
"""
|
2297
|
+
Deletes a Prolific study.
|
2298
|
+
|
2299
|
+
Note: Only draft studies can be deleted. Once you publish a study, it cannot be deleted.
|
2300
|
+
"""
|
2301
|
+
response = self._send_server_request(
|
2302
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
|
2303
|
+
method="DELETE",
|
2304
|
+
)
|
2305
|
+
self._resolve_server_response(response)
|
2306
|
+
return response.json()
|
2307
|
+
|
2308
|
+
def approve_prolific_study_submission(
|
2309
|
+
self,
|
2310
|
+
project_uuid: str,
|
2311
|
+
study_id: str,
|
2312
|
+
submission_id: str,
|
2313
|
+
) -> dict:
|
2314
|
+
"""
|
2315
|
+
Approve a Prolific study submission.
|
2316
|
+
"""
|
2317
|
+
response = self._send_server_request(
|
2318
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/approve",
|
2319
|
+
method="POST",
|
2320
|
+
)
|
2321
|
+
self._resolve_server_response(response)
|
2322
|
+
return response.json()
|
2323
|
+
|
2324
|
+
def reject_prolific_study_submission(
|
2325
|
+
self,
|
2326
|
+
project_uuid: str,
|
2327
|
+
study_id: str,
|
2328
|
+
submission_id: str,
|
2329
|
+
reason: Literal[
|
2330
|
+
"TOO_QUICKLY",
|
2331
|
+
"TOO_SLOWLY",
|
2332
|
+
"FAILED_INSTRUCTIONS",
|
2333
|
+
"INCOMP_LONGITUDINAL",
|
2334
|
+
"FAILED_CHECK",
|
2335
|
+
"LOW_EFFORT",
|
2336
|
+
"MALINGERING",
|
2337
|
+
"NO_CODE",
|
2338
|
+
"BAD_CODE",
|
2339
|
+
"NO_DATA",
|
2340
|
+
"UNSUPP_DEVICE",
|
2341
|
+
"OTHER",
|
2342
|
+
],
|
2343
|
+
explanation: str,
|
2344
|
+
) -> dict:
|
2345
|
+
"""
|
2346
|
+
Reject a Prolific study submission.
|
2347
|
+
"""
|
2348
|
+
valid_rejection_reasons = [
|
2349
|
+
"TOO_QUICKLY",
|
2350
|
+
"TOO_SLOWLY",
|
2351
|
+
"FAILED_INSTRUCTIONS",
|
2352
|
+
"INCOMP_LONGITUDINAL",
|
2353
|
+
"FAILED_CHECK",
|
2354
|
+
"LOW_EFFORT",
|
2355
|
+
"MALINGERING",
|
2356
|
+
"NO_CODE",
|
2357
|
+
"BAD_CODE",
|
2358
|
+
"NO_DATA",
|
2359
|
+
"UNSUPP_DEVICE",
|
2360
|
+
"OTHER",
|
2361
|
+
]
|
2362
|
+
if reason not in valid_rejection_reasons:
|
2363
|
+
raise CoopValueError(
|
2364
|
+
f"Invalid rejection reason. Please use one of the following: {valid_rejection_reasons}."
|
2365
|
+
)
|
2366
|
+
if len(explanation) < 100:
|
2367
|
+
raise CoopValueError(
|
2368
|
+
"Rejection explanation must be at least 100 characters."
|
2369
|
+
)
|
2370
|
+
response = self._send_server_request(
|
2371
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/reject",
|
2372
|
+
method="POST",
|
2373
|
+
payload={
|
2374
|
+
"reason": reason,
|
2375
|
+
"explanation": explanation,
|
2376
|
+
},
|
2377
|
+
)
|
2378
|
+
self._resolve_server_response(response)
|
2379
|
+
return response.json()
|
2380
|
+
|
1496
2381
|
def __repr__(self):
|
1497
2382
|
"""Return a string representation of the client."""
|
1498
2383
|
return f"Client(api_key='{self.api_key}', url='{self.url}')"
|
@@ -1686,6 +2571,235 @@ class Coop(CoopFunctionsMixin):
|
|
1686
2571
|
self._resolve_server_response(response)
|
1687
2572
|
return response.json().get("uuid")
|
1688
2573
|
|
2574
|
+
def pull(
|
2575
|
+
self,
|
2576
|
+
url_or_uuid: Optional[Union[str, UUID]] = None,
|
2577
|
+
expected_object_type: Optional[ObjectType] = None,
|
2578
|
+
) -> dict:
|
2579
|
+
"""
|
2580
|
+
Generate a signed URL for pulling an object directly from Google Cloud Storage.
|
2581
|
+
|
2582
|
+
This method gets a signed URL that allows direct download access to the object from
|
2583
|
+
Google Cloud Storage, which is more efficient for large files.
|
2584
|
+
|
2585
|
+
Parameters:
|
2586
|
+
url_or_uuid (Union[str, UUID], optional): Identifier for the object to retrieve.
|
2587
|
+
Can be one of:
|
2588
|
+
- UUID string (e.g., "123e4567-e89b-12d3-a456-426614174000")
|
2589
|
+
- Full URL (e.g., "https://expectedparrot.com/content/123e4567...")
|
2590
|
+
- Alias URL (e.g., "https://expectedparrot.com/content/username/my-survey")
|
2591
|
+
expected_object_type (ObjectType, optional): If provided, validates that the
|
2592
|
+
retrieved object is of the expected type (e.g., "survey", "agent")
|
2593
|
+
|
2594
|
+
Returns:
|
2595
|
+
dict: A response containing the signed_url for direct download
|
2596
|
+
|
2597
|
+
Raises:
|
2598
|
+
CoopNoUUIDError: If no UUID or URL is provided
|
2599
|
+
CoopInvalidURLError: If the URL format is invalid
|
2600
|
+
CoopServerResponseError: If there's an error communicating with the server
|
2601
|
+
HTTPException: If the object or object files are not found
|
2602
|
+
|
2603
|
+
Example:
|
2604
|
+
>>> response = coop.pull("123e4567-e89b-12d3-a456-426614174000")
|
2605
|
+
>>> response = coop.pull("https://expectedparrot.com/content/username/my-survey")
|
2606
|
+
>>> print(f"Download URL: {response['signed_url']}")
|
2607
|
+
>>> # Use the signed_url to download the object directly
|
2608
|
+
"""
|
2609
|
+
obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
|
2610
|
+
|
2611
|
+
# Handle alias-based retrieval with new/old format detection
|
2612
|
+
if not obj_uuid and owner_username and alias:
|
2613
|
+
# First, get object info to determine format and UUID
|
2614
|
+
info_response = self._send_server_request(
|
2615
|
+
uri="api/v0/object/alias/info",
|
2616
|
+
method="GET",
|
2617
|
+
params={"owner_username": owner_username, "alias": alias},
|
2618
|
+
)
|
2619
|
+
self._resolve_server_response(info_response)
|
2620
|
+
info_data = info_response.json()
|
2621
|
+
|
2622
|
+
obj_uuid = info_data.get("uuid")
|
2623
|
+
is_new_format = info_data.get("is_new_format", False)
|
2624
|
+
|
2625
|
+
# Validate object type if expected
|
2626
|
+
if expected_object_type:
|
2627
|
+
actual_object_type = info_data.get("object_type")
|
2628
|
+
if actual_object_type != expected_object_type:
|
2629
|
+
from .exceptions import CoopObjectTypeError
|
2630
|
+
|
2631
|
+
raise CoopObjectTypeError(
|
2632
|
+
f"Expected {expected_object_type=} but got {actual_object_type=}"
|
2633
|
+
)
|
2634
|
+
|
2635
|
+
# Use get method for old format objects
|
2636
|
+
if not is_new_format:
|
2637
|
+
return self.get(url_or_uuid, expected_object_type)
|
2638
|
+
|
2639
|
+
# Send the request to the API endpoint with the resolved UUID
|
2640
|
+
response = self._send_server_request(
|
2641
|
+
uri="api/v0/object/pull",
|
2642
|
+
method="POST",
|
2643
|
+
payload={"object_uuid": obj_uuid},
|
2644
|
+
)
|
2645
|
+
# Handle any errors in the response
|
2646
|
+
self._resolve_server_response(response)
|
2647
|
+
if "signed_url" not in response.json():
|
2648
|
+
from .exceptions import CoopResponseError
|
2649
|
+
|
2650
|
+
raise CoopResponseError("No signed url was provided.")
|
2651
|
+
signed_url = response.json().get("signed_url")
|
2652
|
+
|
2653
|
+
if signed_url == "": # it is in old format
|
2654
|
+
return self.get(url_or_uuid, expected_object_type)
|
2655
|
+
|
2656
|
+
try:
|
2657
|
+
response = requests.get(signed_url)
|
2658
|
+
|
2659
|
+
self._resolve_gcs_response(response)
|
2660
|
+
|
2661
|
+
except Exception:
|
2662
|
+
return self.get(url_or_uuid, expected_object_type)
|
2663
|
+
object_dict = response.json()
|
2664
|
+
if expected_object_type is not None:
|
2665
|
+
edsl_class = ObjectRegistry.get_edsl_class_by_object_type(
|
2666
|
+
expected_object_type
|
2667
|
+
)
|
2668
|
+
edsl_object = edsl_class.from_dict(object_dict)
|
2669
|
+
# Return the response containing the signed URL
|
2670
|
+
return edsl_object
|
2671
|
+
|
2672
|
+
def get_upload_url(self, object_uuid: str) -> dict:
|
2673
|
+
"""
|
2674
|
+
Get a signed upload URL for updating the content of an existing object.
|
2675
|
+
|
2676
|
+
This method gets a signed URL that allows direct upload to Google Cloud Storage
|
2677
|
+
for objects stored in the new format, while preserving the existing UUID.
|
2678
|
+
|
2679
|
+
Parameters:
|
2680
|
+
object_uuid (str): The UUID of the object to get an upload URL for
|
2681
|
+
|
2682
|
+
Returns:
|
2683
|
+
dict: A response containing:
|
2684
|
+
- signed_url: The signed URL for uploading new content
|
2685
|
+
- object_uuid: The UUID of the object
|
2686
|
+
- message: Success message
|
2687
|
+
|
2688
|
+
Raises:
|
2689
|
+
CoopServerResponseError: If there's an error communicating with the server
|
2690
|
+
HTTPException: If the object is not found, not owned by user, or not in new format
|
2691
|
+
|
2692
|
+
Notes:
|
2693
|
+
- Only works with objects stored in the new format (transition table)
|
2694
|
+
- User must be the owner of the object
|
2695
|
+
- The signed URL expires after 60 minutes
|
2696
|
+
|
2697
|
+
Example:
|
2698
|
+
>>> response = coop.get_upload_url("123e4567-e89b-12d3-a456-426614174000")
|
2699
|
+
>>> upload_url = response['signed_url']
|
2700
|
+
>>> # Use the upload_url to PUT new content directly to GCS
|
2701
|
+
"""
|
2702
|
+
response = self._send_server_request(
|
2703
|
+
uri="api/v0/object/upload-url",
|
2704
|
+
method="POST",
|
2705
|
+
payload={"object_uuid": object_uuid},
|
2706
|
+
)
|
2707
|
+
self._resolve_server_response(response)
|
2708
|
+
return response.json()
|
2709
|
+
|
2710
|
+
def push(
|
2711
|
+
self,
|
2712
|
+
object: EDSLObject,
|
2713
|
+
description: Optional[str] = None,
|
2714
|
+
alias: Optional[str] = None,
|
2715
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
2716
|
+
) -> dict:
|
2717
|
+
"""
|
2718
|
+
Generate a signed URL for pushing an object directly to Google Cloud Storage.
|
2719
|
+
|
2720
|
+
This method gets a signed URL that allows direct upload access to Google Cloud Storage,
|
2721
|
+
which is more efficient for large files.
|
2722
|
+
|
2723
|
+
Parameters:
|
2724
|
+
object_type (ObjectType): The type of object to be uploaded
|
2725
|
+
|
2726
|
+
Returns:
|
2727
|
+
dict: A response containing the signed_url for direct upload and optionally a job_id
|
2728
|
+
|
2729
|
+
Raises:
|
2730
|
+
CoopServerResponseError: If there's an error communicating with the server
|
2731
|
+
|
2732
|
+
Example:
|
2733
|
+
>>> response = coop.push("scenario")
|
2734
|
+
>>> print(f"Upload URL: {response['signed_url']}")
|
2735
|
+
>>> # Use the signed_url to upload the object directly
|
2736
|
+
"""
|
2737
|
+
|
2738
|
+
object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
|
2739
|
+
object_dict = object.to_dict()
|
2740
|
+
object_hash = object.get_hash() if hasattr(object, "get_hash") else None
|
2741
|
+
|
2742
|
+
# Send the request to the API endpoint
|
2743
|
+
response = self._send_server_request(
|
2744
|
+
uri="api/v0/object/push",
|
2745
|
+
method="POST",
|
2746
|
+
payload={
|
2747
|
+
"object_type": object_type,
|
2748
|
+
"description": description,
|
2749
|
+
"alias": alias,
|
2750
|
+
"visibility": visibility,
|
2751
|
+
"object_hash": object_hash,
|
2752
|
+
"version": self._edsl_version,
|
2753
|
+
},
|
2754
|
+
)
|
2755
|
+
response_json = response.json()
|
2756
|
+
if response_json.get("signed_url") is not None:
|
2757
|
+
signed_url = response_json.get("signed_url")
|
2758
|
+
else:
|
2759
|
+
from .exceptions import CoopResponseError
|
2760
|
+
|
2761
|
+
raise CoopResponseError(response.text)
|
2762
|
+
|
2763
|
+
json_data = json.dumps(
|
2764
|
+
object_dict,
|
2765
|
+
default=self._json_handle_none,
|
2766
|
+
allow_nan=False,
|
2767
|
+
)
|
2768
|
+
response = requests.put(
|
2769
|
+
signed_url,
|
2770
|
+
data=json_data.encode(),
|
2771
|
+
headers={"Content-Type": "application/json"},
|
2772
|
+
)
|
2773
|
+
self._resolve_gcs_response(response)
|
2774
|
+
|
2775
|
+
# Send confirmation that upload was completed
|
2776
|
+
object_uuid = response_json.get("object_uuid", None)
|
2777
|
+
owner_username = response_json.get("owner_username", None)
|
2778
|
+
object_alias = response_json.get("alias", None)
|
2779
|
+
|
2780
|
+
if object_uuid is None:
|
2781
|
+
from .exceptions import CoopResponseError
|
2782
|
+
|
2783
|
+
raise CoopResponseError("No object uuid was provided received")
|
2784
|
+
|
2785
|
+
# Confirm the upload completion
|
2786
|
+
confirm_response = self._send_server_request(
|
2787
|
+
uri="api/v0/object/confirm-upload",
|
2788
|
+
method="POST",
|
2789
|
+
payload={"object_uuid": object_uuid},
|
2790
|
+
)
|
2791
|
+
self._resolve_server_response(confirm_response)
|
2792
|
+
|
2793
|
+
return {
|
2794
|
+
"description": response_json.get("description"),
|
2795
|
+
"object_type": object_type,
|
2796
|
+
"url": f"{self.url}/content/{object_uuid}",
|
2797
|
+
"alias_url": self._get_alias_url(owner_username, object_alias),
|
2798
|
+
"uuid": object_uuid,
|
2799
|
+
"version": self._edsl_version,
|
2800
|
+
"visibility": response_json.get("visibility"),
|
2801
|
+
}
|
2802
|
+
|
1689
2803
|
def _display_login_url(
|
1690
2804
|
self, edsl_auth_token: str, link_description: Optional[str] = None
|
1691
2805
|
):
|
@@ -1769,6 +2883,125 @@ class Coop(CoopFunctionsMixin):
|
|
1769
2883
|
# Add API key to environment
|
1770
2884
|
load_dotenv()
|
1771
2885
|
|
2886
|
+
def login_streamlit(self, timeout: int = 120):
|
2887
|
+
"""
|
2888
|
+
Start the EDSL auth token login flow inside a Streamlit application.
|
2889
|
+
|
2890
|
+
This helper is functionally equivalent to ``Coop.login`` but renders the
|
2891
|
+
login link and status updates directly in the Streamlit UI. The method
|
2892
|
+
will automatically poll the Expected Parrot server for the API-key
|
2893
|
+
associated with the generated auth-token and, once received, store it
|
2894
|
+
via ``ExpectedParrotKeyHandler`` and write it to the local ``.env``
|
2895
|
+
file so subsequent sessions pick it up automatically.
|
2896
|
+
|
2897
|
+
Parameters
|
2898
|
+
----------
|
2899
|
+
timeout : int, default 120
|
2900
|
+
How many seconds to wait for the user to complete the login before
|
2901
|
+
giving up and showing an error in the Streamlit app.
|
2902
|
+
|
2903
|
+
Returns
|
2904
|
+
-------
|
2905
|
+
str | None
|
2906
|
+
The API-key if the user logged-in successfully, otherwise ``None``.
|
2907
|
+
"""
|
2908
|
+
try:
|
2909
|
+
import streamlit as st
|
2910
|
+
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
2911
|
+
except ModuleNotFoundError as exc:
|
2912
|
+
raise ImportError(
|
2913
|
+
"Streamlit is required for `login_streamlit`. Install it with `pip install streamlit`."
|
2914
|
+
) from exc
|
2915
|
+
|
2916
|
+
# Ensure we are actually running inside a Streamlit script. If not, give a
|
2917
|
+
# clear error message instead of crashing when `st.experimental_rerun` is
|
2918
|
+
# invoked outside the Streamlit runtime.
|
2919
|
+
if get_script_run_ctx() is None:
|
2920
|
+
raise RuntimeError(
|
2921
|
+
"`login_streamlit` must be invoked from within a running Streamlit "
|
2922
|
+
"app (use `streamlit run your_script.py`). If you need to obtain an "
|
2923
|
+
"API-key in a regular Python script or notebook, use `Coop.login()` "
|
2924
|
+
"instead."
|
2925
|
+
)
|
2926
|
+
|
2927
|
+
import secrets
|
2928
|
+
import time
|
2929
|
+
import os
|
2930
|
+
from dotenv import load_dotenv
|
2931
|
+
from .ep_key_handling import ExpectedParrotKeyHandler
|
2932
|
+
from ..utilities.utilities import write_api_key_to_env
|
2933
|
+
|
2934
|
+
# ------------------------------------------------------------------
|
2935
|
+
# 1. Prepare auth-token and store state across reruns
|
2936
|
+
# ------------------------------------------------------------------
|
2937
|
+
if "edsl_auth_token" not in st.session_state:
|
2938
|
+
st.session_state.edsl_auth_token = secrets.token_urlsafe(16)
|
2939
|
+
st.session_state.login_start_time = time.time()
|
2940
|
+
|
2941
|
+
edsl_auth_token: str = st.session_state.edsl_auth_token
|
2942
|
+
login_url = (
|
2943
|
+
f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
2944
|
+
)
|
2945
|
+
|
2946
|
+
# ------------------------------------------------------------------
|
2947
|
+
# 2. Render clickable login link
|
2948
|
+
# ------------------------------------------------------------------
|
2949
|
+
st.markdown(
|
2950
|
+
f"🔗 **Log in to Expected Parrot** → [click here]({login_url})",
|
2951
|
+
unsafe_allow_html=True,
|
2952
|
+
)
|
2953
|
+
|
2954
|
+
# ------------------------------------------------------------------
|
2955
|
+
# 3. Poll server for API-key (runs once per Streamlit execution)
|
2956
|
+
# ------------------------------------------------------------------
|
2957
|
+
api_key = self._get_api_key(edsl_auth_token)
|
2958
|
+
if api_key is None:
|
2959
|
+
elapsed = time.time() - st.session_state.login_start_time
|
2960
|
+
if elapsed > timeout:
|
2961
|
+
st.error(
|
2962
|
+
"Timed-out waiting for login. Please rerun the app to try again."
|
2963
|
+
)
|
2964
|
+
return None
|
2965
|
+
|
2966
|
+
remaining = int(timeout - elapsed)
|
2967
|
+
st.info(f"Waiting for login… ({remaining}s left)")
|
2968
|
+
# Trigger a rerun after a short delay to continue polling
|
2969
|
+
time.sleep(1)
|
2970
|
+
|
2971
|
+
# Attempt a rerun in a version-agnostic way. Different Streamlit
|
2972
|
+
# releases expose the helper under different names.
|
2973
|
+
def _safe_rerun():
|
2974
|
+
if hasattr(st, "experimental_rerun"):
|
2975
|
+
st.experimental_rerun()
|
2976
|
+
elif hasattr(st, "rerun"):
|
2977
|
+
st.rerun() # introduced in newer versions
|
2978
|
+
else:
|
2979
|
+
# Fallback – advise the user to update Streamlit for automatic polling.
|
2980
|
+
st.warning(
|
2981
|
+
"Please refresh the page to continue the login flow. "
|
2982
|
+
"(Consider upgrading Streamlit to enable automatic refresh.)"
|
2983
|
+
)
|
2984
|
+
|
2985
|
+
try:
|
2986
|
+
_safe_rerun()
|
2987
|
+
except Exception:
|
2988
|
+
# The Streamlit runtime intercepts the rerun exception; any other
|
2989
|
+
# unexpected errors are ignored to avoid crashing the app.
|
2990
|
+
pass
|
2991
|
+
|
2992
|
+
# ------------------------------------------------------------------
|
2993
|
+
# 4. Key received – persist it and notify user
|
2994
|
+
# ------------------------------------------------------------------
|
2995
|
+
ExpectedParrotKeyHandler().store_ep_api_key(api_key)
|
2996
|
+
os.environ["EXPECTED_PARROT_API_KEY"] = api_key
|
2997
|
+
path_to_env = write_api_key_to_env(api_key)
|
2998
|
+
load_dotenv()
|
2999
|
+
|
3000
|
+
st.success("API-key retrieved and stored. You are now logged-in! 🎉")
|
3001
|
+
st.caption(f"Key saved to `{path_to_env}`.")
|
3002
|
+
|
3003
|
+
return api_key
|
3004
|
+
|
1772
3005
|
def transfer_credits(
|
1773
3006
|
self,
|
1774
3007
|
credits_transferred: int,
|
@@ -1816,6 +3049,53 @@ class Coop(CoopFunctionsMixin):
|
|
1816
3049
|
self._resolve_server_response(response)
|
1817
3050
|
return response.json()
|
1818
3051
|
|
3052
|
+
def pay_for_service(
|
3053
|
+
self,
|
3054
|
+
credits_transferred: int,
|
3055
|
+
recipient_username: str,
|
3056
|
+
service_name: str,
|
3057
|
+
) -> dict:
|
3058
|
+
"""
|
3059
|
+
Pay for a service.
|
3060
|
+
|
3061
|
+
This method transfers a specified number of credits from the authenticated user's
|
3062
|
+
account to another user's account on the Expected Parrot platform.
|
3063
|
+
|
3064
|
+
Parameters:
|
3065
|
+
credits_transferred (int): The number of credits to transfer to the recipient
|
3066
|
+
recipient_username (str): The username of the recipient
|
3067
|
+
service_name (str): The name of the service to pay for
|
3068
|
+
|
3069
|
+
Returns:
|
3070
|
+
dict: Information about the transfer transaction, including:
|
3071
|
+
- success: Whether the transaction was successful
|
3072
|
+
- transaction_id: A unique identifier for the transaction
|
3073
|
+
- remaining_credits: The number of credits remaining in the sender's account
|
3074
|
+
|
3075
|
+
Raises:
|
3076
|
+
CoopServerResponseError: If there's an error communicating with the server
|
3077
|
+
or if the transfer criteria aren't met (e.g., insufficient credits)
|
3078
|
+
|
3079
|
+
Example:
|
3080
|
+
>>> result = coop.pay_for_service(
|
3081
|
+
... credits_transferred=100,
|
3082
|
+
... service_name="service_name",
|
3083
|
+
... recipient_username="friend_username",
|
3084
|
+
... )
|
3085
|
+
>>> print(f"Transfer successful! You have {result['remaining_credits']} credits left.")
|
3086
|
+
"""
|
3087
|
+
response = self._send_server_request(
|
3088
|
+
uri="api/v0/users/pay-for-service",
|
3089
|
+
method="POST",
|
3090
|
+
payload={
|
3091
|
+
"cost_credits": credits_transferred,
|
3092
|
+
"service_name": service_name,
|
3093
|
+
"recipient_username": recipient_username,
|
3094
|
+
},
|
3095
|
+
)
|
3096
|
+
self._resolve_server_response(response)
|
3097
|
+
return response.json()
|
3098
|
+
|
1819
3099
|
def get_balance(self) -> dict:
|
1820
3100
|
"""
|
1821
3101
|
Get the current credit balance for the authenticated user.
|
@@ -1835,10 +3115,178 @@ class Coop(CoopFunctionsMixin):
|
|
1835
3115
|
>>> balance = coop.get_balance()
|
1836
3116
|
>>> print(f"You have {balance['credits']} credits available.")
|
1837
3117
|
"""
|
1838
|
-
response = self._send_server_request(
|
3118
|
+
response = self._send_server_request(
|
3119
|
+
uri="api/v0/users/get-balance", method="GET"
|
3120
|
+
)
|
3121
|
+
self._resolve_server_response(response)
|
3122
|
+
return response.json()
|
3123
|
+
|
3124
|
+
def get_profile(self) -> dict:
|
3125
|
+
"""
|
3126
|
+
Get the current user's profile information.
|
3127
|
+
|
3128
|
+
This method retrieves the authenticated user's profile information from
|
3129
|
+
the Expected Parrot platform using their API key.
|
3130
|
+
|
3131
|
+
Returns:
|
3132
|
+
dict: User profile information including:
|
3133
|
+
- username: The user's username
|
3134
|
+
- email: The user's email address
|
3135
|
+
|
3136
|
+
Raises:
|
3137
|
+
CoopServerResponseError: If there's an error communicating with the server
|
3138
|
+
|
3139
|
+
Example:
|
3140
|
+
>>> profile = coop.get_profile()
|
3141
|
+
>>> print(f"Welcome, {profile['username']}!")
|
3142
|
+
"""
|
3143
|
+
response = self._send_server_request(uri="api/v0/users/profile", method="GET")
|
1839
3144
|
self._resolve_server_response(response)
|
1840
3145
|
return response.json()
|
1841
3146
|
|
3147
|
+
def login_gradio(self, timeout: int = 120, launch: bool = True, **launch_kwargs):
|
3148
|
+
"""
|
3149
|
+
Start the EDSL auth token login flow inside a **Gradio** application.
|
3150
|
+
|
3151
|
+
This helper mirrors the behaviour of :py:meth:`Coop.login_streamlit` but
|
3152
|
+
renders the login link and status updates inside a Gradio UI. It will
|
3153
|
+
poll the Expected Parrot server for the API-key associated with a newly
|
3154
|
+
generated auth-token and, once received, store it via
|
3155
|
+
:pyclass:`~edsl.coop.ep_key_handling.ExpectedParrotKeyHandler` as well as
|
3156
|
+
in the local ``.env`` file so subsequent sessions pick it up
|
3157
|
+
automatically.
|
3158
|
+
|
3159
|
+
Parameters
|
3160
|
+
----------
|
3161
|
+
timeout : int, default 120
|
3162
|
+
How many seconds to wait for the user to complete the login before
|
3163
|
+
giving up.
|
3164
|
+
launch : bool, default True
|
3165
|
+
If ``True`` the Gradio app is immediately launched with
|
3166
|
+
``demo.launch(**launch_kwargs)``. Set this to ``False`` if you want
|
3167
|
+
to embed the returned :class:`gradio.Blocks` object into an existing
|
3168
|
+
Gradio interface.
|
3169
|
+
**launch_kwargs
|
3170
|
+
Additional keyword-arguments forwarded to ``gr.Blocks.launch`` when
|
3171
|
+
*launch* is ``True``.
|
3172
|
+
|
3173
|
+
Returns
|
3174
|
+
-------
|
3175
|
+
str | gradio.Blocks | None
|
3176
|
+
• If the API-key is retrieved within *timeout* seconds while the
|
3177
|
+
function is executing (e.g. when *launch* is ``False`` and the
|
3178
|
+
caller integrates the Blocks into another app) the key is
|
3179
|
+
returned.
|
3180
|
+
• If *launch* is ``True`` the method returns ``None`` after the
|
3181
|
+
Gradio app has been launched.
|
3182
|
+
• If *launch* is ``False`` the constructed ``gr.Blocks`` is
|
3183
|
+
returned so the caller can compose it further.
|
3184
|
+
"""
|
3185
|
+
try:
|
3186
|
+
import gradio as gr
|
3187
|
+
except ModuleNotFoundError as exc:
|
3188
|
+
raise ImportError(
|
3189
|
+
"Gradio is required for `login_gradio`. Install it with `pip install gradio`."
|
3190
|
+
) from exc
|
3191
|
+
|
3192
|
+
import secrets
|
3193
|
+
import time
|
3194
|
+
import os
|
3195
|
+
from dotenv import load_dotenv
|
3196
|
+
from .ep_key_handling import ExpectedParrotKeyHandler
|
3197
|
+
from ..utilities.utilities import write_api_key_to_env
|
3198
|
+
|
3199
|
+
# ------------------------------------------------------------------
|
3200
|
+
# 1. Prepare auth-token
|
3201
|
+
# ------------------------------------------------------------------
|
3202
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
3203
|
+
login_url = (
|
3204
|
+
f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
3205
|
+
)
|
3206
|
+
start_time = time.time()
|
3207
|
+
|
3208
|
+
# ------------------------------------------------------------------
|
3209
|
+
# 2. Build Gradio interface
|
3210
|
+
# ------------------------------------------------------------------
|
3211
|
+
with gr.Blocks() as demo:
|
3212
|
+
gr.HTML(
|
3213
|
+
f'🔗 <b>Log in to Expected Parrot</b> → <a href="{login_url}" target="_blank">click here</a>'
|
3214
|
+
)
|
3215
|
+
status_md = gr.Markdown("Waiting for login…")
|
3216
|
+
refresh_btn = gr.Button(
|
3217
|
+
"I've logged in – click to continue", elem_id="refresh-btn"
|
3218
|
+
)
|
3219
|
+
key_state = gr.State(value=None)
|
3220
|
+
|
3221
|
+
# --------------------------------------------------------------
|
3222
|
+
# Polling callback
|
3223
|
+
# --------------------------------------------------------------
|
3224
|
+
def _refresh(current_key): # noqa: D401, pylint: disable=unused-argument
|
3225
|
+
"""Poll server for API-key and update UI accordingly."""
|
3226
|
+
|
3227
|
+
# Fallback helper to generate a `update` object for the refresh button
|
3228
|
+
def _button_update(**kwargs):
|
3229
|
+
try:
|
3230
|
+
return gr.Button.update(**kwargs)
|
3231
|
+
except AttributeError:
|
3232
|
+
return gr.update(**kwargs)
|
3233
|
+
|
3234
|
+
api_key = self._get_api_key(edsl_auth_token)
|
3235
|
+
# Fall back to env var in case the key was obtained earlier in this session
|
3236
|
+
if not api_key:
|
3237
|
+
api_key = os.environ.get("EXPECTED_PARROT_API_KEY")
|
3238
|
+
elapsed = time.time() - start_time
|
3239
|
+
remaining = max(0, int(timeout - elapsed))
|
3240
|
+
|
3241
|
+
if api_key:
|
3242
|
+
# Persist and expose the key
|
3243
|
+
ExpectedParrotKeyHandler().store_ep_api_key(api_key)
|
3244
|
+
os.environ["EXPECTED_PARROT_API_KEY"] = api_key
|
3245
|
+
path_to_env = write_api_key_to_env(api_key)
|
3246
|
+
load_dotenv()
|
3247
|
+
success_msg = (
|
3248
|
+
"API-key retrieved and stored 🎉\n\n"
|
3249
|
+
f"Key saved to `{path_to_env}`."
|
3250
|
+
)
|
3251
|
+
return (
|
3252
|
+
success_msg,
|
3253
|
+
_button_update(interactive=False, visible=False),
|
3254
|
+
api_key,
|
3255
|
+
)
|
3256
|
+
|
3257
|
+
if elapsed > timeout:
|
3258
|
+
err_msg = (
|
3259
|
+
"Timed-out waiting for login. Please refresh the page "
|
3260
|
+
"or restart the app to try again."
|
3261
|
+
)
|
3262
|
+
return (
|
3263
|
+
err_msg,
|
3264
|
+
_button_update(),
|
3265
|
+
None,
|
3266
|
+
)
|
3267
|
+
|
3268
|
+
info_msg = f"Waiting for login… ({remaining}s left)"
|
3269
|
+
return (
|
3270
|
+
info_msg,
|
3271
|
+
_button_update(),
|
3272
|
+
None,
|
3273
|
+
)
|
3274
|
+
|
3275
|
+
# Initial status check when the interface loads
|
3276
|
+
demo.load(
|
3277
|
+
fn=_refresh,
|
3278
|
+
inputs=key_state,
|
3279
|
+
outputs=[status_md, refresh_btn, key_state],
|
3280
|
+
)
|
3281
|
+
|
3282
|
+
# ------------------------------------------------------------------
|
3283
|
+
# 3. Launch or return interface
|
3284
|
+
# ------------------------------------------------------------------
|
3285
|
+
if launch:
|
3286
|
+
demo.launch(**launch_kwargs)
|
3287
|
+
return None
|
3288
|
+
return demo
|
3289
|
+
|
1842
3290
|
|
1843
3291
|
def main():
|
1844
3292
|
"""
|
@@ -1973,5 +3421,14 @@ def main():
|
|
1973
3421
|
job = Jobs.example()
|
1974
3422
|
coop.remote_inference_cost(job)
|
1975
3423
|
job_coop_object = coop.remote_inference_create(job)
|
1976
|
-
job_coop_results = coop.
|
3424
|
+
job_coop_results = coop.new_remote_inference_get(job_coop_object.get("uuid"))
|
1977
3425
|
coop.get(job_coop_results.get("results_uuid"))
|
3426
|
+
|
3427
|
+
import streamlit as st
|
3428
|
+
from edsl.coop import Coop
|
3429
|
+
|
3430
|
+
coop = Coop() # no API-key required yet
|
3431
|
+
api_key = coop.login_streamlit() # renders link + handles polling & storage
|
3432
|
+
|
3433
|
+
if api_key:
|
3434
|
+
st.success("Ready to use EDSL with remote features!")
|