snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/_internal/utils/mixins.py +26 -1
  3. snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
  4. snowflake/ml/data/data_connector.py +2 -2
  5. snowflake/ml/data/data_ingestor.py +2 -1
  6. snowflake/ml/experiment/_experiment_info.py +3 -3
  7. snowflake/ml/feature_store/__init__.py +2 -0
  8. snowflake/ml/feature_store/aggregation.py +367 -0
  9. snowflake/ml/feature_store/feature.py +366 -0
  10. snowflake/ml/feature_store/feature_store.py +234 -20
  11. snowflake/ml/feature_store/feature_view.py +189 -4
  12. snowflake/ml/feature_store/metadata_manager.py +425 -0
  13. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  14. snowflake/ml/jobs/_interop/data_utils.py +8 -8
  15. snowflake/ml/jobs/_interop/dto_schema.py +52 -7
  16. snowflake/ml/jobs/_interop/protocols.py +124 -7
  17. snowflake/ml/jobs/_interop/utils.py +92 -33
  18. snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
  19. snowflake/ml/jobs/_utils/constants.py +4 -0
  20. snowflake/ml/jobs/_utils/feature_flags.py +97 -13
  21. snowflake/ml/jobs/_utils/payload_utils.py +6 -40
  22. snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
  23. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
  24. snowflake/ml/jobs/decorators.py +17 -22
  25. snowflake/ml/jobs/job.py +25 -10
  26. snowflake/ml/jobs/job_definition.py +100 -8
  27. snowflake/ml/model/__init__.py +4 -0
  28. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  29. snowflake/ml/model/_client/model/model_version_impl.py +56 -28
  30. snowflake/ml/model/_client/ops/model_ops.py +2 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +6 -11
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  34. snowflake/ml/model/_client/sql/service.py +21 -29
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
  36. snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
  37. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
  38. snowflake/ml/model/_signatures/utils.py +76 -1
  39. snowflake/ml/model/models/huggingface_pipeline.py +3 -0
  40. snowflake/ml/model/openai_signatures.py +154 -0
  41. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
  42. snowflake/ml/version.py +1 -1
  43. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
  44. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
  45. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
  46. snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
  47. snowflake/ml/jobs/_utils/spec_utils.py +0 -22
  48. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,14 @@ from snowflake.ml._internal import telemetry
14
14
  from snowflake.ml._internal.utils import identifier
15
15
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
16
16
  from snowflake.ml.jobs import job as jb
17
+ from snowflake.ml.jobs._interop import utils as interop_utils
17
18
  from snowflake.ml.jobs._utils import (
19
+ arg_protocol,
18
20
  constants,
19
21
  feature_flags,
20
22
  payload_utils,
21
23
  query_helper,
24
+ runtime_env_utils,
22
25
  types,
23
26
  )
24
27
  from snowflake.snowpark import context as sp_context
@@ -40,6 +43,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
40
43
  compute_pool: str,
41
44
  name: str,
42
45
  entrypoint_args: list[Any],
46
+ arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
47
+ default_args: Optional[list[Any]] = None,
43
48
  database: Optional[str] = None,
44
49
  schema: Optional[str] = None,
45
50
  session: Optional[snowpark.Session] = None,
@@ -49,12 +54,22 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
49
54
  self.spec_options = spec_options
50
55
  self.compute_pool = compute_pool
51
56
  self.session = session or sp_context.get_active_session()
52
- self.database = database or self.session.get_current_database()
53
- self.schema = schema or self.session.get_current_schema()
57
+ resolved_database = database or self.session.get_current_database()
58
+ resolved_schema = schema or self.session.get_current_schema()
59
+ if resolved_database is None:
60
+ raise ValueError("Database must be specified either in the session context or as a parameter.")
61
+ if resolved_schema is None:
62
+ raise ValueError("Schema must be specified either in the session context or as a parameter.")
63
+ self.database = identifier.resolve_identifier(resolved_database)
64
+ self.schema = identifier.resolve_identifier(resolved_schema)
54
65
  self.job_definition_id = identifier.get_schema_level_object_identifier(self.database, self.schema, name)
55
66
  self.entrypoint_args = entrypoint_args
67
+ self.arg_protocol = arg_protocol
68
+ self.default_args = default_args
56
69
 
57
70
  def delete(self) -> None:
71
+ if self.session is None:
72
+ raise RuntimeError("Session is required to delete job definition")
58
73
  if self.stage_name:
59
74
  try:
60
75
  self.session.sql(f"REMOVE {self.stage_name}/").collect()
@@ -62,9 +77,27 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
62
77
  except Exception as e:
63
78
  logger.warning(f"Failed to clean up stage files for job definition {self.stage_name}: {e}")
64
79
 
65
- def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> list[Any]:
66
- # TODO: Add ArgProtocol and respective logics
67
- return [arg for arg in args]
80
+ def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> Optional[list[Any]]:
81
+ if self.arg_protocol == arg_protocol.ArgProtocol.NONE:
82
+ if len(kwargs) > 0:
83
+ raise ValueError(f"Keyword arguments are not supported with {self.arg_protocol}")
84
+ return list(args)
85
+ elif self.arg_protocol == arg_protocol.ArgProtocol.CLI:
86
+ return _combine_runtime_arguments(self.default_args, *args, **kwargs)
87
+ elif self.arg_protocol == arg_protocol.ArgProtocol.PICKLE:
88
+ if not args and not kwargs:
89
+ return []
90
+ uid = uuid4().hex[:8]
91
+ rel_path = f"{uid}/function_args"
92
+ file_path = f"{self.stage_name}/{constants.APP_STAGE_SUBPATH}/{rel_path}"
93
+ payload = interop_utils.save_result(
94
+ (args, kwargs), file_path, session=self.session, max_inline_size=interop_utils._MAX_INLINE_SIZE
95
+ )
96
+ if payload is not None:
97
+ return [f"--function_args={payload.decode('utf-8')}"]
98
+ return [f"--function_args={rel_path}"]
99
+ else:
100
+ raise ValueError(f"Invalid arg_protocol: {self.arg_protocol}")
68
101
 
69
102
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
70
103
  def __call__(self, *args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
@@ -98,6 +131,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
98
131
  json.dumps(job_options_dict),
99
132
  ]
100
133
  query_template = "CALL SYSTEM$EXECUTE_ML_JOB(%s, %s, %s, %s)"
134
+ assert self.session is not None, "Session is required to generate MLJob SQL query"
101
135
  sql = self.session._conn._cursor._preprocess_pyformat_query(query_template, params)
102
136
  return sql
103
137
 
@@ -123,6 +157,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
123
157
  entrypoint: Optional[Union[str, list[str]]] = None,
124
158
  target_instances: int = 1,
125
159
  generate_suffix: bool = True,
160
+ arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
126
161
  **kwargs: Any,
127
162
  ) -> "MLJobDefinition[_Args, _ReturnValue]":
128
163
  # Use kwargs for less common optional parameters
@@ -142,6 +177,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
142
177
  )
143
178
  overwrite = kwargs.pop("overwrite", False)
144
179
  name = kwargs.pop("name", None)
180
+ default_args = kwargs.pop("default_args", None)
145
181
  # Warn if there are unknown kwargs
146
182
  if kwargs:
147
183
  logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
@@ -149,6 +185,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
149
185
  # Validate parameters
150
186
  if database and not schema:
151
187
  raise ValueError("Schema must be specified if database is specified.")
188
+
189
+ compute_pool = identifier.resolve_identifier(compute_pool)
190
+ if query_warehouse is not None:
191
+ query_warehouse = identifier.resolve_identifier(query_warehouse)
192
+
152
193
  if target_instances < 1:
153
194
  raise ValueError("target_instances must be greater than 0.")
154
195
  if not (0 < min_instances <= target_instances):
@@ -190,10 +231,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
190
231
  )
191
232
  raise
192
233
 
193
- if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(default=True):
234
+ if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
194
235
  # Pass a JSON object for runtime versions so it serializes as nested JSON in options
195
236
  runtime_environment = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
196
237
 
238
+ runtime = runtime_env_utils.get_runtime_image(session, compute_pool, runtime_environment)
197
239
  combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
198
240
  entrypoint_args = [v.as_posix() if isinstance(v, PurePath) else v for v in uploaded_payload.entrypoint]
199
241
  spec_options = types.SpecOptions(
@@ -203,8 +245,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
203
245
  env_vars=combined_env_vars,
204
246
  enable_metrics=enable_metrics,
205
247
  spec_overrides=spec_overrides,
206
- runtime=runtime_environment if runtime_environment else None,
207
- enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True),
248
+ runtime=runtime,
249
+ enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(),
208
250
  )
209
251
 
210
252
  job_options = types.JobOptions(
@@ -222,6 +264,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
222
264
  compute_pool=compute_pool,
223
265
  entrypoint_args=entrypoint_args,
224
266
  session=session,
267
+ arg_protocol=arg_protocol,
268
+ default_args=default_args,
225
269
  database=database,
226
270
  schema=schema,
227
271
  name=name,
@@ -230,3 +274,51 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
230
274
 
231
275
  def _generate_suffix() -> str:
232
276
  return str(uuid4().hex)[:8]
277
+
278
+
279
+ def _combine_runtime_arguments(
280
+ default_runtime_args: Optional[list[Any]] = None, *args: Any, **kwargs: Any
281
+ ) -> list[Any]:
282
+ """Merge default CLI arguments with runtime overrides into a flat argument list.
283
+
284
+ Parses `default_runtime_args` for flags (e.g., `--key value`) and merges them with
285
+ `kwargs`. Keyword arguments override defaults unless their value is None. Positional
286
+ arguments from both `default_args` and `*args` are preserved in order.
287
+
288
+ Args:
289
+ default_runtime_args: Optional list of default CLI arguments to parse for flags and positional args.
290
+ *args: Additional positional arguments to include in the output.
291
+ **kwargs: Keyword arguments that override default flags. Values of None are ignored.
292
+
293
+ Returns:
294
+ A list of CLI-style arguments: positional args followed by `--key value` pairs.
295
+ """
296
+ cli_args = list(args)
297
+ flags: dict[str, Any] = {}
298
+ if default_runtime_args:
299
+ i = 0
300
+ while i < len(default_runtime_args):
301
+ arg = default_runtime_args[i]
302
+ if isinstance(arg, str) and arg.startswith("--"):
303
+ key = arg[2:]
304
+ # Check if next arg is a value (not a flag)
305
+ if i + 1 < len(default_runtime_args):
306
+ next_arg = default_runtime_args[i + 1]
307
+ if not (isinstance(next_arg, str) and next_arg.startswith("--")):
308
+ flags[key] = next_arg
309
+ i += 2
310
+ continue
311
+
312
+ flags[key] = None
313
+ else:
314
+ cli_args.append(arg)
315
+ i += 1
316
+ # Prioritize kwargs over default_args. Explicit None values in kwargs
317
+ # serve as overrides and are converted to the string "None" to match
318
+ # CLI flag conventions (--key=value)
319
+ # Downstream logic must handle the parsing of these string-based nulls.
320
+ for k, v in kwargs.items():
321
+ flags[k] = v
322
+ for k, v in flags.items():
323
+ cli_args.extend([f"--{k}", str(v)])
324
+ return cli_args
@@ -4,6 +4,8 @@ import warnings
4
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
5
5
  ColumnHandlingOptions,
6
6
  FileEncoding,
7
+ InputFormat,
8
+ InputSpec,
7
9
  JobSpec,
8
10
  OutputSpec,
9
11
  SaveMode,
@@ -20,6 +22,8 @@ __all__ = [
20
22
  "ModelVersion",
21
23
  "ExportMode",
22
24
  "HuggingFacePipelineModel",
25
+ "InputSpec",
26
+ "InputFormat",
23
27
  "JobSpec",
24
28
  "OutputSpec",
25
29
  "SaveMode",
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from pydantic import BaseModel
5
5
  from typing_extensions import TypedDict
@@ -19,6 +19,12 @@ class SaveMode(str, Enum):
19
19
  ERROR = "error"
20
20
 
21
21
 
22
+ class InputFormat(str, Enum):
23
+ """The format of the input column data."""
24
+
25
+ FULL_STAGE_PATH = "full_stage_path"
26
+
27
+
22
28
  class FileEncoding(str, Enum):
23
29
  """The encoding of the file content that will be passed to the custom model."""
24
30
 
@@ -30,7 +36,37 @@ class FileEncoding(str, Enum):
30
36
  class ColumnHandlingOptions(TypedDict):
31
37
  """Options for handling specific columns during run_batch for file I/O."""
32
38
 
33
- encoding: FileEncoding
39
+ input_format: InputFormat
40
+ convert_to: FileEncoding
41
+
42
+
43
+ class InputSpec(BaseModel):
44
+ """Specification for batch inference input options.
45
+
46
+ Defines optional configuration for processing input data during batch inference.
47
+
48
+ Attributes:
49
+ params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
50
+ (e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
51
+ model's inference method. Defaults to None.
52
+ column_handling (Optional[dict[str, ColumnHandlingOptions]]): Optional dictionary
53
+ specifying how to handle specific columns during file I/O. Maps column names to their
54
+ input format and file encoding configuration.
55
+
56
+ Example:
57
+ >>> input_spec = InputSpec(
58
+ ... params={"temperature": 0.7, "top_k": 50},
59
+ ... column_handling={
60
+ ... "image_col": {
61
+ ... "input_format": InputFormat.FULL_STAGE_PATH,
62
+ ... "convert_to": FileEncoding.BASE64
63
+ ... }
64
+ ... }
65
+ ... )
66
+ """
67
+
68
+ params: Optional[dict[str, Any]] = None
69
+ column_handling: Optional[dict[str, ColumnHandlingOptions]] = None
34
70
 
35
71
 
36
72
  class OutputSpec(BaseModel):
@@ -33,6 +33,12 @@ _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
33
33
  VLLM_SUPPORTED_TASKS = [
34
34
  "text-generation",
35
35
  "image-text-to-text",
36
+ "video-text-to-text",
37
+ "audio-text-to-text",
38
+ ]
39
+ VALID_OPENAI_SIGNATURES = [
40
+ openai_signatures.OPENAI_CHAT_SIGNATURE,
41
+ openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
36
42
  ]
37
43
 
38
44
 
@@ -661,13 +667,12 @@ class ModelVersion(lineage_node.LineageNode):
661
667
  @snowpark._internal.utils.private_preview(version="1.18.0")
662
668
  def run_batch(
663
669
  self,
670
+ X: dataframe.DataFrame,
664
671
  *,
665
672
  compute_pool: str,
666
- input_spec: dataframe.DataFrame,
673
+ input_spec: Optional[batch_inference_specs.InputSpec] = None,
667
674
  output_spec: batch_inference_specs.OutputSpec,
668
675
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
669
- params: Optional[dict[str, Any]] = None,
670
- column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
671
676
  inference_engine_options: Optional[dict[str, Any]] = None,
672
677
  ) -> job.MLJob[Any]:
673
678
  """Execute batch inference on datasets as an SPCS job.
@@ -675,19 +680,16 @@ class ModelVersion(lineage_node.LineageNode):
675
680
  Args:
676
681
  compute_pool (str): Name of the compute pool to use for building the image containers and batch
677
682
  inference execution.
678
- input_spec (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
683
+ X (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
679
684
  The DataFrame should contain all required features for model prediction and passthrough columns.
680
685
  output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
681
686
  the inference results. Specifies the stage location and file handling behavior.
687
+ input_spec (Optional[batch_inference_specs.InputSpec]): Optional configuration for input
688
+ processing including model inference parameters and column handling options.
689
+ If None, default values will be used for params and column_handling.
682
690
  job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
683
691
  execution parameters such as compute resources, worker counts, and job naming.
684
692
  If None, default values will be used.
685
- params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
686
- (e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
687
- model's inference method. Defaults to None.
688
- column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
689
- specifying how to handle specific columns during file I/O. Maps column names to their
690
- file encoding configuration.
691
693
  inference_engine_options: Options for the service creation with custom inference engine.
692
694
  Supports `engine` and `engine_args_override`.
693
695
  `engine` is the type of the inference engine to use.
@@ -699,7 +701,7 @@ class ModelVersion(lineage_node.LineageNode):
699
701
 
700
702
  Raises:
701
703
  ValueError: If warehouse is not set in job_spec and no current warehouse is available.
702
- RuntimeError: If the input_spec cannot be processed or written to the staging location.
704
+ RuntimeError: If the input data cannot be processed or written to the staging location.
703
705
 
704
706
  Example:
705
707
  >>> # Prepare input data - Example 1: From a table
@@ -732,10 +734,24 @@ class ModelVersion(lineage_node.LineageNode):
732
734
  >>> # Run batch inference
733
735
  >>> job = model_version.run_batch(
734
736
  ... compute_pool="my_compute_pool",
735
- ... input_spec=input_df,
737
+ ... X=input_df,
736
738
  ... output_spec=output_spec,
737
739
  ... job_spec=job_spec
738
740
  ... )
741
+ >>>
742
+ >>> # Run batch inference with InputSpec for additional options
743
+ >>> from snowflake.ml.model._client.model.batch_inference_specs import InputSpec, FileEncoding
744
+ >>> input_spec = InputSpec(
745
+ ... params={"temperature": 0.7, "top_k": 50},
746
+ ... column_handling={"image_col": {"encoding": FileEncoding.BASE64}}
747
+ ... )
748
+ >>> job = model_version.run_batch(
749
+ ... compute_pool="my_compute_pool",
750
+ ... X=input_df,
751
+ ... output_spec=output_spec,
752
+ ... input_spec=input_spec,
753
+ ... job_spec=job_spec
754
+ ... )
739
755
 
740
756
  Note:
741
757
  This method is currently in private preview and requires Snowflake version 1.18.0 or later.
@@ -747,6 +763,13 @@ class ModelVersion(lineage_node.LineageNode):
747
763
  subproject=_TELEMETRY_SUBPROJECT,
748
764
  )
749
765
 
766
+ # Extract params and column_handling from input_spec if provided
767
+ if input_spec is None:
768
+ input_spec = batch_inference_specs.InputSpec()
769
+
770
+ params = input_spec.params
771
+ column_handling = input_spec.column_handling
772
+
750
773
  if job_spec is None:
751
774
  job_spec = batch_inference_specs.JobSpec()
752
775
 
@@ -772,10 +795,10 @@ class ModelVersion(lineage_node.LineageNode):
772
795
  self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
773
796
 
774
797
  try:
775
- input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
798
+ X.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
776
799
  # todo: be specific about the type of errors to provide better error messages.
777
800
  except Exception as e:
778
- raise RuntimeError(f"Failed to process input_spec: {e}")
801
+ raise RuntimeError(f"Failed to process input data: {e}")
779
802
 
780
803
  if job_spec.job_name is None:
781
804
  # Same as the MLJob ID generation logic with a different prefix
@@ -1123,16 +1146,11 @@ class ModelVersion(lineage_node.LineageNode):
1123
1146
  func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
1124
1147
  }
1125
1148
 
1126
- if deserialized_signatures not in [
1127
- openai_signatures.OPENAI_CHAT_SIGNATURE,
1128
- openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
1129
- ]:
1149
+ if deserialized_signatures not in VALID_OPENAI_SIGNATURES:
1130
1150
  raise ValueError(
1131
- "Inference engine requires the model to be logged with openai_signatures.OPENAI_CHAT_SIGNATURE or "
1132
- "openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING. "
1151
+ "Inference engine requires the model to be logged with one of the following signatures: "
1152
+ f"{VALID_OPENAI_SIGNATURES}. Please log the model again with one of these supported signatures."
1133
1153
  f"Found signatures: {signatures_dict}. "
1134
- "Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
1135
- "signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
1136
1154
  )
1137
1155
 
1138
1156
  @overload
@@ -1144,6 +1162,7 @@ class ModelVersion(lineage_node.LineageNode):
1144
1162
  service_compute_pool: str,
1145
1163
  image_repo: Optional[str] = None,
1146
1164
  ingress_enabled: bool = False,
1165
+ min_instances: int = 0,
1147
1166
  max_instances: int = 1,
1148
1167
  cpu_requests: Optional[str] = None,
1149
1168
  memory_requests: Optional[str] = None,
@@ -1170,8 +1189,10 @@ class ModelVersion(lineage_node.LineageNode):
1170
1189
  will be used.
1171
1190
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
1172
1191
  BIND SERVICE ENDPOINT privilege on the account.
1173
- max_instances: The maximum number of inference service instances to run. The same value it set to
1174
- MIN_INSTANCES property of the service.
1192
+ min_instances: The minimum number of instances for the inference service. The service will automatically
1193
+ scale between min_instances and max_instances based on traffic and hardware utilization. If set to
1194
+ 0 (default), the service will automatically suspend after a period of inactivity.
1195
+ max_instances: The maximum number of instances for the inference service.
1175
1196
  cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
1176
1197
  None, we attempt to utilize all the vCPU of the node.
1177
1198
  memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
@@ -1207,6 +1228,7 @@ class ModelVersion(lineage_node.LineageNode):
1207
1228
  service_compute_pool: str,
1208
1229
  image_repo: Optional[str] = None,
1209
1230
  ingress_enabled: bool = False,
1231
+ min_instances: int = 0,
1210
1232
  max_instances: int = 1,
1211
1233
  cpu_requests: Optional[str] = None,
1212
1234
  memory_requests: Optional[str] = None,
@@ -1233,8 +1255,10 @@ class ModelVersion(lineage_node.LineageNode):
1233
1255
  will be used.
1234
1256
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
1235
1257
  BIND SERVICE ENDPOINT privilege on the account.
1236
- max_instances: The maximum number of inference service instances to run. The same value it set to
1237
- MIN_INSTANCES property of the service.
1258
+ min_instances: The minimum number of instances for the inference service. The service will automatically
1259
+ scale between min_instances and max_instances based on traffic and hardware utilization. If set to
1260
+ 0 (default), the service will automatically suspend after a period of inactivity.
1261
+ max_instances: The maximum number of instances for the inference service.
1238
1262
  cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
1239
1263
  None, we attempt to utilize all the vCPU of the node.
1240
1264
  memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
@@ -1284,6 +1308,7 @@ class ModelVersion(lineage_node.LineageNode):
1284
1308
  service_compute_pool: str,
1285
1309
  image_repo: Optional[str] = None,
1286
1310
  ingress_enabled: bool = False,
1311
+ min_instances: int = 0,
1287
1312
  max_instances: int = 1,
1288
1313
  cpu_requests: Optional[str] = None,
1289
1314
  memory_requests: Optional[str] = None,
@@ -1311,8 +1336,10 @@ class ModelVersion(lineage_node.LineageNode):
1311
1336
  will be used.
1312
1337
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
1313
1338
  BIND SERVICE ENDPOINT privilege on the account.
1314
- max_instances: The maximum number of inference service instances to run. The same value it set to
1315
- MIN_INSTANCES property of the service.
1339
+ min_instances: The minimum number of instances for the inference service. The service will automatically
1340
+ scale between min_instances and max_instances based on traffic and hardware utilization. If set to
1341
+ 0 (default), the service will automatically suspend after a period of inactivity.
1342
+ max_instances: The maximum number of instances for the inference service.
1316
1343
  cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
1317
1344
  None, we attempt to utilize all the vCPU of the node.
1318
1345
  memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
@@ -1402,6 +1429,7 @@ class ModelVersion(lineage_node.LineageNode):
1402
1429
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
1403
1430
  image_repo_name=image_repo,
1404
1431
  ingress_enabled=ingress_enabled,
1432
+ min_instances=min_instances,
1405
1433
  max_instances=max_instances,
1406
1434
  cpu_requests=cpu_requests,
1407
1435
  memory_requests=memory_requests,
@@ -10,7 +10,6 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
10
10
  import yaml
11
11
  from typing_extensions import NotRequired
12
12
 
13
- from snowflake.ml._internal import platform_capabilities
14
13
  from snowflake.ml._internal.exceptions import error_codes, exceptions
15
14
  from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
16
15
  from snowflake.ml.model import model_signature, type_hints
@@ -698,9 +697,6 @@ class ModelOperator:
698
697
 
699
698
  result: list[ServiceInfo] = []
700
699
  is_privatelink_connection = self._is_privatelink_connection()
701
- is_autocapture_param_enabled = (
702
- platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
703
- )
704
700
 
705
701
  for fully_qualified_service_name in fully_qualified_service_names:
706
702
  port: Optional[int] = None
@@ -742,10 +738,8 @@ class ModelOperator:
742
738
  inference_endpoint=inference_endpoint,
743
739
  internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
744
740
  )
745
- if is_autocapture_param_enabled and self._service_client.DESC_SERVICE_SPEC_COL_NAME in service_description:
746
- # Include column only if parameter is enabled and spec exists for service owner caller
747
- autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
748
- service_info["autocapture_enabled"] = autocapture_enabled
741
+ autocapture_enabled = self._service_client.is_autocapture_enabled(service_description)
742
+ service_info["autocapture_enabled"] = autocapture_enabled
749
743
 
750
744
  result.append(service_info)
751
745
 
@@ -155,7 +155,6 @@ class ServiceOperator:
155
155
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
156
156
  workspace_path=pathlib.Path(self._workspace.name)
157
157
  )
158
- self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
159
158
 
160
159
  def __eq__(self, __value: object) -> bool:
161
160
  if not isinstance(__value, ServiceOperator):
@@ -176,6 +175,7 @@ class ServiceOperator:
176
175
  service_compute_pool_name: sql_identifier.SqlIdentifier,
177
176
  image_repo_name: Optional[str],
178
177
  ingress_enabled: bool,
178
+ min_instances: int,
179
179
  max_instances: int,
180
180
  cpu_requests: Optional[str],
181
181
  memory_requests: Optional[str],
@@ -216,10 +216,6 @@ class ServiceOperator:
216
216
  progress_status.update("preparing deployment artifacts...")
217
217
  progress_status.increment()
218
218
 
219
- # If autocapture param is disabled, don't allow create service with autocapture
220
- if not self._inference_autocapture_enabled and autocapture:
221
- raise ValueError("Invalid Argument: Autocapture feature is not supported.")
222
-
223
219
  if self._workspace:
224
220
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
225
221
  else:
@@ -246,6 +242,7 @@ class ServiceOperator:
246
242
  service_name=service_name,
247
243
  inference_compute_pool_name=service_compute_pool_name,
248
244
  ingress_enabled=ingress_enabled,
245
+ min_instances=min_instances,
249
246
  max_instances=max_instances,
250
247
  cpu=cpu_requests,
251
248
  memory=memory_requests,
@@ -834,15 +831,13 @@ class ServiceOperator:
834
831
  service_seen_before = False
835
832
 
836
833
  while True:
837
- # Check if async job has failed (but don't return on success - we need specific service status)
834
+ # Check if async job has completed
838
835
  if async_job.is_done():
839
836
  try:
840
837
  async_job.result()
841
- # Async job completed successfully, but we're waiting for a specific service status
842
- # This might mean the service completed and was cleaned up
843
- module_logger.debug(
844
- f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
845
- )
838
+ # Async job completed successfully - deployment is done
839
+ module_logger.debug(f"Async job completed successfully, returning from wait for {service_name}")
840
+ return
846
841
  except Exception as e:
847
842
  raise RuntimeError(f"Service deployment failed: {e}")
848
843
 
@@ -140,6 +140,7 @@ class ModelDeploymentSpec:
140
140
  service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
141
141
  service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
142
142
  ingress_enabled: bool = True,
143
+ min_instances: int = 0,
143
144
  max_instances: int = 1,
144
145
  cpu: Optional[str] = None,
145
146
  memory: Optional[str] = None,
@@ -156,6 +157,7 @@ class ModelDeploymentSpec:
156
157
  service_database_name: Database name for the service.
157
158
  service_schema_name: Schema name for the service.
158
159
  ingress_enabled: Whether ingress is enabled.
160
+ min_instances: Minimum number of service instances.
159
161
  max_instances: Maximum number of service instances.
160
162
  cpu: CPU requirement.
161
163
  memory: Memory requirement.
@@ -187,6 +189,7 @@ class ModelDeploymentSpec:
187
189
  name=fq_service_name,
188
190
  compute_pool=inference_compute_pool_name.identifier(),
189
191
  ingress_enabled=ingress_enabled,
192
+ min_instances=min_instances,
190
193
  max_instances=max_instances,
191
194
  autocapture=autocapture,
192
195
  **self._inference_spec,
@@ -26,6 +26,7 @@ class Service(BaseModel):
26
26
  name: str
27
27
  compute_pool: str
28
28
  ingress_enabled: bool
29
+ min_instances: int
29
30
  max_instances: int
30
31
  cpu: Optional[str] = None
31
32
  memory: Optional[str] = None