snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +2 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
- snowflake/ml/jobs/_utils/spec_utils.py +0 -31
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +109 -32
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +45 -2
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +81 -61
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +30 -29
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +85 -0
- snowflake/ml/model/_signatures/utils.py +55 -0
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,14 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
|
22
22
|
return f"'{url}'"
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def _format_param_value(value: Any) -> str:
|
|
26
|
+
if isinstance(value, str):
|
|
27
|
+
return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
|
|
28
|
+
elif value is None:
|
|
29
|
+
return "NULL"
|
|
30
|
+
return str(value)
|
|
31
|
+
|
|
32
|
+
|
|
25
33
|
class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
26
34
|
FUNCTION_NAME_COL_NAME = "name"
|
|
27
35
|
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
|
@@ -354,6 +362,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
354
362
|
input_args: list[sql_identifier.SqlIdentifier],
|
|
355
363
|
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
|
356
364
|
statement_params: Optional[dict[str, Any]] = None,
|
|
365
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
357
366
|
) -> dataframe.DataFrame:
|
|
358
367
|
with_statements = []
|
|
359
368
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -392,10 +401,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
392
401
|
|
|
393
402
|
args_sql = ", ".join(args_sql_list)
|
|
394
403
|
|
|
395
|
-
|
|
404
|
+
if params:
|
|
405
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
406
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
407
|
+
|
|
408
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
409
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
396
410
|
if wide_input:
|
|
397
|
-
|
|
398
|
-
|
|
411
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
412
|
+
if params:
|
|
413
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
414
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
399
415
|
|
|
400
416
|
sql = textwrap.dedent(
|
|
401
417
|
f"""WITH {','.join(with_statements)}
|
|
@@ -439,6 +455,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
439
455
|
statement_params: Optional[dict[str, Any]] = None,
|
|
440
456
|
is_partitioned: bool = True,
|
|
441
457
|
explain_case_sensitive: bool = False,
|
|
458
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
442
459
|
) -> dataframe.DataFrame:
|
|
443
460
|
with_statements = []
|
|
444
461
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -477,10 +494,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
477
494
|
|
|
478
495
|
args_sql = ", ".join(args_sql_list)
|
|
479
496
|
|
|
480
|
-
|
|
497
|
+
if params:
|
|
498
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
499
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
500
|
+
|
|
501
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
502
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
481
503
|
if wide_input:
|
|
482
|
-
|
|
483
|
-
|
|
504
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
505
|
+
if params:
|
|
506
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
507
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
484
508
|
|
|
485
509
|
sql = textwrap.dedent(
|
|
486
510
|
f"""WITH {','.join(with_statements)}
|
|
@@ -20,6 +20,15 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
|
+
|
|
24
|
+
def _format_param_value(value: Any) -> str:
|
|
25
|
+
if isinstance(value, str):
|
|
26
|
+
return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
|
|
27
|
+
elif value is None:
|
|
28
|
+
return "NULL"
|
|
29
|
+
return str(value)
|
|
30
|
+
|
|
31
|
+
|
|
23
32
|
# Using this token instead of '?' to avoid escaping issues
|
|
24
33
|
# After quotes are escaped, we replace this token with '|| ? ||'
|
|
25
34
|
QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
|
|
@@ -38,22 +47,6 @@ class ServiceStatus(enum.Enum):
|
|
|
38
47
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
|
39
48
|
|
|
40
49
|
|
|
41
|
-
class InstanceStatus(enum.Enum):
|
|
42
|
-
PENDING = "PENDING"
|
|
43
|
-
READY = "READY"
|
|
44
|
-
FAILED = "FAILED"
|
|
45
|
-
TERMINATING = "TERMINATING"
|
|
46
|
-
SUCCEEDED = "SUCCEEDED"
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class ContainerStatus(enum.Enum):
|
|
50
|
-
PENDING = "PENDING"
|
|
51
|
-
READY = "READY"
|
|
52
|
-
DONE = "DONE"
|
|
53
|
-
FAILED = "FAILED"
|
|
54
|
-
UNKNOWN = "UNKNOWN"
|
|
55
|
-
|
|
56
|
-
|
|
57
50
|
@dataclasses.dataclass
|
|
58
51
|
class ServiceStatusInfo:
|
|
59
52
|
"""
|
|
@@ -63,8 +56,8 @@ class ServiceStatusInfo:
|
|
|
63
56
|
|
|
64
57
|
service_status: ServiceStatus
|
|
65
58
|
instance_id: Optional[int] = None
|
|
66
|
-
instance_status: Optional[
|
|
67
|
-
container_status: Optional[
|
|
59
|
+
instance_status: Optional[str] = None
|
|
60
|
+
container_status: Optional[str] = None
|
|
68
61
|
message: Optional[str] = None
|
|
69
62
|
|
|
70
63
|
|
|
@@ -140,6 +133,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
140
133
|
input_args: list[sql_identifier.SqlIdentifier],
|
|
141
134
|
returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
|
142
135
|
statement_params: Optional[dict[str, Any]] = None,
|
|
136
|
+
params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
|
|
143
137
|
) -> dataframe.DataFrame:
|
|
144
138
|
with_statements = []
|
|
145
139
|
actual_database_name = database_name or self._database_name
|
|
@@ -170,10 +164,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
170
164
|
args_sql_list.append(input_arg_value)
|
|
171
165
|
args_sql = ", ".join(args_sql_list)
|
|
172
166
|
|
|
173
|
-
|
|
167
|
+
if params:
|
|
168
|
+
param_sql = ", ".join(_format_param_value(val) for _, val in params)
|
|
169
|
+
args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
|
|
170
|
+
|
|
171
|
+
total_args = len(input_args) + (len(params) if params else 0)
|
|
172
|
+
wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
|
174
173
|
if wide_input:
|
|
175
|
-
|
|
176
|
-
|
|
174
|
+
parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
|
|
175
|
+
if params:
|
|
176
|
+
parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
|
|
177
|
+
args_sql = f"object_construct_keep_null({', '.join(parts)})"
|
|
177
178
|
|
|
178
179
|
fully_qualified_service_name = self.fully_qualified_object_name(
|
|
179
180
|
actual_database_name, actual_schema_name, service_name
|
|
@@ -255,17 +256,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
255
256
|
)
|
|
256
257
|
statuses = []
|
|
257
258
|
for r in rows:
|
|
258
|
-
instance_status, container_status = None, None
|
|
259
|
-
if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
|
|
260
|
-
instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
|
|
261
|
-
if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
|
|
262
|
-
container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
|
|
263
259
|
statuses.append(
|
|
264
260
|
ServiceStatusInfo(
|
|
265
261
|
service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
|
|
266
262
|
instance_id=r[ServiceSQLClient.INSTANCE_ID],
|
|
267
|
-
instance_status=
|
|
268
|
-
container_status=
|
|
263
|
+
instance_status=r[ServiceSQLClient.INSTANCE_STATUS],
|
|
264
|
+
container_status=r[ServiceSQLClient.CONTAINER_STATUS],
|
|
269
265
|
message=r[ServiceSQLClient.MESSAGE] if include_message else None,
|
|
270
266
|
)
|
|
271
267
|
)
|
|
@@ -301,7 +297,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
301
297
|
False if service doesn't have proxy container
|
|
302
298
|
"""
|
|
303
299
|
try:
|
|
304
|
-
|
|
300
|
+
spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
|
|
301
|
+
if spec_yaml is None:
|
|
302
|
+
return False
|
|
303
|
+
spec_raw = yaml.safe_load(spec_yaml)
|
|
304
|
+
if spec_raw is None:
|
|
305
|
+
return False
|
|
305
306
|
spec = cast(dict[str, Any], spec_raw)
|
|
306
307
|
|
|
307
308
|
proxy_container_spec = next(
|
|
@@ -131,7 +131,7 @@ class ModelComposer:
|
|
|
131
131
|
python_version: Optional[str] = None,
|
|
132
132
|
user_files: Optional[dict[str, list[str]]] = None,
|
|
133
133
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
134
|
-
code_paths: Optional[list[
|
|
134
|
+
code_paths: Optional[list[model_types.CodePathLike]] = None,
|
|
135
135
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
136
136
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
137
137
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
@@ -39,6 +39,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField):
|
|
|
39
39
|
name: Required[str]
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
class ModelMethodSignatureFieldWithNameAndDefault(ModelMethodSignatureFieldWithName):
|
|
43
|
+
default: Required[Any]
|
|
44
|
+
|
|
45
|
+
|
|
42
46
|
class ModelFunctionMethodDict(TypedDict):
|
|
43
47
|
name: Required[str]
|
|
44
48
|
runtime: Required[str]
|
|
@@ -46,6 +50,7 @@ class ModelFunctionMethodDict(TypedDict):
|
|
|
46
50
|
handler: Required[str]
|
|
47
51
|
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
|
48
52
|
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
|
53
|
+
params: NotRequired[list[ModelMethodSignatureFieldWithNameAndDefault]]
|
|
49
54
|
volatility: NotRequired[str]
|
|
50
55
|
|
|
51
56
|
|
|
@@ -41,11 +41,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
41
41
|
input_cols = [feature.name for feature in features]
|
|
42
42
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
|
44
|
+
# Load inference parameters from method signature (if any)
|
|
45
|
+
param_cols = []
|
|
46
|
+
param_defaults = {{}}
|
|
47
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
48
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
49
|
+
param_cols.append(param_spec.name)
|
|
50
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
51
|
+
|
|
44
52
|
|
|
45
53
|
# Actual function
|
|
46
54
|
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
|
47
55
|
def {function_name}(df: pd.DataFrame) -> dict:
|
|
48
|
-
df.columns = input_cols
|
|
49
|
-
input_df = df.astype(dtype=dtype_map)
|
|
50
|
-
|
|
56
|
+
df.columns = input_cols + param_cols
|
|
57
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
58
|
+
|
|
59
|
+
# Extract runtime param values, using defaults if None
|
|
60
|
+
method_params = {{}}
|
|
61
|
+
for col in param_cols:
|
|
62
|
+
val = df[col].iloc[0]
|
|
63
|
+
if val is None or pd.isna(val):
|
|
64
|
+
method_params[col] = param_defaults[col]
|
|
65
|
+
else:
|
|
66
|
+
method_params[col] = val
|
|
67
|
+
|
|
68
|
+
predictions_df = runner(input_df, **method_params)
|
|
51
69
|
return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
|
|
@@ -45,11 +45,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
45
45
|
input_cols = [feature.name for feature in features]
|
|
46
46
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
47
47
|
|
|
48
|
+
# Load inference parameters from method signature (if any)
|
|
49
|
+
param_cols = []
|
|
50
|
+
param_defaults = {{}}
|
|
51
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
52
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
53
|
+
param_cols.append(param_spec.name)
|
|
54
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
55
|
+
|
|
48
56
|
|
|
49
57
|
# Actual table function
|
|
50
58
|
class {function_name}:
|
|
51
59
|
@vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
|
|
52
60
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
53
|
-
df.columns = input_cols
|
|
54
|
-
input_df = df.astype(dtype=dtype_map)
|
|
55
|
-
|
|
61
|
+
df.columns = input_cols + param_cols
|
|
62
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
63
|
+
|
|
64
|
+
# Extract runtime param values, using defaults if None
|
|
65
|
+
method_params = {{}}
|
|
66
|
+
for col in param_cols:
|
|
67
|
+
val = df[col].iloc[0]
|
|
68
|
+
if val is None or pd.isna(val):
|
|
69
|
+
method_params[col] = param_defaults[col]
|
|
70
|
+
else:
|
|
71
|
+
method_params[col] = val
|
|
72
|
+
|
|
73
|
+
return runner(input_df, **method_params)
|
|
@@ -40,11 +40,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
40
40
|
input_cols = [feature.name for feature in features]
|
|
41
41
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
42
42
|
|
|
43
|
+
# Load inference parameters from method signature (if any)
|
|
44
|
+
param_cols = []
|
|
45
|
+
param_defaults = {{}}
|
|
46
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
47
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
48
|
+
param_cols.append(param_spec.name)
|
|
49
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
50
|
+
|
|
43
51
|
|
|
44
52
|
# Actual table function
|
|
45
53
|
class {function_name}:
|
|
46
54
|
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
|
47
55
|
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
48
|
-
df.columns = input_cols
|
|
49
|
-
input_df = df.astype(dtype=dtype_map)
|
|
50
|
-
|
|
56
|
+
df.columns = input_cols + param_cols
|
|
57
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
58
|
+
|
|
59
|
+
# Extract runtime param values, using defaults if None
|
|
60
|
+
method_params = {{}}
|
|
61
|
+
for col in param_cols:
|
|
62
|
+
val = df[col].iloc[0]
|
|
63
|
+
if val is None or pd.isna(val):
|
|
64
|
+
method_params[col] = param_defaults[col]
|
|
65
|
+
else:
|
|
66
|
+
method_params[col] = val
|
|
67
|
+
|
|
68
|
+
return runner(input_df, **method_params)
|
|
@@ -105,7 +105,7 @@ class ModelMethod:
|
|
|
105
105
|
except ValueError as e:
|
|
106
106
|
raise ValueError(
|
|
107
107
|
f"Your target method {self.target_method} cannot be resolved as valid SQL identifier. "
|
|
108
|
-
"Try
|
|
108
|
+
"Try specifying `case_sensitive` as True."
|
|
109
109
|
) from e
|
|
110
110
|
|
|
111
111
|
if self.target_method not in self.model_meta.signatures.keys():
|
|
@@ -127,12 +127,42 @@ class ModelMethod:
|
|
|
127
127
|
except ValueError as e:
|
|
128
128
|
raise ValueError(
|
|
129
129
|
f"Your feature {feature.name} cannot be resolved as valid SQL identifier. "
|
|
130
|
-
"Try
|
|
130
|
+
"Try specifying `case_sensitive` as True."
|
|
131
131
|
) from e
|
|
132
132
|
return model_manifest_schema.ModelMethodSignatureFieldWithName(
|
|
133
133
|
name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type())
|
|
134
134
|
)
|
|
135
135
|
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _flatten_params(params: list[model_signature.BaseParamSpec]) -> list[model_signature.ParamSpec]:
|
|
138
|
+
"""Flatten ParamGroupSpec into leaf ParamSpec items."""
|
|
139
|
+
result: list[model_signature.ParamSpec] = []
|
|
140
|
+
for param in params:
|
|
141
|
+
if isinstance(param, model_signature.ParamSpec):
|
|
142
|
+
result.append(param)
|
|
143
|
+
elif isinstance(param, model_signature.ParamGroupSpec):
|
|
144
|
+
result.extend(ModelMethod._flatten_params(param.specs))
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _get_method_arg_from_param(
|
|
149
|
+
param_spec: model_signature.ParamSpec,
|
|
150
|
+
case_sensitive: bool = False,
|
|
151
|
+
) -> model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault:
|
|
152
|
+
try:
|
|
153
|
+
param_name = sql_identifier.SqlIdentifier(param_spec.name, case_sensitive=case_sensitive)
|
|
154
|
+
except ValueError as e:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
|
|
157
|
+
"Try specifying `case_sensitive` as True."
|
|
158
|
+
) from e
|
|
159
|
+
default_value = param_spec.default_value if param_spec.default_value is None else str(param_spec.default_value)
|
|
160
|
+
return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
|
|
161
|
+
name=param_name.resolved(),
|
|
162
|
+
type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
|
|
163
|
+
default=default_value,
|
|
164
|
+
)
|
|
165
|
+
|
|
136
166
|
def save(
|
|
137
167
|
self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None
|
|
138
168
|
) -> model_manifest_schema.ModelMethodDict:
|
|
@@ -182,6 +212,36 @@ class ModelMethod:
|
|
|
182
212
|
inputs=input_list,
|
|
183
213
|
outputs=outputs,
|
|
184
214
|
)
|
|
215
|
+
|
|
216
|
+
# Add parameters if signature has parameters
|
|
217
|
+
if self.model_meta.signatures[self.target_method].params:
|
|
218
|
+
flat_params = ModelMethod._flatten_params(list(self.model_meta.signatures[self.target_method].params))
|
|
219
|
+
param_list = [
|
|
220
|
+
ModelMethod._get_method_arg_from_param(
|
|
221
|
+
param_spec, case_sensitive=self.options.get("case_sensitive", False)
|
|
222
|
+
)
|
|
223
|
+
for param_spec in flat_params
|
|
224
|
+
]
|
|
225
|
+
param_name_counter = collections.Counter([param_info["name"] for param_info in param_list])
|
|
226
|
+
dup_param_names = [k for k, v in param_name_counter.items() if v > 1]
|
|
227
|
+
if dup_param_names:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Found duplicate parameter named resolved as {', '.join(dup_param_names)} in the method"
|
|
230
|
+
f" {self.target_method}. This might be because you have parameters with same letters but "
|
|
231
|
+
"different cases. In this case, set case_sensitive as True for those methods to distinguish them."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Check for name collisions between parameters and inputs using existing counters
|
|
235
|
+
collision_names = [name for name in param_name_counter if name in input_name_counter]
|
|
236
|
+
if collision_names:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))} "
|
|
239
|
+
f"in the method {self.target_method}. Parameters and inputs must have distinct names. "
|
|
240
|
+
"Try using case_sensitive=True if the names differ only by case."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
method_dict["params"] = param_list
|
|
244
|
+
|
|
185
245
|
should_set_volatility = (
|
|
186
246
|
platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
|
|
187
247
|
)
|
|
@@ -86,6 +86,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
|
86
86
|
get_prediction_fn=get_prediction,
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
+
# Add parameters extracted from custom model inference methods to signatures
|
|
90
|
+
cls._add_method_parameters_to_signatures(model, model_meta)
|
|
91
|
+
|
|
89
92
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
90
93
|
os.makedirs(model_blob_path, exist_ok=True)
|
|
91
94
|
if model.context.artifacts:
|
|
@@ -188,6 +191,55 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
|
188
191
|
assert isinstance(model, custom_model.CustomModel)
|
|
189
192
|
return model
|
|
190
193
|
|
|
194
|
+
@classmethod
|
|
195
|
+
def _add_method_parameters_to_signatures(
|
|
196
|
+
cls,
|
|
197
|
+
model: "custom_model.CustomModel",
|
|
198
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Extract parameters from custom model inference methods and add them to signatures.
|
|
201
|
+
|
|
202
|
+
For each inference method, if the signature doesn't already have parameters and the method
|
|
203
|
+
has keyword-only parameters with defaults, create ParamSpecs and add them to the signature.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
model: The custom model instance.
|
|
207
|
+
model_meta: The model metadata containing signatures to augment.
|
|
208
|
+
"""
|
|
209
|
+
for method in model._get_infer_methods():
|
|
210
|
+
method_name = method.__name__
|
|
211
|
+
if method_name not in model_meta.signatures:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
sig = model_meta.signatures[method_name]
|
|
215
|
+
|
|
216
|
+
# Skip if the signature already has parameters (user-provided or previously set)
|
|
217
|
+
if sig.params:
|
|
218
|
+
continue
|
|
219
|
+
|
|
220
|
+
# Extract parameters from the method
|
|
221
|
+
method_params = custom_model.get_method_parameters(method)
|
|
222
|
+
if not method_params:
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Create ParamSpecs from the method parameters
|
|
226
|
+
param_specs = []
|
|
227
|
+
for param_name, param_type, param_default in method_params:
|
|
228
|
+
dtype = model_signature.DataType.from_python_type(param_type)
|
|
229
|
+
param_spec = model_signature.ParamSpec(
|
|
230
|
+
name=param_name,
|
|
231
|
+
dtype=dtype,
|
|
232
|
+
default_value=param_default,
|
|
233
|
+
)
|
|
234
|
+
param_specs.append(param_spec)
|
|
235
|
+
|
|
236
|
+
# Create a new signature with parameters
|
|
237
|
+
model_meta.signatures[method_name] = model_signature.ModelSignature(
|
|
238
|
+
inputs=sig.inputs,
|
|
239
|
+
outputs=sig.outputs,
|
|
240
|
+
params=param_specs,
|
|
241
|
+
)
|
|
242
|
+
|
|
191
243
|
@classmethod
|
|
192
244
|
def convert_as_custom_model(
|
|
193
245
|
cls,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -28,7 +29,10 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
28
29
|
model_meta as model_meta_api,
|
|
29
30
|
model_meta_schema,
|
|
30
31
|
)
|
|
31
|
-
from snowflake.ml.model._signatures import
|
|
32
|
+
from snowflake.ml.model._signatures import (
|
|
33
|
+
core as model_signature_core,
|
|
34
|
+
utils as model_signature_utils,
|
|
35
|
+
)
|
|
32
36
|
from snowflake.ml.model.models import (
|
|
33
37
|
huggingface as huggingface_base,
|
|
34
38
|
huggingface_pipeline,
|
|
@@ -530,7 +534,10 @@ class TransformersPipelineHandler(
|
|
|
530
534
|
# verify when the target method is __call__ and
|
|
531
535
|
# if the signature is default text-generation signature
|
|
532
536
|
# then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
|
|
533
|
-
if
|
|
537
|
+
if (
|
|
538
|
+
signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
|
|
539
|
+
or signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING
|
|
540
|
+
):
|
|
534
541
|
wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
|
|
535
542
|
|
|
536
543
|
temp_res = X.apply(
|
|
@@ -554,6 +561,19 @@ class TransformersPipelineHandler(
|
|
|
554
561
|
else:
|
|
555
562
|
input_data = X[signature.inputs[0].name].to_list()
|
|
556
563
|
temp_res = getattr(raw_model, target_method)(input_data)
|
|
564
|
+
elif isinstance(raw_model, transformers.ImageClassificationPipeline):
|
|
565
|
+
# Image classification expects PIL Images. Convert bytes to PIL Images.
|
|
566
|
+
from PIL import Image
|
|
567
|
+
|
|
568
|
+
input_col = signature.inputs[0].name
|
|
569
|
+
images = [Image.open(io.BytesIO(img_bytes)) for img_bytes in X[input_col].to_list()]
|
|
570
|
+
temp_res = getattr(raw_model, target_method)(images)
|
|
571
|
+
elif isinstance(raw_model, transformers.AutomaticSpeechRecognitionPipeline):
|
|
572
|
+
# ASR pipeline accepts a single audio input (bytes, str, np.ndarray, or dict),
|
|
573
|
+
# not a list. Process each audio input individually.
|
|
574
|
+
input_col = signature.inputs[0].name
|
|
575
|
+
audio_inputs = X[input_col].to_list()
|
|
576
|
+
temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
|
|
557
577
|
else:
|
|
558
578
|
# TODO: remove conversational pipeline code
|
|
559
579
|
# For others, we could offer the whole dataframe as a list.
|
|
@@ -615,11 +635,14 @@ class TransformersPipelineHandler(
|
|
|
615
635
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
|
616
636
|
|
|
617
637
|
# To concat those who outputs a list with one input.
|
|
618
|
-
if
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
638
|
+
# if `signature.outputs` is single valued and is a FeatureGroupSpec,
|
|
639
|
+
# we create a DataFrame with one column and the values are stored as a dictionary.
|
|
640
|
+
# Otherwise, we create a DataFrame with the output as the column.
|
|
641
|
+
if len(signature.outputs) == 1 and isinstance(
|
|
642
|
+
signature.outputs[0], model_signature_core.FeatureGroupSpec
|
|
643
|
+
):
|
|
644
|
+
# creating a dataframe with one column
|
|
645
|
+
res = pd.DataFrame({signature.outputs[0].name: temp_res})
|
|
623
646
|
else:
|
|
624
647
|
res = pd.DataFrame(temp_res)
|
|
625
648
|
|
|
@@ -702,7 +725,6 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
702
725
|
self.pipeline = pipeline
|
|
703
726
|
self.model = self.pipeline.model
|
|
704
727
|
self.tokenizer = self.pipeline.tokenizer
|
|
705
|
-
|
|
706
728
|
self.model_name = self.pipeline.model.name_or_path
|
|
707
729
|
|
|
708
730
|
if self.tokenizer.pad_token is None:
|
|
@@ -724,11 +746,33 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
724
746
|
Returns:
|
|
725
747
|
The formatted prompt string ready for model input.
|
|
726
748
|
"""
|
|
749
|
+
|
|
750
|
+
final_messages = []
|
|
751
|
+
for message in messages:
|
|
752
|
+
if isinstance(message.get("content", ""), str):
|
|
753
|
+
final_messages.append({"role": message.get("role", "user"), "content": message.get("content", "")})
|
|
754
|
+
else:
|
|
755
|
+
# extract only the text from the content
|
|
756
|
+
# sample data:
|
|
757
|
+
# {
|
|
758
|
+
# "role": "user",
|
|
759
|
+
# "content": [
|
|
760
|
+
# {"type": "text", "text": "Hello, how are you?"}, # extracted
|
|
761
|
+
# {"type": "image", "image": "https://example.com/image.png"}, # not extracted
|
|
762
|
+
# ],
|
|
763
|
+
# }
|
|
764
|
+
for content_part in message.get("content", []):
|
|
765
|
+
if content_part.get("type", "") == "text":
|
|
766
|
+
final_messages.append(
|
|
767
|
+
{"role": message.get("role", "user"), "content": content_part.get("text", "")}
|
|
768
|
+
)
|
|
769
|
+
# TODO: implement other content types
|
|
770
|
+
|
|
727
771
|
# Use the tokenizer's apply_chat_template method.
|
|
728
772
|
# We ensured a template exists in __init__.
|
|
729
773
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
730
774
|
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
731
|
-
|
|
775
|
+
final_messages,
|
|
732
776
|
tokenize=False,
|
|
733
777
|
add_generation_prompt=True,
|
|
734
778
|
)
|
|
@@ -736,7 +780,7 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
736
780
|
# Fallback for very old transformers without apply_chat_template
|
|
737
781
|
# Manually apply ChatML-like formatting
|
|
738
782
|
prompt = ""
|
|
739
|
-
for message in
|
|
783
|
+
for message in final_messages:
|
|
740
784
|
role = message.get("role", "user")
|
|
741
785
|
content = message.get("content", "")
|
|
742
786
|
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|