sutro 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/cli.py +6 -4
- sutro/sdk.py +180 -94
- {sutro-0.1.29.dist-info → sutro-0.1.31.dist-info}/METADATA +3 -1
- sutro-0.1.31.dist-info/RECORD +8 -0
- sutro-0.1.29.dist-info/RECORD +0 -8
- {sutro-0.1.29.dist-info → sutro-0.1.31.dist-info}/WHEEL +0 -0
- {sutro-0.1.29.dist-info → sutro-0.1.31.dist-info}/entry_points.txt +0 -0
- {sutro-0.1.29.dist-info → sutro-0.1.31.dist-info}/licenses/LICENSE +0 -0
sutro/cli.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from datetime import
|
|
1
|
+
from datetime import timezone
|
|
2
2
|
import click
|
|
3
3
|
from colorama import Fore, Style
|
|
4
4
|
import os
|
|
@@ -35,9 +35,7 @@ def check_auth():
|
|
|
35
35
|
def get_sdk():
|
|
36
36
|
config = load_config()
|
|
37
37
|
if config.get("base_url") != None:
|
|
38
|
-
return Sutro(
|
|
39
|
-
api_key=config.get("api_key"), base_url=config.get("base_url")
|
|
40
|
-
)
|
|
38
|
+
return Sutro(api_key=config.get("api_key"), base_url=config.get("base_url"))
|
|
41
39
|
else:
|
|
42
40
|
return Sutro(api_key=config.get("api_key"))
|
|
43
41
|
|
|
@@ -141,6 +139,7 @@ def jobs():
|
|
|
141
139
|
"""Manage jobs."""
|
|
142
140
|
pass
|
|
143
141
|
|
|
142
|
+
|
|
144
143
|
@jobs.command()
|
|
145
144
|
@click.option(
|
|
146
145
|
"--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
|
|
@@ -359,11 +358,13 @@ def download(dataset_id, file_name=None, output_path=None):
|
|
|
359
358
|
with open(output_path + "/" + file_name, "wb") as f:
|
|
360
359
|
f.write(file)
|
|
361
360
|
|
|
361
|
+
|
|
362
362
|
@cli.group()
|
|
363
363
|
def cache():
|
|
364
364
|
"""Manage the local job results cache."""
|
|
365
365
|
pass
|
|
366
366
|
|
|
367
|
+
|
|
367
368
|
@cache.command()
|
|
368
369
|
def clear():
|
|
369
370
|
"""Clear the local job results cache."""
|
|
@@ -371,6 +372,7 @@ def clear():
|
|
|
371
372
|
sdk._clear_job_results_cache()
|
|
372
373
|
click.echo(Fore.GREEN + "Job results cache cleared." + Style.RESET_ALL)
|
|
373
374
|
|
|
375
|
+
|
|
374
376
|
@cache.command()
|
|
375
377
|
def show():
|
|
376
378
|
"""Show the contents and size of the job results cache."""
|
sutro/sdk.py
CHANGED
|
@@ -1,22 +1,17 @@
|
|
|
1
|
-
import threading
|
|
2
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
-
from contextlib import contextmanager
|
|
4
1
|
from enum import Enum
|
|
5
|
-
|
|
6
2
|
import requests
|
|
7
3
|
import pandas as pd
|
|
8
4
|
import polars as pl
|
|
9
5
|
import json
|
|
10
|
-
from typing import Union, List, Optional, Literal,
|
|
6
|
+
from typing import Union, List, Optional, Literal, Dict, Any
|
|
11
7
|
import os
|
|
12
8
|
import sys
|
|
13
9
|
from yaspin import yaspin
|
|
14
10
|
from yaspin.spinners import Spinners
|
|
15
|
-
from colorama import init, Fore,
|
|
11
|
+
from colorama import init, Fore, Style
|
|
16
12
|
from tqdm import tqdm
|
|
17
13
|
import time
|
|
18
14
|
from pydantic import BaseModel
|
|
19
|
-
import json
|
|
20
15
|
import pyarrow.parquet as pq
|
|
21
16
|
import shutil
|
|
22
17
|
|
|
@@ -45,6 +40,7 @@ class JobStatus(str, Enum):
|
|
|
45
40
|
def is_terminal(self) -> bool:
|
|
46
41
|
return self in self.terminal_statuses()
|
|
47
42
|
|
|
43
|
+
|
|
48
44
|
# Initialize colorama (required for Windows)
|
|
49
45
|
init()
|
|
50
46
|
|
|
@@ -71,8 +67,13 @@ ModelOptions = Literal[
|
|
|
71
67
|
"qwen-3-32b-thinking",
|
|
72
68
|
"gemma-3-4b-it",
|
|
73
69
|
"gemma-3-27b-it",
|
|
74
|
-
"
|
|
75
|
-
"
|
|
70
|
+
"gpt-oss-120b",
|
|
71
|
+
"gpt-oss-20b",
|
|
72
|
+
"qwen-3-235b-a22b-thinking",
|
|
73
|
+
"qwen-3-30b-a3b-thinking",
|
|
74
|
+
"qwen-3-embedding-0.6b",
|
|
75
|
+
"qwen-3-embedding-6b",
|
|
76
|
+
"qwen-3-embedding-8b",
|
|
76
77
|
]
|
|
77
78
|
|
|
78
79
|
|
|
@@ -99,6 +100,7 @@ def to_colored_text(
|
|
|
99
100
|
# Default to blue for normal/processing states
|
|
100
101
|
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
101
102
|
|
|
103
|
+
|
|
102
104
|
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
103
105
|
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
104
106
|
def make_clickable_link(url, text=None):
|
|
@@ -110,10 +112,9 @@ def make_clickable_link(url, text=None):
|
|
|
110
112
|
text = url
|
|
111
113
|
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
112
114
|
|
|
115
|
+
|
|
113
116
|
class Sutro:
|
|
114
|
-
def __init__(
|
|
115
|
-
self, api_key: str = None, base_url: str = "https://api.sutro.sh/"
|
|
116
|
-
):
|
|
117
|
+
def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
|
|
117
118
|
self.api_key = api_key or self.check_for_api_key()
|
|
118
119
|
self.base_url = base_url
|
|
119
120
|
|
|
@@ -220,7 +221,7 @@ class Sutro:
|
|
|
220
221
|
cost_estimate: bool,
|
|
221
222
|
stay_attached: Optional[bool],
|
|
222
223
|
random_seed_per_input: bool,
|
|
223
|
-
truncate_rows: bool
|
|
224
|
+
truncate_rows: bool,
|
|
224
225
|
):
|
|
225
226
|
input_data = self.handle_data_helper(data, column)
|
|
226
227
|
endpoint = f"{self.base_url}/batch-inference"
|
|
@@ -237,7 +238,7 @@ class Sutro:
|
|
|
237
238
|
"cost_estimate": cost_estimate,
|
|
238
239
|
"sampling_params": sampling_params,
|
|
239
240
|
"random_seed_per_input": random_seed_per_input,
|
|
240
|
-
"truncate_rows": truncate_rows
|
|
241
|
+
"truncate_rows": truncate_rows,
|
|
241
242
|
}
|
|
242
243
|
|
|
243
244
|
# There are two gotchas with yaspin:
|
|
@@ -266,13 +267,20 @@ class Sutro:
|
|
|
266
267
|
job_id = response_data["results"]
|
|
267
268
|
if cost_estimate:
|
|
268
269
|
spinner.write(
|
|
269
|
-
to_colored_text(
|
|
270
|
+
to_colored_text(
|
|
271
|
+
f"Awaiting cost estimates with job ID: {job_id}. You can safely detach and retrieve the cost estimates later."
|
|
272
|
+
)
|
|
270
273
|
)
|
|
271
274
|
spinner.stop()
|
|
272
|
-
self.await_job_completion(
|
|
275
|
+
self.await_job_completion(
|
|
276
|
+
job_id, obtain_results=False, is_cost_estimate=True
|
|
277
|
+
)
|
|
273
278
|
cost_estimate = self._get_job_cost_estimate(job_id)
|
|
274
279
|
spinner.write(
|
|
275
|
-
to_colored_text(
|
|
280
|
+
to_colored_text(
|
|
281
|
+
f"✔ Cost estimates retrieved for job {job_id}: ${cost_estimate}",
|
|
282
|
+
state="success",
|
|
283
|
+
)
|
|
276
284
|
)
|
|
277
285
|
return job_id
|
|
278
286
|
else:
|
|
@@ -283,12 +291,14 @@ class Sutro:
|
|
|
283
291
|
)
|
|
284
292
|
)
|
|
285
293
|
if not stay_attached:
|
|
286
|
-
clickable_link = make_clickable_link(
|
|
294
|
+
clickable_link = make_clickable_link(
|
|
295
|
+
f"https://app.sutro.sh/jobs/{job_id}"
|
|
296
|
+
)
|
|
287
297
|
spinner.write(
|
|
288
298
|
to_colored_text(
|
|
289
299
|
f"Use `so.get_job_status('{job_id}')` to check the status of the job, or monitor progress at {clickable_link}"
|
|
290
|
-
)
|
|
291
300
|
)
|
|
301
|
+
)
|
|
292
302
|
return job_id
|
|
293
303
|
except KeyboardInterrupt:
|
|
294
304
|
pass
|
|
@@ -298,22 +308,32 @@ class Sutro:
|
|
|
298
308
|
|
|
299
309
|
success = False
|
|
300
310
|
if stay_attached and job_id is not None:
|
|
301
|
-
spinner.write(
|
|
302
|
-
|
|
303
|
-
|
|
311
|
+
spinner.write(
|
|
312
|
+
to_colored_text(
|
|
313
|
+
"Awaiting job start...",
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
clickable_link = make_clickable_link(f"https://app.sutro.sh/jobs/{job_id}")
|
|
317
|
+
spinner.write(
|
|
318
|
+
to_colored_text(f"Progress can also be monitored at: {clickable_link}")
|
|
319
|
+
)
|
|
304
320
|
started = self._await_job_start(job_id)
|
|
305
321
|
if not started:
|
|
306
322
|
failure_reason = self._get_failure_reason(job_id)
|
|
307
|
-
spinner.write(
|
|
323
|
+
spinner.write(
|
|
324
|
+
to_colored_text(
|
|
325
|
+
f"Failure reason: {failure_reason['message']}", "fail"
|
|
326
|
+
)
|
|
327
|
+
)
|
|
308
328
|
return None
|
|
309
329
|
s = requests.Session()
|
|
310
330
|
pbar = None
|
|
311
331
|
|
|
312
332
|
try:
|
|
313
333
|
with requests.get(
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
334
|
+
f"{self.base_url}/stream-job-progress/{job_id}",
|
|
335
|
+
headers=headers,
|
|
336
|
+
stream=True,
|
|
317
337
|
) as streaming_response:
|
|
318
338
|
streaming_response.raise_for_status()
|
|
319
339
|
spinner = yaspin(
|
|
@@ -324,9 +344,9 @@ class Sutro:
|
|
|
324
344
|
spinner.start()
|
|
325
345
|
|
|
326
346
|
token_state = {
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
347
|
+
"input_tokens": 0,
|
|
348
|
+
"output_tokens": 0,
|
|
349
|
+
"total_tokens_processed_per_second": 0,
|
|
330
350
|
}
|
|
331
351
|
|
|
332
352
|
for line in streaming_response.iter_lines():
|
|
@@ -340,7 +360,7 @@ class Sutro:
|
|
|
340
360
|
if json_obj["update_type"] == "progress":
|
|
341
361
|
if pbar is None:
|
|
342
362
|
spinner.stop()
|
|
343
|
-
postfix =
|
|
363
|
+
postfix = "Input tokens processed: 0"
|
|
344
364
|
pbar = self.fancy_tqdm(
|
|
345
365
|
total=len(input_data),
|
|
346
366
|
desc="Progress",
|
|
@@ -357,7 +377,8 @@ class Sutro:
|
|
|
357
377
|
# Currently, the way the progress stream endpoint is defined,
|
|
358
378
|
# its possible to have updates come in that only have 1 or 2 fields
|
|
359
379
|
new = {
|
|
360
|
-
k: v
|
|
380
|
+
k: v
|
|
381
|
+
for k, v in json_obj.get("result", {}).items()
|
|
361
382
|
if k in token_state and v >= token_state[k]
|
|
362
383
|
}
|
|
363
384
|
token_state.update(new)
|
|
@@ -388,8 +409,8 @@ class Sutro:
|
|
|
388
409
|
# TODO: we implment retries in cases where the job hasn't written results yet
|
|
389
410
|
# it would be better if we could receive a fully succeeded status from the job
|
|
390
411
|
# and not have such a race condition
|
|
391
|
-
max_retries = 20
|
|
392
|
-
retry_delay = 5
|
|
412
|
+
max_retries = 20 # winds up being 100 seconds cumulative delay
|
|
413
|
+
retry_delay = 5 # initial delay in seconds
|
|
393
414
|
|
|
394
415
|
for _ in range(max_retries):
|
|
395
416
|
time.sleep(retry_delay)
|
|
@@ -436,7 +457,7 @@ class Sutro:
|
|
|
436
457
|
def infer(
|
|
437
458
|
self,
|
|
438
459
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
439
|
-
model: Union[ModelOptions, List[ModelOptions]] = "
|
|
460
|
+
model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
|
|
440
461
|
column: str = None,
|
|
441
462
|
output_column: str = "inference_result",
|
|
442
463
|
job_priority: int = 0,
|
|
@@ -446,7 +467,7 @@ class Sutro:
|
|
|
446
467
|
dry_run: bool = False,
|
|
447
468
|
stay_attached: Optional[bool] = None,
|
|
448
469
|
random_seed_per_input: bool = False,
|
|
449
|
-
truncate_rows: bool = False
|
|
470
|
+
truncate_rows: bool = False,
|
|
450
471
|
):
|
|
451
472
|
"""
|
|
452
473
|
Run inference on the provided data.
|
|
@@ -475,22 +496,29 @@ class Sutro:
|
|
|
475
496
|
"""
|
|
476
497
|
if isinstance(model, list) == False:
|
|
477
498
|
model_list = [model]
|
|
478
|
-
stay_attached =
|
|
499
|
+
stay_attached = (
|
|
500
|
+
stay_attached if stay_attached is not None else job_priority == 0
|
|
501
|
+
)
|
|
479
502
|
else:
|
|
480
503
|
model_list = model
|
|
481
504
|
stay_attached = False
|
|
482
505
|
|
|
483
506
|
# Convert BaseModel to dict if needed
|
|
484
507
|
if output_schema is not None:
|
|
485
|
-
if hasattr(
|
|
508
|
+
if hasattr(
|
|
509
|
+
output_schema, "model_json_schema"
|
|
510
|
+
): # Check for pydantic Model interface
|
|
486
511
|
json_schema = output_schema.model_json_schema()
|
|
487
512
|
elif isinstance(output_schema, dict):
|
|
488
513
|
json_schema = output_schema
|
|
489
514
|
else:
|
|
490
|
-
raise ValueError(
|
|
515
|
+
raise ValueError(
|
|
516
|
+
"Invalid output schema type. Must be a dictionary or a pydantic Model."
|
|
517
|
+
)
|
|
491
518
|
else:
|
|
492
519
|
json_schema = None
|
|
493
520
|
|
|
521
|
+
results = []
|
|
494
522
|
for model in model_list:
|
|
495
523
|
res = self._run_one_batch_inference(
|
|
496
524
|
data,
|
|
@@ -504,11 +532,16 @@ class Sutro:
|
|
|
504
532
|
dry_run,
|
|
505
533
|
stay_attached,
|
|
506
534
|
random_seed_per_input,
|
|
507
|
-
truncate_rows
|
|
535
|
+
truncate_rows,
|
|
508
536
|
)
|
|
509
|
-
|
|
510
|
-
|
|
537
|
+
results.append(res)
|
|
538
|
+
|
|
539
|
+
if len(results) > 1:
|
|
540
|
+
return results
|
|
541
|
+
elif len(results) == 1:
|
|
542
|
+
return results[0]
|
|
511
543
|
|
|
544
|
+
return None
|
|
512
545
|
|
|
513
546
|
def attach(self, job_id):
|
|
514
547
|
"""
|
|
@@ -530,9 +563,9 @@ class Sutro:
|
|
|
530
563
|
}
|
|
531
564
|
|
|
532
565
|
with yaspin(
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
566
|
+
SPINNER,
|
|
567
|
+
text=to_colored_text("Looking for job..."),
|
|
568
|
+
color=YASPIN_COLOR,
|
|
536
569
|
) as spinner:
|
|
537
570
|
# Fetch the specific job we want to attach to
|
|
538
571
|
job = self._fetch_job(job_id)
|
|
@@ -550,10 +583,14 @@ class Sutro:
|
|
|
550
583
|
)
|
|
551
584
|
return
|
|
552
585
|
case "FAILED":
|
|
553
|
-
spinner.write(
|
|
586
|
+
spinner.write(
|
|
587
|
+
to_colored_text("❌ Job is in failed state.", state="fail")
|
|
588
|
+
)
|
|
554
589
|
return
|
|
555
590
|
case "CANCELLED":
|
|
556
|
-
spinner.write(
|
|
591
|
+
spinner.write(
|
|
592
|
+
to_colored_text("❌ Job was cancelled.", state="fail")
|
|
593
|
+
)
|
|
557
594
|
return
|
|
558
595
|
case _:
|
|
559
596
|
spinner.write(to_colored_text("✔ Job found!", state="success"))
|
|
@@ -563,9 +600,9 @@ class Sutro:
|
|
|
563
600
|
|
|
564
601
|
try:
|
|
565
602
|
with s.get(
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
603
|
+
f"{self.base_url}/stream-job-progress/{job_id}",
|
|
604
|
+
headers=headers,
|
|
605
|
+
stream=True,
|
|
569
606
|
) as streaming_response:
|
|
570
607
|
streaming_response.raise_for_status()
|
|
571
608
|
spinner = yaspin(
|
|
@@ -573,8 +610,14 @@ class Sutro:
|
|
|
573
610
|
text=to_colored_text("Awaiting status updates..."),
|
|
574
611
|
color=YASPIN_COLOR,
|
|
575
612
|
)
|
|
576
|
-
clickable_link = make_clickable_link(
|
|
577
|
-
|
|
613
|
+
clickable_link = make_clickable_link(
|
|
614
|
+
f"https://app.sutro.sh/jobs/{job_id}"
|
|
615
|
+
)
|
|
616
|
+
spinner.write(
|
|
617
|
+
to_colored_text(
|
|
618
|
+
f"Progress can also be monitored at: {clickable_link}"
|
|
619
|
+
)
|
|
620
|
+
)
|
|
578
621
|
spinner.start()
|
|
579
622
|
for line in streaming_response.iter_lines():
|
|
580
623
|
if line:
|
|
@@ -587,7 +630,7 @@ class Sutro:
|
|
|
587
630
|
if json_obj["update_type"] == "progress":
|
|
588
631
|
if pbar is None:
|
|
589
632
|
spinner.stop()
|
|
590
|
-
postfix =
|
|
633
|
+
postfix = "Input tokens processed: 0"
|
|
591
634
|
pbar = self.fancy_tqdm(
|
|
592
635
|
total=total_rows,
|
|
593
636
|
desc="Progress",
|
|
@@ -621,8 +664,6 @@ class Sutro:
|
|
|
621
664
|
if spinner:
|
|
622
665
|
spinner.stop()
|
|
623
666
|
|
|
624
|
-
|
|
625
|
-
|
|
626
667
|
def fancy_tqdm(
|
|
627
668
|
self,
|
|
628
669
|
total: int,
|
|
@@ -738,7 +779,7 @@ class Sutro:
|
|
|
738
779
|
response = requests.get(endpoint, headers=headers)
|
|
739
780
|
if response.status_code != 200:
|
|
740
781
|
return None
|
|
741
|
-
return response.json().get(
|
|
782
|
+
return response.json().get("job")
|
|
742
783
|
|
|
743
784
|
def _get_job_cost_estimate(self, job_id: str):
|
|
744
785
|
"""
|
|
@@ -748,8 +789,8 @@ class Sutro:
|
|
|
748
789
|
if not job:
|
|
749
790
|
return None
|
|
750
791
|
|
|
751
|
-
return job.get(
|
|
752
|
-
|
|
792
|
+
return job.get("cost_estimate")
|
|
793
|
+
|
|
753
794
|
def _get_failure_reason(self, job_id: str):
|
|
754
795
|
"""
|
|
755
796
|
Get the failure reason for a job.
|
|
@@ -757,7 +798,7 @@ class Sutro:
|
|
|
757
798
|
job = self._fetch_job(job_id)
|
|
758
799
|
if not job:
|
|
759
800
|
return None
|
|
760
|
-
return job.get(
|
|
801
|
+
return job.get("failure_reason")
|
|
761
802
|
|
|
762
803
|
def _fetch_job_status(self, job_id: str):
|
|
763
804
|
"""
|
|
@@ -796,13 +837,15 @@ class Sutro:
|
|
|
796
837
|
str: The status of the job.
|
|
797
838
|
"""
|
|
798
839
|
with yaspin(
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
840
|
+
SPINNER,
|
|
841
|
+
text=to_colored_text(f"Checking job status with ID: {job_id}"),
|
|
842
|
+
color=YASPIN_COLOR,
|
|
802
843
|
) as spinner:
|
|
803
844
|
try:
|
|
804
845
|
response_data = self._fetch_job_status(job_id)
|
|
805
|
-
spinner.write(
|
|
846
|
+
spinner.write(
|
|
847
|
+
to_colored_text("✔ Job status retrieved!", state="success")
|
|
848
|
+
)
|
|
806
849
|
return response_data["job_status"][job_id]
|
|
807
850
|
except requests.HTTPError as e:
|
|
808
851
|
spinner.write(
|
|
@@ -842,14 +885,13 @@ class Sutro:
|
|
|
842
885
|
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
843
886
|
"""
|
|
844
887
|
|
|
845
|
-
|
|
846
888
|
file_path = os.path.expanduser(f"~/.sutro/job-results/{job_id}.snappy.parquet")
|
|
847
889
|
expected_num_columns = 1 + include_inputs + include_cumulative_logprobs
|
|
848
890
|
contains_expected_columns = False
|
|
849
891
|
if os.path.exists(file_path):
|
|
850
892
|
num_columns = pq.read_table(file_path).num_columns
|
|
851
893
|
contains_expected_columns = num_columns == expected_num_columns
|
|
852
|
-
|
|
894
|
+
|
|
853
895
|
if disable_cache == False and contains_expected_columns:
|
|
854
896
|
with yaspin(
|
|
855
897
|
SPINNER,
|
|
@@ -857,7 +899,9 @@ class Sutro:
|
|
|
857
899
|
color=YASPIN_COLOR,
|
|
858
900
|
) as spinner:
|
|
859
901
|
results_df = pl.read_parquet(file_path)
|
|
860
|
-
spinner.write(
|
|
902
|
+
spinner.write(
|
|
903
|
+
to_colored_text("✔ Results loaded from cache", state="success")
|
|
904
|
+
)
|
|
861
905
|
else:
|
|
862
906
|
endpoint = f"{self.base_url}/job-results"
|
|
863
907
|
payload = {
|
|
@@ -894,38 +938,51 @@ class Sutro:
|
|
|
894
938
|
response_data = response.json()
|
|
895
939
|
results_df = pl.DataFrame(response_data["results"])
|
|
896
940
|
|
|
897
|
-
results_df = results_df.rename({
|
|
941
|
+
results_df = results_df.rename({"outputs": output_column})
|
|
898
942
|
|
|
899
943
|
if disable_cache == False:
|
|
900
944
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
901
945
|
results_df.write_parquet(file_path, compression="snappy")
|
|
902
|
-
spinner.write(
|
|
903
|
-
|
|
946
|
+
spinner.write(
|
|
947
|
+
to_colored_text("✔ Results saved to cache", state="success")
|
|
948
|
+
)
|
|
949
|
+
|
|
904
950
|
# Ordering inputs col first seems most logical/useful
|
|
905
951
|
column_config = [
|
|
906
|
-
(
|
|
952
|
+
("inputs", include_inputs),
|
|
907
953
|
(output_column, True),
|
|
908
|
-
(
|
|
954
|
+
("cumulative_logprobs", include_cumulative_logprobs),
|
|
909
955
|
]
|
|
910
956
|
|
|
911
|
-
columns_to_keep = [
|
|
912
|
-
|
|
957
|
+
columns_to_keep = [
|
|
958
|
+
col
|
|
959
|
+
for col, include in column_config
|
|
960
|
+
if include and col in results_df.columns
|
|
961
|
+
]
|
|
913
962
|
|
|
914
963
|
results_df = results_df.select(columns_to_keep)
|
|
915
964
|
|
|
916
965
|
if unpack_json:
|
|
917
966
|
try:
|
|
918
|
-
first_row = json.loads(
|
|
967
|
+
first_row = json.loads(
|
|
968
|
+
results_df.head(1)[output_column][0]
|
|
969
|
+
) # checks if the first row can be json decoded
|
|
919
970
|
results_df = results_df.with_columns(
|
|
920
|
-
pl.col(output_column)
|
|
971
|
+
pl.col(output_column)
|
|
972
|
+
.str.json_decode()
|
|
973
|
+
.alias("output_column_json_decoded")
|
|
921
974
|
)
|
|
922
975
|
json_decoded_fields = first_row.keys()
|
|
923
976
|
for field in json_decoded_fields:
|
|
924
977
|
results_df = results_df.with_columns(
|
|
925
|
-
pl.col("output_column_json_decoded")
|
|
978
|
+
pl.col("output_column_json_decoded")
|
|
979
|
+
.struct.field(field)
|
|
980
|
+
.alias(field)
|
|
926
981
|
)
|
|
927
982
|
# drop the output_column and the json decoded column
|
|
928
|
-
results_df = results_df.drop(
|
|
983
|
+
results_df = results_df.drop(
|
|
984
|
+
[output_column, "output_column_json_decoded"]
|
|
985
|
+
)
|
|
929
986
|
except json.JSONDecodeError:
|
|
930
987
|
# if the first row cannot be json decoded, do nothing
|
|
931
988
|
pass
|
|
@@ -1011,7 +1068,9 @@ class Sutro:
|
|
|
1011
1068
|
return
|
|
1012
1069
|
dataset_id = response.json()["dataset_id"]
|
|
1013
1070
|
spinner.write(
|
|
1014
|
-
to_colored_text(
|
|
1071
|
+
to_colored_text(
|
|
1072
|
+
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
1073
|
+
)
|
|
1015
1074
|
)
|
|
1016
1075
|
return dataset_id
|
|
1017
1076
|
|
|
@@ -1079,8 +1138,7 @@ class Sutro:
|
|
|
1079
1138
|
"dataset_id": dataset_id,
|
|
1080
1139
|
}
|
|
1081
1140
|
|
|
1082
|
-
headers = {
|
|
1083
|
-
"Authorization": f"Key {self.api_key}"}
|
|
1141
|
+
headers = {"Authorization": f"Key {self.api_key}"}
|
|
1084
1142
|
|
|
1085
1143
|
count += 1
|
|
1086
1144
|
spinner.write(
|
|
@@ -1164,7 +1222,9 @@ class Sutro:
|
|
|
1164
1222
|
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1165
1223
|
return
|
|
1166
1224
|
spinner.write(
|
|
1167
|
-
to_colored_text(
|
|
1225
|
+
to_colored_text(
|
|
1226
|
+
f"✔ Files listed in dataset: {dataset_id}", state="success"
|
|
1227
|
+
)
|
|
1168
1228
|
)
|
|
1169
1229
|
return response.json()["files"]
|
|
1170
1230
|
|
|
@@ -1286,7 +1346,13 @@ class Sutro:
|
|
|
1286
1346
|
return
|
|
1287
1347
|
return response.json()["quotas"]
|
|
1288
1348
|
|
|
1289
|
-
def await_job_completion(
|
|
1349
|
+
def await_job_completion(
|
|
1350
|
+
self,
|
|
1351
|
+
job_id: str,
|
|
1352
|
+
timeout: Optional[int] = 7200,
|
|
1353
|
+
obtain_results: bool = True,
|
|
1354
|
+
is_cost_estimate: bool = False,
|
|
1355
|
+
) -> list | None:
|
|
1290
1356
|
"""
|
|
1291
1357
|
Waits for job completion to occur and then returns the results upon
|
|
1292
1358
|
a successful completion.
|
|
@@ -1308,8 +1374,14 @@ class Sutro:
|
|
|
1308
1374
|
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1309
1375
|
) as spinner:
|
|
1310
1376
|
if not is_cost_estimate:
|
|
1311
|
-
clickable_link = make_clickable_link(
|
|
1312
|
-
|
|
1377
|
+
clickable_link = make_clickable_link(
|
|
1378
|
+
f"https://app.sutro.sh/jobs/{job_id}"
|
|
1379
|
+
)
|
|
1380
|
+
spinner.write(
|
|
1381
|
+
to_colored_text(
|
|
1382
|
+
f"Progress can also be monitored at: {clickable_link}"
|
|
1383
|
+
)
|
|
1384
|
+
)
|
|
1313
1385
|
while (time.time() - start_time) < timeout:
|
|
1314
1386
|
try:
|
|
1315
1387
|
status = self._fetch_job_status(job_id)
|
|
@@ -1326,9 +1398,13 @@ class Sutro:
|
|
|
1326
1398
|
spinner.text = to_colored_text(f"Job status is {status} for {job_id}")
|
|
1327
1399
|
|
|
1328
1400
|
if status == JobStatus.SUCCEEDED:
|
|
1329
|
-
spinner.stop()
|
|
1401
|
+
spinner.stop() # Stop this spinner as `get_job_results` has its own spinner text
|
|
1330
1402
|
if obtain_results:
|
|
1331
|
-
spinner.write(
|
|
1403
|
+
spinner.write(
|
|
1404
|
+
to_colored_text(
|
|
1405
|
+
"Job completed! Retrieving results...", "success"
|
|
1406
|
+
)
|
|
1407
|
+
)
|
|
1332
1408
|
results = self.get_job_results(job_id)
|
|
1333
1409
|
break
|
|
1334
1410
|
if status == JobStatus.FAILED:
|
|
@@ -1338,12 +1414,11 @@ class Sutro:
|
|
|
1338
1414
|
spinner.write(to_colored_text("Job has been cancelled"))
|
|
1339
1415
|
return None
|
|
1340
1416
|
|
|
1341
|
-
|
|
1342
1417
|
time.sleep(POLL_INTERVAL)
|
|
1343
1418
|
|
|
1344
1419
|
return results
|
|
1345
1420
|
|
|
1346
|
-
def _clear_job_results_cache(self):
|
|
1421
|
+
def _clear_job_results_cache(self): # only to be called by the CLI
|
|
1347
1422
|
"""
|
|
1348
1423
|
Clears the cache for a job results.
|
|
1349
1424
|
"""
|
|
@@ -1356,29 +1431,41 @@ class Sutro:
|
|
|
1356
1431
|
"""
|
|
1357
1432
|
# get the size of the job-results directory
|
|
1358
1433
|
with yaspin(
|
|
1359
|
-
SPINNER,
|
|
1434
|
+
SPINNER,
|
|
1435
|
+
text=to_colored_text("Retrieving job results cache contents"),
|
|
1436
|
+
color=YASPIN_COLOR,
|
|
1360
1437
|
) as spinner:
|
|
1361
1438
|
if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
|
|
1362
1439
|
spinner.write(to_colored_text("No job results cache found", "success"))
|
|
1363
1440
|
return
|
|
1364
1441
|
total_size = 0
|
|
1365
1442
|
for file in os.listdir(os.path.expanduser("~/.sutro/job-results")):
|
|
1366
|
-
size =
|
|
1443
|
+
size = (
|
|
1444
|
+
os.path.getsize(os.path.expanduser(f"~/.sutro/job-results/{file}"))
|
|
1445
|
+
/ 1024
|
|
1446
|
+
/ 1024
|
|
1447
|
+
/ 1024
|
|
1448
|
+
)
|
|
1367
1449
|
total_size += size
|
|
1368
1450
|
spinner.write(to_colored_text(f"File: {file} - Size: {size} GB"))
|
|
1369
|
-
spinner.write(
|
|
1370
|
-
|
|
1451
|
+
spinner.write(
|
|
1452
|
+
to_colored_text(
|
|
1453
|
+
f"Total size of results cache at ~/.sutro/job-results: {total_size} GB",
|
|
1454
|
+
"success",
|
|
1455
|
+
)
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1371
1458
|
def _await_job_start(self, job_id: str, timeout: Optional[int] = 7200):
|
|
1372
1459
|
"""
|
|
1373
1460
|
Waits for job start to occur and then returns the results upon
|
|
1374
1461
|
a successful start.
|
|
1375
|
-
|
|
1462
|
+
|
|
1376
1463
|
"""
|
|
1377
1464
|
POLL_INTERVAL = 5
|
|
1378
1465
|
|
|
1379
1466
|
start_time = time.time()
|
|
1380
1467
|
with yaspin(
|
|
1381
|
-
|
|
1468
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1382
1469
|
) as spinner:
|
|
1383
1470
|
while (time.time() - start_time) < timeout:
|
|
1384
1471
|
try:
|
|
@@ -1405,4 +1492,3 @@ class Sutro:
|
|
|
1405
1492
|
time.sleep(POLL_INTERVAL)
|
|
1406
1493
|
|
|
1407
1494
|
return False
|
|
1408
|
-
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sutro
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.31
|
|
4
4
|
Summary: Sutro Python SDK
|
|
5
5
|
Project-URL: Homepage, https://sutro.sh
|
|
6
6
|
Project-URL: Documentation, https://docs.sutro.sh
|
|
@@ -17,6 +17,8 @@ Requires-Dist: pydantic==2.11.4
|
|
|
17
17
|
Requires-Dist: requests==2.32.3
|
|
18
18
|
Requires-Dist: tqdm==4.67.1
|
|
19
19
|
Requires-Dist: yaspin==3.1.0
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: ruff==0.13.1; extra == 'dev'
|
|
20
22
|
Description-Content-Type: text/markdown
|
|
21
23
|
|
|
22
24
|
# sutro-client
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
+
sutro/cli.py,sha256=YzwIdpHUcG6GEc1IofRFD5rdxAZpgXy3OWH0MqehryA,13608
|
|
3
|
+
sutro/sdk.py,sha256=UoDGXsknj8S6aLn6d4GrtCFqhzqdlbqvcmOFG-hkr44,57008
|
|
4
|
+
sutro-0.1.31.dist-info/METADATA,sha256=9_2xTAwhE0_6USQFVyqy6CQIAs_W6Xc20QFtyHp2-gM,764
|
|
5
|
+
sutro-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
sutro-0.1.31.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
+
sutro-0.1.31.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
sutro-0.1.31.dist-info/RECORD,,
|
sutro-0.1.29.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
-
sutro/cli.py,sha256=8DrJVbjoayCUz4iszlj35Tv1q1gzDVzx_CuF6gZHwuU,13636
|
|
3
|
-
sutro/sdk.py,sha256=4cdRclI_meTrp26yhVeJDMlbWrqEb5Lwufi5d8WTmh0,55357
|
|
4
|
-
sutro-0.1.29.dist-info/METADATA,sha256=a6c-nO9s4x3a2cT99JuGIySRgqMzGlWA6iUEXtnGvow,700
|
|
5
|
-
sutro-0.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
sutro-0.1.29.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
-
sutro-0.1.29.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
sutro-0.1.29.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|