edsl 0.1.59__py3-none-any.whl → 0.1.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/coop/coop.py CHANGED
@@ -3,8 +3,9 @@ import base64
3
3
  import json
4
4
  import requests
5
5
  import time
6
+ import os
6
7
 
7
- from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
8
+ from typing import Any, Dict, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
8
9
  from uuid import UUID
9
10
 
10
11
  from .. import __version__
@@ -35,6 +36,7 @@ from .utils import (
35
36
  from .coop_functions import CoopFunctionsMixin
36
37
  from .coop_regular_objects import CoopRegularObjects
37
38
  from .coop_jobs_objects import CoopJobsObjects
39
+ from .coop_prolific_filters import CoopProlificFilters
38
40
  from .ep_key_handling import ExpectedParrotKeyHandler
39
41
 
40
42
  from ..inference_services.data_structures import ServiceToModelsMapping
@@ -66,7 +68,6 @@ class JobRunInterviewDetails(TypedDict):
66
68
 
67
69
 
68
70
  class LatestJobRunDetails(TypedDict):
69
-
70
71
  # For running, completed, and partially failed jobs
71
72
  interview_details: Optional[JobRunInterviewDetails] = None
72
73
 
@@ -296,9 +297,9 @@ class Coop(CoopFunctionsMixin):
296
297
  message = str(response.json().get("detail"))
297
298
  except json.JSONDecodeError:
298
299
  raise CoopServerResponseError(
299
- f"Server returned status code {response.status_code}."
300
- "JSON response could not be decoded.",
301
- "The server response was: " + response.text,
300
+ f"Server returned status code {response.status_code}. "
301
+ f"JSON response could not be decoded. "
302
+ f"The server response was: {response.text}"
302
303
  )
303
304
  # print(response.text)
304
305
  if "The API key you provided is invalid" in message and check_api_key:
@@ -694,6 +695,35 @@ class Coop(CoopFunctionsMixin):
694
695
  """
695
696
  obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
696
697
 
698
+ # Handle alias-based retrieval with new/old format detection
699
+ if not obj_uuid and owner_username and alias:
700
+ # First, get object info to determine format and UUID
701
+ info_response = self._send_server_request(
702
+ uri="api/v0/object/alias/info",
703
+ method="GET",
704
+ params={"owner_username": owner_username, "alias": alias},
705
+ )
706
+ self._resolve_server_response(info_response)
707
+ info_data = info_response.json()
708
+
709
+ obj_uuid = info_data.get("uuid")
710
+ is_new_format = info_data.get("is_new_format", False)
711
+
712
+ # Validate object type if expected
713
+ if expected_object_type:
714
+ actual_object_type = info_data.get("object_type")
715
+ if actual_object_type != expected_object_type:
716
+ from .exceptions import CoopObjectTypeError
717
+
718
+ raise CoopObjectTypeError(
719
+ f"Expected {expected_object_type=} but got {actual_object_type=}"
720
+ )
721
+
722
+ # Use pull method for new format objects
723
+ if is_new_format:
724
+ return self.pull(obj_uuid, expected_object_type)
725
+
726
+ # Handle UUID-based retrieval or legacy alias objects
697
727
  if obj_uuid:
698
728
  response = self._send_server_request(
699
729
  uri="api/v0/object",
@@ -915,6 +945,26 @@ class Coop(CoopFunctionsMixin):
915
945
 
916
946
  obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
917
947
 
948
+ # If we have a UUID and are updating the value, check the storage format first
949
+ if obj_uuid and value:
950
+ # Check if object is in new format (GCS)
951
+ format_check_response = self._send_server_request(
952
+ uri="api/v0/object/check-format",
953
+ method="POST",
954
+ payload={"object_uuid": str(obj_uuid)},
955
+ )
956
+ self._resolve_server_response(format_check_response)
957
+ format_data = format_check_response.json()
958
+
959
+ is_new_format = format_data.get("is_new_format", False)
960
+
961
+ if is_new_format:
962
+ # Handle new format objects: update metadata first, then upload content
963
+ return self._patch_new_format_object(
964
+ obj_uuid, description, alias, value, visibility
965
+ )
966
+
967
+ # Handle traditional format objects or metadata-only updates
918
968
  if obj_uuid:
919
969
  uri = "api/v0/object"
920
970
  params = {"uuid": obj_uuid}
@@ -944,6 +994,70 @@ class Coop(CoopFunctionsMixin):
944
994
  self._resolve_server_response(response)
945
995
  return response.json()
946
996
 
997
+ def _patch_new_format_object(
998
+ self,
999
+ obj_uuid: UUID,
1000
+ description: Optional[str],
1001
+ alias: Optional[str],
1002
+ value: EDSLObject,
1003
+ visibility: Optional[VisibilityType],
1004
+ ) -> dict:
1005
+ """
1006
+ Handle patching of objects stored in the new format (GCS).
1007
+ """
1008
+ # Step 1: Update metadata only (no json_string)
1009
+ if description is not None or alias is not None or visibility is not None:
1010
+ metadata_response = self._send_server_request(
1011
+ uri="api/v0/object",
1012
+ method="PATCH",
1013
+ params={"uuid": obj_uuid},
1014
+ payload={
1015
+ "description": description,
1016
+ "alias": alias,
1017
+ "json_string": None, # Don't send content to traditional endpoint
1018
+ "visibility": visibility,
1019
+ },
1020
+ )
1021
+ self._resolve_server_response(metadata_response)
1022
+
1023
+ # Step 2: Get signed upload URL for content update
1024
+ upload_url_response = self._send_server_request(
1025
+ uri="api/v0/object/upload-url",
1026
+ method="POST",
1027
+ payload={"object_uuid": str(obj_uuid)},
1028
+ )
1029
+ self._resolve_server_response(upload_url_response)
1030
+ upload_data = upload_url_response.json()
1031
+
1032
+ # Step 3: Upload the object content to GCS
1033
+ signed_url = upload_data.get("signed_url")
1034
+ if not signed_url:
1035
+ raise CoopServerResponseError("Failed to get signed upload URL")
1036
+
1037
+ json_content = json.dumps(
1038
+ value.to_dict(),
1039
+ default=self._json_handle_none,
1040
+ allow_nan=False,
1041
+ )
1042
+
1043
+ # Upload to GCS using signed URL
1044
+ gcs_response = requests.put(
1045
+ signed_url,
1046
+ data=json_content,
1047
+ headers={"Content-Type": "application/json"},
1048
+ )
1049
+
1050
+ if gcs_response.status_code != 200:
1051
+ raise CoopServerResponseError(
1052
+ f"Failed to upload object to GCS: {gcs_response.status_code}"
1053
+ )
1054
+
1055
+ return {
1056
+ "status": "success",
1057
+ "message": "Object updated successfully (new format - uploaded to GCS)",
1058
+ "object_uuid": str(obj_uuid),
1059
+ }
1060
+
947
1061
  ################
948
1062
  # Remote Cache
949
1063
  ################
@@ -1025,6 +1139,115 @@ class Coop(CoopFunctionsMixin):
1025
1139
  is handled by Expected Parrot's infrastructure, and you can check the status
1026
1140
  and retrieve results later.
1027
1141
 
1142
+ Parameters:
1143
+ job (Jobs): The EDSL job to run in the cloud
1144
+ description (str, optional): A human-readable description of the job
1145
+ status (RemoteJobStatus): Initial status, should be "queued" for normal use
1146
+ Possible values: "queued", "running", "completed", "failed"
1147
+ visibility (VisibilityType): Access level for the job information. One of:
1148
+ - "private": Only accessible by the owner
1149
+ - "public": Accessible by anyone
1150
+ - "unlisted": Accessible with the link, but not listed publicly
1151
+ initial_results_visibility (VisibilityType): Access level for the job results
1152
+ iterations (int): Number of times to run each interview (default: 1)
1153
+ fresh (bool): If True, ignore existing cache entries and generate new results
1154
+
1155
+ Returns:
1156
+ RemoteInferenceCreationInfo: Information about the created job including:
1157
+ - uuid: The unique identifier for the job
1158
+ - description: The job description
1159
+ - status: Current status of the job
1160
+ - iterations: Number of iterations for each interview
1161
+ - visibility: Access level for the job
1162
+ - version: EDSL version used to create the job
1163
+
1164
+ Raises:
1165
+ CoopServerResponseError: If there's an error communicating with the server
1166
+
1167
+ Notes:
1168
+ - Remote jobs run asynchronously and may take time to complete
1169
+ - Use remote_inference_get() with the returned UUID to check status
1170
+ - Credits are consumed based on the complexity of the job
1171
+
1172
+ Example:
1173
+ >>> from edsl.jobs import Jobs
1174
+ >>> job = Jobs.example()
1175
+ >>> job_info = coop.remote_inference_create(job=job, description="My job")
1176
+ >>> print(f"Job created with UUID: {job_info['uuid']}")
1177
+ """
1178
+ response = self._send_server_request(
1179
+ uri="api/v0/new-remote-inference",
1180
+ method="POST",
1181
+ payload={
1182
+ "json_string": "offloaded",
1183
+ "description": description,
1184
+ "status": status,
1185
+ "iterations": iterations,
1186
+ "visibility": visibility,
1187
+ "version": self._edsl_version,
1188
+ "initial_results_visibility": initial_results_visibility,
1189
+ "fresh": fresh,
1190
+ },
1191
+ )
1192
+ self._resolve_server_response(response)
1193
+ response_json = response.json()
1194
+ upload_signed_url = response_json.get("upload_signed_url")
1195
+ if not upload_signed_url:
1196
+ from .exceptions import CoopResponseError
1197
+
1198
+ raise CoopResponseError("No signed url was provided received")
1199
+
1200
+ response = requests.put(
1201
+ upload_signed_url,
1202
+ data=json.dumps(
1203
+ job.to_dict(),
1204
+ default=self._json_handle_none,
1205
+ ).encode(),
1206
+ headers={"Content-Type": "application/json"},
1207
+ )
1208
+ self._resolve_gcs_response(response)
1209
+
1210
+ job_uuid = response_json.get("job_uuid")
1211
+
1212
+ response = self._send_server_request(
1213
+ uri="api/v0/new-remote-inference/uploaded",
1214
+ method="POST",
1215
+ payload={
1216
+ "job_uuid": job_uuid,
1217
+ "message": "Job uploaded successfully",
1218
+ },
1219
+ )
1220
+ response_json = response.json()
1221
+
1222
+ return RemoteInferenceCreationInfo(
1223
+ **{
1224
+ "uuid": response_json.get("job_uuid"),
1225
+ "description": response_json.get("description", ""),
1226
+ "status": response_json.get("status"),
1227
+ "iterations": response_json.get("iterations", ""),
1228
+ "visibility": response_json.get("visibility", ""),
1229
+ "version": self._edsl_version,
1230
+ }
1231
+ )
1232
+
1233
+ def old_remote_inference_create(
1234
+ self,
1235
+ job: "Jobs",
1236
+ description: Optional[str] = None,
1237
+ status: RemoteJobStatus = "queued",
1238
+ visibility: Optional[VisibilityType] = "unlisted",
1239
+ initial_results_visibility: Optional[VisibilityType] = "unlisted",
1240
+ iterations: Optional[int] = 1,
1241
+ fresh: Optional[bool] = False,
1242
+ ) -> RemoteInferenceCreationInfo:
1243
+ """
1244
+ Create a remote inference job for execution in the Expected Parrot cloud.
1245
+
1246
+ This method sends a job to be executed in the cloud, which can be more efficient
1247
+ for large jobs or when you want to run jobs in the background. The job execution
1248
+ is handled by Expected Parrot's infrastructure, and you can check the status
1249
+ and retrieve results later.
1250
+
1028
1251
  Parameters:
1029
1252
  job (Jobs): The EDSL job to run in the cloud
1030
1253
  description (str, optional): A human-readable description of the job
@@ -1368,25 +1591,57 @@ class Coop(CoopFunctionsMixin):
1368
1591
  def create_project(
1369
1592
  self,
1370
1593
  survey: "Survey",
1594
+ scenario_list: Optional["ScenarioList"] = None,
1595
+ scenario_list_method: Optional[
1596
+ Literal["randomize", "loop", "single_scenario"]
1597
+ ] = None,
1371
1598
  project_name: str = "Project",
1372
1599
  survey_description: Optional[str] = None,
1373
1600
  survey_alias: Optional[str] = None,
1374
1601
  survey_visibility: Optional[VisibilityType] = "unlisted",
1602
+ scenario_list_description: Optional[str] = None,
1603
+ scenario_list_alias: Optional[str] = None,
1604
+ scenario_list_visibility: Optional[VisibilityType] = "unlisted",
1375
1605
  ):
1376
1606
  """
1377
1607
  Create a survey object on Coop, then create a project from the survey.
1378
1608
  """
1379
- survey_details = self.create(
1609
+ if scenario_list is None and scenario_list_method is not None:
1610
+ raise CoopValueError(
1611
+ "You must specify both a scenario list and a scenario list method to use scenarios with your survey."
1612
+ )
1613
+ elif scenario_list is not None and scenario_list_method is None:
1614
+ raise CoopValueError(
1615
+ "You must specify both a scenario list and a scenario list method to use scenarios with your survey."
1616
+ )
1617
+ survey_details = self.push(
1380
1618
  object=survey,
1381
1619
  description=survey_description,
1382
1620
  alias=survey_alias,
1383
1621
  visibility=survey_visibility,
1384
1622
  )
1385
1623
  survey_uuid = survey_details.get("uuid")
1624
+ if scenario_list is not None:
1625
+ scenario_list_details = self.push(
1626
+ object=scenario_list,
1627
+ description=scenario_list_description,
1628
+ alias=scenario_list_alias,
1629
+ visibility=scenario_list_visibility,
1630
+ )
1631
+ scenario_list_uuid = scenario_list_details.get("uuid")
1632
+ else:
1633
+ scenario_list_uuid = None
1386
1634
  response = self._send_server_request(
1387
1635
  uri="api/v0/projects/create-from-survey",
1388
1636
  method="POST",
1389
- payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
1637
+ payload={
1638
+ "project_name": project_name,
1639
+ "survey_uuid": str(survey_uuid),
1640
+ "scenario_list_uuid": (
1641
+ str(scenario_list_uuid) if scenario_list_uuid is not None else None
1642
+ ),
1643
+ "scenario_list_method": scenario_list_method,
1644
+ },
1390
1645
  )
1391
1646
  self._resolve_server_response(response)
1392
1647
  response_json = response.json()
@@ -1413,14 +1668,26 @@ class Coop(CoopFunctionsMixin):
1413
1668
  return {
1414
1669
  "project_name": response_json.get("project_name"),
1415
1670
  "project_job_uuids": response_json.get("job_uuids"),
1671
+ "project_prolific_studies": [
1672
+ {
1673
+ "study_id": study.get("id"),
1674
+ "name": study.get("name"),
1675
+ "status": study.get("status"),
1676
+ "num_participants": study.get("total_available_places"),
1677
+ "places_taken": study.get("places_taken"),
1678
+ }
1679
+ for study in response_json.get("prolific_studies", [])
1680
+ ],
1416
1681
  }
1417
1682
 
1418
- def get_project_human_responses(
1683
+ def _turn_human_responses_into_results(
1419
1684
  self,
1420
- project_uuid: str,
1685
+ human_responses: List[dict],
1686
+ survey_json_string: str,
1687
+ scenario_list_json_string: Optional[str] = None,
1421
1688
  ) -> Union["Results", "ScenarioList"]:
1422
1689
  """
1423
- Return a Results object with the human responses for a project.
1690
+ Turn a list of human responses into a Results object.
1424
1691
 
1425
1692
  If generating the Results object fails, a ScenarioList will be returned instead.
1426
1693
  """
@@ -1430,16 +1697,19 @@ class Coop(CoopFunctionsMixin):
1430
1697
  from ..scenarios import Scenario, ScenarioList
1431
1698
  from ..surveys import Survey
1432
1699
 
1433
- response = self._send_server_request(
1434
- uri=f"api/v0/projects/{project_uuid}/human-responses",
1435
- method="GET",
1436
- )
1437
- self._resolve_server_response(response)
1438
- response_json = response.json()
1439
- human_responses = response_json.get("human_responses", [])
1440
-
1441
1700
  try:
1442
- agent_list = AgentList()
1701
+ survey = Survey.from_dict(json.loads(survey_json_string))
1702
+
1703
+ model = Model("test")
1704
+
1705
+ if scenario_list_json_string is not None:
1706
+ scenario_list = ScenarioList.from_dict(
1707
+ json.loads(scenario_list_json_string)
1708
+ )
1709
+ else:
1710
+ scenario_list = ScenarioList()
1711
+
1712
+ results = None
1443
1713
 
1444
1714
  for response in human_responses:
1445
1715
  response_uuid = response.get("response_uuid")
@@ -1449,8 +1719,14 @@ class Coop(CoopFunctionsMixin):
1449
1719
  )
1450
1720
 
1451
1721
  response_dict = json.loads(response.get("response_json_string"))
1722
+ agent_traits_json_string = response.get("agent_traits_json_string")
1723
+ scenario_uuid = response.get("scenario_uuid")
1724
+ if agent_traits_json_string is not None:
1725
+ agent_traits = json.loads(agent_traits_json_string)
1726
+ else:
1727
+ agent_traits = {}
1452
1728
 
1453
- a = Agent(name=response_uuid, instruction="")
1729
+ a = Agent(name=response_uuid, instruction="", traits=agent_traits)
1454
1730
 
1455
1731
  def create_answer_function(response_data):
1456
1732
  def f(self, question, scenario):
@@ -1458,27 +1734,38 @@ class Coop(CoopFunctionsMixin):
1458
1734
 
1459
1735
  return f
1460
1736
 
1737
+ scenario = None
1738
+ if scenario_uuid is not None:
1739
+ for s in scenario_list:
1740
+ if s.get("uuid") == scenario_uuid:
1741
+ scenario = s
1742
+ break
1743
+
1744
+ if scenario is None:
1745
+ raise RuntimeError("Scenario not found.")
1746
+
1461
1747
  a.add_direct_question_answering_method(
1462
1748
  create_answer_function(response_dict)
1463
1749
  )
1464
- agent_list.append(a)
1465
1750
 
1466
- survey_json_string = response_json.get("survey_json_string")
1467
- survey = Survey.from_dict(json.loads(survey_json_string))
1751
+ job = survey.by(a).by(model)
1468
1752
 
1469
- model = Model("test")
1470
- results = (
1471
- survey.by(agent_list)
1472
- .by(model)
1473
- .run(
1753
+ if scenario is not None:
1754
+ job = job.by(scenario)
1755
+
1756
+ question_results = job.run(
1474
1757
  cache=Cache(),
1475
1758
  disable_remote_cache=True,
1476
1759
  disable_remote_inference=True,
1477
1760
  print_exceptions=False,
1478
1761
  )
1479
- )
1762
+
1763
+ if results is None:
1764
+ results = question_results
1765
+ else:
1766
+ results = results + question_results
1480
1767
  return results
1481
- except Exception:
1768
+ except Exception as e:
1482
1769
  human_response_scenarios = []
1483
1770
  for response in human_responses:
1484
1771
  response_uuid = response.get("response_uuid")
@@ -1493,71 +1780,492 @@ class Coop(CoopFunctionsMixin):
1493
1780
  human_response_scenarios.append(scenario)
1494
1781
  return ScenarioList(human_response_scenarios)
1495
1782
 
1496
- def __repr__(self):
1497
- """Return a string representation of the client."""
1498
- return f"Client(api_key='{self.api_key}', url='{self.url}')"
1499
-
1500
- async def remote_async_execute_model_call(
1501
- self, model_dict: dict, user_prompt: str, system_prompt: str
1502
- ) -> dict:
1503
- url = self.api_url + "/inference/"
1504
- # print("Now using url: ", url)
1505
- data = {
1506
- "model_dict": model_dict,
1507
- "user_prompt": user_prompt,
1508
- "system_prompt": system_prompt,
1509
- }
1510
- # Use aiohttp to send a POST request asynchronously
1511
- async with aiohttp.ClientSession() as session:
1512
- async with session.post(url, json=data) as response:
1513
- response_data = await response.json()
1514
- return response_data
1515
-
1516
- def web(
1783
+ def get_project_human_responses(
1517
1784
  self,
1518
- survey: dict,
1519
- platform: Literal[
1520
- "google_forms", "lime_survey", "survey_monkey"
1521
- ] = "lime_survey",
1522
- email=None,
1523
- ):
1524
- url = f"{self.api_url}/api/v0/export_to_{platform}"
1525
- if email:
1526
- data = {"json_string": json.dumps({"survey": survey, "email": email})}
1527
- else:
1528
- data = {"json_string": json.dumps({"survey": survey, "email": ""})}
1785
+ project_uuid: str,
1786
+ ) -> Union["Results", "ScenarioList"]:
1787
+ """
1788
+ Return a Results object with the human responses for a project.
1529
1789
 
1530
- response_json = requests.post(url, headers=self.headers, data=json.dumps(data))
1790
+ If generating the Results object fails, a ScenarioList will be returned instead.
1791
+ """
1792
+ response = self._send_server_request(
1793
+ uri=f"api/v0/projects/{project_uuid}/human-responses",
1794
+ method="GET",
1795
+ )
1796
+ self._resolve_server_response(response)
1797
+ response_json = response.json()
1798
+ human_responses = response_json.get("human_responses", [])
1799
+ survey_json_string = response_json.get("survey_json_string")
1800
+ scenario_list_json_string = response_json.get("scenario_list_json_string")
1531
1801
 
1532
- return response_json
1802
+ return self._turn_human_responses_into_results(
1803
+ human_responses, survey_json_string, scenario_list_json_string
1804
+ )
1533
1805
 
1534
- def fetch_prices(self) -> dict:
1806
+ def list_prolific_filters(self) -> "CoopProlificFilters":
1535
1807
  """
1536
- Fetch the current pricing information for language models.
1808
+ Get a ScenarioList of supported Prolific filters. This list has several methods
1809
+ that you can use to create valid filter dicts for use with Coop.create_prolific_study().
1537
1810
 
1538
- This method retrieves the latest pricing information for all supported language models
1539
- from the Expected Parrot API. The pricing data is used to estimate costs for jobs
1540
- and to optimize model selection based on budget constraints.
1811
+ Call find() to examine a specific filter by ID:
1812
+ >>> filters = coop.list_prolific_filters()
1813
+ >>> filters.find("age")
1814
+ Scenario(
1815
+ {
1816
+ "filter_id": "age",
1817
+ "type": "range",
1818
+ "range_filter_min": 18,
1819
+ "range_filter_max": 100,
1820
+ ...
1821
+ }
1822
+ )
1541
1823
 
1542
- Returns:
1543
- dict: A dictionary mapping (service, model) tuples to pricing information.
1544
- Each entry contains token pricing for input and output tokens.
1545
- Example structure:
1824
+ Call create_study_filter() to create a valid filter dict:
1825
+ >>> filters.create_study_filter("age", min=30, max=40)
1826
+ {
1827
+ "filter_id": "age",
1828
+ "selected_range": {
1829
+ "lower": 30,
1830
+ "upper": 40,
1831
+ },
1832
+ }
1833
+ """
1834
+ from ..scenarios import Scenario
1835
+
1836
+ response = self._send_server_request(
1837
+ uri="api/v0/prolific-filters",
1838
+ method="GET",
1839
+ )
1840
+ self._resolve_server_response(response)
1841
+ response_json = response.json()
1842
+ filters = response_json.get("prolific_filters", [])
1843
+ filter_scenarios = []
1844
+ for filter in filters:
1845
+ filter_type = filter.get("type")
1846
+ question = filter.get("question")
1847
+ scenario = Scenario(
1546
1848
  {
1547
- ('openai', 'gpt-4'): {
1548
- 'input': {'usd_per_1M_tokens': 30.0, ...},
1549
- 'output': {'usd_per_1M_tokens': 60.0, ...}
1550
- }
1849
+ "filter_id": filter.get("filter_id"),
1850
+ "title": filter.get("title"),
1851
+ "question": (
1852
+ f"Participants were asked the following: {question}"
1853
+ if question
1854
+ else None
1855
+ ),
1856
+ "type": filter_type,
1857
+ "range_filter_min": (
1858
+ filter.get("min") if filter_type == "range" else None
1859
+ ),
1860
+ "range_filter_max": (
1861
+ filter.get("max") if filter_type == "range" else None
1862
+ ),
1863
+ "select_filter_num_options": (
1864
+ len(filter.get("choices", []))
1865
+ if filter_type == "select"
1866
+ else None
1867
+ ),
1868
+ "select_filter_options": (
1869
+ filter.get("choices") if filter_type == "select" else None
1870
+ ),
1551
1871
  }
1872
+ )
1873
+ filter_scenarios.append(scenario)
1874
+ return CoopProlificFilters(filter_scenarios)
1552
1875
 
1553
- Raises:
1554
- ValueError: If the EDSL_FETCH_TOKEN_PRICES configuration setting is invalid
1876
+ @staticmethod
1877
+ def _validate_prolific_study_cost(
1878
+ estimated_completion_time_minutes: int, participant_payment_cents: int
1879
+ ) -> tuple[bool, float]:
1880
+ """
1881
+ If the cost of a Prolific study is below the threshold, return True.
1882
+ Otherwise, return False.
1883
+ The second value in the tuple is the cost of the study in USD per hour.
1884
+ """
1885
+ estimated_completion_time_hours = estimated_completion_time_minutes / 60
1886
+ participant_payment_usd = participant_payment_cents / 100
1887
+ cost_usd_per_hour = participant_payment_usd / estimated_completion_time_hours
1555
1888
 
1556
- Notes:
1557
- - Returns an empty dict if EDSL_FETCH_TOKEN_PRICES is set to "False"
1558
- - The pricing data is cached to minimize API calls
1559
- - Pricing may vary based on the model, provider, and token type (input/output)
1560
- - All prices are in USD per million tokens
1889
+ # $8.00 USD per hour is the minimum amount for using Prolific
1890
+ if cost_usd_per_hour < 8:
1891
+ return True, cost_usd_per_hour
1892
+ else:
1893
+ return False, cost_usd_per_hour
1894
+
1895
+ def create_prolific_study(
1896
+ self,
1897
+ project_uuid: str,
1898
+ name: str,
1899
+ description: str,
1900
+ num_participants: int,
1901
+ estimated_completion_time_minutes: int,
1902
+ participant_payment_cents: int,
1903
+ device_compatibility: Optional[
1904
+ List[Literal["desktop", "tablet", "mobile"]]
1905
+ ] = None,
1906
+ peripheral_requirements: Optional[
1907
+ List[Literal["audio", "camera", "download", "microphone"]]
1908
+ ] = None,
1909
+ filters: Optional[List[Dict]] = None,
1910
+ ) -> dict:
1911
+ """
1912
+ Create a Prolific study for a project. Returns a dict with the study details.
1913
+
1914
+ To add filters to your study, you should first pull the list of supported
1915
+ filters using Coop.list_prolific_filters().
1916
+ Then, you can use the create_study_filter method of the returned
1917
+ CoopProlificFilters object to create a valid filter dict.
1918
+ """
1919
+ is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
1920
+ estimated_completion_time_minutes, participant_payment_cents
1921
+ )
1922
+ if is_underpayment:
1923
+ raise CoopValueError(
1924
+ f"The current participant payment of ${cost_usd_per_hour:.2f} USD per hour is below the minimum payment for using Prolific ($8.00 USD per hour)."
1925
+ )
1926
+
1927
+ response = self._send_server_request(
1928
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies",
1929
+ method="POST",
1930
+ payload={
1931
+ "name": name,
1932
+ "description": description,
1933
+ "total_available_places": num_participants,
1934
+ "estimated_completion_time": estimated_completion_time_minutes,
1935
+ "reward": participant_payment_cents,
1936
+ "device_compatibility": (
1937
+ ["desktop", "tablet", "mobile"]
1938
+ if device_compatibility is None
1939
+ else device_compatibility
1940
+ ),
1941
+ "peripheral_requirements": (
1942
+ [] if peripheral_requirements is None else peripheral_requirements
1943
+ ),
1944
+ "filters": [] if filters is None else filters,
1945
+ },
1946
+ )
1947
+ self._resolve_server_response(response)
1948
+ response_json = response.json()
1949
+ return {
1950
+ "study_id": response_json.get("study_id"),
1951
+ "status": response_json.get("status"),
1952
+ "admin_url": response_json.get("admin_url"),
1953
+ "respondent_url": response_json.get("respondent_url"),
1954
+ "name": response_json.get("name"),
1955
+ "description": response_json.get("description"),
1956
+ "num_participants": response_json.get("total_available_places"),
1957
+ "estimated_completion_time_minutes": response_json.get(
1958
+ "estimated_completion_time"
1959
+ ),
1960
+ "participant_payment_cents": response_json.get("reward"),
1961
+ "total_cost_cents": response_json.get("total_cost"),
1962
+ "device_compatibility": response_json.get("device_compatibility"),
1963
+ "peripheral_requirements": response_json.get("peripheral_requirements"),
1964
+ "filters": response_json.get("filters"),
1965
+ }
1966
+
1967
+ def update_prolific_study(
1968
+ self,
1969
+ project_uuid: str,
1970
+ study_id: str,
1971
+ name: Optional[str] = None,
1972
+ description: Optional[str] = None,
1973
+ num_participants: Optional[int] = None,
1974
+ estimated_completion_time_minutes: Optional[int] = None,
1975
+ participant_payment_cents: Optional[int] = None,
1976
+ device_compatibility: Optional[
1977
+ List[Literal["desktop", "tablet", "mobile"]]
1978
+ ] = None,
1979
+ peripheral_requirements: Optional[
1980
+ List[Literal["audio", "camera", "download", "microphone"]]
1981
+ ] = None,
1982
+ filters: Optional[List[Dict]] = None,
1983
+ ) -> dict:
1984
+ """
1985
+ Update a Prolific study. Returns a dict with the study details.
1986
+ """
1987
+ study = self.get_prolific_study(project_uuid, study_id)
1988
+
1989
+ current_completion_time = study.get("estimated_completion_time_minutes")
1990
+ current_payment = study.get("participant_payment_cents")
1991
+
1992
+ updated_completion_time = (
1993
+ estimated_completion_time_minutes or current_completion_time
1994
+ )
1995
+ updated_payment = participant_payment_cents or current_payment
1996
+
1997
+ is_underpayment, cost_usd_per_hour = self._validate_prolific_study_cost(
1998
+ updated_completion_time, updated_payment
1999
+ )
2000
+ if is_underpayment:
2001
+ raise CoopValueError(
2002
+ f"This update would result in a participant payment of ${cost_usd_per_hour:.2f} USD per hour, which is below the minimum payment for using Prolific ($8.00 USD per hour)."
2003
+ )
2004
+
2005
+ payload = {}
2006
+ if name is not None:
2007
+ payload["name"] = name
2008
+ if description is not None:
2009
+ payload["description"] = description
2010
+ if num_participants is not None:
2011
+ payload["total_available_places"] = num_participants
2012
+ if estimated_completion_time_minutes is not None:
2013
+ payload["estimated_completion_time"] = estimated_completion_time_minutes
2014
+ if participant_payment_cents is not None:
2015
+ payload["reward"] = participant_payment_cents
2016
+ if device_compatibility is not None:
2017
+ payload["device_compatibility"] = device_compatibility
2018
+ if peripheral_requirements is not None:
2019
+ payload["peripheral_requirements"] = peripheral_requirements
2020
+ if filters is not None:
2021
+ payload["filters"] = filters
2022
+
2023
+ response = self._send_server_request(
2024
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
2025
+ method="PATCH",
2026
+ payload=payload,
2027
+ )
2028
+ self._resolve_server_response(response)
2029
+ response_json = response.json()
2030
+ return {
2031
+ "study_id": response_json.get("study_id"),
2032
+ "status": response_json.get("status"),
2033
+ "admin_url": response_json.get("admin_url"),
2034
+ "respondent_url": response_json.get("respondent_url"),
2035
+ "name": response_json.get("name"),
2036
+ "description": response_json.get("description"),
2037
+ "num_participants": response_json.get("total_available_places"),
2038
+ "estimated_completion_time_minutes": response_json.get(
2039
+ "estimated_completion_time"
2040
+ ),
2041
+ "participant_payment_cents": response_json.get("reward"),
2042
+ "total_cost_cents": response_json.get("total_cost"),
2043
+ "device_compatibility": response_json.get("device_compatibility"),
2044
+ "peripheral_requirements": response_json.get("peripheral_requirements"),
2045
+ "filters": response_json.get("filters"),
2046
+ }
2047
+
2048
+ def publish_prolific_study(
2049
+ self,
2050
+ project_uuid: str,
2051
+ study_id: str,
2052
+ ) -> dict:
2053
+ """
2054
+ Publish a Prolific study.
2055
+ """
2056
+ response = self._send_server_request(
2057
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/publish",
2058
+ method="POST",
2059
+ )
2060
+ self._resolve_server_response(response)
2061
+ return response.json()
2062
+
2063
+ def get_prolific_study(self, project_uuid: str, study_id: str) -> dict:
2064
+ """
2065
+ Get a Prolific study. Returns a dict with the study details.
2066
+ """
2067
+ response = self._send_server_request(
2068
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
2069
+ method="GET",
2070
+ )
2071
+ self._resolve_server_response(response)
2072
+ response_json = response.json()
2073
+ return {
2074
+ "study_id": response_json.get("study_id"),
2075
+ "status": response_json.get("status"),
2076
+ "admin_url": response_json.get("admin_url"),
2077
+ "respondent_url": response_json.get("respondent_url"),
2078
+ "name": response_json.get("name"),
2079
+ "description": response_json.get("description"),
2080
+ "num_participants": response_json.get("total_available_places"),
2081
+ "estimated_completion_time_minutes": response_json.get(
2082
+ "estimated_completion_time"
2083
+ ),
2084
+ "participant_payment_cents": response_json.get("reward"),
2085
+ "total_cost_cents": response_json.get("total_cost"),
2086
+ "device_compatibility": response_json.get("device_compatibility"),
2087
+ "peripheral_requirements": response_json.get("peripheral_requirements"),
2088
+ "filters": response_json.get("filters"),
2089
+ }
2090
+
2091
+ def get_prolific_study_responses(
2092
+ self,
2093
+ project_uuid: str,
2094
+ study_id: str,
2095
+ ) -> Union["Results", "ScenarioList"]:
2096
+ """
2097
+ Return a Results object with the human responses for a project.
2098
+
2099
+ If generating the Results object fails, a ScenarioList will be returned instead.
2100
+ """
2101
+ response = self._send_server_request(
2102
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/responses",
2103
+ method="GET",
2104
+ )
2105
+ self._resolve_server_response(response)
2106
+ response_json = response.json()
2107
+ human_responses = response_json.get("human_responses", [])
2108
+ survey_json_string = response_json.get("survey_json_string")
2109
+
2110
+ return self._turn_human_responses_into_results(
2111
+ human_responses, survey_json_string
2112
+ )
2113
+
2114
+ def delete_prolific_study(
2115
+ self,
2116
+ project_uuid: str,
2117
+ study_id: str,
2118
+ ) -> dict:
2119
+ """
2120
+ Deletes a Prolific study.
2121
+
2122
+ Note: Only draft studies can be deleted. Once you publish a study, it cannot be deleted.
2123
+ """
2124
+ response = self._send_server_request(
2125
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}",
2126
+ method="DELETE",
2127
+ )
2128
+ self._resolve_server_response(response)
2129
+ return response.json()
2130
+
2131
+ def approve_prolific_study_submission(
2132
+ self,
2133
+ project_uuid: str,
2134
+ study_id: str,
2135
+ submission_id: str,
2136
+ ) -> dict:
2137
+ """
2138
+ Approve a Prolific study submission.
2139
+ """
2140
+ response = self._send_server_request(
2141
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/approve",
2142
+ method="POST",
2143
+ )
2144
+ self._resolve_server_response(response)
2145
+ return response.json()
2146
+
2147
+ def reject_prolific_study_submission(
2148
+ self,
2149
+ project_uuid: str,
2150
+ study_id: str,
2151
+ submission_id: str,
2152
+ reason: Literal[
2153
+ "TOO_QUICKLY",
2154
+ "TOO_SLOWLY",
2155
+ "FAILED_INSTRUCTIONS",
2156
+ "INCOMP_LONGITUDINAL",
2157
+ "FAILED_CHECK",
2158
+ "LOW_EFFORT",
2159
+ "MALINGERING",
2160
+ "NO_CODE",
2161
+ "BAD_CODE",
2162
+ "NO_DATA",
2163
+ "UNSUPP_DEVICE",
2164
+ "OTHER",
2165
+ ],
2166
+ explanation: str,
2167
+ ) -> dict:
2168
+ """
2169
+ Reject a Prolific study submission.
2170
+ """
2171
+ valid_rejection_reasons = [
2172
+ "TOO_QUICKLY",
2173
+ "TOO_SLOWLY",
2174
+ "FAILED_INSTRUCTIONS",
2175
+ "INCOMP_LONGITUDINAL",
2176
+ "FAILED_CHECK",
2177
+ "LOW_EFFORT",
2178
+ "MALINGERING",
2179
+ "NO_CODE",
2180
+ "BAD_CODE",
2181
+ "NO_DATA",
2182
+ "UNSUPP_DEVICE",
2183
+ "OTHER",
2184
+ ]
2185
+ if reason not in valid_rejection_reasons:
2186
+ raise CoopValueError(
2187
+ f"Invalid rejection reason. Please use one of the following: {valid_rejection_reasons}."
2188
+ )
2189
+ if len(explanation) < 100:
2190
+ raise CoopValueError(
2191
+ "Rejection explanation must be at least 100 characters."
2192
+ )
2193
+ response = self._send_server_request(
2194
+ uri=f"api/v0/projects/{project_uuid}/prolific-studies/{study_id}/submissions/{submission_id}/reject",
2195
+ method="POST",
2196
+ payload={
2197
+ "reason": reason,
2198
+ "explanation": explanation,
2199
+ },
2200
+ )
2201
+ self._resolve_server_response(response)
2202
+ return response.json()
2203
+
2204
+ def __repr__(self):
2205
+ """Return a string representation of the client."""
2206
+ return f"Client(api_key='{self.api_key}', url='{self.url}')"
2207
+
2208
+ async def remote_async_execute_model_call(
2209
+ self, model_dict: dict, user_prompt: str, system_prompt: str
2210
+ ) -> dict:
2211
+ url = self.api_url + "/inference/"
2212
+ # print("Now using url: ", url)
2213
+ data = {
2214
+ "model_dict": model_dict,
2215
+ "user_prompt": user_prompt,
2216
+ "system_prompt": system_prompt,
2217
+ }
2218
+ # Use aiohttp to send a POST request asynchronously
2219
+ async with aiohttp.ClientSession() as session:
2220
+ async with session.post(url, json=data) as response:
2221
+ response_data = await response.json()
2222
+ return response_data
2223
+
2224
+ def web(
2225
+ self,
2226
+ survey: dict,
2227
+ platform: Literal[
2228
+ "google_forms", "lime_survey", "survey_monkey"
2229
+ ] = "lime_survey",
2230
+ email=None,
2231
+ ):
2232
+ url = f"{self.api_url}/api/v0/export_to_{platform}"
2233
+ if email:
2234
+ data = {"json_string": json.dumps({"survey": survey, "email": email})}
2235
+ else:
2236
+ data = {"json_string": json.dumps({"survey": survey, "email": ""})}
2237
+
2238
+ response_json = requests.post(url, headers=self.headers, data=json.dumps(data))
2239
+
2240
+ return response_json
2241
+
2242
+ def fetch_prices(self) -> dict:
2243
+ """
2244
+ Fetch the current pricing information for language models.
2245
+
2246
+ This method retrieves the latest pricing information for all supported language models
2247
+ from the Expected Parrot API. The pricing data is used to estimate costs for jobs
2248
+ and to optimize model selection based on budget constraints.
2249
+
2250
+ Returns:
2251
+ dict: A dictionary mapping (service, model) tuples to pricing information.
2252
+ Each entry contains token pricing for input and output tokens.
2253
+ Example structure:
2254
+ {
2255
+ ('openai', 'gpt-4'): {
2256
+ 'input': {'usd_per_1M_tokens': 30.0, ...},
2257
+ 'output': {'usd_per_1M_tokens': 60.0, ...}
2258
+ }
2259
+ }
2260
+
2261
+ Raises:
2262
+ ValueError: If the EDSL_FETCH_TOKEN_PRICES configuration setting is invalid
2263
+
2264
+ Notes:
2265
+ - Returns an empty dict if EDSL_FETCH_TOKEN_PRICES is set to "False"
2266
+ - The pricing data is cached to minimize API calls
2267
+ - Pricing may vary based on the model, provider, and token type (input/output)
2268
+ - All prices are in USD per million tokens
1561
2269
 
1562
2270
  Example:
1563
2271
  >>> prices = coop.fetch_prices()
@@ -1686,6 +2394,235 @@ class Coop(CoopFunctionsMixin):
1686
2394
  self._resolve_server_response(response)
1687
2395
  return response.json().get("uuid")
1688
2396
 
2397
+ def pull(
2398
+ self,
2399
+ url_or_uuid: Optional[Union[str, UUID]] = None,
2400
+ expected_object_type: Optional[ObjectType] = None,
2401
+ ) -> dict:
2402
+ """
2403
+ Generate a signed URL for pulling an object directly from Google Cloud Storage.
2404
+
2405
+ This method gets a signed URL that allows direct download access to the object from
2406
+ Google Cloud Storage, which is more efficient for large files.
2407
+
2408
+ Parameters:
2409
+ url_or_uuid (Union[str, UUID], optional): Identifier for the object to retrieve.
2410
+ Can be one of:
2411
+ - UUID string (e.g., "123e4567-e89b-12d3-a456-426614174000")
2412
+ - Full URL (e.g., "https://expectedparrot.com/content/123e4567...")
2413
+ - Alias URL (e.g., "https://expectedparrot.com/content/username/my-survey")
2414
+ expected_object_type (ObjectType, optional): If provided, validates that the
2415
+ retrieved object is of the expected type (e.g., "survey", "agent")
2416
+
2417
+ Returns:
2418
+ dict: A response containing the signed_url for direct download
2419
+
2420
+ Raises:
2421
+ CoopNoUUIDError: If no UUID or URL is provided
2422
+ CoopInvalidURLError: If the URL format is invalid
2423
+ CoopServerResponseError: If there's an error communicating with the server
2424
+ HTTPException: If the object or object files are not found
2425
+
2426
+ Example:
2427
+ >>> response = coop.pull("123e4567-e89b-12d3-a456-426614174000")
2428
+ >>> response = coop.pull("https://expectedparrot.com/content/username/my-survey")
2429
+ >>> print(f"Download URL: {response['signed_url']}")
2430
+ >>> # Use the signed_url to download the object directly
2431
+ """
2432
+ obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(url_or_uuid)
2433
+
2434
+ # Handle alias-based retrieval with new/old format detection
2435
+ if not obj_uuid and owner_username and alias:
2436
+ # First, get object info to determine format and UUID
2437
+ info_response = self._send_server_request(
2438
+ uri="api/v0/object/alias/info",
2439
+ method="GET",
2440
+ params={"owner_username": owner_username, "alias": alias},
2441
+ )
2442
+ self._resolve_server_response(info_response)
2443
+ info_data = info_response.json()
2444
+
2445
+ obj_uuid = info_data.get("uuid")
2446
+ is_new_format = info_data.get("is_new_format", False)
2447
+
2448
+ # Validate object type if expected
2449
+ if expected_object_type:
2450
+ actual_object_type = info_data.get("object_type")
2451
+ if actual_object_type != expected_object_type:
2452
+ from .exceptions import CoopObjectTypeError
2453
+
2454
+ raise CoopObjectTypeError(
2455
+ f"Expected {expected_object_type=} but got {actual_object_type=}"
2456
+ )
2457
+
2458
+ # Use get method for old format objects
2459
+ if not is_new_format:
2460
+ return self.get(url_or_uuid, expected_object_type)
2461
+
2462
+ # Send the request to the API endpoint with the resolved UUID
2463
+ response = self._send_server_request(
2464
+ uri="api/v0/object/pull",
2465
+ method="POST",
2466
+ payload={"object_uuid": obj_uuid},
2467
+ )
2468
+ # Handle any errors in the response
2469
+ self._resolve_server_response(response)
2470
+ if "signed_url" not in response.json():
2471
+ from .exceptions import CoopResponseError
2472
+
2473
+ raise CoopResponseError("No signed url was provided received")
2474
+ signed_url = response.json().get("signed_url")
2475
+
2476
+ if signed_url == "": # it is in old format
2477
+ return self.get(url_or_uuid, expected_object_type)
2478
+
2479
+ try:
2480
+ response = requests.get(signed_url)
2481
+
2482
+ self._resolve_gcs_response(response)
2483
+
2484
+ except Exception:
2485
+ return self.get(url_or_uuid, expected_object_type)
2486
+ object_dict = response.json()
2487
+ if expected_object_type is not None:
2488
+ edsl_class = ObjectRegistry.get_edsl_class_by_object_type(
2489
+ expected_object_type
2490
+ )
2491
+ edsl_object = edsl_class.from_dict(object_dict)
2492
+ # Return the response containing the signed URL
2493
+ return edsl_object
2494
+
2495
+ def get_upload_url(self, object_uuid: str) -> dict:
2496
+ """
2497
+ Get a signed upload URL for updating the content of an existing object.
2498
+
2499
+ This method gets a signed URL that allows direct upload to Google Cloud Storage
2500
+ for objects stored in the new format, while preserving the existing UUID.
2501
+
2502
+ Parameters:
2503
+ object_uuid (str): The UUID of the object to get an upload URL for
2504
+
2505
+ Returns:
2506
+ dict: A response containing:
2507
+ - signed_url: The signed URL for uploading new content
2508
+ - object_uuid: The UUID of the object
2509
+ - message: Success message
2510
+
2511
+ Raises:
2512
+ CoopServerResponseError: If there's an error communicating with the server
2513
+ HTTPException: If the object is not found, not owned by user, or not in new format
2514
+
2515
+ Notes:
2516
+ - Only works with objects stored in the new format (transition table)
2517
+ - User must be the owner of the object
2518
+ - The signed URL expires after 60 minutes
2519
+
2520
+ Example:
2521
+ >>> response = coop.get_upload_url("123e4567-e89b-12d3-a456-426614174000")
2522
+ >>> upload_url = response['signed_url']
2523
+ >>> # Use the upload_url to PUT new content directly to GCS
2524
+ """
2525
+ response = self._send_server_request(
2526
+ uri="api/v0/object/upload-url",
2527
+ method="POST",
2528
+ payload={"object_uuid": object_uuid},
2529
+ )
2530
+ self._resolve_server_response(response)
2531
+ return response.json()
2532
+
2533
+ def push(
2534
+ self,
2535
+ object: EDSLObject,
2536
+ description: Optional[str] = None,
2537
+ alias: Optional[str] = None,
2538
+ visibility: Optional[VisibilityType] = "unlisted",
2539
+ ) -> dict:
2540
+ """
2541
+ Generate a signed URL for pushing an object directly to Google Cloud Storage.
2542
+
2543
+ This method gets a signed URL that allows direct upload access to Google Cloud Storage,
2544
+ which is more efficient for large files.
2545
+
2546
+ Parameters:
2547
+ object_type (ObjectType): The type of object to be uploaded
2548
+
2549
+ Returns:
2550
+ dict: A response containing the signed_url for direct upload and optionally a job_id
2551
+
2552
+ Raises:
2553
+ CoopServerResponseError: If there's an error communicating with the server
2554
+
2555
+ Example:
2556
+ >>> response = coop.push("scenario")
2557
+ >>> print(f"Upload URL: {response['signed_url']}")
2558
+ >>> # Use the signed_url to upload the object directly
2559
+ """
2560
+
2561
+ object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
2562
+ object_dict = object.to_dict()
2563
+ object_hash = object.get_hash() if hasattr(object, "get_hash") else None
2564
+
2565
+ # Send the request to the API endpoint
2566
+ response = self._send_server_request(
2567
+ uri="api/v0/object/push",
2568
+ method="POST",
2569
+ payload={
2570
+ "object_type": object_type,
2571
+ "description": description,
2572
+ "alias": alias,
2573
+ "visibility": visibility,
2574
+ "object_hash": object_hash,
2575
+ "version": self._edsl_version,
2576
+ },
2577
+ )
2578
+ response_json = response.json()
2579
+ if response_json.get("signed_url") is not None:
2580
+ signed_url = response_json.get("signed_url")
2581
+ else:
2582
+ from .exceptions import CoopResponseError
2583
+
2584
+ raise CoopResponseError(response.text)
2585
+
2586
+ json_data = json.dumps(
2587
+ object_dict,
2588
+ default=self._json_handle_none,
2589
+ allow_nan=False,
2590
+ )
2591
+ response = requests.put(
2592
+ signed_url,
2593
+ data=json_data.encode(),
2594
+ headers={"Content-Type": "application/json"},
2595
+ )
2596
+ self._resolve_gcs_response(response)
2597
+
2598
+ # Send confirmation that upload was completed
2599
+ object_uuid = response_json.get("object_uuid", None)
2600
+ owner_username = response_json.get("owner_username", None)
2601
+ object_alias = response_json.get("alias", None)
2602
+
2603
+ if object_uuid is None:
2604
+ from .exceptions import CoopResponseError
2605
+
2606
+ raise CoopResponseError("No object uuid was provided received")
2607
+
2608
+ # Confirm the upload completion
2609
+ confirm_response = self._send_server_request(
2610
+ uri="api/v0/object/confirm-upload",
2611
+ method="POST",
2612
+ payload={"object_uuid": object_uuid},
2613
+ )
2614
+ self._resolve_server_response(confirm_response)
2615
+
2616
+ return {
2617
+ "description": response_json.get("description"),
2618
+ "object_type": object_type,
2619
+ "url": f"{self.url}/content/{object_uuid}",
2620
+ "alias_url": self._get_alias_url(owner_username, object_alias),
2621
+ "uuid": object_uuid,
2622
+ "version": self._edsl_version,
2623
+ "visibility": response_json.get("visibility"),
2624
+ }
2625
+
1689
2626
  def _display_login_url(
1690
2627
  self, edsl_auth_token: str, link_description: Optional[str] = None
1691
2628
  ):
@@ -1769,6 +2706,125 @@ class Coop(CoopFunctionsMixin):
1769
2706
  # Add API key to environment
1770
2707
  load_dotenv()
1771
2708
 
2709
+ def login_streamlit(self, timeout: int = 120):
2710
+ """
2711
+ Start the EDSL auth token login flow inside a Streamlit application.
2712
+
2713
+ This helper is functionally equivalent to ``Coop.login`` but renders the
2714
+ login link and status updates directly in the Streamlit UI. The method
2715
+ will automatically poll the Expected Parrot server for the API-key
2716
+ associated with the generated auth-token and, once received, store it
2717
+ via ``ExpectedParrotKeyHandler`` and write it to the local ``.env``
2718
+ file so subsequent sessions pick it up automatically.
2719
+
2720
+ Parameters
2721
+ ----------
2722
+ timeout : int, default 120
2723
+ How many seconds to wait for the user to complete the login before
2724
+ giving up and showing an error in the Streamlit app.
2725
+
2726
+ Returns
2727
+ -------
2728
+ str | None
2729
+ The API-key if the user logged-in successfully, otherwise ``None``.
2730
+ """
2731
+ try:
2732
+ import streamlit as st
2733
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
2734
+ except ModuleNotFoundError as exc:
2735
+ raise ImportError(
2736
+ "Streamlit is required for `login_streamlit`. Install it with `pip install streamlit`."
2737
+ ) from exc
2738
+
2739
+ # Ensure we are actually running inside a Streamlit script. If not, give a
2740
+ # clear error message instead of crashing when `st.experimental_rerun` is
2741
+ # invoked outside the Streamlit runtime.
2742
+ if get_script_run_ctx() is None:
2743
+ raise RuntimeError(
2744
+ "`login_streamlit` must be invoked from within a running Streamlit "
2745
+ "app (use `streamlit run your_script.py`). If you need to obtain an "
2746
+ "API-key in a regular Python script or notebook, use `Coop.login()` "
2747
+ "instead."
2748
+ )
2749
+
2750
+ import secrets
2751
+ import time
2752
+ import os
2753
+ from dotenv import load_dotenv
2754
+ from .ep_key_handling import ExpectedParrotKeyHandler
2755
+ from ..utilities.utilities import write_api_key_to_env
2756
+
2757
+ # ------------------------------------------------------------------
2758
+ # 1. Prepare auth-token and store state across reruns
2759
+ # ------------------------------------------------------------------
2760
+ if "edsl_auth_token" not in st.session_state:
2761
+ st.session_state.edsl_auth_token = secrets.token_urlsafe(16)
2762
+ st.session_state.login_start_time = time.time()
2763
+
2764
+ edsl_auth_token: str = st.session_state.edsl_auth_token
2765
+ login_url = (
2766
+ f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
2767
+ )
2768
+
2769
+ # ------------------------------------------------------------------
2770
+ # 2. Render clickable login link
2771
+ # ------------------------------------------------------------------
2772
+ st.markdown(
2773
+ f"🔗 **Log in to Expected Parrot** → [click here]({login_url})",
2774
+ unsafe_allow_html=True,
2775
+ )
2776
+
2777
+ # ------------------------------------------------------------------
2778
+ # 3. Poll server for API-key (runs once per Streamlit execution)
2779
+ # ------------------------------------------------------------------
2780
+ api_key = self._get_api_key(edsl_auth_token)
2781
+ if api_key is None:
2782
+ elapsed = time.time() - st.session_state.login_start_time
2783
+ if elapsed > timeout:
2784
+ st.error(
2785
+ "Timed-out waiting for login. Please rerun the app to try again."
2786
+ )
2787
+ return None
2788
+
2789
+ remaining = int(timeout - elapsed)
2790
+ st.info(f"Waiting for login… ({remaining}s left)")
2791
+ # Trigger a rerun after a short delay to continue polling
2792
+ time.sleep(1)
2793
+
2794
+ # Attempt a rerun in a version-agnostic way. Different Streamlit
2795
+ # releases expose the helper under different names.
2796
+ def _safe_rerun():
2797
+ if hasattr(st, "experimental_rerun"):
2798
+ st.experimental_rerun()
2799
+ elif hasattr(st, "rerun"):
2800
+ st.rerun() # introduced in newer versions
2801
+ else:
2802
+ # Fallback – advise the user to update Streamlit for automatic polling.
2803
+ st.warning(
2804
+ "Please refresh the page to continue the login flow. "
2805
+ "(Consider upgrading Streamlit to enable automatic refresh.)"
2806
+ )
2807
+
2808
+ try:
2809
+ _safe_rerun()
2810
+ except Exception:
2811
+ # The Streamlit runtime intercepts the rerun exception; any other
2812
+ # unexpected errors are ignored to avoid crashing the app.
2813
+ pass
2814
+
2815
+ # ------------------------------------------------------------------
2816
+ # 4. Key received – persist it and notify user
2817
+ # ------------------------------------------------------------------
2818
+ ExpectedParrotKeyHandler().store_ep_api_key(api_key)
2819
+ os.environ["EXPECTED_PARROT_API_KEY"] = api_key
2820
+ path_to_env = write_api_key_to_env(api_key)
2821
+ load_dotenv()
2822
+
2823
+ st.success("API-key retrieved and stored. You are now logged-in! 🎉")
2824
+ st.caption(f"Key saved to `{path_to_env}`.")
2825
+
2826
+ return api_key
2827
+
1772
2828
  def transfer_credits(
1773
2829
  self,
1774
2830
  credits_transferred: int,
@@ -1835,10 +2891,155 @@ class Coop(CoopFunctionsMixin):
1835
2891
  >>> balance = coop.get_balance()
1836
2892
  >>> print(f"You have {balance['credits']} credits available.")
1837
2893
  """
1838
- response = self._send_server_request(uri="api/users/get_balance", method="GET")
2894
+ response = self._send_server_request(
2895
+ uri="api/v0/users/get-balance", method="GET"
2896
+ )
1839
2897
  self._resolve_server_response(response)
1840
2898
  return response.json()
1841
2899
 
2900
+ def login_gradio(self, timeout: int = 120, launch: bool = True, **launch_kwargs):
2901
+ """
2902
+ Start the EDSL auth token login flow inside a **Gradio** application.
2903
+
2904
+ This helper mirrors the behaviour of :py:meth:`Coop.login_streamlit` but
2905
+ renders the login link and status updates inside a Gradio UI. It will
2906
+ poll the Expected Parrot server for the API-key associated with a newly
2907
+ generated auth-token and, once received, store it via
2908
+ :pyclass:`~edsl.coop.ep_key_handling.ExpectedParrotKeyHandler` as well as
2909
+ in the local ``.env`` file so subsequent sessions pick it up
2910
+ automatically.
2911
+
2912
+ Parameters
2913
+ ----------
2914
+ timeout : int, default 120
2915
+ How many seconds to wait for the user to complete the login before
2916
+ giving up.
2917
+ launch : bool, default True
2918
+ If ``True`` the Gradio app is immediately launched with
2919
+ ``demo.launch(**launch_kwargs)``. Set this to ``False`` if you want
2920
+ to embed the returned :class:`gradio.Blocks` object into an existing
2921
+ Gradio interface.
2922
+ **launch_kwargs
2923
+ Additional keyword-arguments forwarded to ``gr.Blocks.launch`` when
2924
+ *launch* is ``True``.
2925
+
2926
+ Returns
2927
+ -------
2928
+ str | gradio.Blocks | None
2929
+ • If the API-key is retrieved within *timeout* seconds while the
2930
+ function is executing (e.g. when *launch* is ``False`` and the
2931
+ caller integrates the Blocks into another app) the key is
2932
+ returned.
2933
+ • If *launch* is ``True`` the method returns ``None`` after the
2934
+ Gradio app has been launched.
2935
+ • If *launch* is ``False`` the constructed ``gr.Blocks`` is
2936
+ returned so the caller can compose it further.
2937
+ """
2938
+ try:
2939
+ import gradio as gr
2940
+ except ModuleNotFoundError as exc:
2941
+ raise ImportError(
2942
+ "Gradio is required for `login_gradio`. Install it with `pip install gradio`."
2943
+ ) from exc
2944
+
2945
+ import secrets
2946
+ import time
2947
+ import os
2948
+ from dotenv import load_dotenv
2949
+ from .ep_key_handling import ExpectedParrotKeyHandler
2950
+ from ..utilities.utilities import write_api_key_to_env
2951
+
2952
+ # ------------------------------------------------------------------
2953
+ # 1. Prepare auth-token
2954
+ # ------------------------------------------------------------------
2955
+ edsl_auth_token = secrets.token_urlsafe(16)
2956
+ login_url = (
2957
+ f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
2958
+ )
2959
+ start_time = time.time()
2960
+
2961
+ # ------------------------------------------------------------------
2962
+ # 2. Build Gradio interface
2963
+ # ------------------------------------------------------------------
2964
+ with gr.Blocks() as demo:
2965
+ gr.HTML(
2966
+ f'🔗 <b>Log in to Expected Parrot</b> → <a href="{login_url}" target="_blank">click here</a>'
2967
+ )
2968
+ status_md = gr.Markdown("Waiting for login…")
2969
+ refresh_btn = gr.Button(
2970
+ "I've logged in – click to continue", elem_id="refresh-btn"
2971
+ )
2972
+ key_state = gr.State(value=None)
2973
+
2974
+ # --------------------------------------------------------------
2975
+ # Polling callback
2976
+ # --------------------------------------------------------------
2977
+ def _refresh(current_key): # noqa: D401, pylint: disable=unused-argument
2978
+ """Poll server for API-key and update UI accordingly."""
2979
+
2980
+ # Fallback helper to generate a `update` object for the refresh button
2981
+ def _button_update(**kwargs):
2982
+ try:
2983
+ return gr.Button.update(**kwargs)
2984
+ except AttributeError:
2985
+ return gr.update(**kwargs)
2986
+
2987
+ api_key = self._get_api_key(edsl_auth_token)
2988
+ # Fall back to env var in case the key was obtained earlier in this session
2989
+ if not api_key:
2990
+ api_key = os.environ.get("EXPECTED_PARROT_API_KEY")
2991
+ elapsed = time.time() - start_time
2992
+ remaining = max(0, int(timeout - elapsed))
2993
+
2994
+ if api_key:
2995
+ # Persist and expose the key
2996
+ ExpectedParrotKeyHandler().store_ep_api_key(api_key)
2997
+ os.environ["EXPECTED_PARROT_API_KEY"] = api_key
2998
+ path_to_env = write_api_key_to_env(api_key)
2999
+ load_dotenv()
3000
+ success_msg = (
3001
+ "API-key retrieved and stored 🎉\n\n"
3002
+ f"Key saved to `{path_to_env}`."
3003
+ )
3004
+ return (
3005
+ success_msg,
3006
+ _button_update(interactive=False, visible=False),
3007
+ api_key,
3008
+ )
3009
+
3010
+ if elapsed > timeout:
3011
+ err_msg = (
3012
+ "Timed-out waiting for login. Please refresh the page "
3013
+ "or restart the app to try again."
3014
+ )
3015
+ return (
3016
+ err_msg,
3017
+ _button_update(),
3018
+ None,
3019
+ )
3020
+
3021
+ info_msg = f"Waiting for login… ({remaining}s left)"
3022
+ return (
3023
+ info_msg,
3024
+ _button_update(),
3025
+ None,
3026
+ )
3027
+
3028
+ # Initial status check when the interface loads
3029
+ demo.load(
3030
+ fn=_refresh,
3031
+ inputs=key_state,
3032
+ outputs=[status_md, refresh_btn, key_state],
3033
+ )
3034
+
3035
+ # ------------------------------------------------------------------
3036
+ # 3. Launch or return interface
3037
+ # ------------------------------------------------------------------
3038
+ if launch:
3039
+ demo.launch(**launch_kwargs)
3040
+ return None
3041
+ return demo
3042
+
1842
3043
 
1843
3044
  def main():
1844
3045
  """
@@ -1975,3 +3176,12 @@ def main():
1975
3176
  job_coop_object = coop.remote_inference_create(job)
1976
3177
  job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
1977
3178
  coop.get(job_coop_results.get("results_uuid"))
3179
+
3180
+ import streamlit as st
3181
+ from edsl.coop import Coop
3182
+
3183
+ coop = Coop() # no API-key required yet
3184
+ api_key = coop.login_streamlit() # renders link + handles polling & storage
3185
+
3186
+ if api_key:
3187
+ st.success("Ready to use EDSL with remote features!")