sutro 0.1.16__tar.gz → 0.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.16
3
+ Version: 0.1.17
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -9,7 +9,7 @@ installer = "uv"
9
9
 
10
10
  [project]
11
11
  name = "sutro"
12
- version = "0.1.16"
12
+ version = "0.1.17"
13
13
  description = "Sutro Python SDK"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -1,6 +1,7 @@
1
1
  import threading
2
2
  from concurrent.futures import ThreadPoolExecutor
3
3
  from contextlib import contextmanager
4
+ from enum import Enum
4
5
 
5
6
  import requests
6
7
  import pandas as pd
@@ -17,6 +18,31 @@ import time
17
18
  from pydantic import BaseModel
18
19
  import json
19
20
 
21
+
22
+ class JobStatus(str, Enum):
23
+ """Job statuses that will be returned by the API & SDK"""
24
+
25
+ UNKNOWN = "UNKNOWN"
26
+ QUEUED = "QUEUED" # Job is waiting to start
27
+ STARTING = "STARTING" # Job is in the process of starting up
28
+ RUNNING = "RUNNING" # Job is actively running
29
+ SUCCEEDED = "SUCCEEDED" # Job completed successfully
30
+ CANCELLING = "CANCELLING" # Job is in the process of being canceled
31
+ CANCELLED = "CANCELLED" # Job was canceled by the user
32
+ FAILED = "FAILED" # Job failed
33
+
34
+ @classmethod
35
+ def terminal_statuses(cls) -> list["JobStatus"]:
36
+ return [
37
+ cls.SUCCEEDED,
38
+ cls.FAILED,
39
+ cls.CANCELLING,
40
+ cls.CANCELLED,
41
+ ]
42
+
43
+ def is_terminal(self) -> bool:
44
+ return self in self.terminal_statuses()
45
+
20
46
  # Initialize colorama (required for Windows)
21
47
  init()
22
48
 
@@ -181,7 +207,7 @@ class Sutro:
181
207
  sampling_params: dict = None,
182
208
  system_prompt: str = None,
183
209
  dry_run: bool = False,
184
- stay_attached: bool = False,
210
+ stay_attached: Optional[bool] = None,
185
211
  random_seed_per_input: bool = False,
186
212
  truncate_rows: bool = False
187
213
  ):
@@ -211,7 +237,7 @@ class Sutro:
211
237
 
212
238
  """
213
239
  input_data = self.handle_data_helper(data, column)
214
- stay_attached = stay_attached or job_priority == 0
240
+ stay_attached = stay_attached if stay_attached is not None else job_priority == 0
215
241
 
216
242
  # Convert BaseModel to dict if needed
217
243
  if output_schema is not None:
@@ -264,7 +290,7 @@ class Sutro:
264
290
  )
265
291
  spinner.stop()
266
292
  print(to_colored_text(response.json(), state="fail"))
267
- return
293
+ return None
268
294
  else:
269
295
  if dry_run:
270
296
  spinner.write(
@@ -375,7 +401,7 @@ class Sutro:
375
401
  )
376
402
  )
377
403
  spinner.stop()
378
- return
404
+ return None
379
405
 
380
406
  results = job_results_response.json()["results"]
381
407
 
@@ -401,6 +427,8 @@ class Sutro:
401
427
  return data
402
428
 
403
429
  return results
430
+ return None
431
+ return None
404
432
 
405
433
  def register_stream_listener(self, job_id: str) -> str:
406
434
  """Register a new stream listener and get a session token."""
@@ -691,40 +719,60 @@ class Sutro:
691
719
  return
692
720
  return response.json()["jobs"]
693
721
 
694
- def get_job_status(self, job_id: str):
722
+ def _fetch_job_status(self, job_id: str):
695
723
  """
696
- Get the status of a job by its ID.
697
-
698
- This method retrieves the status of a job using its unique identifier.
724
+ Core logic to fetch job status from the API.
699
725
 
700
726
  Args:
701
727
  job_id (str): The ID of the job to retrieve the status for.
702
728
 
703
729
  Returns:
704
- str: The status of the job.
730
+ dict: The response JSON from the API.
731
+
732
+ Raises:
733
+ requests.HTTPError: If the API returns a non-200 status code.
705
734
  """
706
735
  endpoint = f"{self.base_url}/job-status/{job_id}"
707
736
  headers = {
708
737
  "Authorization": f"Key {self.api_key}",
709
738
  "Content-Type": "application/json",
710
739
  }
740
+
741
+ response = requests.get(endpoint, headers=headers)
742
+ response.raise_for_status()
743
+
744
+ return response.json()["job_status"][job_id]
745
+
746
+ def get_job_status(self, job_id: str):
747
+ """
748
+ Get the status of a job by its ID.
749
+
750
+ This method retrieves the status of a job using its unique identifier.
751
+
752
+ Args:
753
+ job_id (str): The ID of the job to retrieve the status for.
754
+
755
+ Returns:
756
+ str: The status of the job.
757
+ """
711
758
  with yaspin(
712
- SPINNER,
713
- text=to_colored_text(f"Checking job status with ID: {job_id}"),
714
- color=YASPIN_COLOR,
759
+ SPINNER,
760
+ text=to_colored_text(f"Checking job status with ID: {job_id}"),
761
+ color=YASPIN_COLOR,
715
762
  ) as spinner:
716
- response = requests.get(endpoint, headers=headers)
717
- if response.status_code != 200:
763
+ try:
764
+ response_data = self._fetch_job_status(job_id)
765
+ spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
766
+ return response_data["job_status"][job_id]
767
+ except requests.HTTPError as e:
718
768
  spinner.write(
719
769
  to_colored_text(
720
- f"Bad status code: {response.status_code}", state="fail"
770
+ f"Bad status code: {e.response.status_code}", state="fail"
721
771
  )
722
772
  )
723
773
  spinner.stop()
724
- print(to_colored_text(response.json(), state="fail"))
725
- return
726
- spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
727
- return response.json()["job_status"][job_id]
774
+ print(to_colored_text(e.response.json(), state="fail"))
775
+ return None
728
776
 
729
777
  def get_job_results(
730
778
  self,
@@ -1113,3 +1161,56 @@ class Sutro:
1113
1161
  print(to_colored_text(f"Error: {response.json()}", state="fail"))
1114
1162
  return
1115
1163
  return response.json()["quotas"]
1164
+
1165
+ def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200) -> list | None:
1166
+ """
1167
+ Waits for job completion to occur and then returns the results upon
1168
+ a successful completion.
1169
+
1170
+ Prints out the job's status every 5 seconds.
1171
+
1172
+ Args:
1173
+ job_id (str): The ID of the job to await.
1174
+ timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
1175
+
1176
+ Returns:
1177
+ list: The results of the job.
1178
+ """
1179
+ POLL_INTERVAL = 5
1180
+
1181
+ results = None
1182
+ start_time = time.time()
1183
+ with yaspin(
1184
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1185
+ ) as spinner:
1186
+ while (time.time() - start_time) < timeout:
1187
+ try:
1188
+ status = self._fetch_job_status(job_id)
1189
+ except requests.HTTPError as e:
1190
+ spinner.write(
1191
+ to_colored_text(
1192
+ f"Bad status code: {e.response.status_code}", state="fail"
1193
+ )
1194
+ )
1195
+ spinner.stop()
1196
+ print(to_colored_text(e.response.json(), state="fail"))
1197
+ return None
1198
+
1199
+ spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
1200
+
1201
+ if status == JobStatus.SUCCEEDED:
1202
+ spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
1203
+ spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
1204
+ results = self.get_job_results(job_id)
1205
+ break
1206
+ if status == JobStatus.FAILED:
1207
+ spinner.write(to_colored_text("Job has failed", "fail"))
1208
+ return None
1209
+ if status == JobStatus.CANCELLED:
1210
+ spinner.write(to_colored_text("Job has been cancelled"))
1211
+ return None
1212
+
1213
+
1214
+ time.sleep(POLL_INTERVAL)
1215
+
1216
+ return results
File without changes
File without changes
File without changes
File without changes
File without changes