sutro 0.1.20__tar.gz → 0.1.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sutro might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sutro
3
- Version: 0.1.20
3
+ Version: 0.1.22
4
4
  Summary: Sutro Python SDK
5
5
  Project-URL: Homepage, https://sutro.sh
6
6
  Project-URL: Documentation, https://docs.sutro.sh
@@ -9,7 +9,7 @@ installer = "uv"
9
9
 
10
10
  [project]
11
11
  name = "sutro"
12
- version = "0.1.20"
12
+ version = "0.1.22"
13
13
  description = "Sutro Python SDK"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -205,60 +205,22 @@ class Sutro:
205
205
  """
206
206
  self.base_url = base_url
207
207
 
208
- def infer(
208
+ def _run_one_batch_inference(
209
209
  self,
210
210
  data: Union[List, pd.DataFrame, pl.DataFrame, str],
211
- model: ModelOptions = "llama-3.1-8b",
212
- column: str = None,
213
- output_column: str = "inference_result",
214
- job_priority: int = 0,
215
- output_schema: Union[Dict[str, Any], BaseModel] = None,
216
- sampling_params: dict = None,
217
- system_prompt: str = None,
218
- dry_run: bool = False,
219
- stay_attached: Optional[bool] = None,
220
- random_seed_per_input: bool = False,
221
- truncate_rows: bool = False
211
+ model: ModelOptions,
212
+ column: str,
213
+ output_column: str,
214
+ job_priority: int,
215
+ json_schema: Dict[str, Any],
216
+ sampling_params: dict,
217
+ system_prompt: str,
218
+ dry_run: bool,
219
+ stay_attached: Optional[bool],
220
+ random_seed_per_input: bool,
221
+ truncate_rows: bool
222
222
  ):
223
- """
224
- Run inference on the provided data.
225
-
226
- This method allows you to run inference on the provided data using the Sutro API.
227
- It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
228
-
229
- Args:
230
- data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
231
- model (ModelOptions, optional): The model to use for inference. Defaults to "llama-3.1-8b".
232
- column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
233
- output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
234
- job_priority (int, optional): The priority of the job. Defaults to 0.
235
- output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
236
- Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
237
- sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
238
- system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
239
- dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
240
- stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
241
- random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
242
- truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
243
-
244
- Returns:
245
- Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
246
-
247
- """
248
223
  input_data = self.handle_data_helper(data, column)
249
- stay_attached = stay_attached if stay_attached is not None else job_priority == 0
250
-
251
- # Convert BaseModel to dict if needed
252
- if output_schema is not None:
253
- if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
254
- json_schema = output_schema.model_json_schema()
255
- elif isinstance(output_schema, dict):
256
- json_schema = output_schema
257
- else:
258
- raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
259
- else:
260
- json_schema = None
261
-
262
224
  endpoint = f"{self.base_url}/batch-inference"
263
225
  headers = {
264
226
  "Authorization": f"Key {self.api_key}",
@@ -319,9 +281,10 @@ class Sutro:
319
281
  )
320
282
  )
321
283
  if not stay_attached:
284
+ clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
322
285
  spinner.write(
323
286
  to_colored_text(
324
- f"Use `so.get_job_status('{job_id}')` to check the status of the job."
287
+ f"Use `so.get_job_status('{job_id}')` to check the status of the job, or monitor progress at {clickable_link}"
325
288
  )
326
289
  )
327
290
  return job_id
@@ -334,7 +297,8 @@ class Sutro:
334
297
  success = False
335
298
  if stay_attached and job_id is not None:
336
299
  spinner.write(to_colored_text("Awaiting job start...", ))
337
- spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
300
+ clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
301
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
338
302
  started = self._await_job_start(job_id)
339
303
  if not started:
340
304
  failure_reason = self._get_failure_reason(job_id)
@@ -473,6 +437,82 @@ class Sutro:
473
437
  return None
474
438
  return None
475
439
 
440
+ def infer(
441
+ self,
442
+ data: Union[List, pd.DataFrame, pl.DataFrame, str],
443
+ model: Union[ModelOptions, List[ModelOptions]] = "llama-3.1-8b",
444
+ column: str = None,
445
+ output_column: str = "inference_result",
446
+ job_priority: int = 0,
447
+ output_schema: Union[Dict[str, Any], BaseModel] = None,
448
+ sampling_params: dict = None,
449
+ system_prompt: str = None,
450
+ dry_run: bool = False,
451
+ stay_attached: Optional[bool] = None,
452
+ random_seed_per_input: bool = False,
453
+ truncate_rows: bool = False
454
+ ):
455
+ """
456
+ Run inference on the provided data.
457
+
458
+ This method allows you to run inference on the provided data using the Sutro API.
459
+ It supports various data types such as lists, pandas DataFrames, polars DataFrames, file paths and datasets.
460
+
461
+ Args:
462
+ data (Union[List, pd.DataFrame, pl.DataFrame, str]): The data to run inference on.
463
+ model (Union[ModelOptions, List[ModelOptions]], optional): The model(s) to use for inference. Defaults to "llama-3.1-8b". You can pass a single model or a list of models. In the case of a list, the inference will be run in parallel for each model and stay_attached will be set to False.
464
+ column (str, optional): The column name to use for inference. Required if data is a DataFrame, file path, or dataset.
465
+ output_column (str, optional): The column name to store the inference results in if the input is a DataFrame. Defaults to "inference_result".
466
+ job_priority (int, optional): The priority of the job. Defaults to 0.
467
+ output_schema (Union[Dict[str, Any], BaseModel], optional): A structured schema for the output.
468
+ Can be either a dictionary representing a JSON schema or a pydantic BaseModel. Defaults to None.
469
+ sampling_params: (dict, optional): The sampling parameters to use at generation time, ie temperature, top_p etc.
470
+ system_prompt (str, optional): A system prompt to add to all inputs. This allows you to define the behavior of the model. Defaults to None.
471
+ dry_run (bool, optional): If True, the method will return cost estimates instead of running inference. Defaults to False.
472
+ stay_attached (bool, optional): If True, the method will stay attached to the job until it is complete. Defaults to True for prototyping jobs, False otherwise.
473
+ random_seed_per_input (bool, optional): If True, the method will use a different random seed for each input. Defaults to False.
474
+ truncate_rows (bool, optional): If True, any rows that have a token count exceeding the context window length of the selected model will be truncated to the max length that will fit within the context window. Defaults to False.
475
+
476
+ Returns:
477
+ Union[List, pd.DataFrame, pl.DataFrame, str]: The results of the inference.
478
+
479
+ """
480
+ if isinstance(model, list) == False:
481
+ model_list = [model]
482
+ stay_attached = stay_attached if stay_attached is not None else job_priority == 0
483
+ else:
484
+ model_list = model
485
+ stay_attached = False
486
+
487
+ # Convert BaseModel to dict if needed
488
+ if output_schema is not None:
489
+ if hasattr(output_schema, 'model_json_schema'): # Check for pydantic Model interface
490
+ json_schema = output_schema.model_json_schema()
491
+ elif isinstance(output_schema, dict):
492
+ json_schema = output_schema
493
+ else:
494
+ raise ValueError("Invalid output schema type. Must be a dictionary or a pydantic Model.")
495
+ else:
496
+ json_schema = None
497
+
498
+ for model in model_list:
499
+ res = self._run_one_batch_inference(
500
+ data,
501
+ model,
502
+ column,
503
+ output_column,
504
+ job_priority,
505
+ json_schema,
506
+ sampling_params,
507
+ system_prompt,
508
+ dry_run,
509
+ stay_attached,
510
+ random_seed_per_input,
511
+ truncate_rows
512
+ )
513
+ if stay_attached:
514
+ return res
515
+
476
516
 
477
517
  def attach(self, job_id):
478
518
  """
@@ -549,7 +589,8 @@ class Sutro:
549
589
  text=to_colored_text("Awaiting status updates..."),
550
590
  color=YASPIN_COLOR,
551
591
  )
552
- spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
592
+ clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
593
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
553
594
  spinner.start()
554
595
  for line in streaming_response.iter_lines():
555
596
  if line:
@@ -1233,7 +1274,8 @@ class Sutro:
1233
1274
  with yaspin(
1234
1275
  SPINNER, text=to_colored_text("Awaiting job completion"), color=YASPIN_COLOR
1235
1276
  ) as spinner:
1236
- spinner.write(to_colored_text(f'Progress can also be monitored at: {make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')}'))
1277
+ clickable_link = make_clickable_link(f'https://app.sutro.sh/jobs/{job_id}')
1278
+ spinner.write(to_colored_text(f'Progress can also be monitored at: {clickable_link}'))
1237
1279
  while (time.time() - start_time) < timeout:
1238
1280
  try:
1239
1281
  status = self._fetch_job_status(job_id)
File without changes
File without changes
File without changes
File without changes
File without changes