sutro 0.1.35__tar.gz → 0.1.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sutro-0.1.35 → sutro-0.1.37}/PKG-INFO +14 -15
- {sutro-0.1.35 → sutro-0.1.37}/pyproject.toml +8 -10
- {sutro-0.1.35 → sutro-0.1.37}/sutro/sdk.py +161 -27
- sutro-0.1.35/.gitignore +0 -4
- {sutro-0.1.35 → sutro-0.1.37}/LICENSE +0 -0
- {sutro-0.1.35 → sutro-0.1.37}/README.md +0 -0
- {sutro-0.1.35 → sutro-0.1.37}/sutro/__init__.py +0 -0
- {sutro-0.1.35 → sutro-0.1.37}/sutro/cli.py +0 -0
|
@@ -1,24 +1,23 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sutro
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.37
|
|
4
4
|
Summary: Sutro Python SDK
|
|
5
|
-
Project-URL: Homepage, https://sutro.sh
|
|
6
|
-
Project-URL: Documentation, https://docs.sutro.sh
|
|
7
5
|
License-Expression: Apache-2.0
|
|
8
|
-
|
|
6
|
+
Requires-Dist: numpy>=2.1.1,<3.0.0
|
|
7
|
+
Requires-Dist: requests>=2.32.3,<3.0.0
|
|
8
|
+
Requires-Dist: pandas>=2.2.3,<3.0.0
|
|
9
|
+
Requires-Dist: polars>=1.33.0,<=1.34.0
|
|
10
|
+
Requires-Dist: click>=8.1.7,<9.0.0
|
|
11
|
+
Requires-Dist: colorama>=0.4.4,<1.0.0
|
|
12
|
+
Requires-Dist: yaspin>=3.2.0,<4.0.0
|
|
13
|
+
Requires-Dist: tqdm>=4.67.1,<5.0.0
|
|
14
|
+
Requires-Dist: pydantic>=2.11.4,<3.0.0
|
|
15
|
+
Requires-Dist: pyarrow>=21.0.0,<22.0.0
|
|
16
|
+
Requires-Dist: ruff==0.13.1 ; extra == 'dev'
|
|
9
17
|
Requires-Python: >=3.10
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
Requires-Dist: numpy<3.0.0,>=2.1.1
|
|
13
|
-
Requires-Dist: pandas<3.0.0,>=2.2.3
|
|
14
|
-
Requires-Dist: polars<=1.8.2
|
|
15
|
-
Requires-Dist: pyarrow<22.0.0,>=21.0.0
|
|
16
|
-
Requires-Dist: pydantic<3.0.0,>=2.11.4
|
|
17
|
-
Requires-Dist: requests<3.0.0,>=2.32.3
|
|
18
|
-
Requires-Dist: tqdm<5.0.0,>=4.67.1
|
|
19
|
-
Requires-Dist: yaspin<4.0.0,>=3.2.0
|
|
18
|
+
Project-URL: Documentation, https://docs.sutro.sh
|
|
19
|
+
Project-URL: Homepage, https://sutro.sh
|
|
20
20
|
Provides-Extra: dev
|
|
21
|
-
Requires-Dist: ruff==0.13.1; extra == 'dev'
|
|
22
21
|
Description-Content-Type: text/markdown
|
|
23
22
|
|
|
24
23
|

|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[build-system]
|
|
2
|
-
requires = ["
|
|
3
|
-
build-backend = "
|
|
2
|
+
requires = ["uv_build>=0.7.19,<=0.9.2"]
|
|
3
|
+
build-backend = "uv_build"
|
|
4
4
|
|
|
5
5
|
[tool.hatch.env]
|
|
6
6
|
requires = ["pip"]
|
|
@@ -9,7 +9,7 @@ installer = "uv"
|
|
|
9
9
|
|
|
10
10
|
[project]
|
|
11
11
|
name = "sutro"
|
|
12
|
-
version = "0.1.
|
|
12
|
+
version = "0.1.37"
|
|
13
13
|
description = "Sutro Python SDK"
|
|
14
14
|
readme = "README.md"
|
|
15
15
|
requires-python = ">=3.10"
|
|
@@ -18,7 +18,7 @@ dependencies = [
|
|
|
18
18
|
"numpy>=2.1.1,<3.0.0",
|
|
19
19
|
"requests>=2.32.3,<3.0.0",
|
|
20
20
|
"pandas>=2.2.3,<3.0.0",
|
|
21
|
-
"polars
|
|
21
|
+
"polars>=1.33.0,<=1.34.0",
|
|
22
22
|
"click>=8.1.7,<9.0.0",
|
|
23
23
|
"colorama>=0.4.4,<1.0.0",
|
|
24
24
|
"yaspin>=3.2.0,<4.0.0",
|
|
@@ -39,16 +39,14 @@ sutro = "sutro.cli:cli"
|
|
|
39
39
|
"Homepage" = "https://sutro.sh"
|
|
40
40
|
"Documentation" = "https://docs.sutro.sh"
|
|
41
41
|
|
|
42
|
-
[tool.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
[tool.hatch.build.targets.sdist]
|
|
46
|
-
include = [
|
|
42
|
+
[tool.uv.build-backend]
|
|
43
|
+
module-root = "."
|
|
44
|
+
source-include = [
|
|
47
45
|
"sutro",
|
|
48
46
|
"README.md",
|
|
49
47
|
"LICENSE",
|
|
50
48
|
]
|
|
51
|
-
exclude = [
|
|
49
|
+
source-exclude = [
|
|
52
50
|
"demo_data",
|
|
53
51
|
"demo.py",
|
|
54
52
|
".gitignore",
|
|
@@ -14,7 +14,10 @@ import time
|
|
|
14
14
|
from pydantic import BaseModel
|
|
15
15
|
import pyarrow.parquet as pq
|
|
16
16
|
import shutil
|
|
17
|
+
import importlib.metadata
|
|
17
18
|
|
|
19
|
+
JOB_NAME_CHAR_LIMIT = 45
|
|
20
|
+
JOB_DESCRIPTION_CHAR_LIMIT = 512
|
|
18
21
|
|
|
19
22
|
class JobStatus(str, Enum):
|
|
20
23
|
"""Job statuses that will be returned by the API & SDK"""
|
|
@@ -62,15 +65,20 @@ ModelOptions = Literal[
|
|
|
62
65
|
"llama-3.3-70b",
|
|
63
66
|
"llama-3.3-70b",
|
|
64
67
|
"qwen-3-4b",
|
|
68
|
+
"qwen-3-14b",
|
|
65
69
|
"qwen-3-32b",
|
|
70
|
+
"qwen-3-30b-a3b",
|
|
71
|
+
"qwen-3-235b-a22b",
|
|
66
72
|
"qwen-3-4b-thinking",
|
|
73
|
+
"qwen-3-14b-thinking",
|
|
67
74
|
"qwen-3-32b-thinking",
|
|
75
|
+
"qwen-3-235b-a22b-thinking",
|
|
76
|
+
"qwen-3-30b-a3b-thinking",
|
|
68
77
|
"gemma-3-4b-it",
|
|
78
|
+
"gemma-3-12b-it",
|
|
69
79
|
"gemma-3-27b-it",
|
|
70
|
-
"gpt-oss-120b",
|
|
71
80
|
"gpt-oss-20b",
|
|
72
|
-
"
|
|
73
|
-
"qwen-3-30b-a3b-thinking",
|
|
81
|
+
"gpt-oss-120b",
|
|
74
82
|
"qwen-3-embedding-0.6b",
|
|
75
83
|
"qwen-3-embedding-6b",
|
|
76
84
|
"qwen-3-embedding-8b",
|
|
@@ -78,7 +86,7 @@ ModelOptions = Literal[
|
|
|
78
86
|
|
|
79
87
|
|
|
80
88
|
def to_colored_text(
|
|
81
|
-
text: str, state: Optional[Literal["success", "fail"]] = None
|
|
89
|
+
text: str, state: Optional[Literal["success", "fail", "callout"]] = None
|
|
82
90
|
) -> str:
|
|
83
91
|
"""
|
|
84
92
|
Apply color to text based on state.
|
|
@@ -96,6 +104,8 @@ def to_colored_text(
|
|
|
96
104
|
return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
|
|
97
105
|
case "fail":
|
|
98
106
|
return f"{Fore.RED}{text}{Style.RESET_ALL}"
|
|
107
|
+
case "callout":
|
|
108
|
+
return f"{Fore.MAGENTA}{text}{Style.RESET_ALL}"
|
|
99
109
|
case _:
|
|
100
110
|
# Default to blue for normal/processing states
|
|
101
111
|
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
@@ -117,6 +127,34 @@ class Sutro:
|
|
|
117
127
|
def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
|
|
118
128
|
self.api_key = api_key or self.check_for_api_key()
|
|
119
129
|
self.base_url = base_url
|
|
130
|
+
self.check_version("sutro")
|
|
131
|
+
|
|
132
|
+
def check_version(self, package_name: str):
|
|
133
|
+
try:
|
|
134
|
+
# Local version
|
|
135
|
+
local_version = importlib.metadata.version(package_name)
|
|
136
|
+
except importlib.metadata.PackageNotFoundError:
|
|
137
|
+
print(f"{package_name} is not installed.")
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
# Latest release from PyPI
|
|
142
|
+
resp = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=2)
|
|
143
|
+
resp.raise_for_status()
|
|
144
|
+
latest_version = resp.json()["info"]["version"]
|
|
145
|
+
|
|
146
|
+
if local_version != latest_version:
|
|
147
|
+
msg = (f"⚠️ You are using {package_name} {local_version}, "
|
|
148
|
+
f"but the latest release is {latest_version}. "
|
|
149
|
+
f"Run `[uv] pip install -U {package_name}` to upgrade.")
|
|
150
|
+
print(to_colored_text(
|
|
151
|
+
msg,
|
|
152
|
+
state="callout"
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
# Fail silently or log, you don’t want this blocking usage
|
|
157
|
+
pass
|
|
120
158
|
|
|
121
159
|
def check_for_api_key(self):
|
|
122
160
|
"""
|
|
@@ -159,6 +197,39 @@ class Sutro:
|
|
|
159
197
|
"""
|
|
160
198
|
self.api_key = api_key
|
|
161
199
|
|
|
200
|
+
def do_dataframe_column_concatenation(self, data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]):
|
|
201
|
+
"""
|
|
202
|
+
If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
|
|
203
|
+
"""
|
|
204
|
+
try:
|
|
205
|
+
if isinstance(data, pd.DataFrame):
|
|
206
|
+
series_parts = []
|
|
207
|
+
for p in column:
|
|
208
|
+
if p in data.columns:
|
|
209
|
+
s = data[p].astype("string").fillna("")
|
|
210
|
+
else:
|
|
211
|
+
# Treat as a literal separator
|
|
212
|
+
s = pd.Series([p] * len(data), index=data.index, dtype="string")
|
|
213
|
+
series_parts.append(s)
|
|
214
|
+
|
|
215
|
+
out = series_parts[0]
|
|
216
|
+
for s in series_parts[1:]:
|
|
217
|
+
out = out.str.cat(s, na_rep="")
|
|
218
|
+
|
|
219
|
+
return out.tolist()
|
|
220
|
+
elif isinstance(data, pl.DataFrame):
|
|
221
|
+
exprs = []
|
|
222
|
+
for p in column:
|
|
223
|
+
if p in data.columns:
|
|
224
|
+
exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
|
|
225
|
+
else:
|
|
226
|
+
exprs.append(pl.lit(p))
|
|
227
|
+
|
|
228
|
+
result = data.select(pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat"))
|
|
229
|
+
return result["concat"].to_list()
|
|
230
|
+
except Exception as e:
|
|
231
|
+
raise ValueError(f"Error handling column concatentation: {e}")
|
|
232
|
+
|
|
162
233
|
def handle_data_helper(
|
|
163
234
|
self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
|
|
164
235
|
):
|
|
@@ -167,7 +238,10 @@ class Sutro:
|
|
|
167
238
|
elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
168
239
|
if column is None:
|
|
169
240
|
raise ValueError("Column name must be specified for DataFrame input")
|
|
170
|
-
|
|
241
|
+
if isinstance(column, list):
|
|
242
|
+
input_data = self.do_dataframe_column_concatenation(data, column)
|
|
243
|
+
elif isinstance(column, str):
|
|
244
|
+
input_data = data[column].to_list()
|
|
171
245
|
elif isinstance(data, str):
|
|
172
246
|
if data.startswith("dataset-"):
|
|
173
247
|
input_data = data + ":" + column
|
|
@@ -212,7 +286,7 @@ class Sutro:
|
|
|
212
286
|
self,
|
|
213
287
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
214
288
|
model: ModelOptions,
|
|
215
|
-
column: str,
|
|
289
|
+
column: Union[str, List[str]],
|
|
216
290
|
output_column: str,
|
|
217
291
|
job_priority: int,
|
|
218
292
|
json_schema: Dict[str, Any],
|
|
@@ -222,7 +296,15 @@ class Sutro:
|
|
|
222
296
|
stay_attached: Optional[bool],
|
|
223
297
|
random_seed_per_input: bool,
|
|
224
298
|
truncate_rows: bool,
|
|
299
|
+
name: str,
|
|
300
|
+
description: str,
|
|
225
301
|
):
|
|
302
|
+
# Validate name and description lengths
|
|
303
|
+
if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
|
|
304
|
+
raise ValueError(f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters.")
|
|
305
|
+
if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
|
|
306
|
+
raise ValueError(f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters.")
|
|
307
|
+
|
|
226
308
|
input_data = self.handle_data_helper(data, column)
|
|
227
309
|
endpoint = f"{self.base_url}/batch-inference"
|
|
228
310
|
headers = {
|
|
@@ -239,6 +321,8 @@ class Sutro:
|
|
|
239
321
|
"sampling_params": sampling_params,
|
|
240
322
|
"random_seed_per_input": random_seed_per_input,
|
|
241
323
|
"truncate_rows": truncate_rows,
|
|
324
|
+
"name": name,
|
|
325
|
+
"description": description,
|
|
242
326
|
}
|
|
243
327
|
|
|
244
328
|
# There are two gotchas with yaspin:
|
|
@@ -284,9 +368,10 @@ class Sutro:
|
|
|
284
368
|
)
|
|
285
369
|
return job_id
|
|
286
370
|
else:
|
|
371
|
+
name_text = f" and name {name}" if name is not None else ""
|
|
287
372
|
spinner.write(
|
|
288
373
|
to_colored_text(
|
|
289
|
-
f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
|
|
374
|
+
f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}.",
|
|
290
375
|
state="success",
|
|
291
376
|
)
|
|
292
377
|
)
|
|
@@ -435,7 +520,21 @@ class Sutro:
|
|
|
435
520
|
|
|
436
521
|
results = job_results_response.json()["results"]["outputs"]
|
|
437
522
|
|
|
438
|
-
|
|
523
|
+
if isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
524
|
+
if isinstance(data, pd.DataFrame):
|
|
525
|
+
data[output_column] = results
|
|
526
|
+
elif isinstance(data, pl.DataFrame):
|
|
527
|
+
data = data.with_columns(pl.Series(output_column, results))
|
|
528
|
+
print(data)
|
|
529
|
+
spinner.write(
|
|
530
|
+
to_colored_text(
|
|
531
|
+
f"✔ Displaying result preview. You can join the results on the original dataframe with `so.get_job_results('{job_id}', with_original_df=<original_df>)`",
|
|
532
|
+
state="success",
|
|
533
|
+
)
|
|
534
|
+
)
|
|
535
|
+
else:
|
|
536
|
+
print(results)
|
|
537
|
+
spinner.write(
|
|
439
538
|
to_colored_text(
|
|
440
539
|
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
441
540
|
state="success",
|
|
@@ -443,14 +542,7 @@ class Sutro:
|
|
|
443
542
|
)
|
|
444
543
|
spinner.stop()
|
|
445
544
|
|
|
446
|
-
|
|
447
|
-
if isinstance(data, pd.DataFrame):
|
|
448
|
-
data[output_column] = results
|
|
449
|
-
elif isinstance(data, pl.DataFrame):
|
|
450
|
-
data = data.with_columns(pl.Series(output_column, results))
|
|
451
|
-
return data
|
|
452
|
-
|
|
453
|
-
return results
|
|
545
|
+
return job_id
|
|
454
546
|
return None
|
|
455
547
|
return None
|
|
456
548
|
|
|
@@ -458,7 +550,9 @@ class Sutro:
|
|
|
458
550
|
self,
|
|
459
551
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
460
552
|
model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
|
|
461
|
-
|
|
553
|
+
name: Union[str, List[str]] = None,
|
|
554
|
+
description: Union[str, List[str]] = None,
|
|
555
|
+
column: Union[str, List[str]] = None,
|
|
462
556
|
output_column: str = "inference_result",
|
|
463
557
|
job_priority: int = 0,
|
|
464
558
|
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
@@ -467,7 +561,7 @@ class Sutro:
|
|
|
467
561
|
dry_run: bool = False,
|
|
468
562
|
stay_attached: Optional[bool] = None,
|
|
469
563
|
random_seed_per_input: bool = False,
|
|
470
|
-
truncate_rows: bool =
|
|
564
|
+
truncate_rows: bool = True,
|
|
471
565
|
):
|
|
472
566
|
"""
|
|
473
567
|
Run inference on the provided data.
|
|
@@ -478,7 +572,9 @@ class Sutro:
|
|
|
478
572
|
Args:
|
|
479
573
|
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
480
574
|
model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
|
|
481
|
-
|
|
575
|
+
name (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
|
|
576
|
+
description (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
|
|
577
|
+
column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
|
|
482
578
|
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
483
579
|
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
484
580
|
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
@@ -488,10 +584,10 @@ class Sutro:
|
|
|
488
584
|
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
489
585
|
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
490
586
|
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
491
|
-
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to
|
|
587
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
|
|
492
588
|
|
|
493
589
|
Returns:
|
|
494
|
-
|
|
590
|
+
str: The ID of the inference job.
|
|
495
591
|
|
|
496
592
|
"""
|
|
497
593
|
if isinstance(model, list) == False:
|
|
@@ -503,6 +599,34 @@ class Sutro:
|
|
|
503
599
|
model_list = model
|
|
504
600
|
stay_attached = False
|
|
505
601
|
|
|
602
|
+
if isinstance(model_list, list):
|
|
603
|
+
if isinstance(name, list):
|
|
604
|
+
if len(name) != len(model_list):
|
|
605
|
+
raise ValueError("Name list must be the same length as the model list.")
|
|
606
|
+
name_list = name
|
|
607
|
+
elif isinstance(name, str):
|
|
608
|
+
raise ValueError("Name must be a list if using a list of models.")
|
|
609
|
+
elif name is None:
|
|
610
|
+
name_list = [None] * len(model_list)
|
|
611
|
+
else:
|
|
612
|
+
if isinstance(name, list):
|
|
613
|
+
raise ValueError("Name must be a string or None if using a single model.")
|
|
614
|
+
name_list = [name]
|
|
615
|
+
|
|
616
|
+
if isinstance(model_list, list):
|
|
617
|
+
if isinstance(description, list):
|
|
618
|
+
if len(description) != len(model_list):
|
|
619
|
+
raise ValueError("Descriptions list must be the same length as the model list.")
|
|
620
|
+
description_list = description
|
|
621
|
+
elif isinstance(description, str):
|
|
622
|
+
raise ValueError("Description must be a list if using a list of models.")
|
|
623
|
+
elif description is None:
|
|
624
|
+
description_list = [None] * len(model_list)
|
|
625
|
+
else:
|
|
626
|
+
if isinstance(name, list):
|
|
627
|
+
raise ValueError("Description must be a string or None if using a single model.")
|
|
628
|
+
description_list = [description]
|
|
629
|
+
|
|
506
630
|
# Convert BaseModel to dict if needed
|
|
507
631
|
if output_schema is not None:
|
|
508
632
|
if hasattr(
|
|
@@ -517,12 +641,12 @@ class Sutro:
|
|
|
517
641
|
)
|
|
518
642
|
else:
|
|
519
643
|
json_schema = None
|
|
520
|
-
|
|
644
|
+
|
|
521
645
|
results = []
|
|
522
|
-
for
|
|
646
|
+
for i in range(len(model_list)):
|
|
523
647
|
res = self._run_one_batch_inference(
|
|
524
648
|
data,
|
|
525
|
-
|
|
649
|
+
model_list[i],
|
|
526
650
|
column,
|
|
527
651
|
output_column,
|
|
528
652
|
job_priority,
|
|
@@ -533,6 +657,8 @@ class Sutro:
|
|
|
533
657
|
stay_attached,
|
|
534
658
|
random_seed_per_input,
|
|
535
659
|
truncate_rows,
|
|
660
|
+
name_list[i],
|
|
661
|
+
description_list[i],
|
|
536
662
|
)
|
|
537
663
|
results.append(res)
|
|
538
664
|
|
|
@@ -967,9 +1093,9 @@ class Sutro:
|
|
|
967
1093
|
first_row = json.loads(
|
|
968
1094
|
results_df.head(1)[output_column][0]
|
|
969
1095
|
) # checks if the first row can be json decoded
|
|
1096
|
+
results_df = results_df.map_columns(output_column, lambda s: s.str.json_decode())
|
|
970
1097
|
results_df = results_df.with_columns(
|
|
971
1098
|
pl.col(output_column)
|
|
972
|
-
.str.json_decode()
|
|
973
1099
|
.alias("output_column_json_decoded")
|
|
974
1100
|
)
|
|
975
1101
|
json_decoded_fields = first_row.keys()
|
|
@@ -979,7 +1105,15 @@ class Sutro:
|
|
|
979
1105
|
.struct.field(field)
|
|
980
1106
|
.alias(field)
|
|
981
1107
|
)
|
|
982
|
-
#
|
|
1108
|
+
if sorted(list(set(json_decoded_fields))) == ['content', 'reasoning_content']: # if it's a reasoning model, we need to unpack the content field
|
|
1109
|
+
content_keys = results_df.head(1)['content'][0].keys()
|
|
1110
|
+
for key in content_keys:
|
|
1111
|
+
results_df = results_df.with_columns(
|
|
1112
|
+
pl.col("content")
|
|
1113
|
+
.struct.field(key)
|
|
1114
|
+
.alias(key)
|
|
1115
|
+
)
|
|
1116
|
+
results_df = results_df.drop("content")
|
|
983
1117
|
results_df = results_df.drop(
|
|
984
1118
|
[output_column, "output_column_json_decoded"]
|
|
985
1119
|
)
|
|
@@ -1364,7 +1498,7 @@ class Sutro:
|
|
|
1364
1498
|
timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
|
|
1365
1499
|
|
|
1366
1500
|
Returns:
|
|
1367
|
-
|
|
1501
|
+
pl.DataFrame: The results of the job in a polars DataFrame.
|
|
1368
1502
|
"""
|
|
1369
1503
|
POLL_INTERVAL = 5
|
|
1370
1504
|
|
sutro-0.1.35/.gitignore
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|