sutro 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/sdk.py CHANGED
@@ -1,48 +1,32 @@
1
- from enum import Enum
2
1
  import requests
3
2
  import pandas as pd
4
3
  import polars as pl
5
4
  import json
6
- from typing import Union, List, Optional, Literal, Dict, Any
5
+ from typing import Union, List, Optional, Dict, Any, Type
7
6
  import os
8
7
  import sys
9
8
  from yaspin import yaspin
10
9
  from yaspin.spinners import Spinners
11
- from colorama import init, Fore, Style
12
- from tqdm import tqdm
10
+ from colorama import init
13
11
  import time
14
12
  from pydantic import BaseModel
15
13
  import pyarrow.parquet as pq
16
14
  import shutil
15
+ from sutro.common import (
16
+ ModelOptions,
17
+ handle_data_helper,
18
+ normalize_output_schema,
19
+ to_colored_text,
20
+ fancy_tqdm,
21
+ )
22
+ from sutro.interfaces import JobStatus
23
+ from sutro.templates.classification import ClassificationTemplates
24
+ from sutro.templates.embed import EmbeddingTemplates
25
+ from sutro.validation import check_version, check_for_api_key
17
26
 
18
27
  JOB_NAME_CHAR_LIMIT = 45
19
28
  JOB_DESCRIPTION_CHAR_LIMIT = 512
20
29
 
21
- class JobStatus(str, Enum):
22
- """Job statuses that will be returned by the API & SDK"""
23
-
24
- UNKNOWN = "UNKNOWN"
25
- QUEUED = "QUEUED" # Job is waiting to start
26
- STARTING = "STARTING" # Job is in the process of starting up
27
- RUNNING = "RUNNING" # Job is actively running
28
- SUCCEEDED = "SUCCEEDED" # Job completed successfully
29
- CANCELLING = "CANCELLING" # Job is in the process of being canceled
30
- CANCELLED = "CANCELLED" # Job was canceled by the user
31
- FAILED = "FAILED" # Job failed
32
-
33
- @classmethod
34
- def terminal_statuses(cls) -> list["JobStatus"]:
35
- return [
36
- cls.SUCCEEDED,
37
- cls.FAILED,
38
- cls.CANCELLING,
39
- cls.CANCELLED,
40
- ]
41
-
42
- def is_terminal(self) -> bool:
43
- return self in self.terminal_statuses()
44
-
45
-
46
30
  # Initialize colorama (required for Windows)
47
31
  init()
48
32
 
@@ -56,57 +40,6 @@ def is_jupyter() -> bool:
56
40
  YASPIN_COLOR = None if is_jupyter() else "blue"
57
41
  SPINNER = Spinners.dots14
58
42
 
59
- # Models available for inference. Keep in sync with the backend configuration
60
- # so users get helpful autocompletion when selecting a model.
61
- ModelOptions = Literal[
62
- "llama-3.2-3b",
63
- "llama-3.1-8b",
64
- "llama-3.3-70b",
65
- "llama-3.3-70b",
66
- "qwen-3-4b",
67
- "qwen-3-14b",
68
- "qwen-3-32b",
69
- "qwen-3-30b-a3b",
70
- "qwen-3-235b-a22b",
71
- "qwen-3-4b-thinking",
72
- "qwen-3-14b-thinking",
73
- "qwen-3-32b-thinking",
74
- "qwen-3-235b-a22b-thinking",
75
- "qwen-3-30b-a3b-thinking",
76
- "gemma-3-4b-it",
77
- "gemma-3-12b-it",
78
- "gemma-3-27b-it",
79
- "gpt-oss-20b",
80
- "gpt-oss-120b",
81
- "qwen-3-embedding-0.6b",
82
- "qwen-3-embedding-6b",
83
- "qwen-3-embedding-8b",
84
- ]
85
-
86
-
87
- def to_colored_text(
88
- text: str, state: Optional[Literal["success", "fail"]] = None
89
- ) -> str:
90
- """
91
- Apply color to text based on state.
92
-
93
- Args:
94
- text (str): The text to color
95
- state (Optional[Literal['success', 'fail']]): The state that determines the color.
96
- Options: 'success', 'fail', or None (default blue)
97
-
98
- Returns:
99
- str: Text with appropriate color applied
100
- """
101
- match state:
102
- case "success":
103
- return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
104
- case "fail":
105
- return f"{Fore.RED}{text}{Style.RESET_ALL}"
106
- case _:
107
- # Default to blue for normal/processing states
108
- return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
109
-
110
43
 
111
44
  # Isn't fully support in all terminals unfortunately. We should switch to Rich
112
45
  # at some point, but even Rich links aren't clickable on MacOS Terminal
@@ -120,36 +53,11 @@ def make_clickable_link(url, text=None):
120
53
  return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
121
54
 
122
55
 
123
- class Sutro:
56
+ class Sutro(EmbeddingTemplates, ClassificationTemplates):
124
57
  def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
125
- self.api_key = api_key or self.check_for_api_key()
58
+ self.api_key = api_key or check_for_api_key()
126
59
  self.base_url = base_url
127
-
128
- def check_for_api_key(self):
129
- """
130
- Check for an API key in the user's home directory.
131
-
132
- This method looks for a configuration file named 'config.json' in the
133
- '.sutro' directory within the user's home directory.
134
- If the file exists, it attempts to read the API key from it.
135
-
136
- Returns:
137
- str or None: The API key if found in the configuration file, or None if not found.
138
-
139
- Note:
140
- The expected structure of the config.json file is:
141
- {
142
- "api_key": "your_api_key_here"
143
- }
144
- """
145
- CONFIG_DIR = os.path.expanduser("~/.sutro")
146
- CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
147
- if os.path.exists(CONFIG_FILE):
148
- with open(CONFIG_FILE, "r") as f:
149
- config = json.load(f)
150
- return config.get("api_key")
151
- else:
152
- return None
60
+ check_version("sutro")
153
61
 
154
62
  def set_api_key(self, api_key: str):
155
63
  """
@@ -166,79 +74,6 @@ class Sutro:
166
74
  """
167
75
  self.api_key = api_key
168
76
 
169
- def do_dataframe_column_concatenation(self, data: Union[pd.DataFrame, pl.DataFrame], column: Union[str, List[str]]):
170
- """
171
- If the user has supplied a dataframe and a list of columns, this will intelligenly concatenate the columns into a single column, accepting separator strings.
172
- """
173
- try:
174
- if isinstance(data, pd.DataFrame):
175
- series_parts = []
176
- for p in column:
177
- if p in data.columns:
178
- s = data[p].astype("string").fillna("")
179
- else:
180
- # Treat as a literal separator
181
- s = pd.Series([p] * len(data), index=data.index, dtype="string")
182
- series_parts.append(s)
183
-
184
- out = series_parts[0]
185
- for s in series_parts[1:]:
186
- out = out.str.cat(s, na_rep="")
187
-
188
- return out.tolist()
189
- elif isinstance(data, pl.DataFrame):
190
- exprs = []
191
- for p in column:
192
- if p in data.columns:
193
- exprs.append(pl.col(p).cast(pl.Utf8).fill_null(""))
194
- else:
195
- exprs.append(pl.lit(p))
196
-
197
- result = data.select(pl.concat_str(exprs, separator="", ignore_nulls=False).alias("concat"))
198
- return result["concat"].to_list()
199
- except Exception as e:
200
- raise ValueError(f"Error handling column concatentation: {e}")
201
-
202
- def handle_data_helper(
203
- self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
204
- ):
205
- if isinstance(data, list):
206
- input_data = data
207
- elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
208
- if column is None:
209
- raise ValueError("Column name must be specified for DataFrame input")
210
- if isinstance(column, list):
211
- input_data = self.do_dataframe_column_concatenation(data, column)
212
- elif isinstance(column, str):
213
- input_data = data[column].to_list()
214
- elif isinstance(data, str):
215
- if data.startswith("dataset-"):
216
- input_data = data + ":" + column
217
- else:
218
- file_ext = os.path.splitext(data)[1].lower()
219
- if file_ext == ".csv":
220
- df = pl.read_csv(data)
221
- elif file_ext == ".parquet":
222
- df = pl.read_parquet(data)
223
- elif file_ext in [".txt", ""]:
224
- with open(data, "r") as file:
225
- input_data = [line.strip() for line in file]
226
- else:
227
- raise ValueError(f"Unsupported file type: {file_ext}")
228
-
229
- if file_ext in [".csv", ".parquet"]:
230
- if column is None:
231
- raise ValueError(
232
- "Column name must be specified for CSV/Parquet input"
233
- )
234
- input_data = df[column].to_list()
235
- else:
236
- raise ValueError(
237
- "Unsupported data type. Please provide a list, DataFrame, or file path."
238
- )
239
-
240
- return input_data
241
-
242
77
  def set_base_url(self, base_url: str):
243
78
  """
244
79
  Set the base URL for the Sutro API.
@@ -251,6 +86,43 @@ class Sutro:
251
86
  """
252
87
  self.base_url = base_url
253
88
 
89
+ def do_request(
90
+ self,
91
+ method: str,
92
+ endpoint: str,
93
+ api_key_override: Optional[str] = None,
94
+ **kwargs: Any,
95
+ ):
96
+ """
97
+ Helper to make authenticated requests.
98
+ """
99
+ key = self.api_key if not api_key_override else api_key_override
100
+ headers = {"Authorization": f"Key {key}"}
101
+
102
+ # Merge with any headers passed in kwargs
103
+ if "headers" in kwargs:
104
+ headers.update(kwargs.pop("headers"))
105
+
106
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
107
+
108
+ # Explicit method dispatch
109
+ method = method.upper()
110
+ if method == "GET":
111
+ response = requests.get(url, headers=headers, **kwargs)
112
+ elif method == "POST":
113
+ response = requests.post(url, headers=headers, **kwargs)
114
+ elif method == "PUT":
115
+ response = requests.put(url, headers=headers, **kwargs)
116
+ elif method == "DELETE":
117
+ response = requests.delete(url, headers=headers, **kwargs)
118
+ elif method == "PATCH":
119
+ response = requests.patch(url, headers=headers, **kwargs)
120
+ else:
121
+ raise ValueError(f"Unsupported HTTP method: {method}")
122
+
123
+ response.raise_for_status()
124
+ return response
125
+
254
126
  def _run_one_batch_inference(
255
127
  self,
256
128
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
@@ -270,16 +142,15 @@ class Sutro:
270
142
  ):
271
143
  # Validate name and description lengths
272
144
  if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
273
- raise ValueError(f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters.")
145
+ raise ValueError(
146
+ f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters."
147
+ )
274
148
  if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
275
- raise ValueError(f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters.")
149
+ raise ValueError(
150
+ f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters."
151
+ )
276
152
 
277
- input_data = self.handle_data_helper(data, column)
278
- endpoint = f"{self.base_url}/batch-inference"
279
- headers = {
280
- "Authorization": f"Key {self.api_key}",
281
- "Content-Type": "application/json",
282
- }
153
+ input_data = handle_data_helper(data, column)
283
154
  payload = {
284
155
  "model": model,
285
156
  "inputs": input_data,
@@ -305,16 +176,19 @@ class Sutro:
305
176
  spinner_text = to_colored_text(t)
306
177
  try:
307
178
  with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
308
- response = requests.post(
309
- endpoint, data=json.dumps(payload), headers=headers
310
- )
311
- response_data = response.json()
179
+ try:
180
+ response = self.do_request("POST", "batch-inference", json=payload)
181
+ response_data = response.json()
182
+ except requests.HTTPError as e:
183
+ response = e.response
184
+ response_data = response.json()
185
+
312
186
  if response.status_code != 200:
313
187
  spinner.write(
314
188
  to_colored_text(f"Error: {response.status_code}", state="fail")
315
189
  )
316
190
  spinner.stop()
317
- print(to_colored_text(response.json(), state="fail"))
191
+ print(to_colored_text(response_data, state="fail"))
318
192
  return None
319
193
  else:
320
194
  job_id = response_data["results"]
@@ -340,10 +214,11 @@ class Sutro:
340
214
  name_text = f" and name {name}" if name is not None else ""
341
215
  spinner.write(
342
216
  to_colored_text(
343
- f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}.",
217
+ f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}",
344
218
  state="success",
345
219
  )
346
220
  )
221
+ spinner.write(to_colored_text(f"Model: {model}"))
347
222
  if not stay_attached:
348
223
  clickable_link = make_clickable_link(
349
224
  f"https://app.sutro.sh/jobs/{job_id}"
@@ -380,13 +255,13 @@ class Sutro:
380
255
  )
381
256
  )
382
257
  return None
383
- s = requests.Session()
258
+
384
259
  pbar = None
385
260
 
386
261
  try:
387
- with requests.get(
388
- f"{self.base_url}/stream-job-progress/{job_id}",
389
- headers=headers,
262
+ with self.do_request(
263
+ "GET",
264
+ f"/stream-job-progress/{job_id}",
390
265
  stream=True,
391
266
  ) as streaming_response:
392
267
  streaming_response.raise_for_status()
@@ -415,7 +290,7 @@ class Sutro:
415
290
  if pbar is None:
416
291
  spinner.stop()
417
292
  postfix = "Input tokens processed: 0"
418
- pbar = self.fancy_tqdm(
293
+ pbar = fancy_tqdm(
419
294
  total=len(input_data),
420
295
  desc="Progress",
421
296
  style=1,
@@ -456,28 +331,27 @@ class Sutro:
456
331
  )
457
332
  spinner.start()
458
333
 
459
- payload = {
460
- "job_id": job_id,
461
- }
462
-
463
334
  # TODO: we implment retries in cases where the job hasn't written results yet
464
335
  # it would be better if we could receive a fully succeeded status from the job
465
336
  # and not have such a race condition
466
337
  max_retries = 20 # winds up being 100 seconds cumulative delay
467
338
  retry_delay = 5 # initial delay in seconds
468
-
339
+ job_results_response = None
469
340
  for _ in range(max_retries):
470
- time.sleep(retry_delay)
471
-
472
- job_results_response = s.post(
473
- f"{self.base_url}/job-results",
474
- headers=headers,
475
- data=json.dumps(payload),
476
- )
477
- if job_results_response.status_code == 200:
341
+ try:
342
+ job_results_response = self.do_request(
343
+ "POST",
344
+ "job-results",
345
+ json={
346
+ "job_id": job_id,
347
+ },
348
+ )
478
349
  break
350
+ except requests.HTTPError:
351
+ time.sleep(retry_delay)
352
+ continue
479
353
 
480
- if job_results_response.status_code != 200:
354
+ if not job_results_response or job_results_response.status_code != 200:
481
355
  spinner.write(
482
356
  to_colored_text(
483
357
  "Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
@@ -489,122 +363,183 @@ class Sutro:
489
363
 
490
364
  results = job_results_response.json()["results"]["outputs"]
491
365
 
492
- spinner.write(
493
- to_colored_text(
494
- f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
495
- state="success",
496
- )
497
- )
498
- spinner.stop()
499
-
500
366
  if isinstance(data, (pd.DataFrame, pl.DataFrame)):
501
367
  if isinstance(data, pd.DataFrame):
502
368
  data[output_column] = results
503
369
  elif isinstance(data, pl.DataFrame):
504
370
  data = data.with_columns(pl.Series(output_column, results))
505
- return data
371
+ print(data)
372
+ spinner.write(
373
+ to_colored_text(
374
+ f"✔ Displaying result preview. You can join the results on the original dataframe with `so.get_job_results('{job_id}', with_original_df=<original_df>)`",
375
+ state="success",
376
+ )
377
+ )
378
+ else:
379
+ print(results)
380
+ spinner.write(
381
+ to_colored_text(
382
+ f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
383
+ state="success",
384
+ )
385
+ )
386
+ spinner.stop()
506
387
 
507
- return results
388
+ return job_id
508
389
  return None
509
390
  return None
510
391
 
511
392
  def infer(
512
393
  self,
513
394
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
514
- model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
515
- name: Union[str, List[str]] = None,
516
- description: Union[str, List[str]] = None,
395
+ model: ModelOptions = "gemma-3-12b-it",
396
+ name: Optional[str] = None,
397
+ description: Optional[str] = None,
517
398
  column: Union[str, List[str]] = None,
518
399
  output_column: str = "inference_result",
519
400
  job_priority: int = 0,
520
- output_schema: Union[Dict[str, Any], BaseModel] = None,
401
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
521
402
  sampling_params: dict = None,
522
403
  system_prompt: str = None,
523
404
  dry_run: bool = False,
524
405
  stay_attached: Optional[bool] = None,
525
406
  random_seed_per_input: bool = False,
526
- truncate_rows: bool = False,
407
+ truncate_rows: bool = True,
527
408
  ):
528
409
  """
529
410
  Run inference on the provided data.
530
411
 
531
412
  This method allows you to run inference on the provided data using the Sutro API.
532
- It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
413
+ It supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
533
414
 
534
415
  Args:
535
416
  data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
536
- model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
537
- name (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
538
- description (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
417
+ model (ModelOptions, optional): The model to use for inference. Defaults to "gemma-3-12b-it".
418
+ name (str, optional): A job name for experiment/metadata tracking purposes. Defaults to None.
419
+ description (str, optional): A job description for experiment/metadata tracking purposes. Defaults to None.
539
420
  column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
540
421
  output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
541
422
  job_priority (int, optional): The priority of the job. Defaults to 0.
542
423
  output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
543
- Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
424
+ Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
544
425
  sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
545
426
  system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
546
427
  dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
547
428
  stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
548
429
  random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
549
- truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
430
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
550
431
 
551
432
  Returns:
552
- Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
433
+ str: The ID of the inference job.
553
434
 
554
435
  """
555
- if isinstance(model, list) == False:
556
- model_list = [model]
557
- stay_attached = (
558
- stay_attached if stay_attached is not None else job_priority == 0
559
- )
560
- else:
561
- model_list = model
562
- stay_attached = False
563
-
564
- if isinstance(model_list, list):
565
- if isinstance(name, list):
566
- if len(name) != len(model_list):
567
- raise ValueError("Name list must be the same length as the model list.")
568
- name_list = name
569
- elif isinstance(name, str):
570
- raise ValueError("Name must be a list if using a list of models.")
571
- else:
572
- if isinstance(name, list):
573
- raise ValueError("Name must be a string or None if using a single model.")
574
- name_list = [name]
575
-
576
- if isinstance(model_list, list):
577
- if isinstance(description, list):
578
- if len(description) != len(model_list):
579
- raise ValueError("Descriptions list must be the same length as the model list.")
580
- description_list = description
581
- elif isinstance(description, str):
582
- raise ValueError("Description must be a list if using a list of models.")
436
+ # Default stay_attached to True for prototyping jobs (priority 0)
437
+ if stay_attached is None:
438
+ stay_attached = job_priority == 0
439
+
440
+ json_schema = None
441
+ if output_schema:
442
+ # Convert BaseModel to dict if needed
443
+ json_schema = normalize_output_schema(output_schema)
444
+
445
+ return self._run_one_batch_inference(
446
+ data,
447
+ model,
448
+ column,
449
+ output_column,
450
+ job_priority,
451
+ json_schema,
452
+ sampling_params,
453
+ system_prompt,
454
+ dry_run,
455
+ stay_attached,
456
+ random_seed_per_input,
457
+ truncate_rows,
458
+ name,
459
+ description,
460
+ )
461
+
462
+ def infer_per_model(
463
+ self,
464
+ data: Union[List, pd.DataFrame, pl.DataFrame, str],
465
+ models: List[ModelOptions],
466
+ names: List[str] = None,
467
+ descriptions: List[str] = None,
468
+ column: Union[str, List[str]] = None,
469
+ output_column: str = "inference_result",
470
+ job_priority: int = 0,
471
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
472
+ sampling_params: dict = None,
473
+ system_prompt: str = None,
474
+ dry_run: bool = False,
475
+ random_seed_per_input: bool = False,
476
+ truncate_rows: bool = True,
477
+ ):
478
+ """
479
+ Run inference on the provided data, across multiple models. This method is often useful to sampling outputs from multiple models across the same dataset and compare the job_ids.
480
+
481
+ For input data, it supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
482
+
483
+ Args:
484
+ data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
485
+ models (Union[ModelOptions, List[ModelOptions]], optional): The models to use for inference. Fans out each model to its own seperate job, over the same dataset.
486
+ names (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
487
+ descriptions (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
488
+ column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
489
+ output_column (str, optional): The column name to store the inference job_ids in if the input is a DataFrame. Defaults to "inference_result".
490
+ job_priority (int, optional): The priority of the job. Defaults to 0.
491
+ output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
492
+ Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
493
+ sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
494
+ system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
495
+ dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
496
+ stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
497
+ random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
498
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
499
+
500
+ Returns:
501
+ str: The ID of the inference job.
502
+
503
+ """
504
+ if isinstance(names, list):
505
+ if len(names) != len(models):
506
+ raise ValueError(
507
+ "names parameter must be the same length as the models parameter."
508
+ )
509
+ elif names is None:
510
+ names = [None] * len(models)
583
511
  else:
584
- if isinstance(name, list):
585
- raise ValueError("Description must be a string or None if using a single model.")
586
- description_list = [description]
587
-
588
- # Convert BaseModel to dict if needed
589
- if output_schema is not None:
590
- if hasattr(
591
- output_schema, "model_json_schema"
592
- ): # Check for pydantic Model interface
593
- json_schema = output_schema.model_json_schema()
594
- elif isinstance(output_schema, dict):
595
- json_schema = output_schema
596
- else:
512
+ raise ValueError(
513
+ "names parameter must be a list or None if using a list of models"
514
+ )
515
+
516
+ if isinstance(descriptions, list):
517
+ if len(descriptions) != len(models):
597
518
  raise ValueError(
598
- "Invalid output schema type. Must be a dictionary or a pydantic Model."
519
+ "descriptions parameter must be the same length as the models"
520
+ " parameter."
599
521
  )
522
+ elif descriptions is None:
523
+ descriptions = [None] * len(models)
600
524
  else:
601
- json_schema = None
602
-
603
- results = []
604
- for i in range(len(model_list)):
605
- res = self._run_one_batch_inference(
525
+ raise ValueError(
526
+ "descriptions parameter must be a list or None if using a list of "
527
+ "models"
528
+ )
529
+
530
+ json_schema = None
531
+ if output_schema:
532
+ # Convert BaseModel to dict if needed
533
+ json_schema = normalize_output_schema(output_schema)
534
+
535
+ def start_job(
536
+ model_singleton: ModelOptions,
537
+ name_singleton: str | None,
538
+ description_singleton: str | None,
539
+ ):
540
+ return self._run_one_batch_inference(
606
541
  data,
607
- model_list[i],
542
+ model_singleton,
608
543
  column,
609
544
  output_column,
610
545
  job_priority,
@@ -612,20 +547,21 @@ class Sutro:
612
547
  sampling_params,
613
548
  system_prompt,
614
549
  dry_run,
615
- stay_attached,
550
+ False,
616
551
  random_seed_per_input,
617
552
  truncate_rows,
618
- name_list[i],
619
- description_list[i],
553
+ name_singleton,
554
+ description_singleton,
620
555
  )
621
- results.append(res)
622
556
 
623
- if len(results) > 1:
624
- return results
625
- elif len(results) == 1:
626
- return results[0]
557
+ job_ids = [
558
+ start_job(model, name, description)
559
+ for model, name, description in zip(
560
+ models, names, descriptions, strict=True
561
+ )
562
+ ]
627
563
 
628
- return None
564
+ return job_ids
629
565
 
630
566
  def attach(self, job_id):
631
567
  """
@@ -636,16 +572,8 @@ class Sutro:
636
572
  """
637
573
 
638
574
  s = requests.Session()
639
- payload = {
640
- "job_id": job_id,
641
- }
642
575
  pbar = None
643
576
 
644
- headers = {
645
- "Authorization": f"Key {self.api_key}",
646
- "Content-Type": "application/json",
647
- }
648
-
649
577
  with yaspin(
650
578
  SPINNER,
651
579
  text=to_colored_text("Looking for job..."),
@@ -683,9 +611,9 @@ class Sutro:
683
611
  success = False
684
612
 
685
613
  try:
686
- with s.get(
687
- f"{self.base_url}/stream-job-progress/{job_id}",
688
- headers=headers,
614
+ with self.do_request(
615
+ "GET",
616
+ f"/stream-job-progress/{job_id}",
689
617
  stream=True,
690
618
  ) as streaming_response:
691
619
  streaming_response.raise_for_status()
@@ -715,7 +643,7 @@ class Sutro:
715
643
  if pbar is None:
716
644
  spinner.stop()
717
645
  postfix = "Input tokens processed: 0"
718
- pbar = self.fancy_tqdm(
646
+ pbar = fancy_tqdm(
719
647
  total=total_rows,
720
648
  desc="Progress",
721
649
  style=1,
@@ -748,65 +676,6 @@ class Sutro:
748
676
  if spinner:
749
677
  spinner.stop()
750
678
 
751
- def fancy_tqdm(
752
- self,
753
- total: int,
754
- desc: str = "Progress",
755
- color: str = "blue",
756
- style=1,
757
- postfix: str = None,
758
- ):
759
- """
760
- Creates a customized tqdm progress bar with different styling options.
761
-
762
- Args:
763
- total (int): Total iterations
764
- desc (str): Description for the progress bar
765
- color (str): Color of the progress bar (green, blue, red, yellow, magenta)
766
- style (int): Style preset (1-4)
767
- postfix (str): Postfix for the progress bar
768
- """
769
-
770
- # Style presets
771
- style_presets = {
772
- 1: {
773
- "bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
774
- "ascii": "░▒█",
775
- },
776
- 2: {
777
- "bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
778
- "ascii": "▁▂▃▄▅▆▇█",
779
- },
780
- 3: {
781
- "bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
782
- "ascii": "◯◔◑◕●",
783
- },
784
- 4: {
785
- "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
786
- "ascii": "⬜⬛",
787
- },
788
- 5: {
789
- "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
790
- "ascii": "▏▎▍▌▋▊▉█",
791
- },
792
- }
793
-
794
- # Get style configuration
795
- style_config = style_presets.get(style, style_presets[1])
796
-
797
- return tqdm(
798
- total=total,
799
- desc=desc,
800
- colour=color,
801
- bar_format=style_config["bar_format"],
802
- ascii=style_config["ascii"],
803
- ncols=80,
804
- dynamic_ncols=True,
805
- smoothing=0.3,
806
- leave=True,
807
- postfix=postfix,
808
- )
809
-
810
679
  def list_jobs(self):
811
680
  """
812
681
  List all jobs.
@@ -814,56 +683,36 @@ class Sutro:
814
683
  This method retrieves a list of all jobs associated with the API key.
815
684
 
816
685
  Returns:
817
- list: A list of job details.
686
+ list: A list of job details, or None if the request fails.
818
687
  """
819
- endpoint = f"{self.base_url}/list-jobs"
820
- headers = {
821
- "Authorization": f"Key {self.api_key}",
822
- "Content-Type": "application/json",
823
- }
824
-
825
688
  with yaspin(
826
689
  SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
827
690
  ) as spinner:
828
- response = requests.get(endpoint, headers=headers)
829
- if response.status_code != 200:
691
+ try:
692
+ return self._list_all_jobs_for_user()
693
+ except requests.HTTPError as e:
830
694
  spinner.write(
831
695
  to_colored_text(
832
- f"Bad status code: {response.status_code}", state="fail"
696
+ f"Bad status code: {e.response.status_code}", state="fail"
833
697
  )
834
698
  )
835
699
  spinner.stop()
836
- print(to_colored_text(response.json(), state="fail"))
837
- return
838
- return response.json()["jobs"]
700
+ print(to_colored_text(e.response.json(), state="fail"))
701
+ return None
839
702
 
840
- def _list_jobs_helper(self):
841
- """
842
- Helper function to list jobs.
843
- """
844
- endpoint = f"{self.base_url}/list-jobs˚"
845
- headers = {
846
- "Authorization": f"Key {self.api_key}",
847
- "Content-Type": "application/json",
848
- }
849
- response = requests.get(endpoint, headers=headers)
850
- if response.status_code != 200:
851
- return None
703
+ def _list_all_jobs_for_user(self):
704
+ response = self.do_request("GET", "list-jobs")
852
705
  return response.json()["jobs"]
853
706
 
854
707
  def _fetch_job(self, job_id):
855
708
  """
856
709
  Helper function to fetch a single job.
857
710
  """
858
- endpoint = f"{self.base_url}/jobs/{job_id}"
859
- headers = {
860
- "Authorization": f"Key {self.api_key}",
861
- "Content-Type": "application/json",
862
- }
863
- response = requests.get(endpoint, headers=headers)
864
- if response.status_code != 200:
711
+ try:
712
+ response = self.do_request("GET", f"jobs/{job_id}")
713
+ return response.json().get("job")
714
+ except requests.HTTPError:
865
715
  return None
866
- return response.json().get("job")
867
716
 
868
717
  def _get_job_cost_estimate(self, job_id: str):
869
718
  """
@@ -897,15 +746,7 @@ class Sutro:
897
746
  Raises:
898
747
  requests.HTTPError: If the API returns a non-200 status code.
899
748
  """
900
- endpoint = f"{self.base_url}/job-status/{job_id}"
901
- headers = {
902
- "Authorization": f"Key {self.api_key}",
903
- "Content-Type": "application/json",
904
- }
905
-
906
- response = requests.get(endpoint, headers=headers)
907
- response.raise_for_status()
908
-
749
+ response = self.do_request("GET", f"job-status/{job_id}")
909
750
  return response.json()["job_status"][job_id]
910
751
 
911
752
  def get_job_status(self, job_id: str):
@@ -950,7 +791,7 @@ class Sutro:
950
791
  output_column: str = "inference_result",
951
792
  disable_cache: bool = False,
952
793
  unpack_json: bool = True,
953
- ):
794
+ ) -> pl.DataFrame | pd.DataFrame:
954
795
  """
955
796
  Get the results of a job by its ID.
956
797
 
@@ -987,44 +828,37 @@ class Sutro:
987
828
  to_colored_text("✔ Results loaded from cache", state="success")
988
829
  )
989
830
  else:
990
- endpoint = f"{self.base_url}/job-results"
991
831
  payload = {
992
832
  "job_id": job_id,
993
833
  "include_inputs": include_inputs,
994
834
  "include_cumulative_logprobs": include_cumulative_logprobs,
995
835
  }
996
- headers = {
997
- "Authorization": f"Key {self.api_key}",
998
- "Content-Type": "application/json",
999
- }
1000
836
  with yaspin(
1001
837
  SPINNER,
1002
838
  text=to_colored_text(f"Gathering results from job: {job_id}"),
1003
839
  color=YASPIN_COLOR,
1004
840
  ) as spinner:
1005
- response = requests.post(
1006
- endpoint, data=json.dumps(payload), headers=headers
1007
- )
1008
- if response.status_code != 200:
841
+ try:
842
+ response = self.do_request("POST", "job-results", json=payload)
843
+ response_data = response.json()
844
+ spinner.write(
845
+ to_colored_text("✔ Job results retrieved", state="success")
846
+ )
847
+ except requests.HTTPError as e:
1009
848
  spinner.write(
1010
849
  to_colored_text(
1011
- f"Bad status code: {response.status_code}", state="fail"
850
+ f"Bad status code: {e.response.status_code}", state="fail"
1012
851
  )
1013
852
  )
1014
853
  spinner.stop()
1015
- print(to_colored_text(response.json(), state="fail"))
854
+ print(to_colored_text(e.response.json(), state="fail"))
1016
855
  return None
1017
856
 
1018
- spinner.write(
1019
- to_colored_text("✔ Job results retrieved", state="success")
1020
- )
1021
-
1022
- response_data = response.json()
1023
857
  results_df = pl.DataFrame(response_data["results"])
1024
858
 
1025
859
  results_df = results_df.rename({"outputs": output_column})
1026
860
 
1027
- if disable_cache == False:
861
+ if not disable_cache:
1028
862
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
1029
863
  results_df.write_parquet(file_path, compression="snappy")
1030
864
  spinner.write(
@@ -1051,10 +885,11 @@ class Sutro:
1051
885
  first_row = json.loads(
1052
886
  results_df.head(1)[output_column][0]
1053
887
  ) # checks if the first row can be json decoded
888
+ results_df = results_df.map_columns(
889
+ output_column, lambda s: s.str.json_decode()
890
+ )
1054
891
  results_df = results_df.with_columns(
1055
- pl.col(output_column)
1056
- .str.json_decode()
1057
- .alias("output_column_json_decoded")
892
+ pl.col(output_column).alias("output_column_json_decoded")
1058
893
  )
1059
894
  json_decoded_fields = first_row.keys()
1060
895
  for field in json_decoded_fields:
@@ -1063,11 +898,20 @@ class Sutro:
1063
898
  .struct.field(field)
1064
899
  .alias(field)
1065
900
  )
1066
- # drop the output_column and the json decoded column
901
+ if sorted(list(set(json_decoded_fields))) == [
902
+ "content",
903
+ "reasoning_content",
904
+ ]: # if it's a reasoning model, we need to unpack the content field
905
+ content_keys = results_df.head(1)["content"][0].keys()
906
+ for key in content_keys:
907
+ results_df = results_df.with_columns(
908
+ pl.col("content").struct.field(key).alias(key)
909
+ )
910
+ results_df = results_df.drop("content")
1067
911
  results_df = results_df.drop(
1068
912
  [output_column, "output_column_json_decoded"]
1069
913
  )
1070
- except Exception as e:
914
+ except Exception:
1071
915
  # if the first row cannot be json decoded, do nothing
1072
916
  pass
1073
917
 
@@ -1103,25 +947,20 @@ class Sutro:
1103
947
  Returns:
1104
948
  dict: The status of the job.
1105
949
  """
1106
- endpoint = f"{self.base_url}/job-cancel/{job_id}"
1107
- headers = {
1108
- "Authorization": f"Key {self.api_key}",
1109
- "Content-Type": "application/json",
1110
- }
1111
950
  with yaspin(
1112
951
  SPINNER,
1113
952
  text=to_colored_text(f"Cancelling job: {job_id}"),
1114
953
  color=YASPIN_COLOR,
1115
954
  ) as spinner:
1116
- response = requests.get(endpoint, headers=headers)
1117
- if response.status_code == 200:
955
+ try:
956
+ response = self.do_request("GET", f"job-cancel/{job_id}")
1118
957
  spinner.write(to_colored_text("✔ Job cancelled", state="success"))
1119
- else:
958
+ return response.json()
959
+ except requests.HTTPError as e:
1120
960
  spinner.write(to_colored_text("Failed to cancel job", state="fail"))
1121
961
  spinner.stop()
1122
- print(to_colored_text(response.json(), state="fail"))
1123
- return
1124
- return response.json()
962
+ print(to_colored_text(e.response.json(), state="fail"))
963
+ return None
1125
964
 
1126
965
  def create_dataset(self):
1127
966
  """
@@ -1132,31 +971,27 @@ class Sutro:
1132
971
  Returns:
1133
972
  str: The ID of the new dataset.
1134
973
  """
1135
- endpoint = f"{self.base_url}/create-dataset"
1136
- headers = {
1137
- "Authorization": f"Key {self.api_key}",
1138
- "Content-Type": "application/json",
1139
- }
1140
974
  with yaspin(
1141
975
  SPINNER, text=to_colored_text("Creating dataset"), color=YASPIN_COLOR
1142
976
  ) as spinner:
1143
- response = requests.get(endpoint, headers=headers)
1144
- if response.status_code != 200:
977
+ try:
978
+ response = self.do_request("GET", "create-dataset")
979
+ dataset_id = response.json()["dataset_id"]
1145
980
  spinner.write(
1146
981
  to_colored_text(
1147
- f"Bad status code: {response.status_code}", state="fail"
982
+ f" Dataset created with ID: {dataset_id}", state="success"
1148
983
  )
1149
984
  )
1150
- spinner.stop()
1151
- print(to_colored_text(response.json(), state="fail"))
1152
- return
1153
- dataset_id = response.json()["dataset_id"]
1154
- spinner.write(
1155
- to_colored_text(
1156
- f"✔ Dataset created with ID: {dataset_id}", state="success"
985
+ return dataset_id
986
+ except requests.HTTPError as e:
987
+ spinner.write(
988
+ to_colored_text(
989
+ f"Bad status code: {e.response.status_code}", state="fail"
990
+ )
1157
991
  )
1158
- )
1159
- return dataset_id
992
+ spinner.stop()
993
+ print(to_colored_text(e.response.json(), state="fail"))
994
+ return None
1160
995
 
1161
996
  def upload_to_dataset(
1162
997
  self,
@@ -1188,8 +1023,6 @@ class Sutro:
1188
1023
  if dataset_id is None:
1189
1024
  dataset_id = self.create_dataset()
1190
1025
 
1191
- endpoint = f"{self.base_url}/upload-to-dataset"
1192
-
1193
1026
  if isinstance(file_paths, str):
1194
1027
  # check if the file path is a directory
1195
1028
  if os.path.isdir(file_paths):
@@ -1222,8 +1055,6 @@ class Sutro:
1222
1055
  "dataset_id": dataset_id,
1223
1056
  }
1224
1057
 
1225
- headers = {"Authorization": f"Key {self.api_key}"}
1226
-
1227
1058
  count += 1
1228
1059
  spinner.write(
1229
1060
  to_colored_text(
@@ -1232,25 +1063,18 @@ class Sutro:
1232
1063
  )
1233
1064
 
1234
1065
  try:
1235
- response = requests.post(
1236
- endpoint, headers=headers, data=payload, files=files
1066
+ self.do_request(
1067
+ "POST",
1068
+ "/upload-to-dataset",
1069
+ data=payload,
1070
+ files=files,
1071
+ verify=verify_ssl,
1237
1072
  )
1238
- if response.status_code != 200:
1239
- # Stop spinner before showing error to avoid terminal width error
1240
- spinner.stop()
1241
- print(
1242
- to_colored_text(
1243
- f"Error: HTTP {response.status_code}", state="fail"
1244
- )
1245
- )
1246
- print(to_colored_text(response.json(), state="fail"))
1247
- return
1248
-
1249
1073
  except requests.exceptions.RequestException as e:
1250
1074
  # Stop spinner before showing error to avoid terminal width error
1251
1075
  spinner.stop()
1252
1076
  print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
1253
- return
1077
+ return None
1254
1078
 
1255
1079
  spinner.write(
1256
1080
  to_colored_text(
@@ -1260,32 +1084,23 @@ class Sutro:
1260
1084
  return dataset_id
1261
1085
 
1262
1086
  def list_datasets(self):
1263
- endpoint = f"{self.base_url}/list-datasets"
1264
- headers = {
1265
- "Authorization": f"Key {self.api_key}",
1266
- "Content-Type": "application/json",
1267
- }
1268
1087
  with yaspin(
1269
1088
  SPINNER, text=to_colored_text("Retrieving datasets"), color=YASPIN_COLOR
1270
1089
  ) as spinner:
1271
- response = requests.post(endpoint, headers=headers)
1272
- if response.status_code != 200:
1090
+ try:
1091
+ response = self.do_request("POST", "list-datasets")
1092
+ spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
1093
+ return response.json()["datasets"]
1094
+ except requests.HTTPError as e:
1273
1095
  spinner.fail(
1274
1096
  to_colored_text(
1275
- f"Bad status code: {response.status_code}", state="fail"
1097
+ f"Bad status code: {e.response.status_code}", state="fail"
1276
1098
  )
1277
1099
  )
1278
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1279
- return
1280
- spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
1281
- return response.json()["datasets"]
1100
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1101
+ return None
1282
1102
 
1283
1103
  def list_dataset_files(self, dataset_id: str):
1284
- endpoint = f"{self.base_url}/list-dataset-files"
1285
- headers = {
1286
- "Authorization": f"Key {self.api_key}",
1287
- "Content-Type": "application/json",
1288
- }
1289
1104
  payload = {
1290
1105
  "dataset_id": dataset_id,
1291
1106
  }
@@ -1294,23 +1109,22 @@ class Sutro:
1294
1109
  text=to_colored_text(f"Listing files in dataset: {dataset_id}"),
1295
1110
  color=YASPIN_COLOR,
1296
1111
  ) as spinner:
1297
- response = requests.post(
1298
- endpoint, headers=headers, data=json.dumps(payload)
1299
- )
1300
- if response.status_code != 200:
1301
- spinner.fail(
1112
+ try:
1113
+ response = self.do_request("POST", "list-dataset-files", json=payload)
1114
+ spinner.write(
1302
1115
  to_colored_text(
1303
- f"Bad status code: {response.status_code}", state="fail"
1116
+ f" Files listed in dataset: {dataset_id}", state="success"
1304
1117
  )
1305
1118
  )
1306
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1307
- return
1308
- spinner.write(
1309
- to_colored_text(
1310
- f" Files listed in dataset: {dataset_id}", state="success"
1119
+ return response.json()["files"]
1120
+ except requests.HTTPError as e:
1121
+ spinner.fail(
1122
+ to_colored_text(
1123
+ f"Bad status code: {e.response.status_code}", state="fail"
1124
+ )
1311
1125
  )
1312
- )
1313
- return response.json()["files"]
1126
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1127
+ return None
1314
1128
 
1315
1129
  def download_from_dataset(
1316
1130
  self,
@@ -1318,8 +1132,6 @@ class Sutro:
1318
1132
  files: Union[List[str], str] = None,
1319
1133
  output_path: str = None,
1320
1134
  ):
1321
- endpoint = f"{self.base_url}/download-from-dataset"
1322
-
1323
1135
  if files is None:
1324
1136
  files = self.list_dataset_files(dataset_id)
1325
1137
  elif isinstance(files, str):
@@ -1344,32 +1156,32 @@ class Sutro:
1344
1156
  ) as spinner:
1345
1157
  count = 0
1346
1158
  for file in files:
1347
- headers = {
1348
- "Authorization": f"Key {self.api_key}",
1349
- "Content-Type": "application/json",
1350
- }
1351
- payload = {
1352
- "dataset_id": dataset_id,
1353
- "file_name": file,
1354
- }
1355
1159
  spinner.text = to_colored_text(
1356
1160
  f"Downloading file {count + 1}/{len(files)} from dataset: {dataset_id}"
1357
1161
  )
1358
- response = requests.post(
1359
- endpoint, headers=headers, data=json.dumps(payload)
1360
- )
1361
- if response.status_code != 200:
1162
+
1163
+ try:
1164
+ payload = {
1165
+ "dataset_id": dataset_id,
1166
+ "file_name": file,
1167
+ }
1168
+ response = self.do_request(
1169
+ "POST", "download-from-dataset", json=payload
1170
+ )
1171
+
1172
+ file_content = response.content
1173
+ with open(os.path.join(output_path, file), "wb") as f:
1174
+ f.write(file_content)
1175
+
1176
+ count += 1
1177
+ except requests.HTTPError as e:
1362
1178
  spinner.fail(
1363
1179
  to_colored_text(
1364
- f"Bad status code: {response.status_code}", state="fail"
1180
+ f"Bad status code: {e.response.status_code}", state="fail"
1365
1181
  )
1366
1182
  )
1367
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1183
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1368
1184
  return
1369
- file_content = response.content
1370
- with open(os.path.join(output_path, file), "wb") as f:
1371
- f.write(file_content)
1372
- count += 1
1373
1185
  spinner.write(
1374
1186
  to_colored_text(
1375
1187
  f"✔ {count} files successfully downloaded from dataset: {dataset_id}",
@@ -1389,46 +1201,38 @@ class Sutro:
1389
1201
  Returns:
1390
1202
  dict: The status of the authentication.
1391
1203
  """
1392
- endpoint = f"{self.base_url}/try-authentication"
1393
- headers = {
1394
- "Authorization": f"Key {api_key}",
1395
- "Content-Type": "application/json",
1396
- }
1397
1204
  with yaspin(
1398
1205
  SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
1399
1206
  ) as spinner:
1400
- response = requests.get(endpoint, headers=headers)
1401
- if response.status_code == 200:
1207
+ try:
1208
+ response = self.do_request("GET", "try-authentication", api_key)
1209
+
1402
1210
  spinner.write(to_colored_text("✔"))
1403
- else:
1211
+ return response.json()
1212
+ except requests.HTTPError as e:
1404
1213
  spinner.write(
1405
1214
  to_colored_text(
1406
- f"API key failed to authenticate: {response.status_code}",
1215
+ f"API key failed to authenticate: {e.response.status_code}",
1407
1216
  state="fail",
1408
1217
  )
1409
1218
  )
1410
- return
1411
- return response.json()
1219
+ return None
1412
1220
 
1413
1221
  def get_quotas(self):
1414
- endpoint = f"{self.base_url}/get-quotas"
1415
- headers = {
1416
- "Authorization": f"Key {self.api_key}",
1417
- "Content-Type": "application/json",
1418
- }
1419
1222
  with yaspin(
1420
1223
  SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
1421
1224
  ) as spinner:
1422
- response = requests.get(endpoint, headers=headers)
1423
- if response.status_code != 200:
1225
+ try:
1226
+ response = self.do_request("GET", "get-quotas")
1227
+ return response.json()["quotas"]
1228
+ except requests.HTTPError as e:
1424
1229
  spinner.fail(
1425
1230
  to_colored_text(
1426
- f"Bad status code: {response.status_code}", state="fail"
1231
+ f"Bad status code: {e.response.status_code}", state="fail"
1427
1232
  )
1428
1233
  )
1429
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1430
- return
1431
- return response.json()["quotas"]
1234
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1235
+ return None
1432
1236
 
1433
1237
  def await_job_completion(
1434
1238
  self,
@@ -1436,7 +1240,7 @@ class Sutro:
1436
1240
  timeout: Optional[int] = 7200,
1437
1241
  obtain_results: bool = True,
1438
1242
  is_cost_estimate: bool = False,
1439
- ) -> list | None:
1243
+ ) -> pl.DataFrame | None:
1440
1244
  """
1441
1245
  Waits for job completion to occur and then returns the results upon
1442
1246
  a successful completion.
@@ -1448,11 +1252,11 @@ class Sutro:
1448
1252
  timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
1449
1253
 
1450
1254
  Returns:
1451
- list: The results of the job.
1255
+ pl.DataFrame: The results of the job in a polars DataFrame.
1452
1256
  """
1453
1257
  POLL_INTERVAL = 5
1454
1258
 
1455
- results = None
1259
+ results: pl.DataFrame | None = None
1456
1260
  start_time = time.time()
1457
1261
  with yaspin(
1458
1262
  SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR