sutro 0.1.37__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/cli.py CHANGED
@@ -52,7 +52,7 @@ def set_human_readable_dates(datetime_columns, df):
52
52
  # Convert UTC string to local time string
53
53
  df = df.with_columns(
54
54
  pl.col(col)
55
- .str.to_datetime()
55
+ .str.to_datetime("%Y-%m-%dT%H:%M:%S%.f%Z")
56
56
  .map_elements(
57
57
  lambda dt: dt.replace(tzinfo=timezone.utc)
58
58
  .astimezone()
sutro/common.py ADDED
@@ -0,0 +1,220 @@
1
+ import os
2
+ from typing import Union, List, Literal, Dict, Any, Type, Optional
3
+
4
+ import pandas as pd
5
+ import polars as pl
6
+ from colorama import Fore, Style
7
+ from pydantic import BaseModel
8
+ from tqdm import tqdm
9
+
10
+ EmbeddingModelOptions = Literal[
11
+ "qwen-3-embedding-0.6b",
12
+ "qwen-3-embedding-6b",
13
+ "qwen-3-embedding-8b",
14
+ ]
15
+
16
+ # Models available for inference. Keep in sync with the backend configuration
17
+ # so users get helpful autocompletion when selecting a model.
18
+ ModelOptions = Literal[
19
+ "llama-3.2-3b",
20
+ "llama-3.1-8b",
21
+ "llama-3.3-70b",
22
+ "llama-3.3-70b",
23
+ "qwen-3-4b",
24
+ "qwen-3-14b",
25
+ "qwen-3-32b",
26
+ "qwen-3-30b-a3b",
27
+ "qwen-3-235b-a22b",
28
+ "qwen-3-4b-thinking",
29
+ "qwen-3-14b-thinking",
30
+ "qwen-3-32b-thinking",
31
+ "qwen-3-235b-a22b-thinking",
32
+ "qwen-3-30b-a3b-thinking",
33
+ "gemma-3-4b-it",
34
+ "gemma-3-12b-it",
35
+ "gemma-3-27b-it",
36
+ "gpt-oss-20b",
37
+ "gpt-oss-120b",
38
+ "qwen-3-embedding-0.6b",
39
+ "qwen-3-embedding-6b",
40
+ "qwen-3-embedding-8b",
41
+ ]
42
+
43
+
44
+ def do_dataframe_column_concatenation(
45
+ data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]
46
+ ):
47
+ """
48
+ If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
49
+ """
50
+ try:
51
+ if isinstance(data, pd.DataFrame):
52
+ series_parts = []
53
+ for p in column:
54
+ if p in data.columns:
55
+ s = data[p].astype("string").fillna("")
56
+ else:
57
+ # Treat as a literal separator
58
+ s = pd.Series([p] * len(data), index=data.index, dtype="string")
59
+ series_parts.append(s)
60
+
61
+ out = series_parts[0]
62
+ for s in series_parts[1:]:
63
+ out = out.str.cat(s, na_rep="")
64
+
65
+ return out.tolist()
66
+ elif isinstance(data, pl.DataFrame):
67
+ exprs = []
68
+ for p in column:
69
+ if p in data.columns:
70
+ exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
71
+ else:
72
+ exprs.append(pl.lit(p))
73
+
74
+ result = data.select(
75
+ pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat")
76
+ )
77
+ return result["concat"].to_list()
78
+ return None
79
+ except Exception as e:
80
+ raise ValueError(f"Error handling column concatentation: {e}")
81
+
82
+
83
+ def handle_data_helper(
84
+ data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
85
+ ):
86
+ if isinstance(data, list):
87
+ input_data = data
88
+ elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
89
+ if column is None:
90
+ raise ValueError("Column name must be specified for DataFrame input")
91
+ if isinstance(column, list):
92
+ input_data = do_dataframe_column_concatenation(data, column)
93
+ elif isinstance(column, str):
94
+ input_data = data[column].to_list()
95
+ elif isinstance(data, str):
96
+ if data.startswith("dataset-"):
97
+ input_data = data + ":" + column
98
+ else:
99
+ file_ext = os.path.splitext(data)[1].lower()
100
+ if file_ext == ".csv":
101
+ df = pl.read_csv(data)
102
+ elif file_ext == ".parquet":
103
+ df = pl.read_parquet(data)
104
+ elif file_ext in [".txt", ""]:
105
+ with open(data, "r") as file:
106
+ input_data = [line.strip() for line in file]
107
+ else:
108
+ raise ValueError(f"Unsupported file type: {file_ext}")
109
+
110
+ if file_ext in [".csv", ".parquet"]:
111
+ if column is None:
112
+ raise ValueError(
113
+ "Column name must be specified for CSV/Parquet input"
114
+ )
115
+ input_data = df[column].to_list()
116
+ else:
117
+ raise ValueError(
118
+ "Unsupported data type. Please provide a list, DataFrame, or file path."
119
+ )
120
+
121
+ return input_data
122
+
123
+
124
+ def normalize_output_schema(
125
+ output_schema: Union[Dict[str, Any], Type[BaseModel], None],
126
+ ):
127
+ """Consolidate any varied types for output_schema to dict format."""
128
+ if hasattr(output_schema, "model_json_schema"):
129
+ return output_schema.model_json_schema()
130
+ elif isinstance(output_schema, dict):
131
+ return output_schema
132
+ else:
133
+ raise ValueError(
134
+ "Invalid output schema type. Must be a dictionary or a pydantic Model."
135
+ )
136
+
137
+
138
+ def to_colored_text(
139
+ text: str, state: Optional[Literal["success", "fail", "callout"]] = None
140
+ ) -> str:
141
+ """
142
+ Apply color to text based on state.
143
+
144
+ Args:
145
+ text (str): The text to color
146
+ state (Optional[Literal['success', 'fail']]): The state that determines the color.
147
+ Options: 'success', 'fail', or None (default blue)
148
+
149
+ Returns:
150
+ str: Text with appropriate color applied
151
+ """
152
+ match state:
153
+ case "success":
154
+ return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
155
+ case "fail":
156
+ return f"{Fore.RED}{text}{Style.RESET_ALL}"
157
+ case "callout":
158
+ return f"{Fore.MAGENTA}{text}{Style.RESET_ALL}"
159
+ case _:
160
+ # Default to blue for normal/processing states
161
+ return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
162
+
163
+
164
+ def fancy_tqdm(
165
+ total: int,
166
+ desc: str = "Progress",
167
+ color: str = "blue",
168
+ style=1,
169
+ postfix: str = None,
170
+ ):
171
+ """
172
+ Creates a customized tqdm progress bar with different styling options.
173
+
174
+ Args:
175
+ total (int): Total iterations
176
+ desc (str): Description for the progress bar
177
+ color (str): Color of the progress bar (green, blue, red, yellow, magenta)
178
+ style (int): Style preset (1-4)
179
+ postfix (str): Postfix for the progress bar
180
+ """
181
+
182
+ # Style presets
183
+ style_presets = {
184
+ 1: {
185
+ "bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
186
+ "ascii": "░▒█",
187
+ },
188
+ 2: {
189
+ "bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
190
+ "ascii": "▁▂▃▄▅▆▇█",
191
+ },
192
+ 3: {
193
+ "bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
194
+ "ascii": "◯◔◑◕●",
195
+ },
196
+ 4: {
197
+ "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
198
+ "ascii": "⬜⬛",
199
+ },
200
+ 5: {
201
+ "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
202
+ "ascii": "▏▎▍▌▋▊▉█",
203
+ },
204
+ }
205
+
206
+ # Get style configuration
207
+ style_config = style_presets.get(style, style_presets[1])
208
+
209
+ return tqdm(
210
+ total=total,
211
+ desc=desc,
212
+ colour=color,
213
+ bar_format=style_config["bar_format"],
214
+ ascii=style_config["ascii"],
215
+ ncols=80,
216
+ dynamic_ncols=True,
217
+ smoothing=0.3,
218
+ leave=True,
219
+ postfix=postfix,
220
+ )
sutro/interfaces.py ADDED
@@ -0,0 +1,90 @@
1
+ from enum import Enum
2
+
3
+ import pandas as pd
4
+ import polars as pl
5
+ from typing import Any, Dict, List, Optional, Union, Type
6
+ from pydantic import BaseModel
7
+
8
+ from sutro.common import ModelOptions
9
+
10
+
11
+ class BaseSutroClient:
12
+ """
13
+ Base class declaring attributes and method interfaces for template function mixins
14
+ to use.
15
+ """
16
+
17
+ # Core inference method interface
18
+ def infer(
19
+ self,
20
+ data: Union[List, pd.DataFrame, pl.DataFrame, str],
21
+ model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
22
+ name: Union[str, List[str]] = None,
23
+ description: Union[str, List[str]] = None,
24
+ column: Union[str, List[str]] = None,
25
+ output_column: str = "inference_result",
26
+ job_priority: int = 0,
27
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
28
+ sampling_params: dict = None,
29
+ system_prompt: str = None,
30
+ dry_run: bool = False,
31
+ stay_attached: Optional[bool] = None,
32
+ random_seed_per_input: bool = False,
33
+ truncate_rows: bool = True,
34
+ ) -> Any:
35
+ """
36
+ Run inference on a dataset.
37
+
38
+ Args:
39
+ data: Input data (list, DataFrame, or dataset ID)
40
+ model: Model(s) to use for inference
41
+ name: Job name(s)
42
+ description: Job description(s)
43
+ column: Column(s) to process
44
+ output_column: Name for output column
45
+ job_priority: Job priority (0-10, higher = more priority)
46
+ output_schema: Pydantic model or JSON schema for structured output
47
+ sampling_params: Model sampling parameters
48
+ system_prompt: System prompt for the model
49
+ dry_run: If True, validate without running
50
+ stay_attached: Wait for job completion
51
+ random_seed_per_input: Use random seed per input
52
+ truncate_rows: Truncate long inputs
53
+
54
+ Returns:
55
+ Job result or job ID
56
+ """
57
+ ...
58
+
59
+ def await_job_completion(
60
+ self,
61
+ job_id: str,
62
+ timeout: Optional[int] = 7200,
63
+ obtain_results: bool = True,
64
+ is_cost_estimate: bool = False,
65
+ ) -> pl.DataFrame | None: ...
66
+
67
+
68
+ class JobStatus(str, Enum):
69
+ """Job statuses that will be returned by the API & SDK"""
70
+
71
+ UNKNOWN = "UNKNOWN"
72
+ QUEUED = "QUEUED" # Job is waiting to start
73
+ STARTING = "STARTING" # Job is in the process of starting up
74
+ RUNNING = "RUNNING" # Job is actively running
75
+ SUCCEEDED = "SUCCEEDED" # Job completed successfully
76
+ CANCELLING = "CANCELLING" # Job is in the process of being canceled
77
+ CANCELLED = "CANCELLED" # Job was canceled by the user
78
+ FAILED = "FAILED" # Job failed
79
+
80
+ @classmethod
81
+ def terminal_statuses(cls) -> list["JobStatus"]:
82
+ return [
83
+ cls.SUCCEEDED,
84
+ cls.FAILED,
85
+ cls.CANCELLING,
86
+ cls.CANCELLED,
87
+ ]
88
+
89
+ def is_terminal(self) -> bool:
90
+ return self in self.terminal_statuses()