sutro 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/sdk.py CHANGED
@@ -97,6 +97,16 @@ def to_colored_text(
97
97
  # Default to blue for normal/processing states
98
98
  return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
99
99
 
100
+ # Isn't fully support in all terminals unfortunately. We should switch to Rich
101
+ # at some point, but even Rich links aren't clickable on MacOS Terminal
102
+ def make_clickable_link(url, text=None):
103
+ """
104
+ Create a clickable link for terminals that support OSC 8 hyperlinks.
105
+ Falls back to plain text for terminals that don't support it.
106
+ """
107
+ if text is None:
108
+ text = url
109
+ return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
100
110
 
101
111
  class Sutro:
102
112
  def __init__(
@@ -266,11 +276,6 @@ class Sutro:
266
276
  "random_seed_per_input": random_seed_per_input,
267
277
  "truncate_rows": truncate_rows
268
278
  }
269
- if dry_run:
270
- spinner_text = to_colored_text("Retrieving cost estimates...")
271
- else:
272
- t = f"Creating priority {job_priority} job"
273
- spinner_text = to_colored_text(t)
274
279
 
275
280
  # There are two gotchas with yaspin:
276
281
  # 1. Can't use print while in spinner is running
@@ -279,6 +284,8 @@ class Sutro:
279
284
  # Terminal size {self._terminal_width} is too small to display spinner with the given settings.
280
285
  # https://github.com/pavdmyt/yaspin/blob/9c7430b499ab4611888ece39783a870e4a05fa45/yaspin/core.py#L568-L571
281
286
  job_id = None
287
+ t = f"Creating {'[dry run] ' if dry_run else ''}priority {job_priority} job"
288
+ spinner_text = to_colored_text(t)
282
289
  with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
283
290
  response = requests.post(
284
291
  endpoint, data=json.dumps(payload), headers=headers
@@ -292,16 +299,22 @@ class Sutro:
292
299
  print(to_colored_text(response.json(), state="fail"))
293
300
  return None
294
301
  else:
302
+ job_id = response_data["results"]
295
303
  if dry_run:
296
304
  spinner.write(
297
- to_colored_text(" Cost estimates retrieved", state="success")
305
+ to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
306
+ )
307
+ spinner.stop()
308
+ self.await_job_completion(job_id, obtain_results=False)
309
+ cost_estimate = self._get_job_cost_estimate(job_id)
310
+ spinner.write(
311
+ to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
298
312
  )
299
- return response_data["results"]
313
+ return job_id
300
314
  else:
301
- job_id = response_data["results"]
302
315
  spinner.write(
303
316
  to_colored_text(
304
- f"🛠️ Priority {job_priority} Job created with ID: {job_id}",
317
+ f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
305
318
  state="success",
306
319
  )
307
320
  )
@@ -309,12 +322,20 @@ class Sutro:
309
322
  spinner.write(
310
323
  to_colored_text(
311
324
  f"Use `so.get_job_status('{job_id}')` to check the status of the job."
325
+ )
312
326
  )
313
- )
314
327
  return job_id
315
328
 
316
329
  success = False
317
330
  if stay_attached and job_id is not None:
331
+ spinner.write(to_colored_text("Awaiting job start...", ))
332
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
333
+ started = self._await_job_start(job_id)
334
+ if not started:
335
+ failure_reason = self._get_failure_reason(job_id)
336
+ spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
337
+ return None
338
+
318
339
  s = requests.Session()
319
340
  payload = {
320
341
  "job_id": job_id,
@@ -589,6 +610,7 @@ class Sutro:
589
610
  text=to_colored_text("Awaiting status updates..."),
590
611
  color=YASPIN_COLOR,
591
612
  )
613
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
592
614
  spinner.start()
593
615
  for line in streaming_response.iter_lines():
594
616
  if line:
@@ -719,6 +741,40 @@ class Sutro:
719
741
  return
720
742
  return response.json()["jobs"]
721
743
 
744
+ def _list_jobs_helper(self):
745
+ """
746
+ Helper function to list jobs.
747
+ """
748
+ endpoint = f"{self.base_url}/list-jobs"
749
+ headers = {
750
+ "Authorization": f"Key {self.api_key}",
751
+ "Content-Type": "application/json",
752
+ }
753
+ response = requests.get(endpoint, headers=headers)
754
+ if response.status_code != 200:
755
+ return None
756
+ return response.json()["jobs"]
757
+
758
+ def _get_job_cost_estimate(self, job_id: str):
759
+ """
760
+ Get the cost estimate for a job.
761
+ """
762
+ all_jobs = self._list_jobs_helper()
763
+ for job in all_jobs:
764
+ if job["job_id"] == job_id:
765
+ return job["cost_estimate"]
766
+ return None
767
+
768
+ def _get_failure_reason(self, job_id: str):
769
+ """
770
+ Get the failure reason for a job.
771
+ """
772
+ all_jobs = self._list_jobs_helper()
773
+ for job in all_jobs:
774
+ if job["job_id"] == job_id:
775
+ return job["failure_reason"]
776
+ return None
777
+
722
778
  def _fetch_job_status(self, job_id: str):
723
779
  """
724
780
  Core logic to fetch job status from the API.
@@ -775,10 +831,12 @@ class Sutro:
775
831
  return None
776
832
 
777
833
  def get_job_results(
778
- self,
779
- job_id: str,
780
- include_inputs: bool = False,
781
- include_cumulative_logprobs: bool = False,
834
+ self,
835
+ job_id: str,
836
+ include_inputs: bool = False,
837
+ include_cumulative_logprobs: bool = False,
838
+ with_original_df: pl.DataFrame | pd.DataFrame = None,
839
+ output_column: str = "inference_result",
782
840
  ):
783
841
  """
784
842
  Get the results of a job by its ID.
@@ -789,9 +847,11 @@ class Sutro:
789
847
  job_id (str): The ID of the job to retrieve the results for.
790
848
  include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
791
849
  include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
850
+ with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
851
+ output_column (str, optional): Name of the output column. Defaults to "inference_result".
792
852
 
793
853
  Returns:
794
- list: The results of the job.
854
+ Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
795
855
  """
796
856
  endpoint = f"{self.base_url}/job-results"
797
857
  payload = {
@@ -804,18 +864,14 @@ class Sutro:
804
864
  "Content-Type": "application/json",
805
865
  }
806
866
  with yaspin(
807
- SPINNER,
808
- text=to_colored_text(f"Gathering results from job: {job_id}"),
809
- color=YASPIN_COLOR,
867
+ SPINNER,
868
+ text=to_colored_text(f"Gathering results from job: {job_id}"),
869
+ color=YASPIN_COLOR,
810
870
  ) as spinner:
811
871
  response = requests.post(
812
872
  endpoint, data=json.dumps(payload), headers=headers
813
873
  )
814
- if response.status_code == 200:
815
- spinner.write(
816
- to_colored_text("✔ Job results retrieved", state="success")
817
- )
818
- else:
874
+ if response.status_code != 200:
819
875
  spinner.write(
820
876
  to_colored_text(
821
877
  f"Bad status code: {response.status_code}", state="fail"
@@ -823,8 +879,56 @@ class Sutro:
823
879
  )
824
880
  spinner.stop()
825
881
  print(to_colored_text(response.json(), state="fail"))
826
- return
827
- return response.json()["results"]
882
+ return None
883
+
884
+ spinner.write(
885
+ to_colored_text("✔ Job results retrieved", state="success")
886
+ )
887
+
888
+ response_data = response.json()
889
+ results_df = pl.DataFrame(response_data["results"])
890
+
891
+
892
+ if len(results_df.columns ) == 1:
893
+ # Default column when API is only returning a list, and we construct the df
894
+ # from that
895
+ original_results_column = 'column_0'
896
+ else:
897
+ original_results_column = 'outputs'
898
+
899
+ results_df = results_df.rename({original_results_column: output_column})
900
+
901
+ # Ordering inputs col first seems most logical/useful
902
+ column_config = [
903
+ ('inputs', include_inputs),
904
+ (output_column, True),
905
+ ('cumulative_logprobs', include_cumulative_logprobs),
906
+ ]
907
+
908
+ columns_to_keep = [col for col, include in column_config
909
+ if include and col in results_df.columns]
910
+
911
+ results_df = results_df.select(columns_to_keep)
912
+
913
+ # Handle concatenation with original DataFrame
914
+ if with_original_df is not None:
915
+ if isinstance(with_original_df, pd.DataFrame):
916
+ # Convert to polars for consistent handling
917
+ original_pl = pl.from_pandas(with_original_df)
918
+
919
+ combined_df = original_pl.with_columns(results_df)
920
+
921
+ # Convert back to pandas to match input type
922
+ return combined_df.to_pandas()
923
+
924
+ elif isinstance(with_original_df, pl.DataFrame):
925
+ return with_original_df.with_columns(results_df)
926
+
927
+ # Return pd.DataFrame type when appropriate
928
+ if with_original_df is None and isinstance(with_original_df, pd.DataFrame):
929
+ return results_df.to_pandas()
930
+
931
+ return results_df
828
932
 
829
933
  def cancel_job(self, job_id: str):
830
934
  """
@@ -1162,7 +1266,7 @@ class Sutro:
1162
1266
  return
1163
1267
  return response.json()["quotas"]
1164
1268
 
1165
- def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200) -> list | None:
1269
+ def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200, obtain_results: bool = True) -> list | None:
1166
1270
  """
1167
1271
  Waits for job completion to occur and then returns the results upon
1168
1272
  a successful completion.
@@ -1181,8 +1285,9 @@ class Sutro:
1181
1285
  results = None
1182
1286
  start_time = time.time()
1183
1287
  with yaspin(
1184
- SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1288
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1185
1289
  ) as spinner:
1290
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
1186
1291
  while (time.time() - start_time) < timeout:
1187
1292
  try:
1188
1293
  status = self._fetch_job_status(job_id)
@@ -1201,7 +1306,8 @@ class Sutro:
1201
1306
  if status == JobStatus.SUCCEEDED:
1202
1307
  spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
1203
1308
  spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
1204
- results = self.get_job_results(job_id)
1309
+ if obtain_results:
1310
+ results = self.get_job_results(job_id)
1205
1311
  break
1206
1312
  if status == JobStatus.FAILED:
1207
1313
  spinner.write(to_colored_text("Job has failed", "fail"))
@@ -1213,4 +1319,43 @@ class Sutro:
1213
1319
 
1214
1320
  time.sleep(POLL_INTERVAL)
1215
1321
 
1216
- return results
1322
+ return results
1323
+
1324
+ def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
1325
+ """
1326
+ Waits for job start to occur and then returns the results upon
1327
+ a successful start.
1328
+
1329
+ """
1330
+ POLL_INTERVAL = 5
1331
+
1332
+ start_time = time.time()
1333
+ with yaspin(
1334
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1335
+ ) as spinner:
1336
+ while (time.time() - start_time) < timeout:
1337
+ try:
1338
+ status = self._fetch_job_status(job_id)
1339
+ except requests.HTTPError as e:
1340
+ spinner.write(
1341
+ to_colored_text(
1342
+ f"Bad status code: {e.response.status_code}", state="fail"
1343
+ )
1344
+ )
1345
+ spinner.stop()
1346
+ print(to_colored_text(e.response.json(), state="fail"))
1347
+ return None
1348
+
1349
+ spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
1350
+
1351
+ if status == JobStatus.RUNNING or status == JobStatus.STARTING:
1352
+ return True
1353
+ if status == JobStatus.FAILED:
1354
+ return False
1355
+ if status == JobStatus.CANCELLED:
1356
+ return False
1357
+
1358
+ time.sleep(POLL_INTERVAL)
1359
+
1360
+ return False
1361
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.17
3
+ Version: 0.1.19
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -0,0 +1,8 @@
1
+ sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
2
+ sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
3
+ sutro/sdk.py,sha256=1FLepL3M7afptWwudF310KWmPQ9ZDQ51LiT0Xoh1S_o,52705
4
+ sutro-0.1.19.dist-info/METADATA,sha256=6Irt5RX_DaIyRC6ig4XvHr0p0CKnFK0nmWA23oi1mCA,669
5
+ sutro-0.1.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ sutro-0.1.19.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
7
+ sutro-0.1.19.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ sutro-0.1.19.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
2
- sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
3
- sutro/sdk.py,sha256=z3cvU9zyAi7EJF_rFIb9mQMjE9zuh576-p8KuP_F1PM,46500
4
- sutro-0.1.17.dist-info/METADATA,sha256=UDHp7-xr8tdKmDl_GbxxRbVVC6nx62yr6-nka4vBwOY,669
5
- sutro-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- sutro-0.1.17.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
7
- sutro-0.1.17.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- sutro-0.1.17.dist-info/RECORD,,
File without changes