sutro 0.1.33__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutro/__init__.py +15 -7
- sutro/cli.py +1 -1
- sutro/common.py +220 -0
- sutro/interfaces.py +91 -0
- sutro/sdk.py +400 -444
- sutro/templates/classification.py +117 -0
- sutro/templates/embed.py +53 -0
- sutro/templates/evals.py +340 -0
- sutro/validation.py +60 -0
- {sutro-0.1.33.dist-info → sutro-0.1.40.dist-info}/METADATA +14 -16
- sutro-0.1.40.dist-info/RECORD +13 -0
- sutro-0.1.40.dist-info/WHEEL +4 -0
- {sutro-0.1.33.dist-info → sutro-0.1.40.dist-info}/entry_points.txt +1 -0
- sutro-0.1.33.dist-info/RECORD +0 -8
- sutro-0.1.33.dist-info/WHEEL +0 -4
- sutro-0.1.33.dist-info/licenses/LICENSE +0 -201
sutro/__init__.py
CHANGED
|
@@ -1,14 +1,22 @@
|
|
|
1
1
|
from .sdk import Sutro
|
|
2
2
|
|
|
3
|
-
# Create
|
|
3
|
+
# Create a singleton instance
|
|
4
4
|
_instance = Sutro()
|
|
5
5
|
|
|
6
|
-
#
|
|
7
|
-
from types import MethodType
|
|
8
|
-
|
|
6
|
+
# Export all public methods from the instance
|
|
9
7
|
for attr in dir(_instance):
|
|
10
|
-
if callable(getattr(_instance, attr)) and not attr.startswith("
|
|
11
|
-
globals()[attr] =
|
|
8
|
+
if callable(getattr(_instance, attr)) and not attr.startswith("_"):
|
|
9
|
+
globals()[attr] = getattr(_instance, attr)
|
|
10
|
+
|
|
11
|
+
# Optionally export the class itself if users need direct access
|
|
12
|
+
# Sutro is already imported and available
|
|
13
|
+
|
|
14
|
+
# Define __all__ for clean imports
|
|
15
|
+
__all__ = ["Sutro"] + [
|
|
16
|
+
attr
|
|
17
|
+
for attr in dir(_instance)
|
|
18
|
+
if callable(getattr(_instance, attr)) and not attr.startswith("_")
|
|
19
|
+
]
|
|
12
20
|
|
|
13
21
|
# Clean up namespace
|
|
14
|
-
del
|
|
22
|
+
del attr
|
sutro/cli.py
CHANGED
|
@@ -52,7 +52,7 @@ def set_human_readable_dates(datetime_columns, df):
|
|
|
52
52
|
# Convert UTC string to local time string
|
|
53
53
|
df = df.with_columns(
|
|
54
54
|
pl.col(col)
|
|
55
|
-
.str.to_datetime()
|
|
55
|
+
.str.to_datetime("%Y-%m-%dT%H:%M:%S%.f%Z")
|
|
56
56
|
.map_elements(
|
|
57
57
|
lambda dt: dt.replace(tzinfo=timezone.utc)
|
|
58
58
|
.astimezone()
|
sutro/common.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Union, List, Literal, Dict, Any, Type, Optional
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import polars as pl
|
|
6
|
+
from colorama import Fore, Style
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
EmbeddingModelOptions = Literal[
|
|
11
|
+
"qwen-3-embedding-0.6b",
|
|
12
|
+
"qwen-3-embedding-6b",
|
|
13
|
+
"qwen-3-embedding-8b",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
# Models available for inference. Keep in sync with the backend configuration
|
|
17
|
+
# so users get helpful autocompletion when selecting a model.
|
|
18
|
+
ModelOptions = Literal[
|
|
19
|
+
"llama-3.2-3b",
|
|
20
|
+
"llama-3.1-8b",
|
|
21
|
+
"llama-3.3-70b",
|
|
22
|
+
"llama-3.3-70b",
|
|
23
|
+
"qwen-3-4b",
|
|
24
|
+
"qwen-3-14b",
|
|
25
|
+
"qwen-3-32b",
|
|
26
|
+
"qwen-3-30b-a3b",
|
|
27
|
+
"qwen-3-235b-a22b",
|
|
28
|
+
"qwen-3-4b-thinking",
|
|
29
|
+
"qwen-3-14b-thinking",
|
|
30
|
+
"qwen-3-32b-thinking",
|
|
31
|
+
"qwen-3-235b-a22b-thinking",
|
|
32
|
+
"qwen-3-30b-a3b-thinking",
|
|
33
|
+
"gemma-3-4b-it",
|
|
34
|
+
"gemma-3-12b-it",
|
|
35
|
+
"gemma-3-27b-it",
|
|
36
|
+
"gpt-oss-20b",
|
|
37
|
+
"gpt-oss-120b",
|
|
38
|
+
"qwen-3-embedding-0.6b",
|
|
39
|
+
"qwen-3-embedding-6b",
|
|
40
|
+
"qwen-3-embedding-8b",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def do_dataframe_column_concatenation(
|
|
45
|
+
data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
if isinstance(data, pd.DataFrame):
|
|
52
|
+
series_parts = []
|
|
53
|
+
for p in column:
|
|
54
|
+
if p in data.columns:
|
|
55
|
+
s = data[p].astype("string").fillna("")
|
|
56
|
+
else:
|
|
57
|
+
# Treat as a literal separator
|
|
58
|
+
s = pd.Series([p] * len(data), index=data.index, dtype="string")
|
|
59
|
+
series_parts.append(s)
|
|
60
|
+
|
|
61
|
+
out = series_parts[0]
|
|
62
|
+
for s in series_parts[1:]:
|
|
63
|
+
out = out.str.cat(s, na_rep="")
|
|
64
|
+
|
|
65
|
+
return out.tolist()
|
|
66
|
+
elif isinstance(data, pl.DataFrame):
|
|
67
|
+
exprs = []
|
|
68
|
+
for p in column:
|
|
69
|
+
if p in data.columns:
|
|
70
|
+
exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
|
|
71
|
+
else:
|
|
72
|
+
exprs.append(pl.lit(p))
|
|
73
|
+
|
|
74
|
+
result = data.select(
|
|
75
|
+
pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat")
|
|
76
|
+
)
|
|
77
|
+
return result["concat"].to_list()
|
|
78
|
+
return None
|
|
79
|
+
except Exception as e:
|
|
80
|
+
raise ValueError(f"Error handling column concatentation: {e}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def handle_data_helper(
|
|
84
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
|
|
85
|
+
):
|
|
86
|
+
if isinstance(data, list):
|
|
87
|
+
input_data = data
|
|
88
|
+
elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
89
|
+
if column is None:
|
|
90
|
+
raise ValueError("Column name must be specified for DataFrame input")
|
|
91
|
+
if isinstance(column, list):
|
|
92
|
+
input_data = do_dataframe_column_concatenation(data, column)
|
|
93
|
+
elif isinstance(column, str):
|
|
94
|
+
input_data = data[column].to_list()
|
|
95
|
+
elif isinstance(data, str):
|
|
96
|
+
if data.startswith("dataset-"):
|
|
97
|
+
input_data = data + ":" + column
|
|
98
|
+
else:
|
|
99
|
+
file_ext = os.path.splitext(data)[1].lower()
|
|
100
|
+
if file_ext == ".csv":
|
|
101
|
+
df = pl.read_csv(data)
|
|
102
|
+
elif file_ext == ".parquet":
|
|
103
|
+
df = pl.read_parquet(data)
|
|
104
|
+
elif file_ext in [".txt", ""]:
|
|
105
|
+
with open(data, "r") as file:
|
|
106
|
+
input_data = [line.strip() for line in file]
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Unsupported file type: {file_ext}")
|
|
109
|
+
|
|
110
|
+
if file_ext in [".csv", ".parquet"]:
|
|
111
|
+
if column is None:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"Column name must be specified for CSV/Parquet input"
|
|
114
|
+
)
|
|
115
|
+
input_data = df[column].to_list()
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Unsupported data type. Please provide a list, DataFrame, or file path."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return input_data
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def normalize_output_schema(
|
|
125
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel], None],
|
|
126
|
+
):
|
|
127
|
+
"""Consolidate any varied types for output_schema to dict format."""
|
|
128
|
+
if hasattr(output_schema, "model_json_schema"):
|
|
129
|
+
return output_schema.model_json_schema()
|
|
130
|
+
elif isinstance(output_schema, dict):
|
|
131
|
+
return output_schema
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Invalid output schema type. Must be a dictionary or a pydantic Model."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def to_colored_text(
|
|
139
|
+
text: str, state: Optional[Literal["success", "fail", "callout"]] = None
|
|
140
|
+
) -> str:
|
|
141
|
+
"""
|
|
142
|
+
Apply color to text based on state.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
text (str): The text to color
|
|
146
|
+
state (Optional[Literal['success', 'fail']]): The state that determines the color.
|
|
147
|
+
Options: 'success', 'fail', or None (default blue)
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
str: Text with appropriate color applied
|
|
151
|
+
"""
|
|
152
|
+
match state:
|
|
153
|
+
case "success":
|
|
154
|
+
return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
|
|
155
|
+
case "fail":
|
|
156
|
+
return f"{Fore.RED}{text}{Style.RESET_ALL}"
|
|
157
|
+
case "callout":
|
|
158
|
+
return f"{Fore.MAGENTA}{text}{Style.RESET_ALL}"
|
|
159
|
+
case _:
|
|
160
|
+
# Default to blue for normal/processing states
|
|
161
|
+
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def fancy_tqdm(
|
|
165
|
+
total: int,
|
|
166
|
+
desc: str = "Progress",
|
|
167
|
+
color: str = "blue",
|
|
168
|
+
style=1,
|
|
169
|
+
postfix: str = None,
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Creates a customized tqdm progress bar with different styling options.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
total (int): Total iterations
|
|
176
|
+
desc (str): Description for the progress bar
|
|
177
|
+
color (str): Color of the progress bar (green, blue, red, yellow, magenta)
|
|
178
|
+
style (int): Style preset (1-4)
|
|
179
|
+
postfix (str): Postfix for the progress bar
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
# Style presets
|
|
183
|
+
style_presets = {
|
|
184
|
+
1: {
|
|
185
|
+
"bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
|
|
186
|
+
"ascii": "░▒█",
|
|
187
|
+
},
|
|
188
|
+
2: {
|
|
189
|
+
"bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
|
|
190
|
+
"ascii": "▁▂▃▄▅▆▇█",
|
|
191
|
+
},
|
|
192
|
+
3: {
|
|
193
|
+
"bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
|
|
194
|
+
"ascii": "◯◔◑◕●",
|
|
195
|
+
},
|
|
196
|
+
4: {
|
|
197
|
+
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
198
|
+
"ascii": "⬜⬛",
|
|
199
|
+
},
|
|
200
|
+
5: {
|
|
201
|
+
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
202
|
+
"ascii": "▏▎▍▌▋▊▉█",
|
|
203
|
+
},
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
# Get style configuration
|
|
207
|
+
style_config = style_presets.get(style, style_presets[1])
|
|
208
|
+
|
|
209
|
+
return tqdm(
|
|
210
|
+
total=total,
|
|
211
|
+
desc=desc,
|
|
212
|
+
colour=color,
|
|
213
|
+
bar_format=style_config["bar_format"],
|
|
214
|
+
ascii=style_config["ascii"],
|
|
215
|
+
ncols=80,
|
|
216
|
+
dynamic_ncols=True,
|
|
217
|
+
smoothing=0.3,
|
|
218
|
+
leave=True,
|
|
219
|
+
postfix=postfix,
|
|
220
|
+
)
|
sutro/interfaces.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import polars as pl
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union, Type
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from sutro.common import ModelOptions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseSutroClient:
|
|
12
|
+
"""
|
|
13
|
+
Base class declaring attributes and method interfaces for template function mixins
|
|
14
|
+
to use.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# Core inference method interface
|
|
18
|
+
def infer(
|
|
19
|
+
self,
|
|
20
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
21
|
+
model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
|
|
22
|
+
name: Union[str, List[str]] = None,
|
|
23
|
+
description: Union[str, List[str]] = None,
|
|
24
|
+
column: Union[str, List[str]] = None,
|
|
25
|
+
output_column: str = "inference_result",
|
|
26
|
+
job_priority: int = 0,
|
|
27
|
+
output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
|
|
28
|
+
sampling_params: dict = None,
|
|
29
|
+
system_prompt: str = None,
|
|
30
|
+
dry_run: bool = False,
|
|
31
|
+
stay_attached: Optional[bool] = None,
|
|
32
|
+
random_seed_per_input: bool = False,
|
|
33
|
+
truncate_rows: bool = True,
|
|
34
|
+
) -> Any:
|
|
35
|
+
"""
|
|
36
|
+
Run inference on a dataset.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
data: Input data (list, DataFrame, or dataset ID)
|
|
40
|
+
model: Model(s) to use for inference
|
|
41
|
+
name: Job name(s)
|
|
42
|
+
description: Job description(s)
|
|
43
|
+
column: Column(s) to process
|
|
44
|
+
output_column: Name for output column
|
|
45
|
+
job_priority: Job priority (0-10, higher = more priority)
|
|
46
|
+
output_schema: Pydantic model or JSON schema for structured output
|
|
47
|
+
sampling_params: Model sampling parameters
|
|
48
|
+
system_prompt: System prompt for the model
|
|
49
|
+
dry_run: If True, validate without running
|
|
50
|
+
stay_attached: Wait for job completion
|
|
51
|
+
random_seed_per_input: Use random seed per input
|
|
52
|
+
truncate_rows: Truncate long inputs
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Job result or job ID
|
|
56
|
+
"""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
def await_job_completion(
|
|
60
|
+
self,
|
|
61
|
+
job_id: str,
|
|
62
|
+
timeout: Optional[int] = 7200,
|
|
63
|
+
obtain_results: bool = True,
|
|
64
|
+
output_column: str = "inference_result",
|
|
65
|
+
is_cost_estimate: bool = False,
|
|
66
|
+
) -> pl.DataFrame | None: ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class JobStatus(str, Enum):
|
|
70
|
+
"""Job statuses that will be returned by the API & SDK"""
|
|
71
|
+
|
|
72
|
+
UNKNOWN = "UNKNOWN"
|
|
73
|
+
QUEUED = "QUEUED" # Job is waiting to start
|
|
74
|
+
STARTING = "STARTING" # Job is in the process of starting up
|
|
75
|
+
RUNNING = "RUNNING" # Job is actively running
|
|
76
|
+
SUCCEEDED = "SUCCEEDED" # Job completed successfully
|
|
77
|
+
CANCELLING = "CANCELLING" # Job is in the process of being canceled
|
|
78
|
+
CANCELLED = "CANCELLED" # Job was canceled by the user
|
|
79
|
+
FAILED = "FAILED" # Job failed
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def terminal_statuses(cls) -> list["JobStatus"]:
|
|
83
|
+
return [
|
|
84
|
+
cls.SUCCEEDED,
|
|
85
|
+
cls.FAILED,
|
|
86
|
+
cls.CANCELLING,
|
|
87
|
+
cls.CANCELLED,
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
def is_terminal(self) -> bool:
|
|
91
|
+
return self in self.terminal_statuses()
|