sutro 0.1.15__tar.gz → 0.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -9,7 +9,7 @@ installer = "uv"
9
9
 
10
10
  [project]
11
11
  name = "sutro"
12
- version = "0.1.15"
12
+ version = "0.1.17"
13
13
  description = "Sutro Python SDK"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -1,6 +1,7 @@
1
1
  import threading
2
2
  from concurrent.futures import ThreadPoolExecutor
3
3
  from contextlib import contextmanager
4
+ from enum import Enum
4
5
 
5
6
  import requests
6
7
  import pandas as pd
@@ -17,6 +18,31 @@ import time
17
18
  from pydantic import BaseModel
18
19
  import json
19
20
 
21
+
22
+ class JobStatus(str, Enum):
23
+ """Job statuses that will be returned by the API & SDK"""
24
+
25
+ UNKNOWN = "UNKNOWN"
26
+ QUEUED = "QUEUED" # Job is waiting to start
27
+ STARTING = "STARTING" # Job is in the process of starting up
28
+ RUNNING = "RUNNING" # Job is actively running
29
+ SUCCEEDED = "SUCCEEDED" # Job completed successfully
30
+ CANCELLING = "CANCELLING" # Job is in the process of being canceled
31
+ CANCELLED = "CANCELLED" # Job was canceled by the user
32
+ FAILED = "FAILED" # Job failed
33
+
34
+ @classmethod
35
+ def terminal_statuses(cls) -> list["JobStatus"]:
36
+ return [
37
+ cls.SUCCEEDED,
38
+ cls.FAILED,
39
+ cls.CANCELLING,
40
+ cls.CANCELLED,
41
+ ]
42
+
43
+ def is_terminal(self) -> bool:
44
+ return self in self.terminal_statuses()
45
+
20
46
  # Initialize colorama (required for Windows)
21
47
  init()
22
48
 
@@ -35,16 +61,14 @@ SPINNER = Spinners.dots14
35
61
  ModelOptions = Literal[
36
62
  "llama-3.2-3b",
37
63
  "llama-3.1-8b",
38
- "llama-3.3-70b-8k",
39
- "llama-3.3-70b-64k",
40
- "qwen-qwq-32b-8k",
64
+ "llama-3.3-70b",
65
+ "llama-3.3-70b",
41
66
  "qwen-3-4b",
42
67
  "qwen-3-32b",
43
68
  "qwen-3-4b-thinking",
44
69
  "qwen-3-32b-thinking",
45
70
  "gemma-3-4b-it",
46
- "gemma-3-27b-it-16k",
47
- "gemma-3-27b-it-128k",
71
+ "gemma-3-27b-it",
48
72
  "multilingual-e5-large-instruct",
49
73
  "gte-qwen2-7b-instruct",
50
74
  ]
@@ -183,7 +207,7 @@ class Sutro:
183
207
  sampling_params: dict = None,
184
208
  system_prompt: str = None,
185
209
  dry_run: bool = False,
186
- stay_attached: bool = False,
210
+ stay_attached: Optional[bool] = None,
187
211
  random_seed_per_input: bool = False,
188
212
  truncate_rows: bool = False
189
213
  ):
@@ -213,7 +237,7 @@ class Sutro:
213
237
 
214
238
  """
215
239
  input_data = self.handle_data_helper(data, column)
216
- stay_attached = stay_attached or job_priority == 0
240
+ stay_attached = stay_attached if stay_attached is not None else job_priority == 0
217
241
 
218
242
  # Convert BaseModel to dict if needed
219
243
  if output_schema is not None:
@@ -266,7 +290,7 @@ class Sutro:
266
290
  )
267
291
  spinner.stop()
268
292
  print(to_colored_text(response.json(), state="fail"))
269
- return
293
+ return None
270
294
  else:
271
295
  if dry_run:
272
296
  spinner.write(
@@ -377,7 +401,7 @@ class Sutro:
377
401
  )
378
402
  )
379
403
  spinner.stop()
380
- return
404
+ return None
381
405
 
382
406
  results = job_results_response.json()["results"]
383
407
 
@@ -403,6 +427,8 @@ class Sutro:
403
427
  return data
404
428
 
405
429
  return results
430
+ return None
431
+ return None
406
432
 
407
433
  def register_stream_listener(self, job_id: str) -> str:
408
434
  """Register a new stream listener and get a session token."""
@@ -693,40 +719,60 @@ class Sutro:
693
719
  return
694
720
  return response.json()["jobs"]
695
721
 
696
- def get_job_status(self, job_id: str):
722
+ def _fetch_job_status(self, job_id: str):
697
723
  """
698
- Get the status of a job by its ID.
699
-
700
- This method retrieves the status of a job using its unique identifier.
724
+ Core logic to fetch job status from the API.
701
725
 
702
726
  Args:
703
727
  job_id (str): The ID of the job to retrieve the status for.
704
728
 
705
729
  Returns:
706
- str: The status of the job.
730
+ dict: The response JSON from the API.
731
+
732
+ Raises:
733
+ requests.HTTPError: If the API returns a non-200 status code.
707
734
  """
708
735
  endpoint = f"{self.base_url}/job-status/{job_id}"
709
736
  headers = {
710
737
  "Authorization": f"Key {self.api_key}",
711
738
  "Content-Type": "application/json",
712
739
  }
740
+
741
+ response = requests.get(endpoint, headers=headers)
742
+ response.raise_for_status()
743
+
744
+ return response.json()["job_status"][job_id]
745
+
746
+ def get_job_status(self, job_id: str):
747
+ """
748
+ Get the status of a job by its ID.
749
+
750
+ This method retrieves the status of a job using its unique identifier.
751
+
752
+ Args:
753
+ job_id (str): The ID of the job to retrieve the status for.
754
+
755
+ Returns:
756
+ str: The status of the job.
757
+ """
713
758
  with yaspin(
714
- SPINNER,
715
- text=to_colored_text(f"Checking job status with ID: {job_id}"),
716
- color=YASPIN_COLOR,
759
+ SPINNER,
760
+ text=to_colored_text(f"Checking job status with ID: {job_id}"),
761
+ color=YASPIN_COLOR,
717
762
  ) as spinner:
718
- response = requests.get(endpoint, headers=headers)
719
- if response.status_code != 200:
763
+ try:
764
+ response_data = self._fetch_job_status(job_id)
765
+ spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
766
+ return response_data["job_status"][job_id]
767
+ except requests.HTTPError as e:
720
768
  spinner.write(
721
769
  to_colored_text(
722
- f"Bad status code: {response.status_code}", state="fail"
770
+ f"Bad status code: {e.response.status_code}", state="fail"
723
771
  )
724
772
  )
725
773
  spinner.stop()
726
- print(to_colored_text(response.json(), state="fail"))
727
- return
728
- spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
729
- return response.json()["job_status"][job_id]
774
+ print(to_colored_text(e.response.json(), state="fail"))
775
+ return None
730
776
 
731
777
  def get_job_results(
732
778
  self,
@@ -1115,3 +1161,56 @@ class Sutro:
1115
1161
  print(to_colored_text(f"Error: {response.json()}", state="fail"))
1116
1162
  return
1117
1163
  return response.json()["quotas"]
1164
+
1165
+ def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200) -> list | None:
1166
+ """
1167
+ Waits for job completion to occur and then returns the results upon
1168
+ a successful completion.
1169
+
1170
+ Prints out the job's status every 5 seconds.
1171
+
1172
+ Args:
1173
+ job_id (str): The ID of the job to await.
1174
+ timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
1175
+
1176
+ Returns:
1177
+ list: The results of the job.
1178
+ """
1179
+ POLL_INTERVAL = 5
1180
+
1181
+ results = None
1182
+ start_time = time.time()
1183
+ with yaspin(
1184
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1185
+ ) as spinner:
1186
+ while (time.time() - start_time) < timeout:
1187
+ try:
1188
+ status = self._fetch_job_status(job_id)
1189
+ except requests.HTTPError as e:
1190
+ spinner.write(
1191
+ to_colored_text(
1192
+ f"Bad status code: {e.response.status_code}", state="fail"
1193
+ )
1194
+ )
1195
+ spinner.stop()
1196
+ print(to_colored_text(e.response.json(), state="fail"))
1197
+ return None
1198
+
1199
+ spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
1200
+
1201
+ if status == JobStatus.SUCCEEDED:
1202
+ spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
1203
+ spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
1204
+ results = self.get_job_results(job_id)
1205
+ break
1206
+ if status == JobStatus.FAILED:
1207
+ spinner.write(to_colored_text("Job has failed", "fail"))
1208
+ return None
1209
+ if status == JobStatus.CANCELLED:
1210
+ spinner.write(to_colored_text("Job has been cancelled"))
1211
+ return None
1212
+
1213
+
1214
+ time.sleep(POLL_INTERVAL)
1215
+
1216
+ return results
File without changes
File without changes
File without changes
File without changes
File without changes