sutro 0.1.20__tar.gz → 0.1.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- {sutro-0.1.20 → sutro-0.1.22}/PKG-INFO +1 -1
- {sutro-0.1.20 → sutro-0.1.22}/pyproject.toml +1 -1
- {sutro-0.1.20 → sutro-0.1.22}/sutro/sdk.py +96 -54
- {sutro-0.1.20 → sutro-0.1.22}/.gitignore +0 -0
- {sutro-0.1.20 → sutro-0.1.22}/LICENSE +0 -0
- {sutro-0.1.20 → sutro-0.1.22}/README.md +0 -0
- {sutro-0.1.20 → sutro-0.1.22}/sutro/__init__.py +0 -0
- {sutro-0.1.20 → sutro-0.1.22}/sutro/cli.py +0 -0
|
@@ -205,60 +205,22 @@ class Sutro:
|
|
|
205
205
|
"""
|
|
206
206
|
self.base_url = base_url
|
|
207
207
|
|
|
208
|
-
def
|
|
208
|
+
def _run_one_batch_inference(
|
|
209
209
|
self,
|
|
210
210
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
211
|
-
model: ModelOptions
|
|
212
|
-
column: str
|
|
213
|
-
output_column: str
|
|
214
|
-
job_priority: int
|
|
215
|
-
|
|
216
|
-
sampling_params: dict
|
|
217
|
-
system_prompt: str
|
|
218
|
-
dry_run: bool
|
|
219
|
-
stay_attached: Optional[bool]
|
|
220
|
-
random_seed_per_input: bool
|
|
221
|
-
truncate_rows: bool
|
|
211
|
+
model: ModelOptions,
|
|
212
|
+
column: str,
|
|
213
|
+
output_column: str,
|
|
214
|
+
job_priority: int,
|
|
215
|
+
json_schema: Dict[str, Any],
|
|
216
|
+
sampling_params: dict,
|
|
217
|
+
system_prompt: str,
|
|
218
|
+
dry_run: bool,
|
|
219
|
+
stay_attached: Optional[bool],
|
|
220
|
+
random_seed_per_input: bool,
|
|
221
|
+
truncate_rows: bool
|
|
222
222
|
):
|
|
223
|
-
"""
|
|
224
|
-
Run inference on the provided data.
|
|
225
|
-
|
|
226
|
-
This method allows you to run inference on the provided data using the Sutro API.
|
|
227
|
-
It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
231
|
-
model (ModelOptions, optional): The model to use for inference. Defaults to "llama-3.1-8b".
|
|
232
|
-
column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
|
|
233
|
-
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
234
|
-
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
235
|
-
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
236
|
-
Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
|
|
237
|
-
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
238
|
-
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
239
|
-
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
240
|
-
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
241
|
-
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
242
|
-
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
|
|
246
|
-
|
|
247
|
-
"""
|
|
248
223
|
input_data = self.handle_data_helper(data, column)
|
|
249
|
-
stay_attached = stay_attached if stay_attached is not None else job_priority == 0
|
|
250
|
-
|
|
251
|
-
# Convert BaseModel to dict if needed
|
|
252
|
-
if output_schema is not None:
|
|
253
|
-
if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
|
|
254
|
-
json_schema = output_schema.model_json_schema()
|
|
255
|
-
elif isinstance(output_schema, dict):
|
|
256
|
-
json_schema = output_schema
|
|
257
|
-
else:
|
|
258
|
-
raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
|
|
259
|
-
else:
|
|
260
|
-
json_schema = None
|
|
261
|
-
|
|
262
224
|
endpoint = f"{self.base_url}/batch-inference"
|
|
263
225
|
headers = {
|
|
264
226
|
"Authorization": f"Key {self.api_key}",
|
|
@@ -319,9 +281,10 @@ class Sutro:
|
|
|
319
281
|
)
|
|
320
282
|
)
|
|
321
283
|
if not stay_attached:
|
|
284
|
+
clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
|
|
322
285
|
spinner.write(
|
|
323
286
|
to_colored_text(
|
|
324
|
-
f"Use `so.get_job_status('{job_id}')` to check the status of the job
|
|
287
|
+
f"Use `so.get_job_status('{job_id}')` to check the status of the job, or monitor progress at {clickable_link}"
|
|
325
288
|
)
|
|
326
289
|
)
|
|
327
290
|
return job_id
|
|
@@ -334,7 +297,8 @@ class Sutro:
|
|
|
334
297
|
success = False
|
|
335
298
|
if stay_attached and job_id is not None:
|
|
336
299
|
spinner.write(to_colored_text("Awaiting job start...", ))
|
|
337
|
-
|
|
300
|
+
clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
|
|
301
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
|
|
338
302
|
started = self._await_job_start(job_id)
|
|
339
303
|
if not started:
|
|
340
304
|
failure_reason = self._get_failure_reason(job_id)
|
|
@@ -473,6 +437,82 @@ class Sutro:
|
|
|
473
437
|
return None
|
|
474
438
|
return None
|
|
475
439
|
|
|
440
|
+
def infer(
|
|
441
|
+
self,
|
|
442
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
443
|
+
model: Union[ModelOptions, List[ModelOptions]] = "llama-3.1-8b",
|
|
444
|
+
column: str = None,
|
|
445
|
+
output_column: str = "inference_result",
|
|
446
|
+
job_priority: int = 0,
|
|
447
|
+
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
448
|
+
sampling_params: dict = None,
|
|
449
|
+
system_prompt: str = None,
|
|
450
|
+
dry_run: bool = False,
|
|
451
|
+
stay_attached: Optional[bool] = None,
|
|
452
|
+
random_seed_per_input: bool = False,
|
|
453
|
+
truncate_rows: bool = False
|
|
454
|
+
):
|
|
455
|
+
"""
|
|
456
|
+
Run inference on the provided data.
|
|
457
|
+
|
|
458
|
+
This method allows you to run inference on the provided data using the Sutro API.
|
|
459
|
+
It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
463
|
+
model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
|
|
464
|
+
column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
|
|
465
|
+
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
466
|
+
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
467
|
+
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
468
|
+
Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
|
|
469
|
+
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
470
|
+
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
471
|
+
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
472
|
+
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
473
|
+
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
474
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
|
|
478
|
+
|
|
479
|
+
"""
|
|
480
|
+
if isinstance(model, list) == False:
|
|
481
|
+
model_list = [model]
|
|
482
|
+
stay_attached = stay_attached if stay_attached is not None else job_priority == 0
|
|
483
|
+
else:
|
|
484
|
+
model_list = model
|
|
485
|
+
stay_attached = False
|
|
486
|
+
|
|
487
|
+
# Convert BaseModel to dict if needed
|
|
488
|
+
if output_schema is not None:
|
|
489
|
+
if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
|
|
490
|
+
json_schema = output_schema.model_json_schema()
|
|
491
|
+
elif isinstance(output_schema, dict):
|
|
492
|
+
json_schema = output_schema
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
|
|
495
|
+
else:
|
|
496
|
+
json_schema = None
|
|
497
|
+
|
|
498
|
+
for model in model_list:
|
|
499
|
+
res = self._run_one_batch_inference(
|
|
500
|
+
data,
|
|
501
|
+
model,
|
|
502
|
+
column,
|
|
503
|
+
output_column,
|
|
504
|
+
job_priority,
|
|
505
|
+
json_schema,
|
|
506
|
+
sampling_params,
|
|
507
|
+
system_prompt,
|
|
508
|
+
dry_run,
|
|
509
|
+
stay_attached,
|
|
510
|
+
random_seed_per_input,
|
|
511
|
+
truncate_rows
|
|
512
|
+
)
|
|
513
|
+
if stay_attached:
|
|
514
|
+
return res
|
|
515
|
+
|
|
476
516
|
|
|
477
517
|
def attach(self, job_id):
|
|
478
518
|
"""
|
|
@@ -549,7 +589,8 @@ class Sutro:
|
|
|
549
589
|
text=to_colored_text("Awaiting status updates..."),
|
|
550
590
|
color=YASPIN_COLOR,
|
|
551
591
|
)
|
|
552
|
-
|
|
592
|
+
clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
|
|
593
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
|
|
553
594
|
spinner.start()
|
|
554
595
|
for line in streaming_response.iter_lines():
|
|
555
596
|
if line:
|
|
@@ -1233,7 +1274,8 @@ class Sutro:
|
|
|
1233
1274
|
with yaspin(
|
|
1234
1275
|
SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
|
|
1235
1276
|
) as spinner:
|
|
1236
|
-
|
|
1277
|
+
clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
|
|
1278
|
+
spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
|
|
1237
1279
|
while (time.time() - start_time) < timeout:
|
|
1238
1280
|
try:
|
|
1239
1281
|
status = self._fetch_job_status(job_id)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|