sutro 0.1.18__tar.gz → 0.1.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.18
3
+ Version: 0.1.20
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -9,7 +9,7 @@ installer = "uv"
9
9
 
10
10
  [project]
11
11
  name = "sutro"
12
- version = "0.1.18"
12
+ version = "0.1.20"
13
13
  description = "Sutro Python SDK"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -97,6 +97,16 @@ def to_colored_text(
97
97
  # Default to blue for normal/processing states
98
98
  return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
99
99
 
100
+ # Isn't fully support in all terminals unfortunately. We should switch to Rich
101
+ # at some point, but even Rich links aren't clickable on MacOS Terminal
102
+ def make_clickable_link(url, text=None):
103
+ """
104
+ Create a clickable link for terminals that support OSC 8 hyperlinks.
105
+ Falls back to plain text for terminals that don't support it.
106
+ """
107
+ if text is None:
108
+ text = url
109
+ return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
100
110
 
101
111
  class Sutro:
102
112
  def __init__(
@@ -104,7 +114,6 @@ class Sutro:
104
114
  ):
105
115
  self.api_key = api_key or self.check_for_api_key()
106
116
  self.base_url = base_url
107
- self.HEARTBEAT_INTERVAL_SECONDS = 15 # Keep in sync w what the backend expects
108
117
 
109
118
  def check_for_api_key(self):
110
119
  """
@@ -276,68 +285,67 @@ class Sutro:
276
285
  job_id = None
277
286
  t = f"Creating {'[dry run] ' if dry_run else ''}priority {job_priority} job"
278
287
  spinner_text = to_colored_text(t)
279
- with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
280
- response = requests.post(
281
- endpoint, data=json.dumps(payload), headers=headers
282
- )
283
- response_data = response.json()
284
- if response.status_code != 200:
285
- spinner.write(
286
- to_colored_text(f"Error: {response.status_code}", state="fail")
288
+ try:
289
+ with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
290
+ response = requests.post(
291
+ endpoint, data=json.dumps(payload), headers=headers
287
292
  )
288
- spinner.stop()
289
- print(to_colored_text(response.json(), state="fail"))
290
- return None
291
- else:
292
- job_id = response_data["results"]
293
- if dry_run:
293
+ response_data = response.json()
294
+ if response.status_code != 200:
294
295
  spinner.write(
295
- to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
296
+ to_colored_text(f"Error: {response.status_code}", state="fail")
296
297
  )
297
298
  spinner.stop()
298
- self.await_job_completion(job_id, obtain_results=False)
299
- cost_estimate = self._get_job_cost_estimate(job_id)
300
- spinner.write(
301
- to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
302
- )
303
- return job_id
299
+ print(to_colored_text(response.json(), state="fail"))
300
+ return None
304
301
  else:
305
- spinner.write(
306
- to_colored_text(
307
- f"🛠️ Priority {job_priority} Job created with ID: {job_id}",
308
- state="success",
302
+ job_id = response_data["results"]
303
+ if dry_run:
304
+ spinner.write(
305
+ to_colored_text(f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later.", state="info")
309
306
  )
310
- )
311
- if not stay_attached:
307
+ spinner.stop()
308
+ self.await_job_completion(job_id, obtain_results=False)
309
+ cost_estimate = self._get_job_cost_estimate(job_id)
310
+ spinner.write(
311
+ to_colored_text(f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}", state="success")
312
+ )
313
+ return job_id
314
+ else:
312
315
  spinner.write(
313
316
  to_colored_text(
314
- f"Use `so.get_job_status('{job_id}')` to check the status of the job."
317
+ f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
318
+ state="success",
315
319
  )
316
320
  )
317
- return job_id
321
+ if not stay_attached:
322
+ spinner.write(
323
+ to_colored_text(
324
+ f"Use `so.get_job_status('{job_id}')` to check the status of the job."
325
+ )
326
+ )
327
+ return job_id
328
+ except KeyboardInterrupt:
329
+ pass
330
+ finally:
331
+ if spinner:
332
+ spinner.stop()
318
333
 
319
334
  success = False
320
335
  if stay_attached and job_id is not None:
321
- spinner.write(to_colored_text("Awaiting job start...", "info"))
336
+ spinner.write(to_colored_text("Awaiting job start...", ))
337
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
322
338
  started = self._await_job_start(job_id)
323
339
  if not started:
324
340
  failure_reason = self._get_failure_reason(job_id)
325
341
  spinner.write(to_colored_text(f"Failure reason: {failure_reason['message']}", "fail"))
326
342
  return None
327
-
328
343
  s = requests.Session()
329
- payload = {
330
- "job_id": job_id,
331
- }
332
344
  pbar = None
333
345
 
334
- # Register for stream and get session token
335
- session_token = self.register_stream_listener(job_id)
336
-
337
- # Use the heartbeat session context manager
338
- with self.stream_heartbeat_session(job_id, session_token) as s:
339
- with s.get(
340
- f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
346
+ try:
347
+ with requests.get(
348
+ f"{self.base_url}/stream-job-progress/{job_id}",
341
349
  headers=headers,
342
350
  stream=True,
343
351
  ) as streaming_response:
@@ -348,6 +356,13 @@ class Sutro:
348
356
  color=YASPIN_COLOR,
349
357
  )
350
358
  spinner.start()
359
+
360
+ token_state = {
361
+ 'input_tokens': 0,
362
+ 'output_tokens': 0,
363
+ 'total_tokens_processed_per_second': 0
364
+ }
365
+
351
366
  for line in streaming_response.iter_lines():
352
367
  if line:
353
368
  try:
@@ -370,12 +385,30 @@ class Sutro:
370
385
  pbar.update(json_obj["result"] - pbar.n)
371
386
  pbar.refresh()
372
387
  if json_obj["result"] == len(input_data):
373
- pbar.close()
374
388
  success = True
375
389
  elif json_obj["update_type"] == "tokens":
390
+ # Update only the values that are present in this update
391
+ # Currently, the way the progress stream endpoint is defined,
392
+ # its possible to have updates come in that only have 1 or 2 fields
393
+ new = {
394
+ k: v for k, v in json_obj.get('result', {}).items()
395
+ if k in token_state and v >= token_state[k]
396
+ }
397
+ token_state.update(new)
398
+
376
399
  if pbar is not None:
377
- pbar.postfix = f"Input tokens processed: {json_obj['result']['input_tokens']}, Tokens generated: {json_obj['result']['output_tokens']}, Total tokens/s: {json_obj['result'].get('total_tokens_processed_per_second')}"
400
+ pbar.postfix = f"Input tokens processed: {token_state['input_tokens']}, Output tokens generated: {token_state['output_tokens']}, Total tokens/s: {token_state['total_tokens_processed_per_second']}"
378
401
  pbar.refresh()
402
+
403
+ except KeyboardInterrupt:
404
+ pass
405
+ finally:
406
+ # Need to clean these up on keyboard exit otherwise it causes
407
+ # an error
408
+ if pbar is not None:
409
+ pbar.close()
410
+ if spinner is not None:
411
+ spinner.stop()
379
412
  if success:
380
413
  spinner.text = to_colored_text(
381
414
  "✔ Job succeeded. Obtaining results...", state="success"
@@ -440,87 +473,6 @@ class Sutro:
440
473
  return None
441
474
  return None
442
475
 
443
- def register_stream_listener(self, job_id: str) -> str:
444
- """Register a new stream listener and get a session token."""
445
- headers = {
446
- "Authorization": f"Key {self.api_key}",
447
- "Content-Type": "application/json",
448
- }
449
- with requests.post(
450
- f"{self.base_url}/register-stream-listener/{job_id}",
451
- headers=headers,
452
- ) as response:
453
- response.raise_for_status()
454
- data = response.json()
455
- return data["request_session_token"]
456
-
457
- # This is a best effort action and is ok if it sometimes doesn't complete etc
458
- def unregister_stream_listener(self, job_id: str, session_token: str):
459
- """Explicitly unregister a stream listener."""
460
- headers = {
461
- "Authorization": f"Key {self.api_key}",
462
- "Content-Type": "application/json",
463
- }
464
- with requests.post(
465
- f"{self.base_url}/unregister-stream-listener/{job_id}",
466
- headers=headers,
467
- json={"request_session_token": session_token},
468
- ) as response:
469
- response.raise_for_status()
470
-
471
- def start_heartbeat(
472
- self,
473
- job_id: str,
474
- session_token: str,
475
- session: requests.Session,
476
- stop_event: threading.Event
477
- ):
478
- """Send heartbeats until stopped."""
479
- while not stop_event.is_set():
480
- try:
481
- headers = {
482
- "Authorization": f"Key {self.api_key}",
483
- "Content-Type": "application/json",
484
- }
485
- response = session.post(
486
- f"{self.base_url}/stream-heartbeat/{job_id}",
487
- headers=headers,
488
- params={"request_session_token": session_token},
489
- )
490
- response.raise_for_status()
491
- except Exception as e:
492
- if not stop_event.is_set(): # Only log if we weren't stopping anyway
493
- print(f"Heartbeat failed for job {job_id}: {e}")
494
-
495
- for _ in range(self.HEARTBEAT_INTERVAL_SECONDS):
496
- if stop_event.is_set():
497
- break
498
- time.sleep(1)
499
-
500
- @contextmanager
501
- def stream_heartbeat_session(self, job_id: str, session_token: str) -> Generator[requests.Session, None, None]:
502
- """Context manager that handles session registration and heartbeat."""
503
- session = requests.Session()
504
- stop_heartbeat = threading.Event()
505
-
506
- # Run this concurrently in a thread so we can not block main SDK path/behavior
507
- # but still run heartbeat requests
508
- with ThreadPoolExecutor(max_workers=1) as executor:
509
- executor.submit(
510
- self.start_heartbeat,
511
- job_id,
512
- session_token,
513
- session,
514
- stop_heartbeat
515
- )
516
-
517
- try:
518
- yield session
519
- finally:
520
- # Signal stop and cleanup
521
- stop_heartbeat.set()
522
- self.unregister_stream_listener(job_id, session_token)
523
- session.close()
524
476
 
525
477
  def attach(self, job_id):
526
478
  """
@@ -585,11 +537,9 @@ class Sutro:
585
537
  total_rows = job["num_rows"]
586
538
  success = False
587
539
 
588
- session_token = self.register_stream_listener(job_id)
589
-
590
- with self.stream_heartbeat_session(job_id, session_token) as s:
540
+ try:
591
541
  with s.get(
592
- f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
542
+ f"{self.base_url}/stream-job-progress/{job_id}",
593
543
  headers=headers,
594
544
  stream=True,
595
545
  ) as streaming_response:
@@ -599,6 +549,7 @@ class Sutro:
599
549
  text=to_colored_text("Awaiting status updates..."),
600
550
  color=YASPIN_COLOR,
601
551
  )
552
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
602
553
  spinner.start()
603
554
  for line in streaming_response.iter_lines():
604
555
  if line:
@@ -637,6 +588,13 @@ class Sutro:
637
588
  )
638
589
  )
639
590
  spinner.stop()
591
+ except KeyboardInterrupt:
592
+ pass
593
+ finally:
594
+ if pbar:
595
+ pbar.close()
596
+ if spinner:
597
+ spinner.stop()
640
598
 
641
599
 
642
600
 
@@ -819,10 +777,12 @@ class Sutro:
819
777
  return None
820
778
 
821
779
  def get_job_results(
822
- self,
823
- job_id: str,
824
- include_inputs: bool = False,
825
- include_cumulative_logprobs: bool = False,
780
+ self,
781
+ job_id: str,
782
+ include_inputs: bool = False,
783
+ include_cumulative_logprobs: bool = False,
784
+ with_original_df: pl.DataFrame | pd.DataFrame = None,
785
+ output_column: str = "inference_result",
826
786
  ):
827
787
  """
828
788
  Get the results of a job by its ID.
@@ -833,9 +793,11 @@ class Sutro:
833
793
  job_id (str): The ID of the job to retrieve the results for.
834
794
  include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
835
795
  include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
796
+ with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
797
+ output_column (str, optional): Name of the output column. Defaults to "inference_result".
836
798
 
837
799
  Returns:
838
- list: The results of the job.
800
+ Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
839
801
  """
840
802
  endpoint = f"{self.base_url}/job-results"
841
803
  payload = {
@@ -848,18 +810,14 @@ class Sutro:
848
810
  "Content-Type": "application/json",
849
811
  }
850
812
  with yaspin(
851
- SPINNER,
852
- text=to_colored_text(f"Gathering results from job: {job_id}"),
853
- color=YASPIN_COLOR,
813
+ SPINNER,
814
+ text=to_colored_text(f"Gathering results from job: {job_id}"),
815
+ color=YASPIN_COLOR,
854
816
  ) as spinner:
855
817
  response = requests.post(
856
818
  endpoint, data=json.dumps(payload), headers=headers
857
819
  )
858
- if response.status_code == 200:
859
- spinner.write(
860
- to_colored_text("✔ Job results retrieved", state="success")
861
- )
862
- else:
820
+ if response.status_code != 200:
863
821
  spinner.write(
864
822
  to_colored_text(
865
823
  f"Bad status code: {response.status_code}", state="fail"
@@ -867,8 +825,56 @@ class Sutro:
867
825
  )
868
826
  spinner.stop()
869
827
  print(to_colored_text(response.json(), state="fail"))
870
- return
871
- return response.json()["results"]
828
+ return None
829
+
830
+ spinner.write(
831
+ to_colored_text("✔ Job results retrieved", state="success")
832
+ )
833
+
834
+ response_data = response.json()
835
+ results_df = pl.DataFrame(response_data["results"])
836
+
837
+
838
+ if len(results_df.columns ) == 1:
839
+ # Default column when API is only returning a list, and we construct the df
840
+ # from that
841
+ original_results_column = 'column_0'
842
+ else:
843
+ original_results_column = 'outputs'
844
+
845
+ results_df = results_df.rename({original_results_column: output_column})
846
+
847
+ # Ordering inputs col first seems most logical/useful
848
+ column_config = [
849
+ ('inputs', include_inputs),
850
+ (output_column, True),
851
+ ('cumulative_logprobs', include_cumulative_logprobs),
852
+ ]
853
+
854
+ columns_to_keep = [col for col, include in column_config
855
+ if include and col in results_df.columns]
856
+
857
+ results_df = results_df.select(columns_to_keep)
858
+
859
+ # Handle concatenation with original DataFrame
860
+ if with_original_df is not None:
861
+ if isinstance(with_original_df, pd.DataFrame):
862
+ # Convert to polars for consistent handling
863
+ original_pl = pl.from_pandas(with_original_df)
864
+
865
+ combined_df = original_pl.with_columns(results_df)
866
+
867
+ # Convert back to pandas to match input type
868
+ return combined_df.to_pandas()
869
+
870
+ elif isinstance(with_original_df, pl.DataFrame):
871
+ return with_original_df.with_columns(results_df)
872
+
873
+ # Return pd.DataFrame type when appropriate
874
+ if with_original_df is None and isinstance(with_original_df, pd.DataFrame):
875
+ return results_df.to_pandas()
876
+
877
+ return results_df
872
878
 
873
879
  def cancel_job(self, job_id: str):
874
880
  """
@@ -1227,6 +1233,7 @@ class Sutro:
1227
1233
  with yaspin(
1228
1234
  SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1229
1235
  ) as spinner:
1236
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
1230
1237
  while (time.time() - start_time) < timeout:
1231
1238
  try:
1232
1239
  status = self._fetch_job_status(job_id)
File without changes
File without changes
File without changes
File without changes
File without changes