sutro 0.1.16__tar.gz → 0.1.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.16
3
+ Version: 0.1.18
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -9,7 +9,7 @@ installer = "uv"
9
9
 
10
10
  [project]
11
11
  name = "sutro"
12
- version = "0.1.16"
12
+ version = "0.1.18"
13
13
  description = "Sutro Python SDK"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -1,6 +1,7 @@
1
1
  import threading
2
2
  from concurrent.futures import ThreadPoolExecutor
3
3
  from contextlib import contextmanager
4
+ from enum import Enum
4
5
 
5
6
  import requests
6
7
  import pandas as pd
@@ -17,6 +18,31 @@ import time
17
18
  from pydantic import BaseModel
18
19
  import json
19
20
 
21
+
22
+ class JobStatus(str, Enum):
23
+ """Job statuses that will be returned by the API & SDK"""
24
+
25
+ UNKNOWN = "UNKNOWN"
26
+ QUEUED = "QUEUED" # Job is waiting to start
27
+ STARTING = "STARTING" # Job is in the process of starting up
28
+ RUNNING = "RUNNING" # Job is actively running
29
+ SUCCEEDED = "SUCCEEDED" # Job completed successfully
30
+ CANCELLING = "CANCELLING" # Job is in the process of being canceled
31
+ CANCELLED = "CANCELLED" # Job was canceled by the user
32
+ FAILED = "FAILED" # Job failed
33
+
34
+ @classmethod
35
+ def terminal_statuses(cls) -> list["JobStatus"]:
36
+ return [
37
+ cls.SUCCEEDED,
38
+ cls.FAILED,
39
+ cls.CANCELLING,
40
+ cls.CANCELLED,
41
+ ]
42
+
43
+ def is_terminal(self) -> bool:
44
+ return self in self.terminal_statuses()
45
+
20
46
  # Initialize colorama (required for Windows)
21
47
  init()
22
48
 
@@ -181,7 +207,7 @@ class Sutro:
181
207
  sampling_params: dict = None,
182
208
  system_prompt: str = None,
183
209
  dry_run: bool = False,
184
- stay_attached: bool = False,
210
+ stay_attached: Optional[bool] = None,
185
211
  random_seed_per_input: bool = False,
186
212
  truncate_rows: bool = False
187
213
  ):
@@ -211,7 +237,7 @@ class Sutro:
211
237
 
212
238
  """
213
239
  input_data = self.handle_data_helper(data, column)
214
- stay_attached = stay_attached or job_priority == 0
240
+ stay_attached = stay_attached if stay_attached is not None else job_priority == 0
215
241
 
216
242
  # Convert BaseModel to dict if needed
217
243
  if output_schema is not None:
@@ -240,11 +266,6 @@ class Sutro:
240
266
  "random_seed_per_input": random_seed_per_input,
241
267
  "truncate_rows": truncate_rows
242
268
  }
243
- if dry_run:
244
- spinner_text = to_colored_text("Retrieving cost estimates...")
245
- else:
246
- t = f"Creating priority {job_priority} job"
247
- spinner_text = to_colored_text(t)
248
269
 
249
270
  # There are two gotchas with yaspin:
250
271
  # 1. Can't use print while in spinner is running
@@ -253,6 +274,8 @@ class Sutro:
253
274
  # Terminal size {self._terminal_width} is too small to display spinner with the given settings.
254
275
  # https://github.com/pavdmyt/yaspin/blob/9c7430b499ab4611888ece39783a870e4a05fa45/yaspin/core.py#L568-L571
255
276
  job_id = None
277
+ t = f"Creating {'[dry run] ' if dry_run else ''}priority {job_priority} job"
278
+ spinner_text = to_colored_text(t)
256
279
  with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
257
280
  response = requests.post(
258
281
  endpoint, data=json.dumps(payload), headers=headers
@@ -264,15 +287,21 @@ class Sutro:
264
287
  )
265
288
  spinner.stop()
266
289
  print(to_colored_text(response.json(), state="fail"))
267
- return
290
+ return None
268
291
  else:
292
+ job_id = response_data["results"]
269
293
  if dry_run:
270
294
  spinner.write(
271
- to_colored_text(" Cost estimates retrieved", state="success")
295
+ to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
272
296
  )
273
- return response_data["results"]
297
+ spinner.stop()
298
+ self.await_job_completion(job_id, obtain_results=False)
299
+ cost_estimate = self._get_job_cost_estimate(job_id)
300
+ spinner.write(
301
+ to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
302
+ )
303
+ return job_id
274
304
  else:
275
- job_id = response_data["results"]
276
305
  spinner.write(
277
306
  to_colored_text(
278
307
  f"🛠️ Priority {job_priority} Job created with ID: {job_id}",
@@ -289,6 +318,13 @@ class Sutro:
289
318
 
290
319
  success = False
291
320
  if stay_attached and job_id is not None:
321
+ spinner.write(to_colored_text("Awaiting job start...", "info"))
322
+ started = self._await_job_start(job_id)
323
+ if not started:
324
+ failure_reason = self._get_failure_reason(job_id)
325
+ spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
326
+ return None
327
+
292
328
  s = requests.Session()
293
329
  payload = {
294
330
  "job_id": job_id,
@@ -375,7 +411,7 @@ class Sutro:
375
411
  )
376
412
  )
377
413
  spinner.stop()
378
- return
414
+ return None
379
415
 
380
416
  results = job_results_response.json()["results"]
381
417
 
@@ -401,6 +437,8 @@ class Sutro:
401
437
  return data
402
438
 
403
439
  return results
440
+ return None
441
+ return None
404
442
 
405
443
  def register_stream_listener(self, job_id: str) -> str:
406
444
  """Register a new stream listener and get a session token."""
@@ -691,40 +729,94 @@ class Sutro:
691
729
  return
692
730
  return response.json()["jobs"]
693
731
 
694
- def get_job_status(self, job_id: str):
732
+ def _list_jobs_helper(self):
695
733
  """
696
- Get the status of a job by its ID.
734
+ Helper function to list jobs.
735
+ """
736
+ endpoint = f"{self.base_url}/list-jobs"
737
+ headers = {
738
+ "Authorization": f"Key {self.api_key}",
739
+ "Content-Type": "application/json",
740
+ }
741
+ response = requests.get(endpoint, headers=headers)
742
+ if response.status_code != 200:
743
+ return None
744
+ return response.json()["jobs"]
697
745
 
698
- This method retrieves the status of a job using its unique identifier.
746
+ def _get_job_cost_estimate(self, job_id: str):
747
+ """
748
+ Get the cost estimate for a job.
749
+ """
750
+ all_jobs = self._list_jobs_helper()
751
+ for job in all_jobs:
752
+ if job["job_id"] == job_id:
753
+ return job["cost_estimate"]
754
+ return None
755
+
756
+ def _get_failure_reason(self, job_id: str):
757
+ """
758
+ Get the failure reason for a job.
759
+ """
760
+ all_jobs = self._list_jobs_helper()
761
+ for job in all_jobs:
762
+ if job["job_id"] == job_id:
763
+ return job["failure_reason"]
764
+ return None
765
+
766
+ def _fetch_job_status(self, job_id: str):
767
+ """
768
+ Core logic to fetch job status from the API.
699
769
 
700
770
  Args:
701
771
  job_id (str): The ID of the job to retrieve the status for.
702
772
 
703
773
  Returns:
704
- str: The status of the job.
774
+ dict: The response JSON from the API.
775
+
776
+ Raises:
777
+ requests.HTTPError: If the API returns a non-200 status code.
705
778
  """
706
779
  endpoint = f"{self.base_url}/job-status/{job_id}"
707
780
  headers = {
708
781
  "Authorization": f"Key {self.api_key}",
709
782
  "Content-Type": "application/json",
710
783
  }
784
+
785
+ response = requests.get(endpoint, headers=headers)
786
+ response.raise_for_status()
787
+
788
+ return response.json()["job_status"][job_id]
789
+
790
+ def get_job_status(self, job_id: str):
791
+ """
792
+ Get the status of a job by its ID.
793
+
794
+ This method retrieves the status of a job using its unique identifier.
795
+
796
+ Args:
797
+ job_id (str): The ID of the job to retrieve the status for.
798
+
799
+ Returns:
800
+ str: The status of the job.
801
+ """
711
802
  with yaspin(
712
- SPINNER,
713
- text=to_colored_text(f"Checking job status with ID: {job_id}"),
714
- color=YASPIN_COLOR,
803
+ SPINNER,
804
+ text=to_colored_text(f"Checking job status with ID: {job_id}"),
805
+ color=YASPIN_COLOR,
715
806
  ) as spinner:
716
- response = requests.get(endpoint, headers=headers)
717
- if response.status_code != 200:
807
+ try:
808
+ response_data = self._fetch_job_status(job_id)
809
+ spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
810
+ return response_data["job_status"][job_id]
811
+ except requests.HTTPError as e:
718
812
  spinner.write(
719
813
  to_colored_text(
720
- f"Bad status code: {response.status_code}", state="fail"
814
+ f"Bad status code: {e.response.status_code}", state="fail"
721
815
  )
722
816
  )
723
817
  spinner.stop()
724
- print(to_colored_text(response.json(), state="fail"))
725
- return
726
- spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
727
- return response.json()["job_status"][job_id]
818
+ print(to_colored_text(e.response.json(), state="fail"))
819
+ return None
728
820
 
729
821
  def get_job_results(
730
822
  self,
@@ -1113,3 +1205,96 @@ class Sutro:
1113
1205
  print(to_colored_text(f"Error: {response.json()}", state="fail"))
1114
1206
  return
1115
1207
  return response.json()["quotas"]
1208
+
1209
+ def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200, obtain_results: bool = True) -> list | None:
1210
+ """
1211
+ Waits for job completion to occur and then returns the results upon
1212
+ a successful completion.
1213
+
1214
+ Prints out the job's status every 5 seconds.
1215
+
1216
+ Args:
1217
+ job_id (str): The ID of the job to await.
1218
+ timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
1219
+
1220
+ Returns:
1221
+ list: The results of the job.
1222
+ """
1223
+ POLL_INTERVAL = 5
1224
+
1225
+ results = None
1226
+ start_time = time.time()
1227
+ with yaspin(
1228
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1229
+ ) as spinner:
1230
+ while (time.time() - start_time) < timeout:
1231
+ try:
1232
+ status = self._fetch_job_status(job_id)
1233
+ except requests.HTTPError as e:
1234
+ spinner.write(
1235
+ to_colored_text(
1236
+ f"Bad status code: {e.response.status_code}", state="fail"
1237
+ )
1238
+ )
1239
+ spinner.stop()
1240
+ print(to_colored_text(e.response.json(), state="fail"))
1241
+ return None
1242
+
1243
+ spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
1244
+
1245
+ if status == JobStatus.SUCCEEDED:
1246
+ spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
1247
+ spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
1248
+ if obtain_results:
1249
+ results = self.get_job_results(job_id)
1250
+ break
1251
+ if status == JobStatus.FAILED:
1252
+ spinner.write(to_colored_text("Job has failed", "fail"))
1253
+ return None
1254
+ if status == JobStatus.CANCELLED:
1255
+ spinner.write(to_colored_text("Job has been cancelled"))
1256
+ return None
1257
+
1258
+
1259
+ time.sleep(POLL_INTERVAL)
1260
+
1261
+ return results
1262
+
1263
+ def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
1264
+ """
1265
+ Waits for job start to occur and then returns the results upon
1266
+ a successful start.
1267
+
1268
+ """
1269
+ POLL_INTERVAL = 5
1270
+
1271
+ start_time = time.time()
1272
+ with yaspin(
1273
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1274
+ ) as spinner:
1275
+ while (time.time() - start_time) < timeout:
1276
+ try:
1277
+ status = self._fetch_job_status(job_id)
1278
+ except requests.HTTPError as e:
1279
+ spinner.write(
1280
+ to_colored_text(
1281
+ f"Bad status code: {e.response.status_code}", state="fail"
1282
+ )
1283
+ )
1284
+ spinner.stop()
1285
+ print(to_colored_text(e.response.json(), state="fail"))
1286
+ return None
1287
+
1288
+ spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
1289
+
1290
+ if status == JobStatus.RUNNING or status == JobStatus.STARTING:
1291
+ return True
1292
+ if status == JobStatus.FAILED:
1293
+ return False
1294
+ if status == JobStatus.CANCELLED:
1295
+ return False
1296
+
1297
+ time.sleep(POLL_INTERVAL)
1298
+
1299
+ return False
1300
+
File without changes
File without changes
File without changes
File without changes
File without changes