sutro 0.1.18__tar.gz → 0.1.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- {sutro-0.1.18 → sutro-0.1.20}/PKG-INFO +1 -1
- {sutro-0.1.18 → sutro-0.1.20}/pyproject.toml +1 -1
- {sutro-0.1.18 → sutro-0.1.20}/sutro/sdk.py +151 -144
- {sutro-0.1.18 → sutro-0.1.20}/.gitignore +0 -0
- {sutro-0.1.18 → sutro-0.1.20}/LICENSE +0 -0
- {sutro-0.1.18 → sutro-0.1.20}/README.md +0 -0
- {sutro-0.1.18 → sutro-0.1.20}/sutro/__init__.py +0 -0
- {sutro-0.1.18 → sutro-0.1.20}/sutro/cli.py +0 -0
|
@@ -97,6 +97,16 @@ def to_colored_text(
|
|
|
97
97
|
# Default to blue for normal/processing states
|
|
98
98
|
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
99
99
|
|
|
100
|
+
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
101
|
+
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
102
|
+
def make_clickable_link(url, text=None):
|
|
103
|
+
"""
|
|
104
|
+
Create a clickable link for terminals that support OSC 8 hyperlinks.
|
|
105
|
+
Falls back to plain text for terminals that don't support it.
|
|
106
|
+
"""
|
|
107
|
+
if text is None:
|
|
108
|
+
text = url
|
|
109
|
+
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
100
110
|
|
|
101
111
|
class Sutro:
|
|
102
112
|
def __init__(
|
|
@@ -104,7 +114,6 @@ class Sutro:
|
|
|
104
114
|
):
|
|
105
115
|
self.api_key = api_key or self.check_for_api_key()
|
|
106
116
|
self.base_url = base_url
|
|
107
|
-
self.HEARTBEAT_INTERVAL_SECONDS = 15 # Keep in sync w what the backend expects
|
|
108
117
|
|
|
109
118
|
def check_for_api_key(self):
|
|
110
119
|
"""
|
|
@@ -276,68 +285,67 @@ class Sutro:
|
|
|
276
285
|
job_id = None
|
|
277
286
|
t = f"Creating {'[dry run] ' if dry_run else ''}priority {job_priority} job"
|
|
278
287
|
spinner_text = to_colored_text(t)
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
response_data = response.json()
|
|
284
|
-
if response.status_code != 200:
|
|
285
|
-
spinner.write(
|
|
286
|
-
to_colored_text(f"Error: {response.status_code}", state="fail")
|
|
288
|
+
try:
|
|
289
|
+
with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
|
|
290
|
+
response = requests.post(
|
|
291
|
+
endpoint, data=json.dumps(payload), headers=headers
|
|
287
292
|
)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
return None
|
|
291
|
-
else:
|
|
292
|
-
job_id = response_data["results"]
|
|
293
|
-
if dry_run:
|
|
293
|
+
response_data = response.json()
|
|
294
|
+
if response.status_code != 200:
|
|
294
295
|
spinner.write(
|
|
295
|
-
to_colored_text(f"
|
|
296
|
+
to_colored_text(f"Error: {response.status_code}", state="fail")
|
|
296
297
|
)
|
|
297
298
|
spinner.stop()
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
spinner.write(
|
|
301
|
-
to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
|
|
302
|
-
)
|
|
303
|
-
return job_id
|
|
299
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
300
|
+
return None
|
|
304
301
|
else:
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
state="
|
|
302
|
+
job_id = response_data["results"]
|
|
303
|
+
if dry_run:
|
|
304
|
+
spinner.write(
|
|
305
|
+
to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
|
|
309
306
|
)
|
|
310
|
-
|
|
311
|
-
|
|
307
|
+
spinner.stop()
|
|
308
|
+
self.await_job_completion(job_id, obtain_results=False)
|
|
309
|
+
cost_estimate = self._get_job_cost_estimate(job_id)
|
|
310
|
+
spinner.write(
|
|
311
|
+
to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
|
|
312
|
+
)
|
|
313
|
+
return job_id
|
|
314
|
+
else:
|
|
312
315
|
spinner.write(
|
|
313
316
|
to_colored_text(
|
|
314
|
-
f"
|
|
317
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
|
|
318
|
+
state="success",
|
|
315
319
|
)
|
|
316
320
|
)
|
|
317
|
-
|
|
321
|
+
if not stay_attached:
|
|
322
|
+
spinner.write(
|
|
323
|
+
to_colored_text(
|
|
324
|
+
f"Use `so.get_job_status('{job_id}')` to check the status of the job."
|
|
325
|
+
)
|
|
326
|
+
)
|
|
327
|
+
return job_id
|
|
328
|
+
except KeyboardInterrupt:
|
|
329
|
+
pass
|
|
330
|
+
finally:
|
|
331
|
+
if spinner:
|
|
332
|
+
spinner.stop()
|
|
318
333
|
|
|
319
334
|
success = False
|
|
320
335
|
if stay_attached and job_id is not None:
|
|
321
|
-
spinner.write(to_colored_text("Awaiting job start...",
|
|
336
|
+
spinner.write(to_colored_text("Awaiting job start...", ))
|
|
337
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
322
338
|
started = self._await_job_start(job_id)
|
|
323
339
|
if not started:
|
|
324
340
|
failure_reason = self._get_failure_reason(job_id)
|
|
325
341
|
spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
|
|
326
342
|
return None
|
|
327
|
-
|
|
328
343
|
s = requests.Session()
|
|
329
|
-
payload = {
|
|
330
|
-
"job_id": job_id,
|
|
331
|
-
}
|
|
332
344
|
pbar = None
|
|
333
345
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
# Use the heartbeat session context manager
|
|
338
|
-
with self.stream_heartbeat_session(job_id, session_token) as s:
|
|
339
|
-
with s.get(
|
|
340
|
-
f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
|
|
346
|
+
try:
|
|
347
|
+
with requests.get(
|
|
348
|
+
f"{self.base_url}/stream-job-progress/{job_id}",
|
|
341
349
|
headers=headers,
|
|
342
350
|
stream=True,
|
|
343
351
|
) as streaming_response:
|
|
@@ -348,6 +356,13 @@ class Sutro:
|
|
|
348
356
|
color=YASPIN_COLOR,
|
|
349
357
|
)
|
|
350
358
|
spinner.start()
|
|
359
|
+
|
|
360
|
+
token_state = {
|
|
361
|
+
'input_tokens': 0,
|
|
362
|
+
'output_tokens': 0,
|
|
363
|
+
'total_tokens_processed_per_second': 0
|
|
364
|
+
}
|
|
365
|
+
|
|
351
366
|
for line in streaming_response.iter_lines():
|
|
352
367
|
if line:
|
|
353
368
|
try:
|
|
@@ -370,12 +385,30 @@ class Sutro:
|
|
|
370
385
|
pbar.update(json_obj["result"] - pbar.n)
|
|
371
386
|
pbar.refresh()
|
|
372
387
|
if json_obj["result"] == len(input_data):
|
|
373
|
-
pbar.close()
|
|
374
388
|
success = True
|
|
375
389
|
elif json_obj["update_type"] == "tokens":
|
|
390
|
+
# Update only the values that are present in this update
|
|
391
|
+
# Currently, the way the progress stream endpoint is defined,
|
|
392
|
+
# its possible to have updates come in that only have 1 or 2 fields
|
|
393
|
+
new = {
|
|
394
|
+
k: v for k, v in json_obj.get('result', {}).items()
|
|
395
|
+
if k in token_state and v >= token_state[k]
|
|
396
|
+
}
|
|
397
|
+
token_state.update(new)
|
|
398
|
+
|
|
376
399
|
if pbar is not None:
|
|
377
|
-
pbar.postfix = f"Input tokens processed: {
|
|
400
|
+
pbar.postfix = f"Input tokens processed: {token_state['input_tokens']}, Output tokens generated: {token_state['output_tokens']}, Total tokens/s: {token_state['total_tokens_processed_per_second']}"
|
|
378
401
|
pbar.refresh()
|
|
402
|
+
|
|
403
|
+
except KeyboardInterrupt:
|
|
404
|
+
pass
|
|
405
|
+
finally:
|
|
406
|
+
# Need to clean these up on keyboard exit otherwise it causes
|
|
407
|
+
# an error
|
|
408
|
+
if pbar is not None:
|
|
409
|
+
pbar.close()
|
|
410
|
+
if spinner is not None:
|
|
411
|
+
spinner.stop()
|
|
379
412
|
if success:
|
|
380
413
|
spinner.text = to_colored_text(
|
|
381
414
|
"✔ Job succeeded. Obtaining results...", state="success"
|
|
@@ -440,87 +473,6 @@ class Sutro:
|
|
|
440
473
|
return None
|
|
441
474
|
return None
|
|
442
475
|
|
|
443
|
-
def register_stream_listener(self, job_id: str) -> str:
|
|
444
|
-
"""Register a new stream listener and get a session token."""
|
|
445
|
-
headers = {
|
|
446
|
-
"Authorization": f"Key {self.api_key}",
|
|
447
|
-
"Content-Type": "application/json",
|
|
448
|
-
}
|
|
449
|
-
with requests.post(
|
|
450
|
-
f"{self.base_url}/register-stream-listener/{job_id}",
|
|
451
|
-
headers=headers,
|
|
452
|
-
) as response:
|
|
453
|
-
response.raise_for_status()
|
|
454
|
-
data = response.json()
|
|
455
|
-
return data["request_session_token"]
|
|
456
|
-
|
|
457
|
-
# This is a best effort action and is ok if it sometimes doesn't complete etc
|
|
458
|
-
def unregister_stream_listener(self, job_id: str, session_token: str):
|
|
459
|
-
"""Explicitly unregister a stream listener."""
|
|
460
|
-
headers = {
|
|
461
|
-
"Authorization": f"Key {self.api_key}",
|
|
462
|
-
"Content-Type": "application/json",
|
|
463
|
-
}
|
|
464
|
-
with requests.post(
|
|
465
|
-
f"{self.base_url}/unregister-stream-listener/{job_id}",
|
|
466
|
-
headers=headers,
|
|
467
|
-
json={"request_session_token": session_token},
|
|
468
|
-
) as response:
|
|
469
|
-
response.raise_for_status()
|
|
470
|
-
|
|
471
|
-
def start_heartbeat(
|
|
472
|
-
self,
|
|
473
|
-
job_id: str,
|
|
474
|
-
session_token: str,
|
|
475
|
-
session: requests.Session,
|
|
476
|
-
stop_event: threading.Event
|
|
477
|
-
):
|
|
478
|
-
"""Send heartbeats until stopped."""
|
|
479
|
-
while not stop_event.is_set():
|
|
480
|
-
try:
|
|
481
|
-
headers = {
|
|
482
|
-
"Authorization": f"Key {self.api_key}",
|
|
483
|
-
"Content-Type": "application/json",
|
|
484
|
-
}
|
|
485
|
-
response = session.post(
|
|
486
|
-
f"{self.base_url}/stream-heartbeat/{job_id}",
|
|
487
|
-
headers=headers,
|
|
488
|
-
params={"request_session_token": session_token},
|
|
489
|
-
)
|
|
490
|
-
response.raise_for_status()
|
|
491
|
-
except Exception as e:
|
|
492
|
-
if not stop_event.is_set(): # Only log if we weren't stopping anyway
|
|
493
|
-
print(f"Heartbeat failed for job {job_id}: {e}")
|
|
494
|
-
|
|
495
|
-
for _ in range(self.HEARTBEAT_INTERVAL_SECONDS):
|
|
496
|
-
if stop_event.is_set():
|
|
497
|
-
break
|
|
498
|
-
time.sleep(1)
|
|
499
|
-
|
|
500
|
-
@contextmanager
|
|
501
|
-
def stream_heartbeat_session(self, job_id: str, session_token: str) -> Generator[requests.Session, None, None]:
|
|
502
|
-
"""Context manager that handles session registration and heartbeat."""
|
|
503
|
-
session = requests.Session()
|
|
504
|
-
stop_heartbeat = threading.Event()
|
|
505
|
-
|
|
506
|
-
# Run this concurrently in a thread so we can not block main SDK path/behavior
|
|
507
|
-
# but still run heartbeat requests
|
|
508
|
-
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
509
|
-
executor.submit(
|
|
510
|
-
self.start_heartbeat,
|
|
511
|
-
job_id,
|
|
512
|
-
session_token,
|
|
513
|
-
session,
|
|
514
|
-
stop_heartbeat
|
|
515
|
-
)
|
|
516
|
-
|
|
517
|
-
try:
|
|
518
|
-
yield session
|
|
519
|
-
finally:
|
|
520
|
-
# Signal stop and cleanup
|
|
521
|
-
stop_heartbeat.set()
|
|
522
|
-
self.unregister_stream_listener(job_id, session_token)
|
|
523
|
-
session.close()
|
|
524
476
|
|
|
525
477
|
def attach(self, job_id):
|
|
526
478
|
"""
|
|
@@ -585,11 +537,9 @@ class Sutro:
|
|
|
585
537
|
total_rows = job["num_rows"]
|
|
586
538
|
success = False
|
|
587
539
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
with self.stream_heartbeat_session(job_id, session_token) as s:
|
|
540
|
+
try:
|
|
591
541
|
with s.get(
|
|
592
|
-
f"{self.base_url}/stream-job-progress/{job_id}
|
|
542
|
+
f"{self.base_url}/stream-job-progress/{job_id}",
|
|
593
543
|
headers=headers,
|
|
594
544
|
stream=True,
|
|
595
545
|
) as streaming_response:
|
|
@@ -599,6 +549,7 @@ class Sutro:
|
|
|
599
549
|
text=to_colored_text("Awaiting status updates..."),
|
|
600
550
|
color=YASPIN_COLOR,
|
|
601
551
|
)
|
|
552
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
602
553
|
spinner.start()
|
|
603
554
|
for line in streaming_response.iter_lines():
|
|
604
555
|
if line:
|
|
@@ -637,6 +588,13 @@ class Sutro:
|
|
|
637
588
|
)
|
|
638
589
|
)
|
|
639
590
|
spinner.stop()
|
|
591
|
+
except KeyboardInterrupt:
|
|
592
|
+
pass
|
|
593
|
+
finally:
|
|
594
|
+
if pbar:
|
|
595
|
+
pbar.close()
|
|
596
|
+
if spinner:
|
|
597
|
+
spinner.stop()
|
|
640
598
|
|
|
641
599
|
|
|
642
600
|
|
|
@@ -819,10 +777,12 @@ class Sutro:
|
|
|
819
777
|
return None
|
|
820
778
|
|
|
821
779
|
def get_job_results(
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
780
|
+
self,
|
|
781
|
+
job_id: str,
|
|
782
|
+
include_inputs: bool = False,
|
|
783
|
+
include_cumulative_logprobs: bool = False,
|
|
784
|
+
with_original_df: pl.DataFrame | pd.DataFrame = None,
|
|
785
|
+
output_column: str = "inference_result",
|
|
826
786
|
):
|
|
827
787
|
"""
|
|
828
788
|
Get the results of a job by its ID.
|
|
@@ -833,9 +793,11 @@ class Sutro:
|
|
|
833
793
|
job_id (str): The ID of the job to retrieve the results for.
|
|
834
794
|
include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
|
|
835
795
|
include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
|
|
796
|
+
with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
|
|
797
|
+
output_column (str, optional): Name of the output column. Defaults to "inference_result".
|
|
836
798
|
|
|
837
799
|
Returns:
|
|
838
|
-
|
|
800
|
+
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
839
801
|
"""
|
|
840
802
|
endpoint = f"{self.base_url}/job-results"
|
|
841
803
|
payload = {
|
|
@@ -848,18 +810,14 @@ class Sutro:
|
|
|
848
810
|
"Content-Type": "application/json",
|
|
849
811
|
}
|
|
850
812
|
with yaspin(
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
813
|
+
SPINNER,
|
|
814
|
+
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
815
|
+
color=YASPIN_COLOR,
|
|
854
816
|
) as spinner:
|
|
855
817
|
response = requests.post(
|
|
856
818
|
endpoint, data=json.dumps(payload), headers=headers
|
|
857
819
|
)
|
|
858
|
-
if response.status_code
|
|
859
|
-
spinner.write(
|
|
860
|
-
to_colored_text("✔ Job results retrieved", state="success")
|
|
861
|
-
)
|
|
862
|
-
else:
|
|
820
|
+
if response.status_code != 200:
|
|
863
821
|
spinner.write(
|
|
864
822
|
to_colored_text(
|
|
865
823
|
f"Bad status code: {response.status_code}", state="fail"
|
|
@@ -867,8 +825,56 @@ class Sutro:
|
|
|
867
825
|
)
|
|
868
826
|
spinner.stop()
|
|
869
827
|
print(to_colored_text(response.json(), state="fail"))
|
|
870
|
-
return
|
|
871
|
-
|
|
828
|
+
return None
|
|
829
|
+
|
|
830
|
+
spinner.write(
|
|
831
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
response_data = response.json()
|
|
835
|
+
results_df = pl.DataFrame(response_data["results"])
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
if len(results_df.columns ) == 1:
|
|
839
|
+
# Default column when API is only returning a list, and we construct the df
|
|
840
|
+
# from that
|
|
841
|
+
original_results_column = 'column_0'
|
|
842
|
+
else:
|
|
843
|
+
original_results_column = 'outputs'
|
|
844
|
+
|
|
845
|
+
results_df = results_df.rename({original_results_column: output_column})
|
|
846
|
+
|
|
847
|
+
# Ordering inputs col first seems most logical/useful
|
|
848
|
+
column_config = [
|
|
849
|
+
('inputs', include_inputs),
|
|
850
|
+
(output_column, True),
|
|
851
|
+
('cumulative_logprobs', include_cumulative_logprobs),
|
|
852
|
+
]
|
|
853
|
+
|
|
854
|
+
columns_to_keep = [col for col, include in column_config
|
|
855
|
+
if include and col in results_df.columns]
|
|
856
|
+
|
|
857
|
+
results_df = results_df.select(columns_to_keep)
|
|
858
|
+
|
|
859
|
+
# Handle concatenation with original DataFrame
|
|
860
|
+
if with_original_df is not None:
|
|
861
|
+
if isinstance(with_original_df, pd.DataFrame):
|
|
862
|
+
# Convert to polars for consistent handling
|
|
863
|
+
original_pl = pl.from_pandas(with_original_df)
|
|
864
|
+
|
|
865
|
+
combined_df = original_pl.with_columns(results_df)
|
|
866
|
+
|
|
867
|
+
# Convert back to pandas to match input type
|
|
868
|
+
return combined_df.to_pandas()
|
|
869
|
+
|
|
870
|
+
elif isinstance(with_original_df, pl.DataFrame):
|
|
871
|
+
return with_original_df.with_columns(results_df)
|
|
872
|
+
|
|
873
|
+
# Return pd.DataFrame type when appropriate
|
|
874
|
+
if with_original_df is None and isinstance(with_original_df, pd.DataFrame):
|
|
875
|
+
return results_df.to_pandas()
|
|
876
|
+
|
|
877
|
+
return results_df
|
|
872
878
|
|
|
873
879
|
def cancel_job(self, job_id: str):
|
|
874
880
|
"""
|
|
@@ -1227,6 +1233,7 @@ class Sutro:
|
|
|
1227
1233
|
with yaspin(
|
|
1228
1234
|
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1229
1235
|
) as spinner:
|
|
1236
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
|
|
1230
1237
|
while (time.time() - start_time) < timeout:
|
|
1231
1238
|
try:
|
|
1232
1239
|
status = self._fetch_job_status(job_id)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|