edsl 0.1.61__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +66 -0
- edsl/__version__.py +1 -1
- edsl/base/base_class.py +53 -0
- edsl/cli.py +93 -27
- edsl/config/config_class.py +4 -0
- edsl/coop/coop.py +403 -28
- edsl/coop/coop_jobs_objects.py +2 -2
- edsl/coop/coop_regular_objects.py +3 -1
- edsl/dataset/dataset.py +47 -41
- edsl/dataset/dataset_operations_mixin.py +138 -15
- edsl/dataset/report_from_template.py +509 -0
- edsl/inference_services/services/azure_ai.py +8 -2
- edsl/inference_services/services/open_ai_service.py +7 -5
- edsl/jobs/jobs.py +5 -4
- edsl/jobs/jobs_checks.py +11 -6
- edsl/jobs/remote_inference.py +17 -10
- edsl/prompts/prompt.py +7 -2
- edsl/questions/question_registry.py +4 -1
- edsl/results/result.py +93 -38
- edsl/results/results.py +24 -15
- edsl/scenarios/file_store.py +69 -0
- edsl/scenarios/scenario.py +233 -0
- edsl/scenarios/scenario_list.py +294 -130
- edsl/scenarios/scenario_source.py +1 -2
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/METADATA +1 -1
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/RECORD +29 -28
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/LICENSE +0 -0
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/WHEEL +0 -0
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/entry_points.txt +0 -0
edsl/coop/coop.py
CHANGED
@@ -273,6 +273,118 @@ class Coop(CoopFunctionsMixin):
|
|
273
273
|
|
274
274
|
return user_stable_version < server_stable_version
|
275
275
|
|
276
|
+
def check_for_updates(self, silent: bool = False) -> Optional[dict]:
|
277
|
+
"""
|
278
|
+
Check if there's a newer version of EDSL available.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
silent: If True, don't print any messages to console
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
dict with version info if update is available, None otherwise
|
285
|
+
"""
|
286
|
+
try:
|
287
|
+
# Use the new /version/updates endpoint
|
288
|
+
response = self._send_server_request(
|
289
|
+
uri="version/updates", method="GET", timeout=5
|
290
|
+
)
|
291
|
+
|
292
|
+
data = response.json()
|
293
|
+
|
294
|
+
# Extract version information from the response
|
295
|
+
current_version = data.get("current") # Latest version in use
|
296
|
+
guid_message = data.get("guid_message", "") # Message about updates
|
297
|
+
force_update = (
|
298
|
+
"force update" in guid_message.lower() if guid_message else False
|
299
|
+
)
|
300
|
+
# Check if update is needed
|
301
|
+
if current_version and self._user_version_is_outdated(
|
302
|
+
user_version_str=self._edsl_version,
|
303
|
+
server_version_str=current_version,
|
304
|
+
):
|
305
|
+
update_data = {
|
306
|
+
"current_version": self._edsl_version,
|
307
|
+
"latest_version": current_version,
|
308
|
+
"guid_message": guid_message,
|
309
|
+
"force_update": force_update,
|
310
|
+
"update_command": "pip install --upgrade edsl",
|
311
|
+
}
|
312
|
+
|
313
|
+
if not silent:
|
314
|
+
print("\n" + "=" * 60)
|
315
|
+
print("📦 EDSL Update Available!")
|
316
|
+
print(f"Your version: {self._edsl_version}")
|
317
|
+
print(f"Latest version: {current_version}")
|
318
|
+
|
319
|
+
# Display the guid message if present
|
320
|
+
if guid_message:
|
321
|
+
print(f"\n{guid_message}")
|
322
|
+
|
323
|
+
# Prompt user for update
|
324
|
+
prompt_message = "\nDo you want to update now? [Y/n] "
|
325
|
+
if force_update:
|
326
|
+
prompt_message = "\n⚠️ FORCE UPDATE REQUIRED - Do you want to update now? [Y/n] "
|
327
|
+
|
328
|
+
print(prompt_message, end="")
|
329
|
+
|
330
|
+
try:
|
331
|
+
user_input = input().strip().lower()
|
332
|
+
if user_input in ["", "y", "yes"]:
|
333
|
+
# Actually run the update
|
334
|
+
print("\nUpdating EDSL...")
|
335
|
+
import subprocess
|
336
|
+
import sys
|
337
|
+
|
338
|
+
try:
|
339
|
+
# Run pip install --upgrade edsl
|
340
|
+
result = subprocess.run(
|
341
|
+
[
|
342
|
+
sys.executable,
|
343
|
+
"-m",
|
344
|
+
"pip",
|
345
|
+
"install",
|
346
|
+
"--upgrade",
|
347
|
+
"edsl",
|
348
|
+
],
|
349
|
+
capture_output=True,
|
350
|
+
text=True,
|
351
|
+
)
|
352
|
+
|
353
|
+
if result.returncode == 0:
|
354
|
+
print(
|
355
|
+
"✅ Update successful! Please restart your application."
|
356
|
+
)
|
357
|
+
else:
|
358
|
+
print(f"❌ Update failed: {result.stderr}")
|
359
|
+
print(
|
360
|
+
"You can try updating manually with: pip install --upgrade edsl"
|
361
|
+
)
|
362
|
+
except Exception as e:
|
363
|
+
print(f"❌ Update failed: {str(e)}")
|
364
|
+
print(
|
365
|
+
"You can try updating manually with: pip install --upgrade edsl"
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
print(
|
369
|
+
"\nUpdate skipped. You can update later with: pip install --upgrade edsl"
|
370
|
+
)
|
371
|
+
|
372
|
+
print("=" * 60 + "\n")
|
373
|
+
|
374
|
+
except (EOFError, KeyboardInterrupt):
|
375
|
+
print(
|
376
|
+
"\nUpdate skipped. You can update later with: pip install --upgrade edsl"
|
377
|
+
)
|
378
|
+
print("=" * 60 + "\n")
|
379
|
+
|
380
|
+
return update_data
|
381
|
+
|
382
|
+
except Exception:
|
383
|
+
# Silently fail if we can't check for updates
|
384
|
+
pass
|
385
|
+
|
386
|
+
return None
|
387
|
+
|
276
388
|
def _resolve_server_response(
|
277
389
|
self, response: requests.Response, check_api_key: bool = True
|
278
390
|
) -> None:
|
@@ -280,18 +392,35 @@ class Coop(CoopFunctionsMixin):
|
|
280
392
|
Check the response from the server and raise errors as appropriate.
|
281
393
|
"""
|
282
394
|
# Get EDSL version from header
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
395
|
+
server_edsl_version = response.headers.get("X-EDSL-Version")
|
396
|
+
|
397
|
+
if server_edsl_version:
|
398
|
+
if self._user_version_is_outdated(
|
399
|
+
user_version_str=self._edsl_version,
|
400
|
+
server_version_str=server_edsl_version,
|
401
|
+
):
|
402
|
+
# Get additional info from server if available
|
403
|
+
update_info = response.headers.get("X-EDSL-Update-Info", "")
|
404
|
+
|
405
|
+
print("\n" + "=" * 60)
|
406
|
+
print("📦 EDSL Update Available!")
|
407
|
+
print(f"Your version: {self._edsl_version}")
|
408
|
+
print(f"Latest version: {server_edsl_version}")
|
409
|
+
if update_info:
|
410
|
+
print(f"Update info: {update_info}")
|
411
|
+
print(
|
412
|
+
"\nYour version is out of date - can we update to latest version? [Y/n]"
|
413
|
+
)
|
414
|
+
|
415
|
+
try:
|
416
|
+
user_input = input().strip().lower()
|
417
|
+
if user_input in ["", "y", "yes"]:
|
418
|
+
print("To update, run: pip install --upgrade edsl")
|
419
|
+
print("=" * 60 + "\n")
|
420
|
+
except (EOFError, KeyboardInterrupt):
|
421
|
+
# Handle non-interactive environments
|
422
|
+
print("To update, run: pip install --upgrade edsl")
|
423
|
+
print("=" * 60 + "\n")
|
295
424
|
if response.status_code >= 400:
|
296
425
|
try:
|
297
426
|
message = str(response.json().get("detail"))
|
@@ -598,7 +727,7 @@ class Coop(CoopFunctionsMixin):
|
|
598
727
|
else:
|
599
728
|
from .exceptions import CoopResponseError
|
600
729
|
|
601
|
-
raise CoopResponseError("No signed url was provided
|
730
|
+
raise CoopResponseError("No signed url was provided.")
|
602
731
|
|
603
732
|
response = requests.put(
|
604
733
|
signed_url, data=json_data.encode(), headers=headers
|
@@ -945,18 +1074,31 @@ class Coop(CoopFunctionsMixin):
|
|
945
1074
|
|
946
1075
|
obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
|
947
1076
|
|
948
|
-
# If we
|
949
|
-
if
|
950
|
-
#
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
1077
|
+
# If we're updating the value, we need to check the storage format
|
1078
|
+
if value:
|
1079
|
+
# If we don't have a UUID but have an alias, get the UUID and format info first
|
1080
|
+
if not obj_uuid and owner_username and obj_alias:
|
1081
|
+
# Get object info including UUID and format
|
1082
|
+
info_response = self._send_server_request(
|
1083
|
+
uri="api/v0/object/alias/info",
|
1084
|
+
method="GET",
|
1085
|
+
params={"owner_username": owner_username, "alias": obj_alias},
|
1086
|
+
)
|
1087
|
+
self._resolve_server_response(info_response)
|
1088
|
+
info_data = info_response.json()
|
958
1089
|
|
959
|
-
|
1090
|
+
obj_uuid = info_data.get("uuid")
|
1091
|
+
is_new_format = info_data.get("is_new_format", False)
|
1092
|
+
else:
|
1093
|
+
# We have a UUID, check the format
|
1094
|
+
format_check_response = self._send_server_request(
|
1095
|
+
uri="api/v0/object/check-format",
|
1096
|
+
method="POST",
|
1097
|
+
payload={"object_uuid": str(obj_uuid)},
|
1098
|
+
)
|
1099
|
+
self._resolve_server_response(format_check_response)
|
1100
|
+
format_data = format_check_response.json()
|
1101
|
+
is_new_format = format_data.get("is_new_format", False)
|
960
1102
|
|
961
1103
|
if is_new_format:
|
962
1104
|
# Handle new format objects: update metadata first, then upload content
|
@@ -1052,10 +1194,20 @@ class Coop(CoopFunctionsMixin):
|
|
1052
1194
|
f"Failed to upload object to GCS: {gcs_response.status_code}"
|
1053
1195
|
)
|
1054
1196
|
|
1197
|
+
# Step 4: Confirm upload and trigger queue worker processing
|
1198
|
+
confirm_response = self._send_server_request(
|
1199
|
+
uri="api/v0/object/confirm-upload",
|
1200
|
+
method="POST",
|
1201
|
+
payload={"object_uuid": str(obj_uuid)},
|
1202
|
+
)
|
1203
|
+
self._resolve_server_response(confirm_response)
|
1204
|
+
confirm_data = confirm_response.json()
|
1205
|
+
|
1055
1206
|
return {
|
1056
1207
|
"status": "success",
|
1057
|
-
"message": "Object updated successfully (new format - uploaded to GCS)",
|
1208
|
+
"message": "Object updated successfully (new format - uploaded to GCS and processing triggered)",
|
1058
1209
|
"object_uuid": str(obj_uuid),
|
1210
|
+
"processing_started": confirm_data.get("processing_started", False),
|
1059
1211
|
}
|
1060
1212
|
|
1061
1213
|
################
|
@@ -1195,7 +1347,7 @@ class Coop(CoopFunctionsMixin):
|
|
1195
1347
|
if not upload_signed_url:
|
1196
1348
|
from .exceptions import CoopResponseError
|
1197
1349
|
|
1198
|
-
raise CoopResponseError("No signed url was provided
|
1350
|
+
raise CoopResponseError("No signed url was provided.")
|
1199
1351
|
|
1200
1352
|
response = requests.put(
|
1201
1353
|
upload_signed_url,
|
@@ -1431,6 +1583,159 @@ class Coop(CoopFunctionsMixin):
|
|
1431
1583
|
}
|
1432
1584
|
)
|
1433
1585
|
|
1586
|
+
def new_remote_inference_get(
|
1587
|
+
self,
|
1588
|
+
job_uuid: Optional[str] = None,
|
1589
|
+
results_uuid: Optional[str] = None,
|
1590
|
+
include_json_string: Optional[bool] = False,
|
1591
|
+
) -> RemoteInferenceResponse:
|
1592
|
+
"""
|
1593
|
+
Get the status and details of a remote inference job.
|
1594
|
+
|
1595
|
+
This method retrieves the current status and information about a remote job,
|
1596
|
+
including links to results if the job has completed successfully.
|
1597
|
+
|
1598
|
+
Parameters:
|
1599
|
+
job_uuid (str, optional): The UUID of the remote job to check
|
1600
|
+
results_uuid (str, optional): The UUID of the results associated with the job
|
1601
|
+
(can be used if you only have the results UUID)
|
1602
|
+
include_json_string (bool, optional): If True, include the json string for the job in the response
|
1603
|
+
|
1604
|
+
Returns:
|
1605
|
+
RemoteInferenceResponse: Information about the job including:
|
1606
|
+
job_uuid: The unique identifier for the job
|
1607
|
+
results_uuid: The UUID of the results
|
1608
|
+
results_url: URL to access the results
|
1609
|
+
status: Current status ("queued", "running", "completed", "failed")
|
1610
|
+
version: EDSL version used for the job
|
1611
|
+
job_json_string: The json string for the job (if include_json_string is True)
|
1612
|
+
latest_job_run_details: Metadata about the job status
|
1613
|
+
interview_details: Metadata about the job interview status (for jobs that have reached running status)
|
1614
|
+
total_interviews: The total number of interviews in the job
|
1615
|
+
completed_interviews: The number of completed interviews
|
1616
|
+
interviews_with_exceptions: The number of completed interviews that have exceptions
|
1617
|
+
exception_counters: A list of exception counts for the job
|
1618
|
+
exception_type: The type of exception
|
1619
|
+
inference_service: The inference service
|
1620
|
+
model: The model
|
1621
|
+
question_name: The name of the question
|
1622
|
+
exception_count: The number of exceptions
|
1623
|
+
failure_reason: The reason the job failed (failed jobs only)
|
1624
|
+
failure_description: The description of the failure (failed jobs only)
|
1625
|
+
error_report_uuid: The UUID of the error report (partially failed jobs only)
|
1626
|
+
cost_credits: The cost of the job run in credits
|
1627
|
+
cost_usd: The cost of the job run in USD
|
1628
|
+
expenses: The expenses incurred by the job run
|
1629
|
+
service: The service
|
1630
|
+
model: The model
|
1631
|
+
token_type: The type of token (input or output)
|
1632
|
+
price_per_million_tokens: The price per million tokens
|
1633
|
+
tokens_count: The number of tokens consumed
|
1634
|
+
cost_credits: The cost of the service/model/token type combination in credits
|
1635
|
+
cost_usd: The cost of the service/model/token type combination in USD
|
1636
|
+
|
1637
|
+
Raises:
|
1638
|
+
ValueError: If neither job_uuid nor results_uuid is provided
|
1639
|
+
CoopServerResponseError: If there's an error communicating with the server
|
1640
|
+
|
1641
|
+
Notes:
|
1642
|
+
- Either job_uuid or results_uuid must be provided
|
1643
|
+
- If both are provided, job_uuid takes precedence
|
1644
|
+
- For completed jobs, you can use the results_url to view or download results
|
1645
|
+
- For failed jobs, check the latest_error_report_url for debugging information
|
1646
|
+
|
1647
|
+
Example:
|
1648
|
+
>>> job_status = coop.new_remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
1649
|
+
>>> print(f"Job status: {job_status['status']}")
|
1650
|
+
>>> if job_status['status'] == 'completed':
|
1651
|
+
... print(f"Results available at: {job_status['results_url']}")
|
1652
|
+
"""
|
1653
|
+
if job_uuid is None and results_uuid is None:
|
1654
|
+
from .exceptions import CoopValueError
|
1655
|
+
|
1656
|
+
raise CoopValueError("Either job_uuid or results_uuid must be provided.")
|
1657
|
+
elif job_uuid is not None:
|
1658
|
+
params = {"job_uuid": job_uuid}
|
1659
|
+
else:
|
1660
|
+
params = {"results_uuid": results_uuid}
|
1661
|
+
if include_json_string:
|
1662
|
+
params["include_json_string"] = include_json_string
|
1663
|
+
|
1664
|
+
response = self._send_server_request(
|
1665
|
+
uri="api/v0/remote-inference",
|
1666
|
+
method="GET",
|
1667
|
+
params=params,
|
1668
|
+
)
|
1669
|
+
self._resolve_server_response(response)
|
1670
|
+
data = response.json()
|
1671
|
+
|
1672
|
+
results_uuid = data.get("results_uuid")
|
1673
|
+
|
1674
|
+
if results_uuid is None:
|
1675
|
+
results_url = None
|
1676
|
+
else:
|
1677
|
+
results_url = f"{self.url}/content/{results_uuid}"
|
1678
|
+
|
1679
|
+
latest_job_run_details = data.get("latest_job_run_details", {})
|
1680
|
+
if data.get("status") == "partial_failed":
|
1681
|
+
latest_error_report_uuid = latest_job_run_details.get("error_report_uuid")
|
1682
|
+
if latest_error_report_uuid is None:
|
1683
|
+
latest_job_run_details["error_report_url"] = None
|
1684
|
+
else:
|
1685
|
+
latest_error_report_url = (
|
1686
|
+
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
1687
|
+
)
|
1688
|
+
latest_job_run_details["error_report_url"] = latest_error_report_url
|
1689
|
+
|
1690
|
+
json_string = data.get("job_json_string")
|
1691
|
+
|
1692
|
+
# The job has been offloaded to GCS
|
1693
|
+
if include_json_string and json_string == "offloaded":
|
1694
|
+
# Attempt to fetch JSON string from GCS
|
1695
|
+
response = self._send_server_request(
|
1696
|
+
uri="api/v0/remote-inference/pull",
|
1697
|
+
method="POST",
|
1698
|
+
payload={"job_uuid": job_uuid},
|
1699
|
+
)
|
1700
|
+
# Handle any errors in the response
|
1701
|
+
self._resolve_server_response(response)
|
1702
|
+
if "signed_url" not in response.json():
|
1703
|
+
from .exceptions import CoopResponseError
|
1704
|
+
|
1705
|
+
raise CoopResponseError("No signed url was provided.")
|
1706
|
+
signed_url = response.json().get("signed_url")
|
1707
|
+
|
1708
|
+
if signed_url == "": # The job is in legacy format
|
1709
|
+
job_json = json_string
|
1710
|
+
|
1711
|
+
try:
|
1712
|
+
response = requests.get(signed_url)
|
1713
|
+
self._resolve_gcs_response(response)
|
1714
|
+
job_json = json.dumps(response.json())
|
1715
|
+
except Exception:
|
1716
|
+
job_json = json_string
|
1717
|
+
|
1718
|
+
# If the job is in legacy format, we should already have the JSON string
|
1719
|
+
# from the first API call
|
1720
|
+
elif include_json_string and not json_string == "offloaded":
|
1721
|
+
job_json = json_string
|
1722
|
+
|
1723
|
+
# If include_json_string is False, we don't need the JSON string at all
|
1724
|
+
else:
|
1725
|
+
job_json = None
|
1726
|
+
|
1727
|
+
return RemoteInferenceResponse(
|
1728
|
+
**{
|
1729
|
+
"job_uuid": data.get("job_uuid"),
|
1730
|
+
"results_uuid": results_uuid,
|
1731
|
+
"results_url": results_url,
|
1732
|
+
"status": data.get("status"),
|
1733
|
+
"version": data.get("version"),
|
1734
|
+
"job_json_string": job_json,
|
1735
|
+
"latest_job_run_details": latest_job_run_details,
|
1736
|
+
}
|
1737
|
+
)
|
1738
|
+
|
1434
1739
|
def _validate_remote_job_status_types(
|
1435
1740
|
self, status: Union[RemoteJobStatus, List[RemoteJobStatus]]
|
1436
1741
|
) -> List[RemoteJobStatus]:
|
@@ -2470,7 +2775,7 @@ class Coop(CoopFunctionsMixin):
|
|
2470
2775
|
if "signed_url" not in response.json():
|
2471
2776
|
from .exceptions import CoopResponseError
|
2472
2777
|
|
2473
|
-
raise CoopResponseError("No signed url was provided
|
2778
|
+
raise CoopResponseError("No signed url was provided.")
|
2474
2779
|
signed_url = response.json().get("signed_url")
|
2475
2780
|
|
2476
2781
|
if signed_url == "": # it is in old format
|
@@ -2872,6 +3177,53 @@ class Coop(CoopFunctionsMixin):
|
|
2872
3177
|
self._resolve_server_response(response)
|
2873
3178
|
return response.json()
|
2874
3179
|
|
3180
|
+
def pay_for_service(
|
3181
|
+
self,
|
3182
|
+
credits_transferred: int,
|
3183
|
+
recipient_username: str,
|
3184
|
+
service_name: str,
|
3185
|
+
) -> dict:
|
3186
|
+
"""
|
3187
|
+
Pay for a service.
|
3188
|
+
|
3189
|
+
This method transfers a specified number of credits from the authenticated user's
|
3190
|
+
account to another user's account on the Expected Parrot platform.
|
3191
|
+
|
3192
|
+
Parameters:
|
3193
|
+
credits_transferred (int): The number of credits to transfer to the recipient
|
3194
|
+
recipient_username (str): The username of the recipient
|
3195
|
+
service_name (str): The name of the service to pay for
|
3196
|
+
|
3197
|
+
Returns:
|
3198
|
+
dict: Information about the transfer transaction, including:
|
3199
|
+
- success: Whether the transaction was successful
|
3200
|
+
- transaction_id: A unique identifier for the transaction
|
3201
|
+
- remaining_credits: The number of credits remaining in the sender's account
|
3202
|
+
|
3203
|
+
Raises:
|
3204
|
+
CoopServerResponseError: If there's an error communicating with the server
|
3205
|
+
or if the transfer criteria aren't met (e.g., insufficient credits)
|
3206
|
+
|
3207
|
+
Example:
|
3208
|
+
>>> result = coop.pay_for_service(
|
3209
|
+
... credits_transferred=100,
|
3210
|
+
... service_name="service_name",
|
3211
|
+
... recipient_username="friend_username",
|
3212
|
+
... )
|
3213
|
+
>>> print(f"Transfer successful! You have {result['remaining_credits']} credits left.")
|
3214
|
+
"""
|
3215
|
+
response = self._send_server_request(
|
3216
|
+
uri="api/v0/users/pay-for-service",
|
3217
|
+
method="POST",
|
3218
|
+
payload={
|
3219
|
+
"cost_credits": credits_transferred,
|
3220
|
+
"service_name": service_name,
|
3221
|
+
"recipient_username": recipient_username,
|
3222
|
+
},
|
3223
|
+
)
|
3224
|
+
self._resolve_server_response(response)
|
3225
|
+
return response.json()
|
3226
|
+
|
2875
3227
|
def get_balance(self) -> dict:
|
2876
3228
|
"""
|
2877
3229
|
Get the current credit balance for the authenticated user.
|
@@ -2897,6 +3249,29 @@ class Coop(CoopFunctionsMixin):
|
|
2897
3249
|
self._resolve_server_response(response)
|
2898
3250
|
return response.json()
|
2899
3251
|
|
3252
|
+
def get_profile(self) -> dict:
|
3253
|
+
"""
|
3254
|
+
Get the current user's profile information.
|
3255
|
+
|
3256
|
+
This method retrieves the authenticated user's profile information from
|
3257
|
+
the Expected Parrot platform using their API key.
|
3258
|
+
|
3259
|
+
Returns:
|
3260
|
+
dict: User profile information including:
|
3261
|
+
- username: The user's username
|
3262
|
+
- email: The user's email address
|
3263
|
+
|
3264
|
+
Raises:
|
3265
|
+
CoopServerResponseError: If there's an error communicating with the server
|
3266
|
+
|
3267
|
+
Example:
|
3268
|
+
>>> profile = coop.get_profile()
|
3269
|
+
>>> print(f"Welcome, {profile['username']}!")
|
3270
|
+
"""
|
3271
|
+
response = self._send_server_request(uri="api/v0/users/profile", method="GET")
|
3272
|
+
self._resolve_server_response(response)
|
3273
|
+
return response.json()
|
3274
|
+
|
2900
3275
|
def login_gradio(self, timeout: int = 120, launch: bool = True, **launch_kwargs):
|
2901
3276
|
"""
|
2902
3277
|
Start the EDSL auth token login flow inside a **Gradio** application.
|
@@ -3174,7 +3549,7 @@ def main():
|
|
3174
3549
|
job = Jobs.example()
|
3175
3550
|
coop.remote_inference_cost(job)
|
3176
3551
|
job_coop_object = coop.remote_inference_create(job)
|
3177
|
-
job_coop_results = coop.
|
3552
|
+
job_coop_results = coop.new_remote_inference_get(job_coop_object.get("uuid"))
|
3178
3553
|
coop.get(job_coop_results.get("results_uuid"))
|
3179
3554
|
|
3180
3555
|
import streamlit as st
|
edsl/coop/coop_jobs_objects.py
CHANGED
@@ -26,7 +26,7 @@ class CoopJobsObjects(CoopObjects):
|
|
26
26
|
|
27
27
|
c = Coop()
|
28
28
|
job_details = [
|
29
|
-
c.
|
29
|
+
c.new_remote_inference_get(obj["uuid"], include_json_string=True)
|
30
30
|
for obj in self
|
31
31
|
]
|
32
32
|
|
@@ -53,7 +53,7 @@ class CoopJobsObjects(CoopObjects):
|
|
53
53
|
|
54
54
|
for obj in self:
|
55
55
|
if obj.get("results_uuid"):
|
56
|
-
result = c.
|
56
|
+
result = c.pull(obj["results_uuid"], expected_object_type="results")
|
57
57
|
results.append(result)
|
58
58
|
|
59
59
|
return results
|
edsl/dataset/dataset.py
CHANGED
@@ -1017,6 +1017,53 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
1017
1017
|
# Save the document
|
1018
1018
|
doc.save(output_file)
|
1019
1019
|
|
1020
|
+
def unique(self) -> "Dataset":
|
1021
|
+
"""
|
1022
|
+
Remove duplicate rows from the dataset.
|
1023
|
+
|
1024
|
+
Returns:
|
1025
|
+
A new Dataset with duplicate rows removed.
|
1026
|
+
|
1027
|
+
Examples:
|
1028
|
+
>>> d = Dataset([{'a': [1, 2, 3, 1]}, {'b': [4, 5, 6, 4]}])
|
1029
|
+
>>> d.unique().data
|
1030
|
+
[{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
|
1031
|
+
|
1032
|
+
>>> d = Dataset([{'x': ['a', 'b', 'a']}, {'y': [1, 2, 1]}])
|
1033
|
+
>>> d.unique().data
|
1034
|
+
[{'x': ['a', 'b']}, {'y': [1, 2]}]
|
1035
|
+
|
1036
|
+
>>> # Dataset with a single column
|
1037
|
+
>>> Dataset([{'value': [1, 2, 3, 2, 1, 3]}]).unique().data
|
1038
|
+
[{'value': [1, 2, 3]}]
|
1039
|
+
"""
|
1040
|
+
# Convert data to tuples for each row to make them hashable
|
1041
|
+
rows = []
|
1042
|
+
for i in range(len(self)):
|
1043
|
+
row = tuple(entry[list(entry.keys())[0]][i] for entry in self.data)
|
1044
|
+
rows.append(row)
|
1045
|
+
|
1046
|
+
# Keep track of unique rows and their indices
|
1047
|
+
unique_rows = []
|
1048
|
+
indices = []
|
1049
|
+
|
1050
|
+
# Use a set to track seen rows
|
1051
|
+
seen = set()
|
1052
|
+
for i, row in enumerate(rows):
|
1053
|
+
if row not in seen:
|
1054
|
+
seen.add(row)
|
1055
|
+
unique_rows.append(row)
|
1056
|
+
indices.append(i)
|
1057
|
+
|
1058
|
+
# Create a new dataset with only the unique rows
|
1059
|
+
new_data = []
|
1060
|
+
for entry in self.data:
|
1061
|
+
key, values = list(entry.items())[0]
|
1062
|
+
new_values = [values[i] for i in indices]
|
1063
|
+
new_data.append({key: new_values})
|
1064
|
+
|
1065
|
+
return Dataset(new_data)
|
1066
|
+
|
1020
1067
|
def expand(self, field: str, number_field: bool = False) -> "Dataset":
|
1021
1068
|
"""
|
1022
1069
|
Expand a field containing lists into multiple rows.
|
@@ -1086,47 +1133,6 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
1086
1133
|
|
1087
1134
|
return Dataset(new_data)
|
1088
1135
|
|
1089
|
-
def unique(self) -> "Dataset":
|
1090
|
-
"""Return a new dataset with only unique observations.
|
1091
|
-
|
1092
|
-
Examples:
|
1093
|
-
>>> d = Dataset([{'a': [1, 2, 2, 3]}, {'b': [4, 5, 5, 6]}])
|
1094
|
-
>>> d.unique().data
|
1095
|
-
[{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
|
1096
|
-
|
1097
|
-
>>> d = Dataset([{'x': ['a', 'a', 'b']}, {'y': [1, 1, 2]}])
|
1098
|
-
>>> d.unique().data
|
1099
|
-
[{'x': ['a', 'b']}, {'y': [1, 2]}]
|
1100
|
-
"""
|
1101
|
-
# Get all column names and values
|
1102
|
-
headers, data = self._tabular()
|
1103
|
-
|
1104
|
-
# Create a list of unique rows
|
1105
|
-
unique_rows = []
|
1106
|
-
seen = set()
|
1107
|
-
|
1108
|
-
for row in data:
|
1109
|
-
# Convert the row to a hashable representation for comparison
|
1110
|
-
# We need to handle potential unhashable types
|
1111
|
-
try:
|
1112
|
-
row_key = tuple(map(lambda x: str(x) if isinstance(x, (list, dict)) else x, row))
|
1113
|
-
if row_key not in seen:
|
1114
|
-
seen.add(row_key)
|
1115
|
-
unique_rows.append(row)
|
1116
|
-
except:
|
1117
|
-
# Fallback for complex objects: compare based on string representation
|
1118
|
-
row_str = str(row)
|
1119
|
-
if row_str not in seen:
|
1120
|
-
seen.add(row_str)
|
1121
|
-
unique_rows.append(row)
|
1122
|
-
|
1123
|
-
# Create a new dataset with unique combinations
|
1124
|
-
new_data = []
|
1125
|
-
for i, header in enumerate(headers):
|
1126
|
-
values = [row[i] for row in unique_rows]
|
1127
|
-
new_data.append({header: values})
|
1128
|
-
|
1129
|
-
return Dataset(new_data)
|
1130
1136
|
|
1131
1137
|
|
1132
1138
|
if __name__ == "__main__":
|