sutro 0.1.21__tar.gz → 0.1.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- {sutro-0.1.21 → sutro-0.1.22}/PKG-INFO +1 -1
- {sutro-0.1.21 → sutro-0.1.22}/pyproject.toml +1 -1
- {sutro-0.1.21 → sutro-0.1.22}/sutro/sdk.py +90 -51
- {sutro-0.1.21 → sutro-0.1.22}/.gitignore +0 -0
- {sutro-0.1.21 → sutro-0.1.22}/LICENSE +0 -0
- {sutro-0.1.21 → sutro-0.1.22}/README.md +0 -0
- {sutro-0.1.21 → sutro-0.1.22}/sutro/__init__.py +0 -0
- {sutro-0.1.21 → sutro-0.1.22}/sutro/cli.py +0 -0
|
@@ -205,60 +205,22 @@ class Sutro:
|
|
|
205
205
|
"""
|
|
206
206
|
self.base_url = base_url
|
|
207
207
|
|
|
208
|
-
def
|
|
208
|
+
def _run_one_batch_inference(
|
|
209
209
|
self,
|
|
210
210
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
211
|
-
model: ModelOptions
|
|
212
|
-
column: str
|
|
213
|
-
output_column: str
|
|
214
|
-
job_priority: int
|
|
215
|
-
|
|
216
|
-
sampling_params: dict
|
|
217
|
-
system_prompt: str
|
|
218
|
-
dry_run: bool
|
|
219
|
-
stay_attached: Optional[bool]
|
|
220
|
-
random_seed_per_input: bool
|
|
221
|
-
truncate_rows: bool
|
|
211
|
+
model: ModelOptions,
|
|
212
|
+
column: str,
|
|
213
|
+
output_column: str,
|
|
214
|
+
job_priority: int,
|
|
215
|
+
json_schema: Dict[str, Any],
|
|
216
|
+
sampling_params: dict,
|
|
217
|
+
system_prompt: str,
|
|
218
|
+
dry_run: bool,
|
|
219
|
+
stay_attached: Optional[bool],
|
|
220
|
+
random_seed_per_input: bool,
|
|
221
|
+
truncate_rows: bool
|
|
222
222
|
):
|
|
223
|
-
"""
|
|
224
|
-
Run inference on the provided data.
|
|
225
|
-
|
|
226
|
-
This method allows you to run inference on the provided data using the Sutro API.
|
|
227
|
-
It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
231
|
-
model (ModelOptions, optional): The model to use for inference. Defaults to "llama-3.1-8b".
|
|
232
|
-
column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
|
|
233
|
-
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
234
|
-
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
235
|
-
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
236
|
-
Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
|
|
237
|
-
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
238
|
-
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
239
|
-
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
240
|
-
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
241
|
-
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
242
|
-
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
|
|
246
|
-
|
|
247
|
-
"""
|
|
248
223
|
input_data = self.handle_data_helper(data, column)
|
|
249
|
-
stay_attached = stay_attached if stay_attached is not None else job_priority == 0
|
|
250
|
-
|
|
251
|
-
# Convert BaseModel to dict if needed
|
|
252
|
-
if output_schema is not None:
|
|
253
|
-
if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
|
|
254
|
-
json_schema = output_schema.model_json_schema()
|
|
255
|
-
elif isinstance(output_schema, dict):
|
|
256
|
-
json_schema = output_schema
|
|
257
|
-
else:
|
|
258
|
-
raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
|
|
259
|
-
else:
|
|
260
|
-
json_schema = None
|
|
261
|
-
|
|
262
224
|
endpoint = f"{self.base_url}/batch-inference"
|
|
263
225
|
headers = {
|
|
264
226
|
"Authorization": f"Key {self.api_key}",
|
|
@@ -319,9 +281,10 @@ class Sutro:
|
|
|
319
281
|
)
|
|
320
282
|
)
|
|
321
283
|
if not stay_attached:
|
|
284
|
+
clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
|
|
322
285
|
spinner.write(
|
|
323
286
|
to_colored_text(
|
|
324
|
-
f"Use `so.get_job_status('{job_id}')` to check the status of the job
|
|
287
|
+
f"Use `so.get_job_status('{job_id}')` to check the status of the job, or monitor progress at {clickable_link}"
|
|
325
288
|
)
|
|
326
289
|
)
|
|
327
290
|
return job_id
|
|
@@ -474,6 +437,82 @@ class Sutro:
|
|
|
474
437
|
return None
|
|
475
438
|
return None
|
|
476
439
|
|
|
440
|
+
def infer(
|
|
441
|
+
self,
|
|
442
|
+
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
443
|
+
model: Union[ModelOptions, List[ModelOptions]] = "llama-3.1-8b",
|
|
444
|
+
column: str = None,
|
|
445
|
+
output_column: str = "inference_result",
|
|
446
|
+
job_priority: int = 0,
|
|
447
|
+
output_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
448
|
+
sampling_params: dict = None,
|
|
449
|
+
system_prompt: str = None,
|
|
450
|
+
dry_run: bool = False,
|
|
451
|
+
stay_attached: Optional[bool] = None,
|
|
452
|
+
random_seed_per_input: bool = False,
|
|
453
|
+
truncate_rows: bool = False
|
|
454
|
+
):
|
|
455
|
+
"""
|
|
456
|
+
Run inference on the provided data.
|
|
457
|
+
|
|
458
|
+
This method allows you to run inference on the provided data using the Sutro API.
|
|
459
|
+
It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
|
|
463
|
+
model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
|
|
464
|
+
column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
|
|
465
|
+
output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
|
|
466
|
+
job_priority (int, optional): The priority of the job. Defaults to 0.
|
|
467
|
+
output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
|
|
468
|
+
Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
|
|
469
|
+
sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
|
|
470
|
+
system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
|
|
471
|
+
dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
|
|
472
|
+
stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
|
|
473
|
+
random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
|
|
474
|
+
truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
|
|
478
|
+
|
|
479
|
+
"""
|
|
480
|
+
if isinstance(model, list) == False:
|
|
481
|
+
model_list = [model]
|
|
482
|
+
stay_attached = stay_attached if stay_attached is not None else job_priority == 0
|
|
483
|
+
else:
|
|
484
|
+
model_list = model
|
|
485
|
+
stay_attached = False
|
|
486
|
+
|
|
487
|
+
# Convert BaseModel to dict if needed
|
|
488
|
+
if output_schema is not None:
|
|
489
|
+
if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
|
|
490
|
+
json_schema = output_schema.model_json_schema()
|
|
491
|
+
elif isinstance(output_schema, dict):
|
|
492
|
+
json_schema = output_schema
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
|
|
495
|
+
else:
|
|
496
|
+
json_schema = None
|
|
497
|
+
|
|
498
|
+
for model in model_list:
|
|
499
|
+
res = self._run_one_batch_inference(
|
|
500
|
+
data,
|
|
501
|
+
model,
|
|
502
|
+
column,
|
|
503
|
+
output_column,
|
|
504
|
+
job_priority,
|
|
505
|
+
json_schema,
|
|
506
|
+
sampling_params,
|
|
507
|
+
system_prompt,
|
|
508
|
+
dry_run,
|
|
509
|
+
stay_attached,
|
|
510
|
+
random_seed_per_input,
|
|
511
|
+
truncate_rows
|
|
512
|
+
)
|
|
513
|
+
if stay_attached:
|
|
514
|
+
return res
|
|
515
|
+
|
|
477
516
|
|
|
478
517
|
def attach(self, job_id):
|
|
479
518
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|