sutro 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/cli.py CHANGED
@@ -141,7 +141,6 @@ def jobs():
141
141
  """Manage jobs."""
142
142
  pass
143
143
 
144
-
145
144
  @jobs.command()
146
145
  @click.option(
147
146
  "--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
@@ -360,6 +359,24 @@ def download(dataset_id, file_name=None, output_path=None):
360
359
  with open(output_path + "/" + file_name, "wb") as f:
361
360
  f.write(file)
362
361
 
362
+ @cli.group()
363
+ def cache():
364
+ """Manage the local job results cache."""
365
+ pass
366
+
367
+ @cache.command()
368
+ def clear():
369
+ """Clear the local job results cache."""
370
+ sdk = get_sdk()
371
+ sdk._clear_job_results_cache()
372
+ click.echo(Fore.GREEN + "Job results cache cleared." + Style.RESET_ALL)
373
+
374
+ @cache.command()
375
+ def show():
376
+ """Show the contents and size of the job results cache."""
377
+ sdk = get_sdk()
378
+ sdk._show_cache_contents()
379
+
363
380
 
364
381
  @cli.command()
365
382
  def docs():
@@ -395,6 +412,7 @@ def quotas():
395
412
  + Style.RESET_ALL
396
413
  )
397
414
 
415
+
398
416
  @jobs.command()
399
417
  @click.argument("job_id", required=False)
400
418
  @click.option("--latest", is_flag=True, help="Attach to the latest job.")
sutro/sdk.py CHANGED
@@ -17,6 +17,8 @@ from tqdm import tqdm
17
17
  import time
18
18
  from pydantic import BaseModel
19
19
  import json
20
+ import pyarrow.parquet as pq
21
+ import shutil
20
22
 
21
23
 
22
24
  class JobStatus(str, Enum):
@@ -813,12 +815,14 @@ class Sutro:
813
815
  return None
814
816
 
815
817
  def get_job_results(
816
- self,
817
- job_id: str,
818
- include_inputs: bool = False,
819
- include_cumulative_logprobs: bool = False,
820
- with_original_df: pl.DataFrame | pd.DataFrame = None,
821
- output_column: str = "inference_result",
818
+ self,
819
+ job_id: str,
820
+ include_inputs: bool = False,
821
+ include_cumulative_logprobs: bool = False,
822
+ with_original_df: pl.DataFrame | pd.DataFrame = None,
823
+ output_column: str = "inference_result",
824
+ disable_cache: bool = False,
825
+ unpack_json: bool = True,
822
826
  ):
823
827
  """
824
828
  Get the results of a job by its ID.
@@ -831,47 +835,72 @@ class Sutro:
831
835
  include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
832
836
  with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
833
837
  output_column (str, optional): Name of the output column. Defaults to "inference_result".
838
+ disable_cache (bool, optional): Whether to disable the cache. Defaults to False.
839
+ unpack_json (bool, optional): If the output_column is formatted as a JSON string, decides whether to unpack the top level JSON fields in the results into separate columns. Defaults to True.
834
840
 
835
841
  Returns:
836
842
  Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
837
843
  """
838
- endpoint = f"{self.base_url}/job-results"
839
- payload = {
840
- "job_id": job_id,
841
- "include_inputs": include_inputs,
842
- "include_cumulative_logprobs": include_cumulative_logprobs,
843
- }
844
- headers = {
845
- "Authorization": f"Key {self.api_key}",
846
- "Content-Type": "application/json",
847
- }
848
- with yaspin(
844
+
845
+
846
+ file_path = os.path.expanduser(f"~/.sutro/job-results/{job_id}.snappy.parquet")
847
+ expected_num_columns = 1 + include_inputs + include_cumulative_logprobs
848
+ contains_expected_columns = False
849
+ if os.path.exists(file_path):
850
+ num_columns = pq.read_table(file_path).num_columns
851
+ contains_expected_columns = num_columns == expected_num_columns
852
+
853
+ if disable_cache == False and contains_expected_columns:
854
+ with yaspin(
855
+ SPINNER,
856
+ text=to_colored_text(f"Loading results from cache: {file_path}"),
857
+ color=YASPIN_COLOR,
858
+ ) as spinner:
859
+ results_df = pl.read_parquet(file_path)
860
+ spinner.write(to_colored_text("✔ Results loaded from cache", state="success"))
861
+ else:
862
+ endpoint = f"{self.base_url}/job-results"
863
+ payload = {
864
+ "job_id": job_id,
865
+ "include_inputs": include_inputs,
866
+ "include_cumulative_logprobs": include_cumulative_logprobs,
867
+ }
868
+ headers = {
869
+ "Authorization": f"Key {self.api_key}",
870
+ "Content-Type": "application/json",
871
+ }
872
+ with yaspin(
849
873
  SPINNER,
850
874
  text=to_colored_text(f"Gathering results from job: {job_id}"),
851
875
  color=YASPIN_COLOR,
852
- ) as spinner:
853
- response = requests.post(
854
- endpoint, data=json.dumps(payload), headers=headers
855
- )
856
- if response.status_code != 200:
857
- spinner.write(
858
- to_colored_text(
859
- f"Bad status code: {response.status_code}", state="fail"
860
- )
876
+ ) as spinner:
877
+ response = requests.post(
878
+ endpoint, data=json.dumps(payload), headers=headers
861
879
  )
862
- spinner.stop()
863
- print(to_colored_text(response.json(), state="fail"))
864
- return None
880
+ if response.status_code != 200:
881
+ spinner.write(
882
+ to_colored_text(
883
+ f"Bad status code: {response.status_code}", state="fail"
884
+ )
885
+ )
886
+ spinner.stop()
887
+ print(to_colored_text(response.json(), state="fail"))
888
+ return None
865
889
 
866
- spinner.write(
867
- to_colored_text("✔ Job results retrieved", state="success")
868
- )
890
+ spinner.write(
891
+ to_colored_text("✔ Job results retrieved", state="success")
892
+ )
869
893
 
870
- response_data = response.json()
871
- results_df = pl.DataFrame(response_data["results"])
894
+ response_data = response.json()
895
+ results_df = pl.DataFrame(response_data["results"])
872
896
 
873
- results_df = results_df.rename({'outputs': output_column})
897
+ results_df = results_df.rename({'outputs': output_column})
874
898
 
899
+ if disable_cache == False:
900
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
901
+ results_df.write_parquet(file_path, compression="snappy")
902
+ spinner.write(to_colored_text("✔ Results saved to cache", state="success"))
903
+
875
904
  # Ordering inputs col first seems most logical/useful
876
905
  column_config = [
877
906
  ('inputs', include_inputs),
@@ -884,6 +913,23 @@ class Sutro:
884
913
 
885
914
  results_df = results_df.select(columns_to_keep)
886
915
 
916
+ if unpack_json:
917
+ try:
918
+ first_row = json.loads(results_df.head(1)[output_column][0]) # checks if the first row can be json decoded
919
+ results_df = results_df.with_columns(
920
+ pl.col(output_column).str.json_decode().alias("output_column_json_decoded")
921
+ )
922
+ json_decoded_fields = first_row.keys()
923
+ for field in json_decoded_fields:
924
+ results_df = results_df.with_columns(
925
+ pl.col("output_column_json_decoded").struct.field(field).alias(field)
926
+ )
927
+ # drop the output_column and the json decoded column
928
+ results_df = results_df.drop([output_column, "output_column_json_decoded"])
929
+ except json.JSONDecodeError:
930
+ # if the first row cannot be json decoded, do nothing
931
+ pass
932
+
887
933
  # Handle concatenation with original DataFrame
888
934
  if with_original_df is not None:
889
935
  if isinstance(with_original_df, pd.DataFrame):
@@ -1296,6 +1342,31 @@ class Sutro:
1296
1342
  time.sleep(POLL_INTERVAL)
1297
1343
 
1298
1344
  return results
1345
+
1346
+ def _clear_job_results_cache(self): # only to be called by the CLI
1347
+ """
1348
+ Clears the cache for a job results.
1349
+ """
1350
+ if os.path.exists(os.path.expanduser("~/.sutro/job-results")):
1351
+ shutil.rmtree(os.path.expanduser("~/.sutro/job-results"))
1352
+
1353
+ def _show_cache_contents(self):
1354
+ """
1355
+ Shows the contents and size of each file in the job results cache.
1356
+ """
1357
+ # get the size of the job-results directory
1358
+ with yaspin(
1359
+ SPINNER, text=to_colored_text("Retrieving job results cache contents"), color=YASPIN_COLOR
1360
+ ) as spinner:
1361
+ if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
1362
+ spinner.write(to_colored_text("No job results cache found", "success"))
1363
+ return
1364
+ total_size = 0
1365
+ for file in os.listdir(os.path.expanduser("~/.sutro/job-results")):
1366
+ size = os.path.getsize(os.path.expanduser(f"~/.sutro/job-results/{file}")) / 1024 / 1024 / 1024
1367
+ total_size += size
1368
+ spinner.write(to_colored_text(f"File: {file} - Size: {size} GB"))
1369
+ spinner.write(to_colored_text(f"Total size of results cache at ~/.sutro/job-results: {total_size} GB", "success"))
1299
1370
 
1300
1371
  def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
1301
1372
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.27
3
+ Version: 0.1.29
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -12,6 +12,7 @@ Requires-Dist: colorama==0.4.4
12
12
  Requires-Dist: numpy==2.1.1
13
13
  Requires-Dist: pandas==2.2.3
14
14
  Requires-Dist: polars==1.8.2
15
+ Requires-Dist: pyarrow==21.0.0
15
16
  Requires-Dist: pydantic==2.11.4
16
17
  Requires-Dist: requests==2.32.3
17
18
  Requires-Dist: tqdm==4.67.1
@@ -0,0 +1,8 @@
1
+ sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
2
+ sutro/cli.py,sha256=8DrJVbjoayCUz4iszlj35Tv1q1gzDVzx_CuF6gZHwuU,13636
3
+ sutro/sdk.py,sha256=4cdRclI_meTrp26yhVeJDMlbWrqEb5Lwufi5d8WTmh0,55357
4
+ sutro-0.1.29.dist-info/METADATA,sha256=a6c-nO9s4x3a2cT99JuGIySRgqMzGlWA6iUEXtnGvow,700
5
+ sutro-0.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ sutro-0.1.29.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
7
+ sutro-0.1.29.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ sutro-0.1.29.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
2
- sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
3
- sutro/sdk.py,sha256=zBnLIg9wvoAplGEVNdgfCsgnUKR48NDPVOB4rBSYR4M,51552
4
- sutro-0.1.27.dist-info/METADATA,sha256=KO6cMAG83DpMnP0wf9B_9Fda1_zipO9Gqnfy-kn5HuM,669
5
- sutro-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- sutro-0.1.27.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
7
- sutro-0.1.27.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- sutro-0.1.27.dist-info/RECORD,,
File without changes