sutro 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/cli.py +19 -1
- sutro/sdk.py +80 -29
- {sutro-0.1.27.dist-info → sutro-0.1.28.dist-info}/METADATA +2 -1
- sutro-0.1.28.dist-info/RECORD +8 -0
- sutro-0.1.27.dist-info/RECORD +0 -8
- {sutro-0.1.27.dist-info → sutro-0.1.28.dist-info}/WHEEL +0 -0
- {sutro-0.1.27.dist-info → sutro-0.1.28.dist-info}/entry_points.txt +0 -0
- {sutro-0.1.27.dist-info → sutro-0.1.28.dist-info}/licenses/LICENSE +0 -0
sutro/cli.py
CHANGED
|
@@ -141,7 +141,6 @@ def jobs():
|
|
|
141
141
|
"""Manage jobs."""
|
|
142
142
|
pass
|
|
143
143
|
|
|
144
|
-
|
|
145
144
|
@jobs.command()
|
|
146
145
|
@click.option(
|
|
147
146
|
"--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
|
|
@@ -360,6 +359,24 @@ def download(dataset_id, file_name=None, output_path=None):
|
|
|
360
359
|
with open(output_path + "/" + file_name, "wb") as f:
|
|
361
360
|
f.write(file)
|
|
362
361
|
|
|
362
|
+
@cli.group()
|
|
363
|
+
def cache():
|
|
364
|
+
"""Manage the local job results cache."""
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
@cache.command()
|
|
368
|
+
def clear():
|
|
369
|
+
"""Clear the local job results cache."""
|
|
370
|
+
sdk = get_sdk()
|
|
371
|
+
sdk._clear_job_results_cache()
|
|
372
|
+
click.echo(Fore.GREEN + "Job results cache cleared." + Style.RESET_ALL)
|
|
373
|
+
|
|
374
|
+
@cache.command()
|
|
375
|
+
def show():
|
|
376
|
+
"""Show the contents and size of the job results cache."""
|
|
377
|
+
sdk = get_sdk()
|
|
378
|
+
sdk._show_cache_contents()
|
|
379
|
+
|
|
363
380
|
|
|
364
381
|
@cli.command()
|
|
365
382
|
def docs():
|
|
@@ -395,6 +412,7 @@ def quotas():
|
|
|
395
412
|
+ Style.RESET_ALL
|
|
396
413
|
)
|
|
397
414
|
|
|
415
|
+
|
|
398
416
|
@jobs.command()
|
|
399
417
|
@click.argument("job_id", required=False)
|
|
400
418
|
@click.option("--latest", is_flag=True, help="Attach to the latest job.")
|
sutro/sdk.py
CHANGED
|
@@ -17,6 +17,8 @@ from tqdm import tqdm
|
|
|
17
17
|
import time
|
|
18
18
|
from pydantic import BaseModel
|
|
19
19
|
import json
|
|
20
|
+
import pyarrow.parquet as pq
|
|
21
|
+
import shutil
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
class JobStatus(str, Enum):
|
|
@@ -819,6 +821,7 @@ class Sutro:
|
|
|
819
821
|
include_cumulative_logprobs: bool = False,
|
|
820
822
|
with_original_df: pl.DataFrame | pd.DataFrame = None,
|
|
821
823
|
output_column: str = "inference_result",
|
|
824
|
+
disable_cache: bool = False
|
|
822
825
|
):
|
|
823
826
|
"""
|
|
824
827
|
Get the results of a job by its ID.
|
|
@@ -835,43 +838,66 @@ class Sutro:
|
|
|
835
838
|
Returns:
|
|
836
839
|
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
837
840
|
"""
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
file_path = os.path.expanduser(f"~/.sutro/job-results/{job_id}.snappy.parquet")
|
|
844
|
+
expected_num_columns = 1 + include_inputs + include_cumulative_logprobs
|
|
845
|
+
contains_expected_columns = False
|
|
846
|
+
if os.path.exists(file_path):
|
|
847
|
+
num_columns = pq.read_table(file_path).num_columns
|
|
848
|
+
contains_expected_columns = num_columns == expected_num_columns
|
|
849
|
+
|
|
850
|
+
if disable_cache == False and contains_expected_columns:
|
|
851
|
+
with yaspin(
|
|
852
|
+
SPINNER,
|
|
853
|
+
text=to_colored_text(f"Loading results from cache: {file_path}"),
|
|
854
|
+
color=YASPIN_COLOR,
|
|
855
|
+
) as spinner:
|
|
856
|
+
results_df = pl.read_parquet(file_path)
|
|
857
|
+
spinner.write(to_colored_text("✔ Results loaded from cache", state="success"))
|
|
858
|
+
else:
|
|
859
|
+
endpoint = f"{self.base_url}/job-results"
|
|
860
|
+
payload = {
|
|
861
|
+
"job_id": job_id,
|
|
862
|
+
"include_inputs": include_inputs,
|
|
863
|
+
"include_cumulative_logprobs": include_cumulative_logprobs,
|
|
864
|
+
}
|
|
865
|
+
headers = {
|
|
866
|
+
"Authorization": f"Key {self.api_key}",
|
|
867
|
+
"Content-Type": "application/json",
|
|
868
|
+
}
|
|
869
|
+
with yaspin(
|
|
849
870
|
SPINNER,
|
|
850
871
|
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
851
872
|
color=YASPIN_COLOR,
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
)
|
|
856
|
-
if response.status_code != 200:
|
|
857
|
-
spinner.write(
|
|
858
|
-
to_colored_text(
|
|
859
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
860
|
-
)
|
|
873
|
+
) as spinner:
|
|
874
|
+
response = requests.post(
|
|
875
|
+
endpoint, data=json.dumps(payload), headers=headers
|
|
861
876
|
)
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
877
|
+
if response.status_code != 200:
|
|
878
|
+
spinner.write(
|
|
879
|
+
to_colored_text(
|
|
880
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
881
|
+
)
|
|
882
|
+
)
|
|
883
|
+
spinner.stop()
|
|
884
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
885
|
+
return None
|
|
865
886
|
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
887
|
+
spinner.write(
|
|
888
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
889
|
+
)
|
|
869
890
|
|
|
870
|
-
|
|
871
|
-
|
|
891
|
+
response_data = response.json()
|
|
892
|
+
results_df = pl.DataFrame(response_data["results"])
|
|
872
893
|
|
|
873
|
-
|
|
894
|
+
results_df = results_df.rename({'outputs': output_column})
|
|
874
895
|
|
|
896
|
+
if disable_cache == False:
|
|
897
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
898
|
+
results_df.write_parquet(file_path, compression="snappy")
|
|
899
|
+
spinner.write(to_colored_text("✔ Results saved to cache", state="success"))
|
|
900
|
+
|
|
875
901
|
# Ordering inputs col first seems most logical/useful
|
|
876
902
|
column_config = [
|
|
877
903
|
('inputs', include_inputs),
|
|
@@ -1296,6 +1322,31 @@ class Sutro:
|
|
|
1296
1322
|
time.sleep(POLL_INTERVAL)
|
|
1297
1323
|
|
|
1298
1324
|
return results
|
|
1325
|
+
|
|
1326
|
+
def _clear_job_results_cache(self): # only to be called by the CLI
|
|
1327
|
+
"""
|
|
1328
|
+
Clears the cache for a job results.
|
|
1329
|
+
"""
|
|
1330
|
+
if os.path.exists(os.path.expanduser("~/.sutro/job-results")):
|
|
1331
|
+
shutil.rmtree(os.path.expanduser("~/.sutro/job-results"))
|
|
1332
|
+
|
|
1333
|
+
def _show_cache_contents(self):
|
|
1334
|
+
"""
|
|
1335
|
+
Shows the contents and size of each file in the job results cache.
|
|
1336
|
+
"""
|
|
1337
|
+
# get the size of the job-results directory
|
|
1338
|
+
with yaspin(
|
|
1339
|
+
SPINNER, text=to_colored_text("Retrieving job results cache contents"), color=YASPIN_COLOR
|
|
1340
|
+
) as spinner:
|
|
1341
|
+
if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
|
|
1342
|
+
spinner.write(to_colored_text("No job results cache found", "success"))
|
|
1343
|
+
return
|
|
1344
|
+
total_size = 0
|
|
1345
|
+
for file in os.listdir(os.path.expanduser("~/.sutro/job-results")):
|
|
1346
|
+
size = os.path.getsize(os.path.expanduser(f"~/.sutro/job-results/{file}")) / 1024 / 1024 / 1024
|
|
1347
|
+
total_size += size
|
|
1348
|
+
spinner.write(to_colored_text(f"File: {file} - Size: {size} GB"))
|
|
1349
|
+
spinner.write(to_colored_text(f"Total size of results cache at ~/.sutro/job-results: {total_size} GB", "success"))
|
|
1299
1350
|
|
|
1300
1351
|
def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
|
|
1301
1352
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sutro
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.28
|
|
4
4
|
Summary: Sutro Python SDK
|
|
5
5
|
Project-URL: Homepage, https://sutro.sh
|
|
6
6
|
Project-URL: Documentation, https://docs.sutro.sh
|
|
@@ -12,6 +12,7 @@ Requires-Dist: colorama==0.4.4
|
|
|
12
12
|
Requires-Dist: numpy==2.1.1
|
|
13
13
|
Requires-Dist: pandas==2.2.3
|
|
14
14
|
Requires-Dist: polars==1.8.2
|
|
15
|
+
Requires-Dist: pyarrow==21.0.0
|
|
15
16
|
Requires-Dist: pydantic==2.11.4
|
|
16
17
|
Requires-Dist: requests==2.32.3
|
|
17
18
|
Requires-Dist: tqdm==4.67.1
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
+
sutro/cli.py,sha256=8DrJVbjoayCUz4iszlj35Tv1q1gzDVzx_CuF6gZHwuU,13636
|
|
3
|
+
sutro/sdk.py,sha256=0_KC8eOOEunEPChrpToIucbg6q-icyxSUeMwd1UKDSY,54150
|
|
4
|
+
sutro-0.1.28.dist-info/METADATA,sha256=Nu730OdGc5TMVMlw3_kqxZ03z8_3jsz57Og5PIhk7Wo,700
|
|
5
|
+
sutro-0.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
sutro-0.1.28.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
+
sutro-0.1.28.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
sutro-0.1.28.dist-info/RECORD,,
|
sutro-0.1.27.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
-
sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
|
|
3
|
-
sutro/sdk.py,sha256=zBnLIg9wvoAplGEVNdgfCsgnUKR48NDPVOB4rBSYR4M,51552
|
|
4
|
-
sutro-0.1.27.dist-info/METADATA,sha256=KO6cMAG83DpMnP0wf9B_9Fda1_zipO9Gqnfy-kn5HuM,669
|
|
5
|
-
sutro-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
sutro-0.1.27.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
-
sutro-0.1.27.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
sutro-0.1.27.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|