snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/utils/url.py +42 -0
  2. snowflake/ml/jobs/__init__.py +2 -0
  3. snowflake/ml/jobs/_utils/constants.py +2 -0
  4. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  5. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  6. snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
  7. snowflake/ml/jobs/_utils/spec_utils.py +0 -31
  8. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  9. snowflake/ml/jobs/_utils/types.py +22 -2
  10. snowflake/ml/jobs/job_definition.py +232 -0
  11. snowflake/ml/jobs/manager.py +16 -177
  12. snowflake/ml/lineage/lineage_node.py +1 -1
  13. snowflake/ml/model/__init__.py +6 -0
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  15. snowflake/ml/model/_client/model/model_version_impl.py +109 -32
  16. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +45 -2
  18. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  19. snowflake/ml/model/_client/ops/service_ops.py +81 -61
  20. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
  23. snowflake/ml/model/_client/sql/model_version.py +30 -6
  24. snowflake/ml/model/_client/sql/service.py +30 -29
  25. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
  31. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  32. snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
  33. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  36. snowflake/ml/model/_packager/model_packager.py +1 -1
  37. snowflake/ml/model/_signatures/core.py +85 -0
  38. snowflake/ml/model/_signatures/utils.py +55 -0
  39. snowflake/ml/model/code_path.py +104 -0
  40. snowflake/ml/model/custom_model.py +55 -13
  41. snowflake/ml/model/model_signature.py +13 -1
  42. snowflake/ml/model/openai_signatures.py +97 -0
  43. snowflake/ml/model/type_hints.py +2 -0
  44. snowflake/ml/registry/_manager/model_manager.py +230 -15
  45. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  46. snowflake/ml/registry/registry.py +4 -4
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
  49. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
  50. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,14 @@ def _normalize_url_for_sql(url: str) -> str:
22
22
  return f"'{url}'"
23
23
 
24
24
 
25
+ def _format_param_value(value: Any) -> str:
26
+ if isinstance(value, str):
27
+ return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
28
+ elif value is None:
29
+ return "NULL"
30
+ return str(value)
31
+
32
+
25
33
  class ModelVersionSQLClient(_base._BaseSQLClient):
26
34
  FUNCTION_NAME_COL_NAME = "name"
27
35
  FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
@@ -354,6 +362,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
354
362
  input_args: list[sql_identifier.SqlIdentifier],
355
363
  returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
364
  statement_params: Optional[dict[str, Any]] = None,
365
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
357
366
  ) -> dataframe.DataFrame:
358
367
  with_statements = []
359
368
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -392,10 +401,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
392
401
 
393
402
  args_sql = ", ".join(args_sql_list)
394
403
 
395
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
404
+ if params:
405
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
406
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
407
+
408
+ total_args = len(input_args) + (len(params) if params else 0)
409
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
396
410
  if wide_input:
397
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
398
- args_sql = f"object_construct_keep_null({input_args_sql})"
411
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
412
+ if params:
413
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
414
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
399
415
 
400
416
  sql = textwrap.dedent(
401
417
  f"""WITH {','.join(with_statements)}
@@ -439,6 +455,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
439
455
  statement_params: Optional[dict[str, Any]] = None,
440
456
  is_partitioned: bool = True,
441
457
  explain_case_sensitive: bool = False,
458
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
442
459
  ) -> dataframe.DataFrame:
443
460
  with_statements = []
444
461
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -477,10 +494,17 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
477
494
 
478
495
  args_sql = ", ".join(args_sql_list)
479
496
 
480
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
497
+ if params:
498
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
499
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
500
+
501
+ total_args = len(input_args) + (len(params) if params else 0)
502
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
481
503
  if wide_input:
482
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
483
- args_sql = f"object_construct_keep_null({input_args_sql})"
504
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
505
+ if params:
506
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
507
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
484
508
 
485
509
  sql = textwrap.dedent(
486
510
  f"""WITH {','.join(with_statements)}
@@ -20,6 +20,15 @@ from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
+
24
+ def _format_param_value(value: Any) -> str:
25
+ if isinstance(value, str):
26
+ return f"'{snowpark_utils.escape_single_quotes(value)}'" # type: ignore[no-untyped-call]
27
+ elif value is None:
28
+ return "NULL"
29
+ return str(value)
30
+
31
+
23
32
  # Using this token instead of '?' to avoid escaping issues
24
33
  # After quotes are escaped, we replace this token with '|| ? ||'
25
34
  QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
@@ -38,22 +47,6 @@ class ServiceStatus(enum.Enum):
38
47
  INTERNAL_ERROR = "INTERNAL_ERROR"
39
48
 
40
49
 
41
- class InstanceStatus(enum.Enum):
42
- PENDING = "PENDING"
43
- READY = "READY"
44
- FAILED = "FAILED"
45
- TERMINATING = "TERMINATING"
46
- SUCCEEDED = "SUCCEEDED"
47
-
48
-
49
- class ContainerStatus(enum.Enum):
50
- PENDING = "PENDING"
51
- READY = "READY"
52
- DONE = "DONE"
53
- FAILED = "FAILED"
54
- UNKNOWN = "UNKNOWN"
55
-
56
-
57
50
  @dataclasses.dataclass
58
51
  class ServiceStatusInfo:
59
52
  """
@@ -63,8 +56,8 @@ class ServiceStatusInfo:
63
56
 
64
57
  service_status: ServiceStatus
65
58
  instance_id: Optional[int] = None
66
- instance_status: Optional[InstanceStatus] = None
67
- container_status: Optional[ContainerStatus] = None
59
+ instance_status: Optional[str] = None
60
+ container_status: Optional[str] = None
68
61
  message: Optional[str] = None
69
62
 
70
63
 
@@ -140,6 +133,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
140
133
  input_args: list[sql_identifier.SqlIdentifier],
141
134
  returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
142
135
  statement_params: Optional[dict[str, Any]] = None,
136
+ params: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None,
143
137
  ) -> dataframe.DataFrame:
144
138
  with_statements = []
145
139
  actual_database_name = database_name or self._database_name
@@ -170,10 +164,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
170
164
  args_sql_list.append(input_arg_value)
171
165
  args_sql = ", ".join(args_sql_list)
172
166
 
173
- wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
167
+ if params:
168
+ param_sql = ", ".join(_format_param_value(val) for _, val in params)
169
+ args_sql = f"{args_sql}, {param_sql}" if args_sql else param_sql
170
+
171
+ total_args = len(input_args) + (len(params) if params else 0)
172
+ wide_input = total_args > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
174
173
  if wide_input:
175
- input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
176
- args_sql = f"object_construct_keep_null({input_args_sql})"
174
+ parts = [f"'{arg}', {arg.identifier()}" for arg in input_args]
175
+ if params:
176
+ parts.extend(f"'{name}', {_format_param_value(val)}" for name, val in params)
177
+ args_sql = f"object_construct_keep_null({', '.join(parts)})"
177
178
 
178
179
  fully_qualified_service_name = self.fully_qualified_object_name(
179
180
  actual_database_name, actual_schema_name, service_name
@@ -255,17 +256,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
255
256
  )
256
257
  statuses = []
257
258
  for r in rows:
258
- instance_status, container_status = None, None
259
- if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
260
- instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
261
- if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
262
- container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
263
259
  statuses.append(
264
260
  ServiceStatusInfo(
265
261
  service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
266
262
  instance_id=r[ServiceSQLClient.INSTANCE_ID],
267
- instance_status=instance_status,
268
- container_status=container_status,
263
+ instance_status=r[ServiceSQLClient.INSTANCE_STATUS],
264
+ container_status=r[ServiceSQLClient.CONTAINER_STATUS],
269
265
  message=r[ServiceSQLClient.MESSAGE] if include_message else None,
270
266
  )
271
267
  )
@@ -301,7 +297,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
301
297
  False if service doesn't have proxy container
302
298
  """
303
299
  try:
304
- spec_raw = yaml.safe_load(row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME])
300
+ spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
301
+ if spec_yaml is None:
302
+ return False
303
+ spec_raw = yaml.safe_load(spec_yaml)
304
+ if spec_raw is None:
305
+ return False
305
306
  spec = cast(dict[str, Any], spec_raw)
306
307
 
307
308
  proxy_container_spec = next(
@@ -131,7 +131,7 @@ class ModelComposer:
131
131
  python_version: Optional[str] = None,
132
132
  user_files: Optional[dict[str, list[str]]] = None,
133
133
  ext_modules: Optional[list[ModuleType]] = None,
134
- code_paths: Optional[list[str]] = None,
134
+ code_paths: Optional[list[model_types.CodePathLike]] = None,
135
135
  task: model_types.Task = model_types.Task.UNKNOWN,
136
136
  experiment_info: Optional["ExperimentInfo"] = None,
137
137
  options: Optional[model_types.ModelSaveOption] = None,
@@ -39,6 +39,10 @@ class ModelMethodSignatureFieldWithName(ModelMethodSignatureField):
39
39
  name: Required[str]
40
40
 
41
41
 
42
+ class ModelMethodSignatureFieldWithNameAndDefault(ModelMethodSignatureFieldWithName):
43
+ default: Required[Any]
44
+
45
+
42
46
  class ModelFunctionMethodDict(TypedDict):
43
47
  name: Required[str]
44
48
  runtime: Required[str]
@@ -46,6 +50,7 @@ class ModelFunctionMethodDict(TypedDict):
46
50
  handler: Required[str]
47
51
  inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
52
  outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
53
+ params: NotRequired[list[ModelMethodSignatureFieldWithNameAndDefault]]
49
54
  volatility: NotRequired[str]
50
55
 
51
56
 
@@ -41,11 +41,29 @@ features = meta.signatures[TARGET_METHOD].inputs
41
41
  input_cols = [feature.name for feature in features]
42
42
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
+ # Load inference parameters from method signature (if any)
45
+ param_cols = []
46
+ param_defaults = {{}}
47
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
48
+ for param_spec in meta.signatures[TARGET_METHOD].params:
49
+ param_cols.append(param_spec.name)
50
+ param_defaults[param_spec.name] = param_spec.default_value
51
+
44
52
 
45
53
  # Actual function
46
54
  @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
55
  def {function_name}(df: pd.DataFrame) -> dict:
48
- df.columns = input_cols
49
- input_df = df.astype(dtype=dtype_map)
50
- predictions_df = runner(input_df[input_cols])
56
+ df.columns = input_cols + param_cols
57
+ input_df = df[input_cols].astype(dtype=dtype_map)
58
+
59
+ # Extract runtime param values, using defaults if None
60
+ method_params = {{}}
61
+ for col in param_cols:
62
+ val = df[col].iloc[0]
63
+ if val is None or pd.isna(val):
64
+ method_params[col] = param_defaults[col]
65
+ else:
66
+ method_params[col] = val
67
+
68
+ predictions_df = runner(input_df, **method_params)
51
69
  return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
@@ -45,11 +45,29 @@ features = meta.signatures[TARGET_METHOD].inputs
45
45
  input_cols = [feature.name for feature in features]
46
46
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
47
47
 
48
+ # Load inference parameters from method signature (if any)
49
+ param_cols = []
50
+ param_defaults = {{}}
51
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
52
+ for param_spec in meta.signatures[TARGET_METHOD].params:
53
+ param_cols.append(param_spec.name)
54
+ param_defaults[param_spec.name] = param_spec.default_value
55
+
48
56
 
49
57
  # Actual table function
50
58
  class {function_name}:
51
59
  @vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
52
60
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
53
- df.columns = input_cols
54
- input_df = df.astype(dtype=dtype_map)
55
- return runner(input_df[input_cols])
61
+ df.columns = input_cols + param_cols
62
+ input_df = df[input_cols].astype(dtype=dtype_map)
63
+
64
+ # Extract runtime param values, using defaults if None
65
+ method_params = {{}}
66
+ for col in param_cols:
67
+ val = df[col].iloc[0]
68
+ if val is None or pd.isna(val):
69
+ method_params[col] = param_defaults[col]
70
+ else:
71
+ method_params[col] = val
72
+
73
+ return runner(input_df, **method_params)
@@ -40,11 +40,29 @@ features = meta.signatures[TARGET_METHOD].inputs
40
40
  input_cols = [feature.name for feature in features]
41
41
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
42
42
 
43
+ # Load inference parameters from method signature (if any)
44
+ param_cols = []
45
+ param_defaults = {{}}
46
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
47
+ for param_spec in meta.signatures[TARGET_METHOD].params:
48
+ param_cols.append(param_spec.name)
49
+ param_defaults[param_spec.name] = param_spec.default_value
50
+
43
51
 
44
52
  # Actual table function
45
53
  class {function_name}:
46
54
  @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
55
  def process(self, df: pd.DataFrame) -> pd.DataFrame:
48
- df.columns = input_cols
49
- input_df = df.astype(dtype=dtype_map)
50
- return runner(input_df[input_cols])
56
+ df.columns = input_cols + param_cols
57
+ input_df = df[input_cols].astype(dtype=dtype_map)
58
+
59
+ # Extract runtime param values, using defaults if None
60
+ method_params = {{}}
61
+ for col in param_cols:
62
+ val = df[col].iloc[0]
63
+ if val is None or pd.isna(val):
64
+ method_params[col] = param_defaults[col]
65
+ else:
66
+ method_params[col] = val
67
+
68
+ return runner(input_df, **method_params)
@@ -105,7 +105,7 @@ class ModelMethod:
105
105
  except ValueError as e:
106
106
  raise ValueError(
107
107
  f"Your target method {self.target_method} cannot be resolved as valid SQL identifier. "
108
- "Try specify `case_sensitive` as True."
108
+ "Try specifying `case_sensitive` as True."
109
109
  ) from e
110
110
 
111
111
  if self.target_method not in self.model_meta.signatures.keys():
@@ -127,12 +127,42 @@ class ModelMethod:
127
127
  except ValueError as e:
128
128
  raise ValueError(
129
129
  f"Your feature {feature.name} cannot be resolved as valid SQL identifier. "
130
- "Try specify `case_sensitive` as True."
130
+ "Try specifying `case_sensitive` as True."
131
131
  ) from e
132
132
  return model_manifest_schema.ModelMethodSignatureFieldWithName(
133
133
  name=feature_name.resolved(), type=type_utils.convert_sp_to_sf_type(feature.as_snowpark_type())
134
134
  )
135
135
 
136
+ @staticmethod
137
+ def _flatten_params(params: list[model_signature.BaseParamSpec]) -> list[model_signature.ParamSpec]:
138
+ """Flatten ParamGroupSpec into leaf ParamSpec items."""
139
+ result: list[model_signature.ParamSpec] = []
140
+ for param in params:
141
+ if isinstance(param, model_signature.ParamSpec):
142
+ result.append(param)
143
+ elif isinstance(param, model_signature.ParamGroupSpec):
144
+ result.extend(ModelMethod._flatten_params(param.specs))
145
+ return result
146
+
147
+ @staticmethod
148
+ def _get_method_arg_from_param(
149
+ param_spec: model_signature.ParamSpec,
150
+ case_sensitive: bool = False,
151
+ ) -> model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault:
152
+ try:
153
+ param_name = sql_identifier.SqlIdentifier(param_spec.name, case_sensitive=case_sensitive)
154
+ except ValueError as e:
155
+ raise ValueError(
156
+ f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
157
+ "Try specifying `case_sensitive` as True."
158
+ ) from e
159
+ default_value = param_spec.default_value if param_spec.default_value is None else str(param_spec.default_value)
160
+ return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
161
+ name=param_name.resolved(),
162
+ type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
163
+ default=default_value,
164
+ )
165
+
136
166
  def save(
137
167
  self, workspace_path: pathlib.Path, options: Optional[function_generator.FunctionGenerateOptions] = None
138
168
  ) -> model_manifest_schema.ModelMethodDict:
@@ -182,6 +212,36 @@ class ModelMethod:
182
212
  inputs=input_list,
183
213
  outputs=outputs,
184
214
  )
215
+
216
+ # Add parameters if signature has parameters
217
+ if self.model_meta.signatures[self.target_method].params:
218
+ flat_params = ModelMethod._flatten_params(list(self.model_meta.signatures[self.target_method].params))
219
+ param_list = [
220
+ ModelMethod._get_method_arg_from_param(
221
+ param_spec, case_sensitive=self.options.get("case_sensitive", False)
222
+ )
223
+ for param_spec in flat_params
224
+ ]
225
+ param_name_counter = collections.Counter([param_info["name"] for param_info in param_list])
226
+ dup_param_names = [k for k, v in param_name_counter.items() if v > 1]
227
+ if dup_param_names:
228
+ raise ValueError(
229
+ f"Found duplicate parameter named resolved as {', '.join(dup_param_names)} in the method"
230
+ f" {self.target_method}. This might be because you have parameters with same letters but "
231
+ "different cases. In this case, set case_sensitive as True for those methods to distinguish them."
232
+ )
233
+
234
+ # Check for name collisions between parameters and inputs using existing counters
235
+ collision_names = [name for name in param_name_counter if name in input_name_counter]
236
+ if collision_names:
237
+ raise ValueError(
238
+ f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))} "
239
+ f"in the method {self.target_method}. Parameters and inputs must have distinct names. "
240
+ "Try using case_sensitive=True if the names differ only by case."
241
+ )
242
+
243
+ method_dict["params"] = param_list
244
+
185
245
  should_set_volatility = (
186
246
  platform_capabilities.PlatformCapabilities.get_instance().is_set_module_functions_volatility_from_manifest()
187
247
  )
@@ -86,6 +86,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
86
86
  get_prediction_fn=get_prediction,
87
87
  )
88
88
 
89
+ # Add parameters extracted from custom model inference methods to signatures
90
+ cls._add_method_parameters_to_signatures(model, model_meta)
91
+
89
92
  model_blob_path = os.path.join(model_blobs_dir_path, name)
90
93
  os.makedirs(model_blob_path, exist_ok=True)
91
94
  if model.context.artifacts:
@@ -188,6 +191,55 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
188
191
  assert isinstance(model, custom_model.CustomModel)
189
192
  return model
190
193
 
194
+ @classmethod
195
+ def _add_method_parameters_to_signatures(
196
+ cls,
197
+ model: "custom_model.CustomModel",
198
+ model_meta: model_meta_api.ModelMetadata,
199
+ ) -> None:
200
+ """Extract parameters from custom model inference methods and add them to signatures.
201
+
202
+ For each inference method, if the signature doesn't already have parameters and the method
203
+ has keyword-only parameters with defaults, create ParamSpecs and add them to the signature.
204
+
205
+ Args:
206
+ model: The custom model instance.
207
+ model_meta: The model metadata containing signatures to augment.
208
+ """
209
+ for method in model._get_infer_methods():
210
+ method_name = method.__name__
211
+ if method_name not in model_meta.signatures:
212
+ continue
213
+
214
+ sig = model_meta.signatures[method_name]
215
+
216
+ # Skip if the signature already has parameters (user-provided or previously set)
217
+ if sig.params:
218
+ continue
219
+
220
+ # Extract parameters from the method
221
+ method_params = custom_model.get_method_parameters(method)
222
+ if not method_params:
223
+ continue
224
+
225
+ # Create ParamSpecs from the method parameters
226
+ param_specs = []
227
+ for param_name, param_type, param_default in method_params:
228
+ dtype = model_signature.DataType.from_python_type(param_type)
229
+ param_spec = model_signature.ParamSpec(
230
+ name=param_name,
231
+ dtype=dtype,
232
+ default_value=param_default,
233
+ )
234
+ param_specs.append(param_spec)
235
+
236
+ # Create a new signature with parameters
237
+ model_meta.signatures[method_name] = model_signature.ModelSignature(
238
+ inputs=sig.inputs,
239
+ outputs=sig.outputs,
240
+ params=param_specs,
241
+ )
242
+
191
243
  @classmethod
192
244
  def convert_as_custom_model(
193
245
  cls,
@@ -1,3 +1,4 @@
1
+ import io
1
2
  import json
2
3
  import logging
3
4
  import os
@@ -28,7 +29,10 @@ from snowflake.ml.model._packager.model_meta import (
28
29
  model_meta as model_meta_api,
29
30
  model_meta_schema,
30
31
  )
31
- from snowflake.ml.model._signatures import utils as model_signature_utils
32
+ from snowflake.ml.model._signatures import (
33
+ core as model_signature_core,
34
+ utils as model_signature_utils,
35
+ )
32
36
  from snowflake.ml.model.models import (
33
37
  huggingface as huggingface_base,
34
38
  huggingface_pipeline,
@@ -530,7 +534,10 @@ class TransformersPipelineHandler(
530
534
  # verify when the target method is __call__ and
531
535
  # if the signature is default text-generation signature
532
536
  # then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
533
- if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
537
+ if (
538
+ signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
539
+ or signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING
540
+ ):
534
541
  wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
535
542
 
536
543
  temp_res = X.apply(
@@ -554,6 +561,19 @@ class TransformersPipelineHandler(
554
561
  else:
555
562
  input_data = X[signature.inputs[0].name].to_list()
556
563
  temp_res = getattr(raw_model, target_method)(input_data)
564
+ elif isinstance(raw_model, transformers.ImageClassificationPipeline):
565
+ # Image classification expects PIL Images. Convert bytes to PIL Images.
566
+ from PIL import Image
567
+
568
+ input_col = signature.inputs[0].name
569
+ images = [Image.open(io.BytesIO(img_bytes)) for img_bytes in X[input_col].to_list()]
570
+ temp_res = getattr(raw_model, target_method)(images)
571
+ elif isinstance(raw_model, transformers.AutomaticSpeechRecognitionPipeline):
572
+ # ASR pipeline accepts a single audio input (bytes, str, np.ndarray, or dict),
573
+ # not a list. Process each audio input individually.
574
+ input_col = signature.inputs[0].name
575
+ audio_inputs = X[input_col].to_list()
576
+ temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
557
577
  else:
558
578
  # TODO: remove conversational pipeline code
559
579
  # For others, we could offer the whole dataframe as a list.
@@ -615,11 +635,14 @@ class TransformersPipelineHandler(
615
635
  temp_res = [[conv.generated_responses] for conv in temp_res]
616
636
 
617
637
  # To concat those who outputs a list with one input.
618
- if isinstance(temp_res[0], list):
619
- if isinstance(temp_res[0][0], dict):
620
- res = pd.DataFrame({0: temp_res})
621
- else:
622
- res = pd.DataFrame(temp_res)
638
+ # if `signature.outputs` is single valued and is a FeatureGroupSpec,
639
+ # we create a DataFrame with one column and the values are stored as a dictionary.
640
+ # Otherwise, we create a DataFrame with the output as the column.
641
+ if len(signature.outputs) == 1 and isinstance(
642
+ signature.outputs[0], model_signature_core.FeatureGroupSpec
643
+ ):
644
+ # creating a dataframe with one column
645
+ res = pd.DataFrame({signature.outputs[0].name: temp_res})
623
646
  else:
624
647
  res = pd.DataFrame(temp_res)
625
648
 
@@ -702,7 +725,6 @@ class HuggingFaceOpenAICompatibleModel:
702
725
  self.pipeline = pipeline
703
726
  self.model = self.pipeline.model
704
727
  self.tokenizer = self.pipeline.tokenizer
705
-
706
728
  self.model_name = self.pipeline.model.name_or_path
707
729
 
708
730
  if self.tokenizer.pad_token is None:
@@ -724,11 +746,33 @@ class HuggingFaceOpenAICompatibleModel:
724
746
  Returns:
725
747
  The formatted prompt string ready for model input.
726
748
  """
749
+
750
+ final_messages = []
751
+ for message in messages:
752
+ if isinstance(message.get("content", ""), str):
753
+ final_messages.append({"role": message.get("role", "user"), "content": message.get("content", "")})
754
+ else:
755
+ # extract only the text from the content
756
+ # sample data:
757
+ # {
758
+ # "role": "user",
759
+ # "content": [
760
+ # {"type": "text", "text": "Hello, how are you?"}, # extracted
761
+ # {"type": "image", "image": "https://example.com/image.png"}, # not extracted
762
+ # ],
763
+ # }
764
+ for content_part in message.get("content", []):
765
+ if content_part.get("type", "") == "text":
766
+ final_messages.append(
767
+ {"role": message.get("role", "user"), "content": content_part.get("text", "")}
768
+ )
769
+ # TODO: implement other content types
770
+
727
771
  # Use the tokenizer's apply_chat_template method.
728
772
  # We ensured a template exists in __init__.
729
773
  if hasattr(self.tokenizer, "apply_chat_template"):
730
774
  return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
731
- messages,
775
+ final_messages,
732
776
  tokenize=False,
733
777
  add_generation_prompt=True,
734
778
  )
@@ -736,7 +780,7 @@ class HuggingFaceOpenAICompatibleModel:
736
780
  # Fallback for very old transformers without apply_chat_template
737
781
  # Manually apply ChatML-like formatting
738
782
  prompt = ""
739
- for message in messages:
783
+ for message in final_messages:
740
784
  role = message.get("role", "user")
741
785
  content = message.get("content", "")
742
786
  prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"