edsl 0.1.59__py3-none-any.whl → 0.1.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +65 -17
- edsl/agents/agent_list.py +117 -33
- edsl/base/base_class.py +80 -11
- edsl/base/data_transfer_models.py +5 -0
- edsl/base/enums.py +7 -2
- edsl/config/config_class.py +7 -2
- edsl/coop/coop.py +1295 -85
- edsl/coop/coop_prolific_filters.py +171 -0
- edsl/dataset/dataset_operations_mixin.py +2 -2
- edsl/dataset/display/table_display.py +40 -7
- edsl/db_list/sqlite_list.py +102 -3
- edsl/inference_services/services/__init__.py +3 -1
- edsl/inference_services/services/open_ai_service_v2.py +243 -0
- edsl/jobs/data_structures.py +48 -30
- edsl/jobs/jobs.py +73 -2
- edsl/jobs/remote_inference.py +49 -15
- edsl/key_management/key_lookup_builder.py +25 -3
- edsl/language_models/language_model.py +2 -1
- edsl/language_models/raw_response_handler.py +126 -7
- edsl/questions/loop_processor.py +289 -10
- edsl/questions/templates/dict/answering_instructions.jinja +0 -1
- edsl/results/result.py +37 -0
- edsl/results/results.py +1 -0
- edsl/scenarios/scenario_list.py +31 -1
- edsl/scenarios/scenario_source.py +606 -498
- edsl/surveys/survey.py +198 -163
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/METADATA +4 -4
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/RECORD +32 -30
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/LICENSE +0 -0
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/WHEEL +0 -0
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/entry_points.txt +0 -0
edsl/coop/coop.py
CHANGED
@@ -3,8 +3,9 @@ import base64
|
|
3
3
|
import json
|
4
4
|
import requests
|
5
5
|
import time
|
6
|
+
import os
|
6
7
|
|
7
|
-
from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
|
8
|
+
from typing import Any, Dict, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
|
8
9
|
from uuid import UUID
|
9
10
|
|
10
11
|
from .. import __version__
|
@@ -35,6 +36,7 @@ from .utils import (
|
|
35
36
|
from .coop_functions import CoopFunctionsMixin
|
36
37
|
from .coop_regular_objects import CoopRegularObjects
|
37
38
|
from .coop_jobs_objects import CoopJobsObjects
|
39
|
+
from .coop_prolific_filters import CoopProlificFilters
|
38
40
|
from .ep_key_handling import ExpectedParrotKeyHandler
|
39
41
|
|
40
42
|
from ..inference_services.data_structures import ServiceToModelsMapping
|
@@ -66,7 +68,6 @@ class JobRunInterviewDetails(TypedDict):
|
|
66
68
|
|
67
69
|
|
68
70
|
class LatestJobRunDetails(TypedDict):
|
69
|
-
|
70
71
|
# For running, completed, and partially failed jobs
|
71
72
|
interview_details: Optional[JobRunInterviewDetails] = None
|
72
73
|
|
@@ -296,9 +297,9 @@ class Coop(CoopFunctionsMixin):
|
|
296
297
|
message = str(response.json().get("detail"))
|
297
298
|
except json.JSONDecodeError:
|
298
299
|
raise CoopServerResponseError(
|
299
|
-
f"Server returned status code {response.status_code}."
|
300
|
-
"JSON response could not be decoded."
|
301
|
-
"The server response was:
|
300
|
+
f"Server returned status code {response.status_code}. "
|
301
|
+
f"JSON response could not be decoded. "
|
302
|
+
f"The server response was: {response.text}"
|
302
303
|
)
|
303
304
|
# print(response.text)
|
304
305
|
if "The API key you provided is invalid" in message and check_api_key:
|
@@ -694,6 +695,35 @@ class Coop(CoopFunctionsMixin):
|
|
694
695
|
"""
|
695
696
|
obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
|
696
697
|
|
698
|
+
# Handle alias-based retrieval with new/old format detection
|
699
|
+
if not obj_uuid and owner_username and alias:
|
700
|
+
# First, get object info to determine format and UUID
|
701
|
+
info_response = self._send_server_request(
|
702
|
+
uri="api/v0/object/alias/info",
|
703
|
+
method="GET",
|
704
|
+
params={"owner_username": owner_username, "alias": alias},
|
705
|
+
)
|
706
|
+
self._resolve_server_response(info_response)
|
707
|
+
info_data = info_response.json()
|
708
|
+
|
709
|
+
obj_uuid = info_data.get("uuid")
|
710
|
+
is_new_format = info_data.get("is_new_format", False)
|
711
|
+
|
712
|
+
# Validate object type if expected
|
713
|
+
if expected_object_type:
|
714
|
+
actual_object_type = info_data.get("object_type")
|
715
|
+
if actual_object_type != expected_object_type:
|
716
|
+
from .exceptions import CoopObjectTypeError
|
717
|
+
|
718
|
+
raise CoopObjectTypeError(
|
719
|
+
f"Expected {expected_object_type=} but got {actual_object_type=}"
|
720
|
+
)
|
721
|
+
|
722
|
+
# Use pull method for new format objects
|
723
|
+
if is_new_format:
|
724
|
+
return self.pull(obj_uuid, expected_object_type)
|
725
|
+
|
726
|
+
# Handle UUID-based retrieval or legacy alias objects
|
697
727
|
if obj_uuid:
|
698
728
|
response = self._send_server_request(
|
699
729
|
uri="api/v0/object",
|
@@ -915,6 +945,26 @@ class Coop(CoopFunctionsMixin):
|
|
915
945
|
|
916
946
|
obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
|
917
947
|
|
948
|
+
# If we have a UUID and are updating the value, check the storage format first
|
949
|
+
if obj_uuid and value:
|
950
|
+
# Check if object is in new format (GCS)
|
951
|
+
format_check_response = self._send_server_request(
|
952
|
+
uri="api/v0/object/check-format",
|
953
|
+
method="POST",
|
954
|
+
payload={"object_uuid": str(obj_uuid)},
|
955
|
+
)
|
956
|
+
self._resolve_server_response(format_check_response)
|
957
|
+
format_data = format_check_response.json()
|
958
|
+
|
959
|
+
is_new_format = format_data.get("is_new_format", False)
|
960
|
+
|
961
|
+
if is_new_format:
|
962
|
+
# Handle new format objects: update metadata first, then upload content
|
963
|
+
return self._patch_new_format_object(
|
964
|
+
obj_uuid, description, alias, value, visibility
|
965
|
+
)
|
966
|
+
|
967
|
+
# Handle traditional format objects or metadata-only updates
|
918
968
|
if obj_uuid:
|
919
969
|
uri = "api/v0/object"
|
920
970
|
params = {"uuid": obj_uuid}
|
@@ -944,6 +994,70 @@ class Coop(CoopFunctionsMixin):
|
|
944
994
|
self._resolve_server_response(response)
|
945
995
|
return response.json()
|
946
996
|
|
997
|
+
def _patch_new_format_object(
|
998
|
+
self,
|
999
|
+
obj_uuid: UUID,
|
1000
|
+
description: Optional[str],
|
1001
|
+
alias: Optional[str],
|
1002
|
+
value: EDSLObject,
|
1003
|
+
visibility: Optional[VisibilityType],
|
1004
|
+
) -> dict:
|
1005
|
+
"""
|
1006
|
+
Handle patching of objects stored in the new format (GCS).
|
1007
|
+
"""
|
1008
|
+
# Step 1: Update metadata only (no json_string)
|
1009
|
+
if description is not None or alias is not None or visibility is not None:
|
1010
|
+
metadata_response = self._send_server_request(
|
1011
|
+
uri="api/v0/object",
|
1012
|
+
method="PATCH",
|
1013
|
+
params={"uuid": obj_uuid},
|
1014
|
+
payload={
|
1015
|
+
"description": description,
|
1016
|
+
"alias": alias,
|
1017
|
+
"json_string": None, # Don't send content to traditional endpoint
|
1018
|
+
"visibility": visibility,
|
1019
|
+
},
|
1020
|
+
)
|
1021
|
+
self._resolve_server_response(metadata_response)
|
1022
|
+
|
1023
|
+
# Step 2: Get signed upload URL for content update
|
1024
|
+
upload_url_response = self._send_server_request(
|
1025
|
+
uri="api/v0/object/upload-url",
|
1026
|
+
method="POST",
|
1027
|
+
payload={"object_uuid": str(obj_uuid)},
|
1028
|
+
)
|
1029
|
+
self._resolve_server_response(upload_url_response)
|
1030
|
+
upload_data = upload_url_response.json()
|
1031
|
+
|
1032
|
+
# Step 3: Upload the object content to GCS
|
1033
|
+
signed_url = upload_data.get("signed_url")
|
1034
|
+
if not signed_url:
|
1035
|
+
raise CoopServerResponseError("Failed to get signed upload URL")
|
1036
|
+
|
1037
|
+
json_content = json.dumps(
|
1038
|
+
value.to_dict(),
|
1039
|
+
default=self._json_handle_none,
|
1040
|
+
allow_nan=False,
|
1041
|
+
)
|
1042
|
+
|
1043
|
+
# Upload to GCS using signed URL
|
1044
|
+
gcs_response = requests.put(
|
1045
|
+
signed_url,
|
1046
|
+
data=json_content,
|
1047
|
+
headers={"Content-Type": "application/json"},
|
1048
|
+
)
|
1049
|
+
|
1050
|
+
if gcs_response.status_code != 200:
|
1051
|
+
raise CoopServerResponseError(
|
1052
|
+
f"Failed to upload object to GCS: {gcs_response.status_code}"
|
1053
|
+
)
|
1054
|
+
|
1055
|
+
return {
|
1056
|
+
"status": "success",
|
1057
|
+
"message": "Object updated successfully (new format - uploaded to GCS)",
|
1058
|
+
"object_uuid": str(obj_uuid),
|
1059
|
+
}
|
1060
|
+
|
947
1061
|
################
|
948
1062
|
# Remote Cache
|
949
1063
|
################
|
@@ -1025,6 +1139,115 @@ class Coop(CoopFunctionsMixin):
|
|
1025
1139
|
is handled by Expected Parrot's infrastructure, and you can check the status
|
1026
1140
|
and retrieve results later.
|
1027
1141
|
|
1142
|
+
Parameters:
|
1143
|
+
job (Jobs): The EDSL job to run in the cloud
|
1144
|
+
description (str, optional): A human-readable description of the job
|
1145
|
+
status (RemoteJobStatus): Initial status, should be "queued" for normal use
|
1146
|
+
Possible values: "queued", "running", "completed", "failed"
|
1147
|
+
visibility (VisibilityType): Access level for the job information. One of:
|
1148
|
+
- "private": Only accessible by the owner
|
1149
|
+
- "public": Accessible by anyone
|
1150
|
+
- "unlisted": Accessible with the link, but not listed publicly
|
1151
|
+
initial_results_visibility (VisibilityType): Access level for the job results
|
1152
|
+
iterations (int): Number of times to run each interview (default: 1)
|
1153
|
+
fresh (bool): If True, ignore existing cache entries and generate new results
|
1154
|
+
|
1155
|
+
Returns:
|
1156
|
+
RemoteInferenceCreationInfo: Information about the created job including:
|
1157
|
+
- uuid: The unique identifier for the job
|
1158
|
+
- description: The job description
|
1159
|
+
- status: Current status of the job
|
1160
|
+
- iterations: Number of iterations for each interview
|
1161
|
+
- visibility: Access level for the job
|
1162
|
+
- version: EDSL version used to create the job
|
1163
|
+
|
1164
|
+
Raises:
|
1165
|
+
CoopServerResponseError: If there's an error communicating with the server
|
1166
|
+
|
1167
|
+
Notes:
|
1168
|
+
- Remote jobs run asynchronously and may take time to complete
|
1169
|
+
- Use remote_inference_get() with the returned UUID to check status
|
1170
|
+
- Credits are consumed based on the complexity of the job
|
1171
|
+
|
1172
|
+
Example:
|
1173
|
+
>>> from edsl.jobs import Jobs
|
1174
|
+
>>> job = Jobs.example()
|
1175
|
+
>>> job_info = coop.remote_inference_create(job=job, description="My job")
|
1176
|
+
>>> print(f"Job created with UUID: {job_info['uuid']}")
|
1177
|
+
"""
|
1178
|
+
response = self._send_server_request(
|
1179
|
+
uri="api/v0/new-remote-inference",
|
1180
|
+
method="POST",
|
1181
|
+
payload={
|
1182
|
+
"json_string": "offloaded",
|
1183
|
+
"description": description,
|
1184
|
+
"status": status,
|
1185
|
+
"iterations": iterations,
|
1186
|
+
"visibility": visibility,
|
1187
|
+
"version": self._edsl_version,
|
1188
|
+
"initial_results_visibility": initial_results_visibility,
|
1189
|
+
"fresh": fresh,
|
1190
|
+
},
|
1191
|
+
)
|
1192
|
+
self._resolve_server_response(response)
|
1193
|
+
response_json = response.json()
|
1194
|
+
upload_signed_url = response_json.get("upload_signed_url")
|
1195
|
+
if not upload_signed_url:
|
1196
|
+
from .exceptions import CoopResponseError
|
1197
|
+
|
1198
|
+
raise CoopResponseError("No signed url was provided received")
|
1199
|
+
|
1200
|
+
response = requests.put(
|
1201
|
+
upload_signed_url,
|
1202
|
+
data=json.dumps(
|
1203
|
+
job.to_dict(),
|
1204
|
+
default=self._json_handle_none,
|
1205
|
+
).encode(),
|
1206
|
+
headers={"Content-Type": "application/json"},
|
1207
|
+
)
|
1208
|
+
self._resolve_gcs_response(response)
|
1209
|
+
|
1210
|
+
job_uuid = response_json.get("job_uuid")
|
1211
|
+
|
1212
|
+
response = self._send_server_request(
|
1213
|
+
uri="api/v0/new-remote-inference/uploaded",
|
1214
|
+
method="POST",
|
1215
|
+
payload={
|
1216
|
+
"job_uuid": job_uuid,
|
1217
|
+
"message": "Job uploaded successfully",
|
1218
|
+
},
|
1219
|
+
)
|
1220
|
+
response_json = response.json()
|
1221
|
+
|
1222
|
+
return RemoteInferenceCreationInfo(
|
1223
|
+
**{
|
1224
|
+
"uuid": response_json.get("job_uuid"),
|
1225
|
+
"description": response_json.get("description", ""),
|
1226
|
+
"status": response_json.get("status"),
|
1227
|
+
"iterations": response_json.get("iterations", ""),
|
1228
|
+
"visibility": response_json.get("visibility", ""),
|
1229
|
+
"version": self._edsl_version,
|
1230
|
+
}
|
1231
|
+
)
|
1232
|
+
|
1233
|
+
def old_remote_inference_create(
|
1234
|
+
self,
|
1235
|
+
job: "Jobs",
|
1236
|
+
description: Optional[str] = None,
|
1237
|
+
status: RemoteJobStatus = "queued",
|
1238
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
1239
|
+
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
1240
|
+
iterations: Optional[int] = 1,
|
1241
|
+
fresh: Optional[bool] = False,
|
1242
|
+
) -> RemoteInferenceCreationInfo:
|
1243
|
+
"""
|
1244
|
+
Create a remote inference job for execution in the Expected Parrot cloud.
|
1245
|
+
|
1246
|
+
This method sends a job to be executed in the cloud, which can be more efficient
|
1247
|
+
for large jobs or when you want to run jobs in the background. The job execution
|
1248
|
+
is handled by Expected Parrot's infrastructure, and you can check the status
|
1249
|
+
and retrieve results later.
|
1250
|
+
|
1028
1251
|
Parameters:
|
1029
1252
|
job (Jobs): The EDSL job to run in the cloud
|
1030
1253
|
description (str, optional): A human-readable description of the job
|
@@ -1368,25 +1591,57 @@ class Coop(CoopFunctionsMixin):
|
|
1368
1591
|
def create_project(
|
1369
1592
|
self,
|
1370
1593
|
survey: "Survey",
|
1594
|
+
scenario_list: Optional["ScenarioList"] = None,
|
1595
|
+
scenario_list_method: Optional[
|
1596
|
+
Literal["randomize", "loop", "single_scenario"]
|
1597
|
+
] = None,
|
1371
1598
|
project_name: str = "Project",
|
1372
1599
|
survey_description: Optional[str] = None,
|
1373
1600
|
survey_alias: Optional[str] = None,
|
1374
1601
|
survey_visibility: Optional[VisibilityType] = "unlisted",
|
1602
|
+
scenario_list_description: Optional[str] = None,
|
1603
|
+
scenario_list_alias: Optional[str] = None,
|
1604
|
+
scenario_list_visibility: Optional[VisibilityType] = "unlisted",
|
1375
1605
|
):
|
1376
1606
|
"""
|
1377
1607
|
Create a survey object on Coop, then create a project from the survey.
|
1378
1608
|
"""
|
1379
|
-
|
1609
|
+
if scenario_list is None and scenario_list_method is not None:
|
1610
|
+
raise CoopValueError(
|
1611
|
+
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1612
|
+
)
|
1613
|
+
elif scenario_list is not None and scenario_list_method is None:
|
1614
|
+
raise CoopValueError(
|
1615
|
+
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1616
|
+
)
|
1617
|
+
survey_details = self.push(
|
1380
1618
|
object=survey,
|
1381
1619
|
description=survey_description,
|
1382
1620
|
alias=survey_alias,
|
1383
1621
|
visibility=survey_visibility,
|
1384
1622
|
)
|
1385
1623
|
survey_uuid = survey_details.get("uuid")
|
1624
|
+
if scenario_list is not None:
|
1625
|
+
scenario_list_details = self.push(
|
1626
|
+
object=scenario_list,
|
1627
|
+
description=scenario_list_description,
|
1628
|
+
alias=scenario_list_alias,
|
1629
|
+
visibility=scenario_list_visibility,
|
1630
|
+
)
|
1631
|
+
scenario_list_uuid = scenario_list_details.get("uuid")
|
1632
|
+
else:
|
1633
|
+
scenario_list_uuid = None
|
1386
1634
|
response = self._send_server_request(
|
1387
1635
|
uri="api/v0/projects/create-from-survey",
|
1388
1636
|
method="POST",
|
1389
|
-
payload={
|
1637
|
+
payload={
|
1638
|
+
"project_name": project_name,
|
1639
|
+
"survey_uuid": str(survey_uuid),
|
1640
|
+
"scenario_list_uuid": (
|
1641
|
+
str(scenario_list_uuid) if scenario_list_uuid is not None else None
|
1642
|
+
),
|
1643
|
+
"scenario_list_method": scenario_list_method,
|
1644
|
+
},
|
1390
1645
|
)
|
1391
1646
|
self._resolve_server_response(response)
|
1392
1647
|
response_json = response.json()
|
@@ -1413,14 +1668,26 @@ class Coop(CoopFunctionsMixin):
|
|
1413
1668
|
return {
|
1414
1669
|
"project_name": response_json.get("project_name"),
|
1415
1670
|
"project_job_uuids": response_json.get("job_uuids"),
|
1671
|
+
"project_prolific_studies": [
|
1672
|
+
{
|
1673
|
+
"study_id": study.get("id"),
|
1674
|
+
"name": study.get("name"),
|
1675
|
+
"status": study.get("status"),
|
1676
|
+
"num_participants": study.get("total_available_places"),
|
1677
|
+
"places_taken": study.get("places_taken"),
|
1678
|
+
}
|
1679
|
+
for study in response_json.get("prolific_studies", [])
|
1680
|
+
],
|
1416
1681
|
}
|
1417
1682
|
|
1418
|
-
def
|
1683
|
+
def _turn_human_responses_into_results(
|
1419
1684
|
self,
|
1420
|
-
|
1685
|
+
human_responses: List[dict],
|
1686
|
+
survey_json_string: str,
|
1687
|
+
scenario_list_json_string: Optional[str] = None,
|
1421
1688
|
) -> Union["Results", "ScenarioList"]:
|
1422
1689
|
"""
|
1423
|
-
|
1690
|
+
Turn a list of human responses into a Results object.
|
1424
1691
|
|
1425
1692
|
If generating the Results object fails, a ScenarioList will be returned instead.
|
1426
1693
|
"""
|
@@ -1430,16 +1697,19 @@ class Coop(CoopFunctionsMixin):
|
|
1430
1697
|
from ..scenarios import Scenario, ScenarioList
|
1431
1698
|
from ..surveys import Survey
|
1432
1699
|
|
1433
|
-
response = self._send_server_request(
|
1434
|
-
uri=f"api/v0/projects/{project_uuid}/human-responses",
|
1435
|
-
method="GET",
|
1436
|
-
)
|
1437
|
-
self._resolve_server_response(response)
|
1438
|
-
response_json = response.json()
|
1439
|
-
human_responses = response_json.get("human_responses", [])
|
1440
|
-
|
1441
1700
|
try:
|
1442
|
-
|
1701
|
+
survey = Survey.from_dict(json.loads(survey_json_string))
|
1702
|
+
|
1703
|
+
model = Model("test")
|
1704
|
+
|
1705
|
+
if scenario_list_json_string is not None:
|
1706
|
+
scenario_list = ScenarioList.from_dict(
|
1707
|
+
json.loads(scenario_list_json_string)
|
1708
|
+
)
|
1709
|
+
else:
|
1710
|
+
scenario_list = ScenarioList()
|
1711
|
+
|
1712
|
+
results = None
|
1443
1713
|
|
1444
1714
|
for response in human_responses:
|
1445
1715
|
response_uuid = response.get("response_uuid")
|
@@ -1449,8 +1719,14 @@ class Coop(CoopFunctionsMixin):
|
|
1449
1719
|
)
|
1450
1720
|
|
1451
1721
|
response_dict = json.loads(response.get("response_json_string"))
|
1722
|
+
agent_traits_json_string = response.get("agent_traits_json_string")
|
1723
|
+
scenario_uuid = response.get("scenario_uuid")
|
1724
|
+
if agent_traits_json_string is not None:
|
1725
|
+
agent_traits = json.loads(agent_traits_json_string)
|
1726
|
+
else:
|
1727
|
+
agent_traits = {}
|
1452
1728
|
|
1453
|
-
a = Agent(name=response_uuid, instruction="")
|
1729
|
+
a = Agent(name=response_uuid, instruction="", traits=agent_traits)
|
1454
1730
|
|
1455
1731
|
def create_answer_function(response_data):
|
1456
1732
|
def f(self, question, scenario):
|
@@ -1458,27 +1734,38 @@ class Coop(CoopFunctionsMixin):
|
|
1458
1734
|
|
1459
1735
|
return f
|
1460
1736
|
|
1737
|
+
scenario = None
|
1738
|
+
if scenario_uuid is not None:
|
1739
|
+
for s in scenario_list:
|
1740
|
+
if s.get("uuid") == scenario_uuid:
|
1741
|
+
scenario = s
|
1742
|
+
break
|
1743
|
+
|
1744
|
+
if scenario is None:
|
1745
|
+
raise RuntimeError("Scenario not found.")
|
1746
|
+
|
1461
1747
|
a.add_direct_question_answering_method(
|
1462
1748
|
create_answer_function(response_dict)
|
1463
1749
|
)
|
1464
|
-
agent_list.append(a)
|
1465
1750
|
|
1466
|
-
|
1467
|
-
survey = Survey.from_dict(json.loads(survey_json_string))
|
1751
|
+
job = survey.by(a).by(model)
|
1468
1752
|
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
.
|
1473
|
-
.run(
|
1753
|
+
if scenario is not None:
|
1754
|
+
job = job.by(scenario)
|
1755
|
+
|
1756
|
+
question_results = job.run(
|
1474
1757
|
cache=Cache(),
|
1475
1758
|
disable_remote_cache=True,
|
1476
1759
|
disable_remote_inference=True,
|
1477
1760
|
print_exceptions=False,
|
1478
1761
|
)
|
1479
|
-
|
1762
|
+
|
1763
|
+
if results is None:
|
1764
|
+
results = question_results
|
1765
|
+
else:
|
1766
|
+
results = results + question_results
|
1480
1767
|
return results
|
1481
|
-
except Exception:
|
1768
|
+
except Exception as e:
|
1482
1769
|
human_response_scenarios = []
|
1483
1770
|
for response in human_responses:
|
1484
1771
|
response_uuid = response.get("response_uuid")
|
@@ -1493,71 +1780,492 @@ class Coop(CoopFunctionsMixin):
|
|
1493
1780
|
human_response_scenarios.append(scenario)
|
1494
1781
|
return ScenarioList(human_response_scenarios)
|
1495
1782
|
|
1496
|
-
def
|
1497
|
-
"""Return a string representation of the client."""
|
1498
|
-
return f"Client(api_key='{self.api_key}', url='{self.url}')"
|
1499
|
-
|
1500
|
-
async def remote_async_execute_model_call(
|
1501
|
-
self, model_dict: dict, user_prompt: str, system_prompt: str
|
1502
|
-
) -> dict:
|
1503
|
-
url = self.api_url + "/inference/"
|
1504
|
-
# print("Now using url: ", url)
|
1505
|
-
data = {
|
1506
|
-
"model_dict": model_dict,
|
1507
|
-
"user_prompt": user_prompt,
|
1508
|
-
"system_prompt": system_prompt,
|
1509
|
-
}
|
1510
|
-
# Use aiohttp to send a POST request asynchronously
|
1511
|
-
async with aiohttp.ClientSession() as session:
|
1512
|
-
async with session.post(url, json=data) as response:
|
1513
|
-
response_data = await response.json()
|
1514
|
-
return response_data
|
1515
|
-
|
1516
|
-
def web(
|
1783
|
+
def get_project_human_responses(
|
1517
1784
|
self,
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
email=None,
|
1523
|
-
):
|
1524
|
-
url = f"{self.api_url}/api/v0/export_to_{platform}"
|
1525
|
-
if email:
|
1526
|
-
data = {"json_string": json.dumps({"survey": survey, "email": email})}
|
1527
|
-
else:
|
1528
|
-
data = {"json_string": json.dumps({"survey": survey, "email": ""})}
|
1785
|
+
project_uuid: str,
|
1786
|
+
) -> Union["Results", "ScenarioList"]:
|
1787
|
+
"""
|
1788
|
+
Return a Results object with the human responses for a project.
|
1529
1789
|
|
1530
|
-
|
1790
|
+
If generating the Results object fails, a ScenarioList will be returned instead.
|
1791
|
+
"""
|
1792
|
+
response = self._send_server_request(
|
1793
|
+
uri=f"api/v0/projects/{project_uuid}/human-responses",
|
1794
|
+
method="GET",
|
1795
|
+
)
|
1796
|
+
self._resolve_server_response(response)
|
1797
|
+
response_json = response.json()
|
1798
|
+
human_responses = response_json.get("human_responses", [])
|
1799
|
+
survey_json_string = response_json.get("survey_json_string")
|
1800
|
+
scenario_list_json_string = response_json.get("scenario_list_json_string")
|
1531
1801
|
|
1532
|
-
return
|
1802
|
+
return self._turn_human_responses_into_results(
|
1803
|
+
human_responses, survey_json_string, scenario_list_json_string
|
1804
|
+
)
|
1533
1805
|
|
1534
|
-
def
|
1806
|
+
def list_prolific_filters(self) -> "CoopProlificFilters":
|
1535
1807
|
"""
|
1536
|
-
|
1808
|
+
Get a ScenarioList of supported Prolific filters. This list has several methods
|
1809
|
+
that you can use to create valid filter dicts for use with Coop.create_prolific_study().
|
1537
1810
|
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1811
|
+
Call find() to examine a specific filter by ID:
|
1812
|
+
>>> filters = coop.list_prolific_filters()
|
1813
|
+
>>> filters.find("age")
|
1814
|
+
Scenario(
|
1815
|
+
{
|
1816
|
+
"filter_id": "age",
|
1817
|
+
"type": "range",
|
1818
|
+
"range_filter_min": 18,
|
1819
|
+
"range_filter_max": 100,
|
1820
|
+
...
|
1821
|
+
}
|
1822
|
+
)
|
1541
1823
|
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1824
|
+
Call create_study_filter() to create a valid filter dict:
|
1825
|
+
>>> filters.create_study_filter("age", min=30, max=40)
|
1826
|
+
{
|
1827
|
+
"filter_id": "age",
|
1828
|
+
"selected_range": {
|
1829
|
+
"lower": 30,
|
1830
|
+
"upper": 40,
|
1831
|
+
},
|
1832
|
+
}
|
1833
|
+
"""
|
1834
|
+
from ..scenarios import Scenario
|
1835
|
+
|
1836
|
+
response = self._send_server_request(
|
1837
|
+
uri="api/v0/prolific-filters",
|
1838
|
+
method="GET",
|
1839
|
+
)
|
1840
|
+
self._resolve_server_response(response)
|
1841
|
+
response_json = response.json()
|
1842
|
+
filters = response_json.get("prolific_filters", [])
|
1843
|
+
filter_scenarios = []
|
1844
|
+
for filter in filters:
|
1845
|
+
filter_type = filter.get("type")
|
1846
|
+
question = filter.get("question")
|
1847
|
+
scenario = Scenario(
|
1546
1848
|
{
|
1547
|
-
(
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1849
|
+
"filter_id": filter.get("filter_id"),
|
1850
|
+
"title": filter.get("title"),
|
1851
|
+
"question": (
|
1852
|
+
f"Participants were asked the following: {question}"
|
1853
|
+
if question
|
1854
|
+
else None
|
1855
|
+
),
|
1856
|
+
"type": filter_type,
|
1857
|
+
"range_filter_min": (
|
1858
|
+
filter.get("min") if filter_type == "range" else None
|
1859
|
+
),
|
1860
|
+
"range_filter_max": (
|
1861
|
+
filter.get("max") if filter_type == "range" else None
|
1862
|
+
),
|
1863
|
+
"select_filter_num_options": (
|
1864
|
+
len(filter.get("choices", []))
|
1865
|
+
if filter_type == "select"
|
1866
|
+
else None
|
1867
|
+
),
|
1868
|
+
"select_filter_options": (
|
1869
|
+
filter.get("choices") if filter_type == "select" else None
|
1870
|
+
),
|
1551
1871
|
}
|
1872
|
+
)
|
1873
|
+
filter_scenarios.append(scenario)
|
1874
|
+
return CoopProlificFilters(filter_scenarios)
|
1552
1875
|
|
1553
|
-
|
1554
|
-
|
1876
|
+
@staticmethod
|
1877
|
+
def _validate_prolific_study_cost(
|
1878
|
+
estimated_completion_time_minutes: int, participant_payment_cents: int
|
1879
|
+
) -> tuple[bool, float]:
|
1880
|
+
"""
|
1881
|
+
If the cost of a Prolific study is below the threshold, return True.
|
1882
|
+
Otherwise, return False.
|
1883
|
+
The second value in the tuple is the cost of the study in USD per hour.
|
1884
|
+
"""
|
1885
|
+
estimated_completion_time_hours = estimated_completion_time_minutes / 60
|
1886
|
+
participant_payment_usd = participant_payment_cents / 100
|
1887
|
+
cost_usd_per_hour = participant_payment_usd / estimated_completion_time_hours
|
1555
1888
|
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1889
|
+
# $8.00 USD per hour is the minimum amount for using Prolific
|
1890
|
+
if cost_usd_per_hour < 8:
|
1891
|
+
return True, cost_usd_per_hour
|
1892
|
+
else:
|
1893
|
+
return False, cost_usd_per_hour
|
1894
|
+
|
1895
|
+
def create_prolific_study(
|
1896
|
+
self,
|
1897
|
+
project_uuid: str,
|
1898
|
+
name: str,
|
1899
|
+
description: str,
|
1900
|
+
num_participants: int,
|
1901
|
+
estimated_completion_time_minutes: int,
|
1902
|
+
participant_payment_cents: int,
|
1903
|
+
device_compatibility: Optional[
|
1904
|
+
List[Literal["desktop", "tablet", "mobile"]]
|
1905
|
+
] = None,
|
1906
|
+
peripheral_requirements: Optional[
|
1907
|
+
List[Literal["audio", "camera", "download", "microphone"]]
|
1908
|
+
] = None,
|
1909
|
+
filters: Optional[List[Dict]] = None,
|
1910
|
+
) -> dict:
|
1911
|
+
"""
|
1912
|
+
Create a Prolific study for a project. Returns a dict with the study details.
|
1913
|
+
|
1914
|
+
To add filters to your study, you should first pull the list of supported
|
1915
|
+
filters using Coop.list_prolific_filters().
|
1916
|
+
Then, you can use the create_study_filter method of the returned
|
1917
|
+
CoopProlificFilters object to create a valid filter dict.
|
1918
|
+
"""
|
1919
|
+
is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
|
1920
|
+
estimated_completion_time_minutes, participant_payment_cents
|
1921
|
+
)
|
1922
|
+
if is_underpayment:
|
1923
|
+
raise CoopValueError(
|
1924
|
+
f"The current participant payment of ${cost_usd_per_hour:.2f} USD per hour is below the minimum payment for using Prolific ($8.00 USD per hour)."
|
1925
|
+
)
|
1926
|
+
|
1927
|
+
response = self._send_server_request(
|
1928
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies",
|
1929
|
+
method="POST",
|
1930
|
+
payload={
|
1931
|
+
"name": name,
|
1932
|
+
"description": description,
|
1933
|
+
"total_available_places": num_participants,
|
1934
|
+
"estimated_completion_time": estimated_completion_time_minutes,
|
1935
|
+
"reward": participant_payment_cents,
|
1936
|
+
"device_compatibility": (
|
1937
|
+
["desktop", "tablet", "mobile"]
|
1938
|
+
if device_compatibility is None
|
1939
|
+
else device_compatibility
|
1940
|
+
),
|
1941
|
+
"peripheral_requirements": (
|
1942
|
+
[] if peripheral_requirements is None else peripheral_requirements
|
1943
|
+
),
|
1944
|
+
"filters": [] if filters is None else filters,
|
1945
|
+
},
|
1946
|
+
)
|
1947
|
+
self._resolve_server_response(response)
|
1948
|
+
response_json = response.json()
|
1949
|
+
return {
|
1950
|
+
"study_id": response_json.get("study_id"),
|
1951
|
+
"status": response_json.get("status"),
|
1952
|
+
"admin_url": response_json.get("admin_url"),
|
1953
|
+
"respondent_url": response_json.get("respondent_url"),
|
1954
|
+
"name": response_json.get("name"),
|
1955
|
+
"description": response_json.get("description"),
|
1956
|
+
"num_participants": response_json.get("total_available_places"),
|
1957
|
+
"estimated_completion_time_minutes": response_json.get(
|
1958
|
+
"estimated_completion_time"
|
1959
|
+
),
|
1960
|
+
"participant_payment_cents": response_json.get("reward"),
|
1961
|
+
"total_cost_cents": response_json.get("total_cost"),
|
1962
|
+
"device_compatibility": response_json.get("device_compatibility"),
|
1963
|
+
"peripheral_requirements": response_json.get("peripheral_requirements"),
|
1964
|
+
"filters": response_json.get("filters"),
|
1965
|
+
}
|
1966
|
+
|
1967
|
+
def update_prolific_study(
|
1968
|
+
self,
|
1969
|
+
project_uuid: str,
|
1970
|
+
study_id: str,
|
1971
|
+
name: Optional[str] = None,
|
1972
|
+
description: Optional[str] = None,
|
1973
|
+
num_participants: Optional[int] = None,
|
1974
|
+
estimated_completion_time_minutes: Optional[int] = None,
|
1975
|
+
participant_payment_cents: Optional[int] = None,
|
1976
|
+
device_compatibility: Optional[
|
1977
|
+
List[Literal["desktop", "tablet", "mobile"]]
|
1978
|
+
] = None,
|
1979
|
+
peripheral_requirements: Optional[
|
1980
|
+
List[Literal["audio", "camera", "download", "microphone"]]
|
1981
|
+
] = None,
|
1982
|
+
filters: Optional[List[Dict]] = None,
|
1983
|
+
) -> dict:
|
1984
|
+
"""
|
1985
|
+
Update a Prolific study. Returns a dict with the study details.
|
1986
|
+
"""
|
1987
|
+
study = self.get_prolific_study(project_uuid, study_id)
|
1988
|
+
|
1989
|
+
current_completion_time = study.get("estimated_completion_time_minutes")
|
1990
|
+
current_payment = study.get("participant_payment_cents")
|
1991
|
+
|
1992
|
+
updated_completion_time = (
|
1993
|
+
estimated_completion_time_minutes or current_completion_time
|
1994
|
+
)
|
1995
|
+
updated_payment = participant_payment_cents or current_payment
|
1996
|
+
|
1997
|
+
is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
|
1998
|
+
updated_completion_time, updated_payment
|
1999
|
+
)
|
2000
|
+
if is_underpayment:
|
2001
|
+
raise CoopValueError(
|
2002
|
+
f"This update would result in a participant payment of ${cost_usd_per_hour:.2f} USD per hour, which is below the minimum payment for using Prolific ($8.00 USD per hour)."
|
2003
|
+
)
|
2004
|
+
|
2005
|
+
payload = {}
|
2006
|
+
if name is not None:
|
2007
|
+
payload["name"] = name
|
2008
|
+
if description is not None:
|
2009
|
+
payload["description"] = description
|
2010
|
+
if num_participants is not None:
|
2011
|
+
payload["total_available_places"] = num_participants
|
2012
|
+
if estimated_completion_time_minutes is not None:
|
2013
|
+
payload["estimated_completion_time"] = estimated_completion_time_minutes
|
2014
|
+
if participant_payment_cents is not None:
|
2015
|
+
payload["reward"] = participant_payment_cents
|
2016
|
+
if device_compatibility is not None:
|
2017
|
+
payload["device_compatibility"] = device_compatibility
|
2018
|
+
if peripheral_requirements is not None:
|
2019
|
+
payload["peripheral_requirements"] = peripheral_requirements
|
2020
|
+
if filters is not None:
|
2021
|
+
payload["filters"] = filters
|
2022
|
+
|
2023
|
+
response = self._send_server_request(
|
2024
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
|
2025
|
+
method="PATCH",
|
2026
|
+
payload=payload,
|
2027
|
+
)
|
2028
|
+
self._resolve_server_response(response)
|
2029
|
+
response_json = response.json()
|
2030
|
+
return {
|
2031
|
+
"study_id": response_json.get("study_id"),
|
2032
|
+
"status": response_json.get("status"),
|
2033
|
+
"admin_url": response_json.get("admin_url"),
|
2034
|
+
"respondent_url": response_json.get("respondent_url"),
|
2035
|
+
"name": response_json.get("name"),
|
2036
|
+
"description": response_json.get("description"),
|
2037
|
+
"num_participants": response_json.get("total_available_places"),
|
2038
|
+
"estimated_completion_time_minutes": response_json.get(
|
2039
|
+
"estimated_completion_time"
|
2040
|
+
),
|
2041
|
+
"participant_payment_cents": response_json.get("reward"),
|
2042
|
+
"total_cost_cents": response_json.get("total_cost"),
|
2043
|
+
"device_compatibility": response_json.get("device_compatibility"),
|
2044
|
+
"peripheral_requirements": response_json.get("peripheral_requirements"),
|
2045
|
+
"filters": response_json.get("filters"),
|
2046
|
+
}
|
2047
|
+
|
2048
|
+
def publish_prolific_study(
|
2049
|
+
self,
|
2050
|
+
project_uuid: str,
|
2051
|
+
study_id: str,
|
2052
|
+
) -> dict:
|
2053
|
+
"""
|
2054
|
+
Publish a Prolific study.
|
2055
|
+
"""
|
2056
|
+
response = self._send_server_request(
|
2057
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/publish",
|
2058
|
+
method="POST",
|
2059
|
+
)
|
2060
|
+
self._resolve_server_response(response)
|
2061
|
+
return response.json()
|
2062
|
+
|
2063
|
+
def get_prolific_study(self, project_uuid: str, study_id: str) -> dict:
|
2064
|
+
"""
|
2065
|
+
Get a Prolific study. Returns a dict with the study details.
|
2066
|
+
"""
|
2067
|
+
response = self._send_server_request(
|
2068
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
|
2069
|
+
method="GET",
|
2070
|
+
)
|
2071
|
+
self._resolve_server_response(response)
|
2072
|
+
response_json = response.json()
|
2073
|
+
return {
|
2074
|
+
"study_id": response_json.get("study_id"),
|
2075
|
+
"status": response_json.get("status"),
|
2076
|
+
"admin_url": response_json.get("admin_url"),
|
2077
|
+
"respondent_url": response_json.get("respondent_url"),
|
2078
|
+
"name": response_json.get("name"),
|
2079
|
+
"description": response_json.get("description"),
|
2080
|
+
"num_participants": response_json.get("total_available_places"),
|
2081
|
+
"estimated_completion_time_minutes": response_json.get(
|
2082
|
+
"estimated_completion_time"
|
2083
|
+
),
|
2084
|
+
"participant_payment_cents": response_json.get("reward"),
|
2085
|
+
"total_cost_cents": response_json.get("total_cost"),
|
2086
|
+
"device_compatibility": response_json.get("device_compatibility"),
|
2087
|
+
"peripheral_requirements": response_json.get("peripheral_requirements"),
|
2088
|
+
"filters": response_json.get("filters"),
|
2089
|
+
}
|
2090
|
+
|
2091
|
+
def get_prolific_study_responses(
|
2092
|
+
self,
|
2093
|
+
project_uuid: str,
|
2094
|
+
study_id: str,
|
2095
|
+
) -> Union["Results", "ScenarioList"]:
|
2096
|
+
"""
|
2097
|
+
Return a Results object with the human responses for a project.
|
2098
|
+
|
2099
|
+
If generating the Results object fails, a ScenarioList will be returned instead.
|
2100
|
+
"""
|
2101
|
+
response = self._send_server_request(
|
2102
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/responses",
|
2103
|
+
method="GET",
|
2104
|
+
)
|
2105
|
+
self._resolve_server_response(response)
|
2106
|
+
response_json = response.json()
|
2107
|
+
human_responses = response_json.get("human_responses", [])
|
2108
|
+
survey_json_string = response_json.get("survey_json_string")
|
2109
|
+
|
2110
|
+
return self._turn_human_responses_into_results(
|
2111
|
+
human_responses, survey_json_string
|
2112
|
+
)
|
2113
|
+
|
2114
|
+
def delete_prolific_study(
|
2115
|
+
self,
|
2116
|
+
project_uuid: str,
|
2117
|
+
study_id: str,
|
2118
|
+
) -> dict:
|
2119
|
+
"""
|
2120
|
+
Deletes a Prolific study.
|
2121
|
+
|
2122
|
+
Note: Only draft studies can be deleted. Once you publish a study, it cannot be deleted.
|
2123
|
+
"""
|
2124
|
+
response = self._send_server_request(
|
2125
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
|
2126
|
+
method="DELETE",
|
2127
|
+
)
|
2128
|
+
self._resolve_server_response(response)
|
2129
|
+
return response.json()
|
2130
|
+
|
2131
|
+
def approve_prolific_study_submission(
|
2132
|
+
self,
|
2133
|
+
project_uuid: str,
|
2134
|
+
study_id: str,
|
2135
|
+
submission_id: str,
|
2136
|
+
) -> dict:
|
2137
|
+
"""
|
2138
|
+
Approve a Prolific study submission.
|
2139
|
+
"""
|
2140
|
+
response = self._send_server_request(
|
2141
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/approve",
|
2142
|
+
method="POST",
|
2143
|
+
)
|
2144
|
+
self._resolve_server_response(response)
|
2145
|
+
return response.json()
|
2146
|
+
|
2147
|
+
def reject_prolific_study_submission(
|
2148
|
+
self,
|
2149
|
+
project_uuid: str,
|
2150
|
+
study_id: str,
|
2151
|
+
submission_id: str,
|
2152
|
+
reason: Literal[
|
2153
|
+
"TOO_QUICKLY",
|
2154
|
+
"TOO_SLOWLY",
|
2155
|
+
"FAILED_INSTRUCTIONS",
|
2156
|
+
"INCOMP_LONGITUDINAL",
|
2157
|
+
"FAILED_CHECK",
|
2158
|
+
"LOW_EFFORT",
|
2159
|
+
"MALINGERING",
|
2160
|
+
"NO_CODE",
|
2161
|
+
"BAD_CODE",
|
2162
|
+
"NO_DATA",
|
2163
|
+
"UNSUPP_DEVICE",
|
2164
|
+
"OTHER",
|
2165
|
+
],
|
2166
|
+
explanation: str,
|
2167
|
+
) -> dict:
|
2168
|
+
"""
|
2169
|
+
Reject a Prolific study submission.
|
2170
|
+
"""
|
2171
|
+
valid_rejection_reasons = [
|
2172
|
+
"TOO_QUICKLY",
|
2173
|
+
"TOO_SLOWLY",
|
2174
|
+
"FAILED_INSTRUCTIONS",
|
2175
|
+
"INCOMP_LONGITUDINAL",
|
2176
|
+
"FAILED_CHECK",
|
2177
|
+
"LOW_EFFORT",
|
2178
|
+
"MALINGERING",
|
2179
|
+
"NO_CODE",
|
2180
|
+
"BAD_CODE",
|
2181
|
+
"NO_DATA",
|
2182
|
+
"UNSUPP_DEVICE",
|
2183
|
+
"OTHER",
|
2184
|
+
]
|
2185
|
+
if reason not in valid_rejection_reasons:
|
2186
|
+
raise CoopValueError(
|
2187
|
+
f"Invalid rejection reason. Please use one of the following: {valid_rejection_reasons}."
|
2188
|
+
)
|
2189
|
+
if len(explanation) < 100:
|
2190
|
+
raise CoopValueError(
|
2191
|
+
"Rejection explanation must be at least 100 characters."
|
2192
|
+
)
|
2193
|
+
response = self._send_server_request(
|
2194
|
+
uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/reject",
|
2195
|
+
method="POST",
|
2196
|
+
payload={
|
2197
|
+
"reason": reason,
|
2198
|
+
"explanation": explanation,
|
2199
|
+
},
|
2200
|
+
)
|
2201
|
+
self._resolve_server_response(response)
|
2202
|
+
return response.json()
|
2203
|
+
|
2204
|
+
def __repr__(self):
|
2205
|
+
"""Return a string representation of the client."""
|
2206
|
+
return f"Client(api_key='{self.api_key}', url='{self.url}')"
|
2207
|
+
|
2208
|
+
async def remote_async_execute_model_call(
|
2209
|
+
self, model_dict: dict, user_prompt: str, system_prompt: str
|
2210
|
+
) -> dict:
|
2211
|
+
url = self.api_url + "/inference/"
|
2212
|
+
# print("Now using url: ", url)
|
2213
|
+
data = {
|
2214
|
+
"model_dict": model_dict,
|
2215
|
+
"user_prompt": user_prompt,
|
2216
|
+
"system_prompt": system_prompt,
|
2217
|
+
}
|
2218
|
+
# Use aiohttp to send a POST request asynchronously
|
2219
|
+
async with aiohttp.ClientSession() as session:
|
2220
|
+
async with session.post(url, json=data) as response:
|
2221
|
+
response_data = await response.json()
|
2222
|
+
return response_data
|
2223
|
+
|
2224
|
+
def web(
|
2225
|
+
self,
|
2226
|
+
survey: dict,
|
2227
|
+
platform: Literal[
|
2228
|
+
"google_forms", "lime_survey", "survey_monkey"
|
2229
|
+
] = "lime_survey",
|
2230
|
+
email=None,
|
2231
|
+
):
|
2232
|
+
url = f"{self.api_url}/api/v0/export_to_{platform}"
|
2233
|
+
if email:
|
2234
|
+
data = {"json_string": json.dumps({"survey": survey, "email": email})}
|
2235
|
+
else:
|
2236
|
+
data = {"json_string": json.dumps({"survey": survey, "email": ""})}
|
2237
|
+
|
2238
|
+
response_json = requests.post(url, headers=self.headers, data=json.dumps(data))
|
2239
|
+
|
2240
|
+
return response_json
|
2241
|
+
|
2242
|
+
def fetch_prices(self) -> dict:
|
2243
|
+
"""
|
2244
|
+
Fetch the current pricing information for language models.
|
2245
|
+
|
2246
|
+
This method retrieves the latest pricing information for all supported language models
|
2247
|
+
from the Expected Parrot API. The pricing data is used to estimate costs for jobs
|
2248
|
+
and to optimize model selection based on budget constraints.
|
2249
|
+
|
2250
|
+
Returns:
|
2251
|
+
dict: A dictionary mapping (service, model) tuples to pricing information.
|
2252
|
+
Each entry contains token pricing for input and output tokens.
|
2253
|
+
Example structure:
|
2254
|
+
{
|
2255
|
+
('openai', 'gpt-4'): {
|
2256
|
+
'input': {'usd_per_1M_tokens': 30.0, ...},
|
2257
|
+
'output': {'usd_per_1M_tokens': 60.0, ...}
|
2258
|
+
}
|
2259
|
+
}
|
2260
|
+
|
2261
|
+
Raises:
|
2262
|
+
ValueError: If the EDSL_FETCH_TOKEN_PRICES configuration setting is invalid
|
2263
|
+
|
2264
|
+
Notes:
|
2265
|
+
- Returns an empty dict if EDSL_FETCH_TOKEN_PRICES is set to "False"
|
2266
|
+
- The pricing data is cached to minimize API calls
|
2267
|
+
- Pricing may vary based on the model, provider, and token type (input/output)
|
2268
|
+
- All prices are in USD per million tokens
|
1561
2269
|
|
1562
2270
|
Example:
|
1563
2271
|
>>> prices = coop.fetch_prices()
|
@@ -1686,6 +2394,235 @@ class Coop(CoopFunctionsMixin):
|
|
1686
2394
|
self._resolve_server_response(response)
|
1687
2395
|
return response.json().get("uuid")
|
1688
2396
|
|
2397
|
+
def pull(
|
2398
|
+
self,
|
2399
|
+
url_or_uuid: Optional[Union[str, UUID]] = None,
|
2400
|
+
expected_object_type: Optional[ObjectType] = None,
|
2401
|
+
) -> dict:
|
2402
|
+
"""
|
2403
|
+
Generate a signed URL for pulling an object directly from Google Cloud Storage.
|
2404
|
+
|
2405
|
+
This method gets a signed URL that allows direct download access to the object from
|
2406
|
+
Google Cloud Storage, which is more efficient for large files.
|
2407
|
+
|
2408
|
+
Parameters:
|
2409
|
+
url_or_uuid (Union[str, UUID], optional): Identifier for the object to retrieve.
|
2410
|
+
Can be one of:
|
2411
|
+
- UUID string (e.g., "123e4567-e89b-12d3-a456-426614174000")
|
2412
|
+
- Full URL (e.g., "https://expectedparrot.com/content/123e4567...")
|
2413
|
+
- Alias URL (e.g., "https://expectedparrot.com/content/username/my-survey")
|
2414
|
+
expected_object_type (ObjectType, optional): If provided, validates that the
|
2415
|
+
retrieved object is of the expected type (e.g., "survey", "agent")
|
2416
|
+
|
2417
|
+
Returns:
|
2418
|
+
dict: A response containing the signed_url for direct download
|
2419
|
+
|
2420
|
+
Raises:
|
2421
|
+
CoopNoUUIDError: If no UUID or URL is provided
|
2422
|
+
CoopInvalidURLError: If the URL format is invalid
|
2423
|
+
CoopServerResponseError: If there's an error communicating with the server
|
2424
|
+
HTTPException: If the object or object files are not found
|
2425
|
+
|
2426
|
+
Example:
|
2427
|
+
>>> response = coop.pull("123e4567-e89b-12d3-a456-426614174000")
|
2428
|
+
>>> response = coop.pull("https://expectedparrot.com/content/username/my-survey")
|
2429
|
+
>>> print(f"Download URL: {response['signed_url']}")
|
2430
|
+
>>> # Use the signed_url to download the object directly
|
2431
|
+
"""
|
2432
|
+
obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
|
2433
|
+
|
2434
|
+
# Handle alias-based retrieval with new/old format detection
|
2435
|
+
if not obj_uuid and owner_username and alias:
|
2436
|
+
# First, get object info to determine format and UUID
|
2437
|
+
info_response = self._send_server_request(
|
2438
|
+
uri="api/v0/object/alias/info",
|
2439
|
+
method="GET",
|
2440
|
+
params={"owner_username": owner_username, "alias": alias},
|
2441
|
+
)
|
2442
|
+
self._resolve_server_response(info_response)
|
2443
|
+
info_data = info_response.json()
|
2444
|
+
|
2445
|
+
obj_uuid = info_data.get("uuid")
|
2446
|
+
is_new_format = info_data.get("is_new_format", False)
|
2447
|
+
|
2448
|
+
# Validate object type if expected
|
2449
|
+
if expected_object_type:
|
2450
|
+
actual_object_type = info_data.get("object_type")
|
2451
|
+
if actual_object_type != expected_object_type:
|
2452
|
+
from .exceptions import CoopObjectTypeError
|
2453
|
+
|
2454
|
+
raise CoopObjectTypeError(
|
2455
|
+
f"Expected {expected_object_type=} but got {actual_object_type=}"
|
2456
|
+
)
|
2457
|
+
|
2458
|
+
# Use get method for old format objects
|
2459
|
+
if not is_new_format:
|
2460
|
+
return self.get(url_or_uuid, expected_object_type)
|
2461
|
+
|
2462
|
+
# Send the request to the API endpoint with the resolved UUID
|
2463
|
+
response = self._send_server_request(
|
2464
|
+
uri="api/v0/object/pull",
|
2465
|
+
method="POST",
|
2466
|
+
payload={"object_uuid": obj_uuid},
|
2467
|
+
)
|
2468
|
+
# Handle any errors in the response
|
2469
|
+
self._resolve_server_response(response)
|
2470
|
+
if "signed_url" not in response.json():
|
2471
|
+
from .exceptions import CoopResponseError
|
2472
|
+
|
2473
|
+
raise CoopResponseError("No signed url was provided received")
|
2474
|
+
signed_url = response.json().get("signed_url")
|
2475
|
+
|
2476
|
+
if signed_url == "": # it is in old format
|
2477
|
+
return self.get(url_or_uuid, expected_object_type)
|
2478
|
+
|
2479
|
+
try:
|
2480
|
+
response = requests.get(signed_url)
|
2481
|
+
|
2482
|
+
self._resolve_gcs_response(response)
|
2483
|
+
|
2484
|
+
except Exception:
|
2485
|
+
return self.get(url_or_uuid, expected_object_type)
|
2486
|
+
object_dict = response.json()
|
2487
|
+
if expected_object_type is not None:
|
2488
|
+
edsl_class = ObjectRegistry.get_edsl_class_by_object_type(
|
2489
|
+
expected_object_type
|
2490
|
+
)
|
2491
|
+
edsl_object = edsl_class.from_dict(object_dict)
|
2492
|
+
# Return the response containing the signed URL
|
2493
|
+
return edsl_object
|
2494
|
+
|
2495
|
+
def get_upload_url(self, object_uuid: str) -> dict:
|
2496
|
+
"""
|
2497
|
+
Get a signed upload URL for updating the content of an existing object.
|
2498
|
+
|
2499
|
+
This method gets a signed URL that allows direct upload to Google Cloud Storage
|
2500
|
+
for objects stored in the new format, while preserving the existing UUID.
|
2501
|
+
|
2502
|
+
Parameters:
|
2503
|
+
object_uuid (str): The UUID of the object to get an upload URL for
|
2504
|
+
|
2505
|
+
Returns:
|
2506
|
+
dict: A response containing:
|
2507
|
+
- signed_url: The signed URL for uploading new content
|
2508
|
+
- object_uuid: The UUID of the object
|
2509
|
+
- message: Success message
|
2510
|
+
|
2511
|
+
Raises:
|
2512
|
+
CoopServerResponseError: If there's an error communicating with the server
|
2513
|
+
HTTPException: If the object is not found, not owned by user, or not in new format
|
2514
|
+
|
2515
|
+
Notes:
|
2516
|
+
- Only works with objects stored in the new format (transition table)
|
2517
|
+
- User must be the owner of the object
|
2518
|
+
- The signed URL expires after 60 minutes
|
2519
|
+
|
2520
|
+
Example:
|
2521
|
+
>>> response = coop.get_upload_url("123e4567-e89b-12d3-a456-426614174000")
|
2522
|
+
>>> upload_url = response['signed_url']
|
2523
|
+
>>> # Use the upload_url to PUT new content directly to GCS
|
2524
|
+
"""
|
2525
|
+
response = self._send_server_request(
|
2526
|
+
uri="api/v0/object/upload-url",
|
2527
|
+
method="POST",
|
2528
|
+
payload={"object_uuid": object_uuid},
|
2529
|
+
)
|
2530
|
+
self._resolve_server_response(response)
|
2531
|
+
return response.json()
|
2532
|
+
|
2533
|
+
def push(
|
2534
|
+
self,
|
2535
|
+
object: EDSLObject,
|
2536
|
+
description: Optional[str] = None,
|
2537
|
+
alias: Optional[str] = None,
|
2538
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
2539
|
+
) -> dict:
|
2540
|
+
"""
|
2541
|
+
Generate a signed URL for pushing an object directly to Google Cloud Storage.
|
2542
|
+
|
2543
|
+
This method gets a signed URL that allows direct upload access to Google Cloud Storage,
|
2544
|
+
which is more efficient for large files.
|
2545
|
+
|
2546
|
+
Parameters:
|
2547
|
+
object_type (ObjectType): The type of object to be uploaded
|
2548
|
+
|
2549
|
+
Returns:
|
2550
|
+
dict: A response containing the signed_url for direct upload and optionally a job_id
|
2551
|
+
|
2552
|
+
Raises:
|
2553
|
+
CoopServerResponseError: If there's an error communicating with the server
|
2554
|
+
|
2555
|
+
Example:
|
2556
|
+
>>> response = coop.push("scenario")
|
2557
|
+
>>> print(f"Upload URL: {response['signed_url']}")
|
2558
|
+
>>> # Use the signed_url to upload the object directly
|
2559
|
+
"""
|
2560
|
+
|
2561
|
+
object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
|
2562
|
+
object_dict = object.to_dict()
|
2563
|
+
object_hash = object.get_hash() if hasattr(object, "get_hash") else None
|
2564
|
+
|
2565
|
+
# Send the request to the API endpoint
|
2566
|
+
response = self._send_server_request(
|
2567
|
+
uri="api/v0/object/push",
|
2568
|
+
method="POST",
|
2569
|
+
payload={
|
2570
|
+
"object_type": object_type,
|
2571
|
+
"description": description,
|
2572
|
+
"alias": alias,
|
2573
|
+
"visibility": visibility,
|
2574
|
+
"object_hash": object_hash,
|
2575
|
+
"version": self._edsl_version,
|
2576
|
+
},
|
2577
|
+
)
|
2578
|
+
response_json = response.json()
|
2579
|
+
if response_json.get("signed_url") is not None:
|
2580
|
+
signed_url = response_json.get("signed_url")
|
2581
|
+
else:
|
2582
|
+
from .exceptions import CoopResponseError
|
2583
|
+
|
2584
|
+
raise CoopResponseError(response.text)
|
2585
|
+
|
2586
|
+
json_data = json.dumps(
|
2587
|
+
object_dict,
|
2588
|
+
default=self._json_handle_none,
|
2589
|
+
allow_nan=False,
|
2590
|
+
)
|
2591
|
+
response = requests.put(
|
2592
|
+
signed_url,
|
2593
|
+
data=json_data.encode(),
|
2594
|
+
headers={"Content-Type": "application/json"},
|
2595
|
+
)
|
2596
|
+
self._resolve_gcs_response(response)
|
2597
|
+
|
2598
|
+
# Send confirmation that upload was completed
|
2599
|
+
object_uuid = response_json.get("object_uuid", None)
|
2600
|
+
owner_username = response_json.get("owner_username", None)
|
2601
|
+
object_alias = response_json.get("alias", None)
|
2602
|
+
|
2603
|
+
if object_uuid is None:
|
2604
|
+
from .exceptions import CoopResponseError
|
2605
|
+
|
2606
|
+
raise CoopResponseError("No object uuid was provided received")
|
2607
|
+
|
2608
|
+
# Confirm the upload completion
|
2609
|
+
confirm_response = self._send_server_request(
|
2610
|
+
uri="api/v0/object/confirm-upload",
|
2611
|
+
method="POST",
|
2612
|
+
payload={"object_uuid": object_uuid},
|
2613
|
+
)
|
2614
|
+
self._resolve_server_response(confirm_response)
|
2615
|
+
|
2616
|
+
return {
|
2617
|
+
"description": response_json.get("description"),
|
2618
|
+
"object_type": object_type,
|
2619
|
+
"url": f"{self.url}/content/{object_uuid}",
|
2620
|
+
"alias_url": self._get_alias_url(owner_username, object_alias),
|
2621
|
+
"uuid": object_uuid,
|
2622
|
+
"version": self._edsl_version,
|
2623
|
+
"visibility": response_json.get("visibility"),
|
2624
|
+
}
|
2625
|
+
|
1689
2626
|
def _display_login_url(
|
1690
2627
|
self, edsl_auth_token: str, link_description: Optional[str] = None
|
1691
2628
|
):
|
@@ -1769,6 +2706,125 @@ class Coop(CoopFunctionsMixin):
|
|
1769
2706
|
# Add API key to environment
|
1770
2707
|
load_dotenv()
|
1771
2708
|
|
2709
|
+
def login_streamlit(self, timeout: int = 120):
|
2710
|
+
"""
|
2711
|
+
Start the EDSL auth token login flow inside a Streamlit application.
|
2712
|
+
|
2713
|
+
This helper is functionally equivalent to ``Coop.login`` but renders the
|
2714
|
+
login link and status updates directly in the Streamlit UI. The method
|
2715
|
+
will automatically poll the Expected Parrot server for the API-key
|
2716
|
+
associated with the generated auth-token and, once received, store it
|
2717
|
+
via ``ExpectedParrotKeyHandler`` and write it to the local ``.env``
|
2718
|
+
file so subsequent sessions pick it up automatically.
|
2719
|
+
|
2720
|
+
Parameters
|
2721
|
+
----------
|
2722
|
+
timeout : int, default 120
|
2723
|
+
How many seconds to wait for the user to complete the login before
|
2724
|
+
giving up and showing an error in the Streamlit app.
|
2725
|
+
|
2726
|
+
Returns
|
2727
|
+
-------
|
2728
|
+
str | None
|
2729
|
+
The API-key if the user logged-in successfully, otherwise ``None``.
|
2730
|
+
"""
|
2731
|
+
try:
|
2732
|
+
import streamlit as st
|
2733
|
+
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
2734
|
+
except ModuleNotFoundError as exc:
|
2735
|
+
raise ImportError(
|
2736
|
+
"Streamlit is required for `login_streamlit`. Install it with `pip install streamlit`."
|
2737
|
+
) from exc
|
2738
|
+
|
2739
|
+
# Ensure we are actually running inside a Streamlit script. If not, give a
|
2740
|
+
# clear error message instead of crashing when `st.experimental_rerun` is
|
2741
|
+
# invoked outside the Streamlit runtime.
|
2742
|
+
if get_script_run_ctx() is None:
|
2743
|
+
raise RuntimeError(
|
2744
|
+
"`login_streamlit` must be invoked from within a running Streamlit "
|
2745
|
+
"app (use `streamlit run your_script.py`). If you need to obtain an "
|
2746
|
+
"API-key in a regular Python script or notebook, use `Coop.login()` "
|
2747
|
+
"instead."
|
2748
|
+
)
|
2749
|
+
|
2750
|
+
import secrets
|
2751
|
+
import time
|
2752
|
+
import os
|
2753
|
+
from dotenv import load_dotenv
|
2754
|
+
from .ep_key_handling import ExpectedParrotKeyHandler
|
2755
|
+
from ..utilities.utilities import write_api_key_to_env
|
2756
|
+
|
2757
|
+
# ------------------------------------------------------------------
|
2758
|
+
# 1. Prepare auth-token and store state across reruns
|
2759
|
+
# ------------------------------------------------------------------
|
2760
|
+
if "edsl_auth_token" not in st.session_state:
|
2761
|
+
st.session_state.edsl_auth_token = secrets.token_urlsafe(16)
|
2762
|
+
st.session_state.login_start_time = time.time()
|
2763
|
+
|
2764
|
+
edsl_auth_token: str = st.session_state.edsl_auth_token
|
2765
|
+
login_url = (
|
2766
|
+
f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
2767
|
+
)
|
2768
|
+
|
2769
|
+
# ------------------------------------------------------------------
|
2770
|
+
# 2. Render clickable login link
|
2771
|
+
# ------------------------------------------------------------------
|
2772
|
+
st.markdown(
|
2773
|
+
f"🔗 **Log in to Expected Parrot** → [click here]({login_url})",
|
2774
|
+
unsafe_allow_html=True,
|
2775
|
+
)
|
2776
|
+
|
2777
|
+
# ------------------------------------------------------------------
|
2778
|
+
# 3. Poll server for API-key (runs once per Streamlit execution)
|
2779
|
+
# ------------------------------------------------------------------
|
2780
|
+
api_key = self._get_api_key(edsl_auth_token)
|
2781
|
+
if api_key is None:
|
2782
|
+
elapsed = time.time() - st.session_state.login_start_time
|
2783
|
+
if elapsed > timeout:
|
2784
|
+
st.error(
|
2785
|
+
"Timed-out waiting for login. Please rerun the app to try again."
|
2786
|
+
)
|
2787
|
+
return None
|
2788
|
+
|
2789
|
+
remaining = int(timeout - elapsed)
|
2790
|
+
st.info(f"Waiting for login… ({remaining}s left)")
|
2791
|
+
# Trigger a rerun after a short delay to continue polling
|
2792
|
+
time.sleep(1)
|
2793
|
+
|
2794
|
+
# Attempt a rerun in a version-agnostic way. Different Streamlit
|
2795
|
+
# releases expose the helper under different names.
|
2796
|
+
def _safe_rerun():
|
2797
|
+
if hasattr(st, "experimental_rerun"):
|
2798
|
+
st.experimental_rerun()
|
2799
|
+
elif hasattr(st, "rerun"):
|
2800
|
+
st.rerun() # introduced in newer versions
|
2801
|
+
else:
|
2802
|
+
# Fallback – advise the user to update Streamlit for automatic polling.
|
2803
|
+
st.warning(
|
2804
|
+
"Please refresh the page to continue the login flow. "
|
2805
|
+
"(Consider upgrading Streamlit to enable automatic refresh.)"
|
2806
|
+
)
|
2807
|
+
|
2808
|
+
try:
|
2809
|
+
_safe_rerun()
|
2810
|
+
except Exception:
|
2811
|
+
# The Streamlit runtime intercepts the rerun exception; any other
|
2812
|
+
# unexpected errors are ignored to avoid crashing the app.
|
2813
|
+
pass
|
2814
|
+
|
2815
|
+
# ------------------------------------------------------------------
|
2816
|
+
# 4. Key received – persist it and notify user
|
2817
|
+
# ------------------------------------------------------------------
|
2818
|
+
ExpectedParrotKeyHandler().store_ep_api_key(api_key)
|
2819
|
+
os.environ["EXPECTED_PARROT_API_KEY"] = api_key
|
2820
|
+
path_to_env = write_api_key_to_env(api_key)
|
2821
|
+
load_dotenv()
|
2822
|
+
|
2823
|
+
st.success("API-key retrieved and stored. You are now logged-in! 🎉")
|
2824
|
+
st.caption(f"Key saved to `{path_to_env}`.")
|
2825
|
+
|
2826
|
+
return api_key
|
2827
|
+
|
1772
2828
|
def transfer_credits(
|
1773
2829
|
self,
|
1774
2830
|
credits_transferred: int,
|
@@ -1835,10 +2891,155 @@ class Coop(CoopFunctionsMixin):
|
|
1835
2891
|
>>> balance = coop.get_balance()
|
1836
2892
|
>>> print(f"You have {balance['credits']} credits available.")
|
1837
2893
|
"""
|
1838
|
-
response = self._send_server_request(
|
2894
|
+
response = self._send_server_request(
|
2895
|
+
uri="api/v0/users/get-balance", method="GET"
|
2896
|
+
)
|
1839
2897
|
self._resolve_server_response(response)
|
1840
2898
|
return response.json()
|
1841
2899
|
|
2900
|
+
def login_gradio(self, timeout: int = 120, launch: bool = True, **launch_kwargs):
|
2901
|
+
"""
|
2902
|
+
Start the EDSL auth token login flow inside a **Gradio** application.
|
2903
|
+
|
2904
|
+
This helper mirrors the behaviour of :py:meth:`Coop.login_streamlit` but
|
2905
|
+
renders the login link and status updates inside a Gradio UI. It will
|
2906
|
+
poll the Expected Parrot server for the API-key associated with a newly
|
2907
|
+
generated auth-token and, once received, store it via
|
2908
|
+
:pyclass:`~edsl.coop.ep_key_handling.ExpectedParrotKeyHandler` as well as
|
2909
|
+
in the local ``.env`` file so subsequent sessions pick it up
|
2910
|
+
automatically.
|
2911
|
+
|
2912
|
+
Parameters
|
2913
|
+
----------
|
2914
|
+
timeout : int, default 120
|
2915
|
+
How many seconds to wait for the user to complete the login before
|
2916
|
+
giving up.
|
2917
|
+
launch : bool, default True
|
2918
|
+
If ``True`` the Gradio app is immediately launched with
|
2919
|
+
``demo.launch(**launch_kwargs)``. Set this to ``False`` if you want
|
2920
|
+
to embed the returned :class:`gradio.Blocks` object into an existing
|
2921
|
+
Gradio interface.
|
2922
|
+
**launch_kwargs
|
2923
|
+
Additional keyword-arguments forwarded to ``gr.Blocks.launch`` when
|
2924
|
+
*launch* is ``True``.
|
2925
|
+
|
2926
|
+
Returns
|
2927
|
+
-------
|
2928
|
+
str | gradio.Blocks | None
|
2929
|
+
• If the API-key is retrieved within *timeout* seconds while the
|
2930
|
+
function is executing (e.g. when *launch* is ``False`` and the
|
2931
|
+
caller integrates the Blocks into another app) the key is
|
2932
|
+
returned.
|
2933
|
+
• If *launch* is ``True`` the method returns ``None`` after the
|
2934
|
+
Gradio app has been launched.
|
2935
|
+
• If *launch* is ``False`` the constructed ``gr.Blocks`` is
|
2936
|
+
returned so the caller can compose it further.
|
2937
|
+
"""
|
2938
|
+
try:
|
2939
|
+
import gradio as gr
|
2940
|
+
except ModuleNotFoundError as exc:
|
2941
|
+
raise ImportError(
|
2942
|
+
"Gradio is required for `login_gradio`. Install it with `pip install gradio`."
|
2943
|
+
) from exc
|
2944
|
+
|
2945
|
+
import secrets
|
2946
|
+
import time
|
2947
|
+
import os
|
2948
|
+
from dotenv import load_dotenv
|
2949
|
+
from .ep_key_handling import ExpectedParrotKeyHandler
|
2950
|
+
from ..utilities.utilities import write_api_key_to_env
|
2951
|
+
|
2952
|
+
# ------------------------------------------------------------------
|
2953
|
+
# 1. Prepare auth-token
|
2954
|
+
# ------------------------------------------------------------------
|
2955
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
2956
|
+
login_url = (
|
2957
|
+
f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
2958
|
+
)
|
2959
|
+
start_time = time.time()
|
2960
|
+
|
2961
|
+
# ------------------------------------------------------------------
|
2962
|
+
# 2. Build Gradio interface
|
2963
|
+
# ------------------------------------------------------------------
|
2964
|
+
with gr.Blocks() as demo:
|
2965
|
+
gr.HTML(
|
2966
|
+
f'🔗 <b>Log in to Expected Parrot</b> → <a href="{login_url}" target="_blank">click here</a>'
|
2967
|
+
)
|
2968
|
+
status_md = gr.Markdown("Waiting for login…")
|
2969
|
+
refresh_btn = gr.Button(
|
2970
|
+
"I've logged in – click to continue", elem_id="refresh-btn"
|
2971
|
+
)
|
2972
|
+
key_state = gr.State(value=None)
|
2973
|
+
|
2974
|
+
# --------------------------------------------------------------
|
2975
|
+
# Polling callback
|
2976
|
+
# --------------------------------------------------------------
|
2977
|
+
def _refresh(current_key): # noqa: D401, pylint: disable=unused-argument
|
2978
|
+
"""Poll server for API-key and update UI accordingly."""
|
2979
|
+
|
2980
|
+
# Fallback helper to generate a `update` object for the refresh button
|
2981
|
+
def _button_update(**kwargs):
|
2982
|
+
try:
|
2983
|
+
return gr.Button.update(**kwargs)
|
2984
|
+
except AttributeError:
|
2985
|
+
return gr.update(**kwargs)
|
2986
|
+
|
2987
|
+
api_key = self._get_api_key(edsl_auth_token)
|
2988
|
+
# Fall back to env var in case the key was obtained earlier in this session
|
2989
|
+
if not api_key:
|
2990
|
+
api_key = os.environ.get("EXPECTED_PARROT_API_KEY")
|
2991
|
+
elapsed = time.time() - start_time
|
2992
|
+
remaining = max(0, int(timeout - elapsed))
|
2993
|
+
|
2994
|
+
if api_key:
|
2995
|
+
# Persist and expose the key
|
2996
|
+
ExpectedParrotKeyHandler().store_ep_api_key(api_key)
|
2997
|
+
os.environ["EXPECTED_PARROT_API_KEY"] = api_key
|
2998
|
+
path_to_env = write_api_key_to_env(api_key)
|
2999
|
+
load_dotenv()
|
3000
|
+
success_msg = (
|
3001
|
+
"API-key retrieved and stored 🎉\n\n"
|
3002
|
+
f"Key saved to `{path_to_env}`."
|
3003
|
+
)
|
3004
|
+
return (
|
3005
|
+
success_msg,
|
3006
|
+
_button_update(interactive=False, visible=False),
|
3007
|
+
api_key,
|
3008
|
+
)
|
3009
|
+
|
3010
|
+
if elapsed > timeout:
|
3011
|
+
err_msg = (
|
3012
|
+
"Timed-out waiting for login. Please refresh the page "
|
3013
|
+
"or restart the app to try again."
|
3014
|
+
)
|
3015
|
+
return (
|
3016
|
+
err_msg,
|
3017
|
+
_button_update(),
|
3018
|
+
None,
|
3019
|
+
)
|
3020
|
+
|
3021
|
+
info_msg = f"Waiting for login… ({remaining}s left)"
|
3022
|
+
return (
|
3023
|
+
info_msg,
|
3024
|
+
_button_update(),
|
3025
|
+
None,
|
3026
|
+
)
|
3027
|
+
|
3028
|
+
# Initial status check when the interface loads
|
3029
|
+
demo.load(
|
3030
|
+
fn=_refresh,
|
3031
|
+
inputs=key_state,
|
3032
|
+
outputs=[status_md, refresh_btn, key_state],
|
3033
|
+
)
|
3034
|
+
|
3035
|
+
# ------------------------------------------------------------------
|
3036
|
+
# 3. Launch or return interface
|
3037
|
+
# ------------------------------------------------------------------
|
3038
|
+
if launch:
|
3039
|
+
demo.launch(**launch_kwargs)
|
3040
|
+
return None
|
3041
|
+
return demo
|
3042
|
+
|
1842
3043
|
|
1843
3044
|
def main():
|
1844
3045
|
"""
|
@@ -1975,3 +3176,12 @@ def main():
|
|
1975
3176
|
job_coop_object = coop.remote_inference_create(job)
|
1976
3177
|
job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
|
1977
3178
|
coop.get(job_coop_results.get("results_uuid"))
|
3179
|
+
|
3180
|
+
import streamlit as st
|
3181
|
+
from edsl.coop import Coop
|
3182
|
+
|
3183
|
+
coop = Coop() # no API-key required yet
|
3184
|
+
api_key = coop.login_streamlit() # renders link + handles polling & storage
|
3185
|
+
|
3186
|
+
if api_key:
|
3187
|
+
st.success("Ready to use EDSL with remote features!")
|