sutro 0.0.0__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutro/__init__.py +14 -0
- sutro/cli.py +418 -0
- sutro/sdk.py +1101 -0
- sutro-0.1.11.dist-info/METADATA +41 -0
- sutro-0.1.11.dist-info/RECORD +8 -0
- {sutro-0.0.0.dist-info → sutro-0.1.11.dist-info}/WHEEL +1 -2
- sutro-0.1.11.dist-info/entry_points.txt +2 -0
- sutro-0.1.11.dist-info/licenses/LICENSE +201 -0
- __init__.py +0 -1
- hi.py +0 -1
- sutro-0.0.0.dist-info/METADATA +0 -6
- sutro-0.0.0.dist-info/RECORD +0 -6
- sutro-0.0.0.dist-info/top_level.txt +0 -2
sutro/sdk.py
ADDED
|
@@ -0,0 +1,1101 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import polars as pl
|
|
8
|
+
import json
|
|
9
|
+
from typing import Union, List, Optional, Literal, Generator, Dict, Any
|
|
10
|
+
import os
|
|
11
|
+
import sys
|
|
12
|
+
from yaspin import yaspin
|
|
13
|
+
from yaspin.spinners import Spinners
|
|
14
|
+
from colorama import init, Fore, Back, Style
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
import time
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
import json
|
|
19
|
+
|
|
20
|
+
# Initialize colorama (required for Windows)
|
|
21
|
+
init()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# This is how yaspin defines is_jupyter logic
|
|
25
|
+
def is_jupyter() -> bool:
|
|
26
|
+
return not sys.stdout.isatty()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# `color` param not supported in Jupyter notebooks
|
|
30
|
+
YASPIN_COLOR = None if is_jupyter() else "blue"
|
|
31
|
+
SPINNER = Spinners.dots14
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def to_colored_text(
|
|
35
|
+
text: str, state: Optional[Literal["success", "fail"]] = None
|
|
36
|
+
) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Apply color to text based on state.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
text (str): The text to color
|
|
42
|
+
state (Optional[Literal['success', 'fail']]): The state that determines the color.
|
|
43
|
+
Options: 'success', 'fail', or None (default blue)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
str: Text with appropriate color applied
|
|
47
|
+
"""
|
|
48
|
+
match state:
|
|
49
|
+
case "success":
|
|
50
|
+
return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
|
|
51
|
+
case "fail":
|
|
52
|
+
return f"{Fore.RED}{text}{Style.RESET_ALL}"
|
|
53
|
+
case _:
|
|
54
|
+
# Default to blue for normal/processing states
|
|
55
|
+
return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Sutro:
|
|
59
|
+
def __init__(
|
|
60
|
+
self, api_key: str = None, base_url: str = "https://api.sutro.sh/"
|
|
61
|
+
):
|
|
62
|
+
self.api_key = api_key or self.check_for_api_key()
|
|
63
|
+
self.base_url = base_url
|
|
64
|
+
self.HEARTBEAT_INTERVAL_SECONDS = 15 # Keep in sync w what the backend expects
|
|
65
|
+
|
|
66
|
+
def check_for_api_key(self):
|
|
67
|
+
"""
|
|
68
|
+
Check for an API key in the user's home directory.
|
|
69
|
+
|
|
70
|
+
This method looks for a configuration file named 'config.json' in the
|
|
71
|
+
'.sutro' directory within the user's home directory.
|
|
72
|
+
If the file exists, it attempts to read the API key from it.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
str or None: The API key if found in the configuration file, or None if not found.
|
|
76
|
+
|
|
77
|
+
Note:
|
|
78
|
+
The expected structure of the config.json file is:
|
|
79
|
+
{
|
|
80
|
+
"api_key": "your_api_key_here"
|
|
81
|
+
}
|
|
82
|
+
"""
|
|
83
|
+
CONFIG_DIR = os.path.expanduser("~/.sutro")
|
|
84
|
+
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
85
|
+
if os.path.exists(CONFIG_FILE):
|
|
86
|
+
with open(CONFIG_FILE, "r") as f:
|
|
87
|
+
config = json.load(f)
|
|
88
|
+
return config.get("api_key")
|
|
89
|
+
else:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
def set_api_key(self, api_key: str):
|
|
93
|
+
"""
|
|
94
|
+
Set the API key for the Sutro API.
|
|
95
|
+
|
|
96
|
+
This method allows you to set the API key for the Sutro API.
|
|
97
|
+
The API key is used to authenticate requests to the API.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
api_key (str): The API key to set.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
None
|
|
104
|
+
"""
|
|
105
|
+
self.api_key = api_key
|
|
106
|
+
|
|
107
|
+
def handle_data_helper(
|
|
108
|
+
self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
|
|
109
|
+
):
|
|
110
|
+
if isinstance(data, list):
|
|
111
|
+
input_data = data
|
|
112
|
+
elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
113
|
+
if column is None:
|
|
114
|
+
raise ValueError("Column name must be specified for DataFrame input")
|
|
115
|
+
input_data = data[column].to_list()
|
|
116
|
+
elif isinstance(data, str):
|
|
117
|
+
if data.startswith("stage-"):
|
|
118
|
+
input_data = data + ":" + column
|
|
119
|
+
else:
|
|
120
|
+
file_ext = os.path.splitext(data)[1].lower()
|
|
121
|
+
if file_ext == ".csv":
|
|
122
|
+
df = pl.read_csv(data)
|
|
123
|
+
elif file_ext == ".parquet":
|
|
124
|
+
df = pl.read_parquet(data)
|
|
125
|
+
elif file_ext in [".txt", ""]:
|
|
126
|
+
with open(data, "r") as file:
|
|
127
|
+
input_data = [line.strip() for line in file]
|
|
128
|
+
else:
|
|
129
|
+
raise ValueError(f"Unsupported file type: {file_ext}")
|
|
130
|
+
|
|
131
|
+
if file_ext in [".csv", ".parquet"]:
|
|
132
|
+
if column is None:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Column name must be specified for CSV/Parquet input"
|
|
135
|
+
)
|
|
136
|
+
input_data = df[column].to_list()
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"Unsupported data type. Please provide a list, DataFrame, or file path."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return input_data
|
|
143
|
+
|
|
144
|
+
def set_base_url(self, base_url: str):
|
|
145
|
+
"""
|
|
146
|
+
Set the base URL for the Sutro API.
|
|
147
|
+
|
|
148
|
+
This method allows you to set the base URL for the Sutro API.
|
|
149
|
+
The base URL is used to authenticate requests to the API.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
base_url (str): The base URL to set.
|
|
153
|
+
"""
|
|
154
|
+
self.base_url = base_url
|
|
155
|
+
|
|
156
|
+
def infer(
|
|
157
|
+
self,
|
|
158
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
159
|
+
model: str = "llama-3.1-8b",
|
|
160
|
+
column: str = None,
|
|
161
|
+
output_column: str = "inference_result",
|
|
162
|
+
job_priority: int = 0,
|
|
163
|
+
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
164
|
+
sampling_params: dict = None,
|
|
165
|
+
system_prompt: str = None,
|
|
166
|
+
dry_run: bool = False,
|
|
167
|
+
stay_attached: bool = False,
|
|
168
|
+
random_seed_per_input: bool = False,
|
|
169
|
+
truncate_rows: bool = False
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Run inference on the provided data.
|
|
173
|
+
|
|
174
|
+
This method allows you to run inference on the provided data using the Sutro API.
|
|
175
|
+
It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and stages.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
179
|
+
model (str, optional): The model to use for inference. Defaults to "llama-3.1-8b".
|
|
180
|
+
column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or stage.
|
|
181
|
+
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
182
|
+
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
183
|
+
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
184
|
+
Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
|
|
185
|
+
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
186
|
+
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
187
|
+
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
188
|
+
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
189
|
+
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
190
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
|
|
194
|
+
|
|
195
|
+
"""
|
|
196
|
+
input_data = self.handle_data_helper(data, column)
|
|
197
|
+
stay_attached = stay_attached or job_priority == 0
|
|
198
|
+
|
|
199
|
+
# Convert BaseModel to dict if needed
|
|
200
|
+
if output_schema is not None:
|
|
201
|
+
if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
|
|
202
|
+
json_schema = output_schema.model_json_schema()
|
|
203
|
+
elif isinstance(output_schema, dict):
|
|
204
|
+
json_schema = output_schema
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
|
|
207
|
+
else:
|
|
208
|
+
json_schema = None
|
|
209
|
+
|
|
210
|
+
endpoint = f"{self.base_url}/batch-inference"
|
|
211
|
+
headers = {
|
|
212
|
+
"Authorization": f"Key {self.api_key}",
|
|
213
|
+
"Content-Type": "application/json",
|
|
214
|
+
}
|
|
215
|
+
payload = {
|
|
216
|
+
"model": model,
|
|
217
|
+
"inputs": input_data,
|
|
218
|
+
"job_priority": job_priority,
|
|
219
|
+
"json_schema": json_schema,
|
|
220
|
+
"system_prompt": system_prompt,
|
|
221
|
+
"dry_run": dry_run,
|
|
222
|
+
"sampling_params": sampling_params,
|
|
223
|
+
"random_seed_per_input": random_seed_per_input,
|
|
224
|
+
"truncate_rows": truncate_rows
|
|
225
|
+
}
|
|
226
|
+
if dry_run:
|
|
227
|
+
spinner_text = to_colored_text("Retrieving cost estimates...")
|
|
228
|
+
else:
|
|
229
|
+
t = f"Creating priority {job_priority} job"
|
|
230
|
+
spinner_text = to_colored_text(t)
|
|
231
|
+
|
|
232
|
+
# There are two gotchas with yaspin:
|
|
233
|
+
# 1. Can't use print while in spinner is running
|
|
234
|
+
# 2. When writing to stdout via spinner.fail, spinner.write etc, there is a pretty strict
|
|
235
|
+
# limit for content length in jupyter notebooks, where it wisll give an error about:
|
|
236
|
+
# Terminal size {self._terminal_width} is too small to display spinner with the given settings.
|
|
237
|
+
# https://github.com/pavdmyt/yaspin/blob/9c7430b499ab4611888ece39783a870e4a05fa45/yaspin/core.py#L568-L571
|
|
238
|
+
job_id = None
|
|
239
|
+
with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
|
|
240
|
+
response = requests.post(
|
|
241
|
+
endpoint, data=json.dumps(payload), headers=headers
|
|
242
|
+
)
|
|
243
|
+
response_data = response.json()
|
|
244
|
+
if response.status_code != 200:
|
|
245
|
+
spinner.write(
|
|
246
|
+
to_colored_text(f"Error: {response.status_code}", state="fail")
|
|
247
|
+
)
|
|
248
|
+
spinner.stop()
|
|
249
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
250
|
+
return
|
|
251
|
+
else:
|
|
252
|
+
if dry_run:
|
|
253
|
+
spinner.write(
|
|
254
|
+
to_colored_text("✔ Cost estimates retrieved", state="success")
|
|
255
|
+
)
|
|
256
|
+
return response_data["results"]
|
|
257
|
+
else:
|
|
258
|
+
job_id = response_data["results"]
|
|
259
|
+
spinner.write(
|
|
260
|
+
to_colored_text(
|
|
261
|
+
f"🛠️ Priority {job_priority} Job created with ID: {job_id}",
|
|
262
|
+
state="success",
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
if not stay_attached:
|
|
266
|
+
spinner.write(
|
|
267
|
+
to_colored_text(
|
|
268
|
+
f"Use `so.get_job_status('{job_id}')` to check the status of the job."
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
return job_id
|
|
272
|
+
|
|
273
|
+
success = False
|
|
274
|
+
if stay_attached and job_id is not None:
|
|
275
|
+
s = requests.Session()
|
|
276
|
+
payload = {
|
|
277
|
+
"job_id": job_id,
|
|
278
|
+
}
|
|
279
|
+
pbar = None
|
|
280
|
+
|
|
281
|
+
# Register for stream and get session token
|
|
282
|
+
session_token = self.register_stream_listener(job_id)
|
|
283
|
+
|
|
284
|
+
# Use the heartbeat session context manager
|
|
285
|
+
with self.stream_heartbeat_session(job_id, session_token) as s:
|
|
286
|
+
with s.get(
|
|
287
|
+
f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
|
|
288
|
+
headers=headers,
|
|
289
|
+
stream=True,
|
|
290
|
+
) as streaming_response:
|
|
291
|
+
streaming_response.raise_for_status()
|
|
292
|
+
spinner = yaspin(
|
|
293
|
+
SPINNER,
|
|
294
|
+
text=to_colored_text("Awaiting status updates..."),
|
|
295
|
+
color=YASPIN_COLOR,
|
|
296
|
+
)
|
|
297
|
+
spinner.start()
|
|
298
|
+
for line in streaming_response.iter_lines():
|
|
299
|
+
if line:
|
|
300
|
+
try:
|
|
301
|
+
json_obj = json.loads(line)
|
|
302
|
+
except json.JSONDecodeError:
|
|
303
|
+
print("Error: ", line, flush=True)
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
if json_obj["update_type"] == "progress":
|
|
307
|
+
if pbar is None:
|
|
308
|
+
spinner.stop()
|
|
309
|
+
postfix = f"Input tokens processed: 0"
|
|
310
|
+
pbar = self.fancy_tqdm(
|
|
311
|
+
total=len(input_data),
|
|
312
|
+
desc="Progress",
|
|
313
|
+
style=1,
|
|
314
|
+
postfix=postfix,
|
|
315
|
+
)
|
|
316
|
+
if json_obj["result"] > pbar.n:
|
|
317
|
+
pbar.update(json_obj["result"] - pbar.n)
|
|
318
|
+
pbar.refresh()
|
|
319
|
+
if json_obj["result"] == len(input_data):
|
|
320
|
+
pbar.close()
|
|
321
|
+
success = True
|
|
322
|
+
elif json_obj["update_type"] == "tokens":
|
|
323
|
+
if pbar is not None:
|
|
324
|
+
pbar.postfix = f"Input tokens processed: {json_obj['result']['input_tokens']}, Tokens generated: {json_obj['result']['output_tokens']}, Total tokens/s: {json_obj['result'].get('total_tokens_processed_per_second')}"
|
|
325
|
+
pbar.refresh()
|
|
326
|
+
if success:
|
|
327
|
+
spinner.text = to_colored_text(
|
|
328
|
+
"✔ Job succeeded. Obtaining results...", state="success"
|
|
329
|
+
)
|
|
330
|
+
spinner.start()
|
|
331
|
+
|
|
332
|
+
payload = {
|
|
333
|
+
"job_id": job_id,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
# TODO: we implment retries in cases where the job hasn't written results yet
|
|
337
|
+
# it would be better if we could receive a fully succeeded status from the job
|
|
338
|
+
# and not have such a race condition
|
|
339
|
+
max_retries = 20 # winds up being 100 seconds cumulative delay
|
|
340
|
+
retry_delay = 5 # initial delay in seconds
|
|
341
|
+
|
|
342
|
+
for _ in range(max_retries):
|
|
343
|
+
time.sleep(retry_delay)
|
|
344
|
+
|
|
345
|
+
job_results_response = s.post(
|
|
346
|
+
f"{self.base_url}/job-results",
|
|
347
|
+
headers=headers,
|
|
348
|
+
data=json.dumps(payload),
|
|
349
|
+
)
|
|
350
|
+
if job_results_response.status_code == 200:
|
|
351
|
+
break
|
|
352
|
+
|
|
353
|
+
if job_results_response.status_code != 200:
|
|
354
|
+
spinner.write(
|
|
355
|
+
to_colored_text(
|
|
356
|
+
"Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
|
|
357
|
+
state="fail",
|
|
358
|
+
)
|
|
359
|
+
)
|
|
360
|
+
spinner.stop()
|
|
361
|
+
return
|
|
362
|
+
|
|
363
|
+
results = job_results_response.json()["results"]
|
|
364
|
+
|
|
365
|
+
spinner.write(
|
|
366
|
+
to_colored_text(
|
|
367
|
+
f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
|
|
368
|
+
state="success",
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
spinner.stop()
|
|
372
|
+
|
|
373
|
+
if isinstance(data, (pd.DataFrame, pl.DataFrame)):
|
|
374
|
+
sample_n = 1 if sampling_params is None else sampling_params["n"]
|
|
375
|
+
if sample_n > 1:
|
|
376
|
+
results = [
|
|
377
|
+
results[i : i + sample_n]
|
|
378
|
+
for i in range(0, len(results), sample_n)
|
|
379
|
+
]
|
|
380
|
+
if isinstance(data, pd.DataFrame):
|
|
381
|
+
data[output_column] = results
|
|
382
|
+
elif isinstance(data, pl.DataFrame):
|
|
383
|
+
data = data.with_columns(pl.Series(output_column, results))
|
|
384
|
+
return data
|
|
385
|
+
|
|
386
|
+
return results
|
|
387
|
+
|
|
388
|
+
def register_stream_listener(self, job_id: str) -> str:
|
|
389
|
+
"""Register a new stream listener and get a session token."""
|
|
390
|
+
headers = {
|
|
391
|
+
"Authorization": f"Key {self.api_key}",
|
|
392
|
+
"Content-Type": "application/json",
|
|
393
|
+
}
|
|
394
|
+
with requests.post(
|
|
395
|
+
f"{self.base_url}/register-stream-listener/{job_id}",
|
|
396
|
+
headers=headers,
|
|
397
|
+
) as response:
|
|
398
|
+
response.raise_for_status()
|
|
399
|
+
data = response.json()
|
|
400
|
+
return data["request_session_token"]
|
|
401
|
+
|
|
402
|
+
# This is a best effort action and is ok if it sometimes doesn't complete etc
|
|
403
|
+
def unregister_stream_listener(self, job_id: str, session_token: str):
|
|
404
|
+
"""Explicitly unregister a stream listener."""
|
|
405
|
+
headers = {
|
|
406
|
+
"Authorization": f"Key {self.api_key}",
|
|
407
|
+
"Content-Type": "application/json",
|
|
408
|
+
}
|
|
409
|
+
with requests.post(
|
|
410
|
+
f"{self.base_url}/unregister-stream-listener/{job_id}",
|
|
411
|
+
headers=headers,
|
|
412
|
+
json={"request_session_token": session_token},
|
|
413
|
+
) as response:
|
|
414
|
+
response.raise_for_status()
|
|
415
|
+
|
|
416
|
+
def start_heartbeat(
|
|
417
|
+
self,
|
|
418
|
+
job_id: str,
|
|
419
|
+
session_token: str,
|
|
420
|
+
session: requests.Session,
|
|
421
|
+
stop_event: threading.Event
|
|
422
|
+
):
|
|
423
|
+
"""Send heartbeats until stopped."""
|
|
424
|
+
while not stop_event.is_set():
|
|
425
|
+
try:
|
|
426
|
+
headers = {
|
|
427
|
+
"Authorization": f"Key {self.api_key}",
|
|
428
|
+
"Content-Type": "application/json",
|
|
429
|
+
}
|
|
430
|
+
response = session.post(
|
|
431
|
+
f"{self.base_url}/stream-heartbeat/{job_id}",
|
|
432
|
+
headers=headers,
|
|
433
|
+
params={"request_session_token": session_token},
|
|
434
|
+
)
|
|
435
|
+
response.raise_for_status()
|
|
436
|
+
except Exception as e:
|
|
437
|
+
if not stop_event.is_set(): # Only log if we weren't stopping anyway
|
|
438
|
+
print(f"Heartbeat failed for job {job_id}: {e}")
|
|
439
|
+
|
|
440
|
+
# Use time.sleep instead of asyncio.sleep since this is synchronous
|
|
441
|
+
time.sleep(self.HEARTBEAT_INTERVAL_SECONDS)
|
|
442
|
+
|
|
443
|
+
@contextmanager
|
|
444
|
+
def stream_heartbeat_session(self, job_id: str, session_token: str) -> Generator[requests.Session, None, None]:
|
|
445
|
+
"""Context manager that handles session registration and heartbeat."""
|
|
446
|
+
session = requests.Session()
|
|
447
|
+
stop_heartbeat = threading.Event()
|
|
448
|
+
|
|
449
|
+
# Run this concurrently in a thread so we can not block main SDK path/behavior
|
|
450
|
+
# but still run heartbeat requests
|
|
451
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
452
|
+
future = executor.submit(
|
|
453
|
+
self.start_heartbeat,
|
|
454
|
+
job_id,
|
|
455
|
+
session_token,
|
|
456
|
+
session,
|
|
457
|
+
stop_heartbeat
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
try:
|
|
461
|
+
yield session
|
|
462
|
+
finally:
|
|
463
|
+
# Signal stop and cleanup
|
|
464
|
+
stop_heartbeat.set()
|
|
465
|
+
# Wait for heartbeat to finish with timeout
|
|
466
|
+
try:
|
|
467
|
+
future.result(timeout=1.0)
|
|
468
|
+
except TimeoutError:
|
|
469
|
+
pass
|
|
470
|
+
self.unregister_stream_listener(job_id, session_token)
|
|
471
|
+
session.close()
|
|
472
|
+
|
|
473
|
+
def attach(self, job_id):
|
|
474
|
+
"""
|
|
475
|
+
Attach to an existing job and stream its progress.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
job_id (str): The ID of the job to attach to
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
s = requests.Session()
|
|
482
|
+
payload = {
|
|
483
|
+
"job_id": job_id,
|
|
484
|
+
}
|
|
485
|
+
pbar = None
|
|
486
|
+
|
|
487
|
+
headers = {
|
|
488
|
+
"Authorization": f"Key {self.api_key}",
|
|
489
|
+
"Content-Type": "application/json",
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
with yaspin(
|
|
493
|
+
SPINNER,
|
|
494
|
+
text=to_colored_text("Looking for job..."),
|
|
495
|
+
color=YASPIN_COLOR,
|
|
496
|
+
) as spinner:
|
|
497
|
+
# Get job information from list-jobs endpoint
|
|
498
|
+
# TODO(cooper) we should add a get jobs endpoint:
|
|
499
|
+
# GET /jobs/{job_id}
|
|
500
|
+
jobs_response = s.get(
|
|
501
|
+
f"{self.base_url}/list-jobs",
|
|
502
|
+
headers=headers
|
|
503
|
+
)
|
|
504
|
+
jobs_response.raise_for_status()
|
|
505
|
+
|
|
506
|
+
# Find the specific job we want to attach to
|
|
507
|
+
job = next(
|
|
508
|
+
(job for job in jobs_response.json()["jobs"] if job["job_id"] == job_id),
|
|
509
|
+
None
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
if not job:
|
|
513
|
+
spinner.write(to_colored_text(f"Job {job_id} not found", state="fail"))
|
|
514
|
+
return
|
|
515
|
+
|
|
516
|
+
match job.get("status"):
|
|
517
|
+
case "SUCCEEDED":
|
|
518
|
+
spinner.write(
|
|
519
|
+
to_colored_text(
|
|
520
|
+
f"Job already completed. You can obtain the results with `sutro jobs results {job_id}`"
|
|
521
|
+
)
|
|
522
|
+
)
|
|
523
|
+
return
|
|
524
|
+
case "FAILED":
|
|
525
|
+
spinner.write(to_colored_text("❌ Job is in failed state.", state="fail"))
|
|
526
|
+
return
|
|
527
|
+
case "CANCELLED":
|
|
528
|
+
spinner.write(to_colored_text("❌ Job was cancelled.", state="fail"))
|
|
529
|
+
return
|
|
530
|
+
case _:
|
|
531
|
+
spinner.write(to_colored_text("✔ Job found!", state="success"))
|
|
532
|
+
|
|
533
|
+
total_rows = job["num_rows"]
|
|
534
|
+
success = False
|
|
535
|
+
|
|
536
|
+
session_token = self.register_stream_listener(job_id)
|
|
537
|
+
|
|
538
|
+
with self.stream_heartbeat_session(job_id, session_token) as s:
|
|
539
|
+
with s.get(
|
|
540
|
+
f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
|
|
541
|
+
headers=headers,
|
|
542
|
+
stream=True,
|
|
543
|
+
) as streaming_response:
|
|
544
|
+
streaming_response.raise_for_status()
|
|
545
|
+
spinner = yaspin(
|
|
546
|
+
SPINNER,
|
|
547
|
+
text=to_colored_text("Awaiting status updates..."),
|
|
548
|
+
color=YASPIN_COLOR,
|
|
549
|
+
)
|
|
550
|
+
spinner.start()
|
|
551
|
+
for line in streaming_response.iter_lines():
|
|
552
|
+
if line:
|
|
553
|
+
try:
|
|
554
|
+
json_obj = json.loads(line)
|
|
555
|
+
except json.JSONDecodeError:
|
|
556
|
+
print("Error: ", line, flush=True)
|
|
557
|
+
continue
|
|
558
|
+
|
|
559
|
+
if json_obj["update_type"] == "progress":
|
|
560
|
+
if pbar is None:
|
|
561
|
+
spinner.stop()
|
|
562
|
+
postfix = f"Input tokens processed: 0"
|
|
563
|
+
pbar = self.fancy_tqdm(
|
|
564
|
+
total=total_rows,
|
|
565
|
+
desc="Progress",
|
|
566
|
+
style=1,
|
|
567
|
+
postfix=postfix,
|
|
568
|
+
)
|
|
569
|
+
if json_obj["result"] > pbar.n:
|
|
570
|
+
pbar.update(json_obj["result"] - pbar.n)
|
|
571
|
+
pbar.refresh()
|
|
572
|
+
if json_obj["result"] == total_rows:
|
|
573
|
+
pbar.close()
|
|
574
|
+
success = True
|
|
575
|
+
elif json_obj["update_type"] == "tokens":
|
|
576
|
+
if pbar is not None:
|
|
577
|
+
pbar.postfix = f"Input tokens processed: {json_obj['result']['input_tokens']}, Tokens generated: {json_obj['result']['output_tokens']}, Total tokens/s: {json_obj['result'].get('total_tokens_processed_per_second')}"
|
|
578
|
+
pbar.refresh()
|
|
579
|
+
|
|
580
|
+
if success:
|
|
581
|
+
spinner.write(
|
|
582
|
+
to_colored_text(
|
|
583
|
+
f"✔ Job succeeded. Use `sutro jobs results {job_id}` to obtain results.",
|
|
584
|
+
state="success",
|
|
585
|
+
)
|
|
586
|
+
)
|
|
587
|
+
spinner.stop()
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def fancy_tqdm(
|
|
592
|
+
self,
|
|
593
|
+
total: int,
|
|
594
|
+
desc: str = "Progress",
|
|
595
|
+
color: str = "blue",
|
|
596
|
+
style=1,
|
|
597
|
+
postfix: str = None,
|
|
598
|
+
):
|
|
599
|
+
"""
|
|
600
|
+
Creates a customized tqdm progress bar with different styling options.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
total (int): Total iterations
|
|
604
|
+
desc (str): Description for the progress bar
|
|
605
|
+
color (str): Color of the progress bar (green, blue, red, yellow, magenta)
|
|
606
|
+
style (int): Style preset (1-4)
|
|
607
|
+
postfix (str): Postfix for the progress bar
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
# Style presets
|
|
611
|
+
style_presets = {
|
|
612
|
+
1: {
|
|
613
|
+
"bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
|
|
614
|
+
"ascii": "░▒█",
|
|
615
|
+
},
|
|
616
|
+
2: {
|
|
617
|
+
"bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
|
|
618
|
+
"ascii": "▁▂▃▄▅▆▇█",
|
|
619
|
+
},
|
|
620
|
+
3: {
|
|
621
|
+
"bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
|
|
622
|
+
"ascii": "◯◔◑◕●",
|
|
623
|
+
},
|
|
624
|
+
4: {
|
|
625
|
+
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
626
|
+
"ascii": "⬜⬛",
|
|
627
|
+
},
|
|
628
|
+
5: {
|
|
629
|
+
"bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
|
|
630
|
+
"ascii": "▏▎▍▌▋▊▉█",
|
|
631
|
+
},
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
# Get style configuration
|
|
635
|
+
style_config = style_presets.get(style, style_presets[1])
|
|
636
|
+
|
|
637
|
+
return tqdm(
|
|
638
|
+
total=total,
|
|
639
|
+
desc=desc,
|
|
640
|
+
colour=color,
|
|
641
|
+
bar_format=style_config["bar_format"],
|
|
642
|
+
ascii=style_config["ascii"],
|
|
643
|
+
ncols=80,
|
|
644
|
+
dynamic_ncols=True,
|
|
645
|
+
smoothing=0.3,
|
|
646
|
+
leave=True,
|
|
647
|
+
postfix=postfix,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
def list_jobs(self):
|
|
651
|
+
"""
|
|
652
|
+
List all jobs.
|
|
653
|
+
|
|
654
|
+
This method retrieves a list of all jobs associated with the API key.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
list: A list of job details.
|
|
658
|
+
"""
|
|
659
|
+
endpoint = f"{self.base_url}/list-jobs"
|
|
660
|
+
headers = {
|
|
661
|
+
"Authorization": f"Key {self.api_key}",
|
|
662
|
+
"Content-Type": "application/json",
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
with yaspin(
|
|
666
|
+
SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
|
|
667
|
+
) as spinner:
|
|
668
|
+
response = requests.get(endpoint, headers=headers)
|
|
669
|
+
if response.status_code != 200:
|
|
670
|
+
spinner.write(
|
|
671
|
+
to_colored_text(
|
|
672
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
673
|
+
)
|
|
674
|
+
)
|
|
675
|
+
spinner.stop()
|
|
676
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
677
|
+
return
|
|
678
|
+
return response.json()["jobs"]
|
|
679
|
+
|
|
680
|
+
def get_job_status(self, job_id: str):
|
|
681
|
+
"""
|
|
682
|
+
Get the status of a job by its ID.
|
|
683
|
+
|
|
684
|
+
This method retrieves the status of a job using its unique identifier.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
job_id (str): The ID of the job to retrieve the status for.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
str: The status of the job.
|
|
691
|
+
"""
|
|
692
|
+
endpoint = f"{self.base_url}/job-status/{job_id}"
|
|
693
|
+
headers = {
|
|
694
|
+
"Authorization": f"Key {self.api_key}",
|
|
695
|
+
"Content-Type": "application/json",
|
|
696
|
+
}
|
|
697
|
+
with yaspin(
|
|
698
|
+
SPINNER,
|
|
699
|
+
text=to_colored_text(f"Checking job status with ID: {job_id}"),
|
|
700
|
+
color=YASPIN_COLOR,
|
|
701
|
+
) as spinner:
|
|
702
|
+
response = requests.get(endpoint, headers=headers)
|
|
703
|
+
if response.status_code != 200:
|
|
704
|
+
spinner.write(
|
|
705
|
+
to_colored_text(
|
|
706
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
707
|
+
)
|
|
708
|
+
)
|
|
709
|
+
spinner.stop()
|
|
710
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
711
|
+
return
|
|
712
|
+
spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
|
|
713
|
+
return response.json()["job_status"][job_id]
|
|
714
|
+
|
|
715
|
+
def get_job_results(
|
|
716
|
+
self,
|
|
717
|
+
job_id: str,
|
|
718
|
+
include_inputs: bool = False,
|
|
719
|
+
include_cumulative_logprobs: bool = False,
|
|
720
|
+
):
|
|
721
|
+
"""
|
|
722
|
+
Get the results of a job by its ID.
|
|
723
|
+
|
|
724
|
+
This method retrieves the results of a job using its unique identifier.
|
|
725
|
+
|
|
726
|
+
Args:
|
|
727
|
+
job_id (str): The ID of the job to retrieve the results for.
|
|
728
|
+
include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
|
|
729
|
+
include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
list: The results of the job.
|
|
733
|
+
"""
|
|
734
|
+
endpoint = f"{self.base_url}/job-results"
|
|
735
|
+
payload = {
|
|
736
|
+
"job_id": job_id,
|
|
737
|
+
"include_inputs": include_inputs,
|
|
738
|
+
"include_cumulative_logprobs": include_cumulative_logprobs,
|
|
739
|
+
}
|
|
740
|
+
headers = {
|
|
741
|
+
"Authorization": f"Key {self.api_key}",
|
|
742
|
+
"Content-Type": "application/json",
|
|
743
|
+
}
|
|
744
|
+
with yaspin(
|
|
745
|
+
SPINNER,
|
|
746
|
+
text=to_colored_text(f"Gathering results from job: {job_id}"),
|
|
747
|
+
color=YASPIN_COLOR,
|
|
748
|
+
) as spinner:
|
|
749
|
+
response = requests.post(
|
|
750
|
+
endpoint, data=json.dumps(payload), headers=headers
|
|
751
|
+
)
|
|
752
|
+
if response.status_code == 200:
|
|
753
|
+
spinner.write(
|
|
754
|
+
to_colored_text("✔ Job results retrieved", state="success")
|
|
755
|
+
)
|
|
756
|
+
else:
|
|
757
|
+
spinner.write(
|
|
758
|
+
to_colored_text(
|
|
759
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
760
|
+
)
|
|
761
|
+
)
|
|
762
|
+
spinner.stop()
|
|
763
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
764
|
+
return
|
|
765
|
+
return response.json()["results"]
|
|
766
|
+
|
|
767
|
+
def cancel_job(self, job_id: str):
|
|
768
|
+
"""
|
|
769
|
+
Cancel a job by its ID.
|
|
770
|
+
|
|
771
|
+
This method allows you to cancel a job using its unique identifier.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
job_id (str): The ID of the job to cancel.
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
dict: The status of the job.
|
|
778
|
+
"""
|
|
779
|
+
endpoint = f"{self.base_url}/job-cancel/{job_id}"
|
|
780
|
+
headers = {
|
|
781
|
+
"Authorization": f"Key {self.api_key}",
|
|
782
|
+
"Content-Type": "application/json",
|
|
783
|
+
}
|
|
784
|
+
with yaspin(
|
|
785
|
+
SPINNER,
|
|
786
|
+
text=to_colored_text(f"Cancelling job: {job_id}"),
|
|
787
|
+
color=YASPIN_COLOR,
|
|
788
|
+
) as spinner:
|
|
789
|
+
response = requests.get(endpoint, headers=headers)
|
|
790
|
+
if response.status_code == 200:
|
|
791
|
+
spinner.write(to_colored_text("✔ Job cancelled", state="success"))
|
|
792
|
+
else:
|
|
793
|
+
spinner.write(to_colored_text("Failed to cancel job", state="fail"))
|
|
794
|
+
spinner.stop()
|
|
795
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
796
|
+
return
|
|
797
|
+
return response.json()
|
|
798
|
+
|
|
799
|
+
def create_stage(self):
|
|
800
|
+
"""
|
|
801
|
+
Create a new stage.
|
|
802
|
+
|
|
803
|
+
This method creates a new stage and returns its ID.
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
str: The ID of the new stage.
|
|
807
|
+
"""
|
|
808
|
+
endpoint = f"{self.base_url}/create-stage"
|
|
809
|
+
headers = {
|
|
810
|
+
"Authorization": f"Key {self.api_key}",
|
|
811
|
+
"Content-Type": "application/json",
|
|
812
|
+
}
|
|
813
|
+
with yaspin(
|
|
814
|
+
SPINNER, text=to_colored_text("Creating stage"), color=YASPIN_COLOR
|
|
815
|
+
) as spinner:
|
|
816
|
+
response = requests.get(endpoint, headers=headers)
|
|
817
|
+
if response.status_code != 200:
|
|
818
|
+
spinner.write(
|
|
819
|
+
to_colored_text(
|
|
820
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
821
|
+
)
|
|
822
|
+
)
|
|
823
|
+
spinner.stop()
|
|
824
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
825
|
+
return
|
|
826
|
+
stage_id = response.json()["stage_id"]
|
|
827
|
+
spinner.write(
|
|
828
|
+
to_colored_text(f"✔ Stage created with ID: {stage_id}", state="success")
|
|
829
|
+
)
|
|
830
|
+
return stage_id
|
|
831
|
+
|
|
832
|
+
def upload_to_stage(
|
|
833
|
+
self,
|
|
834
|
+
stage_id: Union[List[str], str] = None,
|
|
835
|
+
file_paths: Union[List[str], str] = None,
|
|
836
|
+
verify_ssl: bool = True,
|
|
837
|
+
):
|
|
838
|
+
"""
|
|
839
|
+
Upload data to a stage.
|
|
840
|
+
|
|
841
|
+
This method uploads files to a stage. Accepts a stage ID and file paths. If only a single parameter is provided, it will be interpreted as the file paths.
|
|
842
|
+
|
|
843
|
+
Args:
|
|
844
|
+
stage_id (str): The ID of the stage to upload to. If not provided, a new stage will be created.
|
|
845
|
+
file_paths (Union[List[str], str]): A list of paths to the files to upload, or a single path to a collection of files.
|
|
846
|
+
verify_ssl (bool): Whether to verify SSL certificates. Set to False to bypass SSL verification for troubleshooting.
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
dict: The response from the API.
|
|
850
|
+
"""
|
|
851
|
+
# when only a single parameter is provided, it is interpreted as the file paths
|
|
852
|
+
if file_paths is None and stage_id is not None:
|
|
853
|
+
file_paths = stage_id
|
|
854
|
+
stage_id = None
|
|
855
|
+
|
|
856
|
+
if file_paths is None:
|
|
857
|
+
raise ValueError("File paths must be provided")
|
|
858
|
+
|
|
859
|
+
if stage_id is None:
|
|
860
|
+
stage_id = self.create_stage()
|
|
861
|
+
|
|
862
|
+
endpoint = f"{self.base_url}/upload-to-stage"
|
|
863
|
+
|
|
864
|
+
if isinstance(file_paths, str):
|
|
865
|
+
# check if the file path is a directory
|
|
866
|
+
if os.path.isdir(file_paths):
|
|
867
|
+
file_paths = [
|
|
868
|
+
os.path.join(file_paths, f) for f in os.listdir(file_paths)
|
|
869
|
+
]
|
|
870
|
+
if len(file_paths) == 0:
|
|
871
|
+
raise ValueError("No files found in the directory")
|
|
872
|
+
else:
|
|
873
|
+
file_paths = [file_paths]
|
|
874
|
+
|
|
875
|
+
with yaspin(
|
|
876
|
+
SPINNER,
|
|
877
|
+
text=to_colored_text(f"Uploading files to stage: {stage_id}"),
|
|
878
|
+
color=YASPIN_COLOR,
|
|
879
|
+
) as spinner:
|
|
880
|
+
count = 0
|
|
881
|
+
for file_path in file_paths:
|
|
882
|
+
file_name = os.path.basename(file_path)
|
|
883
|
+
|
|
884
|
+
files = {
|
|
885
|
+
"file": (
|
|
886
|
+
file_name,
|
|
887
|
+
open(file_path, "rb"),
|
|
888
|
+
"application/octet-stream",
|
|
889
|
+
)
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
payload = {
|
|
893
|
+
"stage_id": stage_id,
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
headers = {
|
|
897
|
+
"Authorization": f"Key {self.api_key}"}
|
|
898
|
+
|
|
899
|
+
count += 1
|
|
900
|
+
spinner.write(
|
|
901
|
+
to_colored_text(
|
|
902
|
+
f"Uploading file {count}/{len(file_paths)} to stage: {stage_id}"
|
|
903
|
+
)
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
try:
|
|
907
|
+
response = requests.post(
|
|
908
|
+
endpoint, headers=headers, data=payload, files=files
|
|
909
|
+
)
|
|
910
|
+
if response.status_code != 200:
|
|
911
|
+
# Stop spinner before showing error to avoid terminal width error
|
|
912
|
+
spinner.stop()
|
|
913
|
+
print(
|
|
914
|
+
to_colored_text(
|
|
915
|
+
f"Error: HTTP {response.status_code}", state="fail"
|
|
916
|
+
)
|
|
917
|
+
)
|
|
918
|
+
print(to_colored_text(response.json(), state="fail"))
|
|
919
|
+
return
|
|
920
|
+
|
|
921
|
+
except requests.exceptions.RequestException as e:
|
|
922
|
+
# Stop spinner before showing error to avoid terminal width error
|
|
923
|
+
spinner.stop()
|
|
924
|
+
print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
|
|
925
|
+
return
|
|
926
|
+
|
|
927
|
+
spinner.write(
|
|
928
|
+
to_colored_text(
|
|
929
|
+
f"✔ {count} files successfully uploaded to stage", state="success"
|
|
930
|
+
)
|
|
931
|
+
)
|
|
932
|
+
return stage_id
|
|
933
|
+
|
|
934
|
+
def list_stages(self):
|
|
935
|
+
endpoint = f"{self.base_url}/list-stages"
|
|
936
|
+
headers = {
|
|
937
|
+
"Authorization": f"Key {self.api_key}",
|
|
938
|
+
"Content-Type": "application/json",
|
|
939
|
+
}
|
|
940
|
+
with yaspin(
|
|
941
|
+
SPINNER, text=to_colored_text("Retrieving stages"), color=YASPIN_COLOR
|
|
942
|
+
) as spinner:
|
|
943
|
+
response = requests.post(endpoint, headers=headers)
|
|
944
|
+
if response.status_code != 200:
|
|
945
|
+
spinner.fail(
|
|
946
|
+
to_colored_text(
|
|
947
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
948
|
+
)
|
|
949
|
+
)
|
|
950
|
+
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
951
|
+
return
|
|
952
|
+
spinner.write(to_colored_text("✔ Stages retrieved", state="success"))
|
|
953
|
+
return response.json()["stages"]
|
|
954
|
+
|
|
955
|
+
def list_stage_files(self, stage_id: str):
|
|
956
|
+
endpoint = f"{self.base_url}/list-stage-files"
|
|
957
|
+
headers = {
|
|
958
|
+
"Authorization": f"Key {self.api_key}",
|
|
959
|
+
"Content-Type": "application/json",
|
|
960
|
+
}
|
|
961
|
+
payload = {
|
|
962
|
+
"stage_id": stage_id,
|
|
963
|
+
}
|
|
964
|
+
with yaspin(
|
|
965
|
+
SPINNER,
|
|
966
|
+
text=to_colored_text(f"Listing files in stage: {stage_id}"),
|
|
967
|
+
color=YASPIN_COLOR,
|
|
968
|
+
) as spinner:
|
|
969
|
+
response = requests.post(
|
|
970
|
+
endpoint, headers=headers, data=json.dumps(payload)
|
|
971
|
+
)
|
|
972
|
+
if response.status_code != 200:
|
|
973
|
+
spinner.fail(
|
|
974
|
+
to_colored_text(
|
|
975
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
976
|
+
)
|
|
977
|
+
)
|
|
978
|
+
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
979
|
+
return
|
|
980
|
+
spinner.write(
|
|
981
|
+
to_colored_text(f"✔ Files listed in stage: {stage_id}", state="success")
|
|
982
|
+
)
|
|
983
|
+
return response.json()["files"]
|
|
984
|
+
|
|
985
|
+
def download_from_stage(
|
|
986
|
+
self,
|
|
987
|
+
stage_id: str,
|
|
988
|
+
files: Union[List[str], str] = None,
|
|
989
|
+
output_path: str = None,
|
|
990
|
+
):
|
|
991
|
+
endpoint = f"{self.base_url}/download-from-stage"
|
|
992
|
+
|
|
993
|
+
if files is None:
|
|
994
|
+
files = self.list_stage_files(stage_id)
|
|
995
|
+
elif isinstance(files, str):
|
|
996
|
+
files = [files]
|
|
997
|
+
|
|
998
|
+
if not files:
|
|
999
|
+
print(
|
|
1000
|
+
to_colored_text(
|
|
1001
|
+
f"Couldn't find files for stage ID: {stage_id}", state="fail"
|
|
1002
|
+
)
|
|
1003
|
+
)
|
|
1004
|
+
return
|
|
1005
|
+
|
|
1006
|
+
# if no output path is provided, save the files to the current working directory
|
|
1007
|
+
if output_path is None:
|
|
1008
|
+
output_path = os.getcwd()
|
|
1009
|
+
|
|
1010
|
+
with yaspin(
|
|
1011
|
+
SPINNER,
|
|
1012
|
+
text=to_colored_text(f"Downloading files from stage: {stage_id}"),
|
|
1013
|
+
color=YASPIN_COLOR,
|
|
1014
|
+
) as spinner:
|
|
1015
|
+
count = 0
|
|
1016
|
+
for file in files:
|
|
1017
|
+
headers = {
|
|
1018
|
+
"Authorization": f"Key {self.api_key}",
|
|
1019
|
+
"Content-Type": "application/json",
|
|
1020
|
+
}
|
|
1021
|
+
payload = {
|
|
1022
|
+
"stage_id": stage_id,
|
|
1023
|
+
"file_name": file,
|
|
1024
|
+
}
|
|
1025
|
+
spinner.text = to_colored_text(
|
|
1026
|
+
f"Downloading file {count + 1}/{len(files)} from stage: {stage_id}"
|
|
1027
|
+
)
|
|
1028
|
+
response = requests.post(
|
|
1029
|
+
endpoint, headers=headers, data=json.dumps(payload)
|
|
1030
|
+
)
|
|
1031
|
+
if response.status_code != 200:
|
|
1032
|
+
spinner.fail(
|
|
1033
|
+
to_colored_text(
|
|
1034
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
1035
|
+
)
|
|
1036
|
+
)
|
|
1037
|
+
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1038
|
+
return
|
|
1039
|
+
file_content = response.content
|
|
1040
|
+
with open(os.path.join(output_path, file), "wb") as f:
|
|
1041
|
+
f.write(file_content)
|
|
1042
|
+
count += 1
|
|
1043
|
+
spinner.write(
|
|
1044
|
+
to_colored_text(
|
|
1045
|
+
f"✔ {count} files successfully downloaded from stage: {stage_id}",
|
|
1046
|
+
state="success",
|
|
1047
|
+
)
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
|
+
def try_authentication(self, api_key: str):
|
|
1051
|
+
"""
|
|
1052
|
+
Try to authenticate with the API key.
|
|
1053
|
+
|
|
1054
|
+
This method allows you to authenticate with the API key.
|
|
1055
|
+
|
|
1056
|
+
Args:
|
|
1057
|
+
api_key (str): The API key to authenticate with.
|
|
1058
|
+
|
|
1059
|
+
Returns:
|
|
1060
|
+
dict: The status of the authentication.
|
|
1061
|
+
"""
|
|
1062
|
+
endpoint = f"{self.base_url}/try-authentication"
|
|
1063
|
+
headers = {
|
|
1064
|
+
"Authorization": f"Key {api_key}",
|
|
1065
|
+
"Content-Type": "application/json",
|
|
1066
|
+
}
|
|
1067
|
+
with yaspin(
|
|
1068
|
+
SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
|
|
1069
|
+
) as spinner:
|
|
1070
|
+
response = requests.get(endpoint, headers=headers)
|
|
1071
|
+
if response.status_code == 200:
|
|
1072
|
+
spinner.write(to_colored_text("✔"))
|
|
1073
|
+
else:
|
|
1074
|
+
spinner.write(
|
|
1075
|
+
to_colored_text(
|
|
1076
|
+
f"API key failed to authenticate: {response.status_code}",
|
|
1077
|
+
state="fail",
|
|
1078
|
+
)
|
|
1079
|
+
)
|
|
1080
|
+
return
|
|
1081
|
+
return response.json()
|
|
1082
|
+
|
|
1083
|
+
def get_quotas(self):
|
|
1084
|
+
endpoint = f"{self.base_url}/get-quotas"
|
|
1085
|
+
headers = {
|
|
1086
|
+
"Authorization": f"Key {self.api_key}",
|
|
1087
|
+
"Content-Type": "application/json",
|
|
1088
|
+
}
|
|
1089
|
+
with yaspin(
|
|
1090
|
+
SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
|
|
1091
|
+
) as spinner:
|
|
1092
|
+
response = requests.get(endpoint, headers=headers)
|
|
1093
|
+
if response.status_code != 200:
|
|
1094
|
+
spinner.fail(
|
|
1095
|
+
to_colored_text(
|
|
1096
|
+
f"Bad status code: {response.status_code}", state="fail"
|
|
1097
|
+
)
|
|
1098
|
+
)
|
|
1099
|
+
print(to_colored_text(f"Error: {response.json()}", state="fail"))
|
|
1100
|
+
return
|
|
1101
|
+
return response.json()["quotas"]
|