sutro 0.1.34__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sutro/sdk.py CHANGED
@@ -1,45 +1,32 @@
1
- from enum import Enum
2
1
  import requests
3
2
  import pandas as pd
4
3
  import polars as pl
5
4
  import json
6
- from typing import Union, List, Optional, Literal, Dict, Any
5
+ from typing import Union, List, Optional, Dict, Any, Type
7
6
  import os
8
7
  import sys
9
8
  from yaspin import yaspin
10
9
  from yaspin.spinners import Spinners
11
- from colorama import init, Fore, Style
12
- from tqdm import tqdm
10
+ from colorama import init
13
11
  import time
14
12
  from pydantic import BaseModel
15
13
  import pyarrow.parquet as pq
16
14
  import shutil
17
-
18
-
19
- class JobStatus(str, Enum):
20
- """Job statuses that will be returned by the API & SDK"""
21
-
22
- UNKNOWN = "UNKNOWN"
23
- QUEUED = "QUEUED" # Job is waiting to start
24
- STARTING = "STARTING" # Job is in the process of starting up
25
- RUNNING = "RUNNING" # Job is actively running
26
- SUCCEEDED = "SUCCEEDED" # Job completed successfully
27
- CANCELLING = "CANCELLING" # Job is in the process of being canceled
28
- CANCELLED = "CANCELLED" # Job was canceled by the user
29
- FAILED = "FAILED" # Job failed
30
-
31
- @classmethod
32
- def terminal_statuses(cls) -> list["JobStatus"]:
33
- return [
34
- cls.SUCCEEDED,
35
- cls.FAILED,
36
- cls.CANCELLING,
37
- cls.CANCELLED,
38
- ]
39
-
40
- def is_terminal(self) -> bool:
41
- return self in self.terminal_statuses()
42
-
15
+ from sutro.common import (
16
+ ModelOptions,
17
+ handle_data_helper,
18
+ normalize_output_schema,
19
+ to_colored_text,
20
+ fancy_tqdm,
21
+ )
22
+ from sutro.interfaces import JobStatus
23
+ from sutro.templates.classification import ClassificationTemplates
24
+ from sutro.templates.embed import EmbeddingTemplates
25
+ from sutro.templates.evals import EvalTemplates
26
+ from sutro.validation import check_version, check_for_api_key
27
+
28
+ JOB_NAME_CHAR_LIMIT = 45
29
+ JOB_DESCRIPTION_CHAR_LIMIT = 512
43
30
 
44
31
  # Initialize colorama (required for Windows)
45
32
  init()
@@ -50,56 +37,11 @@ def is_jupyter() -> bool:
50
37
  return not sys.stdout.isatty()
51
38
 
52
39
 
53
- # `color` param not supported in Jupyter notebooks
54
- YASPIN_COLOR = None if is_jupyter() else "blue"
40
+ # Adding color to text is not supported in Jupyter notebooks and breaks
41
+ # things
42
+ BASE_OUTPUT_COLOR = None if is_jupyter() else "blue"
55
43
  SPINNER = Spinners.dots14
56
44
 
57
- # Models available for inference. Keep in sync with the backend configuration
58
- # so users get helpful autocompletion when selecting a model.
59
- ModelOptions = Literal[
60
- "llama-3.2-3b",
61
- "llama-3.1-8b",
62
- "llama-3.3-70b",
63
- "llama-3.3-70b",
64
- "qwen-3-4b",
65
- "qwen-3-32b",
66
- "qwen-3-4b-thinking",
67
- "qwen-3-32b-thinking",
68
- "gemma-3-4b-it",
69
- "gemma-3-27b-it",
70
- "gpt-oss-120b",
71
- "gpt-oss-20b",
72
- "qwen-3-235b-a22b-thinking",
73
- "qwen-3-30b-a3b-thinking",
74
- "qwen-3-embedding-0.6b",
75
- "qwen-3-embedding-6b",
76
- "qwen-3-embedding-8b",
77
- ]
78
-
79
-
80
- def to_colored_text(
81
- text: str, state: Optional[Literal["success", "fail"]] = None
82
- ) -> str:
83
- """
84
- Apply color to text based on state.
85
-
86
- Args:
87
- text (str): The text to color
88
- state (Optional[Literal['success', 'fail']]): The state that determines the color.
89
- Options: 'success', 'fail', or None (default blue)
90
-
91
- Returns:
92
- str: Text with appropriate color applied
93
- """
94
- match state:
95
- case "success":
96
- return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
97
- case "fail":
98
- return f"{Fore.RED}{text}{Style.RESET_ALL}"
99
- case _:
100
- # Default to blue for normal/processing states
101
- return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
102
-
103
45
 
104
46
  # Isn't fully support in all terminals unfortunately. We should switch to Rich
105
47
  # at some point, but even Rich links aren't clickable on MacOS Terminal
@@ -108,41 +50,20 @@ def make_clickable_link(url, text=None):
108
50
  Create a clickable link for terminals that support OSC 8 hyperlinks.
109
51
  Falls back to plain text for terminals that don't support it.
110
52
  """
53
+ # Don't need to add the special chars for jupyter notebook
54
+ if is_jupyter():
55
+ return url
56
+
111
57
  if text is None:
112
58
  text = url
113
59
  return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
114
60
 
115
61
 
116
- class Sutro:
62
+ class Sutro(EmbeddingTemplates, ClassificationTemplates, EvalTemplates):
117
63
  def __init__(self, api_key: str = None, base_url: str = "https://api.sutro.sh/"):
118
- self.api_key = api_key or self.check_for_api_key()
64
+ self.api_key = api_key or check_for_api_key()
119
65
  self.base_url = base_url
120
-
121
- def check_for_api_key(self):
122
- """
123
- Check for an API key in the user's home directory.
124
-
125
- This method looks for a configuration file named 'config.json' in the
126
- '.sutro' directory within the user's home directory.
127
- If the file exists, it attempts to read the API key from it.
128
-
129
- Returns:
130
- str or None: The API key if found in the configuration file, or None if not found.
131
-
132
- Note:
133
- The expected structure of the config.json file is:
134
- {
135
- "api_key": "your_api_key_here"
136
- }
137
- """
138
- CONFIG_DIR = os.path.expanduser("~/.sutro")
139
- CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
140
- if os.path.exists(CONFIG_FILE):
141
- with open(CONFIG_FILE, "r") as f:
142
- config = json.load(f)
143
- return config.get("api_key")
144
- else:
145
- return None
66
+ check_version("sutro")
146
67
 
147
68
  def set_api_key(self, api_key: str):
148
69
  """
@@ -159,43 +80,6 @@ class Sutro:
159
80
  """
160
81
  self.api_key = api_key
161
82
 
162
- def handle_data_helper(
163
- self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
164
- ):
165
- if isinstance(data, list):
166
- input_data = data
167
- elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
168
- if column is None:
169
- raise ValueError("Column name must be specified for DataFrame input")
170
- input_data = data[column].to_list()
171
- elif isinstance(data, str):
172
- if data.startswith("dataset-"):
173
- input_data = data + ":" + column
174
- else:
175
- file_ext = os.path.splitext(data)[1].lower()
176
- if file_ext == ".csv":
177
- df = pl.read_csv(data)
178
- elif file_ext == ".parquet":
179
- df = pl.read_parquet(data)
180
- elif file_ext in [".txt", ""]:
181
- with open(data, "r") as file:
182
- input_data = [line.strip() for line in file]
183
- else:
184
- raise ValueError(f"Unsupported file type: {file_ext}")
185
-
186
- if file_ext in [".csv", ".parquet"]:
187
- if column is None:
188
- raise ValueError(
189
- "Column name must be specified for CSV/Parquet input"
190
- )
191
- input_data = df[column].to_list()
192
- else:
193
- raise ValueError(
194
- "Unsupported data type. Please provide a list, DataFrame, or file path."
195
- )
196
-
197
- return input_data
198
-
199
83
  def set_base_url(self, base_url: str):
200
84
  """
201
85
  Set the base URL for the Sutro API.
@@ -208,11 +92,48 @@ class Sutro:
208
92
  """
209
93
  self.base_url = base_url
210
94
 
95
+ def do_request(
96
+ self,
97
+ method: str,
98
+ endpoint: str,
99
+ api_key_override: Optional[str] = None,
100
+ **kwargs: Any,
101
+ ):
102
+ """
103
+ Helper to make authenticated requests.
104
+ """
105
+ key = self.api_key if not api_key_override else api_key_override
106
+ headers = {"Authorization": f"Key {key}"}
107
+
108
+ # Merge with any headers passed in kwargs
109
+ if "headers" in kwargs:
110
+ headers.update(kwargs.pop("headers"))
111
+
112
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
113
+
114
+ # Explicit method dispatch
115
+ method = method.upper()
116
+ if method == "GET":
117
+ response = requests.get(url, headers=headers, **kwargs)
118
+ elif method == "POST":
119
+ response = requests.post(url, headers=headers, **kwargs)
120
+ elif method == "PUT":
121
+ response = requests.put(url, headers=headers, **kwargs)
122
+ elif method == "DELETE":
123
+ response = requests.delete(url, headers=headers, **kwargs)
124
+ elif method == "PATCH":
125
+ response = requests.patch(url, headers=headers, **kwargs)
126
+ else:
127
+ raise ValueError(f"Unsupported HTTP method: {method}")
128
+
129
+ response.raise_for_status()
130
+ return response
131
+
211
132
  def _run_one_batch_inference(
212
133
  self,
213
134
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
214
135
  model: ModelOptions,
215
- column: str,
136
+ column: Union[str, List[str]],
216
137
  output_column: str,
217
138
  job_priority: int,
218
139
  json_schema: Dict[str, Any],
@@ -222,13 +143,20 @@ class Sutro:
222
143
  stay_attached: Optional[bool],
223
144
  random_seed_per_input: bool,
224
145
  truncate_rows: bool,
146
+ name: str,
147
+ description: str,
225
148
  ):
226
- input_data = self.handle_data_helper(data, column)
227
- endpoint = f"{self.base_url}/batch-inference"
228
- headers = {
229
- "Authorization": f"Key {self.api_key}",
230
- "Content-Type": "application/json",
231
- }
149
+ # Validate name and description lengths
150
+ if name is not None and len(name) > JOB_NAME_CHAR_LIMIT:
151
+ raise ValueError(
152
+ f"Job name cannot exceed {JOB_NAME_CHAR_LIMIT} characters."
153
+ )
154
+ if description is not None and len(description) > JOB_DESCRIPTION_CHAR_LIMIT:
155
+ raise ValueError(
156
+ f"Job description cannot exceed {JOB_DESCRIPTION_CHAR_LIMIT} characters."
157
+ )
158
+
159
+ input_data = handle_data_helper(data, column)
232
160
  payload = {
233
161
  "model": model,
234
162
  "inputs": input_data,
@@ -239,6 +167,8 @@ class Sutro:
239
167
  "sampling_params": sampling_params,
240
168
  "random_seed_per_input": random_seed_per_input,
241
169
  "truncate_rows": truncate_rows,
170
+ "name": name,
171
+ "description": description,
242
172
  }
243
173
 
244
174
  # There are two gotchas with yaspin:
@@ -250,18 +180,21 @@ class Sutro:
250
180
  job_id = None
251
181
  t = f"Creating {'[cost estimate] ' if cost_estimate else ''}priority {job_priority} job"
252
182
  spinner_text = to_colored_text(t)
183
+
253
184
  try:
254
- with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
255
- response = requests.post(
256
- endpoint, data=json.dumps(payload), headers=headers
257
- )
258
- response_data = response.json()
185
+ with yaspin(SPINNER, text=spinner_text, color=BASE_OUTPUT_COLOR) as spinner:
186
+ try:
187
+ response = self.do_request("POST", "batch-inference", json=payload)
188
+ response_data = response.json()
189
+ except requests.HTTPError as e:
190
+ response = e.response
191
+ response_data = response.json()
259
192
  if response.status_code != 200:
260
193
  spinner.write(
261
194
  to_colored_text(f"Error: {response.status_code}", state="fail")
262
195
  )
263
196
  spinner.stop()
264
- print(to_colored_text(response.json(), state="fail"))
197
+ print(to_colored_text(response_data, state="fail"))
265
198
  return None
266
199
  else:
267
200
  job_id = response_data["results"]
@@ -284,12 +217,14 @@ class Sutro:
284
217
  )
285
218
  return job_id
286
219
  else:
220
+ name_text = f" and name {name}" if name is not None else ""
287
221
  spinner.write(
288
222
  to_colored_text(
289
- f"🛠 Priority {job_priority} Job created with ID: {job_id}.",
223
+ f"🛠 Priority {job_priority} Job created with ID: {job_id}{name_text}",
290
224
  state="success",
291
225
  )
292
226
  )
227
+ spinner.write(to_colored_text(f"Model: {model}"))
293
228
  if not stay_attached:
294
229
  clickable_link = make_clickable_link(
295
230
  f"https://app.sutro.sh/jobs/{job_id}"
@@ -326,20 +261,20 @@ class Sutro:
326
261
  )
327
262
  )
328
263
  return None
329
- s = requests.Session()
264
+
330
265
  pbar = None
331
266
 
332
267
  try:
333
- with requests.get(
334
- f"{self.base_url}/stream-job-progress/{job_id}",
335
- headers=headers,
268
+ with self.do_request(
269
+ "GET",
270
+ f"/stream-job-progress/{job_id}",
336
271
  stream=True,
337
272
  ) as streaming_response:
338
273
  streaming_response.raise_for_status()
339
274
  spinner = yaspin(
340
275
  SPINNER,
341
276
  text=to_colored_text("Awaiting status updates..."),
342
- color=YASPIN_COLOR,
277
+ color=BASE_OUTPUT_COLOR,
343
278
  )
344
279
  spinner.start()
345
280
 
@@ -361,7 +296,7 @@ class Sutro:
361
296
  if pbar is None:
362
297
  spinner.stop()
363
298
  postfix = "Input tokens processed: 0"
364
- pbar = self.fancy_tqdm(
299
+ pbar = fancy_tqdm(
365
300
  total=len(input_data),
366
301
  desc="Progress",
367
302
  style=1,
@@ -402,28 +337,27 @@ class Sutro:
402
337
  )
403
338
  spinner.start()
404
339
 
405
- payload = {
406
- "job_id": job_id,
407
- }
408
-
409
340
  # TODO: we implment retries in cases where the job hasn't written results yet
410
341
  # it would be better if we could receive a fully succeeded status from the job
411
342
  # and not have such a race condition
412
343
  max_retries = 20 # winds up being 100 seconds cumulative delay
413
344
  retry_delay = 5 # initial delay in seconds
414
-
345
+ job_results_response = None
415
346
  for _ in range(max_retries):
416
- time.sleep(retry_delay)
417
-
418
- job_results_response = s.post(
419
- f"{self.base_url}/job-results",
420
- headers=headers,
421
- data=json.dumps(payload),
422
- )
423
- if job_results_response.status_code == 200:
347
+ try:
348
+ job_results_response = self.do_request(
349
+ "POST",
350
+ "job-results",
351
+ json={
352
+ "job_id": job_id,
353
+ },
354
+ )
424
355
  break
356
+ except requests.HTTPError:
357
+ time.sleep(retry_delay)
358
+ continue
425
359
 
426
- if job_results_response.status_code != 200:
360
+ if not job_results_response or job_results_response.status_code != 200:
427
361
  spinner.write(
428
362
  to_colored_text(
429
363
  "Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
@@ -435,94 +369,183 @@ class Sutro:
435
369
 
436
370
  results = job_results_response.json()["results"]["outputs"]
437
371
 
438
- spinner.write(
439
- to_colored_text(
440
- f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
441
- state="success",
442
- )
443
- )
444
- spinner.stop()
445
-
446
372
  if isinstance(data, (pd.DataFrame, pl.DataFrame)):
447
373
  if isinstance(data, pd.DataFrame):
448
374
  data[output_column] = results
449
375
  elif isinstance(data, pl.DataFrame):
450
376
  data = data.with_columns(pl.Series(output_column, results))
451
- return data
377
+ print(data)
378
+ spinner.write(
379
+ to_colored_text(
380
+ f"✔ Displaying result preview. You can join the results on the original dataframe with `so.get_job_results('{job_id}', with_original_df=<original_df>)`",
381
+ state="success",
382
+ )
383
+ )
384
+ else:
385
+ print(results)
386
+ spinner.write(
387
+ to_colored_text(
388
+ f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
389
+ state="success",
390
+ )
391
+ )
392
+ spinner.stop()
452
393
 
453
- return results
394
+ return job_id
454
395
  return None
455
396
  return None
456
397
 
457
398
  def infer(
458
399
  self,
459
400
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
460
- model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
461
- column: str = None,
401
+ model: ModelOptions = "gemma-3-12b-it",
402
+ name: Optional[str] = None,
403
+ description: Optional[str] = None,
404
+ column: Union[str, List[str]] = None,
462
405
  output_column: str = "inference_result",
463
406
  job_priority: int = 0,
464
- output_schema: Union[Dict[str, Any], BaseModel] = None,
407
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
465
408
  sampling_params: dict = None,
466
409
  system_prompt: str = None,
467
410
  dry_run: bool = False,
468
411
  stay_attached: Optional[bool] = None,
469
412
  random_seed_per_input: bool = False,
470
- truncate_rows: bool = False,
413
+ truncate_rows: bool = True,
471
414
  ):
472
415
  """
473
416
  Run inference on the provided data.
474
417
 
475
418
  This method allows you to run inference on the provided data using the Sutro API.
476
- It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
419
+ It supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
477
420
 
478
421
  Args:
479
422
  data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
480
- model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
481
- column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
423
+ model (ModelOptions, optional): The model to use for inference. Defaults to "gemma-3-12b-it".
424
+ name (str, optional): A job name for experiment/metadata tracking purposes. Defaults to None.
425
+ description (str, optional): A job description for experiment/metadata tracking purposes. Defaults to None.
426
+ column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
482
427
  output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
483
428
  job_priority (int, optional): The priority of the job. Defaults to 0.
484
429
  output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
485
- Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
430
+ Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
486
431
  sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
487
432
  system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
488
433
  dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
489
434
  stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
490
435
  random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
491
- truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
436
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
492
437
 
493
438
  Returns:
494
- Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
439
+ str: The ID of the inference job.
495
440
 
496
441
  """
497
- if isinstance(model, list) == False:
498
- model_list = [model]
499
- stay_attached = (
500
- stay_attached if stay_attached is not None else job_priority == 0
501
- )
442
+ # Default stay_attached to True for prototyping jobs (priority 0)
443
+ if stay_attached is None:
444
+ stay_attached = job_priority == 0
445
+
446
+ json_schema = None
447
+ if output_schema:
448
+ # Convert BaseModel to dict if needed
449
+ json_schema = normalize_output_schema(output_schema)
450
+
451
+ return self._run_one_batch_inference(
452
+ data,
453
+ model,
454
+ column,
455
+ output_column,
456
+ job_priority,
457
+ json_schema,
458
+ sampling_params,
459
+ system_prompt,
460
+ dry_run,
461
+ stay_attached,
462
+ random_seed_per_input,
463
+ truncate_rows,
464
+ name,
465
+ description,
466
+ )
467
+
468
+ def infer_per_model(
469
+ self,
470
+ data: Union[List, pd.DataFrame, pl.DataFrame, str],
471
+ models: List[ModelOptions],
472
+ names: List[str] = None,
473
+ descriptions: List[str] = None,
474
+ column: Union[str, List[str]] = None,
475
+ output_column: str = "inference_result",
476
+ job_priority: int = 0,
477
+ output_schema: Union[Dict[str, Any], Type[BaseModel]] = None,
478
+ sampling_params: dict = None,
479
+ system_prompt: str = None,
480
+ dry_run: bool = False,
481
+ random_seed_per_input: bool = False,
482
+ truncate_rows: bool = True,
483
+ ):
484
+ """
485
+ Run inference on the provided data, across multiple models. This method is often useful to sampling outputs from multiple models across the same dataset and compare the job_ids.
486
+
487
+ For input data, it supports various data types such as lists, DataFrames (Polars or Pandas), file paths and datasets.
488
+
489
+ Args:
490
+ data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
491
+ models (Union[ModelOptions, List[ModelOptions]], optional): The models to use for inference. Fans out each model to its own seperate job, over the same dataset.
492
+ names (Union[str, List[str]], optional): A job name for experiment/metadata tracking purposes. If using a list of models, you must pass a list of names with length equal to the number of models, or None. Defaults to None.
493
+ descriptions (Union[str, List[str]], optional): A job description for experiment/metadata tracking purposes. If using a list of models, you must pass a list of descriptions with length equal to the number of models, or None. Defaults to None.
494
+ column (Union[str, List[str]], optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset. If a list is supplied, it will concatenate the columns of the list into a single column, accepting separator strings.
495
+ output_column (str, optional): The column name to store the inference job_ids in if the input is a DataFrame. Defaults to "inference_result".
496
+ job_priority (int, optional): The priority of the job. Defaults to 0.
497
+ output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
498
+ Can be either a dictionary representing a JSON schema or a class that inherits from Pydantic BaseModel. Defaults to None.
499
+ sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
500
+ system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
501
+ dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
502
+ stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
503
+ random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
504
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to True.
505
+
506
+ Returns:
507
+ str: The ID of the inference job.
508
+
509
+ """
510
+ if isinstance(names, list):
511
+ if len(names) != len(models):
512
+ raise ValueError(
513
+ "names parameter must be the same length as the models parameter."
514
+ )
515
+ elif names is None:
516
+ names = [None] * len(models)
502
517
  else:
503
- model_list = model
504
- stay_attached = False
505
-
506
- # Convert BaseModel to dict if needed
507
- if output_schema is not None:
508
- if hasattr(
509
- output_schema, "model_json_schema"
510
- ): # Check for pydantic Model interface
511
- json_schema = output_schema.model_json_schema()
512
- elif isinstance(output_schema, dict):
513
- json_schema = output_schema
514
- else:
518
+ raise ValueError(
519
+ "names parameter must be a list or None if using a list of models"
520
+ )
521
+
522
+ if isinstance(descriptions, list):
523
+ if len(descriptions) != len(models):
515
524
  raise ValueError(
516
- "Invalid output schema type. Must be a dictionary or a pydantic Model."
525
+ "descriptions parameter must be the same length as the models"
526
+ " parameter."
517
527
  )
528
+ elif descriptions is None:
529
+ descriptions = [None] * len(models)
518
530
  else:
519
- json_schema = None
531
+ raise ValueError(
532
+ "descriptions parameter must be a list or None if using a list of "
533
+ "models"
534
+ )
520
535
 
521
- results = []
522
- for model in model_list:
523
- res = self._run_one_batch_inference(
536
+ json_schema = None
537
+ if output_schema:
538
+ # Convert BaseModel to dict if needed
539
+ json_schema = normalize_output_schema(output_schema)
540
+
541
+ def start_job(
542
+ model_singleton: ModelOptions,
543
+ name_singleton: str | None,
544
+ description_singleton: str | None,
545
+ ):
546
+ return self._run_one_batch_inference(
524
547
  data,
525
- model,
548
+ model_singleton,
526
549
  column,
527
550
  output_column,
528
551
  job_priority,
@@ -530,18 +553,21 @@ class Sutro:
530
553
  sampling_params,
531
554
  system_prompt,
532
555
  dry_run,
533
- stay_attached,
556
+ False,
534
557
  random_seed_per_input,
535
558
  truncate_rows,
559
+ name_singleton,
560
+ description_singleton,
536
561
  )
537
- results.append(res)
538
562
 
539
- if len(results) > 1:
540
- return results
541
- elif len(results) == 1:
542
- return results[0]
563
+ job_ids = [
564
+ start_job(model, name, description)
565
+ for model, name, description in zip(
566
+ models, names, descriptions, strict=True
567
+ )
568
+ ]
543
569
 
544
- return None
570
+ return job_ids
545
571
 
546
572
  def attach(self, job_id):
547
573
  """
@@ -552,20 +578,12 @@ class Sutro:
552
578
  """
553
579
 
554
580
  s = requests.Session()
555
- payload = {
556
- "job_id": job_id,
557
- }
558
581
  pbar = None
559
582
 
560
- headers = {
561
- "Authorization": f"Key {self.api_key}",
562
- "Content-Type": "application/json",
563
- }
564
-
565
583
  with yaspin(
566
584
  SPINNER,
567
585
  text=to_colored_text("Looking for job..."),
568
- color=YASPIN_COLOR,
586
+ color=BASE_OUTPUT_COLOR,
569
587
  ) as spinner:
570
588
  # Fetch the specific job we want to attach to
571
589
  job = self._fetch_job(job_id)
@@ -599,16 +617,16 @@ class Sutro:
599
617
  success = False
600
618
 
601
619
  try:
602
- with s.get(
603
- f"{self.base_url}/stream-job-progress/{job_id}",
604
- headers=headers,
620
+ with self.do_request(
621
+ "GET",
622
+ f"/stream-job-progress/{job_id}",
605
623
  stream=True,
606
624
  ) as streaming_response:
607
625
  streaming_response.raise_for_status()
608
626
  spinner = yaspin(
609
627
  SPINNER,
610
628
  text=to_colored_text("Awaiting status updates..."),
611
- color=YASPIN_COLOR,
629
+ color=BASE_OUTPUT_COLOR,
612
630
  )
613
631
  clickable_link = make_clickable_link(
614
632
  f"https://app.sutro.sh/jobs/{job_id}"
@@ -631,7 +649,7 @@ class Sutro:
631
649
  if pbar is None:
632
650
  spinner.stop()
633
651
  postfix = "Input tokens processed: 0"
634
- pbar = self.fancy_tqdm(
652
+ pbar = fancy_tqdm(
635
653
  total=total_rows,
636
654
  desc="Progress",
637
655
  style=1,
@@ -668,7 +686,7 @@ class Sutro:
668
686
  self,
669
687
  total: int,
670
688
  desc: str = "Progress",
671
- color: str = "blue",
689
+ color: str = BASE_OUTPUT_COLOR,
672
690
  style=1,
673
691
  postfix: str = None,
674
692
  ):
@@ -730,56 +748,36 @@ class Sutro:
730
748
  This method retrieves a list of all jobs associated with the API key.
731
749
 
732
750
  Returns:
733
- list: A list of job details.
751
+ list: A list of job details, or None if the request fails.
734
752
  """
735
- endpoint = f"{self.base_url}/list-jobs"
736
- headers = {
737
- "Authorization": f"Key {self.api_key}",
738
- "Content-Type": "application/json",
739
- }
740
-
741
753
  with yaspin(
742
- SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
754
+ SPINNER, text=to_colored_text("Fetching jobs"), color=BASE_OUTPUT_COLOR
743
755
  ) as spinner:
744
- response = requests.get(endpoint, headers=headers)
745
- if response.status_code != 200:
756
+ try:
757
+ return self._list_all_jobs_for_user()
758
+ except requests.HTTPError as e:
746
759
  spinner.write(
747
760
  to_colored_text(
748
- f"Bad status code: {response.status_code}", state="fail"
761
+ f"Bad status code: {e.response.status_code}", state="fail"
749
762
  )
750
763
  )
751
764
  spinner.stop()
752
- print(to_colored_text(response.json(), state="fail"))
753
- return
754
- return response.json()["jobs"]
765
+ print(to_colored_text(e.response.json(), state="fail"))
766
+ return None
755
767
 
756
- def _list_jobs_helper(self):
757
- """
758
- Helper function to list jobs.
759
- """
760
- endpoint = f"{self.base_url}/list-jobs˚"
761
- headers = {
762
- "Authorization": f"Key {self.api_key}",
763
- "Content-Type": "application/json",
764
- }
765
- response = requests.get(endpoint, headers=headers)
766
- if response.status_code != 200:
767
- return None
768
+ def _list_all_jobs_for_user(self):
769
+ response = self.do_request("GET", "list-jobs")
768
770
  return response.json()["jobs"]
769
771
 
770
772
  def _fetch_job(self, job_id):
771
773
  """
772
774
  Helper function to fetch a single job.
773
775
  """
774
- endpoint = f"{self.base_url}/jobs/{job_id}"
775
- headers = {
776
- "Authorization": f"Key {self.api_key}",
777
- "Content-Type": "application/json",
778
- }
779
- response = requests.get(endpoint, headers=headers)
780
- if response.status_code != 200:
776
+ try:
777
+ response = self.do_request("GET", f"jobs/{job_id}")
778
+ return response.json().get("job")
779
+ except requests.HTTPError:
781
780
  return None
782
- return response.json().get("job")
783
781
 
784
782
  def _get_job_cost_estimate(self, job_id: str):
785
783
  """
@@ -813,15 +811,7 @@ class Sutro:
813
811
  Raises:
814
812
  requests.HTTPError: If the API returns a non-200 status code.
815
813
  """
816
- endpoint = f"{self.base_url}/job-status/{job_id}"
817
- headers = {
818
- "Authorization": f"Key {self.api_key}",
819
- "Content-Type": "application/json",
820
- }
821
-
822
- response = requests.get(endpoint, headers=headers)
823
- response.raise_for_status()
824
-
814
+ response = self.do_request("GET", f"job-status/{job_id}")
825
815
  return response.json()["job_status"][job_id]
826
816
 
827
817
  def get_job_status(self, job_id: str):
@@ -839,7 +829,7 @@ class Sutro:
839
829
  with yaspin(
840
830
  SPINNER,
841
831
  text=to_colored_text(f"Checking job status with ID: {job_id}"),
842
- color=YASPIN_COLOR,
832
+ color=BASE_OUTPUT_COLOR,
843
833
  ) as spinner:
844
834
  try:
845
835
  response_data = self._fetch_job_status(job_id)
@@ -866,7 +856,7 @@ class Sutro:
866
856
  output_column: str = "inference_result",
867
857
  disable_cache: bool = False,
868
858
  unpack_json: bool = True,
869
- ):
859
+ ) -> pl.DataFrame | pd.DataFrame:
870
860
  """
871
861
  Get the results of a job by its ID.
872
862
 
@@ -896,51 +886,44 @@ class Sutro:
896
886
  with yaspin(
897
887
  SPINNER,
898
888
  text=to_colored_text(f"Loading results from cache: {file_path}"),
899
- color=YASPIN_COLOR,
889
+ color=BASE_OUTPUT_COLOR,
900
890
  ) as spinner:
901
891
  results_df = pl.read_parquet(file_path)
902
892
  spinner.write(
903
893
  to_colored_text("✔ Results loaded from cache", state="success")
904
894
  )
905
895
  else:
906
- endpoint = f"{self.base_url}/job-results"
907
896
  payload = {
908
897
  "job_id": job_id,
909
898
  "include_inputs": include_inputs,
910
899
  "include_cumulative_logprobs": include_cumulative_logprobs,
911
900
  }
912
- headers = {
913
- "Authorization": f"Key {self.api_key}",
914
- "Content-Type": "application/json",
915
- }
916
901
  with yaspin(
917
902
  SPINNER,
918
903
  text=to_colored_text(f"Gathering results from job: {job_id}"),
919
- color=YASPIN_COLOR,
904
+ color=BASE_OUTPUT_COLOR,
920
905
  ) as spinner:
921
- response = requests.post(
922
- endpoint, data=json.dumps(payload), headers=headers
923
- )
924
- if response.status_code != 200:
906
+ try:
907
+ response = self.do_request("POST", "job-results", json=payload)
908
+ response_data = response.json()
909
+ spinner.write(
910
+ to_colored_text("✔ Job results retrieved", state="success")
911
+ )
912
+ except requests.HTTPError as e:
925
913
  spinner.write(
926
914
  to_colored_text(
927
- f"Bad status code: {response.status_code}", state="fail"
915
+ f"Bad status code: {e.response.status_code}", state="fail"
928
916
  )
929
917
  )
930
918
  spinner.stop()
931
- print(to_colored_text(response.json(), state="fail"))
919
+ print(to_colored_text(e.response.json(), state="fail"))
932
920
  return None
933
921
 
934
- spinner.write(
935
- to_colored_text("✔ Job results retrieved", state="success")
936
- )
937
-
938
- response_data = response.json()
939
922
  results_df = pl.DataFrame(response_data["results"])
940
923
 
941
924
  results_df = results_df.rename({"outputs": output_column})
942
925
 
943
- if disable_cache == False:
926
+ if not disable_cache:
944
927
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
945
928
  results_df.write_parquet(file_path, compression="snappy")
946
929
  spinner.write(
@@ -967,10 +950,11 @@ class Sutro:
967
950
  first_row = json.loads(
968
951
  results_df.head(1)[output_column][0]
969
952
  ) # checks if the first row can be json decoded
953
+ results_df = results_df.map_columns(
954
+ output_column, lambda s: s.str.json_decode()
955
+ )
970
956
  results_df = results_df.with_columns(
971
- pl.col(output_column)
972
- .str.json_decode()
973
- .alias("output_column_json_decoded")
957
+ pl.col(output_column).alias("output_column_json_decoded")
974
958
  )
975
959
  json_decoded_fields = first_row.keys()
976
960
  for field in json_decoded_fields:
@@ -979,11 +963,20 @@ class Sutro:
979
963
  .struct.field(field)
980
964
  .alias(field)
981
965
  )
982
- # drop the output_column and the json decoded column
966
+ if sorted(list(set(json_decoded_fields))) == [
967
+ "content",
968
+ "reasoning_content",
969
+ ]: # if it's a reasoning model, we need to unpack the content field
970
+ content_keys = results_df.head(1)["content"][0].keys()
971
+ for key in content_keys:
972
+ results_df = results_df.with_columns(
973
+ pl.col("content").struct.field(key).alias(key)
974
+ )
975
+ results_df = results_df.drop("content")
983
976
  results_df = results_df.drop(
984
977
  [output_column, "output_column_json_decoded"]
985
978
  )
986
- except Exception as e:
979
+ except Exception:
987
980
  # if the first row cannot be json decoded, do nothing
988
981
  pass
989
982
 
@@ -1019,25 +1012,20 @@ class Sutro:
1019
1012
  Returns:
1020
1013
  dict: The status of the job.
1021
1014
  """
1022
- endpoint = f"{self.base_url}/job-cancel/{job_id}"
1023
- headers = {
1024
- "Authorization": f"Key {self.api_key}",
1025
- "Content-Type": "application/json",
1026
- }
1027
1015
  with yaspin(
1028
1016
  SPINNER,
1029
1017
  text=to_colored_text(f"Cancelling job: {job_id}"),
1030
- color=YASPIN_COLOR,
1018
+ color=BASE_OUTPUT_COLOR,
1031
1019
  ) as spinner:
1032
- response = requests.get(endpoint, headers=headers)
1033
- if response.status_code == 200:
1020
+ try:
1021
+ response = self.do_request("GET", f"job-cancel/{job_id}")
1034
1022
  spinner.write(to_colored_text("✔ Job cancelled", state="success"))
1035
- else:
1023
+ return response.json()
1024
+ except requests.HTTPError as e:
1036
1025
  spinner.write(to_colored_text("Failed to cancel job", state="fail"))
1037
1026
  spinner.stop()
1038
- print(to_colored_text(response.json(), state="fail"))
1039
- return
1040
- return response.json()
1027
+ print(to_colored_text(e.response.json(), state="fail"))
1028
+ return None
1041
1029
 
1042
1030
  def create_dataset(self):
1043
1031
  """
@@ -1048,31 +1036,27 @@ class Sutro:
1048
1036
  Returns:
1049
1037
  str: The ID of the new dataset.
1050
1038
  """
1051
- endpoint = f"{self.base_url}/create-dataset"
1052
- headers = {
1053
- "Authorization": f"Key {self.api_key}",
1054
- "Content-Type": "application/json",
1055
- }
1056
1039
  with yaspin(
1057
- SPINNER, text=to_colored_text("Creating dataset"), color=YASPIN_COLOR
1040
+ SPINNER, text=to_colored_text("Creating dataset"), color=BASE_OUTPUT_COLOR
1058
1041
  ) as spinner:
1059
- response = requests.get(endpoint, headers=headers)
1060
- if response.status_code != 200:
1042
+ try:
1043
+ response = self.do_request("GET", "create-dataset")
1044
+ dataset_id = response.json()["dataset_id"]
1061
1045
  spinner.write(
1062
1046
  to_colored_text(
1063
- f"Bad status code: {response.status_code}", state="fail"
1047
+ f" Dataset created with ID: {dataset_id}", state="success"
1064
1048
  )
1065
1049
  )
1066
- spinner.stop()
1067
- print(to_colored_text(response.json(), state="fail"))
1068
- return
1069
- dataset_id = response.json()["dataset_id"]
1070
- spinner.write(
1071
- to_colored_text(
1072
- f"✔ Dataset created with ID: {dataset_id}", state="success"
1050
+ return dataset_id
1051
+ except requests.HTTPError as e:
1052
+ spinner.write(
1053
+ to_colored_text(
1054
+ f"Bad status code: {e.response.status_code}", state="fail"
1055
+ )
1073
1056
  )
1074
- )
1075
- return dataset_id
1057
+ spinner.stop()
1058
+ print(to_colored_text(e.response.json(), state="fail"))
1059
+ return None
1076
1060
 
1077
1061
  def upload_to_dataset(
1078
1062
  self,
@@ -1104,8 +1088,6 @@ class Sutro:
1104
1088
  if dataset_id is None:
1105
1089
  dataset_id = self.create_dataset()
1106
1090
 
1107
- endpoint = f"{self.base_url}/upload-to-dataset"
1108
-
1109
1091
  if isinstance(file_paths, str):
1110
1092
  # check if the file path is a directory
1111
1093
  if os.path.isdir(file_paths):
@@ -1120,7 +1102,7 @@ class Sutro:
1120
1102
  with yaspin(
1121
1103
  SPINNER,
1122
1104
  text=to_colored_text(f"Uploading files to dataset: {dataset_id}"),
1123
- color=YASPIN_COLOR,
1105
+ color=BASE_OUTPUT_COLOR,
1124
1106
  ) as spinner:
1125
1107
  count = 0
1126
1108
  for file_path in file_paths:
@@ -1138,8 +1120,6 @@ class Sutro:
1138
1120
  "dataset_id": dataset_id,
1139
1121
  }
1140
1122
 
1141
- headers = {"Authorization": f"Key {self.api_key}"}
1142
-
1143
1123
  count += 1
1144
1124
  spinner.write(
1145
1125
  to_colored_text(
@@ -1148,25 +1128,18 @@ class Sutro:
1148
1128
  )
1149
1129
 
1150
1130
  try:
1151
- response = requests.post(
1152
- endpoint, headers=headers, data=payload, files=files
1131
+ self.do_request(
1132
+ "POST",
1133
+ "/upload-to-dataset",
1134
+ data=payload,
1135
+ files=files,
1136
+ verify=verify_ssl,
1153
1137
  )
1154
- if response.status_code != 200:
1155
- # Stop spinner before showing error to avoid terminal width error
1156
- spinner.stop()
1157
- print(
1158
- to_colored_text(
1159
- f"Error: HTTP {response.status_code}", state="fail"
1160
- )
1161
- )
1162
- print(to_colored_text(response.json(), state="fail"))
1163
- return
1164
-
1165
1138
  except requests.exceptions.RequestException as e:
1166
1139
  # Stop spinner before showing error to avoid terminal width error
1167
1140
  spinner.stop()
1168
1141
  print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
1169
- return
1142
+ return None
1170
1143
 
1171
1144
  spinner.write(
1172
1145
  to_colored_text(
@@ -1176,57 +1149,47 @@ class Sutro:
1176
1149
  return dataset_id
1177
1150
 
1178
1151
  def list_datasets(self):
1179
- endpoint = f"{self.base_url}/list-datasets"
1180
- headers = {
1181
- "Authorization": f"Key {self.api_key}",
1182
- "Content-Type": "application/json",
1183
- }
1184
1152
  with yaspin(
1185
- SPINNER, text=to_colored_text("Retrieving datasets"), color=YASPIN_COLOR
1153
+ SPINNER, text=to_colored_text("Retrieving datasets"), color=BASE_OUTPUT_COLOR
1186
1154
  ) as spinner:
1187
- response = requests.post(endpoint, headers=headers)
1188
- if response.status_code != 200:
1155
+ try:
1156
+ response = self.do_request("POST", "list-datasets")
1157
+ spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
1158
+ return response.json()["datasets"]
1159
+ except requests.HTTPError as e:
1189
1160
  spinner.fail(
1190
1161
  to_colored_text(
1191
- f"Bad status code: {response.status_code}", state="fail"
1162
+ f"Bad status code: {e.response.status_code}", state="fail"
1192
1163
  )
1193
1164
  )
1194
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1195
- return
1196
- spinner.write(to_colored_text("✔ Datasets retrieved", state="success"))
1197
- return response.json()["datasets"]
1165
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1166
+ return None
1198
1167
 
1199
1168
  def list_dataset_files(self, dataset_id: str):
1200
- endpoint = f"{self.base_url}/list-dataset-files"
1201
- headers = {
1202
- "Authorization": f"Key {self.api_key}",
1203
- "Content-Type": "application/json",
1204
- }
1205
1169
  payload = {
1206
1170
  "dataset_id": dataset_id,
1207
1171
  }
1208
1172
  with yaspin(
1209
1173
  SPINNER,
1210
1174
  text=to_colored_text(f"Listing files in dataset: {dataset_id}"),
1211
- color=YASPIN_COLOR,
1175
+ color=BASE_OUTPUT_COLOR,
1212
1176
  ) as spinner:
1213
- response = requests.post(
1214
- endpoint, headers=headers, data=json.dumps(payload)
1215
- )
1216
- if response.status_code != 200:
1217
- spinner.fail(
1177
+ try:
1178
+ response = self.do_request("POST", "list-dataset-files", json=payload)
1179
+ spinner.write(
1218
1180
  to_colored_text(
1219
- f"Bad status code: {response.status_code}", state="fail"
1181
+ f" Files listed in dataset: {dataset_id}", state="success"
1220
1182
  )
1221
1183
  )
1222
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1223
- return
1224
- spinner.write(
1225
- to_colored_text(
1226
- f" Files listed in dataset: {dataset_id}", state="success"
1184
+ return response.json()["files"]
1185
+ except requests.HTTPError as e:
1186
+ spinner.fail(
1187
+ to_colored_text(
1188
+ f"Bad status code: {e.response.status_code}", state="fail"
1189
+ )
1227
1190
  )
1228
- )
1229
- return response.json()["files"]
1191
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1192
+ return None
1230
1193
 
1231
1194
  def download_from_dataset(
1232
1195
  self,
@@ -1234,8 +1197,6 @@ class Sutro:
1234
1197
  files: Union[List[str], str] = None,
1235
1198
  output_path: str = None,
1236
1199
  ):
1237
- endpoint = f"{self.base_url}/download-from-dataset"
1238
-
1239
1200
  if files is None:
1240
1201
  files = self.list_dataset_files(dataset_id)
1241
1202
  elif isinstance(files, str):
@@ -1256,36 +1217,36 @@ class Sutro:
1256
1217
  with yaspin(
1257
1218
  SPINNER,
1258
1219
  text=to_colored_text(f"Downloading files from dataset: {dataset_id}"),
1259
- color=YASPIN_COLOR,
1220
+ color=BASE_OUTPUT_COLOR,
1260
1221
  ) as spinner:
1261
1222
  count = 0
1262
1223
  for file in files:
1263
- headers = {
1264
- "Authorization": f"Key {self.api_key}",
1265
- "Content-Type": "application/json",
1266
- }
1267
- payload = {
1268
- "dataset_id": dataset_id,
1269
- "file_name": file,
1270
- }
1271
1224
  spinner.text = to_colored_text(
1272
1225
  f"Downloading file {count + 1}/{len(files)} from dataset: {dataset_id}"
1273
1226
  )
1274
- response = requests.post(
1275
- endpoint, headers=headers, data=json.dumps(payload)
1276
- )
1277
- if response.status_code != 200:
1227
+
1228
+ try:
1229
+ payload = {
1230
+ "dataset_id": dataset_id,
1231
+ "file_name": file,
1232
+ }
1233
+ response = self.do_request(
1234
+ "POST", "download-from-dataset", json=payload
1235
+ )
1236
+
1237
+ file_content = response.content
1238
+ with open(os.path.join(output_path, file), "wb") as f:
1239
+ f.write(file_content)
1240
+
1241
+ count += 1
1242
+ except requests.HTTPError as e:
1278
1243
  spinner.fail(
1279
1244
  to_colored_text(
1280
- f"Bad status code: {response.status_code}", state="fail"
1245
+ f"Bad status code: {e.response.status_code}", state="fail"
1281
1246
  )
1282
1247
  )
1283
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1248
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1284
1249
  return
1285
- file_content = response.content
1286
- with open(os.path.join(output_path, file), "wb") as f:
1287
- f.write(file_content)
1288
- count += 1
1289
1250
  spinner.write(
1290
1251
  to_colored_text(
1291
1252
  f"✔ {count} files successfully downloaded from dataset: {dataset_id}",
@@ -1305,54 +1266,47 @@ class Sutro:
1305
1266
  Returns:
1306
1267
  dict: The status of the authentication.
1307
1268
  """
1308
- endpoint = f"{self.base_url}/try-authentication"
1309
- headers = {
1310
- "Authorization": f"Key {api_key}",
1311
- "Content-Type": "application/json",
1312
- }
1313
1269
  with yaspin(
1314
- SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
1270
+ SPINNER, text=to_colored_text("Checking API key"), color=BASE_OUTPUT_COLOR
1315
1271
  ) as spinner:
1316
- response = requests.get(endpoint, headers=headers)
1317
- if response.status_code == 200:
1272
+ try:
1273
+ response = self.do_request("GET", "try-authentication", api_key)
1274
+
1318
1275
  spinner.write(to_colored_text("✔"))
1319
- else:
1276
+ return response.json()
1277
+ except requests.HTTPError as e:
1320
1278
  spinner.write(
1321
1279
  to_colored_text(
1322
- f"API key failed to authenticate: {response.status_code}",
1280
+ f"API key failed to authenticate: {e.response.status_code}",
1323
1281
  state="fail",
1324
1282
  )
1325
1283
  )
1326
- return
1327
- return response.json()
1284
+ return None
1328
1285
 
1329
1286
  def get_quotas(self):
1330
- endpoint = f"{self.base_url}/get-quotas"
1331
- headers = {
1332
- "Authorization": f"Key {self.api_key}",
1333
- "Content-Type": "application/json",
1334
- }
1335
1287
  with yaspin(
1336
- SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
1288
+ SPINNER, text=to_colored_text("Fetching quotas"), color=BASE_OUTPUT_COLOR
1337
1289
  ) as spinner:
1338
- response = requests.get(endpoint, headers=headers)
1339
- if response.status_code != 200:
1290
+ try:
1291
+ response = self.do_request("GET", "get-quotas")
1292
+ return response.json()["quotas"]
1293
+ except requests.HTTPError as e:
1340
1294
  spinner.fail(
1341
1295
  to_colored_text(
1342
- f"Bad status code: {response.status_code}", state="fail"
1296
+ f"Bad status code: {e.response.status_code}", state="fail"
1343
1297
  )
1344
1298
  )
1345
- print(to_colored_text(f"Error: {response.json()}", state="fail"))
1346
- return
1347
- return response.json()["quotas"]
1299
+ print(to_colored_text(f"Error: {e.response.json()}", state="fail"))
1300
+ return None
1348
1301
 
1349
1302
  def await_job_completion(
1350
1303
  self,
1351
1304
  job_id: str,
1352
1305
  timeout: Optional[int] = 7200,
1353
1306
  obtain_results: bool = True,
1307
+ output_column: str = "inference_result",
1354
1308
  is_cost_estimate: bool = False,
1355
- ) -> list | None:
1309
+ ) -> pl.DataFrame | None:
1356
1310
  """
1357
1311
  Waits for job completion to occur and then returns the results upon
1358
1312
  a successful completion.
@@ -1364,14 +1318,14 @@ class Sutro:
1364
1318
  timeout (Optional[int]): The max time in seconds the function should wait for job results for. Default is 7200 (2 hours).
1365
1319
 
1366
1320
  Returns:
1367
- list: The results of the job.
1321
+ pl.DataFrame: The results of the job in a polars DataFrame.
1368
1322
  """
1369
1323
  POLL_INTERVAL = 5
1370
1324
 
1371
- results = None
1325
+ results: pl.DataFrame | None = None
1372
1326
  start_time = time.time()
1373
1327
  with yaspin(
1374
- SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1328
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=BASE_OUTPUT_COLOR
1375
1329
  ) as spinner:
1376
1330
  if not is_cost_estimate:
1377
1331
  clickable_link = make_clickable_link(
@@ -1405,7 +1359,9 @@ class Sutro:
1405
1359
  "Job completed! Retrieving results...", "success"
1406
1360
  )
1407
1361
  )
1408
- results = self.get_job_results(job_id)
1362
+ results = self.get_job_results(
1363
+ job_id, output_column=output_column
1364
+ )
1409
1365
  break
1410
1366
  if status == JobStatus.FAILED:
1411
1367
  spinner.write(to_colored_text("Job has failed", "fail"))
@@ -1433,7 +1389,7 @@ class Sutro:
1433
1389
  with yaspin(
1434
1390
  SPINNER,
1435
1391
  text=to_colored_text("Retrieving job results cache contents"),
1436
- color=YASPIN_COLOR,
1392
+ color=BASE_OUTPUT_COLOR,
1437
1393
  ) as spinner:
1438
1394
  if not os.path.exists(os.path.expanduser("~/.sutro/job-results")):
1439
1395
  spinner.write(to_colored_text("No job results cache found", "success"))
@@ -1465,7 +1421,7 @@ class Sutro:
1465
1421
 
1466
1422
  start_time = time.time()
1467
1423
  with yaspin(
1468
- SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1424
+ SPINNER, text=to_colored_text("Awaiting job completion"), color=BASE_OUTPUT_COLOR
1469
1425
  ) as spinner:
1470
1426
  while (time.time() - start_time) < timeout:
1471
1427
  try: