sutro 0.1.27__tar.gz → 0.1.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- {sutro-0.1.27 → sutro-0.1.29}/PKG-INFO +2 -1
- {sutro-0.1.27 → sutro-0.1.29}/pyproject.toml +2 -1
- {sutro-0.1.27 → sutro-0.1.29}/sutro/cli.py +19 -1
- {sutro-0.1.27 → sutro-0.1.29}/sutro/sdk.py +106 -35
- {sutro-0.1.27 → sutro-0.1.29}/.gitignore +0 -0
- {sutro-0.1.27 → sutro-0.1.29}/LICENSE +0 -0
- {sutro-0.1.27 → sutro-0.1.29}/README.md +0 -0
- {sutro-0.1.27 → sutro-0.1.29}/sutro/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sutro
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.29
|
|
4
4
|
Summary: Sutro Python SDK
|
|
5
5
|
Project-URL: Homepage, https://sutro.sh
|
|
6
6
|
Project-URL: Documentation, https://docs.sutro.sh
|
|
@@ -12,6 +12,7 @@ Requires-Dist: colorama==0.4.4
|
|
|
12
12
|
Requires-Dist: numpy==2.1.1
|
|
13
13
|
Requires-Dist: pandas==2.2.3
|
|
14
14
|
Requires-Dist: polars==1.8.2
|
|
15
|
+
Requires-Dist: pyarrow==21.0.0
|
|
15
16
|
Requires-Dist: pydantic==2.11.4
|
|
16
17
|
Requires-Dist: requests==2.32.3
|
|
17
18
|
Requires-Dist: tqdm==4.67.1
|
|
@@ -9,7 +9,7 @@ installer = "uv"
|
|
|
9
9
|
|
|
10
10
|
[project]
|
|
11
11
|
name = "sutro"
|
|
12
|
-
version = "0.1.
|
|
12
|
+
version = "0.1.29"
|
|
13
13
|
description = "Sutro Python SDK"
|
|
14
14
|
readme = "README.md"
|
|
15
15
|
requires-python = ">=3.10"
|
|
@@ -24,6 +24,7 @@ dependencies = [
|
|
|
24
24
|
"yaspin==3.1.0",
|
|
25
25
|
"tqdm==4.67.1",
|
|
26
26
|
"pydantic==2.11.4",
|
|
27
|
+
"pyarrow==21.0.0",
|
|
27
28
|
]
|
|
28
29
|
|
|
29
30
|
[project.scripts]
|
|
@@ -141,7 +141,6 @@ def jobs():
|
|
|
141
141
|
"""Manage jobs."""
|
|
142
142
|
pass
|
|
143
143
|
|
|
144
|
-
|
|
145
144
|
@jobs.command()
|
|
146
145
|
@click.option(
|
|
147
146
|
"--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
|
|
@@ -360,6 +359,24 @@ def download(dataset_id, file_name=None, output_path=None):
|
|
|
360
359
|
with open(output_path + "/" + file_name, "wb") as f:
|
|
361
360
|
f.write(file)
|
|
362
361
|
|
|
362
|
+
@cli.group()
|
|
363
|
+
def cache():
|
|
364
|
+
"""Manage the local job results cache."""
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
@cache.command()
|
|
368
|
+
def clear():
|
|
369
|
+
"""Clear the local job results cache."""
|
|
370
|
+
sdk = get_sdk()
|
|
371
|
+
sdk._clear_job_results_cache()
|
|
372
|
+
click.echo(Fore.GREEN + "Job results cache cleared." + Style.RESET_ALL)
|
|
373
|
+
|
|
374
|
+
@cache.command()
|
|
375
|
+
def show():
|
|
376
|
+
"""Show the contents and size of the job results cache."""
|
|
377
|
+
sdk = get_sdk()
|
|
378
|
+
sdk._show_cache_contents()
|
|
379
|
+
|
|
363
380
|
|
|
364
381
|
@cli.command()
|
|
365
382
|
def docs():
|
|
@@ -395,6 +412,7 @@ def quotas():
|
|
|
395
412
|
+ Style.RESET_ALL
|
|
396
413
|
)
|
|
397
414
|
|
|
415
|
+
|
|
398
416
|
@jobs.command()
|
|
399
417
|
@click.argument("job_id", required=False)
|
|
400
418
|
@click.option("--latest", is_flag=True, help="Attach to the latest job.")
|
|
@@ -17,6 +17,8 @@ from tqdm import tqdm
|
|
|
17
17
|
import time
|
|
18
18
|
from pydantic import BaseModel
|
|
19
19
|
import json
|
|
20
|
+
import pyarrow.parquet as pq
|
|
21
|
+
import shutil
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
class JobStatus(str, Enum):
|
|
@@ -813,12 +815,14 @@ class Sutro:
|
|
|
813
815
|
return None
|
|
814
816
|
|
|
815
817
|
def get_job_results(
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
818
|
+
self,
|
|
819
|
+
job_id: str,
|
|
820
|
+
include_inputs: bool = False,
|
|
821
|
+
include_cumulative_logprobs: bool = False,
|
|
822
|
+
with_original_df: pl.DataFrame | pd.DataFrame = None,
|
|
823
|
+
output_column: str = "inference_result",
|
|
824
|
+
disable_cache: bool = False,
|
|
825
|
+
unpack_json: bool = True,
|
|
822
826
|
):
|
|
823
827
|
"""
|
|
824
828
|
Get the results of a job by its ID.
|
|
@@ -831,47 +835,72 @@ class Sutro:
|
|
|
831
835
|
include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
|
|
832
836
|
with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
|
|
833
837
|
output_column (str, optional): Name of the output column. Defaults to "inference_result".
|
|
838
|
+
disable_cache (bool, optional): Whether to disable the cache. Defaults to False.
|
|
839
|
+
unpack_json (bool, optional): If the output_column is formatted as a JSON string, decides whether to unpack the top level JSON fields in the results into separate columns. Defaults to True.
|
|
834
840
|
|
|
835
841
|
Returns:
|
|
836
842
|
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
837
843
|
"""
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
file_path = os.path.expanduser(f"~/.sutro/job-results/{job_id}.snappy.parquet")
|
|
847
|
+
expected_num_columns = 1 + include_inputs + include_cumulative_logprobs
|
|
848
|
+
contains_expected_columns = False
|
|
849
|
+
if os.path.exists(file_path):
|
|
850
|
+
num_columns = pq.read_table(file_path).num_columns
|
|
851
|
+
contains_expected_columns = num_columns == expected_num_columns
|
|
852
|
+
|
|
853
|
+
if disable_cache == False and contains_expected_columns:
|
|
854
|
+
with yaspin(
|
|
855
|
+
SPINNER,
|
|
856
|
+
text=to_colored_text(f"Loading results from cache: {file_path}"),
|
|
857
|
+
color=YASPIN_COLOR,
|
|
858
|
+
) as spinner:
|
|
859
|
+
results_df = pl.read_parquet(file_path)
|
|
860
|
+
spinner.write(to_colored_text("✔ Results loaded from cache", state="success"))
|
|
861
|
+
else:
|
|
862
|
+
endpoint = f"{self.base_url}/job-results"
|
|
863
|
+
payload = {
|
|
864
|
+
"job_id": job_id,
|
|
865
|
+
"include_inputs": include_inputs,
|
|
866
|
+
"include_cumulative_logprobs": include_cumulative_logprobs,
|
|
867
|
+
}
|
|
868
|
+
headers = {
|
|
869
|
+
"Authorization": f"Key {self.api_key}",
|
|
870
|
+
"Content-Type": "application/json",
|
|
871
|
+
}
|
|
872
|
+
with yaspin(
|
|
849
873
|
SPINNER,
|
|
850
874
|
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
851
875
|
color=YASPIN_COLOR,
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
)
|
|
856
|
-
if response.status_code != 200:
|
|
857
|
-
spinner.write(
|
|
858
|
-
to_colored_text(
|
|
859
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
860
|
-
)
|
|
876
|
+
) as spinner:
|
|
877
|
+
response = requests.post(
|
|
878
|
+
endpoint, data=json.dumps(payload), headers=headers
|
|
861
879
|
)
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
880
|
+
if response.status_code != 200:
|
|
881
|
+
spinner.write(
|
|
882
|
+
to_colored_text(
|
|
883
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
884
|
+
)
|
|
885
|
+
)
|
|
886
|
+
spinner.stop()
|
|
887
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
888
|
+
return None
|
|
865
889
|
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
890
|
+
spinner.write(
|
|
891
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
892
|
+
)
|
|
869
893
|
|
|
870
|
-
|
|
871
|
-
|
|
894
|
+
response_data = response.json()
|
|
895
|
+
results_df = pl.DataFrame(response_data["results"])
|
|
872
896
|
|
|
873
|
-
|
|
897
|
+
results_df = results_df.rename({'outputs': output_column})
|
|
874
898
|
|
|
899
|
+
if disable_cache == False:
|
|
900
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
901
|
+
results_df.write_parquet(file_path, compression="snappy")
|
|
902
|
+
spinner.write(to_colored_text("✔ Results saved to cache", state="success"))
|
|
903
|
+
|
|
875
904
|
# Ordering inputs col first seems most logical/useful
|
|
876
905
|
column_config = [
|
|
877
906
|
('inputs', include_inputs),
|
|
@@ -884,6 +913,23 @@ class Sutro:
|
|
|
884
913
|
|
|
885
914
|
results_df = results_df.select(columns_to_keep)
|
|
886
915
|
|
|
916
|
+
if unpack_json:
|
|
917
|
+
try:
|
|
918
|
+
first_row = json.loads(results_df.head(1)[output_column][0]) # checks if the first row can be json decoded
|
|
919
|
+
results_df = results_df.with_columns(
|
|
920
|
+
pl.col(output_column).str.json_decode().alias("output_column_json_decoded")
|
|
921
|
+
)
|
|
922
|
+
json_decoded_fields = first_row.keys()
|
|
923
|
+
for field in json_decoded_fields:
|
|
924
|
+
results_df = results_df.with_columns(
|
|
925
|
+
pl.col("output_column_json_decoded").struct.field(field).alias(field)
|
|
926
|
+
)
|
|
927
|
+
# drop the output_column and the json decoded column
|
|
928
|
+
results_df = results_df.drop([output_column, "output_column_json_decoded"])
|
|
929
|
+
except json.JSONDecodeError:
|
|
930
|
+
# if the first row cannot be json decoded, do nothing
|
|
931
|
+
pass
|
|
932
|
+
|
|
887
933
|
# Handle concatenation with original DataFrame
|
|
888
934
|
if with_original_df is not None:
|
|
889
935
|
if isinstance(with_original_df, pd.DataFrame):
|
|
@@ -1296,6 +1342,31 @@ class Sutro:
|
|
|
1296
1342
|
time.sleep(POLL_INTERVAL)
|
|
1297
1343
|
|
|
1298
1344
|
return results
|
|
1345
|
+
|
|
1346
|
+
def _clear_job_results_cache(self): # only to be called by the CLI
|
|
1347
|
+
"""
|
|
1348
|
+
Clears the cache for a job results.
|
|
1349
|
+
"""
|
|
1350
|
+
if os.path.exists(os.path.expanduser("~/.sutro/job-results")):
|
|
1351
|
+
shutil.rmtree(os.path.expanduser("~/.sutro/job-results"))
|
|
1352
|
+
|
|
1353
|
+
def _show_cache_contents(self):
|
|
1354
|
+
"""
|
|
1355
|
+
Shows the contents and size of each file in the job results cache.
|
|
1356
|
+
"""
|
|
1357
|
+
# get the size of the job-results directory
|
|
1358
|
+
with yaspin(
|
|
1359
|
+
SPINNER, text=to_colored_text("Retrieving job results cache contents"), color=YASPIN_COLOR
|
|
1360
|
+
) as spinner:
|
|
1361
|
+
if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
|
|
1362
|
+
spinner.write(to_colored_text("No job results cache found", "success"))
|
|
1363
|
+
return
|
|
1364
|
+
total_size = 0
|
|
1365
|
+
for file in os.listdir(os.path.expanduser("~/.sutro/job-results")):
|
|
1366
|
+
size = os.path.getsize(os.path.expanduser(f"~/.sutro/job-results/{file}")) / 1024 / 1024 / 1024
|
|
1367
|
+
total_size += size
|
|
1368
|
+
spinner.write(to_colored_text(f"File: {file} - Size: {size} GB"))
|
|
1369
|
+
spinner.write(to_colored_text(f"Total size of results cache at ~/.sutro/job-results: {total_size} GB", "success"))
|
|
1299
1370
|
|
|
1300
1371
|
def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
|
|
1301
1372
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|