sutro 0.1.29__tar.gz → 0.1.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.29
3
+ Version: 0.1.31
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -17,6 +17,8 @@ Requires-Dist: pydantic==2.11.4
17
17
  Requires-Dist: requests==2.32.3
18
18
  Requires-Dist: tqdm==4.67.1
19
19
  Requires-Dist: yaspin==3.1.0
20
+ Provides-Extra: dev
21
+ Requires-Dist: ruff==0.13.1; extra == 'dev'
20
22
  Description-Content-Type: text/markdown
21
23
 
22
24
  # sutro-client
@@ -9,7 +9,7 @@ installer = "uv"
9
9
 
10
10
  [project]
11
11
  name = "sutro"
12
- version = "0.1.29"
12
+ version = "0.1.31"
13
13
  description = "Sutro Python SDK"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -27,6 +27,11 @@ dependencies = [
27
27
  "pyarrow==21.0.0",
28
28
  ]
29
29
 
30
+ [project.optional-dependencies]
31
+ dev = [
32
+ "ruff==0.13.1"
33
+ ]
34
+
30
35
  [project.scripts]
31
36
  sutro = "sutro.cli:cli"
32
37
 
@@ -1,4 +1,4 @@
1
- from datetime import datetime, timezone
1
+ from datetime import timezone
2
2
  import click
3
3
  from colorama import Fore, Style
4
4
  import os
@@ -35,9 +35,7 @@ def check_auth():
35
35
  def get_sdk():
36
36
  config = load_config()
37
37
  if config.get("base_url") != None:
38
- return Sutro(
39
- api_key=config.get("api_key"), base_url=config.get("base_url")
40
- )
38
+ return Sutro(api_key=config.get("api_key"), base_url=config.get("base_url"))
41
39
  else:
42
40
  return Sutro(api_key=config.get("api_key"))
43
41
 
@@ -141,6 +139,7 @@ def jobs():
141
139
  """Manage jobs."""
142
140
  pass
143
141
 
142
+
144
143
  @jobs.command()
145
144
  @click.option(
146
145
  "--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
@@ -359,11 +358,13 @@ def download(dataset_id, file_name=None, output_path=None):
359
358
  with open(output_path + "/" + file_name, "wb") as f:
360
359
  f.write(file)
361
360
 
361
+
362
362
  @cli.group()
363
363
  def cache():
364
364
  """Manage the local job results cache."""
365
365
  pass
366
366
 
367
+
367
368
  @cache.command()
368
369
  def clear():
369
370
  """Clear the local job results cache."""
@@ -371,6 +372,7 @@ def clear():
371
372
  sdk._clear_job_results_cache()
372
373
  click.echo(Fore.GREEN + "Job results cache cleared." + Style.RESET_ALL)
373
374
 
375
+
374
376
  @cache.command()
375
377
  def show():
376
378
  """Show the contents and size of the job results cache."""
@@ -1,22 +1,17 @@
1
- import threading
2
- from concurrent.futures import ThreadPoolExecutor
3
- from contextlib import contextmanager
4
1
  from enum import Enum
5
-
6
2
  import requests
7
3
  import pandas as pd
8
4
  import polars as pl
9
5
  import json
10
- from typing import Union, List, Optional, Literal, Generator, Dict, Any
6
+ from typing import Union, List, Optional, Literal, Dict, Any
11
7
  import os
12
8
  import sys
13
9
  from yaspin import yaspin
14
10
  from yaspin.spinners import Spinners
15
- from colorama import init, Fore, Back, Style
11
+ from colorama import init, Fore, Style
16
12
  from tqdm import tqdm
17
13
  import time
18
14
  from pydantic import BaseModel
19
- import json
20
15
  import pyarrow.parquet as pq
21
16
  import shutil
22
17
 
@@ -45,6 +40,7 @@ class JobStatus(str, Enum):
45
40
  def is_terminal(self) -> bool:
46
41
  return self in self.terminal_statuses()
47
42
 
43
+
48
44
  # Initialize colorama (required for Windows)
49
45
  init()
50
46
 
@@ -71,8 +67,13 @@ ModelOptions = Literal[
71
67
  "qwen-3-32b-thinking",
72
68
  "gemma-3-4b-it",
73
69
  "gemma-3-27b-it",
74
- "multilingual-e5-large-instruct",
75
- "gte-qwen2-7b-instruct",
70
+ "gpt-oss-120b",
71
+ "gpt-oss-20b",
72
+ "qwen-3-235b-a22b-thinking",
73
+ "qwen-3-30b-a3b-thinking",
74
+ "qwen-3-embedding-0.6b",
75
+ "qwen-3-embedding-6b",
76
+ "qwen-3-embedding-8b",
76
77
  ]
77
78
 
78
79
 
@@ -99,6 +100,7 @@ def to_colored_text(
99
100
  # Default to blue for normal/processing states
100
101
  return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
101
102
 
103
+
102
104
  # Isn't fully support in all terminals unfortunately. We should switch to Rich
103
105
  # at some point, but even Rich links aren't clickable on MacOS Terminal
104
106
  def make_clickable_link(url, text=None):
@@ -110,10 +112,9 @@ def make_clickable_link(url, text=None):
110
112
  text = url
111
113
  return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
112
114
 
115
+
113
116
  class Sutro:
114
- def __init__(
115
- self, api_key: str = None, base_url: str = "https://api.sutro.sh/"
116
- ):
117
+ def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
117
118
  self.api_key = api_key or self.check_for_api_key()
118
119
  self.base_url = base_url
119
120
 
@@ -220,7 +221,7 @@ class Sutro:
220
221
  cost_estimate: bool,
221
222
  stay_attached: Optional[bool],
222
223
  random_seed_per_input: bool,
223
- truncate_rows: bool
224
+ truncate_rows: bool,
224
225
  ):
225
226
  input_data = self.handle_data_helper(data, column)
226
227
  endpoint = f"{self.base_url}/batch-inference"
@@ -237,7 +238,7 @@ class Sutro:
237
238
  "cost_estimate": cost_estimate,
238
239
  "sampling_params": sampling_params,
239
240
  "random_seed_per_input": random_seed_per_input,
240
- "truncate_rows": truncate_rows
241
+ "truncate_rows": truncate_rows,
241
242
  }
242
243
 
243
244
  # There are two gotchas with yaspin:
@@ -266,13 +267,20 @@ class Sutro:
266
267
  job_id = response_data["results"]
267
268
  if cost_estimate:
268
269
  spinner.write(
269
- to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.")
270
+ to_colored_text(
271
+ f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later."
272
+ )
270
273
  )
271
274
  spinner.stop()
272
- self.await_job_completion(job_id, obtain_results=False, is_cost_estimate=True)
275
+ self.await_job_completion(
276
+ job_id, obtain_results=False, is_cost_estimate=True
277
+ )
273
278
  cost_estimate = self._get_job_cost_estimate(job_id)
274
279
  spinner.write(
275
- to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
280
+ to_colored_text(
281
+ f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}",
282
+ state="success",
283
+ )
276
284
  )
277
285
  return job_id
278
286
  else:
@@ -283,12 +291,14 @@ class Sutro:
283
291
  )
284
292
  )
285
293
  if not stay_attached:
286
- clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
294
+ clickable_link = make_clickable_link(
295
+ f"https://app.sutro.sh/jobs/{job_id}"
296
+ )
287
297
  spinner.write(
288
298
  to_colored_text(
289
299
  f"Use `so.get_job_status('{job_id}')` to check the status of the job, or monitor progress at {clickable_link}"
290
- )
291
300
  )
301
+ )
292
302
  return job_id
293
303
  except KeyboardInterrupt:
294
304
  pass
@@ -298,22 +308,32 @@ class Sutro:
298
308
 
299
309
  success = False
300
310
  if stay_attached and job_id is not None:
301
- spinner.write(to_colored_text("Awaiting job start...", ))
302
- clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
303
- spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
311
+ spinner.write(
312
+ to_colored_text(
313
+ "Awaiting job start...",
314
+ )
315
+ )
316
+ clickable_link = make_clickable_link(f"https://app.sutro.sh/jobs/{job_id}")
317
+ spinner.write(
318
+ to_colored_text(f"Progress can also be monitored at: {clickable_link}")
319
+ )
304
320
  started = self._await_job_start(job_id)
305
321
  if not started:
306
322
  failure_reason = self._get_failure_reason(job_id)
307
- spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
323
+ spinner.write(
324
+ to_colored_text(
325
+ f"Failure reason: {failure_reason['message']}", "fail"
326
+ )
327
+ )
308
328
  return None
309
329
  s = requests.Session()
310
330
  pbar = None
311
331
 
312
332
  try:
313
333
  with requests.get(
314
- f"{self.base_url}/stream-job-progress/{job_id}",
315
- headers=headers,
316
- stream=True,
334
+ f"{self.base_url}/stream-job-progress/{job_id}",
335
+ headers=headers,
336
+ stream=True,
317
337
  ) as streaming_response:
318
338
  streaming_response.raise_for_status()
319
339
  spinner = yaspin(
@@ -324,9 +344,9 @@ class Sutro:
324
344
  spinner.start()
325
345
 
326
346
  token_state = {
327
- 'input_tokens': 0,
328
- 'output_tokens': 0,
329
- 'total_tokens_processed_per_second': 0
347
+ "input_tokens": 0,
348
+ "output_tokens": 0,
349
+ "total_tokens_processed_per_second": 0,
330
350
  }
331
351
 
332
352
  for line in streaming_response.iter_lines():
@@ -340,7 +360,7 @@ class Sutro:
340
360
  if json_obj["update_type"] == "progress":
341
361
  if pbar is None:
342
362
  spinner.stop()
343
- postfix = f"Input tokens processed: 0"
363
+ postfix = "Input tokens processed: 0"
344
364
  pbar = self.fancy_tqdm(
345
365
  total=len(input_data),
346
366
  desc="Progress",
@@ -357,7 +377,8 @@ class Sutro:
357
377
  # Currently, the way the progress stream endpoint is defined,
358
378
  # its possible to have updates come in that only have 1 or 2 fields
359
379
  new = {
360
- k: v for k, v in json_obj.get('result', {}).items()
380
+ k: v
381
+ for k, v in json_obj.get("result", {}).items()
361
382
  if k in token_state and v >= token_state[k]
362
383
  }
363
384
  token_state.update(new)
@@ -388,8 +409,8 @@ class Sutro:
388
409
  # TODO: we implment retries in cases where the job hasn't written results yet
389
410
  # it would be better if we could receive a fully succeeded status from the job
390
411
  # and not have such a race condition
391
- max_retries = 20 # winds up being 100 seconds cumulative delay
392
- retry_delay = 5 # initial delay in seconds
412
+ max_retries = 20 # winds up being 100 seconds cumulative delay
413
+ retry_delay = 5 # initial delay in seconds
393
414
 
394
415
  for _ in range(max_retries):
395
416
  time.sleep(retry_delay)
@@ -436,7 +457,7 @@ class Sutro:
436
457
  def infer(
437
458
  self,
438
459
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
439
- model: Union[ModelOptions, List[ModelOptions]] = "llama-3.1-8b",
460
+ model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
440
461
  column: str = None,
441
462
  output_column: str = "inference_result",
442
463
  job_priority: int = 0,
@@ -446,7 +467,7 @@ class Sutro:
446
467
  dry_run: bool = False,
447
468
  stay_attached: Optional[bool] = None,
448
469
  random_seed_per_input: bool = False,
449
- truncate_rows: bool = False
470
+ truncate_rows: bool = False,
450
471
  ):
451
472
  """
452
473
  Run inference on the provided data.
@@ -475,22 +496,29 @@ class Sutro:
475
496
  """
476
497
  if isinstance(model, list) == False:
477
498
  model_list = [model]
478
- stay_attached = stay_attached if stay_attached is not None else job_priority == 0
499
+ stay_attached = (
500
+ stay_attached if stay_attached is not None else job_priority == 0
501
+ )
479
502
  else:
480
503
  model_list = model
481
504
  stay_attached = False
482
505
 
483
506
  # Convert BaseModel to dict if needed
484
507
  if output_schema is not None:
485
- if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
508
+ if hasattr(
509
+ output_schema, "model_json_schema"
510
+ ): # Check for pydantic Model interface
486
511
  json_schema = output_schema.model_json_schema()
487
512
  elif isinstance(output_schema, dict):
488
513
  json_schema = output_schema
489
514
  else:
490
- raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
515
+ raise ValueError(
516
+ "Invalid output schema type. Must be a dictionary or a pydantic Model."
517
+ )
491
518
  else:
492
519
  json_schema = None
493
520
 
521
+ results = []
494
522
  for model in model_list:
495
523
  res = self._run_one_batch_inference(
496
524
  data,
@@ -504,11 +532,16 @@ class Sutro:
504
532
  dry_run,
505
533
  stay_attached,
506
534
  random_seed_per_input,
507
- truncate_rows
535
+ truncate_rows,
508
536
  )
509
- if stay_attached:
510
- return res
537
+ results.append(res)
538
+
539
+ if len(results) > 1:
540
+ return results
541
+ elif len(results) == 1:
542
+ return results[0]
511
543
 
544
+ return None
512
545
 
513
546
  def attach(self, job_id):
514
547
  """
@@ -530,9 +563,9 @@ class Sutro:
530
563
  }
531
564
 
532
565
  with yaspin(
533
- SPINNER,
534
- text=to_colored_text("Looking for job..."),
535
- color=YASPIN_COLOR,
566
+ SPINNER,
567
+ text=to_colored_text("Looking for job..."),
568
+ color=YASPIN_COLOR,
536
569
  ) as spinner:
537
570
  # Fetch the specific job we want to attach to
538
571
  job = self._fetch_job(job_id)
@@ -550,10 +583,14 @@ class Sutro:
550
583
  )
551
584
  return
552
585
  case "FAILED":
553
- spinner.write(to_colored_text("❌ Job is in failed state.", state="fail"))
586
+ spinner.write(
587
+ to_colored_text("❌ Job is in failed state.", state="fail")
588
+ )
554
589
  return
555
590
  case "CANCELLED":
556
- spinner.write(to_colored_text("❌ Job was cancelled.", state="fail"))
591
+ spinner.write(
592
+ to_colored_text("❌ Job was cancelled.", state="fail")
593
+ )
557
594
  return
558
595
  case _:
559
596
  spinner.write(to_colored_text("✔ Job found!", state="success"))
@@ -563,9 +600,9 @@ class Sutro:
563
600
 
564
601
  try:
565
602
  with s.get(
566
- f"{self.base_url}/stream-job-progress/{job_id}",
567
- headers=headers,
568
- stream=True,
603
+ f"{self.base_url}/stream-job-progress/{job_id}",
604
+ headers=headers,
605
+ stream=True,
569
606
  ) as streaming_response:
570
607
  streaming_response.raise_for_status()
571
608
  spinner = yaspin(
@@ -573,8 +610,14 @@ class Sutro:
573
610
  text=to_colored_text("Awaiting status updates..."),
574
611
  color=YASPIN_COLOR,
575
612
  )
576
- clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
577
- spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
613
+ clickable_link = make_clickable_link(
614
+ f"https://app.sutro.sh/jobs/{job_id}"
615
+ )
616
+ spinner.write(
617
+ to_colored_text(
618
+ f"Progress can also be monitored at: {clickable_link}"
619
+ )
620
+ )
578
621
  spinner.start()
579
622
  for line in streaming_response.iter_lines():
580
623
  if line:
@@ -587,7 +630,7 @@ class Sutro:
587
630
  if json_obj["update_type"] == "progress":
588
631
  if pbar is None:
589
632
  spinner.stop()
590
- postfix = f"Input tokens processed: 0"
633
+ postfix = "Input tokens processed: 0"
591
634
  pbar = self.fancy_tqdm(
592
635
  total=total_rows,
593
636
  desc="Progress",
@@ -621,8 +664,6 @@ class Sutro:
621
664
  if spinner:
622
665
  spinner.stop()
623
666
 
624
-
625
-
626
667
  def fancy_tqdm(
627
668
  self,
628
669
  total: int,
@@ -738,7 +779,7 @@ class Sutro:
738
779
  response = requests.get(endpoint, headers=headers)
739
780
  if response.status_code != 200:
740
781
  return None
741
- return response.json().get('job')
782
+ return response.json().get("job")
742
783
 
743
784
  def _get_job_cost_estimate(self, job_id: str):
744
785
  """
@@ -748,8 +789,8 @@ class Sutro:
748
789
  if not job:
749
790
  return None
750
791
 
751
- return job.get('cost_estimate')
752
-
792
+ return job.get("cost_estimate")
793
+
753
794
  def _get_failure_reason(self, job_id: str):
754
795
  """
755
796
  Get the failure reason for a job.
@@ -757,7 +798,7 @@ class Sutro:
757
798
  job = self._fetch_job(job_id)
758
799
  if not job:
759
800
  return None
760
- return job.get('failure_reason')
801
+ return job.get("failure_reason")
761
802
 
762
803
  def _fetch_job_status(self, job_id: str):
763
804
  """
@@ -796,13 +837,15 @@ class Sutro:
796
837
  str: The status of the job.
797
838
  """
798
839
  with yaspin(
799
- SPINNER,
800
- text=to_colored_text(f"Checking job status with ID: {job_id}"),
801
- color=YASPIN_COLOR,
840
+ SPINNER,
841
+ text=to_colored_text(f"Checking job status with ID: {job_id}"),
842
+ color=YASPIN_COLOR,
802
843
  ) as spinner:
803
844
  try:
804
845
  response_data = self._fetch_job_status(job_id)
805
- spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
846
+ spinner.write(
847
+ to_colored_text("✔ Job status retrieved!", state="success")
848
+ )
806
849
  return response_data["job_status"][job_id]
807
850
  except requests.HTTPError as e:
808
851
  spinner.write(
@@ -842,14 +885,13 @@ class Sutro:
842
885
  Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
843
886
  """
844
887
 
845
-
846
888
  file_path = os.path.expanduser(f"~/.sutro/job-results/{job_id}.snappy.parquet")
847
889
  expected_num_columns = 1 + include_inputs + include_cumulative_logprobs
848
890
  contains_expected_columns = False
849
891
  if os.path.exists(file_path):
850
892
  num_columns = pq.read_table(file_path).num_columns
851
893
  contains_expected_columns = num_columns == expected_num_columns
852
-
894
+
853
895
  if disable_cache == False and contains_expected_columns:
854
896
  with yaspin(
855
897
  SPINNER,
@@ -857,7 +899,9 @@ class Sutro:
857
899
  color=YASPIN_COLOR,
858
900
  ) as spinner:
859
901
  results_df = pl.read_parquet(file_path)
860
- spinner.write(to_colored_text("✔ Results loaded from cache", state="success"))
902
+ spinner.write(
903
+ to_colored_text("✔ Results loaded from cache", state="success")
904
+ )
861
905
  else:
862
906
  endpoint = f"{self.base_url}/job-results"
863
907
  payload = {
@@ -894,38 +938,51 @@ class Sutro:
894
938
  response_data = response.json()
895
939
  results_df = pl.DataFrame(response_data["results"])
896
940
 
897
- results_df = results_df.rename({'outputs': output_column})
941
+ results_df = results_df.rename({"outputs": output_column})
898
942
 
899
943
  if disable_cache == False:
900
944
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
901
945
  results_df.write_parquet(file_path, compression="snappy")
902
- spinner.write(to_colored_text("✔ Results saved to cache", state="success"))
903
-
946
+ spinner.write(
947
+ to_colored_text("✔ Results saved to cache", state="success")
948
+ )
949
+
904
950
  # Ordering inputs col first seems most logical/useful
905
951
  column_config = [
906
- ('inputs', include_inputs),
952
+ ("inputs", include_inputs),
907
953
  (output_column, True),
908
- ('cumulative_logprobs', include_cumulative_logprobs),
954
+ ("cumulative_logprobs", include_cumulative_logprobs),
909
955
  ]
910
956
 
911
- columns_to_keep = [col for col, include in column_config
912
- if include and col in results_df.columns]
957
+ columns_to_keep = [
958
+ col
959
+ for col, include in column_config
960
+ if include and col in results_df.columns
961
+ ]
913
962
 
914
963
  results_df = results_df.select(columns_to_keep)
915
964
 
916
965
  if unpack_json:
917
966
  try:
918
- first_row = json.loads(results_df.head(1)[output_column][0]) # checks if the first row can be json decoded
967
+ first_row = json.loads(
968
+ results_df.head(1)[output_column][0]
969
+ ) # checks if the first row can be json decoded
919
970
  results_df = results_df.with_columns(
920
- pl.col(output_column).str.json_decode().alias("output_column_json_decoded")
971
+ pl.col(output_column)
972
+ .str.json_decode()
973
+ .alias("output_column_json_decoded")
921
974
  )
922
975
  json_decoded_fields = first_row.keys()
923
976
  for field in json_decoded_fields:
924
977
  results_df = results_df.with_columns(
925
- pl.col("output_column_json_decoded").struct.field(field).alias(field)
978
+ pl.col("output_column_json_decoded")
979
+ .struct.field(field)
980
+ .alias(field)
926
981
  )
927
982
  # drop the output_column and the json decoded column
928
- results_df = results_df.drop([output_column, "output_column_json_decoded"])
983
+ results_df = results_df.drop(
984
+ [output_column, "output_column_json_decoded"]
985
+ )
929
986
  except json.JSONDecodeError:
930
987
  # if the first row cannot be json decoded, do nothing
931
988
  pass
@@ -1011,7 +1068,9 @@ class Sutro:
1011
1068
  return
1012
1069
  dataset_id = response.json()["dataset_id"]
1013
1070
  spinner.write(
1014
- to_colored_text(f"✔ Dataset created with ID: {dataset_id}", state="success")
1071
+ to_colored_text(
1072
+ f"✔ Dataset created with ID: {dataset_id}", state="success"
1073
+ )
1015
1074
  )
1016
1075
  return dataset_id
1017
1076
 
@@ -1079,8 +1138,7 @@ class Sutro:
1079
1138
  "dataset_id": dataset_id,
1080
1139
  }
1081
1140
 
1082
- headers = {
1083
- "Authorization": f"Key {self.api_key}"}
1141
+ headers = {"Authorization": f"Key {self.api_key}"}
1084
1142
 
1085
1143
  count += 1
1086
1144
  spinner.write(
@@ -1164,7 +1222,9 @@ class Sutro:
1164
1222
  print(to_colored_text(f"Error: {response.json()}", state="fail"))
1165
1223
  return
1166
1224
  spinner.write(
1167
- to_colored_text(f"✔ Files listed in dataset: {dataset_id}", state="success")
1225
+ to_colored_text(
1226
+ f"✔ Files listed in dataset: {dataset_id}", state="success"
1227
+ )
1168
1228
  )
1169
1229
  return response.json()["files"]
1170
1230
 
@@ -1286,7 +1346,13 @@ class Sutro:
1286
1346
  return
1287
1347
  return response.json()["quotas"]
1288
1348
 
1289
- def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200, obtain_results: bool = True, is_cost_estimate: bool=False) -> list | None:
1349
+ def await_job_completion(
1350
+ self,
1351
+ job_id: str,
1352
+ timeout: Optional[int] = 7200,
1353
+ obtain_results: bool = True,
1354
+ is_cost_estimate: bool = False,
1355
+ ) -> list | None:
1290
1356
  """
1291
1357
  Waits for job completion to occur and then returns the results upon
1292
1358
  a successful completion.
@@ -1308,8 +1374,14 @@ class Sutro:
1308
1374
  SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1309
1375
  ) as spinner:
1310
1376
  if not is_cost_estimate:
1311
- clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
1312
- spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
1377
+ clickable_link = make_clickable_link(
1378
+ f"https://app.sutro.sh/jobs/{job_id}"
1379
+ )
1380
+ spinner.write(
1381
+ to_colored_text(
1382
+ f"Progress can also be monitored at: {clickable_link}"
1383
+ )
1384
+ )
1313
1385
  while (time.time() - start_time) < timeout:
1314
1386
  try:
1315
1387
  status = self._fetch_job_status(job_id)
@@ -1326,9 +1398,13 @@ class Sutro:
1326
1398
  spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
1327
1399
 
1328
1400
  if status == JobStatus.SUCCEEDED:
1329
- spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
1401
+ spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
1330
1402
  if obtain_results:
1331
- spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
1403
+ spinner.write(
1404
+ to_colored_text(
1405
+ "Job completed! Retrieving results...", "success"
1406
+ )
1407
+ )
1332
1408
  results = self.get_job_results(job_id)
1333
1409
  break
1334
1410
  if status == JobStatus.FAILED:
@@ -1338,12 +1414,11 @@ class Sutro:
1338
1414
  spinner.write(to_colored_text("Job has been cancelled"))
1339
1415
  return None
1340
1416
 
1341
-
1342
1417
  time.sleep(POLL_INTERVAL)
1343
1418
 
1344
1419
  return results
1345
1420
 
1346
- def _clear_job_results_cache(self): # only to be called by the CLI
1421
+ def _clear_job_results_cache(self): # only to be called by the CLI
1347
1422
  """
1348
1423
  Clears the cache for a job results.
1349
1424
  """
@@ -1356,29 +1431,41 @@ class Sutro:
1356
1431
  """
1357
1432
  # get the size of the job-results directory
1358
1433
  with yaspin(
1359
- SPINNER, text=to_colored_text("Retrieving job results cache contents"), color=YASPIN_COLOR
1434
+ SPINNER,
1435
+ text=to_colored_text("Retrieving job results cache contents"),
1436
+ color=YASPIN_COLOR,
1360
1437
  ) as spinner:
1361
1438
  if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
1362
1439
  spinner.write(to_colored_text("No job results cache found", "success"))
1363
1440
  return
1364
1441
  total_size = 0
1365
1442
  for file in os.listdir(os.path.expanduser("~/.sutro/job-results")):
1366
- size = os.path.getsize(os.path.expanduser(f"~/.sutro/job-results/{file}")) / 1024 / 1024 / 1024
1443
+ size = (
1444
+ os.path.getsize(os.path.expanduser(f"~/.sutro/job-results/{file}"))
1445
+ / 1024
1446
+ / 1024
1447
+ / 1024
1448
+ )
1367
1449
  total_size += size
1368
1450
  spinner.write(to_colored_text(f"File: {file} - Size: {size} GB"))
1369
- spinner.write(to_colored_text(f"Total size of results cache at ~/.sutro/job-results: {total_size} GB", "success"))
1370
-
1451
+ spinner.write(
1452
+ to_colored_text(
1453
+ f"Total size of results cache at ~/.sutro/job-results: {total_size} GB",
1454
+ "success",
1455
+ )
1456
+ )
1457
+
1371
1458
  def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
1372
1459
  """
1373
1460
  Waits for job start to occur and then returns the results upon
1374
1461
  a successful start.
1375
-
1462
+
1376
1463
  """
1377
1464
  POLL_INTERVAL = 5
1378
1465
 
1379
1466
  start_time = time.time()
1380
1467
  with yaspin(
1381
- SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1468
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1382
1469
  ) as spinner:
1383
1470
  while (time.time() - start_time) < timeout:
1384
1471
  try:
@@ -1405,4 +1492,3 @@ class Sutro:
1405
1492
  time.sleep(POLL_INTERVAL)
1406
1493
 
1407
1494
  return False
1408
-
File without changes
File without changes
File without changes
File without changes