sutro 0.1.34__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutro/__init__.py +15 -7
- sutro/cli.py +1 -1
- sutro/common.py +220 -0
- sutro/interfaces.py +91 -0
- sutro/sdk.py +400 -444
- sutro/templates/classification.py +117 -0
- sutro/templates/embed.py +53 -0
- sutro/templates/evals.py +340 -0
- sutro/validation.py +60 -0
- {sutro-0.1.34.dist-info → sutro-0.1.40.dist-info}/METADATA +14 -16
- sutro-0.1.40.dist-info/RECORD +13 -0
- sutro-0.1.40.dist-info/WHEEL +4 -0
- {sutro-0.1.34.dist-info → sutro-0.1.40.dist-info}/entry_points.txt +1 -0
- sutro-0.1.34.dist-info/RECORD +0 -8
- sutro-0.1.34.dist-info/WHEEL +0 -4
- sutro-0.1.34.dist-info/licenses/LICENSE +0 -201
sutro/sdk.py
CHANGED
|
@@ -1,45 +1,32 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
1
|
import requests
|
|
3
2
|
import pandas as pd
|
|
4
3
|
import polars as pl
|
|
5
4
|
import json
|
|
6
|
-
from typing import Union, List, Optional,
|
|
5
|
+
from typing import Union, List, Optional, Dict, Any, Type
|
|
7
6
|
import os
|
|
8
7
|
import sys
|
|
9
8
|
from yaspin import yaspin
|
|
10
9
|
from yaspin.spinners import Spinners
|
|
11
|
-
from colorama import init
|
|
12
|
-
from tqdm import tqdm
|
|
10
|
+
from colorama import init
|
|
13
11
|
import time
|
|
14
12
|
from pydantic import BaseModel
|
|
15
13
|
import pyarrow.parquet as pq
|
|
16
14
|
import shutil
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def terminal_statuses(cls) -> list["JobStatus"]:
|
|
33
|
-
return [
|
|
34
|
-
cls.SUCCEEDED,
|
|
35
|
-
cls.FAILED,
|
|
36
|
-
cls.CANCELLING,
|
|
37
|
-
cls.CANCELLED,
|
|
38
|
-
]
|
|
39
|
-
|
|
40
|
-
def is_terminal(self) -> bool:
|
|
41
|
-
return self in self.terminal_statuses()
|
|
42
|
-
|
|
15
|
+
from sutro.common import (
|
|
16
|
+
ModelOptions,
|
|
17
|
+
handle_data_helper,
|
|
18
|
+
normalize_output_schema,
|
|
19
|
+
to_colored_text,
|
|
20
|
+
fancy_tqdm,
|
|
21
|
+
)
|
|
22
|
+
from sutro.interfaces import JobStatus
|
|
23
|
+
from sutro.templates.classification import ClassificationTemplates
|
|
24
|
+
from sutro.templates.embed import EmbeddingTemplates
|
|
25
|
+
from sutro.templates.evals import EvalTemplates
|
|
26
|
+
from sutro.validation import check_version, check_for_api_key
|
|
27
|
+
|
|
28
|
+
JOB_NAME_CHAR_LIMIT = 45
|
|
29
|
+
JOB_DESCRIPTION_CHAR_LIMIT = 512
|
|
43
30
|
|
|
44
31
|
# Initialize colorama (required for Windows)
|
|
45
32
|
init()
|
|
@@ -50,56 +37,11 @@ def is_jupyter() -> bool:
|
|
|
50
37
|
return not sys.stdout.isatty()
|
|
51
38
|
|
|
52
39
|
|
|
53
|
-
#
|
|
54
|
-
|
|
40
|
+
# Adding color to text is not supported in Jupyter notebooks and breaks
|
|
41
|
+
# things
|
|
42
|
+
BASE_OUTPUT_COLOR = None if is_jupyter() else "blue"
|
|
55
43
|
SPINNER = Spinners.dots14
|
|
56
44
|
|
|
57
|
-
# Models available for inference. Keep in sync with the backend configuration
|
|
58
|
-
# so users get helpful autocompletion when selecting a model.
|
|
59
|
-
ModelOptions = Literal[
|
|
60
|
-
"llama-3.2-3b",
|
|
61
|
-
"llama-3.1-8b",
|
|
62
|
-
"llama-3.3-70b",
|
|
63
|
-
"llama-3.3-70b",
|
|
64
|
-
"qwen-3-4b",
|
|
65
|
-
"qwen-3-32b",
|
|
66
|
-
"qwen-3-4b-thinking",
|
|
67
|
-
"qwen-3-32b-thinking",
|
|
68
|
-
"gemma-3-4b-it",
|
|
69
|
-
"gemma-3-27b-it",
|
|
70
|
-
"gpt-oss-120b",
|
|
71
|
-
"gpt-oss-20b",
|
|
72
|
-
"qwen-3-235b-a22b-thinking",
|
|
73
|
-
"qwen-3-30b-a3b-thinking",
|
|
74
|
-
"qwen-3-embedding-0.6b",
|
|
75
|
-
"qwen-3-embedding-6b",
|
|
76
|
-
"qwen-3-embedding-8b",
|
|
77
|
-
]
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def to_colored_text(
|
|
81
|
-
text: str, state: Optional[Literal["success", "fail"]] = None
|
|
82
|
-
) -> str:
|
|
83
|
-
"""
|
|
84
|
-
Apply color to text based on state.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
text (str): The text to color
|
|
88
|
-
state (Optional[Literal['success', 'fail']]): The state that determines the color.
|
|
89
|
-
Options: 'success', 'fail', or None (default blue)
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
str: Text with appropriate color applied
|
|
93
|
-
"""
|
|
94
|
-
match state:
|
|
95
|
-
case "success":
|
|
96
|
-
return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
|
|
97
|
-
case "fail":
|
|
98
|
-
return f"{Fore.RED}{text}{Style.RESET_ALL}"
|
|
99
|
-
case _:
|
|
100
|
-
# Default to blue for normal/processing states
|
|
101
|
-
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
102
|
-
|
|
103
45
|
|
|
104
46
|
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
105
47
|
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
@@ -108,41 +50,20 @@ def make_clickable_link(url, text=None):
|
|
|
108
50
|
Create a clickable link for terminals that support OSC 8 hyperlinks.
|
|
109
51
|
Falls back to plain text for terminals that don't support it.
|
|
110
52
|
"""
|
|
53
|
+
# Don't need to add the special chars for jupyter notebook
|
|
54
|
+
if is_jupyter():
|
|
55
|
+
return url
|
|
56
|
+
|
|
111
57
|
if text is None:
|
|
112
58
|
text = url
|
|
113
59
|
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
114
60
|
|
|
115
61
|
|
|
116
|
-
class Sutro:
|
|
62
|
+
class Sutro(EmbeddingTemplates, ClassificationTemplates, EvalTemplates):
|
|
117
63
|
def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
|
|
118
|
-
self.api_key = api_key or
|
|
64
|
+
self.api_key = api_key or check_for_api_key()
|
|
119
65
|
self.base_url = base_url
|
|
120
|
-
|
|
121
|
-
def check_for_api_key(self):
|
|
122
|
-
"""
|
|
123
|
-
Check for an API key in the user's home directory.
|
|
124
|
-
|
|
125
|
-
This method looks for a configuration file named 'config.json' in the
|
|
126
|
-
'.sutro' directory within the user's home directory.
|
|
127
|
-
If the file exists, it attempts to read the API key from it.
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
str or None: The API key if found in the configuration file, or None if not found.
|
|
131
|
-
|
|
132
|
-
Note:
|
|
133
|
-
The expected structure of the config.json file is:
|
|
134
|
-
{
|
|
135
|
-
"api_key": "your_api_key_here"
|
|
136
|
-
}
|
|
137
|
-
"""
|
|
138
|
-
CONFIG_DIR = os.path.expanduser("~/.sutro")
|
|
139
|
-
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
140
|
-
if os.path.exists(CONFIG_FILE):
|
|
141
|
-
with open(CONFIG_FILE, "r") as f:
|
|
142
|
-
config = json.load(f)
|
|
143
|
-
return config.get("api_key")
|
|
144
|
-
else:
|
|
145
|
-
return None
|
|
66
|
+
check_version("sutro")
|
|
146
67
|
|
|
147
68
|
def set_api_key(self, api_key: str):
|
|
148
69
|
"""
|
|
@@ -159,43 +80,6 @@ class Sutro:
|
|
|
159
80
|
"""
|
|
160
81
|
self.api_key = api_key
|
|
161
82
|
|
|
162
|
-
def handle_data_helper(
|
|
163
|
-
self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
|
|
164
|
-
):
|
|
165
|
-
if isinstance(data, list):
|
|
166
|
-
input_data = data
|
|
167
|
-
elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
168
|
-
if column is None:
|
|
169
|
-
raise ValueError("Column name must be specified for DataFrame input")
|
|
170
|
-
input_data = data[column].to_list()
|
|
171
|
-
elif isinstance(data, str):
|
|
172
|
-
if data.startswith("dataset-"):
|
|
173
|
-
input_data = data + ":" + column
|
|
174
|
-
else:
|
|
175
|
-
file_ext = os.path.splitext(data)[1].lower()
|
|
176
|
-
if file_ext == ".csv":
|
|
177
|
-
df = pl.read_csv(data)
|
|
178
|
-
elif file_ext == ".parquet":
|
|
179
|
-
df = pl.read_parquet(data)
|
|
180
|
-
elif file_ext in [".txt", ""]:
|
|
181
|
-
with open(data, "r") as file:
|
|
182
|
-
input_data = [line.strip() for line in file]
|
|
183
|
-
else:
|
|
184
|
-
raise ValueError(f"Unsupported file type: {file_ext}")
|
|
185
|
-
|
|
186
|
-
if file_ext in [".csv", ".parquet"]:
|
|
187
|
-
if column is None:
|
|
188
|
-
raise ValueError(
|
|
189
|
-
"Column name must be specified for CSV/Parquet input"
|
|
190
|
-
)
|
|
191
|
-
input_data = df[column].to_list()
|
|
192
|
-
else:
|
|
193
|
-
raise ValueError(
|
|
194
|
-
"Unsupported data type. Please provide a list, DataFrame, or file path."
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
return input_data
|
|
198
|
-
|
|
199
83
|
def set_base_url(self, base_url: str):
|
|
200
84
|
"""
|
|
201
85
|
Set the base URL for the Sutro API.
|
|
@@ -208,11 +92,48 @@ class Sutro:
|
|
|
208
92
|
"""
|
|
209
93
|
self.base_url = base_url
|
|
210
94
|
|
|
95
|
+
def do_request(
|
|
96
|
+
self,
|
|
97
|
+
method: str,
|
|
98
|
+
endpoint: str,
|
|
99
|
+
api_key_override: Optional[str] = None,
|
|
100
|
+
**kwargs: Any,
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Helper to make authenticated requests.
|
|
104
|
+
"""
|
|
105
|
+
key = self.api_key if not api_key_override else api_key_override
|
|
106
|
+
headers = {"Authorization": f"Key {key}"}
|
|
107
|
+
|
|
108
|
+
# Merge with any headers passed in kwargs
|
|
109
|
+
if "headers" in kwargs:
|
|
110
|
+
headers.update(kwargs.pop("headers"))
|
|
111
|
+
|
|
112
|
+
url = f"{self.base_url}/{endpoint.lstrip('/')}"
|
|
113
|
+
|
|
114
|
+
# Explicit method dispatch
|
|
115
|
+
method = method.upper()
|
|
116
|
+
if method == "GET":
|
|
117
|
+
response = requests.get(url, headers=headers, **kwargs)
|
|
118
|
+
elif method == "POST":
|
|
119
|
+
response = requests.post(url, headers=headers, **kwargs)
|
|
120
|
+
elif method == "PUT":
|
|
121
|
+
response = requests.put(url, headers=headers, **kwargs)
|
|
122
|
+
elif method == "DELETE":
|
|
123
|
+
response = requests.delete(url, headers=headers, **kwargs)
|
|
124
|
+
elif method == "PATCH":
|
|
125
|
+
response = requests.patch(url, headers=headers, **kwargs)
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
128
|
+
|
|
129
|
+
response.raise_for_status()
|
|
130
|
+
return response
|
|
131
|
+
|
|
211
132
|
def _run_one_batch_inference(
|
|
212
133
|
self,
|
|
213
134
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
214
135
|
model: ModelOptions,
|
|
215
|
-
column: str,
|
|
136
|
+
column: Union[str, List[str]],
|
|
216
137
|
output_column: str,
|
|
217
138
|
job_priority: int,
|
|
218
139
|
json_schema: Dict[str, Any],
|
|
@@ -222,13 +143,20 @@ class Sutro:
|
|
|
222
143
|
stay_attached: Optional[bool],
|
|
223
144
|
random_seed_per_input: bool,
|
|
224
145
|
truncate_rows: bool,
|
|
146
|
+
name: str,
|
|
147
|
+
description: str,
|
|
225
148
|
):
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
149
|
+
# Validate name and description lengths
|
|
150
|
+
if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters."
|
|
153
|
+
)
|
|
154
|
+
if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
input_data = handle_data_helper(data, column)
|
|
232
160
|
payload = {
|
|
233
161
|
"model": model,
|
|
234
162
|
"inputs": input_data,
|
|
@@ -239,6 +167,8 @@ class Sutro:
|
|
|
239
167
|
"sampling_params": sampling_params,
|
|
240
168
|
"random_seed_per_input": random_seed_per_input,
|
|
241
169
|
"truncate_rows": truncate_rows,
|
|
170
|
+
"name": name,
|
|
171
|
+
"description": description,
|
|
242
172
|
}
|
|
243
173
|
|
|
244
174
|
# There are two gotchas with yaspin:
|
|
@@ -250,18 +180,21 @@ class Sutro:
|
|
|
250
180
|
job_id = None
|
|
251
181
|
t = f"Creating {'[cost estimate] ' if cost_estimate else ''}priority {job_priority} job"
|
|
252
182
|
spinner_text = to_colored_text(t)
|
|
183
|
+
|
|
253
184
|
try:
|
|
254
|
-
with yaspin(SPINNER, text=spinner_text, color=
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
185
|
+
with yaspin(SPINNER, text=spinner_text, color=BASE_OUTPUT_COLOR) as spinner:
|
|
186
|
+
try:
|
|
187
|
+
response = self.do_request("POST", "batch-inference", json=payload)
|
|
188
|
+
response_data = response.json()
|
|
189
|
+
except requests.HTTPError as e:
|
|
190
|
+
response = e.response
|
|
191
|
+
response_data = response.json()
|
|
259
192
|
if response.status_code != 200:
|
|
260
193
|
spinner.write(
|
|
261
194
|
to_colored_text(f"Error: {response.status_code}", state="fail")
|
|
262
195
|
)
|
|
263
196
|
spinner.stop()
|
|
264
|
-
print(to_colored_text(
|
|
197
|
+
print(to_colored_text(response_data, state="fail"))
|
|
265
198
|
return None
|
|
266
199
|
else:
|
|
267
200
|
job_id = response_data["results"]
|
|
@@ -284,12 +217,14 @@ class Sutro:
|
|
|
284
217
|
)
|
|
285
218
|
return job_id
|
|
286
219
|
else:
|
|
220
|
+
name_text = f" and name {name}" if name is not None else ""
|
|
287
221
|
spinner.write(
|
|
288
222
|
to_colored_text(
|
|
289
|
-
f"🛠 Priority {job_priority} Job created with ID: {job_id}
|
|
223
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}",
|
|
290
224
|
state="success",
|
|
291
225
|
)
|
|
292
226
|
)
|
|
227
|
+
spinner.write(to_colored_text(f"Model: {model}"))
|
|
293
228
|
if not stay_attached:
|
|
294
229
|
clickable_link = make_clickable_link(
|
|
295
230
|
f"https://app.sutro.sh/jobs/{job_id}"
|
|
@@ -326,20 +261,20 @@ class Sutro:
|
|
|
326
261
|
)
|
|
327
262
|
)
|
|
328
263
|
return None
|
|
329
|
-
|
|
264
|
+
|
|
330
265
|
pbar = None
|
|
331
266
|
|
|
332
267
|
try:
|
|
333
|
-
with
|
|
334
|
-
|
|
335
|
-
|
|
268
|
+
with self.do_request(
|
|
269
|
+
"GET",
|
|
270
|
+
f"/stream-job-progress/{job_id}",
|
|
336
271
|
stream=True,
|
|
337
272
|
) as streaming_response:
|
|
338
273
|
streaming_response.raise_for_status()
|
|
339
274
|
spinner = yaspin(
|
|
340
275
|
SPINNER,
|
|
341
276
|
text=to_colored_text("Awaiting status updates..."),
|
|
342
|
-
color=
|
|
277
|
+
color=BASE_OUTPUT_COLOR,
|
|
343
278
|
)
|
|
344
279
|
spinner.start()
|
|
345
280
|
|
|
@@ -361,7 +296,7 @@ class Sutro:
|
|
|
361
296
|
if pbar is None:
|
|
362
297
|
spinner.stop()
|
|
363
298
|
postfix = "Input tokens processed: 0"
|
|
364
|
-
pbar =
|
|
299
|
+
pbar = fancy_tqdm(
|
|
365
300
|
total=len(input_data),
|
|
366
301
|
desc="Progress",
|
|
367
302
|
style=1,
|
|
@@ -402,28 +337,27 @@ class Sutro:
|
|
|
402
337
|
)
|
|
403
338
|
spinner.start()
|
|
404
339
|
|
|
405
|
-
payload = {
|
|
406
|
-
"job_id": job_id,
|
|
407
|
-
}
|
|
408
|
-
|
|
409
340
|
# TODO: we implment retries in cases where the job hasn't written results yet
|
|
410
341
|
# it would be better if we could receive a fully succeeded status from the job
|
|
411
342
|
# and not have such a race condition
|
|
412
343
|
max_retries = 20 # winds up being 100 seconds cumulative delay
|
|
413
344
|
retry_delay = 5 # initial delay in seconds
|
|
414
|
-
|
|
345
|
+
job_results_response = None
|
|
415
346
|
for _ in range(max_retries):
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
347
|
+
try:
|
|
348
|
+
job_results_response = self.do_request(
|
|
349
|
+
"POST",
|
|
350
|
+
"job-results",
|
|
351
|
+
json={
|
|
352
|
+
"job_id": job_id,
|
|
353
|
+
},
|
|
354
|
+
)
|
|
424
355
|
break
|
|
356
|
+
except requests.HTTPError:
|
|
357
|
+
time.sleep(retry_delay)
|
|
358
|
+
continue
|
|
425
359
|
|
|
426
|
-
if job_results_response.status_code != 200:
|
|
360
|
+
if not job_results_response or job_results_response.status_code != 200:
|
|
427
361
|
spinner.write(
|
|
428
362
|
to_colored_text(
|
|
429
363
|
"Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
|
|
@@ -435,94 +369,183 @@ class Sutro:
|
|
|
435
369
|
|
|
436
370
|
results = job_results_response.json()["results"]["outputs"]
|
|
437
371
|
|
|
438
|
-
spinner.write(
|
|
439
|
-
to_colored_text(
|
|
440
|
-
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
441
|
-
state="success",
|
|
442
|
-
)
|
|
443
|
-
)
|
|
444
|
-
spinner.stop()
|
|
445
|
-
|
|
446
372
|
if isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
447
373
|
if isinstance(data, pd.DataFrame):
|
|
448
374
|
data[output_column] = results
|
|
449
375
|
elif isinstance(data, pl.DataFrame):
|
|
450
376
|
data = data.with_columns(pl.Series(output_column, results))
|
|
451
|
-
|
|
377
|
+
print(data)
|
|
378
|
+
spinner.write(
|
|
379
|
+
to_colored_text(
|
|
380
|
+
f"✔ Displaying result preview. You can join the results on the original dataframe with `so.get_job_results('{job_id}', with_original_df=<original_df>)`",
|
|
381
|
+
state="success",
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
print(results)
|
|
386
|
+
spinner.write(
|
|
387
|
+
to_colored_text(
|
|
388
|
+
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
389
|
+
state="success",
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
spinner.stop()
|
|
452
393
|
|
|
453
|
-
return
|
|
394
|
+
return job_id
|
|
454
395
|
return None
|
|
455
396
|
return None
|
|
456
397
|
|
|
457
398
|
def infer(
|
|
458
399
|
self,
|
|
459
400
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
460
|
-
model:
|
|
461
|
-
|
|
401
|
+
model: ModelOptions = "gemma-3-12b-it",
|
|
402
|
+
name: Optional[str] = None,
|
|
403
|
+
description: Optional[str] = None,
|
|
404
|
+
column: Union[str, List[str]] = None,
|
|
462
405
|
output_column: str = "inference_result",
|
|
463
406
|
job_priority: int = 0,
|
|
464
|
-
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
407
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
465
408
|
sampling_params: dict = None,
|
|
466
409
|
system_prompt: str = None,
|
|
467
410
|
dry_run: bool = False,
|
|
468
411
|
stay_attached: Optional[bool] = None,
|
|
469
412
|
random_seed_per_input: bool = False,
|
|
470
|
-
truncate_rows: bool =
|
|
413
|
+
truncate_rows: bool = True,
|
|
471
414
|
):
|
|
472
415
|
"""
|
|
473
416
|
Run inference on the provided data.
|
|
474
417
|
|
|
475
418
|
This method allows you to run inference on the provided data using the Sutro API.
|
|
476
|
-
It supports various data types such as lists,
|
|
419
|
+
It supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
|
|
477
420
|
|
|
478
421
|
Args:
|
|
479
422
|
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
480
|
-
model (
|
|
481
|
-
|
|
423
|
+
model (ModelOptions, optional): The model to use for inference. Defaults to "gemma-3-12b-it".
|
|
424
|
+
name (str, optional): A job name for experiment/metadata tracking purposes. Defaults to None.
|
|
425
|
+
description (str, optional): A job description for experiment/metadata tracking purposes. Defaults to None.
|
|
426
|
+
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
482
427
|
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
483
428
|
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
484
429
|
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
485
|
-
Can be either a dictionary representing a JSON schema or a
|
|
430
|
+
Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
|
|
486
431
|
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
487
432
|
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
488
433
|
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
489
434
|
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
490
435
|
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
491
|
-
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to
|
|
436
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
|
|
492
437
|
|
|
493
438
|
Returns:
|
|
494
|
-
|
|
439
|
+
str: The ID of the inference job.
|
|
495
440
|
|
|
496
441
|
"""
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
stay_attached =
|
|
500
|
-
|
|
501
|
-
|
|
442
|
+
# Default stay_attached to True for prototyping jobs (priority 0)
|
|
443
|
+
if stay_attached is None:
|
|
444
|
+
stay_attached = job_priority == 0
|
|
445
|
+
|
|
446
|
+
json_schema = None
|
|
447
|
+
if output_schema:
|
|
448
|
+
# Convert BaseModel to dict if needed
|
|
449
|
+
json_schema = normalize_output_schema(output_schema)
|
|
450
|
+
|
|
451
|
+
return self._run_one_batch_inference(
|
|
452
|
+
data,
|
|
453
|
+
model,
|
|
454
|
+
column,
|
|
455
|
+
output_column,
|
|
456
|
+
job_priority,
|
|
457
|
+
json_schema,
|
|
458
|
+
sampling_params,
|
|
459
|
+
system_prompt,
|
|
460
|
+
dry_run,
|
|
461
|
+
stay_attached,
|
|
462
|
+
random_seed_per_input,
|
|
463
|
+
truncate_rows,
|
|
464
|
+
name,
|
|
465
|
+
description,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def infer_per_model(
|
|
469
|
+
self,
|
|
470
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
471
|
+
models: List[ModelOptions],
|
|
472
|
+
names: List[str] = None,
|
|
473
|
+
descriptions: List[str] = None,
|
|
474
|
+
column: Union[str, List[str]] = None,
|
|
475
|
+
output_column: str = "inference_result",
|
|
476
|
+
job_priority: int = 0,
|
|
477
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
478
|
+
sampling_params: dict = None,
|
|
479
|
+
system_prompt: str = None,
|
|
480
|
+
dry_run: bool = False,
|
|
481
|
+
random_seed_per_input: bool = False,
|
|
482
|
+
truncate_rows: bool = True,
|
|
483
|
+
):
|
|
484
|
+
"""
|
|
485
|
+
Run inference on the provided data, across multiple models. This method is often useful to sampling outputs from multiple models across the same dataset and compare the job_ids.
|
|
486
|
+
|
|
487
|
+
For input data, it supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
491
|
+
models (Union[ModelOptions, List[ModelOptions]], optional): The models to use for inference. Fans out each model to its own seperate job, over the same dataset.
|
|
492
|
+
names (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
|
|
493
|
+
descriptions (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
|
|
494
|
+
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
495
|
+
output_column (str, optional): The column name to store the inference job_ids in if the input is a DataFrame. Defaults to "inference_result".
|
|
496
|
+
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
497
|
+
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
498
|
+
Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
|
|
499
|
+
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
500
|
+
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
501
|
+
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
502
|
+
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
503
|
+
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
504
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
str: The ID of the inference job.
|
|
508
|
+
|
|
509
|
+
"""
|
|
510
|
+
if isinstance(names, list):
|
|
511
|
+
if len(names) != len(models):
|
|
512
|
+
raise ValueError(
|
|
513
|
+
"names parameter must be the same length as the models parameter."
|
|
514
|
+
)
|
|
515
|
+
elif names is None:
|
|
516
|
+
names = [None] * len(models)
|
|
502
517
|
else:
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
if
|
|
508
|
-
if
|
|
509
|
-
output_schema, "model_json_schema"
|
|
510
|
-
): # Check for pydantic Model interface
|
|
511
|
-
json_schema = output_schema.model_json_schema()
|
|
512
|
-
elif isinstance(output_schema, dict):
|
|
513
|
-
json_schema = output_schema
|
|
514
|
-
else:
|
|
518
|
+
raise ValueError(
|
|
519
|
+
"names parameter must be a list or None if using a list of models"
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
if isinstance(descriptions, list):
|
|
523
|
+
if len(descriptions) != len(models):
|
|
515
524
|
raise ValueError(
|
|
516
|
-
"
|
|
525
|
+
"descriptions parameter must be the same length as the models"
|
|
526
|
+
" parameter."
|
|
517
527
|
)
|
|
528
|
+
elif descriptions is None:
|
|
529
|
+
descriptions = [None] * len(models)
|
|
518
530
|
else:
|
|
519
|
-
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"descriptions parameter must be a list or None if using a list of "
|
|
533
|
+
"models"
|
|
534
|
+
)
|
|
520
535
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
536
|
+
json_schema = None
|
|
537
|
+
if output_schema:
|
|
538
|
+
# Convert BaseModel to dict if needed
|
|
539
|
+
json_schema = normalize_output_schema(output_schema)
|
|
540
|
+
|
|
541
|
+
def start_job(
|
|
542
|
+
model_singleton: ModelOptions,
|
|
543
|
+
name_singleton: str | None,
|
|
544
|
+
description_singleton: str | None,
|
|
545
|
+
):
|
|
546
|
+
return self._run_one_batch_inference(
|
|
524
547
|
data,
|
|
525
|
-
|
|
548
|
+
model_singleton,
|
|
526
549
|
column,
|
|
527
550
|
output_column,
|
|
528
551
|
job_priority,
|
|
@@ -530,18 +553,21 @@ class Sutro:
|
|
|
530
553
|
sampling_params,
|
|
531
554
|
system_prompt,
|
|
532
555
|
dry_run,
|
|
533
|
-
|
|
556
|
+
False,
|
|
534
557
|
random_seed_per_input,
|
|
535
558
|
truncate_rows,
|
|
559
|
+
name_singleton,
|
|
560
|
+
description_singleton,
|
|
536
561
|
)
|
|
537
|
-
results.append(res)
|
|
538
562
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
563
|
+
job_ids = [
|
|
564
|
+
start_job(model, name, description)
|
|
565
|
+
for model, name, description in zip(
|
|
566
|
+
models, names, descriptions, strict=True
|
|
567
|
+
)
|
|
568
|
+
]
|
|
543
569
|
|
|
544
|
-
return
|
|
570
|
+
return job_ids
|
|
545
571
|
|
|
546
572
|
def attach(self, job_id):
|
|
547
573
|
"""
|
|
@@ -552,20 +578,12 @@ class Sutro:
|
|
|
552
578
|
"""
|
|
553
579
|
|
|
554
580
|
s = requests.Session()
|
|
555
|
-
payload = {
|
|
556
|
-
"job_id": job_id,
|
|
557
|
-
}
|
|
558
581
|
pbar = None
|
|
559
582
|
|
|
560
|
-
headers = {
|
|
561
|
-
"Authorization": f"Key {self.api_key}",
|
|
562
|
-
"Content-Type": "application/json",
|
|
563
|
-
}
|
|
564
|
-
|
|
565
583
|
with yaspin(
|
|
566
584
|
SPINNER,
|
|
567
585
|
text=to_colored_text("Looking for job..."),
|
|
568
|
-
color=
|
|
586
|
+
color=BASE_OUTPUT_COLOR,
|
|
569
587
|
) as spinner:
|
|
570
588
|
# Fetch the specific job we want to attach to
|
|
571
589
|
job = self._fetch_job(job_id)
|
|
@@ -599,16 +617,16 @@ class Sutro:
|
|
|
599
617
|
success = False
|
|
600
618
|
|
|
601
619
|
try:
|
|
602
|
-
with
|
|
603
|
-
|
|
604
|
-
|
|
620
|
+
with self.do_request(
|
|
621
|
+
"GET",
|
|
622
|
+
f"/stream-job-progress/{job_id}",
|
|
605
623
|
stream=True,
|
|
606
624
|
) as streaming_response:
|
|
607
625
|
streaming_response.raise_for_status()
|
|
608
626
|
spinner = yaspin(
|
|
609
627
|
SPINNER,
|
|
610
628
|
text=to_colored_text("Awaiting status updates..."),
|
|
611
|
-
color=
|
|
629
|
+
color=BASE_OUTPUT_COLOR,
|
|
612
630
|
)
|
|
613
631
|
clickable_link = make_clickable_link(
|
|
614
632
|
f"https://app.sutro.sh/jobs/{job_id}"
|
|
@@ -631,7 +649,7 @@ class Sutro:
|
|
|
631
649
|
if pbar is None:
|
|
632
650
|
spinner.stop()
|
|
633
651
|
postfix = "Input tokens processed: 0"
|
|
634
|
-
pbar =
|
|
652
|
+
pbar = fancy_tqdm(
|
|
635
653
|
total=total_rows,
|
|
636
654
|
desc="Progress",
|
|
637
655
|
style=1,
|
|
@@ -668,7 +686,7 @@ class Sutro:
|
|
|
668
686
|
self,
|
|
669
687
|
total: int,
|
|
670
688
|
desc: str = "Progress",
|
|
671
|
-
color: str =
|
|
689
|
+
color: str = BASE_OUTPUT_COLOR,
|
|
672
690
|
style=1,
|
|
673
691
|
postfix: str = None,
|
|
674
692
|
):
|
|
@@ -730,56 +748,36 @@ class Sutro:
|
|
|
730
748
|
This method retrieves a list of all jobs associated with the API key.
|
|
731
749
|
|
|
732
750
|
Returns:
|
|
733
|
-
list: A list of job details.
|
|
751
|
+
list: A list of job details, or None if the request fails.
|
|
734
752
|
"""
|
|
735
|
-
endpoint = f"{self.base_url}/list-jobs"
|
|
736
|
-
headers = {
|
|
737
|
-
"Authorization": f"Key {self.api_key}",
|
|
738
|
-
"Content-Type": "application/json",
|
|
739
|
-
}
|
|
740
|
-
|
|
741
753
|
with yaspin(
|
|
742
|
-
SPINNER, text=to_colored_text("Fetching jobs"), color=
|
|
754
|
+
SPINNER, text=to_colored_text("Fetching jobs"), color=BASE_OUTPUT_COLOR
|
|
743
755
|
) as spinner:
|
|
744
|
-
|
|
745
|
-
|
|
756
|
+
try:
|
|
757
|
+
return self._list_all_jobs_for_user()
|
|
758
|
+
except requests.HTTPError as e:
|
|
746
759
|
spinner.write(
|
|
747
760
|
to_colored_text(
|
|
748
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
761
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
749
762
|
)
|
|
750
763
|
)
|
|
751
764
|
spinner.stop()
|
|
752
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
753
|
-
return
|
|
754
|
-
return response.json()["jobs"]
|
|
765
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
766
|
+
return None
|
|
755
767
|
|
|
756
|
-
def
|
|
757
|
-
"""
|
|
758
|
-
Helper function to list jobs.
|
|
759
|
-
"""
|
|
760
|
-
endpoint = f"{self.base_url}/list-jobs˚"
|
|
761
|
-
headers = {
|
|
762
|
-
"Authorization": f"Key {self.api_key}",
|
|
763
|
-
"Content-Type": "application/json",
|
|
764
|
-
}
|
|
765
|
-
response = requests.get(endpoint, headers=headers)
|
|
766
|
-
if response.status_code != 200:
|
|
767
|
-
return None
|
|
768
|
+
def _list_all_jobs_for_user(self):
|
|
769
|
+
response = self.do_request("GET", "list-jobs")
|
|
768
770
|
return response.json()["jobs"]
|
|
769
771
|
|
|
770
772
|
def _fetch_job(self, job_id):
|
|
771
773
|
"""
|
|
772
774
|
Helper function to fetch a single job.
|
|
773
775
|
"""
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
}
|
|
779
|
-
response = requests.get(endpoint, headers=headers)
|
|
780
|
-
if response.status_code != 200:
|
|
776
|
+
try:
|
|
777
|
+
response = self.do_request("GET", f"jobs/{job_id}")
|
|
778
|
+
return response.json().get("job")
|
|
779
|
+
except requests.HTTPError:
|
|
781
780
|
return None
|
|
782
|
-
return response.json().get("job")
|
|
783
781
|
|
|
784
782
|
def _get_job_cost_estimate(self, job_id: str):
|
|
785
783
|
"""
|
|
@@ -813,15 +811,7 @@ class Sutro:
|
|
|
813
811
|
Raises:
|
|
814
812
|
requests.HTTPError: If the API returns a non-200 status code.
|
|
815
813
|
"""
|
|
816
|
-
|
|
817
|
-
headers = {
|
|
818
|
-
"Authorization": f"Key {self.api_key}",
|
|
819
|
-
"Content-Type": "application/json",
|
|
820
|
-
}
|
|
821
|
-
|
|
822
|
-
response = requests.get(endpoint, headers=headers)
|
|
823
|
-
response.raise_for_status()
|
|
824
|
-
|
|
814
|
+
response = self.do_request("GET", f"job-status/{job_id}")
|
|
825
815
|
return response.json()["job_status"][job_id]
|
|
826
816
|
|
|
827
817
|
def get_job_status(self, job_id: str):
|
|
@@ -839,7 +829,7 @@ class Sutro:
|
|
|
839
829
|
with yaspin(
|
|
840
830
|
SPINNER,
|
|
841
831
|
text=to_colored_text(f"Checking job status with ID: {job_id}"),
|
|
842
|
-
color=
|
|
832
|
+
color=BASE_OUTPUT_COLOR,
|
|
843
833
|
) as spinner:
|
|
844
834
|
try:
|
|
845
835
|
response_data = self._fetch_job_status(job_id)
|
|
@@ -866,7 +856,7 @@ class Sutro:
|
|
|
866
856
|
output_column: str = "inference_result",
|
|
867
857
|
disable_cache: bool = False,
|
|
868
858
|
unpack_json: bool = True,
|
|
869
|
-
):
|
|
859
|
+
) -> pl.DataFrame | pd.DataFrame:
|
|
870
860
|
"""
|
|
871
861
|
Get the results of a job by its ID.
|
|
872
862
|
|
|
@@ -896,51 +886,44 @@ class Sutro:
|
|
|
896
886
|
with yaspin(
|
|
897
887
|
SPINNER,
|
|
898
888
|
text=to_colored_text(f"Loading results from cache: {file_path}"),
|
|
899
|
-
color=
|
|
889
|
+
color=BASE_OUTPUT_COLOR,
|
|
900
890
|
) as spinner:
|
|
901
891
|
results_df = pl.read_parquet(file_path)
|
|
902
892
|
spinner.write(
|
|
903
893
|
to_colored_text("✔ Results loaded from cache", state="success")
|
|
904
894
|
)
|
|
905
895
|
else:
|
|
906
|
-
endpoint = f"{self.base_url}/job-results"
|
|
907
896
|
payload = {
|
|
908
897
|
"job_id": job_id,
|
|
909
898
|
"include_inputs": include_inputs,
|
|
910
899
|
"include_cumulative_logprobs": include_cumulative_logprobs,
|
|
911
900
|
}
|
|
912
|
-
headers = {
|
|
913
|
-
"Authorization": f"Key {self.api_key}",
|
|
914
|
-
"Content-Type": "application/json",
|
|
915
|
-
}
|
|
916
901
|
with yaspin(
|
|
917
902
|
SPINNER,
|
|
918
903
|
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
919
|
-
color=
|
|
904
|
+
color=BASE_OUTPUT_COLOR,
|
|
920
905
|
) as spinner:
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
906
|
+
try:
|
|
907
|
+
response = self.do_request("POST", "job-results", json=payload)
|
|
908
|
+
response_data = response.json()
|
|
909
|
+
spinner.write(
|
|
910
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
911
|
+
)
|
|
912
|
+
except requests.HTTPError as e:
|
|
925
913
|
spinner.write(
|
|
926
914
|
to_colored_text(
|
|
927
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
915
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
928
916
|
)
|
|
929
917
|
)
|
|
930
918
|
spinner.stop()
|
|
931
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
919
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
932
920
|
return None
|
|
933
921
|
|
|
934
|
-
spinner.write(
|
|
935
|
-
to_colored_text("✔ Job results retrieved", state="success")
|
|
936
|
-
)
|
|
937
|
-
|
|
938
|
-
response_data = response.json()
|
|
939
922
|
results_df = pl.DataFrame(response_data["results"])
|
|
940
923
|
|
|
941
924
|
results_df = results_df.rename({"outputs": output_column})
|
|
942
925
|
|
|
943
|
-
if disable_cache
|
|
926
|
+
if not disable_cache:
|
|
944
927
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
945
928
|
results_df.write_parquet(file_path, compression="snappy")
|
|
946
929
|
spinner.write(
|
|
@@ -967,10 +950,11 @@ class Sutro:
|
|
|
967
950
|
first_row = json.loads(
|
|
968
951
|
results_df.head(1)[output_column][0]
|
|
969
952
|
) # checks if the first row can be json decoded
|
|
953
|
+
results_df = results_df.map_columns(
|
|
954
|
+
output_column, lambda s: s.str.json_decode()
|
|
955
|
+
)
|
|
970
956
|
results_df = results_df.with_columns(
|
|
971
|
-
pl.col(output_column)
|
|
972
|
-
.str.json_decode()
|
|
973
|
-
.alias("output_column_json_decoded")
|
|
957
|
+
pl.col(output_column).alias("output_column_json_decoded")
|
|
974
958
|
)
|
|
975
959
|
json_decoded_fields = first_row.keys()
|
|
976
960
|
for field in json_decoded_fields:
|
|
@@ -979,11 +963,20 @@ class Sutro:
|
|
|
979
963
|
.struct.field(field)
|
|
980
964
|
.alias(field)
|
|
981
965
|
)
|
|
982
|
-
|
|
966
|
+
if sorted(list(set(json_decoded_fields))) == [
|
|
967
|
+
"content",
|
|
968
|
+
"reasoning_content",
|
|
969
|
+
]: # if it's a reasoning model, we need to unpack the content field
|
|
970
|
+
content_keys = results_df.head(1)["content"][0].keys()
|
|
971
|
+
for key in content_keys:
|
|
972
|
+
results_df = results_df.with_columns(
|
|
973
|
+
pl.col("content").struct.field(key).alias(key)
|
|
974
|
+
)
|
|
975
|
+
results_df = results_df.drop("content")
|
|
983
976
|
results_df = results_df.drop(
|
|
984
977
|
[output_column, "output_column_json_decoded"]
|
|
985
978
|
)
|
|
986
|
-
except Exception
|
|
979
|
+
except Exception:
|
|
987
980
|
# if the first row cannot be json decoded, do nothing
|
|
988
981
|
pass
|
|
989
982
|
|
|
@@ -1019,25 +1012,20 @@ class Sutro:
|
|
|
1019
1012
|
Returns:
|
|
1020
1013
|
dict: The status of the job.
|
|
1021
1014
|
"""
|
|
1022
|
-
endpoint = f"{self.base_url}/job-cancel/{job_id}"
|
|
1023
|
-
headers = {
|
|
1024
|
-
"Authorization": f"Key {self.api_key}",
|
|
1025
|
-
"Content-Type": "application/json",
|
|
1026
|
-
}
|
|
1027
1015
|
with yaspin(
|
|
1028
1016
|
SPINNER,
|
|
1029
1017
|
text=to_colored_text(f"Cancelling job: {job_id}"),
|
|
1030
|
-
color=
|
|
1018
|
+
color=BASE_OUTPUT_COLOR,
|
|
1031
1019
|
) as spinner:
|
|
1032
|
-
|
|
1033
|
-
|
|
1020
|
+
try:
|
|
1021
|
+
response = self.do_request("GET", f"job-cancel/{job_id}")
|
|
1034
1022
|
spinner.write(to_colored_text("✔ Job cancelled", state="success"))
|
|
1035
|
-
|
|
1023
|
+
return response.json()
|
|
1024
|
+
except requests.HTTPError as e:
|
|
1036
1025
|
spinner.write(to_colored_text("Failed to cancel job", state="fail"))
|
|
1037
1026
|
spinner.stop()
|
|
1038
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
1039
|
-
return
|
|
1040
|
-
return response.json()
|
|
1027
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1028
|
+
return None
|
|
1041
1029
|
|
|
1042
1030
|
def create_dataset(self):
|
|
1043
1031
|
"""
|
|
@@ -1048,31 +1036,27 @@ class Sutro:
|
|
|
1048
1036
|
Returns:
|
|
1049
1037
|
str: The ID of the new dataset.
|
|
1050
1038
|
"""
|
|
1051
|
-
endpoint = f"{self.base_url}/create-dataset"
|
|
1052
|
-
headers = {
|
|
1053
|
-
"Authorization": f"Key {self.api_key}",
|
|
1054
|
-
"Content-Type": "application/json",
|
|
1055
|
-
}
|
|
1056
1039
|
with yaspin(
|
|
1057
|
-
SPINNER, text=to_colored_text("Creating dataset"), color=
|
|
1040
|
+
SPINNER, text=to_colored_text("Creating dataset"), color=BASE_OUTPUT_COLOR
|
|
1058
1041
|
) as spinner:
|
|
1059
|
-
|
|
1060
|
-
|
|
1042
|
+
try:
|
|
1043
|
+
response = self.do_request("GET", "create-dataset")
|
|
1044
|
+
dataset_id = response.json()["dataset_id"]
|
|
1061
1045
|
spinner.write(
|
|
1062
1046
|
to_colored_text(
|
|
1063
|
-
f"
|
|
1047
|
+
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
1064
1048
|
)
|
|
1065
1049
|
)
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
1050
|
+
return dataset_id
|
|
1051
|
+
except requests.HTTPError as e:
|
|
1052
|
+
spinner.write(
|
|
1053
|
+
to_colored_text(
|
|
1054
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1055
|
+
)
|
|
1073
1056
|
)
|
|
1074
|
-
|
|
1075
|
-
|
|
1057
|
+
spinner.stop()
|
|
1058
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1059
|
+
return None
|
|
1076
1060
|
|
|
1077
1061
|
def upload_to_dataset(
|
|
1078
1062
|
self,
|
|
@@ -1104,8 +1088,6 @@ class Sutro:
|
|
|
1104
1088
|
if dataset_id is None:
|
|
1105
1089
|
dataset_id = self.create_dataset()
|
|
1106
1090
|
|
|
1107
|
-
endpoint = f"{self.base_url}/upload-to-dataset"
|
|
1108
|
-
|
|
1109
1091
|
if isinstance(file_paths, str):
|
|
1110
1092
|
# check if the file path is a directory
|
|
1111
1093
|
if os.path.isdir(file_paths):
|
|
@@ -1120,7 +1102,7 @@ class Sutro:
|
|
|
1120
1102
|
with yaspin(
|
|
1121
1103
|
SPINNER,
|
|
1122
1104
|
text=to_colored_text(f"Uploading files to dataset: {dataset_id}"),
|
|
1123
|
-
color=
|
|
1105
|
+
color=BASE_OUTPUT_COLOR,
|
|
1124
1106
|
) as spinner:
|
|
1125
1107
|
count = 0
|
|
1126
1108
|
for file_path in file_paths:
|
|
@@ -1138,8 +1120,6 @@ class Sutro:
|
|
|
1138
1120
|
"dataset_id": dataset_id,
|
|
1139
1121
|
}
|
|
1140
1122
|
|
|
1141
|
-
headers = {"Authorization": f"Key {self.api_key}"}
|
|
1142
|
-
|
|
1143
1123
|
count += 1
|
|
1144
1124
|
spinner.write(
|
|
1145
1125
|
to_colored_text(
|
|
@@ -1148,25 +1128,18 @@ class Sutro:
|
|
|
1148
1128
|
)
|
|
1149
1129
|
|
|
1150
1130
|
try:
|
|
1151
|
-
|
|
1152
|
-
|
|
1131
|
+
self.do_request(
|
|
1132
|
+
"POST",
|
|
1133
|
+
"/upload-to-dataset",
|
|
1134
|
+
data=payload,
|
|
1135
|
+
files=files,
|
|
1136
|
+
verify=verify_ssl,
|
|
1153
1137
|
)
|
|
1154
|
-
if response.status_code != 200:
|
|
1155
|
-
# Stop spinner before showing error to avoid terminal width error
|
|
1156
|
-
spinner.stop()
|
|
1157
|
-
print(
|
|
1158
|
-
to_colored_text(
|
|
1159
|
-
f"Error: HTTP {response.status_code}", state="fail"
|
|
1160
|
-
)
|
|
1161
|
-
)
|
|
1162
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
1163
|
-
return
|
|
1164
|
-
|
|
1165
1138
|
except requests.exceptions.RequestException as e:
|
|
1166
1139
|
# Stop spinner before showing error to avoid terminal width error
|
|
1167
1140
|
spinner.stop()
|
|
1168
1141
|
print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
|
|
1169
|
-
return
|
|
1142
|
+
return None
|
|
1170
1143
|
|
|
1171
1144
|
spinner.write(
|
|
1172
1145
|
to_colored_text(
|
|
@@ -1176,57 +1149,47 @@ class Sutro:
|
|
|
1176
1149
|
return dataset_id
|
|
1177
1150
|
|
|
1178
1151
|
def list_datasets(self):
|
|
1179
|
-
endpoint = f"{self.base_url}/list-datasets"
|
|
1180
|
-
headers = {
|
|
1181
|
-
"Authorization": f"Key {self.api_key}",
|
|
1182
|
-
"Content-Type": "application/json",
|
|
1183
|
-
}
|
|
1184
1152
|
with yaspin(
|
|
1185
|
-
SPINNER, text=to_colored_text("Retrieving datasets"), color=
|
|
1153
|
+
SPINNER, text=to_colored_text("Retrieving datasets"), color=BASE_OUTPUT_COLOR
|
|
1186
1154
|
) as spinner:
|
|
1187
|
-
|
|
1188
|
-
|
|
1155
|
+
try:
|
|
1156
|
+
response = self.do_request("POST", "list-datasets")
|
|
1157
|
+
spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
|
|
1158
|
+
return response.json()["datasets"]
|
|
1159
|
+
except requests.HTTPError as e:
|
|
1189
1160
|
spinner.fail(
|
|
1190
1161
|
to_colored_text(
|
|
1191
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1162
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1192
1163
|
)
|
|
1193
1164
|
)
|
|
1194
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1195
|
-
return
|
|
1196
|
-
spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
|
|
1197
|
-
return response.json()["datasets"]
|
|
1165
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1166
|
+
return None
|
|
1198
1167
|
|
|
1199
1168
|
def list_dataset_files(self, dataset_id: str):
|
|
1200
|
-
endpoint = f"{self.base_url}/list-dataset-files"
|
|
1201
|
-
headers = {
|
|
1202
|
-
"Authorization": f"Key {self.api_key}",
|
|
1203
|
-
"Content-Type": "application/json",
|
|
1204
|
-
}
|
|
1205
1169
|
payload = {
|
|
1206
1170
|
"dataset_id": dataset_id,
|
|
1207
1171
|
}
|
|
1208
1172
|
with yaspin(
|
|
1209
1173
|
SPINNER,
|
|
1210
1174
|
text=to_colored_text(f"Listing files in dataset: {dataset_id}"),
|
|
1211
|
-
color=
|
|
1175
|
+
color=BASE_OUTPUT_COLOR,
|
|
1212
1176
|
) as spinner:
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
if response.status_code != 200:
|
|
1217
|
-
spinner.fail(
|
|
1177
|
+
try:
|
|
1178
|
+
response = self.do_request("POST", "list-dataset-files", json=payload)
|
|
1179
|
+
spinner.write(
|
|
1218
1180
|
to_colored_text(
|
|
1219
|
-
f"
|
|
1181
|
+
f"✔ Files listed in dataset: {dataset_id}", state="success"
|
|
1220
1182
|
)
|
|
1221
1183
|
)
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1184
|
+
return response.json()["files"]
|
|
1185
|
+
except requests.HTTPError as e:
|
|
1186
|
+
spinner.fail(
|
|
1187
|
+
to_colored_text(
|
|
1188
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1189
|
+
)
|
|
1227
1190
|
)
|
|
1228
|
-
|
|
1229
|
-
|
|
1191
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1192
|
+
return None
|
|
1230
1193
|
|
|
1231
1194
|
def download_from_dataset(
|
|
1232
1195
|
self,
|
|
@@ -1234,8 +1197,6 @@ class Sutro:
|
|
|
1234
1197
|
files: Union[List[str], str] = None,
|
|
1235
1198
|
output_path: str = None,
|
|
1236
1199
|
):
|
|
1237
|
-
endpoint = f"{self.base_url}/download-from-dataset"
|
|
1238
|
-
|
|
1239
1200
|
if files is None:
|
|
1240
1201
|
files = self.list_dataset_files(dataset_id)
|
|
1241
1202
|
elif isinstance(files, str):
|
|
@@ -1256,36 +1217,36 @@ class Sutro:
|
|
|
1256
1217
|
with yaspin(
|
|
1257
1218
|
SPINNER,
|
|
1258
1219
|
text=to_colored_text(f"Downloading files from dataset: {dataset_id}"),
|
|
1259
|
-
color=
|
|
1220
|
+
color=BASE_OUTPUT_COLOR,
|
|
1260
1221
|
) as spinner:
|
|
1261
1222
|
count = 0
|
|
1262
1223
|
for file in files:
|
|
1263
|
-
headers = {
|
|
1264
|
-
"Authorization": f"Key {self.api_key}",
|
|
1265
|
-
"Content-Type": "application/json",
|
|
1266
|
-
}
|
|
1267
|
-
payload = {
|
|
1268
|
-
"dataset_id": dataset_id,
|
|
1269
|
-
"file_name": file,
|
|
1270
|
-
}
|
|
1271
1224
|
spinner.text = to_colored_text(
|
|
1272
1225
|
f"Downloading file {count + 1}/{len(files)} from dataset: {dataset_id}"
|
|
1273
1226
|
)
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1227
|
+
|
|
1228
|
+
try:
|
|
1229
|
+
payload = {
|
|
1230
|
+
"dataset_id": dataset_id,
|
|
1231
|
+
"file_name": file,
|
|
1232
|
+
}
|
|
1233
|
+
response = self.do_request(
|
|
1234
|
+
"POST", "download-from-dataset", json=payload
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
file_content = response.content
|
|
1238
|
+
with open(os.path.join(output_path, file), "wb") as f:
|
|
1239
|
+
f.write(file_content)
|
|
1240
|
+
|
|
1241
|
+
count += 1
|
|
1242
|
+
except requests.HTTPError as e:
|
|
1278
1243
|
spinner.fail(
|
|
1279
1244
|
to_colored_text(
|
|
1280
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1245
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1281
1246
|
)
|
|
1282
1247
|
)
|
|
1283
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1248
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1284
1249
|
return
|
|
1285
|
-
file_content = response.content
|
|
1286
|
-
with open(os.path.join(output_path, file), "wb") as f:
|
|
1287
|
-
f.write(file_content)
|
|
1288
|
-
count += 1
|
|
1289
1250
|
spinner.write(
|
|
1290
1251
|
to_colored_text(
|
|
1291
1252
|
f"✔ {count} files successfully downloaded from dataset: {dataset_id}",
|
|
@@ -1305,54 +1266,47 @@ class Sutro:
|
|
|
1305
1266
|
Returns:
|
|
1306
1267
|
dict: The status of the authentication.
|
|
1307
1268
|
"""
|
|
1308
|
-
endpoint = f"{self.base_url}/try-authentication"
|
|
1309
|
-
headers = {
|
|
1310
|
-
"Authorization": f"Key {api_key}",
|
|
1311
|
-
"Content-Type": "application/json",
|
|
1312
|
-
}
|
|
1313
1269
|
with yaspin(
|
|
1314
|
-
SPINNER, text=to_colored_text("Checking API key"), color=
|
|
1270
|
+
SPINNER, text=to_colored_text("Checking API key"), color=BASE_OUTPUT_COLOR
|
|
1315
1271
|
) as spinner:
|
|
1316
|
-
|
|
1317
|
-
|
|
1272
|
+
try:
|
|
1273
|
+
response = self.do_request("GET", "try-authentication", api_key)
|
|
1274
|
+
|
|
1318
1275
|
spinner.write(to_colored_text("✔"))
|
|
1319
|
-
|
|
1276
|
+
return response.json()
|
|
1277
|
+
except requests.HTTPError as e:
|
|
1320
1278
|
spinner.write(
|
|
1321
1279
|
to_colored_text(
|
|
1322
|
-
f"API key failed to authenticate: {response.status_code}",
|
|
1280
|
+
f"API key failed to authenticate: {e.response.status_code}",
|
|
1323
1281
|
state="fail",
|
|
1324
1282
|
)
|
|
1325
1283
|
)
|
|
1326
|
-
return
|
|
1327
|
-
return response.json()
|
|
1284
|
+
return None
|
|
1328
1285
|
|
|
1329
1286
|
def get_quotas(self):
|
|
1330
|
-
endpoint = f"{self.base_url}/get-quotas"
|
|
1331
|
-
headers = {
|
|
1332
|
-
"Authorization": f"Key {self.api_key}",
|
|
1333
|
-
"Content-Type": "application/json",
|
|
1334
|
-
}
|
|
1335
1287
|
with yaspin(
|
|
1336
|
-
SPINNER, text=to_colored_text("Fetching quotas"), color=
|
|
1288
|
+
SPINNER, text=to_colored_text("Fetching quotas"), color=BASE_OUTPUT_COLOR
|
|
1337
1289
|
) as spinner:
|
|
1338
|
-
|
|
1339
|
-
|
|
1290
|
+
try:
|
|
1291
|
+
response = self.do_request("GET", "get-quotas")
|
|
1292
|
+
return response.json()["quotas"]
|
|
1293
|
+
except requests.HTTPError as e:
|
|
1340
1294
|
spinner.fail(
|
|
1341
1295
|
to_colored_text(
|
|
1342
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1296
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1343
1297
|
)
|
|
1344
1298
|
)
|
|
1345
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1346
|
-
return
|
|
1347
|
-
return response.json()["quotas"]
|
|
1299
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1300
|
+
return None
|
|
1348
1301
|
|
|
1349
1302
|
def await_job_completion(
|
|
1350
1303
|
self,
|
|
1351
1304
|
job_id: str,
|
|
1352
1305
|
timeout: Optional[int] = 7200,
|
|
1353
1306
|
obtain_results: bool = True,
|
|
1307
|
+
output_column: str = "inference_result",
|
|
1354
1308
|
is_cost_estimate: bool = False,
|
|
1355
|
-
) ->
|
|
1309
|
+
) -> pl.DataFrame | None:
|
|
1356
1310
|
"""
|
|
1357
1311
|
Waits for job completion to occur and then returns the results upon
|
|
1358
1312
|
a successful completion.
|
|
@@ -1364,14 +1318,14 @@ class Sutro:
|
|
|
1364
1318
|
timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
|
|
1365
1319
|
|
|
1366
1320
|
Returns:
|
|
1367
|
-
|
|
1321
|
+
pl.DataFrame: The results of the job in a polars DataFrame.
|
|
1368
1322
|
"""
|
|
1369
1323
|
POLL_INTERVAL = 5
|
|
1370
1324
|
|
|
1371
|
-
results = None
|
|
1325
|
+
results: pl.DataFrame | None = None
|
|
1372
1326
|
start_time = time.time()
|
|
1373
1327
|
with yaspin(
|
|
1374
|
-
SPINNER, text=to_colored_text("Awaiting job completion"), color=
|
|
1328
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=BASE_OUTPUT_COLOR
|
|
1375
1329
|
) as spinner:
|
|
1376
1330
|
if not is_cost_estimate:
|
|
1377
1331
|
clickable_link = make_clickable_link(
|
|
@@ -1405,7 +1359,9 @@ class Sutro:
|
|
|
1405
1359
|
"Job completed! Retrieving results...", "success"
|
|
1406
1360
|
)
|
|
1407
1361
|
)
|
|
1408
|
-
results = self.get_job_results(
|
|
1362
|
+
results = self.get_job_results(
|
|
1363
|
+
job_id, output_column=output_column
|
|
1364
|
+
)
|
|
1409
1365
|
break
|
|
1410
1366
|
if status == JobStatus.FAILED:
|
|
1411
1367
|
spinner.write(to_colored_text("Job has failed", "fail"))
|
|
@@ -1433,7 +1389,7 @@ class Sutro:
|
|
|
1433
1389
|
with yaspin(
|
|
1434
1390
|
SPINNER,
|
|
1435
1391
|
text=to_colored_text("Retrieving job results cache contents"),
|
|
1436
|
-
color=
|
|
1392
|
+
color=BASE_OUTPUT_COLOR,
|
|
1437
1393
|
) as spinner:
|
|
1438
1394
|
if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
|
|
1439
1395
|
spinner.write(to_colored_text("No job results cache found", "success"))
|
|
@@ -1465,7 +1421,7 @@ class Sutro:
|
|
|
1465
1421
|
|
|
1466
1422
|
start_time = time.time()
|
|
1467
1423
|
with yaspin(
|
|
1468
|
-
SPINNER, text=to_colored_text("Awaiting job completion"), color=
|
|
1424
|
+
SPINNER, text=to_colored_text("Awaiting job completion"), color=BASE_OUTPUT_COLOR
|
|
1469
1425
|
) as spinner:
|
|
1470
1426
|
while (time.time() - start_time) < timeout:
|
|
1471
1427
|
try:
|