snowflake-ml-python 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. snowflake/ml/_internal/utils/url.py +42 -0
  2. snowflake/ml/experiment/callback/__init__.py +0 -0
  3. snowflake/ml/jobs/_utils/constants.py +1 -0
  4. snowflake/ml/jobs/_utils/spec_utils.py +0 -31
  5. snowflake/ml/lineage/lineage_node.py +1 -1
  6. snowflake/ml/model/__init__.py +6 -0
  7. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  8. snowflake/ml/model/_client/model/model_version_impl.py +63 -0
  9. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  10. snowflake/ml/model/_client/ops/model_ops.py +61 -2
  11. snowflake/ml/model/_client/ops/service_ops.py +23 -48
  12. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  13. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  14. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  15. snowflake/ml/model/_client/sql/model_version.py +30 -6
  16. snowflake/ml/model/_client/sql/service.py +26 -4
  17. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  20. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  21. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  22. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  23. snowflake/ml/model/_packager/model_packager.py +1 -1
  24. snowflake/ml/model/_signatures/core.py +85 -0
  25. snowflake/ml/model/code_path.py +104 -0
  26. snowflake/ml/model/custom_model.py +55 -13
  27. snowflake/ml/model/model_signature.py +13 -1
  28. snowflake/ml/model/type_hints.py +2 -0
  29. snowflake/ml/registry/_manager/model_manager.py +230 -15
  30. snowflake/ml/registry/registry.py +4 -4
  31. snowflake/ml/version.py +1 -1
  32. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +29 -1
  33. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +36 -32
  34. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  35. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  36. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from snowflake.ml.model._packager.model_env import model_env
20
20
  from snowflake.ml.model._packager.model_meta import model_blob_meta, model_meta_schema
21
21
  from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
22
22
  from snowflake.ml.model._packager.model_runtime import model_runtime
23
+ from snowflake.ml.model.code_path import CodePath
23
24
 
24
25
  MODEL_METADATA_FILE = "model.yaml"
25
26
  MODEL_CODE_DIR = "code"
@@ -39,7 +40,7 @@ def create_model_metadata(
39
40
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
40
41
  function_properties: Optional[dict[str, dict[str, Any]]] = None,
41
42
  metadata: Optional[dict[str, str]] = None,
42
- code_paths: Optional[list[str]] = None,
43
+ code_paths: Optional[list[model_types.CodePathLike]] = None,
43
44
  ext_modules: Optional[list[ModuleType]] = None,
44
45
  conda_dependencies: Optional[list[str]] = None,
45
46
  pip_requirements: Optional[list[str]] = None,
@@ -77,7 +78,8 @@ def create_model_metadata(
77
78
  **kwargs: Dict of attributes and values of the metadata. Used when loading from file.
78
79
 
79
80
  Raises:
80
- ValueError: Raised when the code path contains reserved file or directory.
81
+ ValueError: Raised when the code path contains reserved file or directory, or destination conflicts.
82
+ FileNotFoundError: Raised when a code path does not exist.
81
83
 
82
84
  Yields:
83
85
  A model metadata object.
@@ -134,13 +136,44 @@ def create_model_metadata(
134
136
  os.makedirs(code_dir_path, exist_ok=True)
135
137
 
136
138
  if code_paths:
139
+ # Resolve all code paths and check for conflicts
140
+ resolved_paths: list[tuple[str, str]] = [] # (source, destination_relative)
137
141
  for code_path in code_paths:
138
- # This part is to prevent users from providing code following our naming and overwrite our code.
139
- if (
140
- os.path.isfile(code_path) and os.path.splitext(os.path.basename(code_path))[0] == _SNOWFLAKE_PKG_NAME
141
- ) or (os.path.isdir(code_path) and os.path.basename(code_path) == _SNOWFLAKE_PKG_NAME):
142
+ if isinstance(code_path, CodePath):
143
+ source, dest_relative = code_path._resolve()
144
+ else:
145
+ # String path: keep existing behavior
146
+ source = os.path.normpath(os.path.abspath(code_path))
147
+ if not os.path.exists(source):
148
+ raise FileNotFoundError(f"Code path '{code_path}' does not exist (resolved to {source}).")
149
+ dest_relative = os.path.basename(source)
150
+ resolved_paths.append((source, dest_relative))
151
+
152
+ # Check for destination conflicts
153
+ seen: dict[str, str] = {}
154
+ for source, dest in resolved_paths:
155
+ if dest in seen:
156
+ raise ValueError(
157
+ f"Destination path conflict: '{dest}' is targeted by both '{seen[dest]}' and '{source}'."
158
+ )
159
+ seen[dest] = source
160
+
161
+ # Copy files
162
+ for source, dest_relative in resolved_paths:
163
+ # Prevent reserved name conflicts
164
+ dest_name = dest_relative.split(os.sep)[0] if os.sep in dest_relative else dest_relative
165
+ if (os.path.isfile(source) and os.path.splitext(dest_name)[0] == _SNOWFLAKE_PKG_NAME) or (
166
+ os.path.isdir(source) and dest_name == _SNOWFLAKE_PKG_NAME
167
+ ):
142
168
  raise ValueError("`snowflake` is a reserved name and you cannot contain that into code path.")
143
- file_utils.copy_file_or_tree(code_path, code_dir_path)
169
+
170
+ parent_dir = (
171
+ os.path.join(code_dir_path, os.path.dirname(dest_relative))
172
+ if os.path.dirname(dest_relative)
173
+ else code_dir_path
174
+ )
175
+ os.makedirs(parent_dir, exist_ok=True)
176
+ file_utils.copy_file_or_tree(source, parent_dir)
144
177
 
145
178
  try:
146
179
  imported_modules = []
@@ -49,7 +49,7 @@ class ModelPackager:
49
49
  target_platforms: Optional[list[model_types.TargetPlatform]] = None,
50
50
  python_version: Optional[str] = None,
51
51
  ext_modules: Optional[list[ModuleType]] = None,
52
- code_paths: Optional[list[str]] = None,
52
+ code_paths: Optional[list[model_types.CodePathLike]] = None,
53
53
  options: model_types.ModelSaveOption,
54
54
  task: model_types.Task = model_types.Task.UNKNOWN,
55
55
  ) -> model_meta.ModelMetadata:
@@ -191,6 +191,35 @@ class DataType(Enum):
191
191
  original_exception=NotImplementedError(f"Type {snowpark_type} is not supported as a DataType."),
192
192
  )
193
193
 
194
+ @classmethod
195
+ def from_python_type(cls, python_type: type) -> "DataType":
196
+ """Translate Python built-in type to DataType for signature definition.
197
+
198
+ Args:
199
+ python_type: A Python built-in type (int, float, str, bool).
200
+
201
+ Raises:
202
+ SnowflakeMLException: NotImplementedError: Raised when the given Python type is not supported.
203
+
204
+ Returns:
205
+ Corresponding DataType.
206
+ """
207
+ python_to_snowml_type_mapping: dict[type, "DataType"] = {
208
+ int: DataType.INT64,
209
+ float: DataType.DOUBLE,
210
+ str: DataType.STRING,
211
+ bool: DataType.BOOL,
212
+ }
213
+ if python_type in python_to_snowml_type_mapping:
214
+ return python_to_snowml_type_mapping[python_type]
215
+ raise snowml_exceptions.SnowflakeMLException(
216
+ error_code=error_codes.NOT_IMPLEMENTED,
217
+ original_exception=NotImplementedError(
218
+ f"Python type {python_type} is not supported as a DataType. "
219
+ f"Supported types are: {list(python_to_snowml_type_mapping.keys())}."
220
+ ),
221
+ )
222
+
194
223
 
195
224
  class BaseFeatureSpec(ABC):
196
225
  """Abstract Class for specification of a feature."""
@@ -764,10 +793,17 @@ class ModelSignature:
764
793
  the output of the model.
765
794
  params: A sequence of parameter specifications and parameter group specifications that will compose
766
795
  the parameters of the model. Defaults to None.
796
+
797
+ Raises:
798
+ SnowflakeMLException: ValueError: When the parameters have duplicate names or the same
799
+ names as input features.
800
+
801
+ # noqa: DAR402
767
802
  """
768
803
  self._inputs = inputs
769
804
  self._outputs = outputs
770
805
  self._params = params or []
806
+ self._name_validation()
771
807
 
772
808
  @property
773
809
  def inputs(self) -> Sequence[BaseFeatureSpec]:
@@ -879,6 +915,55 @@ class ModelSignature:
879
915
 
880
916
  return html_utils.create_base_container("Model Signature", content)
881
917
 
918
+ def _name_validation(self) -> None:
919
+ """Validate the names of the inputs and parameters.
920
+
921
+ Names are compared case-insensitively (matches Snowflake identifier behavior).
922
+
923
+ Raises:
924
+ SnowflakeMLException: ValueError: When the parameters have duplicate names or the same
925
+ names as input features.
926
+ """
927
+ input_names: set[str] = set()
928
+ for input_spec in self._inputs:
929
+ names = (
930
+ [input_spec.name.upper() for spec in input_spec._specs]
931
+ if isinstance(input_spec, FeatureGroupSpec)
932
+ else [input_spec.name.upper()]
933
+ )
934
+ input_names.update(names)
935
+
936
+ param_names: set[str] = set()
937
+ dup_params: set[str] = set()
938
+ collision_names: set[str] = set()
939
+
940
+ for param in self._params:
941
+ names = [spec.name for spec in param.specs] if isinstance(param, ParamGroupSpec) else [param.name]
942
+ for name in names:
943
+ if name.upper() in param_names:
944
+ dup_params.add(name)
945
+ if name.upper() in input_names:
946
+ collision_names.add(name)
947
+ param_names.add(name.upper())
948
+
949
+ if dup_params:
950
+ raise snowml_exceptions.SnowflakeMLException(
951
+ error_code=error_codes.INVALID_ARGUMENT,
952
+ original_exception=ValueError(
953
+ f"Found duplicate parameter named resolved as {', '.join(sorted(dup_params))}."
954
+ " Parameters must have distinct names (case-insensitive)."
955
+ ),
956
+ )
957
+
958
+ if collision_names:
959
+ raise snowml_exceptions.SnowflakeMLException(
960
+ error_code=error_codes.INVALID_ARGUMENT,
961
+ original_exception=ValueError(
962
+ f"Found parameter(s) with the same name as input feature(s): {', '.join(sorted(collision_names))}."
963
+ " Parameters and inputs must have distinct names (case-insensitive)."
964
+ ),
965
+ )
966
+
882
967
  @classmethod
883
968
  def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
884
969
  return ModelSignature(
@@ -0,0 +1,104 @@
1
+ """CodePath class for selective code packaging in model registry."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ _ERR_ROOT_NOT_FOUND = "CodePath: root '{root}' does not exist (resolved to {resolved})."
8
+ _ERR_WILDCARDS_NOT_SUPPORTED = "CodePath: Wildcards are not supported in filter. Got '{filter}'. Use exact paths only."
9
+ _ERR_FILTER_MUST_BE_RELATIVE = "CodePath: filter must be a relative path, got absolute path '{filter}'."
10
+ _ERR_FILTER_HOME_PATH = "CodePath: filter must be a relative path, got home directory path '{filter}'."
11
+ _ERR_FILTER_ON_FILE_ROOT = (
12
+ "CodePath: cannot apply filter to a file root. " "Root '{root}' is a file. Use filter only with directory roots."
13
+ )
14
+ _ERR_FILTER_ESCAPES_ROOT = (
15
+ "CodePath: filter '{filter}' escapes root directory '{root}'. " "Relative paths must stay within root."
16
+ )
17
+ _ERR_FILTER_NOT_FOUND = "CodePath: filter '{filter}' under root '{root}' does not exist (resolved to {resolved})."
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class CodePath:
22
+ """Specifies a code path with optional filtering for selective inclusion.
23
+
24
+ Args:
25
+ root: The root directory or file path (absolute or relative to cwd).
26
+ filter: Optional relative path under root to select a subdirectory or file.
27
+ The filter also determines the destination path under code/.
28
+
29
+ Examples:
30
+ CodePath("project/src/") # Copy entire src/ to code/src/
31
+ CodePath("project/src/", filter="utils") # Copy utils/ to code/utils/
32
+ CodePath("project/src/", filter="lib/helpers") # Copy to code/lib/helpers/
33
+ """
34
+
35
+ root: str
36
+ filter: Optional[str] = None
37
+
38
+ def __post_init__(self) -> None:
39
+ if self.filter == "":
40
+ object.__setattr__(self, "filter", None)
41
+
42
+ def __repr__(self) -> str:
43
+ if self.filter:
44
+ return f"CodePath({self.root!r}, filter={self.filter!r})"
45
+ return f"CodePath({self.root!r})"
46
+
47
+ def _validate_filter(self) -> Optional[str]:
48
+ """Validate and normalize filter, returning normalized filter or None.
49
+
50
+ Returns:
51
+ Normalized filter path, or None if no filter is set.
52
+
53
+ Raises:
54
+ ValueError: If filter contains wildcards or is an absolute path.
55
+ """
56
+ if self.filter is None:
57
+ return None
58
+
59
+ if any(c in self.filter for c in ["*", "?", "[", "]"]):
60
+ raise ValueError(_ERR_WILDCARDS_NOT_SUPPORTED.format(filter=self.filter))
61
+
62
+ if self.filter.startswith("~"):
63
+ raise ValueError(_ERR_FILTER_HOME_PATH.format(filter=self.filter))
64
+
65
+ filter_normalized = os.path.normpath(self.filter)
66
+
67
+ if os.path.isabs(filter_normalized):
68
+ raise ValueError(_ERR_FILTER_MUST_BE_RELATIVE.format(filter=self.filter))
69
+
70
+ return filter_normalized
71
+
72
+ def _resolve(self) -> tuple[str, str]:
73
+ """Resolve the source path and destination path.
74
+
75
+ Returns:
76
+ Tuple of (source_path, destination_relative_path)
77
+
78
+ Raises:
79
+ FileNotFoundError: If root or filter path does not exist.
80
+ ValueError: If filter is invalid (wildcards, absolute, escapes root, or applied to file).
81
+ """
82
+ filter_normalized = self._validate_filter()
83
+ root_normalized = os.path.normpath(os.path.abspath(self.root))
84
+
85
+ if filter_normalized is None:
86
+ if not os.path.exists(root_normalized):
87
+ raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
88
+ return root_normalized, os.path.basename(root_normalized)
89
+
90
+ if not os.path.exists(root_normalized):
91
+ raise FileNotFoundError(_ERR_ROOT_NOT_FOUND.format(root=self.root, resolved=root_normalized))
92
+
93
+ if os.path.isfile(root_normalized):
94
+ raise ValueError(_ERR_FILTER_ON_FILE_ROOT.format(root=self.root))
95
+
96
+ source = os.path.normpath(os.path.join(root_normalized, filter_normalized))
97
+
98
+ if not (source.startswith(root_normalized + os.sep) or source == root_normalized):
99
+ raise ValueError(_ERR_FILTER_ESCAPES_ROOT.format(filter=self.filter, root=self.root))
100
+
101
+ if not os.path.exists(source):
102
+ raise FileNotFoundError(_ERR_FILTER_NOT_FOUND.format(filter=self.filter, root=self.root, resolved=source))
103
+
104
+ return source, filter_normalized
@@ -4,10 +4,13 @@ from typing import Any, Callable, Coroutine, Generator, Optional, Union
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
7
- from typing_extensions import deprecated
7
+ from typing_extensions import Concatenate, ParamSpec, deprecated
8
8
 
9
9
  from snowflake.ml.model import type_hints as model_types
10
10
 
11
+ # Captures additional keyword-only parameters for inference methods
12
+ InferenceParams = ParamSpec("InferenceParams")
13
+
11
14
 
12
15
  class MethodRef:
13
16
  """Represents a method invocation of an instance of `ModelRef`.
@@ -217,7 +220,7 @@ class CustomModel:
217
220
 
218
221
  def _get_infer_methods(
219
222
  self,
220
- ) -> Generator[Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame], None, None]:
223
+ ) -> Generator[Callable[..., pd.DataFrame], None, None]:
221
224
  """Returns all methods in CLS with `inference_api` decorator as the outermost decorator."""
222
225
  for cls_method_str in dir(self):
223
226
  cls_method = getattr(self, cls_method_str)
@@ -240,7 +243,7 @@ class CustomModel:
240
243
  return rv
241
244
 
242
245
 
243
- def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]) -> None:
246
+ def _validate_predict_function(func: Callable[..., pd.DataFrame]) -> None:
244
247
  """Validate the user provided predict method.
245
248
 
246
249
  Args:
@@ -248,19 +251,22 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
248
251
 
249
252
  Raises:
250
253
  TypeError: Raised when the method is not a callable object.
251
- TypeError: Raised when the method does not have 2 arguments (self and X).
254
+ TypeError: Raised when the method does not have at least 2 arguments (self and X).
252
255
  TypeError: Raised when the method does not have typing annotation.
253
256
  TypeError: Raised when the method's input (X) does not have type pd.DataFrame.
254
257
  TypeError: Raised when the method's output does not have type pd.DataFrame.
258
+ TypeError: Raised when additional parameters are not keyword-only with defaults.
255
259
  """
256
260
  if not callable(func):
257
261
  raise TypeError("Predict method is not callable.")
258
262
 
259
263
  func_signature = inspect.signature(func)
260
- if len(func_signature.parameters) != 2:
261
- raise TypeError("Predict method should have exact 2 arguments.")
264
+ func_signature_params = list(func_signature.parameters.values())
265
+
266
+ if len(func_signature_params) < 2:
267
+ raise TypeError("Predict method should have at least 2 arguments.")
262
268
 
263
- input_annotation = list(func_signature.parameters.values())[1].annotation
269
+ input_annotation = func_signature_params[1].annotation
264
270
  output_annotation = func_signature.return_annotation
265
271
 
266
272
  if input_annotation == inspect.Parameter.empty or output_annotation == inspect.Signature.empty:
@@ -275,17 +281,53 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
275
281
  ):
276
282
  raise TypeError("Output for predict method should have type pandas.DataFrame.")
277
283
 
284
+ # Validate additional parameters (beyond self and input) are keyword-only with defaults
285
+ for func_signature_param in func_signature_params[2:]:
286
+ _validate_parameter(func_signature_param)
287
+
288
+
289
+ def _validate_parameter(param: inspect.Parameter) -> None:
290
+ """Validate a parameter."""
291
+ if param.kind != inspect.Parameter.KEYWORD_ONLY:
292
+ raise TypeError(f"Parameter '{param.name}' must be keyword-only (defined after '*' in signature).")
293
+ if param.default == inspect.Parameter.empty:
294
+ raise TypeError(f"Parameter '{param.name}' must have a default value.")
295
+ if param.annotation == inspect.Parameter.empty:
296
+ raise TypeError(f"Parameter '{param.name}' must have a type annotation.")
297
+
298
+ # Validate annotation is a supported type
299
+ supported_types = {int, float, str, bool}
300
+ if param.annotation not in supported_types:
301
+ raise TypeError(
302
+ f"Parameter '{param.name}' has unsupported type annotation '{param.annotation}'. "
303
+ f"Supported types are: int, float, str, bool"
304
+ )
305
+
306
+
307
+ def get_method_parameters(func: Callable[..., Any]) -> list[tuple[str, Any, Any]]:
308
+ """Extract keyword-only parameters with defaults from an inference method.
309
+
310
+ Args:
311
+ func: The inference method.
312
+
313
+ Returns:
314
+ A list of tuples (name, type annotation, default value) for each keyword-only parameter.
315
+ """
316
+ func_signature = inspect.signature(func)
317
+ params = list(func_signature.parameters.values())
318
+ return [(param.name, param.annotation, param.default) for param in params[2:]]
319
+
278
320
 
279
321
  def inference_api(
280
- func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
281
- ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
322
+ func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
323
+ ) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
282
324
  func.__dict__["_is_inference_api"] = True
283
325
  return func
284
326
 
285
327
 
286
328
  def partitioned_api(
287
- func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
288
- ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
329
+ func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
330
+ ) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
289
331
  func.__dict__["_is_inference_api"] = True
290
332
  func.__dict__["_is_partitioned_api"] = True
291
333
  return func
@@ -296,8 +338,8 @@ def partitioned_api(
296
338
  " Use snowflake.ml.custom_model.partitioned_api instead."
297
339
  )
298
340
  def partitioned_inference_api(
299
- func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
300
- ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
341
+ func: Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame],
342
+ ) -> Callable[Concatenate[model_types.CustomModelType, pd.DataFrame, InferenceParams], pd.DataFrame]:
301
343
  func.__dict__["_is_inference_api"] = True
302
344
  func.__dict__["_is_partitioned_api"] = True
303
345
  return func
@@ -34,6 +34,9 @@ DataType = core.DataType
34
34
  BaseFeatureSpec = core.BaseFeatureSpec
35
35
  FeatureSpec = core.FeatureSpec
36
36
  FeatureGroupSpec = core.FeatureGroupSpec
37
+ BaseParamSpec = core.BaseParamSpec
38
+ ParamSpec = core.ParamSpec
39
+ ParamGroupSpec = core.ParamGroupSpec
37
40
  ModelSignature = core.ModelSignature
38
41
 
39
42
 
@@ -711,6 +714,7 @@ def infer_signature(
711
714
  output_feature_names: Optional[list[str]] = None,
712
715
  input_data_limit: Optional[int] = 100,
713
716
  output_data_limit: Optional[int] = 100,
717
+ params: Optional[Sequence[core.BaseParamSpec]] = None,
714
718
  ) -> core.ModelSignature:
715
719
  """
716
720
  Infer model signature from given input and output sample data.
@@ -740,12 +744,20 @@ def infer_signature(
740
744
  output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
741
745
  100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
742
746
  are used.
747
+ params: Optional sequence of parameter specifications to include in the signature. Parameters define
748
+ optional configuration values that can be passed to model inference. Defaults to None.
749
+
750
+ Raises:
751
+ SnowflakeMLException: ValueError: Raised when input data contains columns matching parameter names.
743
752
 
744
753
  Returns:
745
754
  A model signature inferred from the given input and output sample data.
755
+
756
+ # noqa: DAR402
746
757
  """
747
758
  inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
748
759
  inputs = utils.rename_features(inputs, input_feature_names)
760
+
749
761
  outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
750
762
  outputs = utils.rename_features(outputs, output_feature_names)
751
- return core.ModelSignature(inputs, outputs)
763
+ return core.ModelSignature(inputs, outputs, params=params)
@@ -13,6 +13,7 @@ from typing import (
13
13
  import numpy.typing as npt
14
14
  from typing_extensions import NotRequired
15
15
 
16
+ from snowflake.ml.model.code_path import CodePath
16
17
  from snowflake.ml.model.compute_pool import (
17
18
  DEFAULT_CPU_COMPUTE_POOL,
18
19
  DEFAULT_GPU_COMPUTE_POOL,
@@ -366,6 +367,7 @@ ModelLoadOption = Union[
366
367
 
367
368
 
368
369
  SupportedTargetPlatformType = Union[TargetPlatform, str]
370
+ CodePathLike = Union[str, CodePath]
369
371
 
370
372
 
371
373
  class ProgressStatus(Protocol):