sutro 0.0.0__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ from .sdk import Sutro
2
+
3
+ # Create an instance of the class
4
+ _instance = Sutro()
5
+
6
+ # Import all methods from the instance into the package namespace
7
+ from types import MethodType
8
+
9
+ for attr in dir(_instance):
10
+ if callable(getattr(_instance, attr)) and not attr.startswith("__"):
11
+ globals()[attr] = MethodType(getattr(_instance, attr).__func__, _instance)
12
+
13
+ # Clean up namespace
14
+ del MethodType, attr
sutro/cli.py ADDED
@@ -0,0 +1,418 @@
1
+ from datetime import datetime, timezone
2
+ import click
3
+ from colorama import Fore, Style
4
+ import os
5
+ import json
6
+ from sutro.sdk import Sutro
7
+ import polars as pl
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore", category=pl.PolarsInefficientMapWarning)
11
+ pl.Config.set_tbl_hide_dataframe_shape(True)
12
+
13
+ CONFIG_DIR = os.path.expanduser("~/.sutro")
14
+ CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
15
+
16
+
17
+ def load_config():
18
+ if os.path.exists(CONFIG_FILE):
19
+ with open(CONFIG_FILE, "r") as f:
20
+ return json.load(f)
21
+ return {}
22
+
23
+
24
+ def save_config(config):
25
+ os.makedirs(CONFIG_DIR, exist_ok=True)
26
+ with open(CONFIG_FILE, "w") as f:
27
+ json.dump(config, f)
28
+
29
+
30
+ def check_auth():
31
+ config = load_config()
32
+ return config.get("api_key") is not None
33
+
34
+
35
+ def get_sdk():
36
+ config = load_config()
37
+ if config.get("base_url") != None:
38
+ return Sutro(
39
+ api_key=config.get("api_key"), base_url=config.get("base_url")
40
+ )
41
+ else:
42
+ return Sutro(api_key=config.get("api_key"))
43
+
44
+
45
+ def set_config_base_url(base_url: str):
46
+ config = load_config()
47
+ config["base_url"] = base_url
48
+ save_config(config)
49
+
50
+
51
+ def set_human_readable_dates(datetime_columns, df):
52
+ for col in datetime_columns:
53
+ if col in df.columns:
54
+ # Convert UTC string to local time string
55
+ df = df.with_columns(
56
+ pl.col(col)
57
+ .str.to_datetime()
58
+ .map_elements(
59
+ lambda dt: dt.replace(tzinfo=timezone.utc)
60
+ .astimezone()
61
+ .strftime("%Y-%m-%d %H:%M:%S %Z")
62
+ if dt
63
+ else None,
64
+ return_dtype=pl.Utf8,
65
+ )
66
+ .alias(col)
67
+ )
68
+ return df
69
+
70
+
71
+ @click.group(invoke_without_command=True)
72
+ @click.pass_context
73
+ def cli(ctx):
74
+ # Allow login and set-base-url commands without authentication
75
+ if not check_auth() and ctx.invoked_subcommand not in ["login", "set-base-url"]:
76
+ click.echo("Please login using 'sutro login'.")
77
+ ctx.exit(1)
78
+
79
+ if ctx.invoked_subcommand is None:
80
+ message = """
81
+ Welcome to the Sutro CLI!
82
+
83
+ To see a list of all available commands, use 'sutro --help'.
84
+ """
85
+ click.echo(Fore.GREEN + message + Style.RESET_ALL)
86
+
87
+ click.echo(ctx.get_help())
88
+
89
+
90
+ @cli.command()
91
+ def login():
92
+ """Set or update your API key for Sutro."""
93
+ config = load_config()
94
+ default_api_key = config.get("api_key", "")
95
+ default_base_url = config.get("base_url", "https://api.sutro.sh")
96
+ click.echo(
97
+ "Hint: An API key is already set. Press Enter to keep the existing key."
98
+ if default_api_key
99
+ else ""
100
+ )
101
+ api_key = click.prompt(
102
+ "Enter your API key",
103
+ default=default_api_key,
104
+ hide_input=True,
105
+ show_default=False,
106
+ )
107
+
108
+ result = get_sdk().try_authentication(api_key)
109
+ if not result or "authenticated" not in result or result["authenticated"] != True:
110
+ raise click.ClickException(
111
+ Fore.RED + "Invalid API key. Try again." + Style.RESET_ALL
112
+ )
113
+ else:
114
+ ascii = """
115
+
116
+
117
+ ▄▄▄▄▄▄▄▄▄▄▄ ▄ ▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄
118
+ ▐░░░░░░░░░░░▌▐░▌ ▐░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌
119
+ ▐░█▀▀▀▀▀▀▀▀▀ ▐░▌ ▐░▌ ▀▀▀▀█░█▀▀▀▀ ▐░█▀▀▀▀▀▀▀█░▌▐░█▀▀▀▀▀▀▀█░▌
120
+ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌▐░▌ ▐░▌
121
+ ▐░█▄▄▄▄▄▄▄▄▄ ▐░▌ ▐░▌ ▐░▌ ▐░█▄▄▄▄▄▄▄█░▌▐░▌ ▐░▌
122
+ ▐░░░░░░░░░░░▌▐░▌ ▐░▌ ▐░▌ ▐░░░░░░░░░░░▌▐░▌ ▐░▌
123
+ ▀▀▀▀▀▀▀▀▀█░▌▐░▌ ▐░▌ ▐░▌ ▐░█▀▀▀▀█░█▀▀ ▐░▌ ▐░▌
124
+ ▐░▌▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌ ▐░▌
125
+ ▄▄▄▄▄▄▄▄▄█░▌▐░█▄▄▄▄▄▄▄█░▌ ▐░▌ ▐░▌ ▐░▌ ▐░█▄▄▄▄▄▄▄█░▌
126
+ ▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌ ▐░▌ ▐░▌ ▐░▌▐░░░░░░░░░░░▌
127
+ ▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ ▀ ▀ ▀ ▀▀▀▀▀▀▀▀▀▀▀
128
+
129
+
130
+ """
131
+ click.echo(Fore.BLUE + ascii + Style.RESET_ALL)
132
+ click.echo(
133
+ Fore.GREEN + "Successfully authenticated. Welcome back!" + Style.RESET_ALL
134
+ )
135
+
136
+ save_config({"api_key": api_key, "base_url": default_base_url})
137
+
138
+
139
+ @cli.group()
140
+ def jobs():
141
+ """Manage jobs."""
142
+ pass
143
+
144
+
145
+ @jobs.command()
146
+ @click.option(
147
+ "--all", is_flag=True, help="Include all jobs, including cancelled and failed ones."
148
+ )
149
+ def list(all=False):
150
+ """Lists historical and ongoing jobs. Will only list first 25 jobs by default. Use --all to see all jobs."""
151
+ sdk = get_sdk()
152
+ jobs = sdk.list_jobs()
153
+ if jobs is None or len(jobs) == 0:
154
+ click.echo(Fore.YELLOW + "No jobs found." + Style.RESET_ALL)
155
+ return
156
+
157
+ df = pl.DataFrame(jobs)
158
+ # TODO: this is a temporary fix to remove jobs where datetime_created is null. We should fix this on the backend.
159
+ df = df.filter(pl.col("datetime_created").is_not_null())
160
+ df = df.sort(by=["datetime_created"], descending=True)
161
+
162
+ # Format all datetime columns with a more readable format
163
+ datetime_columns = [
164
+ "datetime_created",
165
+ "datetime_added",
166
+ "datetime_started",
167
+ "datetime_completed",
168
+ ]
169
+ df = set_human_readable_dates(datetime_columns, df)
170
+
171
+ # TODO: get colors working
172
+ # df = df.with_columns([
173
+ # pl.when(pl.col("status") == "SUCCEEDED")
174
+ # .then(pl.concat_str([pl.lit(Fore.GREEN), pl.col("status"), pl.lit(Style.RESET_ALL)]))
175
+ # .when(pl.col("status").is_in(["FAILED", "CANCELLED", "UNKNOWN"]))
176
+ # .then(pl.concat_str([pl.lit(Fore.RED), pl.col("status"), pl.lit(Style.RESET_ALL)]))
177
+ # .otherwise(pl.col("status"))
178
+ # .alias("status")
179
+ # ])
180
+
181
+ # fill null input_tokens and output_tokens with 0
182
+ df = df.with_columns(
183
+ pl.col("input_tokens").fill_null(0).alias("input_tokens"),
184
+ pl.col("output_tokens").fill_null(0).alias("output_tokens"),
185
+ )
186
+
187
+ # fill null datetime_completed with empty string
188
+ df = df.with_columns(
189
+ pl.col("datetime_completed").fill_null("").alias("datetime_completed")
190
+ )
191
+
192
+ df = df.with_columns(
193
+ pl.col("job_cost")
194
+ .fill_null(0)
195
+ .map_elements(lambda x: f"${x:.5f}", return_dtype=pl.Utf8)
196
+ .alias("job_cost")
197
+ )
198
+
199
+ if all == False:
200
+ df = df.slice(0, 25)
201
+
202
+ with pl.Config(tbl_rows=-1, tbl_cols=-1, set_fmt_str_lengths=45):
203
+ print(df.select(pl.all()))
204
+
205
+
206
+ @jobs.command()
207
+ @click.argument("job_id")
208
+ def status(job_id):
209
+ """Get the status of a job."""
210
+ sdk = get_sdk()
211
+ job_status = sdk.get_job_status(job_id)
212
+ if not job_status:
213
+ return
214
+
215
+ print(job_status)
216
+
217
+
218
+ @jobs.command()
219
+ @click.argument("job_id")
220
+ @click.option(
221
+ "--include-inputs", is_flag=True, help="Include the inputs in the results."
222
+ )
223
+ @click.option(
224
+ "--include-cumulative-logprobs",
225
+ is_flag=True,
226
+ help="Include the cumulative logprobs in the results.",
227
+ )
228
+ @click.option(
229
+ "--save",
230
+ is_flag=True,
231
+ help="Download the results to the current working directory. The file name will be the job_id.",
232
+ )
233
+ @click.option(
234
+ "--save-format",
235
+ type=click.Choice(["parquet", "csv"]),
236
+ default="parquet",
237
+ help="The format of the output file. Options: parquet, csv",
238
+ )
239
+ def results(
240
+ job_id,
241
+ include_inputs,
242
+ include_cumulative_logprobs,
243
+ save=False,
244
+ save_format="parquet",
245
+ ):
246
+ """Get the results of a job."""
247
+ sdk = get_sdk()
248
+ job_results = sdk.get_job_results(
249
+ job_id, include_inputs, include_cumulative_logprobs
250
+ )
251
+ if not job_results:
252
+ return
253
+
254
+ df = pl.DataFrame(job_results)
255
+ if not save:
256
+ print(df)
257
+ elif save:
258
+ if save_format == "parquet":
259
+ df.write_parquet(f"{job_id}.parquet")
260
+ else: # csv
261
+ df.write_csv(f"{job_id}.csv")
262
+ print(Fore.GREEN + f"Results saved to {job_id}.{save_format}" + Style.RESET_ALL)
263
+
264
+
265
+ @jobs.command()
266
+ @click.argument("job_id")
267
+ def cancel(job_id):
268
+ """Cancel a running job."""
269
+ sdk = get_sdk()
270
+ result = sdk.cancel_job(job_id)
271
+ if not result:
272
+ return
273
+
274
+ click.echo(Fore.GREEN + "Job cancelled successfully." + Style.RESET_ALL)
275
+
276
+
277
+ @cli.group()
278
+ def stages():
279
+ """Manage stages."""
280
+ pass
281
+
282
+
283
+ @stages.command()
284
+ def create():
285
+ """Create a new stage."""
286
+ sdk = get_sdk()
287
+ stage_id = sdk.create_stage()
288
+ if not stage_id:
289
+ return
290
+ click.echo(
291
+ Fore.GREEN
292
+ + f"Stage created successfully. Stage ID: {stage_id}"
293
+ + Style.RESET_ALL
294
+ )
295
+
296
+
297
+ @stages.command()
298
+ def list():
299
+ """List all stages."""
300
+ sdk = get_sdk()
301
+ stages = sdk.list_stages()
302
+ if stages is None or len(stages) == 0:
303
+ click.echo(Fore.YELLOW + "No stages found." + Style.RESET_ALL)
304
+ return
305
+ df = pl.DataFrame(stages)
306
+
307
+ df = df.with_columns(
308
+ pl.col("schema")
309
+ .map_elements(lambda x: str(x), return_dtype=pl.Utf8)
310
+ .alias("schema")
311
+ )
312
+
313
+ # Format all datetime columns with a more readable format
314
+ datetime_columns = ["datetime_added", "updated_at"]
315
+ df = set_human_readable_dates(datetime_columns, df)
316
+
317
+ df = df.sort(by=["datetime_added"], descending=True)
318
+ with pl.Config(tbl_rows=-1, tbl_cols=-1, set_fmt_str_lengths=45):
319
+ print(df.select(pl.all()))
320
+
321
+
322
+ @stages.command()
323
+ @click.argument("stage_id")
324
+ def files(stage_id):
325
+ """List all files in a stage."""
326
+ sdk = get_sdk()
327
+ files = sdk.list_stage_files(stage_id)
328
+ if not files:
329
+ return
330
+
331
+ print(Fore.YELLOW + "Files in stage " + stage_id + ":" + Style.RESET_ALL)
332
+ for file in files:
333
+ print(f"\t{file}")
334
+
335
+
336
+ @stages.command()
337
+ @click.argument("stage_id", required=False)
338
+ @click.argument("file_path")
339
+ def upload(file_path, stage_id):
340
+ """Upload files to a stage. You can provide a single file path or a directory path to upload all files in the directory."""
341
+ sdk = get_sdk()
342
+ sdk.upload_to_stage(file_path, stage_id)
343
+
344
+
345
+ @stages.command()
346
+ @click.argument("stage_id")
347
+ @click.argument("file_name", required=False)
348
+ @click.argument("output_path", required=False)
349
+ def download(stage_id, file_name=None, output_path=None):
350
+ """Download a file/files from a stage. If no files are provided, all files in the stage will be downloaded. If no output path is provided, the file will be saved to the current working directory."""
351
+ sdk = get_sdk()
352
+ files = sdk.download_from_stage(stage_id, [file_name], output_path)
353
+ if not files:
354
+ return
355
+ for file in files:
356
+ if output_path is None:
357
+ with open(file_name, "wb") as f:
358
+ f.write(file)
359
+ else:
360
+ with open(output_path + "/" + file_name, "wb") as f:
361
+ f.write(file)
362
+
363
+
364
+ @cli.command()
365
+ def docs():
366
+ """Open the Sutro API docs."""
367
+ click.launch("https://docs.sutro.sh")
368
+
369
+
370
+ @cli.command()
371
+ @click.argument("base_url")
372
+ def set_base_url(base_url):
373
+ """Set the base URL for the Sutro API."""
374
+ set_config_base_url(base_url)
375
+ click.echo(Fore.GREEN + f"Base URL set to {base_url}." + Style.RESET_ALL)
376
+
377
+
378
+ @cli.command()
379
+ def quotas():
380
+ """Get API quotas."""
381
+ sdk = get_sdk()
382
+ quotas = sdk.get_quotas()
383
+ if not quotas:
384
+ return
385
+ print(Fore.YELLOW + "Your current quotas are: \n" + Style.RESET_ALL)
386
+ for priority in range(len(quotas)):
387
+ quota = quotas[priority]
388
+ print(f"Job Priority: {priority}")
389
+ print(f"\tRow Quota (Maximum): {quota['row_quota']}")
390
+ print(f"\tToken Quota (Maximum): {quota['token_quota']}")
391
+ print("\n")
392
+ print(
393
+ Fore.YELLOW
394
+ + "To increase your quotas, contact us at team@sutro.sh."
395
+ + Style.RESET_ALL
396
+ )
397
+
398
+ @jobs.command()
399
+ @click.argument("job_id", required=False)
400
+ @click.option("--latest", is_flag=True, help="Attach to the latest job.")
401
+ def attach(job_id, latest):
402
+ """Attach to a running job and stream its progress."""
403
+ sdk = get_sdk()
404
+ if latest:
405
+ jobs = sdk.list_jobs()
406
+ if not jobs:
407
+ click.echo(Fore.YELLOW + "No jobs found." + Style.RESET_ALL)
408
+ return
409
+ job_id = jobs[0]["job_id"]
410
+ print(f"Attaching to latest job: {job_id}")
411
+ elif not job_id:
412
+ click.echo(Fore.YELLOW + "No job ID provided." + Style.RESET_ALL)
413
+ return
414
+ sdk.attach(job_id)
415
+
416
+
417
+ if __name__ == "__main__":
418
+ cli()