sutro 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/sdk.py +174 -29
- {sutro-0.1.17.dist-info → sutro-0.1.19.dist-info}/METADATA +1 -1
- sutro-0.1.19.dist-info/RECORD +8 -0
- sutro-0.1.17.dist-info/RECORD +0 -8
- {sutro-0.1.17.dist-info → sutro-0.1.19.dist-info}/WHEEL +0 -0
- {sutro-0.1.17.dist-info → sutro-0.1.19.dist-info}/entry_points.txt +0 -0
- {sutro-0.1.17.dist-info → sutro-0.1.19.dist-info}/licenses/LICENSE +0 -0
sutro/sdk.py
CHANGED
|
@@ -97,6 +97,16 @@ def to_colored_text(
|
|
|
97
97
|
# Default to blue for normal/processing states
|
|
98
98
|
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
99
99
|
|
|
100
|
+
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
101
|
+
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
102
|
+
def make_clickable_link(url, text=None):
|
|
103
|
+
"""
|
|
104
|
+
Create a clickable link for terminals that support OSC 8 hyperlinks.
|
|
105
|
+
Falls back to plain text for terminals that don't support it.
|
|
106
|
+
"""
|
|
107
|
+
if text is None:
|
|
108
|
+
text = url
|
|
109
|
+
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
100
110
|
|
|
101
111
|
class Sutro:
|
|
102
112
|
def __init__(
|
|
@@ -266,11 +276,6 @@ class Sutro:
|
|
|
266
276
|
"random_seed_per_input": random_seed_per_input,
|
|
267
277
|
"truncate_rows": truncate_rows
|
|
268
278
|
}
|
|
269
|
-
if dry_run:
|
|
270
|
-
spinner_text = to_colored_text("Retrieving cost estimates...")
|
|
271
|
-
else:
|
|
272
|
-
t = f"Creating priority {job_priority} job"
|
|
273
|
-
spinner_text = to_colored_text(t)
|
|
274
279
|
|
|
275
280
|
# There are two gotchas with yaspin:
|
|
276
281
|
# 1. Can't use print while in spinner is running
|
|
@@ -279,6 +284,8 @@ class Sutro:
|
|
|
279
284
|
# Terminal size {self._terminal_width} is too small to display spinner with the given settings.
|
|
280
285
|
# https://github.com/pavdmyt/yaspin/blob/9c7430b499ab4611888ece39783a870e4a05fa45/yaspin/core.py#L568-L571
|
|
281
286
|
job_id = None
|
|
287
|
+
t = f"Creating {'[dry run] ' if dry_run else ''}priority {job_priority} job"
|
|
288
|
+
spinner_text = to_colored_text(t)
|
|
282
289
|
with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
|
|
283
290
|
response = requests.post(
|
|
284
291
|
endpoint, data=json.dumps(payload), headers=headers
|
|
@@ -292,16 +299,22 @@ class Sutro:
|
|
|
292
299
|
print(to_colored_text(response.json(), state="fail"))
|
|
293
300
|
return None
|
|
294
301
|
else:
|
|
302
|
+
job_id = response_data["results"]
|
|
295
303
|
if dry_run:
|
|
296
304
|
spinner.write(
|
|
297
|
-
to_colored_text("
|
|
305
|
+
to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
|
|
306
|
+
)
|
|
307
|
+
spinner.stop()
|
|
308
|
+
self.await_job_completion(job_id, obtain_results=False)
|
|
309
|
+
cost_estimate = self._get_job_cost_estimate(job_id)
|
|
310
|
+
spinner.write(
|
|
311
|
+
to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
|
|
298
312
|
)
|
|
299
|
-
return
|
|
313
|
+
return job_id
|
|
300
314
|
else:
|
|
301
|
-
job_id = response_data["results"]
|
|
302
315
|
spinner.write(
|
|
303
316
|
to_colored_text(
|
|
304
|
-
f"
|
|
317
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
|
|
305
318
|
state="success",
|
|
306
319
|
)
|
|
307
320
|
)
|
|
@@ -309,12 +322,20 @@ class Sutro:
|
|
|
309
322
|
spinner.write(
|
|
310
323
|
to_colored_text(
|
|
311
324
|
f"Use `so.get_job_status('{job_id}')` to check the status of the job."
|
|
325
|
+
)
|
|
312
326
|
)
|
|
313
|
-
)
|
|
314
327
|
return job_id
|
|
315
328
|
|
|
316
329
|
success = False
|
|
317
330
|
if stay_attached and job_id is not None:
|
|
331
|
+
spinner.write(to_colored_text("Awaiting job start...", ))
|
|
332
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
333
|
+
started = self._await_job_start(job_id)
|
|
334
|
+
if not started:
|
|
335
|
+
failure_reason = self._get_failure_reason(job_id)
|
|
336
|
+
spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
|
|
337
|
+
return None
|
|
338
|
+
|
|
318
339
|
s = requests.Session()
|
|
319
340
|
payload = {
|
|
320
341
|
"job_id": job_id,
|
|
@@ -589,6 +610,7 @@ class Sutro:
|
|
|
589
610
|
text=to_colored_text("Awaiting status updates..."),
|
|
590
611
|
color=YASPIN_COLOR,
|
|
591
612
|
)
|
|
613
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
592
614
|
spinner.start()
|
|
593
615
|
for line in streaming_response.iter_lines():
|
|
594
616
|
if line:
|
|
@@ -719,6 +741,40 @@ class Sutro:
|
|
|
719
741
|
return
|
|
720
742
|
return response.json()["jobs"]
|
|
721
743
|
|
|
744
|
+
def _list_jobs_helper(self):
|
|
745
|
+
"""
|
|
746
|
+
Helper function to list jobs.
|
|
747
|
+
"""
|
|
748
|
+
endpoint = f"{self.base_url}/list-jobs"
|
|
749
|
+
headers = {
|
|
750
|
+
"Authorization": f"Key {self.api_key}",
|
|
751
|
+
"Content-Type": "application/json",
|
|
752
|
+
}
|
|
753
|
+
response = requests.get(endpoint, headers=headers)
|
|
754
|
+
if response.status_code != 200:
|
|
755
|
+
return None
|
|
756
|
+
return response.json()["jobs"]
|
|
757
|
+
|
|
758
|
+
def _get_job_cost_estimate(self, job_id: str):
|
|
759
|
+
"""
|
|
760
|
+
Get the cost estimate for a job.
|
|
761
|
+
"""
|
|
762
|
+
all_jobs = self._list_jobs_helper()
|
|
763
|
+
for job in all_jobs:
|
|
764
|
+
if job["job_id"] == job_id:
|
|
765
|
+
return job["cost_estimate"]
|
|
766
|
+
return None
|
|
767
|
+
|
|
768
|
+
def _get_failure_reason(self, job_id: str):
|
|
769
|
+
"""
|
|
770
|
+
Get the failure reason for a job.
|
|
771
|
+
"""
|
|
772
|
+
all_jobs = self._list_jobs_helper()
|
|
773
|
+
for job in all_jobs:
|
|
774
|
+
if job["job_id"] == job_id:
|
|
775
|
+
return job["failure_reason"]
|
|
776
|
+
return None
|
|
777
|
+
|
|
722
778
|
def _fetch_job_status(self, job_id: str):
|
|
723
779
|
"""
|
|
724
780
|
Core logic to fetch job status from the API.
|
|
@@ -775,10 +831,12 @@ class Sutro:
|
|
|
775
831
|
return None
|
|
776
832
|
|
|
777
833
|
def get_job_results(
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
834
|
+
self,
|
|
835
|
+
job_id: str,
|
|
836
|
+
include_inputs: bool = False,
|
|
837
|
+
include_cumulative_logprobs: bool = False,
|
|
838
|
+
with_original_df: pl.DataFrame | pd.DataFrame = None,
|
|
839
|
+
output_column: str = "inference_result",
|
|
782
840
|
):
|
|
783
841
|
"""
|
|
784
842
|
Get the results of a job by its ID.
|
|
@@ -789,9 +847,11 @@ class Sutro:
|
|
|
789
847
|
job_id (str): The ID of the job to retrieve the results for.
|
|
790
848
|
include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
|
|
791
849
|
include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
|
|
850
|
+
with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
|
|
851
|
+
output_column (str, optional): Name of the output column. Defaults to "inference_result".
|
|
792
852
|
|
|
793
853
|
Returns:
|
|
794
|
-
|
|
854
|
+
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
795
855
|
"""
|
|
796
856
|
endpoint = f"{self.base_url}/job-results"
|
|
797
857
|
payload = {
|
|
@@ -804,18 +864,14 @@ class Sutro:
|
|
|
804
864
|
"Content-Type": "application/json",
|
|
805
865
|
}
|
|
806
866
|
with yaspin(
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
867
|
+
SPINNER,
|
|
868
|
+
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
869
|
+
color=YASPIN_COLOR,
|
|
810
870
|
) as spinner:
|
|
811
871
|
response = requests.post(
|
|
812
872
|
endpoint, data=json.dumps(payload), headers=headers
|
|
813
873
|
)
|
|
814
|
-
if response.status_code
|
|
815
|
-
spinner.write(
|
|
816
|
-
to_colored_text("✔ Job results retrieved", state="success")
|
|
817
|
-
)
|
|
818
|
-
else:
|
|
874
|
+
if response.status_code != 200:
|
|
819
875
|
spinner.write(
|
|
820
876
|
to_colored_text(
|
|
821
877
|
f"Bad status code: {response.status_code}", state="fail"
|
|
@@ -823,8 +879,56 @@ class Sutro:
|
|
|
823
879
|
)
|
|
824
880
|
spinner.stop()
|
|
825
881
|
print(to_colored_text(response.json(), state="fail"))
|
|
826
|
-
return
|
|
827
|
-
|
|
882
|
+
return None
|
|
883
|
+
|
|
884
|
+
spinner.write(
|
|
885
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
response_data = response.json()
|
|
889
|
+
results_df = pl.DataFrame(response_data["results"])
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
if len(results_df.columns ) == 1:
|
|
893
|
+
# Default column when API is only returning a list, and we construct the df
|
|
894
|
+
# from that
|
|
895
|
+
original_results_column = 'column_0'
|
|
896
|
+
else:
|
|
897
|
+
original_results_column = 'outputs'
|
|
898
|
+
|
|
899
|
+
results_df = results_df.rename({original_results_column: output_column})
|
|
900
|
+
|
|
901
|
+
# Ordering inputs col first seems most logical/useful
|
|
902
|
+
column_config = [
|
|
903
|
+
('inputs', include_inputs),
|
|
904
|
+
(output_column, True),
|
|
905
|
+
('cumulative_logprobs', include_cumulative_logprobs),
|
|
906
|
+
]
|
|
907
|
+
|
|
908
|
+
columns_to_keep = [col for col, include in column_config
|
|
909
|
+
if include and col in results_df.columns]
|
|
910
|
+
|
|
911
|
+
results_df = results_df.select(columns_to_keep)
|
|
912
|
+
|
|
913
|
+
# Handle concatenation with original DataFrame
|
|
914
|
+
if with_original_df is not None:
|
|
915
|
+
if isinstance(with_original_df, pd.DataFrame):
|
|
916
|
+
# Convert to polars for consistent handling
|
|
917
|
+
original_pl = pl.from_pandas(with_original_df)
|
|
918
|
+
|
|
919
|
+
combined_df = original_pl.with_columns(results_df)
|
|
920
|
+
|
|
921
|
+
# Convert back to pandas to match input type
|
|
922
|
+
return combined_df.to_pandas()
|
|
923
|
+
|
|
924
|
+
elif isinstance(with_original_df, pl.DataFrame):
|
|
925
|
+
return with_original_df.with_columns(results_df)
|
|
926
|
+
|
|
927
|
+
# Return pd.DataFrame type when appropriate
|
|
928
|
+
if with_original_df is None and isinstance(with_original_df, pd.DataFrame):
|
|
929
|
+
return results_df.to_pandas()
|
|
930
|
+
|
|
931
|
+
return results_df
|
|
828
932
|
|
|
829
933
|
def cancel_job(self, job_id: str):
|
|
830
934
|
"""
|
|
@@ -1162,7 +1266,7 @@ class Sutro:
|
|
|
1162
1266
|
return
|
|
1163
1267
|
return response.json()["quotas"]
|
|
1164
1268
|
|
|
1165
|
-
def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200) -> list | None:
|
|
1269
|
+
def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200, obtain_results: bool = True) -> list | None:
|
|
1166
1270
|
"""
|
|
1167
1271
|
Waits for job completion to occur and then returns the results upon
|
|
1168
1272
|
a successful completion.
|
|
@@ -1181,8 +1285,9 @@ class Sutro:
|
|
|
1181
1285
|
results = None
|
|
1182
1286
|
start_time = time.time()
|
|
1183
1287
|
with yaspin(
|
|
1184
|
-
|
|
1288
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1185
1289
|
) as spinner:
|
|
1290
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
1186
1291
|
while (time.time() - start_time) < timeout:
|
|
1187
1292
|
try:
|
|
1188
1293
|
status = self._fetch_job_status(job_id)
|
|
@@ -1201,7 +1306,8 @@ class Sutro:
|
|
|
1201
1306
|
if status == JobStatus.SUCCEEDED:
|
|
1202
1307
|
spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
|
|
1203
1308
|
spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
|
|
1204
|
-
|
|
1309
|
+
if obtain_results:
|
|
1310
|
+
results = self.get_job_results(job_id)
|
|
1205
1311
|
break
|
|
1206
1312
|
if status == JobStatus.FAILED:
|
|
1207
1313
|
spinner.write(to_colored_text("Job has failed", "fail"))
|
|
@@ -1213,4 +1319,43 @@ class Sutro:
|
|
|
1213
1319
|
|
|
1214
1320
|
time.sleep(POLL_INTERVAL)
|
|
1215
1321
|
|
|
1216
|
-
return results
|
|
1322
|
+
return results
|
|
1323
|
+
|
|
1324
|
+
def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
|
|
1325
|
+
"""
|
|
1326
|
+
Waits for job start to occur and then returns the results upon
|
|
1327
|
+
a successful start.
|
|
1328
|
+
|
|
1329
|
+
"""
|
|
1330
|
+
POLL_INTERVAL = 5
|
|
1331
|
+
|
|
1332
|
+
start_time = time.time()
|
|
1333
|
+
with yaspin(
|
|
1334
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1335
|
+
) as spinner:
|
|
1336
|
+
while (time.time() - start_time) < timeout:
|
|
1337
|
+
try:
|
|
1338
|
+
status = self._fetch_job_status(job_id)
|
|
1339
|
+
except requests.HTTPError as e:
|
|
1340
|
+
spinner.write(
|
|
1341
|
+
to_colored_text(
|
|
1342
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1343
|
+
)
|
|
1344
|
+
)
|
|
1345
|
+
spinner.stop()
|
|
1346
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1347
|
+
return None
|
|
1348
|
+
|
|
1349
|
+
spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
|
|
1350
|
+
|
|
1351
|
+
if status == JobStatus.RUNNING or status == JobStatus.STARTING:
|
|
1352
|
+
return True
|
|
1353
|
+
if status == JobStatus.FAILED:
|
|
1354
|
+
return False
|
|
1355
|
+
if status == JobStatus.CANCELLED:
|
|
1356
|
+
return False
|
|
1357
|
+
|
|
1358
|
+
time.sleep(POLL_INTERVAL)
|
|
1359
|
+
|
|
1360
|
+
return False
|
|
1361
|
+
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
+
sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
|
|
3
|
+
sutro/sdk.py,sha256=1FLepL3M7afptWwudF310KWmPQ9ZDQ51LiT0Xoh1S_o,52705
|
|
4
|
+
sutro-0.1.19.dist-info/METADATA,sha256=6Irt5RX_DaIyRC6ig4XvHr0p0CKnFK0nmWA23oi1mCA,669
|
|
5
|
+
sutro-0.1.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
sutro-0.1.19.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
+
sutro-0.1.19.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
sutro-0.1.19.dist-info/RECORD,,
|
sutro-0.1.17.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
-
sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
|
|
3
|
-
sutro/sdk.py,sha256=z3cvU9zyAi7EJF_rFIb9mQMjE9zuh576-p8KuP_F1PM,46500
|
|
4
|
-
sutro-0.1.17.dist-info/METADATA,sha256=UDHp7-xr8tdKmDl_GbxxRbVVC6nx62yr6-nka4vBwOY,669
|
|
5
|
-
sutro-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
sutro-0.1.17.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
-
sutro-0.1.17.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
sutro-0.1.17.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|