sutro 0.0.0__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutro/__init__.py +14 -0
- sutro/cli.py +418 -0
- sutro/sdk.py +1101 -0
- sutro-0.1.11.dist-info/METADATA +41 -0
- sutro-0.1.11.dist-info/RECORD +8 -0
- {sutro-0.0.0.dist-info → sutro-0.1.11.dist-info}/WHEEL +1 -2
- sutro-0.1.11.dist-info/entry_points.txt +2 -0
- sutro-0.1.11.dist-info/licenses/LICENSE +201 -0
- __init__.py +0 -1
- hi.py +0 -1
- sutro-0.0.0.dist-info/METADATA +0 -6
- sutro-0.0.0.dist-info/RECORD +0 -6
- sutro-0.0.0.dist-info/top_level.txt +0 -2
sutro/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .sdk import Sutro
|
|
2
|
+
|
|
3
|
+
# Create an instance of the class
|
|
4
|
+
_instance = Sutro()
|
|
5
|
+
|
|
6
|
+
# Import all methods from the instance into the package namespace
|
|
7
|
+
from types import MethodType
|
|
8
|
+
|
|
9
|
+
for attr in dir(_instance):
|
|
10
|
+
if callable(getattr(_instance, attr)) and not attr.startswith("__"):
|
|
11
|
+
globals()[attr] = MethodType(getattr(_instance, attr).__func__, _instance)
|
|
12
|
+
|
|
13
|
+
# Clean up namespace
|
|
14
|
+
del MethodType, attr
|
sutro/cli.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
from datetime import datetime, timezone
|
|
2
|
+
import click
|
|
3
|
+
from colorama import Fore, Style
|
|
4
|
+
import os
|
|
5
|
+
import json
|
|
6
|
+
from sutro.sdk import Sutro
|
|
7
|
+
import polars as pl
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
warnings.filterwarnings("ignore", category=pl.PolarsInefficientMapWarning)
|
|
11
|
+
pl.Config.set_tbl_hide_dataframe_shape(True)
|
|
12
|
+
|
|
13
|
+
CONFIG_DIR = os.path.expanduser("~/.sutro")
|
|
14
|
+
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_config():
|
|
18
|
+
if os.path.exists(CONFIG_FILE):
|
|
19
|
+
with open(CONFIG_FILE, "r") as f:
|
|
20
|
+
return json.load(f)
|
|
21
|
+
return {}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def save_config(config):
|
|
25
|
+
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
26
|
+
with open(CONFIG_FILE, "w") as f:
|
|
27
|
+
json.dump(config, f)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def check_auth():
|
|
31
|
+
config = load_config()
|
|
32
|
+
return config.get("api_key") is not None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_sdk():
|
|
36
|
+
config = load_config()
|
|
37
|
+
if config.get("base_url") != None:
|
|
38
|
+
return Sutro(
|
|
39
|
+
api_key=config.get("api_key"), base_url=config.get("base_url")
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
return Sutro(api_key=config.get("api_key"))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def set_config_base_url(base_url: str):
|
|
46
|
+
config = load_config()
|
|
47
|
+
config["base_url"] = base_url
|
|
48
|
+
save_config(config)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def set_human_readable_dates(datetime_columns, df):
|
|
52
|
+
for col in datetime_columns:
|
|
53
|
+
if col in df.columns:
|
|
54
|
+
# Convert UTC string to local time string
|
|
55
|
+
df = df.with_columns(
|
|
56
|
+
pl.col(col)
|
|
57
|
+
.str.to_datetime()
|
|
58
|
+
.map_elements(
|
|
59
|
+
lambda dt: dt.replace(tzinfo=timezone.utc)
|
|
60
|
+
.astimezone()
|
|
61
|
+
.strftime("%Y-%m-%d %H:%M:%S %Z")
|
|
62
|
+
if dt
|
|
63
|
+
else None,
|
|
64
|
+
return_dtype=pl.Utf8,
|
|
65
|
+
)
|
|
66
|
+
.alias(col)
|
|
67
|
+
)
|
|
68
|
+
return df
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@click.group(invoke_without_command=True)
|
|
72
|
+
@click.pass_context
|
|
73
|
+
def cli(ctx):
|
|
74
|
+
# Allow login and set-base-url commands without authentication
|
|
75
|
+
if not check_auth() and ctx.invoked_subcommand not in ["login", "set-base-url"]:
|
|
76
|
+
click.echo("Please login using 'sutro login'.")
|
|
77
|
+
ctx.exit(1)
|
|
78
|
+
|
|
79
|
+
if ctx.invoked_subcommand is None:
|
|
80
|
+
message = """
|
|
81
|
+
Welcome to the Sutro CLI!
|
|
82
|
+
|
|
83
|
+
To see a list of all available commands, use 'sutro --help'.
|
|
84
|
+
"""
|
|
85
|
+
click.echo(Fore.GREEN + message + Style.RESET_ALL)
|
|
86
|
+
|
|
87
|
+
click.echo(ctx.get_help())
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@cli.command()
|
|
91
|
+
def login():
|
|
92
|
+
"""Set or update your API key for Sutro."""
|
|
93
|
+
config = load_config()
|
|
94
|
+
default_api_key = config.get("api_key", "")
|
|
95
|
+
default_base_url = config.get("base_url", "https://api.sutro.sh")
|
|
96
|
+
click.echo(
|
|
97
|
+
"Hint: An API key is already set. Press Enter to keep the existing key."
|
|
98
|
+
if default_api_key
|
|
99
|
+
else ""
|
|
100
|
+
)
|
|
101
|
+
api_key = click.prompt(
|
|
102
|
+
"Enter your API key",
|
|
103
|
+
default=default_api_key,
|
|
104
|
+
hide_input=True,
|
|
105
|
+
show_default=False,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
result = get_sdk().try_authentication(api_key)
|
|
109
|
+
if not result or "authenticated" not in result or result["authenticated"] != True:
|
|
110
|
+
raise click.ClickException(
|
|
111
|
+
Fore.RED + "Invalid API key. Try again." + Style.RESET_ALL
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
ascii = """
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
▄▄▄▄▄▄▄▄▄▄▄ ▄ ▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄
|
|
118
|
+
▐░░░░░░░░░░░▌▐░▌ ▐░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌
|
|
119
|
+
▐░█▀▀▀▀▀▀▀▀▀ ▐░▌ ▐░▌ ▀▀▀▀█░█▀▀▀▀ ▐░█▀▀▀▀▀▀▀█░▌▐░█▀▀▀▀▀▀▀█░▌
|
|
120
|
+
▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌▐░▌ ▐░▌
|
|
121
|
+
▐░█▄▄▄▄▄▄▄▄▄ ▐░▌ ▐░▌ ▐░▌ ▐░█▄▄▄▄▄▄▄█░▌▐░▌ ▐░▌
|
|
122
|
+
▐░░░░░░░░░░░▌▐░▌ ▐░▌ ▐░▌ ▐░░░░░░░░░░░▌▐░▌ ▐░▌
|
|
123
|
+
▀▀▀▀▀▀▀▀▀█░▌▐░▌ ▐░▌ ▐░▌ ▐░█▀▀▀▀█░█▀▀ ▐░▌ ▐░▌
|
|
124
|
+
▐░▌▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌
|
|
125
|
+
▄▄▄▄▄▄▄▄▄█░▌▐░█▄▄▄▄▄▄▄█░▌ ▐░▌ ▐░▌ ▐░▌ ▐░█▄▄▄▄▄▄▄█░▌
|
|
126
|
+
▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌ ▐░▌ ▐░▌ ▐░▌▐░░░░░░░░░░░▌
|
|
127
|
+
▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ ▀ ▀ ▀ ▀▀▀▀▀▀▀▀▀▀▀
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
click.echo(Fore.BLUE + ascii + Style.RESET_ALL)
|
|
132
|
+
click.echo(
|
|
133
|
+
Fore.GREEN + "Successfully authenticated. Welcome back!" + Style.RESET_ALL
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
save_config({"api_key": api_key, "base_url": default_base_url})
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@cli.group()
|
|
140
|
+
def jobs():
|
|
141
|
+
"""Manage jobs."""
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@jobs.command()
|
|
146
|
+
@click.option(
|
|
147
|
+
"--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
|
|
148
|
+
)
|
|
149
|
+
def list(all=False):
|
|
150
|
+
"""Lists historical and ongoing jobs. Will only list first 25 jobs by default. Use --all to see all jobs."""
|
|
151
|
+
sdk = get_sdk()
|
|
152
|
+
jobs = sdk.list_jobs()
|
|
153
|
+
if jobs is None or len(jobs) == 0:
|
|
154
|
+
click.echo(Fore.YELLOW + "No jobs found." + Style.RESET_ALL)
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
df = pl.DataFrame(jobs)
|
|
158
|
+
# TODO: this is a temporary fix to remove jobs where datetime_created is null. We should fix this on the backend.
|
|
159
|
+
df = df.filter(pl.col("datetime_created").is_not_null())
|
|
160
|
+
df = df.sort(by=["datetime_created"], descending=True)
|
|
161
|
+
|
|
162
|
+
# Format all datetime columns with a more readable format
|
|
163
|
+
datetime_columns = [
|
|
164
|
+
"datetime_created",
|
|
165
|
+
"datetime_added",
|
|
166
|
+
"datetime_started",
|
|
167
|
+
"datetime_completed",
|
|
168
|
+
]
|
|
169
|
+
df = set_human_readable_dates(datetime_columns, df)
|
|
170
|
+
|
|
171
|
+
# TODO: get colors working
|
|
172
|
+
# df = df.with_columns([
|
|
173
|
+
# pl.when(pl.col("status") == "SUCCEEDED")
|
|
174
|
+
# .then(pl.concat_str([pl.lit(Fore.GREEN), pl.col("status"), pl.lit(Style.RESET_ALL)]))
|
|
175
|
+
# .when(pl.col("status").is_in(["FAILED", "CANCELLED", "UNKNOWN"]))
|
|
176
|
+
# .then(pl.concat_str([pl.lit(Fore.RED), pl.col("status"), pl.lit(Style.RESET_ALL)]))
|
|
177
|
+
# .otherwise(pl.col("status"))
|
|
178
|
+
# .alias("status")
|
|
179
|
+
# ])
|
|
180
|
+
|
|
181
|
+
# fill null input_tokens and output_tokens with 0
|
|
182
|
+
df = df.with_columns(
|
|
183
|
+
pl.col("input_tokens").fill_null(0).alias("input_tokens"),
|
|
184
|
+
pl.col("output_tokens").fill_null(0).alias("output_tokens"),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# fill null datetime_completed with empty string
|
|
188
|
+
df = df.with_columns(
|
|
189
|
+
pl.col("datetime_completed").fill_null("").alias("datetime_completed")
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
df = df.with_columns(
|
|
193
|
+
pl.col("job_cost")
|
|
194
|
+
.fill_null(0)
|
|
195
|
+
.map_elements(lambda x: f"${x:.5f}", return_dtype=pl.Utf8)
|
|
196
|
+
.alias("job_cost")
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if all == False:
|
|
200
|
+
df = df.slice(0, 25)
|
|
201
|
+
|
|
202
|
+
with pl.Config(tbl_rows=-1, tbl_cols=-1, set_fmt_str_lengths=45):
|
|
203
|
+
print(df.select(pl.all()))
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@jobs.command()
|
|
207
|
+
@click.argument("job_id")
|
|
208
|
+
def status(job_id):
|
|
209
|
+
"""Get the status of a job."""
|
|
210
|
+
sdk = get_sdk()
|
|
211
|
+
job_status = sdk.get_job_status(job_id)
|
|
212
|
+
if not job_status:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
print(job_status)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@jobs.command()
|
|
219
|
+
@click.argument("job_id")
|
|
220
|
+
@click.option(
|
|
221
|
+
"--include-inputs", is_flag=True, help="Include the inputs in the results."
|
|
222
|
+
)
|
|
223
|
+
@click.option(
|
|
224
|
+
"--include-cumulative-logprobs",
|
|
225
|
+
is_flag=True,
|
|
226
|
+
help="Include the cumulative logprobs in the results.",
|
|
227
|
+
)
|
|
228
|
+
@click.option(
|
|
229
|
+
"--save",
|
|
230
|
+
is_flag=True,
|
|
231
|
+
help="Download the results to the current working directory. The file name will be the job_id.",
|
|
232
|
+
)
|
|
233
|
+
@click.option(
|
|
234
|
+
"--save-format",
|
|
235
|
+
type=click.Choice(["parquet", "csv"]),
|
|
236
|
+
default="parquet",
|
|
237
|
+
help="The format of the output file. Options: parquet, csv",
|
|
238
|
+
)
|
|
239
|
+
def results(
|
|
240
|
+
job_id,
|
|
241
|
+
include_inputs,
|
|
242
|
+
include_cumulative_logprobs,
|
|
243
|
+
save=False,
|
|
244
|
+
save_format="parquet",
|
|
245
|
+
):
|
|
246
|
+
"""Get the results of a job."""
|
|
247
|
+
sdk = get_sdk()
|
|
248
|
+
job_results = sdk.get_job_results(
|
|
249
|
+
job_id, include_inputs, include_cumulative_logprobs
|
|
250
|
+
)
|
|
251
|
+
if not job_results:
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
df = pl.DataFrame(job_results)
|
|
255
|
+
if not save:
|
|
256
|
+
print(df)
|
|
257
|
+
elif save:
|
|
258
|
+
if save_format == "parquet":
|
|
259
|
+
df.write_parquet(f"{job_id}.parquet")
|
|
260
|
+
else: # csv
|
|
261
|
+
df.write_csv(f"{job_id}.csv")
|
|
262
|
+
print(Fore.GREEN + f"Results saved to {job_id}.{save_format}" + Style.RESET_ALL)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@jobs.command()
|
|
266
|
+
@click.argument("job_id")
|
|
267
|
+
def cancel(job_id):
|
|
268
|
+
"""Cancel a running job."""
|
|
269
|
+
sdk = get_sdk()
|
|
270
|
+
result = sdk.cancel_job(job_id)
|
|
271
|
+
if not result:
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
click.echo(Fore.GREEN + "Job cancelled successfully." + Style.RESET_ALL)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@cli.group()
|
|
278
|
+
def stages():
|
|
279
|
+
"""Manage stages."""
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@stages.command()
|
|
284
|
+
def create():
|
|
285
|
+
"""Create a new stage."""
|
|
286
|
+
sdk = get_sdk()
|
|
287
|
+
stage_id = sdk.create_stage()
|
|
288
|
+
if not stage_id:
|
|
289
|
+
return
|
|
290
|
+
click.echo(
|
|
291
|
+
Fore.GREEN
|
|
292
|
+
+ f"Stage created successfully. Stage ID: {stage_id}"
|
|
293
|
+
+ Style.RESET_ALL
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@stages.command()
|
|
298
|
+
def list():
|
|
299
|
+
"""List all stages."""
|
|
300
|
+
sdk = get_sdk()
|
|
301
|
+
stages = sdk.list_stages()
|
|
302
|
+
if stages is None or len(stages) == 0:
|
|
303
|
+
click.echo(Fore.YELLOW + "No stages found." + Style.RESET_ALL)
|
|
304
|
+
return
|
|
305
|
+
df = pl.DataFrame(stages)
|
|
306
|
+
|
|
307
|
+
df = df.with_columns(
|
|
308
|
+
pl.col("schema")
|
|
309
|
+
.map_elements(lambda x: str(x), return_dtype=pl.Utf8)
|
|
310
|
+
.alias("schema")
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Format all datetime columns with a more readable format
|
|
314
|
+
datetime_columns = ["datetime_added", "updated_at"]
|
|
315
|
+
df = set_human_readable_dates(datetime_columns, df)
|
|
316
|
+
|
|
317
|
+
df = df.sort(by=["datetime_added"], descending=True)
|
|
318
|
+
with pl.Config(tbl_rows=-1, tbl_cols=-1, set_fmt_str_lengths=45):
|
|
319
|
+
print(df.select(pl.all()))
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@stages.command()
|
|
323
|
+
@click.argument("stage_id")
|
|
324
|
+
def files(stage_id):
|
|
325
|
+
"""List all files in a stage."""
|
|
326
|
+
sdk = get_sdk()
|
|
327
|
+
files = sdk.list_stage_files(stage_id)
|
|
328
|
+
if not files:
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
print(Fore.YELLOW + "Files in stage " + stage_id + ":" + Style.RESET_ALL)
|
|
332
|
+
for file in files:
|
|
333
|
+
print(f"\t{file}")
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@stages.command()
|
|
337
|
+
@click.argument("stage_id", required=False)
|
|
338
|
+
@click.argument("file_path")
|
|
339
|
+
def upload(file_path, stage_id):
|
|
340
|
+
"""Upload files to a stage. You can provide a single file path or a directory path to upload all files in the directory."""
|
|
341
|
+
sdk = get_sdk()
|
|
342
|
+
sdk.upload_to_stage(file_path, stage_id)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@stages.command()
|
|
346
|
+
@click.argument("stage_id")
|
|
347
|
+
@click.argument("file_name", required=False)
|
|
348
|
+
@click.argument("output_path", required=False)
|
|
349
|
+
def download(stage_id, file_name=None, output_path=None):
|
|
350
|
+
"""Download a file/files from a stage. If no files are provided, all files in the stage will be downloaded. If no output path is provided, the file will be saved to the current working directory."""
|
|
351
|
+
sdk = get_sdk()
|
|
352
|
+
files = sdk.download_from_stage(stage_id, [file_name], output_path)
|
|
353
|
+
if not files:
|
|
354
|
+
return
|
|
355
|
+
for file in files:
|
|
356
|
+
if output_path is None:
|
|
357
|
+
with open(file_name, "wb") as f:
|
|
358
|
+
f.write(file)
|
|
359
|
+
else:
|
|
360
|
+
with open(output_path + "/" + file_name, "wb") as f:
|
|
361
|
+
f.write(file)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@cli.command()
|
|
365
|
+
def docs():
|
|
366
|
+
"""Open the Sutro API docs."""
|
|
367
|
+
click.launch("https://docs.sutro.sh")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@cli.command()
|
|
371
|
+
@click.argument("base_url")
|
|
372
|
+
def set_base_url(base_url):
|
|
373
|
+
"""Set the base URL for the Sutro API."""
|
|
374
|
+
set_config_base_url(base_url)
|
|
375
|
+
click.echo(Fore.GREEN + f"Base URL set to {base_url}." + Style.RESET_ALL)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@cli.command()
|
|
379
|
+
def quotas():
|
|
380
|
+
"""Get API quotas."""
|
|
381
|
+
sdk = get_sdk()
|
|
382
|
+
quotas = sdk.get_quotas()
|
|
383
|
+
if not quotas:
|
|
384
|
+
return
|
|
385
|
+
print(Fore.YELLOW + "Your current quotas are: \n" + Style.RESET_ALL)
|
|
386
|
+
for priority in range(len(quotas)):
|
|
387
|
+
quota = quotas[priority]
|
|
388
|
+
print(f"Job Priority: {priority}")
|
|
389
|
+
print(f"\tRow Quota (Maximum): {quota['row_quota']}")
|
|
390
|
+
print(f"\tToken Quota (Maximum): {quota['token_quota']}")
|
|
391
|
+
print("\n")
|
|
392
|
+
print(
|
|
393
|
+
Fore.YELLOW
|
|
394
|
+
+ "To increase your quotas, contact us at team@sutro.sh."
|
|
395
|
+
+ Style.RESET_ALL
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
@jobs.command()
|
|
399
|
+
@click.argument("job_id", required=False)
|
|
400
|
+
@click.option("--latest", is_flag=True, help="Attach to the latest job.")
|
|
401
|
+
def attach(job_id, latest):
|
|
402
|
+
"""Attach to a running job and stream its progress."""
|
|
403
|
+
sdk = get_sdk()
|
|
404
|
+
if latest:
|
|
405
|
+
jobs = sdk.list_jobs()
|
|
406
|
+
if not jobs:
|
|
407
|
+
click.echo(Fore.YELLOW + "No jobs found." + Style.RESET_ALL)
|
|
408
|
+
return
|
|
409
|
+
job_id = jobs[0]["job_id"]
|
|
410
|
+
print(f"Attaching to latest job: {job_id}")
|
|
411
|
+
elif not job_id:
|
|
412
|
+
click.echo(Fore.YELLOW + "No job ID provided." + Style.RESET_ALL)
|
|
413
|
+
return
|
|
414
|
+
sdk.attach(job_id)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
if __name__ == "__main__":
|
|
418
|
+
cli()
|