sutro 0.1.28__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sutro might be problematic. Click here for more details.
- sutro/sdk.py +28 -8
- {sutro-0.1.28.dist-info → sutro-0.1.30.dist-info}/METADATA +1 -1
- sutro-0.1.30.dist-info/RECORD +8 -0
- sutro-0.1.28.dist-info/RECORD +0 -8
- {sutro-0.1.28.dist-info → sutro-0.1.30.dist-info}/WHEEL +0 -0
- {sutro-0.1.28.dist-info → sutro-0.1.30.dist-info}/entry_points.txt +0 -0
- {sutro-0.1.28.dist-info → sutro-0.1.30.dist-info}/licenses/LICENSE +0 -0
sutro/sdk.py
CHANGED
|
@@ -436,7 +436,7 @@ class Sutro:
|
|
|
436
436
|
def infer(
|
|
437
437
|
self,
|
|
438
438
|
data: Union[List, pd.DataFrame, pl.DataFrame, str],
|
|
439
|
-
model: Union[ModelOptions, List[ModelOptions]] = "
|
|
439
|
+
model: Union[ModelOptions, List[ModelOptions]] = "gemma-3-12b-it",
|
|
440
440
|
column: str = None,
|
|
441
441
|
output_column: str = "inference_result",
|
|
442
442
|
job_priority: int = 0,
|
|
@@ -815,13 +815,14 @@ class Sutro:
|
|
|
815
815
|
return None
|
|
816
816
|
|
|
817
817
|
def get_job_results(
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
818
|
+
self,
|
|
819
|
+
job_id: str,
|
|
820
|
+
include_inputs: bool = False,
|
|
821
|
+
include_cumulative_logprobs: bool = False,
|
|
822
|
+
with_original_df: pl.DataFrame | pd.DataFrame = None,
|
|
823
|
+
output_column: str = "inference_result",
|
|
824
|
+
disable_cache: bool = False,
|
|
825
|
+
unpack_json: bool = True,
|
|
825
826
|
):
|
|
826
827
|
"""
|
|
827
828
|
Get the results of a job by its ID.
|
|
@@ -834,6 +835,8 @@ class Sutro:
|
|
|
834
835
|
include_cumulative_logprobs (bool, optional): Whether to include the cumulative logprobs in the results. Defaults to False.
|
|
835
836
|
with_original_df (pd.DataFrame | pl.DataFrame, optional): Original DataFrame to concatenate with results. Defaults to None.
|
|
836
837
|
output_column (str, optional): Name of the output column. Defaults to "inference_result".
|
|
838
|
+
disable_cache (bool, optional): Whether to disable the cache. Defaults to False.
|
|
839
|
+
unpack_json (bool, optional): If the output_column is formatted as a JSON string, decides whether to unpack the top level JSON fields in the results into separate columns. Defaults to True.
|
|
837
840
|
|
|
838
841
|
Returns:
|
|
839
842
|
Union[pl.DataFrame, pd.DataFrame]: The results as a DataFrame. By default, returns polars.DataFrame; when with_original_df is an instance of pandas.DataFrame, returns pandas.DataFrame.
|
|
@@ -910,6 +913,23 @@ class Sutro:
|
|
|
910
913
|
|
|
911
914
|
results_df = results_df.select(columns_to_keep)
|
|
912
915
|
|
|
916
|
+
if unpack_json:
|
|
917
|
+
try:
|
|
918
|
+
first_row = json.loads(results_df.head(1)[output_column][0]) # checks if the first row can be json decoded
|
|
919
|
+
results_df = results_df.with_columns(
|
|
920
|
+
pl.col(output_column).str.json_decode().alias("output_column_json_decoded")
|
|
921
|
+
)
|
|
922
|
+
json_decoded_fields = first_row.keys()
|
|
923
|
+
for field in json_decoded_fields:
|
|
924
|
+
results_df = results_df.with_columns(
|
|
925
|
+
pl.col("output_column_json_decoded").struct.field(field).alias(field)
|
|
926
|
+
)
|
|
927
|
+
# drop the output_column and the json decoded column
|
|
928
|
+
results_df = results_df.drop([output_column, "output_column_json_decoded"])
|
|
929
|
+
except json.JSONDecodeError:
|
|
930
|
+
# if the first row cannot be json decoded, do nothing
|
|
931
|
+
pass
|
|
932
|
+
|
|
913
933
|
# Handle concatenation with original DataFrame
|
|
914
934
|
if with_original_df is not None:
|
|
915
935
|
if isinstance(with_original_df, pd.DataFrame):
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
+
sutro/cli.py,sha256=8DrJVbjoayCUz4iszlj35Tv1q1gzDVzx_CuF6gZHwuU,13636
|
|
3
|
+
sutro/sdk.py,sha256=Jt1jDS2nFNI-YbY3rIbzZrJLDcVPjlcf4Lgs4ijL4ag,55359
|
|
4
|
+
sutro-0.1.30.dist-info/METADATA,sha256=fyEcdCfJ2JFXKTYeo7KsK_nlXVBz1BRgje4ZcOWMJfM,700
|
|
5
|
+
sutro-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
sutro-0.1.30.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
+
sutro-0.1.30.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
sutro-0.1.30.dist-info/RECORD,,
|
sutro-0.1.28.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
sutro/__init__.py,sha256=yUiVwcZ8QamSqDdRHgzoANyTZ-x3cPzlt2Fs5OllR_w,402
|
|
2
|
-
sutro/cli.py,sha256=8DrJVbjoayCUz4iszlj35Tv1q1gzDVzx_CuF6gZHwuU,13636
|
|
3
|
-
sutro/sdk.py,sha256=0_KC8eOOEunEPChrpToIucbg6q-icyxSUeMwd1UKDSY,54150
|
|
4
|
-
sutro-0.1.28.dist-info/METADATA,sha256=Nu730OdGc5TMVMlw3_kqxZ03z8_3jsz57Og5PIhk7Wo,700
|
|
5
|
-
sutro-0.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
sutro-0.1.28.dist-info/entry_points.txt,sha256=eXvr4dvMV4UmZgR0zmrY8KOmNpo64cJkhNDywiadRFM,40
|
|
7
|
-
sutro-0.1.28.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
sutro-0.1.28.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|