sutro 0.1.37__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/sdk.py CHANGED
@@ -1,49 +1,32 @@
1
- from enum import Enum
2
1
  import requests
3
2
  import pandas as pd
4
3
  import polars as pl
5
4
  import json
6
- from typing import Union, List, Optional, Literal, Dict, Any
5
+ from typing import Union, List, Optional, Dict, Any, Type
7
6
  import os
8
7
  import sys
9
8
  from yaspin import yaspin
10
9
  from yaspin.spinners import Spinners
11
- from colorama import init, Fore, Style
12
- from tqdm import tqdm
10
+ from colorama import init
13
11
  import time
14
12
  from pydantic import BaseModel
15
13
  import pyarrow.parquet as pq
16
14
  import shutil
17
- import importlib.metadata
15
+ from sutro.common import (
16
+ ModelOptions,
17
+ handle_data_helper,
18
+ normalize_output_schema,
19
+ to_colored_text,
20
+ fancy_tqdm,
21
+ )
22
+ from sutro.interfaces import JobStatus
23
+ from sutro.templates.classification import ClassificationTemplates
24
+ from sutro.templates.embed import EmbeddingTemplates
25
+ from sutro.validation import check_version, check_for_api_key
18
26
 
19
27
  JOB_NAME_CHAR_LIMIT = 45
20
28
  JOB_DESCRIPTION_CHAR_LIMIT = 512
21
29
 
22
- class JobStatus(str, Enum):
23
- """Job statuses that will be returned by the API & SDK"""
24
-
25
- UNKNOWN = "UNKNOWN"
26
- QUEUED = "QUEUED" # Job is waiting to start
27
- STARTING = "STARTING" # Job is in the process of starting up
28
- RUNNING = "RUNNING" # Job is actively running
29
- SUCCEEDED = "SUCCEEDED" # Job completed successfully
30
- CANCELLING = "CANCELLING" # Job is in the process of being canceled
31
- CANCELLED = "CANCELLED" # Job was canceled by the user
32
- FAILED = "FAILED" # Job failed
33
-
34
- @classmethod
35
- def terminal_statuses(cls) -> list["JobStatus"]:
36
- return [
37
- cls.SUCCEEDED,
38
- cls.FAILED,
39
- cls.CANCELLING,
40
- cls.CANCELLED,
41
- ]
42
-
43
- def is_terminal(self) -> bool:
44
- return self in self.terminal_statuses()
45
-
46
-
47
30
  # Initialize colorama (required for Windows)
48
31
  init()
49
32
 
@@ -57,59 +40,6 @@ def is_jupyter() -> bool:
57
40
  YASPIN_COLOR = None if is_jupyter() else "blue"
58
41
  SPINNER = Spinners.dots14
59
42
 
60
- # Models available for inference. Keep in sync with the backend configuration
61
- # so users get helpful autocompletion when selecting a model.
62
- ModelOptions = Literal[
63
- "llama-3.2-3b",
64
- "llama-3.1-8b",
65
- "llama-3.3-70b",
66
- "llama-3.3-70b",
67
- "qwen-3-4b",
68
- "qwen-3-14b",
69
- "qwen-3-32b",
70
- "qwen-3-30b-a3b",
71
- "qwen-3-235b-a22b",
72
- "qwen-3-4b-thinking",
73
- "qwen-3-14b-thinking",
74
- "qwen-3-32b-thinking",
75
- "qwen-3-235b-a22b-thinking",
76
- "qwen-3-30b-a3b-thinking",
77
- "gemma-3-4b-it",
78
- "gemma-3-12b-it",
79
- "gemma-3-27b-it",
80
- "gpt-oss-20b",
81
- "gpt-oss-120b",
82
- "qwen-3-embedding-0.6b",
83
- "qwen-3-embedding-6b",
84
- "qwen-3-embedding-8b",
85
- ]
86
-
87
-
88
- def to_colored_text(
89
- text: str, state: Optional[Literal["success", "fail", "callout"]] = None
90
- ) -> str:
91
- """
92
- Apply color to text based on state.
93
-
94
- Args:
95
- text (str): The text to color
96
- state (Optional[Literal['success', 'fail']]): The state that determines the color.
97
- Options: 'success', 'fail', or None (default blue)
98
-
99
- Returns:
100
- str: Text with appropriate color applied
101
- """
102
- match state:
103
- case "success":
104
- return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
105
- case "fail":
106
- return f"{Fore.RED}{text}{Style.RESET_ALL}"
107
- case "callout":
108
- return f"{Fore.MAGENTA}{text}{Style.RESET_ALL}"
109
- case _:
110
- # Default to blue for normal/processing states
111
- return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
112
-
113
43
 
114
44
  # Isn't fully support in all terminals unfortunately. We should switch to Rich
115
45
  # at some point, but even Rich links aren't clickable on MacOS Terminal
@@ -123,64 +53,11 @@ def make_clickable_link(url, text=None):
123
53
  return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
124
54
 
125
55
 
126
- class Sutro:
56
+ class Sutro(EmbeddingTemplates, ClassificationTemplates):
127
57
  def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
128
- self.api_key = api_key or self.check_for_api_key()
58
+ self.api_key = api_key or check_for_api_key()
129
59
  self.base_url = base_url
130
- self.check_version("sutro")
131
-
132
- def check_version(self, package_name: str):
133
- try:
134
- # Local version
135
- local_version = importlib.metadata.version(package_name)
136
- except importlib.metadata.PackageNotFoundError:
137
- print(f"{package_name} is not installed.")
138
- return
139
-
140
- try:
141
- # Latest release from PyPI
142
- resp = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=2)
143
- resp.raise_for_status()
144
- latest_version = resp.json()["info"]["version"]
145
-
146
- if local_version != latest_version:
147
- msg = (f"⚠️ You are using {package_name} {local_version}, "
148
- f"but the latest release is {latest_version}. "
149
- f"Run `[uv] pip install -U {package_name}` to upgrade.")
150
- print(to_colored_text(
151
- msg,
152
- state="callout"
153
- )
154
- )
155
- except Exception as e:
156
- # Fail silently or log, you don’t want this blocking usage
157
- pass
158
-
159
- def check_for_api_key(self):
160
- """
161
- Check for an API key in the user's home directory.
162
-
163
- This method looks for a configuration file named 'config.json' in the
164
- '.sutro' directory within the user's home directory.
165
- If the file exists, it attempts to read the API key from it.
166
-
167
- Returns:
168
- str or None: The API key if found in the configuration file, or None if not found.
169
-
170
- Note:
171
- The expected structure of the config.json file is:
172
- {
173
- "api_key": "your_api_key_here"
174
- }
175
- """
176
- CONFIG_DIR = os.path.expanduser("~/.sutro")
177
- CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
178
- if os.path.exists(CONFIG_FILE):
179
- with open(CONFIG_FILE, "r") as f:
180
- config = json.load(f)
181
- return config.get("api_key")
182
- else:
183
- return None
60
+ check_version("sutro")
184
61
 
185
62
  def set_api_key(self, api_key: str):
186
63
  """
@@ -197,79 +74,6 @@ class Sutro:
197
74
  """
198
75
  self.api_key = api_key
199
76
 
200
- def do_dataframe_column_concatenation(self, data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]):
201
- """
202
- If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
203
- """
204
- try:
205
- if isinstance(data, pd.DataFrame):
206
- series_parts = []
207
- for p in column:
208
- if p in data.columns:
209
- s = data[p].astype("string").fillna("")
210
- else:
211
- # Treat as a literal separator
212
- s = pd.Series([p] * len(data), index=data.index, dtype="string")
213
- series_parts.append(s)
214
-
215
- out = series_parts[0]
216
- for s in series_parts[1:]:
217
- out = out.str.cat(s, na_rep="")
218
-
219
- return out.tolist()
220
- elif isinstance(data, pl.DataFrame):
221
- exprs = []
222
- for p in column:
223
- if p in data.columns:
224
- exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
225
- else:
226
- exprs.append(pl.lit(p))
227
-
228
- result = data.select(pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat"))
229
- return result["concat"].to_list()
230
- except Exception as e:
231
- raise ValueError(f"Error handling column concatentation: {e}")
232
-
233
- def handle_data_helper(
234
- self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
235
- ):
236
- if isinstance(data, list):
237
- input_data = data
238
- elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
239
- if column is None:
240
- raise ValueError("Column name must be specified for DataFrame input")
241
- if isinstance(column, list):
242
- input_data = self.do_dataframe_column_concatenation(data, column)
243
- elif isinstance(column, str):
244
- input_data = data[column].to_list()
245
- elif isinstance(data, str):
246
- if data.startswith("dataset-"):
247
- input_data = data + ":" + column
248
- else:
249
- file_ext = os.path.splitext(data)[1].lower()
250
- if file_ext == ".csv":
251
- df = pl.read_csv(data)
252
- elif file_ext == ".parquet":
253
- df = pl.read_parquet(data)
254
- elif file_ext in [".txt", ""]:
255
- with open(data, "r") as file:
256
- input_data = [line.strip() for line in file]
257
- else:
258
- raise ValueError(f"Unsupported file type: {file_ext}")
259
-
260
- if file_ext in [".csv", ".parquet"]:
261
- if column is None:
262
- raise ValueError(
263
- "Column name must be specified for CSV/Parquet input"
264
- )
265
- input_data = df[column].to_list()
266
- else:
267
- raise ValueError(
268
- "Unsupported data type. Please provide a list, DataFrame, or file path."
269
- )
270
-
271
- return input_data
272
-
273
77
  def set_base_url(self, base_url: str):
274
78
  """
275
79
  Set the base URL for the Sutro API.
@@ -282,6 +86,43 @@ class Sutro:
282
86
  """
283
87
  self.base_url = base_url
284
88
 
89
+ def do_request(
90
+ self,
91
+ method: str,
92
+ endpoint: str,
93
+ api_key_override: Optional[str] = None,
94
+ **kwargs: Any,
95
+ ):
96
+ """
97
+ Helper to make authenticated requests.
98
+ """
99
+ key = self.api_key if not api_key_override else api_key_override
100
+ headers = {"Authorization": f"Key {key}"}
101
+
102
+ # Merge with any headers passed in kwargs
103
+ if "headers" in kwargs:
104
+ headers.update(kwargs.pop("headers"))
105
+
106
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
107
+
108
+ # Explicit method dispatch
109
+ method = method.upper()
110
+ if method == "GET":
111
+ response = requests.get(url, headers=headers, **kwargs)
112
+ elif method == "POST":
113
+ response = requests.post(url, headers=headers, **kwargs)
114
+ elif method == "PUT":
115
+ response = requests.put(url, headers=headers, **kwargs)
116
+ elif method == "DELETE":
117
+ response = requests.delete(url, headers=headers, **kwargs)
118
+ elif method == "PATCH":
119
+ response = requests.patch(url, headers=headers, **kwargs)
120
+ else:
121
+ raise ValueError(f"Unsupported HTTP method: {method}")
122
+
123
+ response.raise_for_status()
124
+ return response
125
+
285
126
  def _run_one_batch_inference(
286
127
  self,
287
128
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
@@ -301,16 +142,15 @@ class Sutro:
301
142
  ):
302
143
  # Validate name and description lengths
303
144
  if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
304
- raise ValueError(f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters.")
145
+ raise ValueError(
146
+ f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters."
147
+ )
305
148
  if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
306
- raise ValueError(f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters.")
149
+ raise ValueError(
150
+ f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters."
151
+ )
307
152
 
308
- input_data = self.handle_data_helper(data, column)
309
- endpoint = f"{self.base_url}/batch-inference"
310
- headers = {
311
- "Authorization": f"Key {self.api_key}",
312
- "Content-Type": "application/json",
313
- }
153
+ input_data = handle_data_helper(data, column)
314
154
  payload = {
315
155
  "model": model,
316
156
  "inputs": input_data,
@@ -336,16 +176,19 @@ class Sutro:
336
176
  spinner_text = to_colored_text(t)
337
177
  try:
338
178
  with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
339
- response = requests.post(
340
- endpoint, data=json.dumps(payload), headers=headers
341
- )
342
- response_data = response.json()
179
+ try:
180
+ response = self.do_request("POST", "batch-inference", json=payload)
181
+ response_data = response.json()
182
+ except requests.HTTPError as e:
183
+ response = e.response
184
+ response_data = response.json()
185
+
343
186
  if response.status_code != 200:
344
187
  spinner.write(
345
188
  to_colored_text(f"Error: {response.status_code}", state="fail")
346
189
  )
347
190
  spinner.stop()
348
- print(to_colored_text(response.json(), state="fail"))
191
+ print(to_colored_text(response_data, state="fail"))
349
192
  return None
350
193
  else:
351
194
  job_id = response_data["results"]
@@ -371,10 +214,11 @@ class Sutro:
371
214
  name_text = f" and name {name}" if name is not None else ""
372
215
  spinner.write(
373
216
  to_colored_text(
374
- f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}.",
217
+ f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}",
375
218
  state="success",
376
219
  )
377
220
  )
221
+ spinner.write(to_colored_text(f"Model: {model}"))
378
222
  if not stay_attached:
379
223
  clickable_link = make_clickable_link(
380
224
  f"https://app.sutro.sh/jobs/{job_id}"
@@ -411,13 +255,13 @@ class Sutro:
411
255
  )
412
256
  )
413
257
  return None
414
- s = requests.Session()
258
+
415
259
  pbar = None
416
260
 
417
261
  try:
418
- with requests.get(
419
- f"{self.base_url}/stream-job-progress/{job_id}",
420
- headers=headers,
262
+ with self.do_request(
263
+ "GET",
264
+ f"/stream-job-progress/{job_id}",
421
265
  stream=True,
422
266
  ) as streaming_response:
423
267
  streaming_response.raise_for_status()
@@ -446,7 +290,7 @@ class Sutro:
446
290
  if pbar is None:
447
291
  spinner.stop()
448
292
  postfix = "Input tokens processed: 0"
449
- pbar = self.fancy_tqdm(
293
+ pbar = fancy_tqdm(
450
294
  total=len(input_data),
451
295
  desc="Progress",
452
296
  style=1,
@@ -487,28 +331,27 @@ class Sutro:
487
331
  )
488
332
  spinner.start()
489
333
 
490
- payload = {
491
- "job_id": job_id,
492
- }
493
-
494
334
  # TODO: we implment retries in cases where the job hasn't written results yet
495
335
  # it would be better if we could receive a fully succeeded status from the job
496
336
  # and not have such a race condition
497
337
  max_retries = 20 # winds up being 100 seconds cumulative delay
498
338
  retry_delay = 5 # initial delay in seconds
499
-
339
+ job_results_response = None
500
340
  for _ in range(max_retries):
501
- time.sleep(retry_delay)
502
-
503
- job_results_response = s.post(
504
- f"{self.base_url}/job-results",
505
- headers=headers,
506
- data=json.dumps(payload),
507
- )
508
- if job_results_response.status_code == 200:
341
+ try:
342
+ job_results_response = self.do_request(
343
+ "POST",
344
+ "job-results",
345
+ json={
346
+ "job_id": job_id,
347
+ },
348
+ )
509
349
  break
350
+ except requests.HTTPError:
351
+ time.sleep(retry_delay)
352
+ continue
510
353
 
511
- if job_results_response.status_code != 200:
354
+ if not job_results_response or job_results_response.status_code != 200:
512
355
  spinner.write(
513
356
  to_colored_text(
514
357
  "Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
@@ -535,11 +378,11 @@ class Sutro:
535
378
  else:
536
379
  print(results)
537
380
  spinner.write(
538
- to_colored_text(
539
- f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
540
- state="success",
381
+ to_colored_text(
382
+ f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
383
+ state="success",
384
+ )
541
385
  )
542
- )
543
386
  spinner.stop()
544
387
 
545
388
  return job_id
@@ -549,13 +392,13 @@ class Sutro:
549
392
  def infer(
550
393
  self,
551
394
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
552
- model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
553
- name: Union[str, List[str]] = None,
554
- description: Union[str, List[str]] = None,
395
+ model: ModelOptions = "gemma-3-12b-it",
396
+ name: Optional[str] = None,
397
+ description: Optional[str] = None,
555
398
  column: Union[str, List[str]] = None,
556
399
  output_column: str = "inference_result",
557
400
  job_priority: int = 0,
558
- output_schema: Union[Dict[str, Any], BaseModel] = None,
401
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
559
402
  sampling_params: dict = None,
560
403
  system_prompt: str = None,
561
404
  dry_run: bool = False,
@@ -567,18 +410,18 @@ class Sutro:
567
410
  Run inference on the provided data.
568
411
 
569
412
  This method allows you to run inference on the provided data using the Sutro API.
570
- It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
413
+ It supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
571
414
 
572
415
  Args:
573
416
  data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
574
- model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
575
- name (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
576
- description (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
417
+ model (ModelOptions, optional): The model to use for inference. Defaults to "gemma-3-12b-it".
418
+ name (str, optional): A job name for experiment/metadata tracking purposes. Defaults to None.
419
+ description (str, optional): A job description for experiment/metadata tracking purposes. Defaults to None.
577
420
  column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
578
421
  output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
579
422
  job_priority (int, optional): The priority of the job. Defaults to 0.
580
423
  output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
581
- Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
424
+ Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
582
425
  sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
583
426
  system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
584
427
  dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
@@ -590,63 +433,113 @@ class Sutro:
590
433
  str: The ID of the inference job.
591
434
 
592
435
  """
593
- if isinstance(model, list) == False:
594
- model_list = [model]
595
- stay_attached = (
596
- stay_attached if stay_attached is not None else job_priority == 0
597
- )
598
- else:
599
- model_list = model
600
- stay_attached = False
601
-
602
- if isinstance(model_list, list):
603
- if isinstance(name, list):
604
- if len(name) != len(model_list):
605
- raise ValueError("Name list must be the same length as the model list.")
606
- name_list = name
607
- elif isinstance(name, str):
608
- raise ValueError("Name must be a list if using a list of models.")
609
- elif name is None:
610
- name_list = [None] * len(model_list)
611
- else:
612
- if isinstance(name, list):
613
- raise ValueError("Name must be a string or None if using a single model.")
614
- name_list = [name]
615
-
616
- if isinstance(model_list, list):
617
- if isinstance(description, list):
618
- if len(description) != len(model_list):
619
- raise ValueError("Descriptions list must be the same length as the model list.")
620
- description_list = description
621
- elif isinstance(description, str):
622
- raise ValueError("Description must be a list if using a list of models.")
623
- elif description is None:
624
- description_list = [None] * len(model_list)
436
+ # Default stay_attached to True for prototyping jobs (priority 0)
437
+ if stay_attached is None:
438
+ stay_attached = job_priority == 0
439
+
440
+ json_schema = None
441
+ if output_schema:
442
+ # Convert BaseModel to dict if needed
443
+ json_schema = normalize_output_schema(output_schema)
444
+
445
+ return self._run_one_batch_inference(
446
+ data,
447
+ model,
448
+ column,
449
+ output_column,
450
+ job_priority,
451
+ json_schema,
452
+ sampling_params,
453
+ system_prompt,
454
+ dry_run,
455
+ stay_attached,
456
+ random_seed_per_input,
457
+ truncate_rows,
458
+ name,
459
+ description,
460
+ )
461
+
462
+ def infer_per_model(
463
+ self,
464
+ data: Union[List, pd.DataFrame, pl.DataFrame, str],
465
+ models: List[ModelOptions],
466
+ names: List[str] = None,
467
+ descriptions: List[str] = None,
468
+ column: Union[str, List[str]] = None,
469
+ output_column: str = "inference_result",
470
+ job_priority: int = 0,
471
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
472
+ sampling_params: dict = None,
473
+ system_prompt: str = None,
474
+ dry_run: bool = False,
475
+ random_seed_per_input: bool = False,
476
+ truncate_rows: bool = True,
477
+ ):
478
+ """
479
+ Run inference on the provided data, across multiple models. This method is often useful to sampling outputs from multiple models across the same dataset and compare the job_ids.
480
+
481
+ For input data, it supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
482
+
483
+ Args:
484
+ data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
485
+ models (Union[ModelOptions, List[ModelOptions]], optional): The models to use for inference. Fans out each model to its own seperate job, over the same dataset.
486
+ names (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
487
+ descriptions (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
488
+ column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
489
+ output_column (str, optional): The column name to store the inference job_ids in if the input is a DataFrame. Defaults to "inference_result".
490
+ job_priority (int, optional): The priority of the job. Defaults to 0.
491
+ output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
492
+ Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
493
+ sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
494
+ system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
495
+ dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
496
+ stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
497
+ random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
498
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
499
+
500
+ Returns:
501
+ str: The ID of the inference job.
502
+
503
+ """
504
+ if isinstance(names, list):
505
+ if len(names) != len(models):
506
+ raise ValueError(
507
+ "names parameter must be the same length as the models parameter."
508
+ )
509
+ elif names is None:
510
+ names = [None] * len(models)
625
511
  else:
626
- if isinstance(name, list):
627
- raise ValueError("Description must be a string or None if using a single model.")
628
- description_list = [description]
629
-
630
- # Convert BaseModel to dict if needed
631
- if output_schema is not None:
632
- if hasattr(
633
- output_schema, "model_json_schema"
634
- ): # Check for pydantic Model interface
635
- json_schema = output_schema.model_json_schema()
636
- elif isinstance(output_schema, dict):
637
- json_schema = output_schema
638
- else:
512
+ raise ValueError(
513
+ "names parameter must be a list or None if using a list of models"
514
+ )
515
+
516
+ if isinstance(descriptions, list):
517
+ if len(descriptions) != len(models):
639
518
  raise ValueError(
640
- "Invalid output schema type. Must be a dictionary or a pydantic Model."
519
+ "descriptions parameter must be the same length as the models"
520
+ " parameter."
641
521
  )
522
+ elif descriptions is None:
523
+ descriptions = [None] * len(models)
642
524
  else:
643
- json_schema = None
644
-
645
- results = []
646
- for i in range(len(model_list)):
647
- res = self._run_one_batch_inference(
525
+ raise ValueError(
526
+ "descriptions parameter must be a list or None if using a list of "
527
+ "models"
528
+ )
529
+
530
+ json_schema = None
531
+ if output_schema:
532
+ # Convert BaseModel to dict if needed
533
+ json_schema = normalize_output_schema(output_schema)
534
+
535
+ def start_job(
536
+ model_singleton: ModelOptions,
537
+ name_singleton: str | None,
538
+ description_singleton: str | None,
539
+ ):
540
+ return self._run_one_batch_inference(
648
541
  data,
649
- model_list[i],
542
+ model_singleton,
650
543
  column,
651
544
  output_column,
652
545
  job_priority,
@@ -654,20 +547,21 @@ class Sutro:
654
547
  sampling_params,
655
548
  system_prompt,
656
549
  dry_run,
657
- stay_attached,
550
+ False,
658
551
  random_seed_per_input,
659
552
  truncate_rows,
660
- name_list[i],
661
- description_list[i],
553
+ name_singleton,
554
+ description_singleton,
662
555
  )
663
- results.append(res)
664
556
 
665
- if len(results) > 1:
666
- return results
667
- elif len(results) == 1:
668
- return results[0]
557
+ job_ids = [
558
+ start_job(model, name, description)
559
+ for model, name, description in zip(
560
+ models, names, descriptions, strict=True
561
+ )
562
+ ]
669
563
 
670
- return None
564
+ return job_ids
671
565
 
672
566
  def attach(self, job_id):
673
567
  """
@@ -678,16 +572,8 @@ class Sutro:
678
572
  """
679
573
 
680
574
  s = requests.Session()
681
- payload = {
682
- "job_id": job_id,
683
- }
684
575
  pbar = None
685
576
 
686
- headers = {
687
- "Authorization": f"Key {self.api_key}",
688
- "Content-Type": "application/json",
689
- }
690
-
691
577
  with yaspin(
692
578
  SPINNER,
693
579
  text=to_colored_text("Looking for job..."),
@@ -725,9 +611,9 @@ class Sutro:
725
611
  success = False
726
612
 
727
613
  try:
728
- with s.get(
729
- f"{self.base_url}/stream-job-progress/{job_id}",
730
- headers=headers,
614
+ with self.do_request(
615
+ "GET",
616
+ f"/stream-job-progress/{job_id}",
731
617
  stream=True,
732
618
  ) as streaming_response:
733
619
  streaming_response.raise_for_status()
@@ -757,7 +643,7 @@ class Sutro:
757
643
  if pbar is None:
758
644
  spinner.stop()
759
645
  postfix = "Input tokens processed: 0"
760
- pbar = self.fancy_tqdm(
646
+ pbar = fancy_tqdm(
761
647
  total=total_rows,
762
648
  desc="Progress",
763
649
  style=1,
@@ -790,65 +676,6 @@ class Sutro:
790
676
  if spinner:
791
677
  spinner.stop()
792
678
 
793
- def fancy_tqdm(
794
- self,
795
- total: int,
796
- desc: str = "Progress",
797
- color: str = "blue",
798
- style=1,
799
- postfix: str = None,
800
- ):
801
- """
802
- Creates a customized tqdm progress bar with different styling options.
803
-
804
- Args:
805
- total (int): Total iterations
806
- desc (str): Description for the progress bar
807
- color (str): Color of the progress bar (green, blue, red, yellow, magenta)
808
- style (int): Style preset (1-4)
809
- postfix (str): Postfix for the progress bar
810
- """
811
-
812
- # Style presets
813
- style_presets = {
814
- 1: {
815
- "bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
816
- "ascii": "░▒█",
817
- },
818
- 2: {
819
- "bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
820
- "ascii": "▁▂▃▄▅▆▇█",
821
- },
822
- 3: {
823
- "bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
824
- "ascii": "◯◔◑◕●",
825
- },
826
- 4: {
827
- "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
828
- "ascii": "⬜⬛",
829
- },
830
- 5: {
831
- "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
832
- "ascii": "▏▎▍▌▋▊▉█",
833
- },
834
- }
835
-
836
- # Get style configuration
837
- style_config = style_presets.get(style, style_presets[1])
838
-
839
- return tqdm(
840
- total=total,
841
- desc=desc,
842
- colour=color,
843
- bar_format=style_config["bar_format"],
844
- ascii=style_config["ascii"],
845
- ncols=80,
846
- dynamic_ncols=True,
847
- smoothing=0.3,
848
- leave=True,
849
- postfix=postfix,
850
- )
851
-
852
679
  def list_jobs(self):
853
680
  """
854
681
  List all jobs.
@@ -856,56 +683,36 @@ class Sutro:
856
683
  This method retrieves a list of all jobs associated with the API key.
857
684
 
858
685
  Returns:
859
- list: A list of job details.
686
+ list: A list of job details, or None if the request fails.
860
687
  """
861
- endpoint = f"{self.base_url}/list-jobs"
862
- headers = {
863
- "Authorization": f"Key {self.api_key}",
864
- "Content-Type": "application/json",
865
- }
866
-
867
688
  with yaspin(
868
689
  SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
869
690
  ) as spinner:
870
- response = requests.get(endpoint, headers=headers)
871
- if response.status_code != 200:
691
+ try:
692
+ return self._list_all_jobs_for_user()
693
+ except requests.HTTPError as e:
872
694
  spinner.write(
873
695
  to_colored_text(
874
- f"Bad status code: {response.status_code}", state="fail"
696
+ f"Bad status code: {e.response.status_code}", state="fail"
875
697
  )
876
698
  )
877
699
  spinner.stop()
878
- print(to_colored_text(response.json(), state="fail"))
879
- return
880
- return response.json()["jobs"]
700
+ print(to_colored_text(e.response.json(), state="fail"))
701
+ return None
881
702
 
882
- def _list_jobs_helper(self):
883
- """
884
- Helper function to list jobs.
885
- """
886
- endpoint = f"{self.base_url}/list-jobs˚"
887
- headers = {
888
- "Authorization": f"Key {self.api_key}",
889
- "Content-Type": "application/json",
890
- }
891
- response = requests.get(endpoint, headers=headers)
892
- if response.status_code != 200:
893
- return None
703
+ def _list_all_jobs_for_user(self):
704
+ response = self.do_request("GET", "list-jobs")
894
705
  return response.json()["jobs"]
895
706
 
896
707
  def _fetch_job(self, job_id):
897
708
  """
898
709
  Helper function to fetch a single job.
899
710
  """
900
- endpoint = f"{self.base_url}/jobs/{job_id}"
901
- headers = {
902
- "Authorization": f"Key {self.api_key}",
903
- "Content-Type": "application/json",
904
- }
905
- response = requests.get(endpoint, headers=headers)
906
- if response.status_code != 200:
711
+ try:
712
+ response = self.do_request("GET", f"jobs/{job_id}")
713
+ return response.json().get("job")
714
+ except requests.HTTPError:
907
715
  return None
908
- return response.json().get("job")
909
716
 
910
717
  def _get_job_cost_estimate(self, job_id: str):
911
718
  """
@@ -939,15 +746,7 @@ class Sutro:
939
746
  Raises:
940
747
  requests.HTTPError: If the API returns a non-200 status code.
941
748
  """
942
- endpoint = f"{self.base_url}/job-status/{job_id}"
943
- headers = {
944
- "Authorization": f"Key {self.api_key}",
945
- "Content-Type": "application/json",
946
- }
947
-
948
- response = requests.get(endpoint, headers=headers)
949
- response.raise_for_status()
950
-
749
+ response = self.do_request("GET", f"job-status/{job_id}")
951
750
  return response.json()["job_status"][job_id]
952
751
 
953
752
  def get_job_status(self, job_id: str):
@@ -992,7 +791,7 @@ class Sutro:
992
791
  output_column: str = "inference_result",
993
792
  disable_cache: bool = False,
994
793
  unpack_json: bool = True,
995
- ):
794
+ ) -> pl.DataFrame | pd.DataFrame:
996
795
  """
997
796
  Get the results of a job by its ID.
998
797
 
@@ -1029,44 +828,37 @@ class Sutro:
1029
828
  to_colored_text("✔ Results loaded from cache", state="success")
1030
829
  )
1031
830
  else:
1032
- endpoint = f"{self.base_url}/job-results"
1033
831
  payload = {
1034
832
  "job_id": job_id,
1035
833
  "include_inputs": include_inputs,
1036
834
  "include_cumulative_logprobs": include_cumulative_logprobs,
1037
835
  }
1038
- headers = {
1039
- "Authorization": f"Key {self.api_key}",
1040
- "Content-Type": "application/json",
1041
- }
1042
836
  with yaspin(
1043
837
  SPINNER,
1044
838
  text=to_colored_text(f"Gathering results from job: {job_id}"),
1045
839
  color=YASPIN_COLOR,
1046
840
  ) as spinner:
1047
- response = requests.post(
1048
- endpoint, data=json.dumps(payload), headers=headers
1049
- )
1050
- if response.status_code != 200:
841
+ try:
842
+ response = self.do_request("POST", "job-results", json=payload)
843
+ response_data = response.json()
844
+ spinner.write(
845
+ to_colored_text("✔ Job results retrieved", state="success")
846
+ )
847
+ except requests.HTTPError as e:
1051
848
  spinner.write(
1052
849
  to_colored_text(
1053
- f"Bad status code: {response.status_code}", state="fail"
850
+ f"Bad status code: {e.response.status_code}", state="fail"
1054
851
  )
1055
852
  )
1056
853
  spinner.stop()
1057
- print(to_colored_text(response.json(), state="fail"))
854
+ print(to_colored_text(e.response.json(), state="fail"))
1058
855
  return None
1059
856
 
1060
- spinner.write(
1061
- to_colored_text("✔ Job results retrieved", state="success")
1062
- )
1063
-
1064
- response_data = response.json()
1065
857
  results_df = pl.DataFrame(response_data["results"])
1066
858
 
1067
859
  results_df = results_df.rename({"outputs": output_column})
1068
860
 
1069
- if disable_cache == False:
861
+ if not disable_cache:
1070
862
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
1071
863
  results_df.write_parquet(file_path, compression="snappy")
1072
864
  spinner.write(
@@ -1093,10 +885,11 @@ class Sutro:
1093
885
  first_row = json.loads(
1094
886
  results_df.head(1)[output_column][0]
1095
887
  ) # checks if the first row can be json decoded
1096
- results_df = results_df.map_columns(output_column, lambda s: s.str.json_decode())
888
+ results_df = results_df.map_columns(
889
+ output_column, lambda s: s.str.json_decode()
890
+ )
1097
891
  results_df = results_df.with_columns(
1098
- pl.col(output_column)
1099
- .alias("output_column_json_decoded")
892
+ pl.col(output_column).alias("output_column_json_decoded")
1100
893
  )
1101
894
  json_decoded_fields = first_row.keys()
1102
895
  for field in json_decoded_fields:
@@ -1105,19 +898,20 @@ class Sutro:
1105
898
  .struct.field(field)
1106
899
  .alias(field)
1107
900
  )
1108
- if sorted(list(set(json_decoded_fields))) == ['content', 'reasoning_content']: # if it's a reasoning model, we need to unpack the content field
1109
- content_keys = results_df.head(1)['content'][0].keys()
901
+ if sorted(list(set(json_decoded_fields))) == [
902
+ "content",
903
+ "reasoning_content",
904
+ ]: # if it's a reasoning model, we need to unpack the content field
905
+ content_keys = results_df.head(1)["content"][0].keys()
1110
906
  for key in content_keys:
1111
907
  results_df = results_df.with_columns(
1112
- pl.col("content")
1113
- .struct.field(key)
1114
- .alias(key)
908
+ pl.col("content").struct.field(key).alias(key)
1115
909
  )
1116
910
  results_df = results_df.drop("content")
1117
911
  results_df = results_df.drop(
1118
912
  [output_column, "output_column_json_decoded"]
1119
913
  )
1120
- except Exception as e:
914
+ except Exception:
1121
915
  # if the first row cannot be json decoded, do nothing
1122
916
  pass
1123
917
 
@@ -1153,25 +947,20 @@ class Sutro:
1153
947
  Returns:
1154
948
  dict: The status of the job.
1155
949
  """
1156
- endpoint = f"{self.base_url}/job-cancel/{job_id}"
1157
- headers = {
1158
- "Authorization": f"Key {self.api_key}",
1159
- "Content-Type": "application/json",
1160
- }
1161
950
  with yaspin(
1162
951
  SPINNER,
1163
952
  text=to_colored_text(f"Cancelling job: {job_id}"),
1164
953
  color=YASPIN_COLOR,
1165
954
  ) as spinner:
1166
- response = requests.get(endpoint, headers=headers)
1167
- if response.status_code == 200:
955
+ try:
956
+ response = self.do_request("GET", f"job-cancel/{job_id}")
1168
957
  spinner.write(to_colored_text("✔ Job cancelled", state="success"))
1169
- else:
958
+ return response.json()
959
+ except requests.HTTPError as e:
1170
960
  spinner.write(to_colored_text("Failed to cancel job", state="fail"))
1171
961
  spinner.stop()
1172
- print(to_colored_text(response.json(), state="fail"))
1173
- return
1174
- return response.json()
962
+ print(to_colored_text(e.response.json(), state="fail"))
963
+ return None
1175
964
 
1176
965
  def create_dataset(self):
1177
966
  """
@@ -1182,31 +971,27 @@ class Sutro:
1182
971
  Returns:
1183
972
  str: The ID of the new dataset.
1184
973
  """
1185
- endpoint = f"{self.base_url}/create-dataset"
1186
- headers = {
1187
- "Authorization": f"Key {self.api_key}",
1188
- "Content-Type": "application/json",
1189
- }
1190
974
  with yaspin(
1191
975
  SPINNER, text=to_colored_text("Creating dataset"), color=YASPIN_COLOR
1192
976
  ) as spinner:
1193
- response = requests.get(endpoint, headers=headers)
1194
- if response.status_code != 200:
977
+ try:
978
+ response = self.do_request("GET", "create-dataset")
979
+ dataset_id = response.json()["dataset_id"]
1195
980
  spinner.write(
1196
981
  to_colored_text(
1197
- f"Bad status code: {response.status_code}", state="fail"
982
+ f" Dataset created with ID: {dataset_id}", state="success"
1198
983
  )
1199
984
  )
1200
- spinner.stop()
1201
- print(to_colored_text(response.json(), state="fail"))
1202
- return
1203
- dataset_id = response.json()["dataset_id"]
1204
- spinner.write(
1205
- to_colored_text(
1206
- f"✔ Dataset created with ID: {dataset_id}", state="success"
985
+ return dataset_id
986
+ except requests.HTTPError as e:
987
+ spinner.write(
988
+ to_colored_text(
989
+ f"Bad status code: {e.response.status_code}", state="fail"
990
+ )
1207
991
  )
1208
- )
1209
- return dataset_id
992
+ spinner.stop()
993
+ print(to_colored_text(e.response.json(), state="fail"))
994
+ return None
1210
995
 
1211
996
  def upload_to_dataset(
1212
997
  self,
@@ -1238,8 +1023,6 @@ class Sutro:
1238
1023
  if dataset_id is None:
1239
1024
  dataset_id = self.create_dataset()
1240
1025
 
1241
- endpoint = f"{self.base_url}/upload-to-dataset"
1242
-
1243
1026
  if isinstance(file_paths, str):
1244
1027
  # check if the file path is a directory
1245
1028
  if os.path.isdir(file_paths):
@@ -1272,8 +1055,6 @@ class Sutro:
1272
1055
  "dataset_id": dataset_id,
1273
1056
  }
1274
1057
 
1275
- headers = {"Authorization": f"Key {self.api_key}"}
1276
-
1277
1058
  count += 1
1278
1059
  spinner.write(
1279
1060
  to_colored_text(
@@ -1282,25 +1063,18 @@ class Sutro:
1282
1063
  )
1283
1064
 
1284
1065
  try:
1285
- response = requests.post(
1286
- endpoint, headers=headers, data=payload, files=files
1066
+ self.do_request(
1067
+ "POST",
1068
+ "/upload-to-dataset",
1069
+ data=payload,
1070
+ files=files,
1071
+ verify=verify_ssl,
1287
1072
  )
1288
- if response.status_code != 200:
1289
- # Stop spinner before showing error to avoid terminal width error
1290
- spinner.stop()
1291
- print(
1292
- to_colored_text(
1293
- f"Error: HTTP {response.status_code}", state="fail"
1294
- )
1295
- )
1296
- print(to_colored_text(response.json(), state="fail"))
1297
- return
1298
-
1299
1073
  except requests.exceptions.RequestException as e:
1300
1074
  # Stop spinner before showing error to avoid terminal width error
1301
1075
  spinner.stop()
1302
1076
  print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
1303
- return
1077
+ return None
1304
1078
 
1305
1079
  spinner.write(
1306
1080
  to_colored_text(
@@ -1310,32 +1084,23 @@ class Sutro:
1310
1084
  return dataset_id
1311
1085
 
1312
1086
  def list_datasets(self):
1313
- endpoint = f"{self.base_url}/list-datasets"
1314
- headers = {
1315
- "Authorization": f"Key {self.api_key}",
1316
- "Content-Type": "application/json",
1317
- }
1318
1087
  with yaspin(
1319
1088
  SPINNER, text=to_colored_text("Retrieving datasets"), color=YASPIN_COLOR
1320
1089
  ) as spinner:
1321
- response = requests.post(endpoint, headers=headers)
1322
- if response.status_code != 200:
1090
+ try:
1091
+ response = self.do_request("POST", "list-datasets")
1092
+ spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
1093
+ return response.json()["datasets"]
1094
+ except requests.HTTPError as e:
1323
1095
  spinner.fail(
1324
1096
  to_colored_text(
1325
- f"Bad status code: {response.status_code}", state="fail"
1097
+ f"Bad status code: {e.response.status_code}", state="fail"
1326
1098
  )
1327
1099
  )
1328
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1329
- return
1330
- spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
1331
- return response.json()["datasets"]
1100
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1101
+ return None
1332
1102
 
1333
1103
  def list_dataset_files(self, dataset_id: str):
1334
- endpoint = f"{self.base_url}/list-dataset-files"
1335
- headers = {
1336
- "Authorization": f"Key {self.api_key}",
1337
- "Content-Type": "application/json",
1338
- }
1339
1104
  payload = {
1340
1105
  "dataset_id": dataset_id,
1341
1106
  }
@@ -1344,23 +1109,22 @@ class Sutro:
1344
1109
  text=to_colored_text(f"Listing files in dataset: {dataset_id}"),
1345
1110
  color=YASPIN_COLOR,
1346
1111
  ) as spinner:
1347
- response = requests.post(
1348
- endpoint, headers=headers, data=json.dumps(payload)
1349
- )
1350
- if response.status_code != 200:
1351
- spinner.fail(
1112
+ try:
1113
+ response = self.do_request("POST", "list-dataset-files", json=payload)
1114
+ spinner.write(
1352
1115
  to_colored_text(
1353
- f"Bad status code: {response.status_code}", state="fail"
1116
+ f" Files listed in dataset: {dataset_id}", state="success"
1354
1117
  )
1355
1118
  )
1356
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1357
- return
1358
- spinner.write(
1359
- to_colored_text(
1360
- f" Files listed in dataset: {dataset_id}", state="success"
1119
+ return response.json()["files"]
1120
+ except requests.HTTPError as e:
1121
+ spinner.fail(
1122
+ to_colored_text(
1123
+ f"Bad status code: {e.response.status_code}", state="fail"
1124
+ )
1361
1125
  )
1362
- )
1363
- return response.json()["files"]
1126
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1127
+ return None
1364
1128
 
1365
1129
  def download_from_dataset(
1366
1130
  self,
@@ -1368,8 +1132,6 @@ class Sutro:
1368
1132
  files: Union[List[str], str] = None,
1369
1133
  output_path: str = None,
1370
1134
  ):
1371
- endpoint = f"{self.base_url}/download-from-dataset"
1372
-
1373
1135
  if files is None:
1374
1136
  files = self.list_dataset_files(dataset_id)
1375
1137
  elif isinstance(files, str):
@@ -1394,32 +1156,32 @@ class Sutro:
1394
1156
  ) as spinner:
1395
1157
  count = 0
1396
1158
  for file in files:
1397
- headers = {
1398
- "Authorization": f"Key {self.api_key}",
1399
- "Content-Type": "application/json",
1400
- }
1401
- payload = {
1402
- "dataset_id": dataset_id,
1403
- "file_name": file,
1404
- }
1405
1159
  spinner.text = to_colored_text(
1406
1160
  f"Downloading file {count + 1}/{len(files)} from dataset: {dataset_id}"
1407
1161
  )
1408
- response = requests.post(
1409
- endpoint, headers=headers, data=json.dumps(payload)
1410
- )
1411
- if response.status_code != 200:
1162
+
1163
+ try:
1164
+ payload = {
1165
+ "dataset_id": dataset_id,
1166
+ "file_name": file,
1167
+ }
1168
+ response = self.do_request(
1169
+ "POST", "download-from-dataset", json=payload
1170
+ )
1171
+
1172
+ file_content = response.content
1173
+ with open(os.path.join(output_path, file), "wb") as f:
1174
+ f.write(file_content)
1175
+
1176
+ count += 1
1177
+ except requests.HTTPError as e:
1412
1178
  spinner.fail(
1413
1179
  to_colored_text(
1414
- f"Bad status code: {response.status_code}", state="fail"
1180
+ f"Bad status code: {e.response.status_code}", state="fail"
1415
1181
  )
1416
1182
  )
1417
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1183
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1418
1184
  return
1419
- file_content = response.content
1420
- with open(os.path.join(output_path, file), "wb") as f:
1421
- f.write(file_content)
1422
- count += 1
1423
1185
  spinner.write(
1424
1186
  to_colored_text(
1425
1187
  f"✔ {count} files successfully downloaded from dataset: {dataset_id}",
@@ -1439,46 +1201,38 @@ class Sutro:
1439
1201
  Returns:
1440
1202
  dict: The status of the authentication.
1441
1203
  """
1442
- endpoint = f"{self.base_url}/try-authentication"
1443
- headers = {
1444
- "Authorization": f"Key {api_key}",
1445
- "Content-Type": "application/json",
1446
- }
1447
1204
  with yaspin(
1448
1205
  SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
1449
1206
  ) as spinner:
1450
- response = requests.get(endpoint, headers=headers)
1451
- if response.status_code == 200:
1207
+ try:
1208
+ response = self.do_request("GET", "try-authentication", api_key)
1209
+
1452
1210
  spinner.write(to_colored_text("✔"))
1453
- else:
1211
+ return response.json()
1212
+ except requests.HTTPError as e:
1454
1213
  spinner.write(
1455
1214
  to_colored_text(
1456
- f"API key failed to authenticate: {response.status_code}",
1215
+ f"API key failed to authenticate: {e.response.status_code}",
1457
1216
  state="fail",
1458
1217
  )
1459
1218
  )
1460
- return
1461
- return response.json()
1219
+ return None
1462
1220
 
1463
1221
  def get_quotas(self):
1464
- endpoint = f"{self.base_url}/get-quotas"
1465
- headers = {
1466
- "Authorization": f"Key {self.api_key}",
1467
- "Content-Type": "application/json",
1468
- }
1469
1222
  with yaspin(
1470
1223
  SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
1471
1224
  ) as spinner:
1472
- response = requests.get(endpoint, headers=headers)
1473
- if response.status_code != 200:
1225
+ try:
1226
+ response = self.do_request("GET", "get-quotas")
1227
+ return response.json()["quotas"]
1228
+ except requests.HTTPError as e:
1474
1229
  spinner.fail(
1475
1230
  to_colored_text(
1476
- f"Bad status code: {response.status_code}", state="fail"
1231
+ f"Bad status code: {e.response.status_code}", state="fail"
1477
1232
  )
1478
1233
  )
1479
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1480
- return
1481
- return response.json()["quotas"]
1234
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1235
+ return None
1482
1236
 
1483
1237
  def await_job_completion(
1484
1238
  self,
@@ -1486,7 +1240,7 @@ class Sutro:
1486
1240
  timeout: Optional[int] = 7200,
1487
1241
  obtain_results: bool = True,
1488
1242
  is_cost_estimate: bool = False,
1489
- ) -> list | None:
1243
+ ) -> pl.DataFrame | None:
1490
1244
  """
1491
1245
  Waits for job completion to occur and then returns the results upon
1492
1246
  a successful completion.
@@ -1502,7 +1256,7 @@ class Sutro:
1502
1256
  """
1503
1257
  POLL_INTERVAL = 5
1504
1258
 
1505
- results = None
1259
+ results: pl.DataFrame | None = None
1506
1260
  start_time = time.time()
1507
1261
  with yaspin(
1508
1262
  SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR