snowflake-ml-python 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. snowflake/ml/_internal/utils/mixins.py +26 -1
  2. snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
  3. snowflake/ml/data/data_connector.py +2 -2
  4. snowflake/ml/data/data_ingestor.py +2 -1
  5. snowflake/ml/experiment/_experiment_info.py +3 -3
  6. snowflake/ml/jobs/_interop/data_utils.py +8 -8
  7. snowflake/ml/jobs/_interop/dto_schema.py +52 -7
  8. snowflake/ml/jobs/_interop/protocols.py +124 -7
  9. snowflake/ml/jobs/_interop/utils.py +92 -33
  10. snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
  11. snowflake/ml/jobs/_utils/constants.py +4 -0
  12. snowflake/ml/jobs/_utils/feature_flags.py +97 -13
  13. snowflake/ml/jobs/_utils/payload_utils.py +6 -40
  14. snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
  15. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
  16. snowflake/ml/jobs/decorators.py +17 -22
  17. snowflake/ml/jobs/job.py +25 -10
  18. snowflake/ml/jobs/job_definition.py +100 -8
  19. snowflake/ml/model/_client/model/model_version_impl.py +25 -14
  20. snowflake/ml/model/_client/ops/service_ops.py +6 -6
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  23. snowflake/ml/model/models/huggingface_pipeline.py +3 -0
  24. snowflake/ml/model/openai_signatures.py +154 -0
  25. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
  26. snowflake/ml/version.py +1 -1
  27. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +41 -2
  28. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +31 -32
  29. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
  30. snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
  31. snowflake/ml/jobs/_utils/spec_utils.py +0 -22
  32. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  33. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,12 @@
1
1
  import copy
2
- import functools
3
2
  from typing import Any, Callable, Optional, TypeVar
4
3
 
5
4
  from typing_extensions import ParamSpec
6
5
 
7
6
  from snowflake import snowpark
8
7
  from snowflake.ml._internal import telemetry
9
- from snowflake.ml.jobs import job as jb, manager as jm
10
- from snowflake.ml.jobs._utils import payload_utils
8
+ from snowflake.ml.jobs import job_definition as jd
9
+ from snowflake.ml.jobs._utils import arg_protocol, constants
11
10
 
12
11
  _PROJECT = "MLJob"
13
12
 
@@ -25,7 +24,7 @@ def remote(
25
24
  external_access_integrations: Optional[list[str]] = None,
26
25
  session: Optional[snowpark.Session] = None,
27
26
  **kwargs: Any,
28
- ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
27
+ ) -> Callable[[Callable[_Args, _ReturnValue]], jd.MLJobDefinition[_Args, _ReturnValue]]:
29
28
  """
30
29
  Submit a job to the compute pool.
31
30
 
@@ -51,29 +50,25 @@ def remote(
51
50
  Decorator that dispatches invocations of the decorated function as remote jobs.
52
51
  """
53
52
 
54
- def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
53
+ def decorator(func: Callable[_Args, _ReturnValue]) -> jd.MLJobDefinition[_Args, _ReturnValue]:
55
54
  # Copy the function to avoid modifying the original
56
55
  # We need to modify the line number of the function to exclude the
57
56
  # decorator from the copied source code
58
57
  wrapped_func = copy.copy(func)
59
58
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
60
59
 
61
- @functools.wraps(func)
62
- def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
63
- payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
64
- job = jm._submit_job(
65
- source=payload,
66
- stage_name=stage_name,
67
- compute_pool=compute_pool,
68
- target_instances=target_instances,
69
- pip_requirements=pip_requirements,
70
- external_access_integrations=external_access_integrations,
71
- session=payload.session or session,
72
- **kwargs,
73
- )
74
- assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
75
- return job
76
-
77
- return wrapper
60
+ setattr(wrapped_func, constants.IS_MLJOB_REMOTE_ATTR, True)
61
+ return jd.MLJobDefinition.register(
62
+ source=wrapped_func,
63
+ compute_pool=compute_pool,
64
+ stage_name=stage_name,
65
+ target_instances=target_instances,
66
+ pip_requirements=pip_requirements,
67
+ external_access_integrations=external_access_integrations,
68
+ session=session or snowpark.context.get_active_session(),
69
+ arg_protocol=arg_protocol.ArgProtocol.PICKLE,
70
+ generate_suffix=True,
71
+ **kwargs,
72
+ )
78
73
 
79
74
  return decorator
snowflake/ml/jobs/job.py CHANGED
@@ -123,26 +123,41 @@ class MLJob(Generic[T], SerializableSessionMixin):
123
123
 
124
124
  return self._transform_path(result_path_str)
125
125
 
126
- def _transform_path(self, path_str: str) -> str:
126
+ # After introducing ML Job definitions, we have additional stage mount for result path
127
+ # the result path is like @payload_stage/{job_definition_name}/{job_name}/mljob_result
128
+ @property
129
+ def _result_stage_path(self) -> Optional[str]:
130
+ volumes = self._service_spec["spec"]["volumes"]
131
+ stage_volume = next((v for v in volumes if v["name"] == constants.RESULT_VOLUME_NAME), None)
132
+ if stage_volume is None:
133
+ return self._stage_path
134
+ elif "stageConfig" in stage_volume:
135
+ return cast(str, stage_volume["stageConfig"]["name"])
136
+ else:
137
+ return cast(str, stage_volume["source"])
138
+
139
+ def _transform_path(
140
+ self,
141
+ path_str: str,
142
+ ) -> str:
127
143
  """Transform a local path within the container to a stage path."""
128
144
  path = stage_utils.resolve_path(path_str)
129
145
  if isinstance(path, stage_utils.StagePath):
130
- # Stage paths need no transformation
131
146
  return path.as_posix()
132
147
  if not path.is_absolute():
133
- # Assume relative paths are relative to stage mount path
134
- return f"{self._stage_path}/{path.as_posix()}"
148
+ return f"{self._result_stage_path}/{path.as_posix()}"
135
149
 
136
- # If result path is absolute, rebase it onto the stage mount path
137
- # TODO: Rather than matching by name, use the longest mount path which matches
138
150
  volume_mounts = self._container_spec["volumeMounts"]
139
- stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
151
+ stage_volume = next((v for v in volume_mounts if v["name"] == constants.RESULT_VOLUME_NAME), None)
152
+ if stage_volume is None:
153
+ stage_volume = next(v for v in volume_mounts if v["name"] == constants.STAGE_VOLUME_NAME)
154
+ stage_mount_str = stage_volume["mountPath"]
140
155
  stage_mount = Path(stage_mount_str)
141
156
  try:
142
157
  relative_path = path.relative_to(stage_mount)
143
- return f"{self._stage_path}/{relative_path.as_posix()}"
158
+ return f"{self._result_stage_path}/{relative_path.as_posix()}"
144
159
  except ValueError:
145
- raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
160
+ raise ValueError(f"Result Path {path} is absolute, but should be relative to stage mount {stage_mount}")
146
161
 
147
162
  @overload
148
163
  def get_logs(
@@ -279,7 +294,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
279
294
  if self._result is None:
280
295
  self.wait(timeout)
281
296
  try:
282
- self._result = interop_utils.load_result(
297
+ self._result = interop_utils.load(
283
298
  self._result_path, session=self._session, path_transform=self._transform_path
284
299
  )
285
300
  except Exception as e:
@@ -14,11 +14,14 @@ from snowflake.ml._internal import telemetry
14
14
  from snowflake.ml._internal.utils import identifier
15
15
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
16
16
  from snowflake.ml.jobs import job as jb
17
+ from snowflake.ml.jobs._interop import utils as interop_utils
17
18
  from snowflake.ml.jobs._utils import (
19
+ arg_protocol,
18
20
  constants,
19
21
  feature_flags,
20
22
  payload_utils,
21
23
  query_helper,
24
+ runtime_env_utils,
22
25
  types,
23
26
  )
24
27
  from snowflake.snowpark import context as sp_context
@@ -40,6 +43,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
40
43
  compute_pool: str,
41
44
  name: str,
42
45
  entrypoint_args: list[Any],
46
+ arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
47
+ default_args: Optional[list[Any]] = None,
43
48
  database: Optional[str] = None,
44
49
  schema: Optional[str] = None,
45
50
  session: Optional[snowpark.Session] = None,
@@ -49,12 +54,22 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
49
54
  self.spec_options = spec_options
50
55
  self.compute_pool = compute_pool
51
56
  self.session = session or sp_context.get_active_session()
52
- self.database = database or self.session.get_current_database()
53
- self.schema = schema or self.session.get_current_schema()
57
+ resolved_database = database or self.session.get_current_database()
58
+ resolved_schema = schema or self.session.get_current_schema()
59
+ if resolved_database is None:
60
+ raise ValueError("Database must be specified either in the session context or as a parameter.")
61
+ if resolved_schema is None:
62
+ raise ValueError("Schema must be specified either in the session context or as a parameter.")
63
+ self.database = identifier.resolve_identifier(resolved_database)
64
+ self.schema = identifier.resolve_identifier(resolved_schema)
54
65
  self.job_definition_id = identifier.get_schema_level_object_identifier(self.database, self.schema, name)
55
66
  self.entrypoint_args = entrypoint_args
67
+ self.arg_protocol = arg_protocol
68
+ self.default_args = default_args
56
69
 
57
70
  def delete(self) -> None:
71
+ if self.session is None:
72
+ raise RuntimeError("Session is required to delete job definition")
58
73
  if self.stage_name:
59
74
  try:
60
75
  self.session.sql(f"REMOVE {self.stage_name}/").collect()
@@ -62,9 +77,27 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
62
77
  except Exception as e:
63
78
  logger.warning(f"Failed to clean up stage files for job definition {self.stage_name}: {e}")
64
79
 
65
- def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> list[Any]:
66
- # TODO: Add ArgProtocol and respective logics
67
- return [arg for arg in args]
80
+ def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> Optional[list[Any]]:
81
+ if self.arg_protocol == arg_protocol.ArgProtocol.NONE:
82
+ if len(kwargs) > 0:
83
+ raise ValueError(f"Keyword arguments are not supported with {self.arg_protocol}")
84
+ return list(args)
85
+ elif self.arg_protocol == arg_protocol.ArgProtocol.CLI:
86
+ return _combine_runtime_arguments(self.default_args, *args, **kwargs)
87
+ elif self.arg_protocol == arg_protocol.ArgProtocol.PICKLE:
88
+ if not args and not kwargs:
89
+ return []
90
+ uid = uuid4().hex[:8]
91
+ rel_path = f"{uid}/function_args"
92
+ file_path = f"{self.stage_name}/{constants.APP_STAGE_SUBPATH}/{rel_path}"
93
+ payload = interop_utils.save_result(
94
+ (args, kwargs), file_path, session=self.session, max_inline_size=interop_utils._MAX_INLINE_SIZE
95
+ )
96
+ if payload is not None:
97
+ return [f"--function_args={payload.decode('utf-8')}"]
98
+ return [f"--function_args={rel_path}"]
99
+ else:
100
+ raise ValueError(f"Invalid arg_protocol: {self.arg_protocol}")
68
101
 
69
102
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
70
103
  def __call__(self, *args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
@@ -98,6 +131,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
98
131
  json.dumps(job_options_dict),
99
132
  ]
100
133
  query_template = "CALL SYSTEM$EXECUTE_ML_JOB(%s, %s, %s, %s)"
134
+ assert self.session is not None, "Session is required to generate MLJob SQL query"
101
135
  sql = self.session._conn._cursor._preprocess_pyformat_query(query_template, params)
102
136
  return sql
103
137
 
@@ -123,6 +157,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
123
157
  entrypoint: Optional[Union[str, list[str]]] = None,
124
158
  target_instances: int = 1,
125
159
  generate_suffix: bool = True,
160
+ arg_protocol: Optional[arg_protocol.ArgProtocol] = arg_protocol.ArgProtocol.NONE,
126
161
  **kwargs: Any,
127
162
  ) -> "MLJobDefinition[_Args, _ReturnValue]":
128
163
  # Use kwargs for less common optional parameters
@@ -142,6 +177,7 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
142
177
  )
143
178
  overwrite = kwargs.pop("overwrite", False)
144
179
  name = kwargs.pop("name", None)
180
+ default_args = kwargs.pop("default_args", None)
145
181
  # Warn if there are unknown kwargs
146
182
  if kwargs:
147
183
  logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
@@ -149,6 +185,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
149
185
  # Validate parameters
150
186
  if database and not schema:
151
187
  raise ValueError("Schema must be specified if database is specified.")
188
+
189
+ compute_pool = identifier.resolve_identifier(compute_pool)
190
+ if query_warehouse is not None:
191
+ query_warehouse = identifier.resolve_identifier(query_warehouse)
192
+
152
193
  if target_instances < 1:
153
194
  raise ValueError("target_instances must be greater than 0.")
154
195
  if not (0 < min_instances <= target_instances):
@@ -190,10 +231,11 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
190
231
  )
191
232
  raise
192
233
 
193
- if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(default=True):
234
+ if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
194
235
  # Pass a JSON object for runtime versions so it serializes as nested JSON in options
195
236
  runtime_environment = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
196
237
 
238
+ runtime = runtime_env_utils.get_runtime_image(session, compute_pool, runtime_environment)
197
239
  combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
198
240
  entrypoint_args = [v.as_posix() if isinstance(v, PurePath) else v for v in uploaded_payload.entrypoint]
199
241
  spec_options = types.SpecOptions(
@@ -203,8 +245,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
203
245
  env_vars=combined_env_vars,
204
246
  enable_metrics=enable_metrics,
205
247
  spec_overrides=spec_overrides,
206
- runtime=runtime_environment if runtime_environment else None,
207
- enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True),
248
+ runtime=runtime,
249
+ enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(),
208
250
  )
209
251
 
210
252
  job_options = types.JobOptions(
@@ -222,6 +264,8 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
222
264
  compute_pool=compute_pool,
223
265
  entrypoint_args=entrypoint_args,
224
266
  session=session,
267
+ arg_protocol=arg_protocol,
268
+ default_args=default_args,
225
269
  database=database,
226
270
  schema=schema,
227
271
  name=name,
@@ -230,3 +274,51 @@ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
230
274
 
231
275
  def _generate_suffix() -> str:
232
276
  return str(uuid4().hex)[:8]
277
+
278
+
279
+ def _combine_runtime_arguments(
280
+ default_runtime_args: Optional[list[Any]] = None, *args: Any, **kwargs: Any
281
+ ) -> list[Any]:
282
+ """Merge default CLI arguments with runtime overrides into a flat argument list.
283
+
284
+ Parses `default_runtime_args` for flags (e.g., `--key value`) and merges them with
285
+ `kwargs`. Keyword arguments override defaults unless their value is None. Positional
286
+ arguments from both `default_args` and `*args` are preserved in order.
287
+
288
+ Args:
289
+ default_runtime_args: Optional list of default CLI arguments to parse for flags and positional args.
290
+ *args: Additional positional arguments to include in the output.
291
+ **kwargs: Keyword arguments that override default flags. Values of None are ignored.
292
+
293
+ Returns:
294
+ A list of CLI-style arguments: positional args followed by `--key value` pairs.
295
+ """
296
+ cli_args = list(args)
297
+ flags: dict[str, Any] = {}
298
+ if default_runtime_args:
299
+ i = 0
300
+ while i < len(default_runtime_args):
301
+ arg = default_runtime_args[i]
302
+ if isinstance(arg, str) and arg.startswith("--"):
303
+ key = arg[2:]
304
+ # Check if next arg is a value (not a flag)
305
+ if i + 1 < len(default_runtime_args):
306
+ next_arg = default_runtime_args[i + 1]
307
+ if not (isinstance(next_arg, str) and next_arg.startswith("--")):
308
+ flags[key] = next_arg
309
+ i += 2
310
+ continue
311
+
312
+ flags[key] = None
313
+ else:
314
+ cli_args.append(arg)
315
+ i += 1
316
+ # Prioritize kwargs over default_args. Explicit None values in kwargs
317
+ # serve as overrides and are converted to the string "None" to match
318
+ # CLI flag conventions (--key=value)
319
+ # Downstream logic must handle the parsing of these string-based nulls.
320
+ for k, v in kwargs.items():
321
+ flags[k] = v
322
+ for k, v in flags.items():
323
+ cli_args.extend([f"--{k}", str(v)])
324
+ return cli_args
@@ -33,6 +33,12 @@ _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
33
33
  VLLM_SUPPORTED_TASKS = [
34
34
  "text-generation",
35
35
  "image-text-to-text",
36
+ "video-text-to-text",
37
+ "audio-text-to-text",
38
+ ]
39
+ VALID_OPENAI_SIGNATURES = [
40
+ openai_signatures.OPENAI_CHAT_SIGNATURE,
41
+ openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
36
42
  ]
37
43
 
38
44
 
@@ -1140,16 +1146,11 @@ class ModelVersion(lineage_node.LineageNode):
1140
1146
  func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
1141
1147
  }
1142
1148
 
1143
- if deserialized_signatures not in [
1144
- openai_signatures.OPENAI_CHAT_SIGNATURE,
1145
- openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
1146
- ]:
1149
+ if deserialized_signatures not in VALID_OPENAI_SIGNATURES:
1147
1150
  raise ValueError(
1148
- "Inference engine requires the model to be logged with openai_signatures.OPENAI_CHAT_SIGNATURE or "
1149
- "openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING. "
1151
+ "Inference engine requires the model to be logged with one of the following signatures: "
1152
+ f"{VALID_OPENAI_SIGNATURES}. Please log the model again with one of these supported signatures."
1150
1153
  f"Found signatures: {signatures_dict}. "
1151
- "Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
1152
- "signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
1153
1154
  )
1154
1155
 
1155
1156
  @overload
@@ -1161,6 +1162,7 @@ class ModelVersion(lineage_node.LineageNode):
1161
1162
  service_compute_pool: str,
1162
1163
  image_repo: Optional[str] = None,
1163
1164
  ingress_enabled: bool = False,
1165
+ min_instances: int = 0,
1164
1166
  max_instances: int = 1,
1165
1167
  cpu_requests: Optional[str] = None,
1166
1168
  memory_requests: Optional[str] = None,
@@ -1187,8 +1189,10 @@ class ModelVersion(lineage_node.LineageNode):
1187
1189
  will be used.
1188
1190
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
1189
1191
  BIND SERVICE ENDPOINT privilege on the account.
1190
- max_instances: The maximum number of inference service instances to run. The same value it set to
1191
- MIN_INSTANCES property of the service.
1192
+ min_instances: The minimum number of instances for the inference service. The service will automatically
1193
+ scale between min_instances and max_instances based on traffic and hardware utilization. If set to
1194
+ 0 (default), the service will automatically suspend after a period of inactivity.
1195
+ max_instances: The maximum number of instances for the inference service.
1192
1196
  cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
1193
1197
  None, we attempt to utilize all the vCPU of the node.
1194
1198
  memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
@@ -1224,6 +1228,7 @@ class ModelVersion(lineage_node.LineageNode):
1224
1228
  service_compute_pool: str,
1225
1229
  image_repo: Optional[str] = None,
1226
1230
  ingress_enabled: bool = False,
1231
+ min_instances: int = 0,
1227
1232
  max_instances: int = 1,
1228
1233
  cpu_requests: Optional[str] = None,
1229
1234
  memory_requests: Optional[str] = None,
@@ -1250,8 +1255,10 @@ class ModelVersion(lineage_node.LineageNode):
1250
1255
  will be used.
1251
1256
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
1252
1257
  BIND SERVICE ENDPOINT privilege on the account.
1253
- max_instances: The maximum number of inference service instances to run. The same value it set to
1254
- MIN_INSTANCES property of the service.
1258
+ min_instances: The minimum number of instances for the inference service. The service will automatically
1259
+ scale between min_instances and max_instances based on traffic and hardware utilization. If set to
1260
+ 0 (default), the service will automatically suspend after a period of inactivity.
1261
+ max_instances: The maximum number of instances for the inference service.
1255
1262
  cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
1256
1263
  None, we attempt to utilize all the vCPU of the node.
1257
1264
  memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
@@ -1301,6 +1308,7 @@ class ModelVersion(lineage_node.LineageNode):
1301
1308
  service_compute_pool: str,
1302
1309
  image_repo: Optional[str] = None,
1303
1310
  ingress_enabled: bool = False,
1311
+ min_instances: int = 0,
1304
1312
  max_instances: int = 1,
1305
1313
  cpu_requests: Optional[str] = None,
1306
1314
  memory_requests: Optional[str] = None,
@@ -1328,8 +1336,10 @@ class ModelVersion(lineage_node.LineageNode):
1328
1336
  will be used.
1329
1337
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
1330
1338
  BIND SERVICE ENDPOINT privilege on the account.
1331
- max_instances: The maximum number of inference service instances to run. The same value it set to
1332
- MIN_INSTANCES property of the service.
1339
+ min_instances: The minimum number of instances for the inference service. The service will automatically
1340
+ scale between min_instances and max_instances based on traffic and hardware utilization. If set to
1341
+ 0 (default), the service will automatically suspend after a period of inactivity.
1342
+ max_instances: The maximum number of instances for the inference service.
1333
1343
  cpu_requests: The cpu limit for CPU based inference. Can be an integer, fractional or string values. If
1334
1344
  None, we attempt to utilize all the vCPU of the node.
1335
1345
  memory_requests: The memory limit with for CPU based inference. Can be an integer or a fractional value, but
@@ -1419,6 +1429,7 @@ class ModelVersion(lineage_node.LineageNode):
1419
1429
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
1420
1430
  image_repo_name=image_repo,
1421
1431
  ingress_enabled=ingress_enabled,
1432
+ min_instances=min_instances,
1422
1433
  max_instances=max_instances,
1423
1434
  cpu_requests=cpu_requests,
1424
1435
  memory_requests=memory_requests,
@@ -175,6 +175,7 @@ class ServiceOperator:
175
175
  service_compute_pool_name: sql_identifier.SqlIdentifier,
176
176
  image_repo_name: Optional[str],
177
177
  ingress_enabled: bool,
178
+ min_instances: int,
178
179
  max_instances: int,
179
180
  cpu_requests: Optional[str],
180
181
  memory_requests: Optional[str],
@@ -241,6 +242,7 @@ class ServiceOperator:
241
242
  service_name=service_name,
242
243
  inference_compute_pool_name=service_compute_pool_name,
243
244
  ingress_enabled=ingress_enabled,
245
+ min_instances=min_instances,
244
246
  max_instances=max_instances,
245
247
  cpu=cpu_requests,
246
248
  memory=memory_requests,
@@ -829,15 +831,13 @@ class ServiceOperator:
829
831
  service_seen_before = False
830
832
 
831
833
  while True:
832
- # Check if async job has failed (but don't return on success - we need specific service status)
834
+ # Check if async job has completed
833
835
  if async_job.is_done():
834
836
  try:
835
837
  async_job.result()
836
- # Async job completed successfully, but we're waiting for a specific service status
837
- # This might mean the service completed and was cleaned up
838
- module_logger.debug(
839
- f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
840
- )
838
+ # Async job completed successfully - deployment is done
839
+ module_logger.debug(f"Async job completed successfully, returning from wait for {service_name}")
840
+ return
841
841
  except Exception as e:
842
842
  raise RuntimeError(f"Service deployment failed: {e}")
843
843
 
@@ -140,6 +140,7 @@ class ModelDeploymentSpec:
140
140
  service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
141
141
  service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
142
142
  ingress_enabled: bool = True,
143
+ min_instances: int = 0,
143
144
  max_instances: int = 1,
144
145
  cpu: Optional[str] = None,
145
146
  memory: Optional[str] = None,
@@ -156,6 +157,7 @@ class ModelDeploymentSpec:
156
157
  service_database_name: Database name for the service.
157
158
  service_schema_name: Schema name for the service.
158
159
  ingress_enabled: Whether ingress is enabled.
160
+ min_instances: Minimum number of service instances.
159
161
  max_instances: Maximum number of service instances.
160
162
  cpu: CPU requirement.
161
163
  memory: Memory requirement.
@@ -187,6 +189,7 @@ class ModelDeploymentSpec:
187
189
  name=fq_service_name,
188
190
  compute_pool=inference_compute_pool_name.identifier(),
189
191
  ingress_enabled=ingress_enabled,
192
+ min_instances=min_instances,
190
193
  max_instances=max_instances,
191
194
  autocapture=autocapture,
192
195
  **self._inference_spec,
@@ -26,6 +26,7 @@ class Service(BaseModel):
26
26
  name: str
27
27
  compute_pool: str
28
28
  ingress_enabled: bool
29
+ min_instances: int
29
30
  max_instances: int
30
31
  cpu: Optional[str] = None
31
32
  memory: Optional[str] = None
@@ -105,6 +105,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
105
105
  image_repo: Optional[str] = None,
106
106
  image_build_compute_pool: Optional[str] = None,
107
107
  ingress_enabled: bool = False,
108
+ min_instances: int = 0,
108
109
  max_instances: int = 1,
109
110
  cpu_requests: Optional[str] = None,
110
111
  memory_requests: Optional[str] = None,
@@ -133,6 +134,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
133
134
  image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
134
135
  the service compute pool if None.
135
136
  ingress_enabled: Whether ingress is enabled. Defaults to False.
137
+ min_instances: Minimum number of instances. Defaults to 0.
136
138
  max_instances: Maximum number of instances. Defaults to 1.
137
139
  cpu_requests: CPU requests configuration. Defaults to None.
138
140
  memory_requests: Memory requests configuration. Defaults to None.
@@ -225,6 +227,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
225
227
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
226
228
  image_repo_name=image_repo,
227
229
  ingress_enabled=ingress_enabled,
230
+ min_instances=min_instances,
228
231
  max_instances=max_instances,
229
232
  cpu_requests=cpu_requests,
230
233
  memory_requests=memory_requests,