sutro 0.0.0__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

sutro/sdk.py ADDED
@@ -0,0 +1,1097 @@
1
+ import threading
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from contextlib import contextmanager
4
+
5
+ import requests
6
+ import pandas as pd
7
+ import polars as pl
8
+ import json
9
+ from typing import Union, List, Optional, Literal, Generator, Dict, Any
10
+ import os
11
+ import sys
12
+ from yaspin import yaspin
13
+ from yaspin.spinners import Spinners
14
+ from colorama import init, Fore, Back, Style
15
+ from tqdm import tqdm
16
+ import time
17
+ from pydantic import BaseModel
18
+ import json
19
+
20
+ # Initialize colorama (required for Windows)
21
+ init()
22
+
23
+
24
+ # This is how yaspin defines is_jupyter logic
25
+ def is_jupyter() -> bool:
26
+ return not sys.stdout.isatty()
27
+
28
+
29
+ # `color` param not supported in Jupyter notebooks
30
+ YASPIN_COLOR = None if is_jupyter() else "blue"
31
+ SPINNER = Spinners.dots14
32
+
33
+
34
+ def to_colored_text(
35
+ text: str, state: Optional[Literal["success", "fail"]] = None
36
+ ) -> str:
37
+ """
38
+ Apply color to text based on state.
39
+
40
+ Args:
41
+ text (str): The text to color
42
+ state (Optional[Literal['success', 'fail']]): The state that determines the color.
43
+ Options: 'success', 'fail', or None (default blue)
44
+
45
+ Returns:
46
+ str: Text with appropriate color applied
47
+ """
48
+ match state:
49
+ case "success":
50
+ return f"{Fore.GREEN}{text}{Style.RESET_ALL}"
51
+ case "fail":
52
+ return f"{Fore.RED}{text}{Style.RESET_ALL}"
53
+ case _:
54
+ # Default to blue for normal/processing states
55
+ return f"{Fore.BLUE}{text}{Style.RESET_ALL}"
56
+
57
+
58
+ class Sutro:
59
+ def __init__(
60
+ self, api_key: str = None, base_url: str = "https://api.sutro.sh/"
61
+ ):
62
+ self.api_key = api_key or self.check_for_api_key()
63
+ self.base_url = base_url
64
+ self.HEARTBEAT_INTERVAL_SECONDS = 15 # Keep in sync w what the backend expects
65
+
66
+ def check_for_api_key(self):
67
+ """
68
+ Check for an API key in the user's home directory.
69
+
70
+ This method looks for a configuration file named 'config.json' in the
71
+ '.sutro' directory within the user's home directory.
72
+ If the file exists, it attempts to read the API key from it.
73
+
74
+ Returns:
75
+ str or None: The API key if found in the configuration file, or None if not found.
76
+
77
+ Note:
78
+ The expected structure of the config.json file is:
79
+ {
80
+ "api_key": "your_api_key_here"
81
+ }
82
+ """
83
+ CONFIG_DIR = os.path.expanduser("~/.sutro")
84
+ CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
85
+ if os.path.exists(CONFIG_FILE):
86
+ with open(CONFIG_FILE, "r") as f:
87
+ config = json.load(f)
88
+ return config.get("api_key")
89
+ else:
90
+ return None
91
+
92
+ def set_api_key(self, api_key: str):
93
+ """
94
+ Set the API key for the Sutro API.
95
+
96
+ This method allows you to set the API key for the Sutro API.
97
+ The API key is used to authenticate requests to the API.
98
+
99
+ Args:
100
+ api_key (str): The API key to set.
101
+
102
+ Returns:
103
+ None
104
+ """
105
+ self.api_key = api_key
106
+
107
+ def handle_data_helper(
108
+ self, data: Union[List, pd.DataFrame, pl.DataFrame, str], column: str = None
109
+ ):
110
+ if isinstance(data, list):
111
+ input_data = data
112
+ elif isinstance(data, (pd.DataFrame, pl.DataFrame)):
113
+ if column is None:
114
+ raise ValueError("Column name must be specified for DataFrame input")
115
+ input_data = data[column].to_list()
116
+ elif isinstance(data, str):
117
+ if data.startswith("stage-"):
118
+ input_data = data + ":" + column
119
+ else:
120
+ file_ext = os.path.splitext(data)[1].lower()
121
+ if file_ext == ".csv":
122
+ df = pl.read_csv(data)
123
+ elif file_ext == ".parquet":
124
+ df = pl.read_parquet(data)
125
+ elif file_ext in [".txt", ""]:
126
+ with open(data, "r") as file:
127
+ input_data = [line.strip() for line in file]
128
+ else:
129
+ raise ValueError(f"Unsupported file type: {file_ext}")
130
+
131
+ if file_ext in [".csv", ".parquet"]:
132
+ if column is None:
133
+ raise ValueError(
134
+ "Column name must be specified for CSV/Parquet input"
135
+ )
136
+ input_data = df[column].to_list()
137
+ else:
138
+ raise ValueError(
139
+ "Unsupported data type. Please provide a list, DataFrame, or file path."
140
+ )
141
+
142
+ return input_data
143
+
144
+ def set_base_url(self, base_url: str):
145
+ """
146
+ Set the base URL for the Sutro API.
147
+
148
+ This method allows you to set the base URL for the Sutro API.
149
+ The base URL is used to authenticate requests to the API.
150
+
151
+ Args:
152
+ base_url (str): The base URL to set.
153
+ """
154
+ self.base_url = base_url
155
+
156
+ def infer(
157
+ self,
158
+ data: Union[List, pd.DataFrame, pl.DataFrame, str],
159
+ model: str = "llama-3.1-8b",
160
+ column: str = None,
161
+ output_column: str = "inference_result",
162
+ job_priority: int = 0,
163
+ output_schema: Union[Dict[str, Any], BaseModel] = None,
164
+ sampling_params: dict = None,
165
+ system_prompt: str = None,
166
+ dry_run: bool = False,
167
+ stay_attached: bool = False,
168
+ random_seed_per_input: bool = False,
169
+ truncate_rows: bool = False
170
+ ):
171
+ """
172
+ Run inference on the provided data.
173
+
174
+ This method allows you to run inference on the provided data using the Sutro API.
175
+ It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and stages.
176
+
177
+ Args:
178
+ data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
179
+ model (str, optional): The model to use for inference. Defaults to "llama-3.1-8b".
180
+ column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or stage.
181
+ output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
182
+ job_priority (int, optional): The priority of the job. Defaults to 0.
183
+ output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
184
+ Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
185
+ sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
186
+ system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
187
+ dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
188
+ stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
189
+ random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
190
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
191
+
192
+ Returns:
193
+ Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
194
+
195
+ """
196
+ input_data = self.handle_data_helper(data, column)
197
+ stay_attached = stay_attached or job_priority == 0
198
+
199
+ # Convert BaseModel to dict if needed
200
+ if output_schema is not None:
201
+ if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
202
+ json_schema = output_schema.model_json_schema()
203
+ elif isinstance(output_schema, dict):
204
+ json_schema = output_schema
205
+ else:
206
+ raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
207
+ else:
208
+ json_schema = None
209
+
210
+ endpoint = f"{self.base_url}/batch-inference"
211
+ headers = {
212
+ "Authorization": f"Key {self.api_key}",
213
+ "Content-Type": "application/json",
214
+ }
215
+ payload = {
216
+ "model": model,
217
+ "inputs": input_data,
218
+ "job_priority": job_priority,
219
+ "json_schema": json_schema,
220
+ "system_prompt": system_prompt,
221
+ "dry_run": dry_run,
222
+ "sampling_params": sampling_params,
223
+ "random_seed_per_input": random_seed_per_input,
224
+ "truncate_rows": truncate_rows
225
+ }
226
+ if dry_run:
227
+ spinner_text = to_colored_text("Retrieving cost estimates...")
228
+ else:
229
+ t = f"Creating priority {job_priority} job"
230
+ spinner_text = to_colored_text(t)
231
+
232
+ # There are two gotchas with yaspin:
233
+ # 1. Can't use print while in spinner is running
234
+ # 2. When writing to stdout via spinner.fail, spinner.write etc, there is a pretty strict
235
+ # limit for content length in jupyter notebooks, where it wisll give an error about:
236
+ # Terminal size {self._terminal_width} is too small to display spinner with the given settings.
237
+ # https://github.com/pavdmyt/yaspin/blob/9c7430b499ab4611888ece39783a870e4a05fa45/yaspin/core.py#L568-L571
238
+ job_id = None
239
+ with yaspin(SPINNER, text=spinner_text, color=YASPIN_COLOR) as spinner:
240
+ response = requests.post(
241
+ endpoint, data=json.dumps(payload), headers=headers
242
+ )
243
+ response_data = response.json()
244
+ if response.status_code != 200:
245
+ spinner.write(
246
+ to_colored_text(f"Error: {response.status_code}", state="fail")
247
+ )
248
+ spinner.stop()
249
+ print(to_colored_text(response.json(), state="fail"))
250
+ return
251
+ else:
252
+ if dry_run:
253
+ spinner.write(
254
+ to_colored_text("✔ Cost estimates retrieved", state="success")
255
+ )
256
+ return response_data["results"]
257
+ else:
258
+ job_id = response_data["results"]
259
+ spinner.write(
260
+ to_colored_text(
261
+ f"🛠️ Priority {job_priority} Job created with ID: {job_id}",
262
+ state="success",
263
+ )
264
+ )
265
+ if not stay_attached:
266
+ spinner.write(
267
+ to_colored_text(
268
+ f"Use `so.get_job_status('{job_id}')` to check the status of the job."
269
+ )
270
+ )
271
+ return job_id
272
+
273
+ success = False
274
+ if stay_attached and job_id is not None:
275
+ s = requests.Session()
276
+ payload = {
277
+ "job_id": job_id,
278
+ }
279
+ pbar = None
280
+
281
+ # Register for stream and get session token
282
+ session_token = self.register_stream_listener(job_id)
283
+
284
+ # Use the heartbeat session context manager
285
+ with self.stream_heartbeat_session(job_id, session_token) as s:
286
+ with s.get(
287
+ f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
288
+ headers=headers,
289
+ stream=True,
290
+ ) as streaming_response:
291
+ streaming_response.raise_for_status()
292
+ spinner = yaspin(
293
+ SPINNER,
294
+ text=to_colored_text("Awaiting status updates..."),
295
+ color=YASPIN_COLOR,
296
+ )
297
+ spinner.start()
298
+ for line in streaming_response.iter_lines():
299
+ if line:
300
+ try:
301
+ json_obj = json.loads(line)
302
+ except json.JSONDecodeError:
303
+ print("Error: ", line, flush=True)
304
+ continue
305
+
306
+ if json_obj["update_type"] == "progress":
307
+ if pbar is None:
308
+ spinner.stop()
309
+ postfix = f"Input tokens processed: 0"
310
+ pbar = self.fancy_tqdm(
311
+ total=len(input_data),
312
+ desc="Progress",
313
+ style=1,
314
+ postfix=postfix,
315
+ )
316
+ if json_obj["result"] > pbar.n:
317
+ pbar.update(json_obj["result"] - pbar.n)
318
+ pbar.refresh()
319
+ if json_obj["result"] == len(input_data):
320
+ pbar.close()
321
+ success = True
322
+ elif json_obj["update_type"] == "tokens":
323
+ if pbar is not None:
324
+ pbar.postfix = f"Input tokens processed: {json_obj['result']['input_tokens']}, Tokens generated: {json_obj['result']['output_tokens']}, Total tokens/s: {json_obj['result'].get('total_tokens_processed_per_second')}"
325
+ pbar.refresh()
326
+ if success:
327
+ spinner.text = to_colored_text(
328
+ "✔ Job succeeded. Obtaining results...", state="success"
329
+ )
330
+ spinner.start()
331
+
332
+ payload = {
333
+ "job_id": job_id,
334
+ }
335
+
336
+ # TODO: we implment retries in cases where the job hasn't written results yet
337
+ # it would be better if we could receive a fully succeeded status from the job
338
+ # and not have such a race condition
339
+ max_retries = 20 # winds up being 100 seconds cumulative delay
340
+ retry_delay = 5 # initial delay in seconds
341
+
342
+ for _ in range(max_retries):
343
+ time.sleep(retry_delay)
344
+
345
+ job_results_response = s.post(
346
+ f"{self.base_url}/job-results",
347
+ headers=headers,
348
+ data=json.dumps(payload),
349
+ )
350
+ if job_results_response.status_code == 200:
351
+ break
352
+
353
+ if job_results_response.status_code != 200:
354
+ spinner.write(
355
+ to_colored_text(
356
+ "Job succeeded, but results are not yet available. Use `so.get_job_results('{job_id}')` to obtain results.",
357
+ state="fail",
358
+ )
359
+ )
360
+ spinner.stop()
361
+ return
362
+
363
+ results = job_results_response.json()["results"]
364
+
365
+ spinner.write(
366
+ to_colored_text(
367
+ f"✔ Job results received. You can re-obtain the results with `so.get_job_results('{job_id}')`",
368
+ state="success",
369
+ )
370
+ )
371
+ spinner.stop()
372
+
373
+ if isinstance(data, (pd.DataFrame, pl.DataFrame)):
374
+ sample_n = 1 if sampling_params is None else sampling_params["n"]
375
+ if sample_n > 1:
376
+ results = [
377
+ results[i : i + sample_n]
378
+ for i in range(0, len(results), sample_n)
379
+ ]
380
+ if isinstance(data, pd.DataFrame):
381
+ data[output_column] = results
382
+ elif isinstance(data, pl.DataFrame):
383
+ data = data.with_columns(pl.Series(output_column, results))
384
+ return data
385
+
386
+ return results
387
+
388
+ def register_stream_listener(self, job_id: str) -> str:
389
+ """Register a new stream listener and get a session token."""
390
+ headers = {
391
+ "Authorization": f"Key {self.api_key}",
392
+ "Content-Type": "application/json",
393
+ }
394
+ with requests.post(
395
+ f"{self.base_url}/register-stream-listener/{job_id}",
396
+ headers=headers,
397
+ ) as response:
398
+ response.raise_for_status()
399
+ data = response.json()
400
+ return data["request_session_token"]
401
+
402
+ # This is a best effort action and is ok if it sometimes doesn't complete etc
403
+ def unregister_stream_listener(self, job_id: str, session_token: str):
404
+ """Explicitly unregister a stream listener."""
405
+ headers = {
406
+ "Authorization": f"Key {self.api_key}",
407
+ "Content-Type": "application/json",
408
+ }
409
+ with requests.post(
410
+ f"{self.base_url}/unregister-stream-listener/{job_id}",
411
+ headers=headers,
412
+ json={"request_session_token": session_token},
413
+ ) as response:
414
+ response.raise_for_status()
415
+
416
+ def start_heartbeat(
417
+ self,
418
+ job_id: str,
419
+ session_token: str,
420
+ session: requests.Session,
421
+ stop_event: threading.Event
422
+ ):
423
+ """Send heartbeats until stopped."""
424
+ while not stop_event.is_set():
425
+ try:
426
+ headers = {
427
+ "Authorization": f"Key {self.api_key}",
428
+ "Content-Type": "application/json",
429
+ }
430
+ response = session.post(
431
+ f"{self.base_url}/stream-heartbeat/{job_id}",
432
+ headers=headers,
433
+ params={"request_session_token": session_token},
434
+ )
435
+ response.raise_for_status()
436
+ except Exception as e:
437
+ if not stop_event.is_set(): # Only log if we weren't stopping anyway
438
+ print(f"Heartbeat failed for job {job_id}: {e}")
439
+
440
+ # Use time.sleep instead of asyncio.sleep since this is synchronous
441
+ time.sleep(self.HEARTBEAT_INTERVAL_SECONDS)
442
+
443
+ @contextmanager
444
+ def stream_heartbeat_session(self, job_id: str, session_token: str) -> Generator[requests.Session, None, None]:
445
+ """Context manager that handles session registration and heartbeat."""
446
+ session = requests.Session()
447
+ stop_heartbeat = threading.Event()
448
+
449
+ # Run this concurrently in a thread so we can not block main SDK path/behavior
450
+ # but still run heartbeat requests
451
+ with ThreadPoolExecutor(max_workers=1) as executor:
452
+ future = executor.submit(
453
+ self.start_heartbeat,
454
+ job_id,
455
+ session_token,
456
+ session,
457
+ stop_heartbeat
458
+ )
459
+
460
+ try:
461
+ yield session
462
+ finally:
463
+ # Signal stop and cleanup
464
+ stop_heartbeat.set()
465
+ future.result() # Wait for heartbeat to finish
466
+ self.unregister_stream_listener(job_id, session_token)
467
+ session.close()
468
+
469
+ def attach(self, job_id):
470
+ """
471
+ Attach to an existing job and stream its progress.
472
+
473
+ Args:
474
+ job_id (str): The ID of the job to attach to
475
+ """
476
+
477
+ s = requests.Session()
478
+ payload = {
479
+ "job_id": job_id,
480
+ }
481
+ pbar = None
482
+
483
+ headers = {
484
+ "Authorization": f"Key {self.api_key}",
485
+ "Content-Type": "application/json",
486
+ }
487
+
488
+ with yaspin(
489
+ SPINNER,
490
+ text=to_colored_text("Looking for job..."),
491
+ color=YASPIN_COLOR,
492
+ ) as spinner:
493
+ # Get job information from list-jobs endpoint
494
+ # TODO(cooper) we should add a get jobs endpoint:
495
+ # GET /jobs/{job_id}
496
+ jobs_response = s.get(
497
+ f"{self.base_url}/list-jobs",
498
+ headers=headers
499
+ )
500
+ jobs_response.raise_for_status()
501
+
502
+ # Find the specific job we want to attach to
503
+ job = next(
504
+ (job for job in jobs_response.json()["jobs"] if job["job_id"] == job_id),
505
+ None
506
+ )
507
+
508
+ if not job:
509
+ spinner.write(to_colored_text(f"Job {job_id} not found", state="fail"))
510
+ return
511
+
512
+ match job.get("status"):
513
+ case "SUCCEEDED":
514
+ spinner.write(
515
+ to_colored_text(
516
+ f"Job already completed. You can obtain the results with `sutro jobs results {job_id}`"
517
+ )
518
+ )
519
+ return
520
+ case "FAILED":
521
+ spinner.write(to_colored_text("❌ Job is in failed state.", state="fail"))
522
+ return
523
+ case "CANCELLED":
524
+ spinner.write(to_colored_text("❌ Job was cancelled.", state="fail"))
525
+ return
526
+ case _:
527
+ spinner.write(to_colored_text("✔ Job found!", state="success"))
528
+
529
+ total_rows = job["num_rows"]
530
+ success = False
531
+
532
+ session_token = self.register_stream_listener(job_id)
533
+
534
+ with self.stream_heartbeat_session(job_id, session_token) as s:
535
+ with s.get(
536
+ f"{self.base_url}/stream-job-progress/{job_id}?request_session_token={session_token}",
537
+ headers=headers,
538
+ stream=True,
539
+ ) as streaming_response:
540
+ streaming_response.raise_for_status()
541
+ spinner = yaspin(
542
+ SPINNER,
543
+ text=to_colored_text("Awaiting status updates..."),
544
+ color=YASPIN_COLOR,
545
+ )
546
+ spinner.start()
547
+ for line in streaming_response.iter_lines():
548
+ if line:
549
+ try:
550
+ json_obj = json.loads(line)
551
+ except json.JSONDecodeError:
552
+ print("Error: ", line, flush=True)
553
+ continue
554
+
555
+ if json_obj["update_type"] == "progress":
556
+ if pbar is None:
557
+ spinner.stop()
558
+ postfix = f"Input tokens processed: 0"
559
+ pbar = self.fancy_tqdm(
560
+ total=total_rows,
561
+ desc="Progress",
562
+ style=1,
563
+ postfix=postfix,
564
+ )
565
+ if json_obj["result"] > pbar.n:
566
+ pbar.update(json_obj["result"] - pbar.n)
567
+ pbar.refresh()
568
+ if json_obj["result"] == total_rows:
569
+ pbar.close()
570
+ success = True
571
+ elif json_obj["update_type"] == "tokens":
572
+ if pbar is not None:
573
+ pbar.postfix = f"Input tokens processed: {json_obj['result']['input_tokens']}, Tokens generated: {json_obj['result']['output_tokens']}, Total tokens/s: {json_obj['result'].get('total_tokens_processed_per_second')}"
574
+ pbar.refresh()
575
+
576
+ if success:
577
+ spinner.write(
578
+ to_colored_text(
579
+ f"✔ Job succeeded. Use `sutro jobs results {job_id}` to obtain results.",
580
+ state="success",
581
+ )
582
+ )
583
+ spinner.stop()
584
+
585
+
586
+
587
+ def fancy_tqdm(
588
+ self,
589
+ total: int,
590
+ desc: str = "Progress",
591
+ color: str = "blue",
592
+ style=1,
593
+ postfix: str = None,
594
+ ):
595
+ """
596
+ Creates a customized tqdm progress bar with different styling options.
597
+
598
+ Args:
599
+ total (int): Total iterations
600
+ desc (str): Description for the progress bar
601
+ color (str): Color of the progress bar (green, blue, red, yellow, magenta)
602
+ style (int): Style preset (1-4)
603
+ postfix (str): Postfix for the progress bar
604
+ """
605
+
606
+ # Style presets
607
+ style_presets = {
608
+ 1: {
609
+ "bar_format": "{l_bar}{bar:30}| {n_fmt}/{total_fmt} | {percentage:3.0f}% {postfix}",
610
+ "ascii": "░▒█",
611
+ },
612
+ 2: {
613
+ "bar_format": "╢{l_bar}{bar:30}╟ {percentage:3.0f}%",
614
+ "ascii": "▁▂▃▄▅▆▇█",
615
+ },
616
+ 3: {
617
+ "bar_format": "{desc}: |{bar}| {percentage:3.0f}% [{elapsed}<{remaining}]",
618
+ "ascii": "◯◔◑◕●",
619
+ },
620
+ 4: {
621
+ "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
622
+ "ascii": "⬜⬛",
623
+ },
624
+ 5: {
625
+ "bar_format": "⏳ {desc} {percentage:3.0f}% |{bar}| {n_fmt}/{total_fmt}",
626
+ "ascii": "▏▎▍▌▋▊▉█",
627
+ },
628
+ }
629
+
630
+ # Get style configuration
631
+ style_config = style_presets.get(style, style_presets[1])
632
+
633
+ return tqdm(
634
+ total=total,
635
+ desc=desc,
636
+ colour=color,
637
+ bar_format=style_config["bar_format"],
638
+ ascii=style_config["ascii"],
639
+ ncols=80,
640
+ dynamic_ncols=True,
641
+ smoothing=0.3,
642
+ leave=True,
643
+ postfix=postfix,
644
+ )
645
+
646
+ def list_jobs(self):
647
+ """
648
+ List all jobs.
649
+
650
+ This method retrieves a list of all jobs associated with the API key.
651
+
652
+ Returns:
653
+ list: A list of job details.
654
+ """
655
+ endpoint = f"{self.base_url}/list-jobs"
656
+ headers = {
657
+ "Authorization": f"Key {self.api_key}",
658
+ "Content-Type": "application/json",
659
+ }
660
+
661
+ with yaspin(
662
+ SPINNER, text=to_colored_text("Fetching jobs"), color=YASPIN_COLOR
663
+ ) as spinner:
664
+ response = requests.get(endpoint, headers=headers)
665
+ if response.status_code != 200:
666
+ spinner.write(
667
+ to_colored_text(
668
+ f"Bad status code: {response.status_code}", state="fail"
669
+ )
670
+ )
671
+ spinner.stop()
672
+ print(to_colored_text(response.json(), state="fail"))
673
+ return
674
+ return response.json()["jobs"]
675
+
676
+ def get_job_status(self, job_id: str):
677
+ """
678
+ Get the status of a job by its ID.
679
+
680
+ This method retrieves the status of a job using its unique identifier.
681
+
682
+ Args:
683
+ job_id (str): The ID of the job to retrieve the status for.
684
+
685
+ Returns:
686
+ str: The status of the job.
687
+ """
688
+ endpoint = f"{self.base_url}/job-status/{job_id}"
689
+ headers = {
690
+ "Authorization": f"Key {self.api_key}",
691
+ "Content-Type": "application/json",
692
+ }
693
+ with yaspin(
694
+ SPINNER,
695
+ text=to_colored_text(f"Checking job status with ID: {job_id}"),
696
+ color=YASPIN_COLOR,
697
+ ) as spinner:
698
+ response = requests.get(endpoint, headers=headers)
699
+ if response.status_code != 200:
700
+ spinner.write(
701
+ to_colored_text(
702
+ f"Bad status code: {response.status_code}", state="fail"
703
+ )
704
+ )
705
+ spinner.stop()
706
+ print(to_colored_text(response.json(), state="fail"))
707
+ return
708
+ spinner.write(to_colored_text("✔ Job status retrieved!", state="success"))
709
+ return response.json()["job_status"][job_id]
710
+
711
+ def get_job_results(
712
+ self,
713
+ job_id: str,
714
+ include_inputs: bool = False,
715
+ include_cumulative_logprobs: bool = False,
716
+ ):
717
+ """
718
+ Get the results of a job by its ID.
719
+
720
+ This method retrieves the results of a job using its unique identifier.
721
+
722
+ Args:
723
+ job_id (str): The ID of the job to retrieve the results for.
724
+ include_inputs (bool, optional): Whether to include the inputs in the results. Defaults to False.
725
+ include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
726
+
727
+ Returns:
728
+ list: The results of the job.
729
+ """
730
+ endpoint = f"{self.base_url}/job-results"
731
+ payload = {
732
+ "job_id": job_id,
733
+ "include_inputs": include_inputs,
734
+ "include_cumulative_logprobs": include_cumulative_logprobs,
735
+ }
736
+ headers = {
737
+ "Authorization": f"Key {self.api_key}",
738
+ "Content-Type": "application/json",
739
+ }
740
+ with yaspin(
741
+ SPINNER,
742
+ text=to_colored_text(f"Gathering results from job: {job_id}"),
743
+ color=YASPIN_COLOR,
744
+ ) as spinner:
745
+ response = requests.post(
746
+ endpoint, data=json.dumps(payload), headers=headers
747
+ )
748
+ if response.status_code == 200:
749
+ spinner.write(
750
+ to_colored_text("✔ Job results retrieved", state="success")
751
+ )
752
+ else:
753
+ spinner.write(
754
+ to_colored_text(
755
+ f"Bad status code: {response.status_code}", state="fail"
756
+ )
757
+ )
758
+ spinner.stop()
759
+ print(to_colored_text(response.json(), state="fail"))
760
+ return
761
+ return response.json()["results"]
762
+
763
+ def cancel_job(self, job_id: str):
764
+ """
765
+ Cancel a job by its ID.
766
+
767
+ This method allows you to cancel a job using its unique identifier.
768
+
769
+ Args:
770
+ job_id (str): The ID of the job to cancel.
771
+
772
+ Returns:
773
+ dict: The status of the job.
774
+ """
775
+ endpoint = f"{self.base_url}/job-cancel/{job_id}"
776
+ headers = {
777
+ "Authorization": f"Key {self.api_key}",
778
+ "Content-Type": "application/json",
779
+ }
780
+ with yaspin(
781
+ SPINNER,
782
+ text=to_colored_text(f"Cancelling job: {job_id}"),
783
+ color=YASPIN_COLOR,
784
+ ) as spinner:
785
+ response = requests.get(endpoint, headers=headers)
786
+ if response.status_code == 200:
787
+ spinner.write(to_colored_text("✔ Job cancelled", state="success"))
788
+ else:
789
+ spinner.write(to_colored_text("Failed to cancel job", state="fail"))
790
+ spinner.stop()
791
+ print(to_colored_text(response.json(), state="fail"))
792
+ return
793
+ return response.json()
794
+
795
+ def create_stage(self):
796
+ """
797
+ Create a new stage.
798
+
799
+ This method creates a new stage and returns its ID.
800
+
801
+ Returns:
802
+ str: The ID of the new stage.
803
+ """
804
+ endpoint = f"{self.base_url}/create-stage"
805
+ headers = {
806
+ "Authorization": f"Key {self.api_key}",
807
+ "Content-Type": "application/json",
808
+ }
809
+ with yaspin(
810
+ SPINNER, text=to_colored_text("Creating stage"), color=YASPIN_COLOR
811
+ ) as spinner:
812
+ response = requests.get(endpoint, headers=headers)
813
+ if response.status_code != 200:
814
+ spinner.write(
815
+ to_colored_text(
816
+ f"Bad status code: {response.status_code}", state="fail"
817
+ )
818
+ )
819
+ spinner.stop()
820
+ print(to_colored_text(response.json(), state="fail"))
821
+ return
822
+ stage_id = response.json()["stage_id"]
823
+ spinner.write(
824
+ to_colored_text(f"✔ Stage created with ID: {stage_id}", state="success")
825
+ )
826
+ return stage_id
827
+
828
+ def upload_to_stage(
829
+ self,
830
+ stage_id: Union[List[str], str] = None,
831
+ file_paths: Union[List[str], str] = None,
832
+ verify_ssl: bool = True,
833
+ ):
834
+ """
835
+ Upload data to a stage.
836
+
837
+ This method uploads files to a stage. Accepts a stage ID and file paths. If only a single parameter is provided, it will be interpreted as the file paths.
838
+
839
+ Args:
840
+ stage_id (str): The ID of the stage to upload to. If not provided, a new stage will be created.
841
+ file_paths (Union[List[str], str]): A list of paths to the files to upload, or a single path to a collection of files.
842
+ verify_ssl (bool): Whether to verify SSL certificates. Set to False to bypass SSL verification for troubleshooting.
843
+
844
+ Returns:
845
+ dict: The response from the API.
846
+ """
847
+ # when only a single parameter is provided, it is interpreted as the file paths
848
+ if file_paths is None and stage_id is not None:
849
+ file_paths = stage_id
850
+ stage_id = None
851
+
852
+ if file_paths is None:
853
+ raise ValueError("File paths must be provided")
854
+
855
+ if stage_id is None:
856
+ stage_id = self.create_stage()
857
+
858
+ endpoint = f"{self.base_url}/upload-to-stage"
859
+
860
+ if isinstance(file_paths, str):
861
+ # check if the file path is a directory
862
+ if os.path.isdir(file_paths):
863
+ file_paths = [
864
+ os.path.join(file_paths, f) for f in os.listdir(file_paths)
865
+ ]
866
+ if len(file_paths) == 0:
867
+ raise ValueError("No files found in the directory")
868
+ else:
869
+ file_paths = [file_paths]
870
+
871
+ with yaspin(
872
+ SPINNER,
873
+ text=to_colored_text(f"Uploading files to stage: {stage_id}"),
874
+ color=YASPIN_COLOR,
875
+ ) as spinner:
876
+ count = 0
877
+ for file_path in file_paths:
878
+ file_name = os.path.basename(file_path)
879
+
880
+ files = {
881
+ "file": (
882
+ file_name,
883
+ open(file_path, "rb"),
884
+ "application/octet-stream",
885
+ )
886
+ }
887
+
888
+ payload = {
889
+ "stage_id": stage_id,
890
+ }
891
+
892
+ headers = {
893
+ "Authorization": f"Key {self.api_key}"}
894
+
895
+ count += 1
896
+ spinner.write(
897
+ to_colored_text(
898
+ f"Uploading file {count}/{len(file_paths)} to stage: {stage_id}"
899
+ )
900
+ )
901
+
902
+ try:
903
+ response = requests.post(
904
+ endpoint, headers=headers, data=payload, files=files
905
+ )
906
+ if response.status_code != 200:
907
+ # Stop spinner before showing error to avoid terminal width error
908
+ spinner.stop()
909
+ print(
910
+ to_colored_text(
911
+ f"Error: HTTP {response.status_code}", state="fail"
912
+ )
913
+ )
914
+ print(to_colored_text(response.json(), state="fail"))
915
+ return
916
+
917
+ except requests.exceptions.RequestException as e:
918
+ # Stop spinner before showing error to avoid terminal width error
919
+ spinner.stop()
920
+ print(to_colored_text(f"Upload failed: {str(e)}", state="fail"))
921
+ return
922
+
923
+ spinner.write(
924
+ to_colored_text(
925
+ f"✔ {count} files successfully uploaded to stage", state="success"
926
+ )
927
+ )
928
+ return stage_id
929
+
930
+ def list_stages(self):
931
+ endpoint = f"{self.base_url}/list-stages"
932
+ headers = {
933
+ "Authorization": f"Key {self.api_key}",
934
+ "Content-Type": "application/json",
935
+ }
936
+ with yaspin(
937
+ SPINNER, text=to_colored_text("Retrieving stages"), color=YASPIN_COLOR
938
+ ) as spinner:
939
+ response = requests.post(endpoint, headers=headers)
940
+ if response.status_code != 200:
941
+ spinner.fail(
942
+ to_colored_text(
943
+ f"Bad status code: {response.status_code}", state="fail"
944
+ )
945
+ )
946
+ print(to_colored_text(f"Error: {response.json()}", state="fail"))
947
+ return
948
+ spinner.write(to_colored_text("✔ Stages retrieved", state="success"))
949
+ return response.json()["stages"]
950
+
951
+ def list_stage_files(self, stage_id: str):
952
+ endpoint = f"{self.base_url}/list-stage-files"
953
+ headers = {
954
+ "Authorization": f"Key {self.api_key}",
955
+ "Content-Type": "application/json",
956
+ }
957
+ payload = {
958
+ "stage_id": stage_id,
959
+ }
960
+ with yaspin(
961
+ SPINNER,
962
+ text=to_colored_text(f"Listing files in stage: {stage_id}"),
963
+ color=YASPIN_COLOR,
964
+ ) as spinner:
965
+ response = requests.post(
966
+ endpoint, headers=headers, data=json.dumps(payload)
967
+ )
968
+ if response.status_code != 200:
969
+ spinner.fail(
970
+ to_colored_text(
971
+ f"Bad status code: {response.status_code}", state="fail"
972
+ )
973
+ )
974
+ print(to_colored_text(f"Error: {response.json()}", state="fail"))
975
+ return
976
+ spinner.write(
977
+ to_colored_text(f"✔ Files listed in stage: {stage_id}", state="success")
978
+ )
979
+ return response.json()["files"]
980
+
981
+ def download_from_stage(
982
+ self,
983
+ stage_id: str,
984
+ files: Union[List[str], str] = None,
985
+ output_path: str = None,
986
+ ):
987
+ endpoint = f"{self.base_url}/download-from-stage"
988
+
989
+ if files is None:
990
+ files = self.list_stage_files(stage_id)
991
+ elif isinstance(files, str):
992
+ files = [files]
993
+
994
+ if not files:
995
+ print(
996
+ to_colored_text(
997
+ f"Couldn't find files for stage ID: {stage_id}", state="fail"
998
+ )
999
+ )
1000
+ return
1001
+
1002
+ # if no output path is provided, save the files to the current working directory
1003
+ if output_path is None:
1004
+ output_path = os.getcwd()
1005
+
1006
+ with yaspin(
1007
+ SPINNER,
1008
+ text=to_colored_text(f"Downloading files from stage: {stage_id}"),
1009
+ color=YASPIN_COLOR,
1010
+ ) as spinner:
1011
+ count = 0
1012
+ for file in files:
1013
+ headers = {
1014
+ "Authorization": f"Key {self.api_key}",
1015
+ "Content-Type": "application/json",
1016
+ }
1017
+ payload = {
1018
+ "stage_id": stage_id,
1019
+ "file_name": file,
1020
+ }
1021
+ spinner.text = to_colored_text(
1022
+ f"Downloading file {count + 1}/{len(files)} from stage: {stage_id}"
1023
+ )
1024
+ response = requests.post(
1025
+ endpoint, headers=headers, data=json.dumps(payload)
1026
+ )
1027
+ if response.status_code != 200:
1028
+ spinner.fail(
1029
+ to_colored_text(
1030
+ f"Bad status code: {response.status_code}", state="fail"
1031
+ )
1032
+ )
1033
+ print(to_colored_text(f"Error: {response.json()}", state="fail"))
1034
+ return
1035
+ file_content = response.content
1036
+ with open(os.path.join(output_path, file), "wb") as f:
1037
+ f.write(file_content)
1038
+ count += 1
1039
+ spinner.write(
1040
+ to_colored_text(
1041
+ f"✔ {count} files successfully downloaded from stage: {stage_id}",
1042
+ state="success",
1043
+ )
1044
+ )
1045
+
1046
+ def try_authentication(self, api_key: str):
1047
+ """
1048
+ Try to authenticate with the API key.
1049
+
1050
+ This method allows you to authenticate with the API key.
1051
+
1052
+ Args:
1053
+ api_key (str): The API key to authenticate with.
1054
+
1055
+ Returns:
1056
+ dict: The status of the authentication.
1057
+ """
1058
+ endpoint = f"{self.base_url}/try-authentication"
1059
+ headers = {
1060
+ "Authorization": f"Key {api_key}",
1061
+ "Content-Type": "application/json",
1062
+ }
1063
+ with yaspin(
1064
+ SPINNER, text=to_colored_text("Checking API key"), color=YASPIN_COLOR
1065
+ ) as spinner:
1066
+ response = requests.get(endpoint, headers=headers)
1067
+ if response.status_code == 200:
1068
+ spinner.write(to_colored_text("✔"))
1069
+ else:
1070
+ spinner.write(
1071
+ to_colored_text(
1072
+ f"API key failed to authenticate: {response.status_code}",
1073
+ state="fail",
1074
+ )
1075
+ )
1076
+ return
1077
+ return response.json()
1078
+
1079
+ def get_quotas(self):
1080
+ endpoint = f"{self.base_url}/get-quotas"
1081
+ headers = {
1082
+ "Authorization": f"Key {self.api_key}",
1083
+ "Content-Type": "application/json",
1084
+ }
1085
+ with yaspin(
1086
+ SPINNER, text=to_colored_text("Fetching quotas"), color=YASPIN_COLOR
1087
+ ) as spinner:
1088
+ response = requests.get(endpoint, headers=headers)
1089
+ if response.status_code != 200:
1090
+ spinner.fail(
1091
+ to_colored_text(
1092
+ f"Bad status code: {response.status_code}", state="fail"
1093
+ )
1094
+ )
1095
+ print(to_colored_text(f"Error: {response.json()}", state="fail"))
1096
+ return
1097
+ return response.json()["quotas"]