sutro 0.1.18__tar.gz → 0.1.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sutro-0.1.18 → sutro-0.1.19}/PKG-INFO +1 -1
- {sutro-0.1.18 → sutro-0.1.19}/pyproject.toml +1 -1
- {sutro-0.1.18 → sutro-0.1.19}/sutro/sdk.py +79 -18
- {sutro-0.1.18 → sutro-0.1.19}/.gitignore +0 -0
- {sutro-0.1.18 → sutro-0.1.19}/LICENSE +0 -0
- {sutro-0.1.18 → sutro-0.1.19}/README.md +0 -0
- {sutro-0.1.18 → sutro-0.1.19}/sutro/__init__.py +0 -0
- {sutro-0.1.18 → sutro-0.1.19}/sutro/cli.py +0 -0
|
@@ -97,6 +97,16 @@ def to_colored_text(
|
|
|
97
97
|
# Default to blue for normal/processing states
|
|
98
98
|
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
99
99
|
|
|
100
|
+
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
101
|
+
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
102
|
+
def make_clickable_link(url, text=None):
|
|
103
|
+
"""
|
|
104
|
+
Create a clickable link for terminals that support OSC 8 hyperlinks.
|
|
105
|
+
Falls back to plain text for terminals that don't support it.
|
|
106
|
+
"""
|
|
107
|
+
if text is None:
|
|
108
|
+
text = url
|
|
109
|
+
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
100
110
|
|
|
101
111
|
class Sutro:
|
|
102
112
|
def __init__(
|
|
@@ -304,7 +314,7 @@ class Sutro:
|
|
|
304
314
|
else:
|
|
305
315
|
spinner.write(
|
|
306
316
|
to_colored_text(
|
|
307
|
-
f"
|
|
317
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
|
|
308
318
|
state="success",
|
|
309
319
|
)
|
|
310
320
|
)
|
|
@@ -312,13 +322,14 @@ class Sutro:
|
|
|
312
322
|
spinner.write(
|
|
313
323
|
to_colored_text(
|
|
314
324
|
f"Use `so.get_job_status('{job_id}')` to check the status of the job."
|
|
325
|
+
)
|
|
315
326
|
)
|
|
316
|
-
)
|
|
317
327
|
return job_id
|
|
318
328
|
|
|
319
329
|
success = False
|
|
320
330
|
if stay_attached and job_id is not None:
|
|
321
|
-
spinner.write(to_colored_text("Awaiting job start...",
|
|
331
|
+
spinner.write(to_colored_text("Awaiting job start...", ))
|
|
332
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
322
333
|
started = self._await_job_start(job_id)
|
|
323
334
|
if not started:
|
|
324
335
|
failure_reason = self._get_failure_reason(job_id)
|
|
@@ -599,6 +610,7 @@ class Sutro:
|
|
|
599
610
|
text=to_colored_text("Awaiting status updates..."),
|
|
600
611
|
color=YASPIN_COLOR,
|
|
601
612
|
)
|
|
613
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
602
614
|
spinner.start()
|
|
603
615
|
for line in streaming_response.iter_lines():
|
|
604
616
|
if line:
|
|
@@ -819,10 +831,12 @@ class Sutro:
|
|
|
819
831
|
return None
|
|
820
832
|
|
|
821
833
|
def get_job_results(
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
834
|
+
self,
|
|
835
|
+
job_id: str,
|
|
836
|
+
include_inputs: bool = False,
|
|
837
|
+
include_cumulative_logprobs: bool = False,
|
|
838
|
+
with_original_df: pl.DataFrame | pd.DataFrame = None,
|
|
839
|
+
output_column: str = "inference_result",
|
|
826
840
|
):
|
|
827
841
|
"""
|
|
828
842
|
Get the results of a job by its ID.
|
|
@@ -833,9 +847,11 @@ class Sutro:
|
|
|
833
847
|
job_id (str): The ID of the job to retrieve the results for.
|
|
834
848
|
include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
|
|
835
849
|
include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
|
|
850
|
+
with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
|
|
851
|
+
output_column (str, optional): Name of the output column. Defaults to "inference_result".
|
|
836
852
|
|
|
837
853
|
Returns:
|
|
838
|
-
|
|
854
|
+
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
839
855
|
"""
|
|
840
856
|
endpoint = f"{self.base_url}/job-results"
|
|
841
857
|
payload = {
|
|
@@ -848,18 +864,14 @@ class Sutro:
|
|
|
848
864
|
"Content-Type": "application/json",
|
|
849
865
|
}
|
|
850
866
|
with yaspin(
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
867
|
+
SPINNER,
|
|
868
|
+
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
869
|
+
color=YASPIN_COLOR,
|
|
854
870
|
) as spinner:
|
|
855
871
|
response = requests.post(
|
|
856
872
|
endpoint, data=json.dumps(payload), headers=headers
|
|
857
873
|
)
|
|
858
|
-
if response.status_code
|
|
859
|
-
spinner.write(
|
|
860
|
-
to_colored_text("✔ Job results retrieved", state="success")
|
|
861
|
-
)
|
|
862
|
-
else:
|
|
874
|
+
if response.status_code != 200:
|
|
863
875
|
spinner.write(
|
|
864
876
|
to_colored_text(
|
|
865
877
|
f"Bad status code: {response.status_code}", state="fail"
|
|
@@ -867,8 +879,56 @@ class Sutro:
|
|
|
867
879
|
)
|
|
868
880
|
spinner.stop()
|
|
869
881
|
print(to_colored_text(response.json(), state="fail"))
|
|
870
|
-
return
|
|
871
|
-
|
|
882
|
+
return None
|
|
883
|
+
|
|
884
|
+
spinner.write(
|
|
885
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
response_data = response.json()
|
|
889
|
+
results_df = pl.DataFrame(response_data["results"])
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
if len(results_df.columns ) == 1:
|
|
893
|
+
# Default column when API is only returning a list, and we construct the df
|
|
894
|
+
# from that
|
|
895
|
+
original_results_column = 'column_0'
|
|
896
|
+
else:
|
|
897
|
+
original_results_column = 'outputs'
|
|
898
|
+
|
|
899
|
+
results_df = results_df.rename({original_results_column: output_column})
|
|
900
|
+
|
|
901
|
+
# Ordering inputs col first seems most logical/useful
|
|
902
|
+
column_config = [
|
|
903
|
+
('inputs', include_inputs),
|
|
904
|
+
(output_column, True),
|
|
905
|
+
('cumulative_logprobs', include_cumulative_logprobs),
|
|
906
|
+
]
|
|
907
|
+
|
|
908
|
+
columns_to_keep = [col for col, include in column_config
|
|
909
|
+
if include and col in results_df.columns]
|
|
910
|
+
|
|
911
|
+
results_df = results_df.select(columns_to_keep)
|
|
912
|
+
|
|
913
|
+
# Handle concatenation with original DataFrame
|
|
914
|
+
if with_original_df is not None:
|
|
915
|
+
if isinstance(with_original_df, pd.DataFrame):
|
|
916
|
+
# Convert to polars for consistent handling
|
|
917
|
+
original_pl = pl.from_pandas(with_original_df)
|
|
918
|
+
|
|
919
|
+
combined_df = original_pl.with_columns(results_df)
|
|
920
|
+
|
|
921
|
+
# Convert back to pandas to match input type
|
|
922
|
+
return combined_df.to_pandas()
|
|
923
|
+
|
|
924
|
+
elif isinstance(with_original_df, pl.DataFrame):
|
|
925
|
+
return with_original_df.with_columns(results_df)
|
|
926
|
+
|
|
927
|
+
# Return pd.DataFrame type when appropriate
|
|
928
|
+
if with_original_df is None and isinstance(with_original_df, pd.DataFrame):
|
|
929
|
+
return results_df.to_pandas()
|
|
930
|
+
|
|
931
|
+
return results_df
|
|
872
932
|
|
|
873
933
|
def cancel_job(self, job_id: str):
|
|
874
934
|
"""
|
|
@@ -1227,6 +1287,7 @@ class Sutro:
|
|
|
1227
1287
|
with yaspin(
|
|
1228
1288
|
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1229
1289
|
) as spinner:
|
|
1290
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
1230
1291
|
while (time.time() - start_time) < timeout:
|
|
1231
1292
|
try:
|
|
1232
1293
|
status = self._fetch_job_status(job_id)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|