sutro 0.1.37__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/cli.py +1 -1
- sutro/common.py +220 -0
- sutro/interfaces.py +90 -0
- sutro/sdk.py +333 -579
- sutro/templates/classification.py +117 -0
- sutro/templates/embed.py +53 -0
- sutro/validation.py +60 -0
- {sutro-0.1.37.dist-info → sutro-0.1.38.dist-info}/METADATA +1 -1
- sutro-0.1.38.dist-info/RECORD +12 -0
- sutro-0.1.37.dist-info/RECORD +0 -7
- {sutro-0.1.37.dist-info → sutro-0.1.38.dist-info}/WHEEL +0 -0
- {sutro-0.1.37.dist-info → sutro-0.1.38.dist-info}/entry_points.txt +0 -0
sutro/sdk.py
CHANGED
|
@@ -1,49 +1,32 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
1
|
import requests
|
|
3
2
|
import pandas as pd
|
|
4
3
|
import polars as pl
|
|
5
4
|
import json
|
|
6
|
-
from typing import Union, List, Optional,
|
|
5
|
+
from typing import Union, List, Optional, Dict, Any, Type
|
|
7
6
|
import os
|
|
8
7
|
import sys
|
|
9
8
|
from yaspin import yaspin
|
|
10
9
|
from yaspin.spinners import Spinners
|
|
11
|
-
from colorama import init
|
|
12
|
-
from tqdm import tqdm
|
|
10
|
+
from colorama import init
|
|
13
11
|
import time
|
|
14
12
|
from pydantic import BaseModel
|
|
15
13
|
import pyarrow.parquet as pq
|
|
16
14
|
import shutil
|
|
17
|
-
import
|
|
15
|
+
from sutro.common import (
|
|
16
|
+
ModelOptions,
|
|
17
|
+
handle_data_helper,
|
|
18
|
+
normalize_output_schema,
|
|
19
|
+
to_colored_text,
|
|
20
|
+
fancy_tqdm,
|
|
21
|
+
)
|
|
22
|
+
from sutro.interfaces import JobStatus
|
|
23
|
+
from sutro.templates.classification import ClassificationTemplates
|
|
24
|
+
from sutro.templates.embed import EmbeddingTemplates
|
|
25
|
+
from sutro.validation import check_version, check_for_api_key
|
|
18
26
|
|
|
19
27
|
JOB_NAME_CHAR_LIMIT = 45
|
|
20
28
|
JOB_DESCRIPTION_CHAR_LIMIT = 512
|
|
21
29
|
|
|
22
|
-
class JobStatus(str, Enum):
|
|
23
|
-
"""Job statuses that will be returned by the API & SDK"""
|
|
24
|
-
|
|
25
|
-
UNKNOWN = "UNKNOWN"
|
|
26
|
-
QUEUED = "QUEUED" # Job is waiting to start
|
|
27
|
-
STARTING = "STARTING" # Job is in the process of starting up
|
|
28
|
-
RUNNING = "RUNNING" # Job is actively running
|
|
29
|
-
SUCCEEDED = "SUCCEEDED" # Job completed successfully
|
|
30
|
-
CANCELLING = "CANCELLING" # Job is in the process of being canceled
|
|
31
|
-
CANCELLED = "CANCELLED" # Job was canceled by the user
|
|
32
|
-
FAILED = "FAILED" # Job failed
|
|
33
|
-
|
|
34
|
-
@classmethod
|
|
35
|
-
def terminal_statuses(cls) -> list["JobStatus"]:
|
|
36
|
-
return [
|
|
37
|
-
cls.SUCCEEDED,
|
|
38
|
-
cls.FAILED,
|
|
39
|
-
cls.CANCELLING,
|
|
40
|
-
cls.CANCELLED,
|
|
41
|
-
]
|
|
42
|
-
|
|
43
|
-
def is_terminal(self) -> bool:
|
|
44
|
-
return self in self.terminal_statuses()
|
|
45
|
-
|
|
46
|
-
|
|
47
30
|
# Initialize colorama (required for Windows)
|
|
48
31
|
init()
|
|
49
32
|
|
|
@@ -57,59 +40,6 @@ def is_jupyter() -> bool:
|
|
|
57
40
|
YASPIN_COLOR = None if is_jupyter() else "blue"
|
|
58
41
|
SPINNER = Spinners.dots14
|
|
59
42
|
|
|
60
|
-
# Models available for inference. Keep in sync with the backend configuration
|
|
61
|
-
# so users get helpful autocompletion when selecting a model.
|
|
62
|
-
ModelOptions = Literal[
|
|
63
|
-
"llama-3.2-3b",
|
|
64
|
-
"llama-3.1-8b",
|
|
65
|
-
"llama-3.3-70b",
|
|
66
|
-
"llama-3.3-70b",
|
|
67
|
-
"qwen-3-4b",
|
|
68
|
-
"qwen-3-14b",
|
|
69
|
-
"qwen-3-32b",
|
|
70
|
-
"qwen-3-30b-a3b",
|
|
71
|
-
"qwen-3-235b-a22b",
|
|
72
|
-
"qwen-3-4b-thinking",
|
|
73
|
-
"qwen-3-14b-thinking",
|
|
74
|
-
"qwen-3-32b-thinking",
|
|
75
|
-
"qwen-3-235b-a22b-thinking",
|
|
76
|
-
"qwen-3-30b-a3b-thinking",
|
|
77
|
-
"gemma-3-4b-it",
|
|
78
|
-
"gemma-3-12b-it",
|
|
79
|
-
"gemma-3-27b-it",
|
|
80
|
-
"gpt-oss-20b",
|
|
81
|
-
"gpt-oss-120b",
|
|
82
|
-
"qwen-3-embedding-0.6b",
|
|
83
|
-
"qwen-3-embedding-6b",
|
|
84
|
-
"qwen-3-embedding-8b",
|
|
85
|
-
]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def to_colored_text(
|
|
89
|
-
text: str, state: Optional[Literal["success", "fail", "callout"]] = None
|
|
90
|
-
) -> str:
|
|
91
|
-
"""
|
|
92
|
-
Apply color to text based on state.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
text (str): The text to color
|
|
96
|
-
state (Optional[Literal['success', 'fail']]): The state that determines the color.
|
|
97
|
-
Options: 'success', 'fail', or None (default blue)
|
|
98
|
-
|
|
99
|
-
Returns:
|
|
100
|
-
str: Text with appropriate color applied
|
|
101
|
-
"""
|
|
102
|
-
match state:
|
|
103
|
-
case "success":
|
|
104
|
-
return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
|
|
105
|
-
case "fail":
|
|
106
|
-
return f"{Fore.RED}{text}{Style.RESET_ALL}"
|
|
107
|
-
case "callout":
|
|
108
|
-
return f"{Fore.MAGENTA}{text}{Style.RESET_ALL}"
|
|
109
|
-
case _:
|
|
110
|
-
# Default to blue for normal/processing states
|
|
111
|
-
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
112
|
-
|
|
113
43
|
|
|
114
44
|
# Isn't fully support in all terminals unfortunately. We should switch to Rich
|
|
115
45
|
# at some point, but even Rich links aren't clickable on MacOS Terminal
|
|
@@ -123,64 +53,11 @@ def make_clickable_link(url, text=None):
|
|
|
123
53
|
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
|
|
124
54
|
|
|
125
55
|
|
|
126
|
-
class Sutro:
|
|
56
|
+
class Sutro(EmbeddingTemplates, ClassificationTemplates):
|
|
127
57
|
def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
|
|
128
|
-
self.api_key = api_key or
|
|
58
|
+
self.api_key = api_key or check_for_api_key()
|
|
129
59
|
self.base_url = base_url
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def check_version(self, package_name: str):
|
|
133
|
-
try:
|
|
134
|
-
# Local version
|
|
135
|
-
local_version = importlib.metadata.version(package_name)
|
|
136
|
-
except importlib.metadata.PackageNotFoundError:
|
|
137
|
-
print(f"{package_name} is not installed.")
|
|
138
|
-
return
|
|
139
|
-
|
|
140
|
-
try:
|
|
141
|
-
# Latest release from PyPI
|
|
142
|
-
resp = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=2)
|
|
143
|
-
resp.raise_for_status()
|
|
144
|
-
latest_version = resp.json()["info"]["version"]
|
|
145
|
-
|
|
146
|
-
if local_version != latest_version:
|
|
147
|
-
msg = (f"⚠️ You are using {package_name} {local_version}, "
|
|
148
|
-
f"but the latest release is {latest_version}. "
|
|
149
|
-
f"Run `[uv] pip install -U {package_name}` to upgrade.")
|
|
150
|
-
print(to_colored_text(
|
|
151
|
-
msg,
|
|
152
|
-
state="callout"
|
|
153
|
-
)
|
|
154
|
-
)
|
|
155
|
-
except Exception as e:
|
|
156
|
-
# Fail silently or log, you don’t want this blocking usage
|
|
157
|
-
pass
|
|
158
|
-
|
|
159
|
-
def check_for_api_key(self):
|
|
160
|
-
"""
|
|
161
|
-
Check for an API key in the user's home directory.
|
|
162
|
-
|
|
163
|
-
This method looks for a configuration file named 'config.json' in the
|
|
164
|
-
'.sutro' directory within the user's home directory.
|
|
165
|
-
If the file exists, it attempts to read the API key from it.
|
|
166
|
-
|
|
167
|
-
Returns:
|
|
168
|
-
str or None: The API key if found in the configuration file, or None if not found.
|
|
169
|
-
|
|
170
|
-
Note:
|
|
171
|
-
The expected structure of the config.json file is:
|
|
172
|
-
{
|
|
173
|
-
"api_key": "your_api_key_here"
|
|
174
|
-
}
|
|
175
|
-
"""
|
|
176
|
-
CONFIG_DIR = os.path.expanduser("~/.sutro")
|
|
177
|
-
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
178
|
-
if os.path.exists(CONFIG_FILE):
|
|
179
|
-
with open(CONFIG_FILE, "r") as f:
|
|
180
|
-
config = json.load(f)
|
|
181
|
-
return config.get("api_key")
|
|
182
|
-
else:
|
|
183
|
-
return None
|
|
60
|
+
check_version("sutro")
|
|
184
61
|
|
|
185
62
|
def set_api_key(self, api_key: str):
|
|
186
63
|
"""
|
|
@@ -197,79 +74,6 @@ class Sutro:
|
|
|
197
74
|
"""
|
|
198
75
|
self.api_key = api_key
|
|
199
76
|
|
|
200
|
-
def do_dataframe_column_concatenation(self, data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]):
|
|
201
|
-
"""
|
|
202
|
-
If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
|
|
203
|
-
"""
|
|
204
|
-
try:
|
|
205
|
-
if isinstance(data, pd.DataFrame):
|
|
206
|
-
series_parts = []
|
|
207
|
-
for p in column:
|
|
208
|
-
if p in data.columns:
|
|
209
|
-
s = data[p].astype("string").fillna("")
|
|
210
|
-
else:
|
|
211
|
-
# Treat as a literal separator
|
|
212
|
-
s = pd.Series([p] * len(data), index=data.index, dtype="string")
|
|
213
|
-
series_parts.append(s)
|
|
214
|
-
|
|
215
|
-
out = series_parts[0]
|
|
216
|
-
for s in series_parts[1:]:
|
|
217
|
-
out = out.str.cat(s, na_rep="")
|
|
218
|
-
|
|
219
|
-
return out.tolist()
|
|
220
|
-
elif isinstance(data, pl.DataFrame):
|
|
221
|
-
exprs = []
|
|
222
|
-
for p in column:
|
|
223
|
-
if p in data.columns:
|
|
224
|
-
exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
|
|
225
|
-
else:
|
|
226
|
-
exprs.append(pl.lit(p))
|
|
227
|
-
|
|
228
|
-
result = data.select(pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat"))
|
|
229
|
-
return result["concat"].to_list()
|
|
230
|
-
except Exception as e:
|
|
231
|
-
raise ValueError(f"Error handling column concatentation: {e}")
|
|
232
|
-
|
|
233
|
-
def handle_data_helper(
|
|
234
|
-
self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
|
|
235
|
-
):
|
|
236
|
-
if isinstance(data, list):
|
|
237
|
-
input_data = data
|
|
238
|
-
elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
239
|
-
if column is None:
|
|
240
|
-
raise ValueError("Column name must be specified for DataFrame input")
|
|
241
|
-
if isinstance(column, list):
|
|
242
|
-
input_data = self.do_dataframe_column_concatenation(data, column)
|
|
243
|
-
elif isinstance(column, str):
|
|
244
|
-
input_data = data[column].to_list()
|
|
245
|
-
elif isinstance(data, str):
|
|
246
|
-
if data.startswith("dataset-"):
|
|
247
|
-
input_data = data + ":" + column
|
|
248
|
-
else:
|
|
249
|
-
file_ext = os.path.splitext(data)[1].lower()
|
|
250
|
-
if file_ext == ".csv":
|
|
251
|
-
df = pl.read_csv(data)
|
|
252
|
-
elif file_ext == ".parquet":
|
|
253
|
-
df = pl.read_parquet(data)
|
|
254
|
-
elif file_ext in [".txt", ""]:
|
|
255
|
-
with open(data, "r") as file:
|
|
256
|
-
input_data = [line.strip() for line in file]
|
|
257
|
-
else:
|
|
258
|
-
raise ValueError(f"Unsupported file type: {file_ext}")
|
|
259
|
-
|
|
260
|
-
if file_ext in [".csv", ".parquet"]:
|
|
261
|
-
if column is None:
|
|
262
|
-
raise ValueError(
|
|
263
|
-
"Column name must be specified for CSV/Parquet input"
|
|
264
|
-
)
|
|
265
|
-
input_data = df[column].to_list()
|
|
266
|
-
else:
|
|
267
|
-
raise ValueError(
|
|
268
|
-
"Unsupported data type. Please provide a list, DataFrame, or file path."
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
return input_data
|
|
272
|
-
|
|
273
77
|
def set_base_url(self, base_url: str):
|
|
274
78
|
"""
|
|
275
79
|
Set the base URL for the Sutro API.
|
|
@@ -282,6 +86,43 @@ class Sutro:
|
|
|
282
86
|
"""
|
|
283
87
|
self.base_url = base_url
|
|
284
88
|
|
|
89
|
+
def do_request(
|
|
90
|
+
self,
|
|
91
|
+
method: str,
|
|
92
|
+
endpoint: str,
|
|
93
|
+
api_key_override: Optional[str] = None,
|
|
94
|
+
**kwargs: Any,
|
|
95
|
+
):
|
|
96
|
+
"""
|
|
97
|
+
Helper to make authenticated requests.
|
|
98
|
+
"""
|
|
99
|
+
key = self.api_key if not api_key_override else api_key_override
|
|
100
|
+
headers = {"Authorization": f"Key {key}"}
|
|
101
|
+
|
|
102
|
+
# Merge with any headers passed in kwargs
|
|
103
|
+
if "headers" in kwargs:
|
|
104
|
+
headers.update(kwargs.pop("headers"))
|
|
105
|
+
|
|
106
|
+
url = f"{self.base_url}/{endpoint.lstrip('/')}"
|
|
107
|
+
|
|
108
|
+
# Explicit method dispatch
|
|
109
|
+
method = method.upper()
|
|
110
|
+
if method == "GET":
|
|
111
|
+
response = requests.get(url, headers=headers, **kwargs)
|
|
112
|
+
elif method == "POST":
|
|
113
|
+
response = requests.post(url, headers=headers, **kwargs)
|
|
114
|
+
elif method == "PUT":
|
|
115
|
+
response = requests.put(url, headers=headers, **kwargs)
|
|
116
|
+
elif method == "DELETE":
|
|
117
|
+
response = requests.delete(url, headers=headers, **kwargs)
|
|
118
|
+
elif method == "PATCH":
|
|
119
|
+
response = requests.patch(url, headers=headers, **kwargs)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
122
|
+
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
return response
|
|
125
|
+
|
|
285
126
|
def _run_one_batch_inference(
|
|
286
127
|
self,
|
|
287
128
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
@@ -301,16 +142,15 @@ class Sutro:
|
|
|
301
142
|
):
|
|
302
143
|
# Validate name and description lengths
|
|
303
144
|
if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
|
|
304
|
-
raise ValueError(
|
|
145
|
+
raise ValueError(
|
|
146
|
+
f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters."
|
|
147
|
+
)
|
|
305
148
|
if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
|
|
306
|
-
raise ValueError(
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters."
|
|
151
|
+
)
|
|
307
152
|
|
|
308
|
-
input_data =
|
|
309
|
-
endpoint = f"{self.base_url}/batch-inference"
|
|
310
|
-
headers = {
|
|
311
|
-
"Authorization": f"Key {self.api_key}",
|
|
312
|
-
"Content-Type": "application/json",
|
|
313
|
-
}
|
|
153
|
+
input_data = handle_data_helper(data, column)
|
|
314
154
|
payload = {
|
|
315
155
|
"model": model,
|
|
316
156
|
"inputs": input_data,
|
|
@@ -336,16 +176,19 @@ class Sutro:
|
|
|
336
176
|
spinner_text = to_colored_text(t)
|
|
337
177
|
try:
|
|
338
178
|
with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
179
|
+
try:
|
|
180
|
+
response = self.do_request("POST", "batch-inference", json=payload)
|
|
181
|
+
response_data = response.json()
|
|
182
|
+
except requests.HTTPError as e:
|
|
183
|
+
response = e.response
|
|
184
|
+
response_data = response.json()
|
|
185
|
+
|
|
343
186
|
if response.status_code != 200:
|
|
344
187
|
spinner.write(
|
|
345
188
|
to_colored_text(f"Error: {response.status_code}", state="fail")
|
|
346
189
|
)
|
|
347
190
|
spinner.stop()
|
|
348
|
-
print(to_colored_text(
|
|
191
|
+
print(to_colored_text(response_data, state="fail"))
|
|
349
192
|
return None
|
|
350
193
|
else:
|
|
351
194
|
job_id = response_data["results"]
|
|
@@ -371,10 +214,11 @@ class Sutro:
|
|
|
371
214
|
name_text = f" and name {name}" if name is not None else ""
|
|
372
215
|
spinner.write(
|
|
373
216
|
to_colored_text(
|
|
374
|
-
f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}
|
|
217
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}",
|
|
375
218
|
state="success",
|
|
376
219
|
)
|
|
377
220
|
)
|
|
221
|
+
spinner.write(to_colored_text(f"Model: {model}"))
|
|
378
222
|
if not stay_attached:
|
|
379
223
|
clickable_link = make_clickable_link(
|
|
380
224
|
f"https://app.sutro.sh/jobs/{job_id}"
|
|
@@ -411,13 +255,13 @@ class Sutro:
|
|
|
411
255
|
)
|
|
412
256
|
)
|
|
413
257
|
return None
|
|
414
|
-
|
|
258
|
+
|
|
415
259
|
pbar = None
|
|
416
260
|
|
|
417
261
|
try:
|
|
418
|
-
with
|
|
419
|
-
|
|
420
|
-
|
|
262
|
+
with self.do_request(
|
|
263
|
+
"GET",
|
|
264
|
+
f"/stream-job-progress/{job_id}",
|
|
421
265
|
stream=True,
|
|
422
266
|
) as streaming_response:
|
|
423
267
|
streaming_response.raise_for_status()
|
|
@@ -446,7 +290,7 @@ class Sutro:
|
|
|
446
290
|
if pbar is None:
|
|
447
291
|
spinner.stop()
|
|
448
292
|
postfix = "Input tokens processed: 0"
|
|
449
|
-
pbar =
|
|
293
|
+
pbar = fancy_tqdm(
|
|
450
294
|
total=len(input_data),
|
|
451
295
|
desc="Progress",
|
|
452
296
|
style=1,
|
|
@@ -487,28 +331,27 @@ class Sutro:
|
|
|
487
331
|
)
|
|
488
332
|
spinner.start()
|
|
489
333
|
|
|
490
|
-
payload = {
|
|
491
|
-
"job_id": job_id,
|
|
492
|
-
}
|
|
493
|
-
|
|
494
334
|
# TODO: we implment retries in cases where the job hasn't written results yet
|
|
495
335
|
# it would be better if we could receive a fully succeeded status from the job
|
|
496
336
|
# and not have such a race condition
|
|
497
337
|
max_retries = 20 # winds up being 100 seconds cumulative delay
|
|
498
338
|
retry_delay = 5 # initial delay in seconds
|
|
499
|
-
|
|
339
|
+
job_results_response = None
|
|
500
340
|
for _ in range(max_retries):
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
341
|
+
try:
|
|
342
|
+
job_results_response = self.do_request(
|
|
343
|
+
"POST",
|
|
344
|
+
"job-results",
|
|
345
|
+
json={
|
|
346
|
+
"job_id": job_id,
|
|
347
|
+
},
|
|
348
|
+
)
|
|
509
349
|
break
|
|
350
|
+
except requests.HTTPError:
|
|
351
|
+
time.sleep(retry_delay)
|
|
352
|
+
continue
|
|
510
353
|
|
|
511
|
-
if job_results_response.status_code != 200:
|
|
354
|
+
if not job_results_response or job_results_response.status_code != 200:
|
|
512
355
|
spinner.write(
|
|
513
356
|
to_colored_text(
|
|
514
357
|
"Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
|
|
@@ -535,11 +378,11 @@ class Sutro:
|
|
|
535
378
|
else:
|
|
536
379
|
print(results)
|
|
537
380
|
spinner.write(
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
381
|
+
to_colored_text(
|
|
382
|
+
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
383
|
+
state="success",
|
|
384
|
+
)
|
|
541
385
|
)
|
|
542
|
-
)
|
|
543
386
|
spinner.stop()
|
|
544
387
|
|
|
545
388
|
return job_id
|
|
@@ -549,13 +392,13 @@ class Sutro:
|
|
|
549
392
|
def infer(
|
|
550
393
|
self,
|
|
551
394
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
552
|
-
model:
|
|
553
|
-
name:
|
|
554
|
-
description:
|
|
395
|
+
model: ModelOptions = "gemma-3-12b-it",
|
|
396
|
+
name: Optional[str] = None,
|
|
397
|
+
description: Optional[str] = None,
|
|
555
398
|
column: Union[str, List[str]] = None,
|
|
556
399
|
output_column: str = "inference_result",
|
|
557
400
|
job_priority: int = 0,
|
|
558
|
-
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
401
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
559
402
|
sampling_params: dict = None,
|
|
560
403
|
system_prompt: str = None,
|
|
561
404
|
dry_run: bool = False,
|
|
@@ -567,18 +410,18 @@ class Sutro:
|
|
|
567
410
|
Run inference on the provided data.
|
|
568
411
|
|
|
569
412
|
This method allows you to run inference on the provided data using the Sutro API.
|
|
570
|
-
It supports various data types such as lists,
|
|
413
|
+
It supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
|
|
571
414
|
|
|
572
415
|
Args:
|
|
573
416
|
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
574
|
-
model (
|
|
575
|
-
name (
|
|
576
|
-
description (
|
|
417
|
+
model (ModelOptions, optional): The model to use for inference. Defaults to "gemma-3-12b-it".
|
|
418
|
+
name (str, optional): A job name for experiment/metadata tracking purposes. Defaults to None.
|
|
419
|
+
description (str, optional): A job description for experiment/metadata tracking purposes. Defaults to None.
|
|
577
420
|
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
578
421
|
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
579
422
|
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
580
423
|
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
581
|
-
Can be either a dictionary representing a JSON schema or a
|
|
424
|
+
Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
|
|
582
425
|
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
583
426
|
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
584
427
|
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
@@ -590,63 +433,113 @@ class Sutro:
|
|
|
590
433
|
str: The ID of the inference job.
|
|
591
434
|
|
|
592
435
|
"""
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
stay_attached =
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
436
|
+
# Default stay_attached to True for prototyping jobs (priority 0)
|
|
437
|
+
if stay_attached is None:
|
|
438
|
+
stay_attached = job_priority == 0
|
|
439
|
+
|
|
440
|
+
json_schema = None
|
|
441
|
+
if output_schema:
|
|
442
|
+
# Convert BaseModel to dict if needed
|
|
443
|
+
json_schema = normalize_output_schema(output_schema)
|
|
444
|
+
|
|
445
|
+
return self._run_one_batch_inference(
|
|
446
|
+
data,
|
|
447
|
+
model,
|
|
448
|
+
column,
|
|
449
|
+
output_column,
|
|
450
|
+
job_priority,
|
|
451
|
+
json_schema,
|
|
452
|
+
sampling_params,
|
|
453
|
+
system_prompt,
|
|
454
|
+
dry_run,
|
|
455
|
+
stay_attached,
|
|
456
|
+
random_seed_per_input,
|
|
457
|
+
truncate_rows,
|
|
458
|
+
name,
|
|
459
|
+
description,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def infer_per_model(
|
|
463
|
+
self,
|
|
464
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
465
|
+
models: List[ModelOptions],
|
|
466
|
+
names: List[str] = None,
|
|
467
|
+
descriptions: List[str] = None,
|
|
468
|
+
column: Union[str, List[str]] = None,
|
|
469
|
+
output_column: str = "inference_result",
|
|
470
|
+
job_priority: int = 0,
|
|
471
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
472
|
+
sampling_params: dict = None,
|
|
473
|
+
system_prompt: str = None,
|
|
474
|
+
dry_run: bool = False,
|
|
475
|
+
random_seed_per_input: bool = False,
|
|
476
|
+
truncate_rows: bool = True,
|
|
477
|
+
):
|
|
478
|
+
"""
|
|
479
|
+
Run inference on the provided data, across multiple models. This method is often useful to sampling outputs from multiple models across the same dataset and compare the job_ids.
|
|
480
|
+
|
|
481
|
+
For input data, it supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
485
|
+
models (Union[ModelOptions, List[ModelOptions]], optional): The models to use for inference. Fans out each model to its own seperate job, over the same dataset.
|
|
486
|
+
names (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
|
|
487
|
+
descriptions (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
|
|
488
|
+
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
489
|
+
output_column (str, optional): The column name to store the inference job_ids in if the input is a DataFrame. Defaults to "inference_result".
|
|
490
|
+
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
491
|
+
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
492
|
+
Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
|
|
493
|
+
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
494
|
+
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
495
|
+
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
496
|
+
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
497
|
+
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
498
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
str: The ID of the inference job.
|
|
502
|
+
|
|
503
|
+
"""
|
|
504
|
+
if isinstance(names, list):
|
|
505
|
+
if len(names) != len(models):
|
|
506
|
+
raise ValueError(
|
|
507
|
+
"names parameter must be the same length as the models parameter."
|
|
508
|
+
)
|
|
509
|
+
elif names is None:
|
|
510
|
+
names = [None] * len(models)
|
|
625
511
|
else:
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
if hasattr(
|
|
633
|
-
output_schema, "model_json_schema"
|
|
634
|
-
): # Check for pydantic Model interface
|
|
635
|
-
json_schema = output_schema.model_json_schema()
|
|
636
|
-
elif isinstance(output_schema, dict):
|
|
637
|
-
json_schema = output_schema
|
|
638
|
-
else:
|
|
512
|
+
raise ValueError(
|
|
513
|
+
"names parameter must be a list or None if using a list of models"
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
if isinstance(descriptions, list):
|
|
517
|
+
if len(descriptions) != len(models):
|
|
639
518
|
raise ValueError(
|
|
640
|
-
"
|
|
519
|
+
"descriptions parameter must be the same length as the models"
|
|
520
|
+
" parameter."
|
|
641
521
|
)
|
|
522
|
+
elif descriptions is None:
|
|
523
|
+
descriptions = [None] * len(models)
|
|
642
524
|
else:
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
525
|
+
raise ValueError(
|
|
526
|
+
"descriptions parameter must be a list or None if using a list of "
|
|
527
|
+
"models"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
json_schema = None
|
|
531
|
+
if output_schema:
|
|
532
|
+
# Convert BaseModel to dict if needed
|
|
533
|
+
json_schema = normalize_output_schema(output_schema)
|
|
534
|
+
|
|
535
|
+
def start_job(
|
|
536
|
+
model_singleton: ModelOptions,
|
|
537
|
+
name_singleton: str | None,
|
|
538
|
+
description_singleton: str | None,
|
|
539
|
+
):
|
|
540
|
+
return self._run_one_batch_inference(
|
|
648
541
|
data,
|
|
649
|
-
|
|
542
|
+
model_singleton,
|
|
650
543
|
column,
|
|
651
544
|
output_column,
|
|
652
545
|
job_priority,
|
|
@@ -654,20 +547,21 @@ class Sutro:
|
|
|
654
547
|
sampling_params,
|
|
655
548
|
system_prompt,
|
|
656
549
|
dry_run,
|
|
657
|
-
|
|
550
|
+
False,
|
|
658
551
|
random_seed_per_input,
|
|
659
552
|
truncate_rows,
|
|
660
|
-
|
|
661
|
-
|
|
553
|
+
name_singleton,
|
|
554
|
+
description_singleton,
|
|
662
555
|
)
|
|
663
|
-
results.append(res)
|
|
664
556
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
557
|
+
job_ids = [
|
|
558
|
+
start_job(model, name, description)
|
|
559
|
+
for model, name, description in zip(
|
|
560
|
+
models, names, descriptions, strict=True
|
|
561
|
+
)
|
|
562
|
+
]
|
|
669
563
|
|
|
670
|
-
return
|
|
564
|
+
return job_ids
|
|
671
565
|
|
|
672
566
|
def attach(self, job_id):
|
|
673
567
|
"""
|
|
@@ -678,16 +572,8 @@ class Sutro:
|
|
|
678
572
|
"""
|
|
679
573
|
|
|
680
574
|
s = requests.Session()
|
|
681
|
-
payload = {
|
|
682
|
-
"job_id": job_id,
|
|
683
|
-
}
|
|
684
575
|
pbar = None
|
|
685
576
|
|
|
686
|
-
headers = {
|
|
687
|
-
"Authorization": f"Key {self.api_key}",
|
|
688
|
-
"Content-Type": "application/json",
|
|
689
|
-
}
|
|
690
|
-
|
|
691
577
|
with yaspin(
|
|
692
578
|
SPINNER,
|
|
693
579
|
text=to_colored_text("Looking for job..."),
|
|
@@ -725,9 +611,9 @@ class Sutro:
|
|
|
725
611
|
success = False
|
|
726
612
|
|
|
727
613
|
try:
|
|
728
|
-
with
|
|
729
|
-
|
|
730
|
-
|
|
614
|
+
with self.do_request(
|
|
615
|
+
"GET",
|
|
616
|
+
f"/stream-job-progress/{job_id}",
|
|
731
617
|
stream=True,
|
|
732
618
|
) as streaming_response:
|
|
733
619
|
streaming_response.raise_for_status()
|
|
@@ -757,7 +643,7 @@ class Sutro:
|
|
|
757
643
|
if pbar is None:
|
|
758
644
|
spinner.stop()
|
|
759
645
|
postfix = "Input tokens processed: 0"
|
|
760
|
-
pbar =
|
|
646
|
+
pbar = fancy_tqdm(
|
|
761
647
|
total=total_rows,
|
|
762
648
|
desc="Progress",
|
|
763
649
|
style=1,
|
|
@@ -790,65 +676,6 @@ class Sutro:
|
|
|
790
676
|
if spinner:
|
|
791
677
|
spinner.stop()
|
|
792
678
|
|
|
793
|
-
def fancy_tqdm(
|
|
794
|
-
self,
|
|
795
|
-
total: int,
|
|
796
|
-
desc: str = "Progress",
|
|
797
|
-
color: str = "blue",
|
|
798
|
-
style=1,
|
|
799
|
-
postfix: str = None,
|
|
800
|
-
):
|
|
801
|
-
"""
|
|
802
|
-
Creates a customized tqdm progress bar with different styling options.
|
|
803
|
-
|
|
804
|
-
Args:
|
|
805
|
-
total (int): Total iterations
|
|
806
|
-
desc (str): Description for the progress bar
|
|
807
|
-
color (str): Color of the progress bar (green, blue, red, yellow, magenta)
|
|
808
|
-
style (int): Style preset (1-4)
|
|
809
|
-
postfix (str): Postfix for the progress bar
|
|
810
|
-
"""
|
|
811
|
-
|
|
812
|
-
# Style presets
|
|
813
|
-
style_presets = {
|
|
814
|
-
1: {
|
|
815
|
-
"bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
|
|
816
|
-
"ascii": "░▒█",
|
|
817
|
-
},
|
|
818
|
-
2: {
|
|
819
|
-
"bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
|
|
820
|
-
"ascii": "▁▂▃▄▅▆▇█",
|
|
821
|
-
},
|
|
822
|
-
3: {
|
|
823
|
-
"bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
|
|
824
|
-
"ascii": "◯◔◑◕●",
|
|
825
|
-
},
|
|
826
|
-
4: {
|
|
827
|
-
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
828
|
-
"ascii": "⬜⬛",
|
|
829
|
-
},
|
|
830
|
-
5: {
|
|
831
|
-
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
832
|
-
"ascii": "▏▎▍▌▋▊▉█",
|
|
833
|
-
},
|
|
834
|
-
}
|
|
835
|
-
|
|
836
|
-
# Get style configuration
|
|
837
|
-
style_config = style_presets.get(style, style_presets[1])
|
|
838
|
-
|
|
839
|
-
return tqdm(
|
|
840
|
-
total=total,
|
|
841
|
-
desc=desc,
|
|
842
|
-
colour=color,
|
|
843
|
-
bar_format=style_config["bar_format"],
|
|
844
|
-
ascii=style_config["ascii"],
|
|
845
|
-
ncols=80,
|
|
846
|
-
dynamic_ncols=True,
|
|
847
|
-
smoothing=0.3,
|
|
848
|
-
leave=True,
|
|
849
|
-
postfix=postfix,
|
|
850
|
-
)
|
|
851
|
-
|
|
852
679
|
def list_jobs(self):
|
|
853
680
|
"""
|
|
854
681
|
List all jobs.
|
|
@@ -856,56 +683,36 @@ class Sutro:
|
|
|
856
683
|
This method retrieves a list of all jobs associated with the API key.
|
|
857
684
|
|
|
858
685
|
Returns:
|
|
859
|
-
list: A list of job details.
|
|
686
|
+
list: A list of job details, or None if the request fails.
|
|
860
687
|
"""
|
|
861
|
-
endpoint = f"{self.base_url}/list-jobs"
|
|
862
|
-
headers = {
|
|
863
|
-
"Authorization": f"Key {self.api_key}",
|
|
864
|
-
"Content-Type": "application/json",
|
|
865
|
-
}
|
|
866
|
-
|
|
867
688
|
with yaspin(
|
|
868
689
|
SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
|
|
869
690
|
) as spinner:
|
|
870
|
-
|
|
871
|
-
|
|
691
|
+
try:
|
|
692
|
+
return self._list_all_jobs_for_user()
|
|
693
|
+
except requests.HTTPError as e:
|
|
872
694
|
spinner.write(
|
|
873
695
|
to_colored_text(
|
|
874
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
696
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
875
697
|
)
|
|
876
698
|
)
|
|
877
699
|
spinner.stop()
|
|
878
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
879
|
-
return
|
|
880
|
-
return response.json()["jobs"]
|
|
700
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
701
|
+
return None
|
|
881
702
|
|
|
882
|
-
def
|
|
883
|
-
"""
|
|
884
|
-
Helper function to list jobs.
|
|
885
|
-
"""
|
|
886
|
-
endpoint = f"{self.base_url}/list-jobs˚"
|
|
887
|
-
headers = {
|
|
888
|
-
"Authorization": f"Key {self.api_key}",
|
|
889
|
-
"Content-Type": "application/json",
|
|
890
|
-
}
|
|
891
|
-
response = requests.get(endpoint, headers=headers)
|
|
892
|
-
if response.status_code != 200:
|
|
893
|
-
return None
|
|
703
|
+
def _list_all_jobs_for_user(self):
|
|
704
|
+
response = self.do_request("GET", "list-jobs")
|
|
894
705
|
return response.json()["jobs"]
|
|
895
706
|
|
|
896
707
|
def _fetch_job(self, job_id):
|
|
897
708
|
"""
|
|
898
709
|
Helper function to fetch a single job.
|
|
899
710
|
"""
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
}
|
|
905
|
-
response = requests.get(endpoint, headers=headers)
|
|
906
|
-
if response.status_code != 200:
|
|
711
|
+
try:
|
|
712
|
+
response = self.do_request("GET", f"jobs/{job_id}")
|
|
713
|
+
return response.json().get("job")
|
|
714
|
+
except requests.HTTPError:
|
|
907
715
|
return None
|
|
908
|
-
return response.json().get("job")
|
|
909
716
|
|
|
910
717
|
def _get_job_cost_estimate(self, job_id: str):
|
|
911
718
|
"""
|
|
@@ -939,15 +746,7 @@ class Sutro:
|
|
|
939
746
|
Raises:
|
|
940
747
|
requests.HTTPError: If the API returns a non-200 status code.
|
|
941
748
|
"""
|
|
942
|
-
|
|
943
|
-
headers = {
|
|
944
|
-
"Authorization": f"Key {self.api_key}",
|
|
945
|
-
"Content-Type": "application/json",
|
|
946
|
-
}
|
|
947
|
-
|
|
948
|
-
response = requests.get(endpoint, headers=headers)
|
|
949
|
-
response.raise_for_status()
|
|
950
|
-
|
|
749
|
+
response = self.do_request("GET", f"job-status/{job_id}")
|
|
951
750
|
return response.json()["job_status"][job_id]
|
|
952
751
|
|
|
953
752
|
def get_job_status(self, job_id: str):
|
|
@@ -992,7 +791,7 @@ class Sutro:
|
|
|
992
791
|
output_column: str = "inference_result",
|
|
993
792
|
disable_cache: bool = False,
|
|
994
793
|
unpack_json: bool = True,
|
|
995
|
-
):
|
|
794
|
+
) -> pl.DataFrame | pd.DataFrame:
|
|
996
795
|
"""
|
|
997
796
|
Get the results of a job by its ID.
|
|
998
797
|
|
|
@@ -1029,44 +828,37 @@ class Sutro:
|
|
|
1029
828
|
to_colored_text("✔ Results loaded from cache", state="success")
|
|
1030
829
|
)
|
|
1031
830
|
else:
|
|
1032
|
-
endpoint = f"{self.base_url}/job-results"
|
|
1033
831
|
payload = {
|
|
1034
832
|
"job_id": job_id,
|
|
1035
833
|
"include_inputs": include_inputs,
|
|
1036
834
|
"include_cumulative_logprobs": include_cumulative_logprobs,
|
|
1037
835
|
}
|
|
1038
|
-
headers = {
|
|
1039
|
-
"Authorization": f"Key {self.api_key}",
|
|
1040
|
-
"Content-Type": "application/json",
|
|
1041
|
-
}
|
|
1042
836
|
with yaspin(
|
|
1043
837
|
SPINNER,
|
|
1044
838
|
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
1045
839
|
color=YASPIN_COLOR,
|
|
1046
840
|
) as spinner:
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
841
|
+
try:
|
|
842
|
+
response = self.do_request("POST", "job-results", json=payload)
|
|
843
|
+
response_data = response.json()
|
|
844
|
+
spinner.write(
|
|
845
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
846
|
+
)
|
|
847
|
+
except requests.HTTPError as e:
|
|
1051
848
|
spinner.write(
|
|
1052
849
|
to_colored_text(
|
|
1053
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
850
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1054
851
|
)
|
|
1055
852
|
)
|
|
1056
853
|
spinner.stop()
|
|
1057
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
854
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
1058
855
|
return None
|
|
1059
856
|
|
|
1060
|
-
spinner.write(
|
|
1061
|
-
to_colored_text("✔ Job results retrieved", state="success")
|
|
1062
|
-
)
|
|
1063
|
-
|
|
1064
|
-
response_data = response.json()
|
|
1065
857
|
results_df = pl.DataFrame(response_data["results"])
|
|
1066
858
|
|
|
1067
859
|
results_df = results_df.rename({"outputs": output_column})
|
|
1068
860
|
|
|
1069
|
-
if disable_cache
|
|
861
|
+
if not disable_cache:
|
|
1070
862
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
1071
863
|
results_df.write_parquet(file_path, compression="snappy")
|
|
1072
864
|
spinner.write(
|
|
@@ -1093,10 +885,11 @@ class Sutro:
|
|
|
1093
885
|
first_row = json.loads(
|
|
1094
886
|
results_df.head(1)[output_column][0]
|
|
1095
887
|
) # checks if the first row can be json decoded
|
|
1096
|
-
results_df = results_df.map_columns(
|
|
888
|
+
results_df = results_df.map_columns(
|
|
889
|
+
output_column, lambda s: s.str.json_decode()
|
|
890
|
+
)
|
|
1097
891
|
results_df = results_df.with_columns(
|
|
1098
|
-
pl.col(output_column)
|
|
1099
|
-
.alias("output_column_json_decoded")
|
|
892
|
+
pl.col(output_column).alias("output_column_json_decoded")
|
|
1100
893
|
)
|
|
1101
894
|
json_decoded_fields = first_row.keys()
|
|
1102
895
|
for field in json_decoded_fields:
|
|
@@ -1105,19 +898,20 @@ class Sutro:
|
|
|
1105
898
|
.struct.field(field)
|
|
1106
899
|
.alias(field)
|
|
1107
900
|
)
|
|
1108
|
-
if sorted(list(set(json_decoded_fields))) == [
|
|
1109
|
-
|
|
901
|
+
if sorted(list(set(json_decoded_fields))) == [
|
|
902
|
+
"content",
|
|
903
|
+
"reasoning_content",
|
|
904
|
+
]: # if it's a reasoning model, we need to unpack the content field
|
|
905
|
+
content_keys = results_df.head(1)["content"][0].keys()
|
|
1110
906
|
for key in content_keys:
|
|
1111
907
|
results_df = results_df.with_columns(
|
|
1112
|
-
pl.col("content")
|
|
1113
|
-
.struct.field(key)
|
|
1114
|
-
.alias(key)
|
|
908
|
+
pl.col("content").struct.field(key).alias(key)
|
|
1115
909
|
)
|
|
1116
910
|
results_df = results_df.drop("content")
|
|
1117
911
|
results_df = results_df.drop(
|
|
1118
912
|
[output_column, "output_column_json_decoded"]
|
|
1119
913
|
)
|
|
1120
|
-
except Exception
|
|
914
|
+
except Exception:
|
|
1121
915
|
# if the first row cannot be json decoded, do nothing
|
|
1122
916
|
pass
|
|
1123
917
|
|
|
@@ -1153,25 +947,20 @@ class Sutro:
|
|
|
1153
947
|
Returns:
|
|
1154
948
|
dict: The status of the job.
|
|
1155
949
|
"""
|
|
1156
|
-
endpoint = f"{self.base_url}/job-cancel/{job_id}"
|
|
1157
|
-
headers = {
|
|
1158
|
-
"Authorization": f"Key {self.api_key}",
|
|
1159
|
-
"Content-Type": "application/json",
|
|
1160
|
-
}
|
|
1161
950
|
with yaspin(
|
|
1162
951
|
SPINNER,
|
|
1163
952
|
text=to_colored_text(f"Cancelling job: {job_id}"),
|
|
1164
953
|
color=YASPIN_COLOR,
|
|
1165
954
|
) as spinner:
|
|
1166
|
-
|
|
1167
|
-
|
|
955
|
+
try:
|
|
956
|
+
response = self.do_request("GET", f"job-cancel/{job_id}")
|
|
1168
957
|
spinner.write(to_colored_text("✔ Job cancelled", state="success"))
|
|
1169
|
-
|
|
958
|
+
return response.json()
|
|
959
|
+
except requests.HTTPError as e:
|
|
1170
960
|
spinner.write(to_colored_text("Failed to cancel job", state="fail"))
|
|
1171
961
|
spinner.stop()
|
|
1172
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
1173
|
-
return
|
|
1174
|
-
return response.json()
|
|
962
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
963
|
+
return None
|
|
1175
964
|
|
|
1176
965
|
def create_dataset(self):
|
|
1177
966
|
"""
|
|
@@ -1182,31 +971,27 @@ class Sutro:
|
|
|
1182
971
|
Returns:
|
|
1183
972
|
str: The ID of the new dataset.
|
|
1184
973
|
"""
|
|
1185
|
-
endpoint = f"{self.base_url}/create-dataset"
|
|
1186
|
-
headers = {
|
|
1187
|
-
"Authorization": f"Key {self.api_key}",
|
|
1188
|
-
"Content-Type": "application/json",
|
|
1189
|
-
}
|
|
1190
974
|
with yaspin(
|
|
1191
975
|
SPINNER, text=to_colored_text("Creating dataset"), color=YASPIN_COLOR
|
|
1192
976
|
) as spinner:
|
|
1193
|
-
|
|
1194
|
-
|
|
977
|
+
try:
|
|
978
|
+
response = self.do_request("GET", "create-dataset")
|
|
979
|
+
dataset_id = response.json()["dataset_id"]
|
|
1195
980
|
spinner.write(
|
|
1196
981
|
to_colored_text(
|
|
1197
|
-
f"
|
|
982
|
+
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
1198
983
|
)
|
|
1199
984
|
)
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
f"✔ Dataset created with ID: {dataset_id}", state="success"
|
|
985
|
+
return dataset_id
|
|
986
|
+
except requests.HTTPError as e:
|
|
987
|
+
spinner.write(
|
|
988
|
+
to_colored_text(
|
|
989
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
990
|
+
)
|
|
1207
991
|
)
|
|
1208
|
-
|
|
1209
|
-
|
|
992
|
+
spinner.stop()
|
|
993
|
+
print(to_colored_text(e.response.json(), state="fail"))
|
|
994
|
+
return None
|
|
1210
995
|
|
|
1211
996
|
def upload_to_dataset(
|
|
1212
997
|
self,
|
|
@@ -1238,8 +1023,6 @@ class Sutro:
|
|
|
1238
1023
|
if dataset_id is None:
|
|
1239
1024
|
dataset_id = self.create_dataset()
|
|
1240
1025
|
|
|
1241
|
-
endpoint = f"{self.base_url}/upload-to-dataset"
|
|
1242
|
-
|
|
1243
1026
|
if isinstance(file_paths, str):
|
|
1244
1027
|
# check if the file path is a directory
|
|
1245
1028
|
if os.path.isdir(file_paths):
|
|
@@ -1272,8 +1055,6 @@ class Sutro:
|
|
|
1272
1055
|
"dataset_id": dataset_id,
|
|
1273
1056
|
}
|
|
1274
1057
|
|
|
1275
|
-
headers = {"Authorization": f"Key {self.api_key}"}
|
|
1276
|
-
|
|
1277
1058
|
count += 1
|
|
1278
1059
|
spinner.write(
|
|
1279
1060
|
to_colored_text(
|
|
@@ -1282,25 +1063,18 @@ class Sutro:
|
|
|
1282
1063
|
)
|
|
1283
1064
|
|
|
1284
1065
|
try:
|
|
1285
|
-
|
|
1286
|
-
|
|
1066
|
+
self.do_request(
|
|
1067
|
+
"POST",
|
|
1068
|
+
"/upload-to-dataset",
|
|
1069
|
+
data=payload,
|
|
1070
|
+
files=files,
|
|
1071
|
+
verify=verify_ssl,
|
|
1287
1072
|
)
|
|
1288
|
-
if response.status_code != 200:
|
|
1289
|
-
# Stop spinner before showing error to avoid terminal width error
|
|
1290
|
-
spinner.stop()
|
|
1291
|
-
print(
|
|
1292
|
-
to_colored_text(
|
|
1293
|
-
f"Error: HTTP {response.status_code}", state="fail"
|
|
1294
|
-
)
|
|
1295
|
-
)
|
|
1296
|
-
print(to_colored_text(response.json(), state="fail"))
|
|
1297
|
-
return
|
|
1298
|
-
|
|
1299
1073
|
except requests.exceptions.RequestException as e:
|
|
1300
1074
|
# Stop spinner before showing error to avoid terminal width error
|
|
1301
1075
|
spinner.stop()
|
|
1302
1076
|
print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
|
|
1303
|
-
return
|
|
1077
|
+
return None
|
|
1304
1078
|
|
|
1305
1079
|
spinner.write(
|
|
1306
1080
|
to_colored_text(
|
|
@@ -1310,32 +1084,23 @@ class Sutro:
|
|
|
1310
1084
|
return dataset_id
|
|
1311
1085
|
|
|
1312
1086
|
def list_datasets(self):
|
|
1313
|
-
endpoint = f"{self.base_url}/list-datasets"
|
|
1314
|
-
headers = {
|
|
1315
|
-
"Authorization": f"Key {self.api_key}",
|
|
1316
|
-
"Content-Type": "application/json",
|
|
1317
|
-
}
|
|
1318
1087
|
with yaspin(
|
|
1319
1088
|
SPINNER, text=to_colored_text("Retrieving datasets"), color=YASPIN_COLOR
|
|
1320
1089
|
) as spinner:
|
|
1321
|
-
|
|
1322
|
-
|
|
1090
|
+
try:
|
|
1091
|
+
response = self.do_request("POST", "list-datasets")
|
|
1092
|
+
spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
|
|
1093
|
+
return response.json()["datasets"]
|
|
1094
|
+
except requests.HTTPError as e:
|
|
1323
1095
|
spinner.fail(
|
|
1324
1096
|
to_colored_text(
|
|
1325
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1097
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1326
1098
|
)
|
|
1327
1099
|
)
|
|
1328
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1329
|
-
return
|
|
1330
|
-
spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
|
|
1331
|
-
return response.json()["datasets"]
|
|
1100
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1101
|
+
return None
|
|
1332
1102
|
|
|
1333
1103
|
def list_dataset_files(self, dataset_id: str):
|
|
1334
|
-
endpoint = f"{self.base_url}/list-dataset-files"
|
|
1335
|
-
headers = {
|
|
1336
|
-
"Authorization": f"Key {self.api_key}",
|
|
1337
|
-
"Content-Type": "application/json",
|
|
1338
|
-
}
|
|
1339
1104
|
payload = {
|
|
1340
1105
|
"dataset_id": dataset_id,
|
|
1341
1106
|
}
|
|
@@ -1344,23 +1109,22 @@ class Sutro:
|
|
|
1344
1109
|
text=to_colored_text(f"Listing files in dataset: {dataset_id}"),
|
|
1345
1110
|
color=YASPIN_COLOR,
|
|
1346
1111
|
) as spinner:
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
if response.status_code != 200:
|
|
1351
|
-
spinner.fail(
|
|
1112
|
+
try:
|
|
1113
|
+
response = self.do_request("POST", "list-dataset-files", json=payload)
|
|
1114
|
+
spinner.write(
|
|
1352
1115
|
to_colored_text(
|
|
1353
|
-
f"
|
|
1116
|
+
f"✔ Files listed in dataset: {dataset_id}", state="success"
|
|
1354
1117
|
)
|
|
1355
1118
|
)
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1119
|
+
return response.json()["files"]
|
|
1120
|
+
except requests.HTTPError as e:
|
|
1121
|
+
spinner.fail(
|
|
1122
|
+
to_colored_text(
|
|
1123
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1124
|
+
)
|
|
1361
1125
|
)
|
|
1362
|
-
|
|
1363
|
-
|
|
1126
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1127
|
+
return None
|
|
1364
1128
|
|
|
1365
1129
|
def download_from_dataset(
|
|
1366
1130
|
self,
|
|
@@ -1368,8 +1132,6 @@ class Sutro:
|
|
|
1368
1132
|
files: Union[List[str], str] = None,
|
|
1369
1133
|
output_path: str = None,
|
|
1370
1134
|
):
|
|
1371
|
-
endpoint = f"{self.base_url}/download-from-dataset"
|
|
1372
|
-
|
|
1373
1135
|
if files is None:
|
|
1374
1136
|
files = self.list_dataset_files(dataset_id)
|
|
1375
1137
|
elif isinstance(files, str):
|
|
@@ -1394,32 +1156,32 @@ class Sutro:
|
|
|
1394
1156
|
) as spinner:
|
|
1395
1157
|
count = 0
|
|
1396
1158
|
for file in files:
|
|
1397
|
-
headers = {
|
|
1398
|
-
"Authorization": f"Key {self.api_key}",
|
|
1399
|
-
"Content-Type": "application/json",
|
|
1400
|
-
}
|
|
1401
|
-
payload = {
|
|
1402
|
-
"dataset_id": dataset_id,
|
|
1403
|
-
"file_name": file,
|
|
1404
|
-
}
|
|
1405
1159
|
spinner.text = to_colored_text(
|
|
1406
1160
|
f"Downloading file {count + 1}/{len(files)} from dataset: {dataset_id}"
|
|
1407
1161
|
)
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1162
|
+
|
|
1163
|
+
try:
|
|
1164
|
+
payload = {
|
|
1165
|
+
"dataset_id": dataset_id,
|
|
1166
|
+
"file_name": file,
|
|
1167
|
+
}
|
|
1168
|
+
response = self.do_request(
|
|
1169
|
+
"POST", "download-from-dataset", json=payload
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
file_content = response.content
|
|
1173
|
+
with open(os.path.join(output_path, file), "wb") as f:
|
|
1174
|
+
f.write(file_content)
|
|
1175
|
+
|
|
1176
|
+
count += 1
|
|
1177
|
+
except requests.HTTPError as e:
|
|
1412
1178
|
spinner.fail(
|
|
1413
1179
|
to_colored_text(
|
|
1414
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1180
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1415
1181
|
)
|
|
1416
1182
|
)
|
|
1417
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1183
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1418
1184
|
return
|
|
1419
|
-
file_content = response.content
|
|
1420
|
-
with open(os.path.join(output_path, file), "wb") as f:
|
|
1421
|
-
f.write(file_content)
|
|
1422
|
-
count += 1
|
|
1423
1185
|
spinner.write(
|
|
1424
1186
|
to_colored_text(
|
|
1425
1187
|
f"✔ {count} files successfully downloaded from dataset: {dataset_id}",
|
|
@@ -1439,46 +1201,38 @@ class Sutro:
|
|
|
1439
1201
|
Returns:
|
|
1440
1202
|
dict: The status of the authentication.
|
|
1441
1203
|
"""
|
|
1442
|
-
endpoint = f"{self.base_url}/try-authentication"
|
|
1443
|
-
headers = {
|
|
1444
|
-
"Authorization": f"Key {api_key}",
|
|
1445
|
-
"Content-Type": "application/json",
|
|
1446
|
-
}
|
|
1447
1204
|
with yaspin(
|
|
1448
1205
|
SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
|
|
1449
1206
|
) as spinner:
|
|
1450
|
-
|
|
1451
|
-
|
|
1207
|
+
try:
|
|
1208
|
+
response = self.do_request("GET", "try-authentication", api_key)
|
|
1209
|
+
|
|
1452
1210
|
spinner.write(to_colored_text("✔"))
|
|
1453
|
-
|
|
1211
|
+
return response.json()
|
|
1212
|
+
except requests.HTTPError as e:
|
|
1454
1213
|
spinner.write(
|
|
1455
1214
|
to_colored_text(
|
|
1456
|
-
f"API key failed to authenticate: {response.status_code}",
|
|
1215
|
+
f"API key failed to authenticate: {e.response.status_code}",
|
|
1457
1216
|
state="fail",
|
|
1458
1217
|
)
|
|
1459
1218
|
)
|
|
1460
|
-
return
|
|
1461
|
-
return response.json()
|
|
1219
|
+
return None
|
|
1462
1220
|
|
|
1463
1221
|
def get_quotas(self):
|
|
1464
|
-
endpoint = f"{self.base_url}/get-quotas"
|
|
1465
|
-
headers = {
|
|
1466
|
-
"Authorization": f"Key {self.api_key}",
|
|
1467
|
-
"Content-Type": "application/json",
|
|
1468
|
-
}
|
|
1469
1222
|
with yaspin(
|
|
1470
1223
|
SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
|
|
1471
1224
|
) as spinner:
|
|
1472
|
-
|
|
1473
|
-
|
|
1225
|
+
try:
|
|
1226
|
+
response = self.do_request("GET", "get-quotas")
|
|
1227
|
+
return response.json()["quotas"]
|
|
1228
|
+
except requests.HTTPError as e:
|
|
1474
1229
|
spinner.fail(
|
|
1475
1230
|
to_colored_text(
|
|
1476
|
-
f"Bad status code: {response.status_code}", state="fail"
|
|
1231
|
+
f"Bad status code: {e.response.status_code}", state="fail"
|
|
1477
1232
|
)
|
|
1478
1233
|
)
|
|
1479
|
-
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1480
|
-
return
|
|
1481
|
-
return response.json()["quotas"]
|
|
1234
|
+
print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
|
|
1235
|
+
return None
|
|
1482
1236
|
|
|
1483
1237
|
def await_job_completion(
|
|
1484
1238
|
self,
|
|
@@ -1486,7 +1240,7 @@ class Sutro:
|
|
|
1486
1240
|
timeout: Optional[int] = 7200,
|
|
1487
1241
|
obtain_results: bool = True,
|
|
1488
1242
|
is_cost_estimate: bool = False,
|
|
1489
|
-
) ->
|
|
1243
|
+
) -> pl.DataFrame | None:
|
|
1490
1244
|
"""
|
|
1491
1245
|
Waits for job completion to occur and then returns the results upon
|
|
1492
1246
|
a successful completion.
|
|
@@ -1502,7 +1256,7 @@ class Sutro:
|
|
|
1502
1256
|
"""
|
|
1503
1257
|
POLL_INTERVAL = 5
|
|
1504
1258
|
|
|
1505
|
-
results = None
|
|
1259
|
+
results: pl.DataFrame | None = None
|
|
1506
1260
|
start_time = time.time()
|
|
1507
1261
|
with yaspin(
|
|
1508
1262
|
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|