sutro 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/sdk.py +211 -26
- {sutro-0.1.16.dist-info → sutro-0.1.18.dist-info}/METADATA +1 -1
- sutro-0.1.18.dist-info/RECORD +8 -0
- sutro-0.1.16.dist-info/RECORD +0 -8
- {sutro-0.1.16.dist-info → sutro-0.1.18.dist-info}/WHEEL +0 -0
- {sutro-0.1.16.dist-info → sutro-0.1.18.dist-info}/entry_points.txt +0 -0
- {sutro-0.1.16.dist-info → sutro-0.1.18.dist-info}/licenses/LICENSE +0 -0
sutro/sdk.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import threading
|
|
2
2
|
from concurrent.futures import ThreadPoolExecutor
|
|
3
3
|
from contextlib import contextmanager
|
|
4
|
+
from enum import Enum
|
|
4
5
|
|
|
5
6
|
import requests
|
|
6
7
|
import pandas as pd
|
|
@@ -17,6 +18,31 @@ import time
|
|
|
17
18
|
from pydantic import BaseModel
|
|
18
19
|
import json
|
|
19
20
|
|
|
21
|
+
|
|
22
|
+
class JobStatus(str, Enum):
|
|
23
|
+
"""Job statuses that will be returned by the API & SDK"""
|
|
24
|
+
|
|
25
|
+
UNKNOWN = "UNKNOWN"
|
|
26
|
+
QUEUED = "QUEUED" # Job is waiting to start
|
|
27
|
+
STARTING = "STARTING" # Job is in the process of starting up
|
|
28
|
+
RUNNING = "RUNNING" # Job is actively running
|
|
29
|
+
SUCCEEDED = "SUCCEEDED" # Job completed successfully
|
|
30
|
+
CANCELLING = "CANCELLING" # Job is in the process of being canceled
|
|
31
|
+
CANCELLED = "CANCELLED" # Job was canceled by the user
|
|
32
|
+
FAILED = "FAILED" # Job failed
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def terminal_statuses(cls) -> list["JobStatus"]:
|
|
36
|
+
return [
|
|
37
|
+
cls.SUCCEEDED,
|
|
38
|
+
cls.FAILED,
|
|
39
|
+
cls.CANCELLING,
|
|
40
|
+
cls.CANCELLED,
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
def is_terminal(self) -> bool:
|
|
44
|
+
return self in self.terminal_statuses()
|
|
45
|
+
|
|
20
46
|
# Initialize colorama (required for Windows)
|
|
21
47
|
init()
|
|
22
48
|
|
|
@@ -181,7 +207,7 @@ class Sutro:
|
|
|
181
207
|
sampling_params: dict = None,
|
|
182
208
|
system_prompt: str = None,
|
|
183
209
|
dry_run: bool = False,
|
|
184
|
-
stay_attached: bool =
|
|
210
|
+
stay_attached: Optional[bool] = None,
|
|
185
211
|
random_seed_per_input: bool = False,
|
|
186
212
|
truncate_rows: bool = False
|
|
187
213
|
):
|
|
@@ -211,7 +237,7 @@ class Sutro:
|
|
|
211
237
|
|
|
212
238
|
"""
|
|
213
239
|
input_data = self.handle_data_helper(data, column)
|
|
214
|
-
stay_attached = stay_attached
|
|
240
|
+
stay_attached = stay_attached if stay_attached is not None else job_priority == 0
|
|
215
241
|
|
|
216
242
|
# Convert BaseModel to dict if needed
|
|
217
243
|
if output_schema is not None:
|
|
@@ -240,11 +266,6 @@ class Sutro:
|
|
|
240
266
|
"random_seed_per_input": random_seed_per_input,
|
|
241
267
|
"truncate_rows": truncate_rows
|
|
242
268
|
}
|
|
243
|
-
if dry_run:
|
|
244
|
-
spinner_text = to_colored_text("Retrieving cost estimates...")
|
|
245
|
-
else:
|
|
246
|
-
t = f"Creating priority {job_priority} job"
|
|
247
|
-
spinner_text = to_colored_text(t)
|
|
248
269
|
|
|
249
270
|
# There are two gotchas with yaspin:
|
|
250
271
|
# 1. Can't use print while in spinner is running
|
|
@@ -253,6 +274,8 @@ class Sutro:
|
|
|
253
274
|
# Terminal size {self._terminal_width} is too small to display spinner with the given settings.
|
|
254
275
|
# https://github.com/pavdmyt/yaspin/blob/9c7430b499ab4611888ece39783a870e4a05fa45/yaspin/core.py#L568-L571
|
|
255
276
|
job_id = None
|
|
277
|
+
t = f"Creating {'[dry run] ' if dry_run else ''}priority {job_priority} job"
|
|
278
|
+
spinner_text = to_colored_text(t)
|
|
256
279
|
with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
|
|
257
280
|
response = requests.post(
|
|
258
281
|
endpoint, data=json.dumps(payload), headers=headers
|
|
@@ -264,15 +287,21 @@ class Sutro:
|
|
|
264
287
|
)
|
|
265
288
|
spinner.stop()
|
|
266
289
|
print(to_colored_text(response.json(), state="fail"))
|
|
267
|
-
return
|
|
290
|
+
return None
|
|
268
291
|
else:
|
|
292
|
+
job_id = response_data["results"]
|
|
269
293
|
if dry_run:
|
|
270
294
|
spinner.write(
|
|
271
|
-
to_colored_text("
|
|
295
|
+
to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
|
|
272
296
|
)
|
|
273
|
-
|
|
297
|
+
spinner.stop()
|
|
298
|
+
self.await_job_completion(job_id, obtain_results=False)
|
|
299
|
+
cost_estimate = self._get_job_cost_estimate(job_id)
|
|
300
|
+
spinner.write(
|
|
301
|
+
to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
|
|
302
|
+
)
|
|
303
|
+
return job_id
|
|
274
304
|
else:
|
|
275
|
-
job_id = response_data["results"]
|
|
276
305
|
spinner.write(
|
|
277
306
|
to_colored_text(
|
|
278
307
|
f"🛠️ Priority {job_priority} Job created with ID: {job_id}",
|
|
@@ -289,6 +318,13 @@ class Sutro:
|
|
|
289
318
|
|
|
290
319
|
success = False
|
|
291
320
|
if stay_attached and job_id is not None:
|
|
321
|
+
spinner.write(to_colored_text("Awaiting job start...", "info"))
|
|
322
|
+
started = self._await_job_start(job_id)
|
|
323
|
+
if not started:
|
|
324
|
+
failure_reason = self._get_failure_reason(job_id)
|
|
325
|
+
spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
|
|
326
|
+
return None
|
|
327
|
+
|
|
292
328
|
s = requests.Session()
|
|
293
329
|
payload = {
|
|
294
330
|
"job_id": job_id,
|
|
@@ -375,7 +411,7 @@ class Sutro:
|
|
|
375
411
|
)
|
|
376
412
|
)
|
|
377
413
|
spinner.stop()
|
|
378
|
-
return
|
|
414
|
+
return None
|
|
379
415
|
|
|
380
416
|
results = job_results_response.json()["results"]
|
|
381
417
|
|
|
@@ -401,6 +437,8 @@ class Sutro:
|
|
|
401
437
|
return data
|
|
402
438
|
|
|
403
439
|
return results
|
|
440
|
+
return None
|
|
441
|
+
return None
|
|
404
442
|
|
|
405
443
|
def register_stream_listener(self, job_id: str) -> str:
|
|
406
444
|
"""Register a new stream listener and get a session token."""
|
|
@@ -691,40 +729,94 @@ class Sutro:
|
|
|
691
729
|
return
|
|
692
730
|
return response.json()["jobs"]
|
|
693
731
|
|
|
694
|
-
def
|
|
732
|
+
def _list_jobs_helper(self):
|
|
695
733
|
"""
|
|
696
|
-
|
|
734
|
+
Helper function to list jobs.
|
|
735
|
+
"""
|
|
736
|
+
endpoint = f"{self.base_url}/list-jobs"
|
|
737
|
+
headers = {
|
|
738
|
+
"Authorization": f"Key {self.api_key}",
|
|
739
|
+
"Content-Type": "application/json",
|
|
740
|
+
}
|
|
741
|
+
response = requests.get(endpoint, headers=headers)
|
|
742
|
+
if response.status_code != 200:
|
|
743
|
+
return None
|
|
744
|
+
return response.json()["jobs"]
|
|
697
745
|
|
|
698
|
-
|
|
746
|
+
def _get_job_cost_estimate(self, job_id: str):
|
|
747
|
+
"""
|
|
748
|
+
Get the cost estimate for a job.
|
|
749
|
+
"""
|
|
750
|
+
all_jobs = self._list_jobs_helper()
|
|
751
|
+
for job in all_jobs:
|
|
752
|
+
if job["job_id"] == job_id:
|
|
753
|
+
return job["cost_estimate"]
|
|
754
|
+
return None
|
|
755
|
+
|
|
756
|
+
def _get_failure_reason(self, job_id: str):
|
|
757
|
+
"""
|
|
758
|
+
Get the failure reason for a job.
|
|
759
|
+
"""
|
|
760
|
+
all_jobs = self._list_jobs_helper()
|
|
761
|
+
for job in all_jobs:
|
|
762
|
+
if job["job_id"] == job_id:
|
|
763
|
+
return job["failure_reason"]
|
|
764
|
+
return None
|
|
765
|
+
|
|
766
|
+
def _fetch_job_status(self, job_id: str):
|
|
767
|
+
"""
|
|
768
|
+
Core logic to fetch job status from the API.
|
|
699
769
|
|
|
700
770
|
Args:
|
|
701
771
|
job_id (str): The ID of the job to retrieve the status for.
|
|
702
772
|
|
|
703
773
|
Returns:
|
|
704
|
-
|
|
774
|
+
dict: The response JSON from the API.
|
|
775
|
+
|
|
776
|
+
Raises:
|
|
777
|
+
requests.HTTPError: If the API returns a non-200 status code.
|
|
705
778
|
"""
|
|
706
779
|
endpoint = f"{self.base_url}/job-status/{job_id}"
|
|
707
780
|
headers = {
|
|
708
781
|
"Authorization": f"Key {self.api_key}",
|
|
709
782
|
"Content-Type": "application/json",
|
|
710
783
|
}
|
|
784
|
+
|
|
785
|
+
response = requests.get(endpoint, headers=headers)
|
|
786
|
+
response.raise_for_status()
|
|
787
|
+
|
|
788
|
+
return response.json()["job_status"][job_id]
|
|
789
|
+
|
|
790
|
+
def get_job_status(self, job_id: str):
|
|
791
|
+
"""
|
|
792
|
+
Get the status of a job by its ID.
|
|
793
|
+
|
|
794
|
+
This method retrieves the status of a job using its unique identifier.
|
|
795
|
+
|
|
796
|
+
Args:
|
|
797
|
+
job_id (str): The ID of the job to retrieve the status for.
|
|
798
|
+
|
|
799
|
+
Returns:
|
|
800
|
+
str: The status of the job.
|
|
801
|
+
"""
|
|
711
802
|
with yaspin(
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
803
|
+
SPINNER,
|
|
804
|
+
text=to_colored_text(f"Checking job status with ID: {job_id}"),
|
|
805
|
+
color=YASPIN_COLOR,
|
|
715
806
|
) as spinner:
|
|
716
|
-
|
|
717
|
-
|
|
807
|
+
try:
|
|
808
|
+
response_data = self._fetch_job_status(job_id)
|
|
809
|
+
spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
|
|
810
|
+
return response_data["job_status"][job_id]
|
|
811
|
+
except requests.HTTPError as e:
|
|
718
812
|
spinner.write(
|
|
719
813
|
to_colored_text(
|
|
720
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
814
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
721
815
|
)
|
|
722
816
|
)
|
|
723
817
|
spinner.stop()
|
|
724
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
725
|
-
return
|
|
726
|
-
spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
|
|
727
|
-
return response.json()["job_status"][job_id]
|
|
818
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
819
|
+
return None
|
|
728
820
|
|
|
729
821
|
def get_job_results(
|
|
730
822
|
self,
|
|
@@ -1113,3 +1205,96 @@ class Sutro:
|
|
|
1113
1205
|
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1114
1206
|
return
|
|
1115
1207
|
return response.json()["quotas"]
|
|
1208
|
+
|
|
1209
|
+
def await_job_completion(self, job_id: str, timeout: Optional[int] = 7200, obtain_results: bool = True) -> list | None:
|
|
1210
|
+
"""
|
|
1211
|
+
Waits for job completion to occur and then returns the results upon
|
|
1212
|
+
a successful completion.
|
|
1213
|
+
|
|
1214
|
+
Prints out the job's status every 5 seconds.
|
|
1215
|
+
|
|
1216
|
+
Args:
|
|
1217
|
+
job_id (str): The ID of the job to await.
|
|
1218
|
+
timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
|
|
1219
|
+
|
|
1220
|
+
Returns:
|
|
1221
|
+
list: The results of the job.
|
|
1222
|
+
"""
|
|
1223
|
+
POLL_INTERVAL = 5
|
|
1224
|
+
|
|
1225
|
+
results = None
|
|
1226
|
+
start_time = time.time()
|
|
1227
|
+
with yaspin(
|
|
1228
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1229
|
+
) as spinner:
|
|
1230
|
+
while (time.time() - start_time) < timeout:
|
|
1231
|
+
try:
|
|
1232
|
+
status = self._fetch_job_status(job_id)
|
|
1233
|
+
except requests.HTTPError as e:
|
|
1234
|
+
spinner.write(
|
|
1235
|
+
to_colored_text(
|
|
1236
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1237
|
+
)
|
|
1238
|
+
)
|
|
1239
|
+
spinner.stop()
|
|
1240
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1241
|
+
return None
|
|
1242
|
+
|
|
1243
|
+
spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
|
|
1244
|
+
|
|
1245
|
+
if status == JobStatus.SUCCEEDED:
|
|
1246
|
+
spinner.write(to_colored_text("Job completed! Retrieving results...", "success"))
|
|
1247
|
+
spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
|
|
1248
|
+
if obtain_results:
|
|
1249
|
+
results = self.get_job_results(job_id)
|
|
1250
|
+
break
|
|
1251
|
+
if status == JobStatus.FAILED:
|
|
1252
|
+
spinner.write(to_colored_text("Job has failed", "fail"))
|
|
1253
|
+
return None
|
|
1254
|
+
if status == JobStatus.CANCELLED:
|
|
1255
|
+
spinner.write(to_colored_text("Job has been cancelled"))
|
|
1256
|
+
return None
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
time.sleep(POLL_INTERVAL)
|
|
1260
|
+
|
|
1261
|
+
return results
|
|
1262
|
+
|
|
1263
|
+
def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
|
|
1264
|
+
"""
|
|
1265
|
+
Waits for job start to occur and then returns the results upon
|
|
1266
|
+
a successful start.
|
|
1267
|
+
|
|
1268
|
+
"""
|
|
1269
|
+
POLL_INTERVAL = 5
|
|
1270
|
+
|
|
1271
|
+
start_time = time.time()
|
|
1272
|
+
with yaspin(
|
|
1273
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1274
|
+
) as spinner:
|
|
1275
|
+
while (time.time() - start_time) < timeout:
|
|
1276
|
+
try:
|
|
1277
|
+
status = self._fetch_job_status(job_id)
|
|
1278
|
+
except requests.HTTPError as e:
|
|
1279
|
+
spinner.write(
|
|
1280
|
+
to_colored_text(
|
|
1281
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1282
|
+
)
|
|
1283
|
+
)
|
|
1284
|
+
spinner.stop()
|
|
1285
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1286
|
+
return None
|
|
1287
|
+
|
|
1288
|
+
spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
|
|
1289
|
+
|
|
1290
|
+
if status == JobStatus.RUNNING or status == JobStatus.STARTING:
|
|
1291
|
+
return True
|
|
1292
|
+
if status == JobStatus.FAILED:
|
|
1293
|
+
return False
|
|
1294
|
+
if status == JobStatus.CANCELLED:
|
|
1295
|
+
return False
|
|
1296
|
+
|
|
1297
|
+
time.sleep(POLL_INTERVAL)
|
|
1298
|
+
|
|
1299
|
+
return False
|
|
1300
|
+
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
+
sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
|
|
3
|
+
sutro/sdk.py,sha256=Jjv6FQjRyHVF0_6qYHaP-qDOzdx8FlGLKIUod4g7sZU,49687
|
|
4
|
+
sutro-0.1.18.dist-info/METADATA,sha256=DpPBKOfef-Nlou5uN8AacVUtCZeWfq_iCIH8JBLRUNE,669
|
|
5
|
+
sutro-0.1.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
sutro-0.1.18.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
+
sutro-0.1.18.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
sutro-0.1.18.dist-info/RECORD,,
|
sutro-0.1.16.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
-
sutro/cli.py,sha256=6Qy9Vwaaho92HeO8YA_z1De4zp1dEFkSX3bEnLvdbkE,13203
|
|
3
|
-
sutro/sdk.py,sha256=Y89dwGKIn5_bnX2P_MoCTIw4nhc80NXHxAUgVDCuzko,42905
|
|
4
|
-
sutro-0.1.16.dist-info/METADATA,sha256=-grXlBSpbRTMXnBXGMXbN_vI46wuIGkAwqgh6gTeI8g,669
|
|
5
|
-
sutro-0.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
sutro-0.1.16.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
-
sutro-0.1.16.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
sutro-0.1.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|