sutro 0.1.16__tar.gz → 0.1.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sutro-0.1.16 → sutro-0.1.17}/PKG-INFO +1 -1
- {sutro-0.1.16 → sutro-0.1.17}/pyproject.toml +1 -1
- {sutro-0.1.16 → sutro-0.1.17}/sutro/sdk.py +120 -19
- {sutro-0.1.16 → sutro-0.1.17}/.gitignore +0 -0
- {sutro-0.1.16 → sutro-0.1.17}/LICENSE +0 -0
- {sutro-0.1.16 → sutro-0.1.17}/README.md +0 -0
- {sutro-0.1.16 → sutro-0.1.17}/sutro/__init__.py +0 -0
- {sutro-0.1.16 → sutro-0.1.17}/sutro/cli.py +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import threading
|
|
2
2
|
from concurrent.futures import ThreadPoolExecutor
|
|
3
3
|
from contextlib import contextmanager
|
|
4
|
+
from enum import Enum
|
|
4
5
|
|
|
5
6
|
import requests
|
|
6
7
|
import pandas as pd
|
|
@@ -17,6 +18,31 @@ import time
|
|
|
17
18
|
from pydantic import BaseModel
|
|
18
19
|
import json
|
|
19
20
|
|
|
21
|
+
|
|
22
|
+
class JobStatus(str, Enum):
|
|
23
|
+
"""Job statuses that will be returned by the API & SDK"""
|
|
24
|
+
|
|
25
|
+
UNKNOWN = "UNKNOWN"
|
|
26
|
+
QUEUED = "QUEUED" # Job is waiting to start
|
|
27
|
+
STARTING = "STARTING" # Job is in the process of starting up
|
|
28
|
+
RUNNING = "RUNNING" # Job is actively running
|
|
29
|
+
SUCCEEDED = "SUCCEEDED" # Job completed successfully
|
|
30
|
+
CANCELLING = "CANCELLING" # Job is in the process of being canceled
|
|
31
|
+
CANCELLED = "CANCELLED" # Job was canceled by the user
|
|
32
|
+
FAILED = "FAILED" # Job failed
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def terminal_statuses(cls) -> list["JobStatus"]:
|
|
36
|
+
return [
|
|
37
|
+
cls.SUCCEEDED,
|
|
38
|
+
cls.FAILED,
|
|
39
|
+
cls.CANCELLING,
|
|
40
|
+
cls.CANCELLED,
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
def is_terminal(self) -> bool:
|
|
44
|
+
return self in self.terminal_statuses()
|
|
45
|
+
|
|
20
46
|
# Initialize colorama (required for Windows)
|
|
21
47
|
init()
|
|
22
48
|
|
|
@@ -181,7 +207,7 @@ class Sutro:
|
|
|
181
207
|
sampling_params: dict = None,
|
|
182
208
|
system_prompt: str = None,
|
|
183
209
|
dry_run: bool = False,
|
|
184
|
-
stay_attached: bool =
|
|
210
|
+
stay_attached: Optional[bool] = None,
|
|
185
211
|
random_seed_per_input: bool = False,
|
|
186
212
|
truncate_rows: bool = False
|
|
187
213
|
):
|
|
@@ -211,7 +237,7 @@ class Sutro:
|
|
|
211
237
|
|
|
212
238
|
"""
|
|
213
239
|
input_data = self.handle_data_helper(data, column)
|
|
214
|
-
stay_attached = stay_attached
|
|
240
|
+
stay_attached = stay_attached if stay_attached is not None else job_priority == 0
|
|
215
241
|
|
|
216
242
|
# Convert BaseModel to dict if needed
|
|
217
243
|
if output_schema is not None:
|
|
@@ -264,7 +290,7 @@ class Sutro:
|
|
|
264
290
|
)
|
|
265
291
|
spinner.stop()
|
|
266
292
|
print(to_colored_text(response.json(), state="fail"))
|
|
267
|
-
return
|
|
293
|
+
return None
|
|
268
294
|
else:
|
|
269
295
|
if dry_run:
|
|
270
296
|
spinner.write(
|
|
@@ -375,7 +401,7 @@ class Sutro:
|
|
|
375
401
|
)
|
|
376
402
|
)
|
|
377
403
|
spinner.stop()
|
|
378
|
-
return
|
|
404
|
+
return None
|
|
379
405
|
|
|
380
406
|
results = job_results_response.json()["results"]
|
|
381
407
|
|
|
@@ -401,6 +427,8 @@ class Sutro:
|
|
|
401
427
|
return data
|
|
402
428
|
|
|
403
429
|
return results
|
|
430
|
+
return None
|
|
431
|
+
return None
|
|
404
432
|
|
|
405
433
|
def register_stream_listener(self, job_id: str) -> str:
|
|
406
434
|
"""Register a new stream listener and get a session token."""
|
|
@@ -691,40 +719,60 @@ class Sutro:
|
|
|
691
719
|
return
|
|
692
720
|
return response.json()["jobs"]
|
|
693
721
|
|
|
694
|
-
def
|
|
722
|
+
def _fetch_job_status(self, job_id: str):
|
|
695
723
|
"""
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
This method retrieves the status of a job using its unique identifier.
|
|
724
|
+
Core logic to fetch job status from the API.
|
|
699
725
|
|
|
700
726
|
Args:
|
|
701
727
|
job_id (str): The ID of the job to retrieve the status for.
|
|
702
728
|
|
|
703
729
|
Returns:
|
|
704
|
-
|
|
730
|
+
dict: The response JSON from the API.
|
|
731
|
+
|
|
732
|
+
Raises:
|
|
733
|
+
requests.HTTPError: If the API returns a non-200 status code.
|
|
705
734
|
"""
|
|
706
735
|
endpoint = f"{self.base_url}/job-status/{job_id}"
|
|
707
736
|
headers = {
|
|
708
737
|
"Authorization": f"Key {self.api_key}",
|
|
709
738
|
"Content-Type": "application/json",
|
|
710
739
|
}
|
|
740
|
+
|
|
741
|
+
response = requests.get(endpoint, headers=headers)
|
|
742
|
+
response.raise_for_status()
|
|
743
|
+
|
|
744
|
+
return response.json()["job_status"][job_id]
|
|
745
|
+
|
|
746
|
+
def get_job_status(self, job_id: str):
|
|
747
|
+
"""
|
|
748
|
+
Get the status of a job by its ID.
|
|
749
|
+
|
|
750
|
+
This method retrieves the status of a job using its unique identifier.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
job_id (str): The ID of the job to retrieve the status for.
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
str: The status of the job.
|
|
757
|
+
"""
|
|
711
758
|
with yaspin(
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
759
|
+
SPINNER,
|
|
760
|
+
text=to_colored_text(f"Checking job status with ID: {job_id}"),
|
|
761
|
+
color=YASPIN_COLOR,
|
|
715
762
|
) as spinner:
|
|
716
|
-
|
|
717
|
-
|
|
763
|
+
try:
|
|
764
|
+
response_data = self._fetch_job_status(job_id)
|
|
765
|
+
spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
|
|
766
|
+
return response_data["job_status"][job_id]
|
|
767
|
+
except requests.HTTPError as e:
|
|
718
768
|
spinner.write(
|
|
719
769
|
to_colored_text(
|
|
720
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
770
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
721
771
|
)
|
|
722
772
|
)
|
|
723
773
|
spinner.stop()
|
|
724
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
725
|
-
return
|
|
726
|
-
spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
|
|
727
|
-
return response.json()["job_status"][job_id]
|
|
774
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
775
|
+
return None
|
|
728
776
|
|
|
729
777
|
def get_job_results(
|
|
730
778
|
self,
|
|
@@ -1113,3 +1161,56 @@ class Sutro:
|
|
|
1113
1161
|
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1114
1162
|
return
|
|
1115
1163
|
return response.json()["quotas"]
|
|
1164
|
+
|
|
1165
|
+
def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200) -> list | None:
|
|
1166
|
+
"""
|
|
1167
|
+
Waits for job completion to occur and then returns the results upon
|
|
1168
|
+
a successful completion.
|
|
1169
|
+
|
|
1170
|
+
Prints out the job's status every 5 seconds.
|
|
1171
|
+
|
|
1172
|
+
Args:
|
|
1173
|
+
job_id (str): The ID of the job to await.
|
|
1174
|
+
timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
|
|
1175
|
+
|
|
1176
|
+
Returns:
|
|
1177
|
+
list: The results of the job.
|
|
1178
|
+
"""
|
|
1179
|
+
POLL_INTERVAL = 5
|
|
1180
|
+
|
|
1181
|
+
results = None
|
|
1182
|
+
start_time = time.time()
|
|
1183
|
+
with yaspin(
|
|
1184
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1185
|
+
) as spinner:
|
|
1186
|
+
while (time.time() - start_time) < timeout:
|
|
1187
|
+
try:
|
|
1188
|
+
status = self._fetch_job_status(job_id)
|
|
1189
|
+
except requests.HTTPError as e:
|
|
1190
|
+
spinner.write(
|
|
1191
|
+
to_colored_text(
|
|
1192
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1193
|
+
)
|
|
1194
|
+
)
|
|
1195
|
+
spinner.stop()
|
|
1196
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1197
|
+
return None
|
|
1198
|
+
|
|
1199
|
+
spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
|
|
1200
|
+
|
|
1201
|
+
if status == JobStatus.SUCCEEDED:
|
|
1202
|
+
spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
|
|
1203
|
+
spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
|
|
1204
|
+
results = self.get_job_results(job_id)
|
|
1205
|
+
break
|
|
1206
|
+
if status == JobStatus.FAILED:
|
|
1207
|
+
spinner.write(to_colored_text("Job has failed", "fail"))
|
|
1208
|
+
return None
|
|
1209
|
+
if status == JobStatus.CANCELLED:
|
|
1210
|
+
spinner.write(to_colored_text("Job has been cancelled"))
|
|
1211
|
+
return None
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
time.sleep(POLL_INTERVAL)
|
|
1215
|
+
|
|
1216
|
+
return results
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|