sutro 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/cli.py +1 -1
- sutro/common.py +220 -0
- sutro/interfaces.py +90 -0
- sutro/sdk.py +354 -550
- sutro/templates/classification.py +117 -0
- sutro/templates/embed.py +53 -0
- sutro/validation.py +60 -0
- {sutro-0.1.36.dist-info → sutro-0.1.38.dist-info}/METADATA +14 -15
- sutro-0.1.38.dist-info/RECORD +12 -0
- sutro-0.1.38.dist-info/WHEEL +4 -0
- {sutro-0.1.36.dist-info → sutro-0.1.38.dist-info}/entry_points.txt +1 -0
- sutro-0.1.36.dist-info/RECORD +0 -8
- sutro-0.1.36.dist-info/WHEEL +0 -4
- sutro-0.1.36.dist-info/licenses/LICENSE +0 -201
sutro/sdk.py
CHANGED
|
@@ -1,48 +1,32 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
1
|
import requests
|
|
3
2
|
import pandas as pd
|
|
4
3
|
import polars as pl
|
|
5
4
|
import json
|
|
6
|
-
from typing import Union, List, Optional,
|
|
5
|
+
from typing import Union, List, Optional, Dict, Any, Type
|
|
7
6
|
import os
|
|
8
7
|
import sys
|
|
9
8
|
from yaspin import yaspin
|
|
10
9
|
from yaspin.spinners import Spinners
|
|
11
|
-
from colorama import init
|
|
12
|
-
from tqdm import tqdm
|
|
10
|
+
from colorama import init
|
|
13
11
|
import time
|
|
14
12
|
from pydantic import BaseModel
|
|
15
13
|
import pyarrow.parquet as pq
|
|
16
14
|
import shutil
|
|
15
|
+
from sutro.common import (
|
|
16
|
+
ModelOptions,
|
|
17
|
+
handle_data_helper,
|
|
18
|
+
normalize_output_schema,
|
|
19
|
+
to_colored_text,
|
|
20
|
+
fancy_tqdm,
|
|
21
|
+
)
|
|
22
|
+
from sutro.interfaces import JobStatus
|
|
23
|
+
from sutro.templates.classification import ClassificationTemplates
|
|
24
|
+
from sutro.templates.embed import EmbeddingTemplates
|
|
25
|
+
from sutro.validation import check_version, check_for_api_key
|
|
17
26
|
|
|
18
27
|
JOB_NAME_CHAR_LIMIT = 45
|
|
19
28
|
JOB_DESCRIPTION_CHAR_LIMIT = 512
|
|
20
29
|
|
|
21
|
-
class JobStatus(str, Enum):
|
|
22
|
-
"""Job statuses that will be returned by the API & SDK"""
|
|
23
|
-
|
|
24
|
-
UNKNOWN = "UNKNOWN"
|
|
25
|
-
QUEUED = "QUEUED" # Job is waiting to start
|
|
26
|
-
STARTING = "STARTING" # Job is in the process of starting up
|
|
27
|
-
RUNNING = "RUNNING" # Job is actively running
|
|
28
|
-
SUCCEEDED = "SUCCEEDED" # Job completed successfully
|
|
29
|
-
CANCELLING = "CANCELLING" # Job is in the process of being canceled
|
|
30
|
-
CANCELLED = "CANCELLED" # Job was canceled by the user
|
|
31
|
-
FAILED = "FAILED" # Job failed
|
|
32
|
-
|
|
33
|
-
@classmethod
|
|
34
|
-
def terminal_statuses(cls) -> list["JobStatus"]:
|
|
35
|
-
return [
|
|
36
|
-
cls.SUCCEEDED,
|
|
37
|
-
cls.FAILED,
|
|
38
|
-
cls.CANCELLING,
|
|
39
|
-
cls.CANCELLED,
|
|
40
|
-
]
|
|
41
|
-
|
|
42
|
-
def is_terminal(self) -> bool:
|
|
43
|
-
return self in self.terminal_statuses()
|
|
44
|
-
|
|
45
|
-
|
|
46
30
|
# Initialize colorama (required for Windows)
|
|
47
31
|
init()
|
|
48
32
|
|
|
@@ -56,57 +40,6 @@ def is_jupyter() -> bool:
|
|
|
56
40
|
YASPIN_COLOR = None if is_jupyter() else "blue"
|
|
57
41
|
SPINNER = Spinners.dots14
|
|
58
42
|
|
|
59
|
-
# Models available for inference. Keep in sync with the backend configuration
|
|
60
|
-
# so users get helpful autocompletion when selecting a model.
|
|
61
|
-
ModelOptions = Literal[
|
|
62
|
-
"llama-3.2-3b",
|
|
63
|
-
"llama-3.1-8b",
|
|
64
|
-
"llama-3.3-70b",
|
|
65
|
-
"llama-3.3-70b",
|
|
66
|
-
"qwen-3-4b",
|
|
67
|
-
"qwen-3-14b",
|
|
68
|
-
"qwen-3-32b",
|
|
69
|
-
"qwen-3-30b-a3b",
|
|
70
|
-
"qwen-3-235b-a22b",
|
|
71
|
-
"qwen-3-4b-thinking",
|
|
72
|
-
"qwen-3-14b-thinking",
|
|
73
|
-
"qwen-3-32b-thinking",
|
|
74
|
-
"qwen-3-235b-a22b-thinking",
|
|
75
|
-
"qwen-3-30b-a3b-thinking",
|
|
76
|
-
"gemma-3-4b-it",
|
|
77
|
-
"gemma-3-12b-it",
|
|
78
|
-
"gemma-3-27b-it",
|
|
79
|
-
"gpt-oss-20b",
|
|
80
|
-
"gpt-oss-120b",
|
|
81
|
-
"qwen-3-embedding-0.6b",
|
|
82
|
-
"qwen-3-embedding-6b",
|
|
83
|
-
"qwen-3-embedding-8b",
|
|
84
|
-
]
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def to_colored_text(
|
|
88
|
-
text: str, state: Optional[Literal["success", "fail"]] = None
|
|
89
|
-
) -> str:
|
|
90
|
-
"""
|
|
91
|
-
Apply color to text based on state.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
text (str): The text to color
|
|
95
|
-
state (Optional[Literal['success', 'fail']]): The state that determines the color.
|
|
96
|
-
Options: 'success', 'fail', or None (default blue)
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
str: Text with appropriate color applied
|
|
100
|
-
"""
|
|
101
|
-
match state:
|
|
102
|
-
case "success":
|
|
103
|
-
return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
|
|
104
|
-
case "fail":
|
|
105
|
-
return f"{Fore.RED}{text}{Style.RESET_ALL}"
|
|
106
|
-
case _:
|
|
107
|
-
# Default to blue for normal/processing states
|
|
108
|
-
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
109
|
-
|
|
110
43
|
|
|
111
44
|
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
112
45
|
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
@@ -120,36 +53,11 @@ def make_clickable_link(url, text=None):
|
|
|
120
53
|
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
121
54
|
|
|
122
55
|
|
|
123
|
-
class Sutro:
|
|
56
|
+
class Sutro(EmbeddingTemplates, ClassificationTemplates):
|
|
124
57
|
def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
|
|
125
|
-
self.api_key = api_key or
|
|
58
|
+
self.api_key = api_key or check_for_api_key()
|
|
126
59
|
self.base_url = base_url
|
|
127
|
-
|
|
128
|
-
def check_for_api_key(self):
|
|
129
|
-
"""
|
|
130
|
-
Check for an API key in the user's home directory.
|
|
131
|
-
|
|
132
|
-
This method looks for a configuration file named 'config.json' in the
|
|
133
|
-
'.sutro' directory within the user's home directory.
|
|
134
|
-
If the file exists, it attempts to read the API key from it.
|
|
135
|
-
|
|
136
|
-
Returns:
|
|
137
|
-
str or None: The API key if found in the configuration file, or None if not found.
|
|
138
|
-
|
|
139
|
-
Note:
|
|
140
|
-
The expected structure of the config.json file is:
|
|
141
|
-
{
|
|
142
|
-
"api_key": "your_api_key_here"
|
|
143
|
-
}
|
|
144
|
-
"""
|
|
145
|
-
CONFIG_DIR = os.path.expanduser("~/.sutro")
|
|
146
|
-
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
147
|
-
if os.path.exists(CONFIG_FILE):
|
|
148
|
-
with open(CONFIG_FILE, "r") as f:
|
|
149
|
-
config = json.load(f)
|
|
150
|
-
return config.get("api_key")
|
|
151
|
-
else:
|
|
152
|
-
return None
|
|
60
|
+
check_version("sutro")
|
|
153
61
|
|
|
154
62
|
def set_api_key(self, api_key: str):
|
|
155
63
|
"""
|
|
@@ -166,79 +74,6 @@ class Sutro:
|
|
|
166
74
|
"""
|
|
167
75
|
self.api_key = api_key
|
|
168
76
|
|
|
169
|
-
def do_dataframe_column_concatenation(self, data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]):
|
|
170
|
-
"""
|
|
171
|
-
If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
|
|
172
|
-
"""
|
|
173
|
-
try:
|
|
174
|
-
if isinstance(data, pd.DataFrame):
|
|
175
|
-
series_parts = []
|
|
176
|
-
for p in column:
|
|
177
|
-
if p in data.columns:
|
|
178
|
-
s = data[p].astype("string").fillna("")
|
|
179
|
-
else:
|
|
180
|
-
# Treat as a literal separator
|
|
181
|
-
s = pd.Series([p] * len(data), index=data.index, dtype="string")
|
|
182
|
-
series_parts.append(s)
|
|
183
|
-
|
|
184
|
-
out = series_parts[0]
|
|
185
|
-
for s in series_parts[1:]:
|
|
186
|
-
out = out.str.cat(s, na_rep="")
|
|
187
|
-
|
|
188
|
-
return out.tolist()
|
|
189
|
-
elif isinstance(data, pl.DataFrame):
|
|
190
|
-
exprs = []
|
|
191
|
-
for p in column:
|
|
192
|
-
if p in data.columns:
|
|
193
|
-
exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
|
|
194
|
-
else:
|
|
195
|
-
exprs.append(pl.lit(p))
|
|
196
|
-
|
|
197
|
-
result = data.select(pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat"))
|
|
198
|
-
return result["concat"].to_list()
|
|
199
|
-
except Exception as e:
|
|
200
|
-
raise ValueError(f"Error handling column concatentation: {e}")
|
|
201
|
-
|
|
202
|
-
def handle_data_helper(
|
|
203
|
-
self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
|
|
204
|
-
):
|
|
205
|
-
if isinstance(data, list):
|
|
206
|
-
input_data = data
|
|
207
|
-
elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
208
|
-
if column is None:
|
|
209
|
-
raise ValueError("Column name must be specified for DataFrame input")
|
|
210
|
-
if isinstance(column, list):
|
|
211
|
-
input_data = self.do_dataframe_column_concatenation(data, column)
|
|
212
|
-
elif isinstance(column, str):
|
|
213
|
-
input_data = data[column].to_list()
|
|
214
|
-
elif isinstance(data, str):
|
|
215
|
-
if data.startswith("dataset-"):
|
|
216
|
-
input_data = data + ":" + column
|
|
217
|
-
else:
|
|
218
|
-
file_ext = os.path.splitext(data)[1].lower()
|
|
219
|
-
if file_ext == ".csv":
|
|
220
|
-
df = pl.read_csv(data)
|
|
221
|
-
elif file_ext == ".parquet":
|
|
222
|
-
df = pl.read_parquet(data)
|
|
223
|
-
elif file_ext in [".txt", ""]:
|
|
224
|
-
with open(data, "r") as file:
|
|
225
|
-
input_data = [line.strip() for line in file]
|
|
226
|
-
else:
|
|
227
|
-
raise ValueError(f"Unsupported file type: {file_ext}")
|
|
228
|
-
|
|
229
|
-
if file_ext in [".csv", ".parquet"]:
|
|
230
|
-
if column is None:
|
|
231
|
-
raise ValueError(
|
|
232
|
-
"Column name must be specified for CSV/Parquet input"
|
|
233
|
-
)
|
|
234
|
-
input_data = df[column].to_list()
|
|
235
|
-
else:
|
|
236
|
-
raise ValueError(
|
|
237
|
-
"Unsupported data type. Please provide a list, DataFrame, or file path."
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
return input_data
|
|
241
|
-
|
|
242
77
|
def set_base_url(self, base_url: str):
|
|
243
78
|
"""
|
|
244
79
|
Set the base URL for the Sutro API.
|
|
@@ -251,6 +86,43 @@ class Sutro:
|
|
|
251
86
|
"""
|
|
252
87
|
self.base_url = base_url
|
|
253
88
|
|
|
89
|
+
def do_request(
|
|
90
|
+
self,
|
|
91
|
+
method: str,
|
|
92
|
+
endpoint: str,
|
|
93
|
+
api_key_override: Optional[str] = None,
|
|
94
|
+
**kwargs: Any,
|
|
95
|
+
):
|
|
96
|
+
"""
|
|
97
|
+
Helper to make authenticated requests.
|
|
98
|
+
"""
|
|
99
|
+
key = self.api_key if not api_key_override else api_key_override
|
|
100
|
+
headers = {"Authorization": f"Key {key}"}
|
|
101
|
+
|
|
102
|
+
# Merge with any headers passed in kwargs
|
|
103
|
+
if "headers" in kwargs:
|
|
104
|
+
headers.update(kwargs.pop("headers"))
|
|
105
|
+
|
|
106
|
+
url = f"{self.base_url}/{endpoint.lstrip('/')}"
|
|
107
|
+
|
|
108
|
+
# Explicit method dispatch
|
|
109
|
+
method = method.upper()
|
|
110
|
+
if method == "GET":
|
|
111
|
+
response = requests.get(url, headers=headers, **kwargs)
|
|
112
|
+
elif method == "POST":
|
|
113
|
+
response = requests.post(url, headers=headers, **kwargs)
|
|
114
|
+
elif method == "PUT":
|
|
115
|
+
response = requests.put(url, headers=headers, **kwargs)
|
|
116
|
+
elif method == "DELETE":
|
|
117
|
+
response = requests.delete(url, headers=headers, **kwargs)
|
|
118
|
+
elif method == "PATCH":
|
|
119
|
+
response = requests.patch(url, headers=headers, **kwargs)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
122
|
+
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
return response
|
|
125
|
+
|
|
254
126
|
def _run_one_batch_inference(
|
|
255
127
|
self,
|
|
256
128
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
@@ -270,16 +142,15 @@ class Sutro:
|
|
|
270
142
|
):
|
|
271
143
|
# Validate name and description lengths
|
|
272
144
|
if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
|
|
273
|
-
raise ValueError(
|
|
145
|
+
raise ValueError(
|
|
146
|
+
f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters."
|
|
147
|
+
)
|
|
274
148
|
if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
|
|
275
|
-
raise ValueError(
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters."
|
|
151
|
+
)
|
|
276
152
|
|
|
277
|
-
input_data =
|
|
278
|
-
endpoint = f"{self.base_url}/batch-inference"
|
|
279
|
-
headers = {
|
|
280
|
-
"Authorization": f"Key {self.api_key}",
|
|
281
|
-
"Content-Type": "application/json",
|
|
282
|
-
}
|
|
153
|
+
input_data = handle_data_helper(data, column)
|
|
283
154
|
payload = {
|
|
284
155
|
"model": model,
|
|
285
156
|
"inputs": input_data,
|
|
@@ -305,16 +176,19 @@ class Sutro:
|
|
|
305
176
|
spinner_text = to_colored_text(t)
|
|
306
177
|
try:
|
|
307
178
|
with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
179
|
+
try:
|
|
180
|
+
response = self.do_request("POST", "batch-inference", json=payload)
|
|
181
|
+
response_data = response.json()
|
|
182
|
+
except requests.HTTPError as e:
|
|
183
|
+
response = e.response
|
|
184
|
+
response_data = response.json()
|
|
185
|
+
|
|
312
186
|
if response.status_code != 200:
|
|
313
187
|
spinner.write(
|
|
314
188
|
to_colored_text(f"Error: {response.status_code}", state="fail")
|
|
315
189
|
)
|
|
316
190
|
spinner.stop()
|
|
317
|
-
print(to_colored_text(
|
|
191
|
+
print(to_colored_text(response_data, state="fail"))
|
|
318
192
|
return None
|
|
319
193
|
else:
|
|
320
194
|
job_id = response_data["results"]
|
|
@@ -340,10 +214,11 @@ class Sutro:
|
|
|
340
214
|
name_text = f" and name {name}" if name is not None else ""
|
|
341
215
|
spinner.write(
|
|
342
216
|
to_colored_text(
|
|
343
|
-
f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}
|
|
217
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}",
|
|
344
218
|
state="success",
|
|
345
219
|
)
|
|
346
220
|
)
|
|
221
|
+
spinner.write(to_colored_text(f"Model: {model}"))
|
|
347
222
|
if not stay_attached:
|
|
348
223
|
clickable_link = make_clickable_link(
|
|
349
224
|
f"https://app.sutro.sh/jobs/{job_id}"
|
|
@@ -380,13 +255,13 @@ class Sutro:
|
|
|
380
255
|
)
|
|
381
256
|
)
|
|
382
257
|
return None
|
|
383
|
-
|
|
258
|
+
|
|
384
259
|
pbar = None
|
|
385
260
|
|
|
386
261
|
try:
|
|
387
|
-
with
|
|
388
|
-
|
|
389
|
-
|
|
262
|
+
with self.do_request(
|
|
263
|
+
"GET",
|
|
264
|
+
f"/stream-job-progress/{job_id}",
|
|
390
265
|
stream=True,
|
|
391
266
|
) as streaming_response:
|
|
392
267
|
streaming_response.raise_for_status()
|
|
@@ -415,7 +290,7 @@ class Sutro:
|
|
|
415
290
|
if pbar is None:
|
|
416
291
|
spinner.stop()
|
|
417
292
|
postfix = "Input tokens processed: 0"
|
|
418
|
-
pbar =
|
|
293
|
+
pbar = fancy_tqdm(
|
|
419
294
|
total=len(input_data),
|
|
420
295
|
desc="Progress",
|
|
421
296
|
style=1,
|
|
@@ -456,28 +331,27 @@ class Sutro:
|
|
|
456
331
|
)
|
|
457
332
|
spinner.start()
|
|
458
333
|
|
|
459
|
-
payload = {
|
|
460
|
-
"job_id": job_id,
|
|
461
|
-
}
|
|
462
|
-
|
|
463
334
|
# TODO: we implment retries in cases where the job hasn't written results yet
|
|
464
335
|
# it would be better if we could receive a fully succeeded status from the job
|
|
465
336
|
# and not have such a race condition
|
|
466
337
|
max_retries = 20 # winds up being 100 seconds cumulative delay
|
|
467
338
|
retry_delay = 5 # initial delay in seconds
|
|
468
|
-
|
|
339
|
+
job_results_response = None
|
|
469
340
|
for _ in range(max_retries):
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
341
|
+
try:
|
|
342
|
+
job_results_response = self.do_request(
|
|
343
|
+
"POST",
|
|
344
|
+
"job-results",
|
|
345
|
+
json={
|
|
346
|
+
"job_id": job_id,
|
|
347
|
+
},
|
|
348
|
+
)
|
|
478
349
|
break
|
|
350
|
+
except requests.HTTPError:
|
|
351
|
+
time.sleep(retry_delay)
|
|
352
|
+
continue
|
|
479
353
|
|
|
480
|
-
if job_results_response.status_code != 200:
|
|
354
|
+
if not job_results_response or job_results_response.status_code != 200:
|
|
481
355
|
spinner.write(
|
|
482
356
|
to_colored_text(
|
|
483
357
|
"Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
|
|
@@ -489,122 +363,183 @@ class Sutro:
|
|
|
489
363
|
|
|
490
364
|
results = job_results_response.json()["results"]["outputs"]
|
|
491
365
|
|
|
492
|
-
spinner.write(
|
|
493
|
-
to_colored_text(
|
|
494
|
-
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
495
|
-
state="success",
|
|
496
|
-
)
|
|
497
|
-
)
|
|
498
|
-
spinner.stop()
|
|
499
|
-
|
|
500
366
|
if isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
501
367
|
if isinstance(data, pd.DataFrame):
|
|
502
368
|
data[output_column] = results
|
|
503
369
|
elif isinstance(data, pl.DataFrame):
|
|
504
370
|
data = data.with_columns(pl.Series(output_column, results))
|
|
505
|
-
|
|
371
|
+
print(data)
|
|
372
|
+
spinner.write(
|
|
373
|
+
to_colored_text(
|
|
374
|
+
f"✔ Displaying result preview. You can join the results on the original dataframe with `so.get_job_results('{job_id}', with_original_df=<original_df>)`",
|
|
375
|
+
state="success",
|
|
376
|
+
)
|
|
377
|
+
)
|
|
378
|
+
else:
|
|
379
|
+
print(results)
|
|
380
|
+
spinner.write(
|
|
381
|
+
to_colored_text(
|
|
382
|
+
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
383
|
+
state="success",
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
spinner.stop()
|
|
506
387
|
|
|
507
|
-
return
|
|
388
|
+
return job_id
|
|
508
389
|
return None
|
|
509
390
|
return None
|
|
510
391
|
|
|
511
392
|
def infer(
|
|
512
393
|
self,
|
|
513
394
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
514
|
-
model:
|
|
515
|
-
name:
|
|
516
|
-
description:
|
|
395
|
+
model: ModelOptions = "gemma-3-12b-it",
|
|
396
|
+
name: Optional[str] = None,
|
|
397
|
+
description: Optional[str] = None,
|
|
517
398
|
column: Union[str, List[str]] = None,
|
|
518
399
|
output_column: str = "inference_result",
|
|
519
400
|
job_priority: int = 0,
|
|
520
|
-
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
401
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
521
402
|
sampling_params: dict = None,
|
|
522
403
|
system_prompt: str = None,
|
|
523
404
|
dry_run: bool = False,
|
|
524
405
|
stay_attached: Optional[bool] = None,
|
|
525
406
|
random_seed_per_input: bool = False,
|
|
526
|
-
truncate_rows: bool =
|
|
407
|
+
truncate_rows: bool = True,
|
|
527
408
|
):
|
|
528
409
|
"""
|
|
529
410
|
Run inference on the provided data.
|
|
530
411
|
|
|
531
412
|
This method allows you to run inference on the provided data using the Sutro API.
|
|
532
|
-
It supports various data types such as lists,
|
|
413
|
+
It supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
|
|
533
414
|
|
|
534
415
|
Args:
|
|
535
416
|
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
536
|
-
model (
|
|
537
|
-
name (
|
|
538
|
-
description (
|
|
417
|
+
model (ModelOptions, optional): The model to use for inference. Defaults to "gemma-3-12b-it".
|
|
418
|
+
name (str, optional): A job name for experiment/metadata tracking purposes. Defaults to None.
|
|
419
|
+
description (str, optional): A job description for experiment/metadata tracking purposes. Defaults to None.
|
|
539
420
|
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
540
421
|
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
541
422
|
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
542
423
|
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
543
|
-
Can be either a dictionary representing a JSON schema or a
|
|
424
|
+
Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
|
|
544
425
|
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
545
426
|
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
546
427
|
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
547
428
|
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
548
429
|
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
549
|
-
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to
|
|
430
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
|
|
550
431
|
|
|
551
432
|
Returns:
|
|
552
|
-
|
|
433
|
+
str: The ID of the inference job.
|
|
553
434
|
|
|
554
435
|
"""
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
stay_attached =
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
436
|
+
# Default stay_attached to True for prototyping jobs (priority 0)
|
|
437
|
+
if stay_attached is None:
|
|
438
|
+
stay_attached = job_priority == 0
|
|
439
|
+
|
|
440
|
+
json_schema = None
|
|
441
|
+
if output_schema:
|
|
442
|
+
# Convert BaseModel to dict if needed
|
|
443
|
+
json_schema = normalize_output_schema(output_schema)
|
|
444
|
+
|
|
445
|
+
return self._run_one_batch_inference(
|
|
446
|
+
data,
|
|
447
|
+
model,
|
|
448
|
+
column,
|
|
449
|
+
output_column,
|
|
450
|
+
job_priority,
|
|
451
|
+
json_schema,
|
|
452
|
+
sampling_params,
|
|
453
|
+
system_prompt,
|
|
454
|
+
dry_run,
|
|
455
|
+
stay_attached,
|
|
456
|
+
random_seed_per_input,
|
|
457
|
+
truncate_rows,
|
|
458
|
+
name,
|
|
459
|
+
description,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def infer_per_model(
|
|
463
|
+
self,
|
|
464
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
465
|
+
models: List[ModelOptions],
|
|
466
|
+
names: List[str] = None,
|
|
467
|
+
descriptions: List[str] = None,
|
|
468
|
+
column: Union[str, List[str]] = None,
|
|
469
|
+
output_column: str = "inference_result",
|
|
470
|
+
job_priority: int = 0,
|
|
471
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
472
|
+
sampling_params: dict = None,
|
|
473
|
+
system_prompt: str = None,
|
|
474
|
+
dry_run: bool = False,
|
|
475
|
+
random_seed_per_input: bool = False,
|
|
476
|
+
truncate_rows: bool = True,
|
|
477
|
+
):
|
|
478
|
+
"""
|
|
479
|
+
Run inference on the provided data, across multiple models. This method is often useful to sampling outputs from multiple models across the same dataset and compare the job_ids.
|
|
480
|
+
|
|
481
|
+
For input data, it supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
485
|
+
models (Union[ModelOptions, List[ModelOptions]], optional): The models to use for inference. Fans out each model to its own seperate job, over the same dataset.
|
|
486
|
+
names (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
|
|
487
|
+
descriptions (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
|
|
488
|
+
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
489
|
+
output_column (str, optional): The column name to store the inference job_ids in if the input is a DataFrame. Defaults to "inference_result".
|
|
490
|
+
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
491
|
+
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
492
|
+
Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
|
|
493
|
+
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
494
|
+
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
495
|
+
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
496
|
+
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
497
|
+
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
498
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
str: The ID of the inference job.
|
|
502
|
+
|
|
503
|
+
"""
|
|
504
|
+
if isinstance(names, list):
|
|
505
|
+
if len(names) != len(models):
|
|
506
|
+
raise ValueError(
|
|
507
|
+
"names parameter must be the same length as the models parameter."
|
|
508
|
+
)
|
|
509
|
+
elif names is None:
|
|
510
|
+
names = [None] * len(models)
|
|
583
511
|
else:
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
if hasattr(
|
|
591
|
-
output_schema, "model_json_schema"
|
|
592
|
-
): # Check for pydantic Model interface
|
|
593
|
-
json_schema = output_schema.model_json_schema()
|
|
594
|
-
elif isinstance(output_schema, dict):
|
|
595
|
-
json_schema = output_schema
|
|
596
|
-
else:
|
|
512
|
+
raise ValueError(
|
|
513
|
+
"names parameter must be a list or None if using a list of models"
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
if isinstance(descriptions, list):
|
|
517
|
+
if len(descriptions) != len(models):
|
|
597
518
|
raise ValueError(
|
|
598
|
-
"
|
|
519
|
+
"descriptions parameter must be the same length as the models"
|
|
520
|
+
" parameter."
|
|
599
521
|
)
|
|
522
|
+
elif descriptions is None:
|
|
523
|
+
descriptions = [None] * len(models)
|
|
600
524
|
else:
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
525
|
+
raise ValueError(
|
|
526
|
+
"descriptions parameter must be a list or None if using a list of "
|
|
527
|
+
"models"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
json_schema = None
|
|
531
|
+
if output_schema:
|
|
532
|
+
# Convert BaseModel to dict if needed
|
|
533
|
+
json_schema = normalize_output_schema(output_schema)
|
|
534
|
+
|
|
535
|
+
def start_job(
|
|
536
|
+
model_singleton: ModelOptions,
|
|
537
|
+
name_singleton: str | None,
|
|
538
|
+
description_singleton: str | None,
|
|
539
|
+
):
|
|
540
|
+
return self._run_one_batch_inference(
|
|
606
541
|
data,
|
|
607
|
-
|
|
542
|
+
model_singleton,
|
|
608
543
|
column,
|
|
609
544
|
output_column,
|
|
610
545
|
job_priority,
|
|
@@ -612,20 +547,21 @@ class Sutro:
|
|
|
612
547
|
sampling_params,
|
|
613
548
|
system_prompt,
|
|
614
549
|
dry_run,
|
|
615
|
-
|
|
550
|
+
False,
|
|
616
551
|
random_seed_per_input,
|
|
617
552
|
truncate_rows,
|
|
618
|
-
|
|
619
|
-
|
|
553
|
+
name_singleton,
|
|
554
|
+
description_singleton,
|
|
620
555
|
)
|
|
621
|
-
results.append(res)
|
|
622
556
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
557
|
+
job_ids = [
|
|
558
|
+
start_job(model, name, description)
|
|
559
|
+
for model, name, description in zip(
|
|
560
|
+
models, names, descriptions, strict=True
|
|
561
|
+
)
|
|
562
|
+
]
|
|
627
563
|
|
|
628
|
-
return
|
|
564
|
+
return job_ids
|
|
629
565
|
|
|
630
566
|
def attach(self, job_id):
|
|
631
567
|
"""
|
|
@@ -636,16 +572,8 @@ class Sutro:
|
|
|
636
572
|
"""
|
|
637
573
|
|
|
638
574
|
s = requests.Session()
|
|
639
|
-
payload = {
|
|
640
|
-
"job_id": job_id,
|
|
641
|
-
}
|
|
642
575
|
pbar = None
|
|
643
576
|
|
|
644
|
-
headers = {
|
|
645
|
-
"Authorization": f"Key {self.api_key}",
|
|
646
|
-
"Content-Type": "application/json",
|
|
647
|
-
}
|
|
648
|
-
|
|
649
577
|
with yaspin(
|
|
650
578
|
SPINNER,
|
|
651
579
|
text=to_colored_text("Looking for job..."),
|
|
@@ -683,9 +611,9 @@ class Sutro:
|
|
|
683
611
|
success = False
|
|
684
612
|
|
|
685
613
|
try:
|
|
686
|
-
with
|
|
687
|
-
|
|
688
|
-
|
|
614
|
+
with self.do_request(
|
|
615
|
+
"GET",
|
|
616
|
+
f"/stream-job-progress/{job_id}",
|
|
689
617
|
stream=True,
|
|
690
618
|
) as streaming_response:
|
|
691
619
|
streaming_response.raise_for_status()
|
|
@@ -715,7 +643,7 @@ class Sutro:
|
|
|
715
643
|
if pbar is None:
|
|
716
644
|
spinner.stop()
|
|
717
645
|
postfix = "Input tokens processed: 0"
|
|
718
|
-
pbar =
|
|
646
|
+
pbar = fancy_tqdm(
|
|
719
647
|
total=total_rows,
|
|
720
648
|
desc="Progress",
|
|
721
649
|
style=1,
|
|
@@ -748,65 +676,6 @@ class Sutro:
|
|
|
748
676
|
if spinner:
|
|
749
677
|
spinner.stop()
|
|
750
678
|
|
|
751
|
-
def fancy_tqdm(
|
|
752
|
-
self,
|
|
753
|
-
total: int,
|
|
754
|
-
desc: str = "Progress",
|
|
755
|
-
color: str = "blue",
|
|
756
|
-
style=1,
|
|
757
|
-
postfix: str = None,
|
|
758
|
-
):
|
|
759
|
-
"""
|
|
760
|
-
Creates a customized tqdm progress bar with different styling options.
|
|
761
|
-
|
|
762
|
-
Args:
|
|
763
|
-
total (int): Total iterations
|
|
764
|
-
desc (str): Description for the progress bar
|
|
765
|
-
color (str): Color of the progress bar (green, blue, red, yellow, magenta)
|
|
766
|
-
style (int): Style preset (1-4)
|
|
767
|
-
postfix (str): Postfix for the progress bar
|
|
768
|
-
"""
|
|
769
|
-
|
|
770
|
-
# Style presets
|
|
771
|
-
style_presets = {
|
|
772
|
-
1: {
|
|
773
|
-
"bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
|
|
774
|
-
"ascii": "░▒█",
|
|
775
|
-
},
|
|
776
|
-
2: {
|
|
777
|
-
"bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
|
|
778
|
-
"ascii": "▁▂▃▄▅▆▇█",
|
|
779
|
-
},
|
|
780
|
-
3: {
|
|
781
|
-
"bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
|
|
782
|
-
"ascii": "◯◔◑◕●",
|
|
783
|
-
},
|
|
784
|
-
4: {
|
|
785
|
-
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
786
|
-
"ascii": "⬜⬛",
|
|
787
|
-
},
|
|
788
|
-
5: {
|
|
789
|
-
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
790
|
-
"ascii": "▏▎▍▌▋▊▉█",
|
|
791
|
-
},
|
|
792
|
-
}
|
|
793
|
-
|
|
794
|
-
# Get style configuration
|
|
795
|
-
style_config = style_presets.get(style, style_presets[1])
|
|
796
|
-
|
|
797
|
-
return tqdm(
|
|
798
|
-
total=total,
|
|
799
|
-
desc=desc,
|
|
800
|
-
colour=color,
|
|
801
|
-
bar_format=style_config["bar_format"],
|
|
802
|
-
ascii=style_config["ascii"],
|
|
803
|
-
ncols=80,
|
|
804
|
-
dynamic_ncols=True,
|
|
805
|
-
smoothing=0.3,
|
|
806
|
-
leave=True,
|
|
807
|
-
postfix=postfix,
|
|
808
|
-
)
|
|
809
|
-
|
|
810
679
|
def list_jobs(self):
|
|
811
680
|
"""
|
|
812
681
|
List all jobs.
|
|
@@ -814,56 +683,36 @@ class Sutro:
|
|
|
814
683
|
This method retrieves a list of all jobs associated with the API key.
|
|
815
684
|
|
|
816
685
|
Returns:
|
|
817
|
-
list: A list of job details.
|
|
686
|
+
list: A list of job details, or None if the request fails.
|
|
818
687
|
"""
|
|
819
|
-
endpoint = f"{self.base_url}/list-jobs"
|
|
820
|
-
headers = {
|
|
821
|
-
"Authorization": f"Key {self.api_key}",
|
|
822
|
-
"Content-Type": "application/json",
|
|
823
|
-
}
|
|
824
|
-
|
|
825
688
|
with yaspin(
|
|
826
689
|
SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
|
|
827
690
|
) as spinner:
|
|
828
|
-
|
|
829
|
-
|
|
691
|
+
try:
|
|
692
|
+
return self._list_all_jobs_for_user()
|
|
693
|
+
except requests.HTTPError as e:
|
|
830
694
|
spinner.write(
|
|
831
695
|
to_colored_text(
|
|
832
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
696
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
833
697
|
)
|
|
834
698
|
)
|
|
835
699
|
spinner.stop()
|
|
836
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
837
|
-
return
|
|
838
|
-
return response.json()["jobs"]
|
|
700
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
701
|
+
return None
|
|
839
702
|
|
|
840
|
-
def
|
|
841
|
-
"""
|
|
842
|
-
Helper function to list jobs.
|
|
843
|
-
"""
|
|
844
|
-
endpoint = f"{self.base_url}/list-jobs˚"
|
|
845
|
-
headers = {
|
|
846
|
-
"Authorization": f"Key {self.api_key}",
|
|
847
|
-
"Content-Type": "application/json",
|
|
848
|
-
}
|
|
849
|
-
response = requests.get(endpoint, headers=headers)
|
|
850
|
-
if response.status_code != 200:
|
|
851
|
-
return None
|
|
703
|
+
def _list_all_jobs_for_user(self):
|
|
704
|
+
response = self.do_request("GET", "list-jobs")
|
|
852
705
|
return response.json()["jobs"]
|
|
853
706
|
|
|
854
707
|
def _fetch_job(self, job_id):
|
|
855
708
|
"""
|
|
856
709
|
Helper function to fetch a single job.
|
|
857
710
|
"""
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
}
|
|
863
|
-
response = requests.get(endpoint, headers=headers)
|
|
864
|
-
if response.status_code != 200:
|
|
711
|
+
try:
|
|
712
|
+
response = self.do_request("GET", f"jobs/{job_id}")
|
|
713
|
+
return response.json().get("job")
|
|
714
|
+
except requests.HTTPError:
|
|
865
715
|
return None
|
|
866
|
-
return response.json().get("job")
|
|
867
716
|
|
|
868
717
|
def _get_job_cost_estimate(self, job_id: str):
|
|
869
718
|
"""
|
|
@@ -897,15 +746,7 @@ class Sutro:
|
|
|
897
746
|
Raises:
|
|
898
747
|
requests.HTTPError: If the API returns a non-200 status code.
|
|
899
748
|
"""
|
|
900
|
-
|
|
901
|
-
headers = {
|
|
902
|
-
"Authorization": f"Key {self.api_key}",
|
|
903
|
-
"Content-Type": "application/json",
|
|
904
|
-
}
|
|
905
|
-
|
|
906
|
-
response = requests.get(endpoint, headers=headers)
|
|
907
|
-
response.raise_for_status()
|
|
908
|
-
|
|
749
|
+
response = self.do_request("GET", f"job-status/{job_id}")
|
|
909
750
|
return response.json()["job_status"][job_id]
|
|
910
751
|
|
|
911
752
|
def get_job_status(self, job_id: str):
|
|
@@ -950,7 +791,7 @@ class Sutro:
|
|
|
950
791
|
output_column: str = "inference_result",
|
|
951
792
|
disable_cache: bool = False,
|
|
952
793
|
unpack_json: bool = True,
|
|
953
|
-
):
|
|
794
|
+
) -> pl.DataFrame | pd.DataFrame:
|
|
954
795
|
"""
|
|
955
796
|
Get the results of a job by its ID.
|
|
956
797
|
|
|
@@ -987,44 +828,37 @@ class Sutro:
|
|
|
987
828
|
to_colored_text("✔ Results loaded from cache", state="success")
|
|
988
829
|
)
|
|
989
830
|
else:
|
|
990
|
-
endpoint = f"{self.base_url}/job-results"
|
|
991
831
|
payload = {
|
|
992
832
|
"job_id": job_id,
|
|
993
833
|
"include_inputs": include_inputs,
|
|
994
834
|
"include_cumulative_logprobs": include_cumulative_logprobs,
|
|
995
835
|
}
|
|
996
|
-
headers = {
|
|
997
|
-
"Authorization": f"Key {self.api_key}",
|
|
998
|
-
"Content-Type": "application/json",
|
|
999
|
-
}
|
|
1000
836
|
with yaspin(
|
|
1001
837
|
SPINNER,
|
|
1002
838
|
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
1003
839
|
color=YASPIN_COLOR,
|
|
1004
840
|
) as spinner:
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
841
|
+
try:
|
|
842
|
+
response = self.do_request("POST", "job-results", json=payload)
|
|
843
|
+
response_data = response.json()
|
|
844
|
+
spinner.write(
|
|
845
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
846
|
+
)
|
|
847
|
+
except requests.HTTPError as e:
|
|
1009
848
|
spinner.write(
|
|
1010
849
|
to_colored_text(
|
|
1011
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
850
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1012
851
|
)
|
|
1013
852
|
)
|
|
1014
853
|
spinner.stop()
|
|
1015
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
854
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1016
855
|
return None
|
|
1017
856
|
|
|
1018
|
-
spinner.write(
|
|
1019
|
-
to_colored_text("✔ Job results retrieved", state="success")
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
response_data = response.json()
|
|
1023
857
|
results_df = pl.DataFrame(response_data["results"])
|
|
1024
858
|
|
|
1025
859
|
results_df = results_df.rename({"outputs": output_column})
|
|
1026
860
|
|
|
1027
|
-
if disable_cache
|
|
861
|
+
if not disable_cache:
|
|
1028
862
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
1029
863
|
results_df.write_parquet(file_path, compression="snappy")
|
|
1030
864
|
spinner.write(
|
|
@@ -1051,10 +885,11 @@ class Sutro:
|
|
|
1051
885
|
first_row = json.loads(
|
|
1052
886
|
results_df.head(1)[output_column][0]
|
|
1053
887
|
) # checks if the first row can be json decoded
|
|
888
|
+
results_df = results_df.map_columns(
|
|
889
|
+
output_column, lambda s: s.str.json_decode()
|
|
890
|
+
)
|
|
1054
891
|
results_df = results_df.with_columns(
|
|
1055
|
-
pl.col(output_column)
|
|
1056
|
-
.str.json_decode()
|
|
1057
|
-
.alias("output_column_json_decoded")
|
|
892
|
+
pl.col(output_column).alias("output_column_json_decoded")
|
|
1058
893
|
)
|
|
1059
894
|
json_decoded_fields = first_row.keys()
|
|
1060
895
|
for field in json_decoded_fields:
|
|
@@ -1063,11 +898,20 @@ class Sutro:
|
|
|
1063
898
|
.struct.field(field)
|
|
1064
899
|
.alias(field)
|
|
1065
900
|
)
|
|
1066
|
-
|
|
901
|
+
if sorted(list(set(json_decoded_fields))) == [
|
|
902
|
+
"content",
|
|
903
|
+
"reasoning_content",
|
|
904
|
+
]: # if it's a reasoning model, we need to unpack the content field
|
|
905
|
+
content_keys = results_df.head(1)["content"][0].keys()
|
|
906
|
+
for key in content_keys:
|
|
907
|
+
results_df = results_df.with_columns(
|
|
908
|
+
pl.col("content").struct.field(key).alias(key)
|
|
909
|
+
)
|
|
910
|
+
results_df = results_df.drop("content")
|
|
1067
911
|
results_df = results_df.drop(
|
|
1068
912
|
[output_column, "output_column_json_decoded"]
|
|
1069
913
|
)
|
|
1070
|
-
except Exception
|
|
914
|
+
except Exception:
|
|
1071
915
|
# if the first row cannot be json decoded, do nothing
|
|
1072
916
|
pass
|
|
1073
917
|
|
|
@@ -1103,25 +947,20 @@ class Sutro:
|
|
|
1103
947
|
Returns:
|
|
1104
948
|
dict: The status of the job.
|
|
1105
949
|
"""
|
|
1106
|
-
endpoint = f"{self.base_url}/job-cancel/{job_id}"
|
|
1107
|
-
headers = {
|
|
1108
|
-
"Authorization": f"Key {self.api_key}",
|
|
1109
|
-
"Content-Type": "application/json",
|
|
1110
|
-
}
|
|
1111
950
|
with yaspin(
|
|
1112
951
|
SPINNER,
|
|
1113
952
|
text=to_colored_text(f"Cancelling job: {job_id}"),
|
|
1114
953
|
color=YASPIN_COLOR,
|
|
1115
954
|
) as spinner:
|
|
1116
|
-
|
|
1117
|
-
|
|
955
|
+
try:
|
|
956
|
+
response = self.do_request("GET", f"job-cancel/{job_id}")
|
|
1118
957
|
spinner.write(to_colored_text("✔ Job cancelled", state="success"))
|
|
1119
|
-
|
|
958
|
+
return response.json()
|
|
959
|
+
except requests.HTTPError as e:
|
|
1120
960
|
spinner.write(to_colored_text("Failed to cancel job", state="fail"))
|
|
1121
961
|
spinner.stop()
|
|
1122
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
1123
|
-
return
|
|
1124
|
-
return response.json()
|
|
962
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
963
|
+
return None
|
|
1125
964
|
|
|
1126
965
|
def create_dataset(self):
|
|
1127
966
|
"""
|
|
@@ -1132,31 +971,27 @@ class Sutro:
|
|
|
1132
971
|
Returns:
|
|
1133
972
|
str: The ID of the new dataset.
|
|
1134
973
|
"""
|
|
1135
|
-
endpoint = f"{self.base_url}/create-dataset"
|
|
1136
|
-
headers = {
|
|
1137
|
-
"Authorization": f"Key {self.api_key}",
|
|
1138
|
-
"Content-Type": "application/json",
|
|
1139
|
-
}
|
|
1140
974
|
with yaspin(
|
|
1141
975
|
SPINNER, text=to_colored_text("Creating dataset"), color=YASPIN_COLOR
|
|
1142
976
|
) as spinner:
|
|
1143
|
-
|
|
1144
|
-
|
|
977
|
+
try:
|
|
978
|
+
response = self.do_request("GET", "create-dataset")
|
|
979
|
+
dataset_id = response.json()["dataset_id"]
|
|
1145
980
|
spinner.write(
|
|
1146
981
|
to_colored_text(
|
|
1147
|
-
f"
|
|
982
|
+
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
1148
983
|
)
|
|
1149
984
|
)
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
985
|
+
return dataset_id
|
|
986
|
+
except requests.HTTPError as e:
|
|
987
|
+
spinner.write(
|
|
988
|
+
to_colored_text(
|
|
989
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
990
|
+
)
|
|
1157
991
|
)
|
|
1158
|
-
|
|
1159
|
-
|
|
992
|
+
spinner.stop()
|
|
993
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
994
|
+
return None
|
|
1160
995
|
|
|
1161
996
|
def upload_to_dataset(
|
|
1162
997
|
self,
|
|
@@ -1188,8 +1023,6 @@ class Sutro:
|
|
|
1188
1023
|
if dataset_id is None:
|
|
1189
1024
|
dataset_id = self.create_dataset()
|
|
1190
1025
|
|
|
1191
|
-
endpoint = f"{self.base_url}/upload-to-dataset"
|
|
1192
|
-
|
|
1193
1026
|
if isinstance(file_paths, str):
|
|
1194
1027
|
# check if the file path is a directory
|
|
1195
1028
|
if os.path.isdir(file_paths):
|
|
@@ -1222,8 +1055,6 @@ class Sutro:
|
|
|
1222
1055
|
"dataset_id": dataset_id,
|
|
1223
1056
|
}
|
|
1224
1057
|
|
|
1225
|
-
headers = {"Authorization": f"Key {self.api_key}"}
|
|
1226
|
-
|
|
1227
1058
|
count += 1
|
|
1228
1059
|
spinner.write(
|
|
1229
1060
|
to_colored_text(
|
|
@@ -1232,25 +1063,18 @@ class Sutro:
|
|
|
1232
1063
|
)
|
|
1233
1064
|
|
|
1234
1065
|
try:
|
|
1235
|
-
|
|
1236
|
-
|
|
1066
|
+
self.do_request(
|
|
1067
|
+
"POST",
|
|
1068
|
+
"/upload-to-dataset",
|
|
1069
|
+
data=payload,
|
|
1070
|
+
files=files,
|
|
1071
|
+
verify=verify_ssl,
|
|
1237
1072
|
)
|
|
1238
|
-
if response.status_code != 200:
|
|
1239
|
-
# Stop spinner before showing error to avoid terminal width error
|
|
1240
|
-
spinner.stop()
|
|
1241
|
-
print(
|
|
1242
|
-
to_colored_text(
|
|
1243
|
-
f"Error: HTTP {response.status_code}", state="fail"
|
|
1244
|
-
)
|
|
1245
|
-
)
|
|
1246
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
1247
|
-
return
|
|
1248
|
-
|
|
1249
1073
|
except requests.exceptions.RequestException as e:
|
|
1250
1074
|
# Stop spinner before showing error to avoid terminal width error
|
|
1251
1075
|
spinner.stop()
|
|
1252
1076
|
print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
|
|
1253
|
-
return
|
|
1077
|
+
return None
|
|
1254
1078
|
|
|
1255
1079
|
spinner.write(
|
|
1256
1080
|
to_colored_text(
|
|
@@ -1260,32 +1084,23 @@ class Sutro:
|
|
|
1260
1084
|
return dataset_id
|
|
1261
1085
|
|
|
1262
1086
|
def list_datasets(self):
|
|
1263
|
-
endpoint = f"{self.base_url}/list-datasets"
|
|
1264
|
-
headers = {
|
|
1265
|
-
"Authorization": f"Key {self.api_key}",
|
|
1266
|
-
"Content-Type": "application/json",
|
|
1267
|
-
}
|
|
1268
1087
|
with yaspin(
|
|
1269
1088
|
SPINNER, text=to_colored_text("Retrieving datasets"), color=YASPIN_COLOR
|
|
1270
1089
|
) as spinner:
|
|
1271
|
-
|
|
1272
|
-
|
|
1090
|
+
try:
|
|
1091
|
+
response = self.do_request("POST", "list-datasets")
|
|
1092
|
+
spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
|
|
1093
|
+
return response.json()["datasets"]
|
|
1094
|
+
except requests.HTTPError as e:
|
|
1273
1095
|
spinner.fail(
|
|
1274
1096
|
to_colored_text(
|
|
1275
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1097
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1276
1098
|
)
|
|
1277
1099
|
)
|
|
1278
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1279
|
-
return
|
|
1280
|
-
spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
|
|
1281
|
-
return response.json()["datasets"]
|
|
1100
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1101
|
+
return None
|
|
1282
1102
|
|
|
1283
1103
|
def list_dataset_files(self, dataset_id: str):
|
|
1284
|
-
endpoint = f"{self.base_url}/list-dataset-files"
|
|
1285
|
-
headers = {
|
|
1286
|
-
"Authorization": f"Key {self.api_key}",
|
|
1287
|
-
"Content-Type": "application/json",
|
|
1288
|
-
}
|
|
1289
1104
|
payload = {
|
|
1290
1105
|
"dataset_id": dataset_id,
|
|
1291
1106
|
}
|
|
@@ -1294,23 +1109,22 @@ class Sutro:
|
|
|
1294
1109
|
text=to_colored_text(f"Listing files in dataset: {dataset_id}"),
|
|
1295
1110
|
color=YASPIN_COLOR,
|
|
1296
1111
|
) as spinner:
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
if response.status_code != 200:
|
|
1301
|
-
spinner.fail(
|
|
1112
|
+
try:
|
|
1113
|
+
response = self.do_request("POST", "list-dataset-files", json=payload)
|
|
1114
|
+
spinner.write(
|
|
1302
1115
|
to_colored_text(
|
|
1303
|
-
f"
|
|
1116
|
+
f"✔ Files listed in dataset: {dataset_id}", state="success"
|
|
1304
1117
|
)
|
|
1305
1118
|
)
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1119
|
+
return response.json()["files"]
|
|
1120
|
+
except requests.HTTPError as e:
|
|
1121
|
+
spinner.fail(
|
|
1122
|
+
to_colored_text(
|
|
1123
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1124
|
+
)
|
|
1311
1125
|
)
|
|
1312
|
-
|
|
1313
|
-
|
|
1126
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1127
|
+
return None
|
|
1314
1128
|
|
|
1315
1129
|
def download_from_dataset(
|
|
1316
1130
|
self,
|
|
@@ -1318,8 +1132,6 @@ class Sutro:
|
|
|
1318
1132
|
files: Union[List[str], str] = None,
|
|
1319
1133
|
output_path: str = None,
|
|
1320
1134
|
):
|
|
1321
|
-
endpoint = f"{self.base_url}/download-from-dataset"
|
|
1322
|
-
|
|
1323
1135
|
if files is None:
|
|
1324
1136
|
files = self.list_dataset_files(dataset_id)
|
|
1325
1137
|
elif isinstance(files, str):
|
|
@@ -1344,32 +1156,32 @@ class Sutro:
|
|
|
1344
1156
|
) as spinner:
|
|
1345
1157
|
count = 0
|
|
1346
1158
|
for file in files:
|
|
1347
|
-
headers = {
|
|
1348
|
-
"Authorization": f"Key {self.api_key}",
|
|
1349
|
-
"Content-Type": "application/json",
|
|
1350
|
-
}
|
|
1351
|
-
payload = {
|
|
1352
|
-
"dataset_id": dataset_id,
|
|
1353
|
-
"file_name": file,
|
|
1354
|
-
}
|
|
1355
1159
|
spinner.text = to_colored_text(
|
|
1356
1160
|
f"Downloading file {count + 1}/{len(files)} from dataset: {dataset_id}"
|
|
1357
1161
|
)
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1162
|
+
|
|
1163
|
+
try:
|
|
1164
|
+
payload = {
|
|
1165
|
+
"dataset_id": dataset_id,
|
|
1166
|
+
"file_name": file,
|
|
1167
|
+
}
|
|
1168
|
+
response = self.do_request(
|
|
1169
|
+
"POST", "download-from-dataset", json=payload
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
file_content = response.content
|
|
1173
|
+
with open(os.path.join(output_path, file), "wb") as f:
|
|
1174
|
+
f.write(file_content)
|
|
1175
|
+
|
|
1176
|
+
count += 1
|
|
1177
|
+
except requests.HTTPError as e:
|
|
1362
1178
|
spinner.fail(
|
|
1363
1179
|
to_colored_text(
|
|
1364
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1180
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1365
1181
|
)
|
|
1366
1182
|
)
|
|
1367
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1183
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1368
1184
|
return
|
|
1369
|
-
file_content = response.content
|
|
1370
|
-
with open(os.path.join(output_path, file), "wb") as f:
|
|
1371
|
-
f.write(file_content)
|
|
1372
|
-
count += 1
|
|
1373
1185
|
spinner.write(
|
|
1374
1186
|
to_colored_text(
|
|
1375
1187
|
f"✔ {count} files successfully downloaded from dataset: {dataset_id}",
|
|
@@ -1389,46 +1201,38 @@ class Sutro:
|
|
|
1389
1201
|
Returns:
|
|
1390
1202
|
dict: The status of the authentication.
|
|
1391
1203
|
"""
|
|
1392
|
-
endpoint = f"{self.base_url}/try-authentication"
|
|
1393
|
-
headers = {
|
|
1394
|
-
"Authorization": f"Key {api_key}",
|
|
1395
|
-
"Content-Type": "application/json",
|
|
1396
|
-
}
|
|
1397
1204
|
with yaspin(
|
|
1398
1205
|
SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
|
|
1399
1206
|
) as spinner:
|
|
1400
|
-
|
|
1401
|
-
|
|
1207
|
+
try:
|
|
1208
|
+
response = self.do_request("GET", "try-authentication", api_key)
|
|
1209
|
+
|
|
1402
1210
|
spinner.write(to_colored_text("✔"))
|
|
1403
|
-
|
|
1211
|
+
return response.json()
|
|
1212
|
+
except requests.HTTPError as e:
|
|
1404
1213
|
spinner.write(
|
|
1405
1214
|
to_colored_text(
|
|
1406
|
-
f"API key failed to authenticate: {response.status_code}",
|
|
1215
|
+
f"API key failed to authenticate: {e.response.status_code}",
|
|
1407
1216
|
state="fail",
|
|
1408
1217
|
)
|
|
1409
1218
|
)
|
|
1410
|
-
return
|
|
1411
|
-
return response.json()
|
|
1219
|
+
return None
|
|
1412
1220
|
|
|
1413
1221
|
def get_quotas(self):
|
|
1414
|
-
endpoint = f"{self.base_url}/get-quotas"
|
|
1415
|
-
headers = {
|
|
1416
|
-
"Authorization": f"Key {self.api_key}",
|
|
1417
|
-
"Content-Type": "application/json",
|
|
1418
|
-
}
|
|
1419
1222
|
with yaspin(
|
|
1420
1223
|
SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
|
|
1421
1224
|
) as spinner:
|
|
1422
|
-
|
|
1423
|
-
|
|
1225
|
+
try:
|
|
1226
|
+
response = self.do_request("GET", "get-quotas")
|
|
1227
|
+
return response.json()["quotas"]
|
|
1228
|
+
except requests.HTTPError as e:
|
|
1424
1229
|
spinner.fail(
|
|
1425
1230
|
to_colored_text(
|
|
1426
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1231
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1427
1232
|
)
|
|
1428
1233
|
)
|
|
1429
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1430
|
-
return
|
|
1431
|
-
return response.json()["quotas"]
|
|
1234
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1235
|
+
return None
|
|
1432
1236
|
|
|
1433
1237
|
def await_job_completion(
|
|
1434
1238
|
self,
|
|
@@ -1436,7 +1240,7 @@ class Sutro:
|
|
|
1436
1240
|
timeout: Optional[int] = 7200,
|
|
1437
1241
|
obtain_results: bool = True,
|
|
1438
1242
|
is_cost_estimate: bool = False,
|
|
1439
|
-
) ->
|
|
1243
|
+
) -> pl.DataFrame | None:
|
|
1440
1244
|
"""
|
|
1441
1245
|
Waits for job completion to occur and then returns the results upon
|
|
1442
1246
|
a successful completion.
|
|
@@ -1448,11 +1252,11 @@ class Sutro:
|
|
|
1448
1252
|
timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
|
|
1449
1253
|
|
|
1450
1254
|
Returns:
|
|
1451
|
-
|
|
1255
|
+
pl.DataFrame: The results of the job in a polars DataFrame.
|
|
1452
1256
|
"""
|
|
1453
1257
|
POLL_INTERVAL = 5
|
|
1454
1258
|
|
|
1455
|
-
results = None
|
|
1259
|
+
results: pl.DataFrame | None = None
|
|
1456
1260
|
start_time = time.time()
|
|
1457
1261
|
with yaspin(
|
|
1458
1262
|
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|