snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -16
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/data/data_connector.py +1 -1
  4. snowflake/ml/jobs/__init__.py +2 -0
  5. snowflake/ml/jobs/_utils/constants.py +12 -2
  6. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +95 -39
  9. snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
  11. snowflake/ml/jobs/_utils/spec_utils.py +30 -6
  12. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -1
  14. snowflake/ml/jobs/decorators.py +10 -7
  15. snowflake/ml/jobs/job.py +176 -28
  16. snowflake/ml/jobs/manager.py +119 -26
  17. snowflake/ml/model/_client/model/model_impl.py +58 -0
  18. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  20. snowflake/ml/model/_client/ops/service_ops.py +24 -7
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  22. snowflake/ml/model/_client/sql/model_version.py +1 -1
  23. snowflake/ml/model/_client/sql/service.py +73 -28
  24. snowflake/ml/model/_client/sql/stage.py +5 -2
  25. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  29. snowflake/ml/model/_signatures/core.py +24 -0
  30. snowflake/ml/monitoring/explain_visualize.py +160 -22
  31. snowflake/ml/monitoring/model_monitor.py +0 -4
  32. snowflake/ml/registry/registry.py +34 -14
  33. snowflake/ml/utils/connection_params.py +9 -3
  34. snowflake/ml/utils/html_utils.py +263 -0
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
  37. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
  38. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
  39. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  40. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -1,3 +1,5 @@
1
+ import logging
2
+ import os
1
3
  import time
2
4
  from functools import cached_property
3
5
  from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
@@ -12,10 +14,12 @@ from snowflake.snowpark import Row, context as sp_context
12
14
  from snowflake.snowpark.exceptions import SnowparkSQLException
13
15
 
14
16
  _PROJECT = "MLJob"
15
- TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
17
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
16
18
 
17
19
  T = TypeVar("T")
18
20
 
21
+ logger = logging.getLogger(__name__)
22
+
19
23
 
20
24
  class MLJob(Generic[T]):
21
25
  def __init__(
@@ -36,8 +40,15 @@ class MLJob(Generic[T]):
36
40
  return identifier.parse_schema_level_object_identifier(self.id)[-1]
37
41
 
38
42
  @cached_property
39
- def num_instances(self) -> int:
40
- return _get_num_instances(self._session, self.id)
43
+ def target_instances(self) -> int:
44
+ return _get_target_instances(self._session, self.id)
45
+
46
+ @cached_property
47
+ def min_instances(self) -> int:
48
+ try:
49
+ return int(self._container_spec["env"].get(constants.MIN_INSTANCES_ENV_VAR, 1))
50
+ except TypeError:
51
+ return 1
41
52
 
42
53
  @property
43
54
  def id(self) -> str:
@@ -52,6 +63,12 @@ class MLJob(Generic[T]):
52
63
  self._status = _get_status(self._session, self.id)
53
64
  return self._status
54
65
 
66
+ @cached_property
67
+ def _compute_pool(self) -> str:
68
+ """Get the job's compute pool name."""
69
+ row = _get_service_info(self._session, self.id)
70
+ return cast(str, row["compute_pool"])
71
+
55
72
  @property
56
73
  def _service_spec(self) -> dict[str, Any]:
57
74
  """Get the job's service spec."""
@@ -82,15 +99,34 @@ class MLJob(Generic[T]):
82
99
  return f"{self._stage_path}/{result_path}"
83
100
 
84
101
  @overload
85
- def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[True]) -> list[str]:
102
+ def get_logs(
103
+ self,
104
+ limit: int = -1,
105
+ instance_id: Optional[int] = None,
106
+ *,
107
+ as_list: Literal[True],
108
+ verbose: bool = constants.DEFAULT_VERBOSE_LOG,
109
+ ) -> list[str]:
86
110
  ...
87
111
 
88
112
  @overload
89
- def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[False] = False) -> str:
113
+ def get_logs(
114
+ self,
115
+ limit: int = -1,
116
+ instance_id: Optional[int] = None,
117
+ *,
118
+ as_list: Literal[False] = False,
119
+ verbose: bool = constants.DEFAULT_VERBOSE_LOG,
120
+ ) -> str:
90
121
  ...
91
122
 
92
123
  def get_logs(
93
- self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: bool = False
124
+ self,
125
+ limit: int = -1,
126
+ instance_id: Optional[int] = None,
127
+ *,
128
+ as_list: bool = False,
129
+ verbose: bool = constants.DEFAULT_VERBOSE_LOG,
94
130
  ) -> Union[str, list[str]]:
95
131
  """
96
132
  Return the job's execution logs.
@@ -100,17 +136,20 @@ class MLJob(Generic[T]):
100
136
  instance_id: Optional instance ID to get logs from a specific instance.
101
137
  If not provided, returns logs from the head node.
102
138
  as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
139
+ verbose: Whether to return the full log or just the user log.
103
140
 
104
141
  Returns:
105
142
  The job's execution logs.
106
143
  """
107
- logs = _get_logs(self._session, self.id, limit, instance_id)
144
+ logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
108
145
  assert isinstance(logs, str) # mypy
109
146
  if as_list:
110
147
  return logs.splitlines()
111
148
  return logs
112
149
 
113
- def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
150
+ def show_logs(
151
+ self, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = constants.DEFAULT_VERBOSE_LOG
152
+ ) -> None:
114
153
  """
115
154
  Display the job's execution logs.
116
155
 
@@ -118,8 +157,9 @@ class MLJob(Generic[T]):
118
157
  limit: The maximum number of lines to display. Negative values are treated as no limit.
119
158
  instance_id: Optional instance ID to get logs from a specific instance.
120
159
  If not provided, displays logs from the head node.
160
+ verbose: Whether to return the full log or just the user log.
121
161
  """
122
- print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
162
+ print(self.get_logs(limit, instance_id, as_list=False, verbose=verbose)) # noqa: T201: we need to print here.
123
163
 
124
164
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
125
165
  def wait(self, timeout: float = -1) -> types.JOB_STATUS:
@@ -137,11 +177,20 @@ class MLJob(Generic[T]):
137
177
  """
138
178
  delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
139
179
  start_time = time.monotonic()
140
- while self.status not in TERMINAL_JOB_STATUSES:
180
+ warning_shown = False
181
+ while (status := self.status) not in TERMINAL_JOB_STATUSES:
182
+ if status == "PENDING" and not warning_shown:
183
+ pool_info = _get_compute_pool_info(self._session, self._compute_pool)
184
+ if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
185
+ logger.warning(
186
+ f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
187
+ f"{self.min_instances} nodes required). Job execution may be delayed."
188
+ )
189
+ warning_shown = True
141
190
  if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
142
191
  raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
143
192
  time.sleep(delay)
144
- delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
193
+ delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
145
194
  return self.status
146
195
 
147
196
  @snowpark._internal.utils.private_preview(version="1.8.2")
@@ -195,7 +244,9 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
195
244
 
196
245
 
197
246
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
198
- def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
247
+ def _get_logs(
248
+ session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
249
+ ) -> str:
199
250
  """
200
251
  Retrieve the job's execution logs.
201
252
 
@@ -204,24 +255,20 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
204
255
  limit: The maximum number of lines to return. Negative values are treated as no limit.
205
256
  session: The Snowpark session to use. If none specified, uses active session.
206
257
  instance_id: Optional instance ID to get logs from a specific instance.
258
+ verbose: Whether to return the full log or just the portion between START and END messages.
207
259
 
208
260
  Returns:
209
261
  The job's execution logs.
210
262
 
211
263
  Raises:
212
- SnowparkSQLException: if the container is pending
213
264
  RuntimeError: if failed to get head instance_id
214
-
215
265
  """
216
266
  # If instance_id is not specified, try to get the head instance ID
217
267
  if instance_id is None:
218
268
  try:
219
269
  instance_id = _get_head_instance_id(session, job_id)
220
270
  except RuntimeError:
221
- raise RuntimeError(
222
- "Failed to retrieve job logs. "
223
- "Logs may be inaccessible due to job expiration and can be retrieved from Event Table instead."
224
- )
271
+ instance_id = None
225
272
 
226
273
  # Assemble params: [job_id, instance_id, container_name, (optional) limit]
227
274
  params: list[Any] = [
@@ -231,7 +278,6 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
231
278
  ]
232
279
  if limit > 0:
233
280
  params.append(limit)
234
-
235
281
  try:
236
282
  (row,) = session.sql(
237
283
  f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
@@ -239,9 +285,43 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
239
285
  ).collect()
240
286
  except SnowparkSQLException as e:
241
287
  if "Container Status: PENDING" in e.message:
242
- return "Warning: Waiting for container to start. Logs will be shown when available."
243
- raise
244
- return str(row[0])
288
+ logger.warning("Waiting for container to start. Logs will be shown when available.")
289
+ return ""
290
+ else:
291
+ # event table accepts job name, not fully qualified name
292
+ # cast is to resolve the type check error
293
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
294
+ db = cast(str, db or session.get_current_database())
295
+ schema = cast(str, schema or session.get_current_schema())
296
+ logs = _get_service_log_from_event_table(
297
+ session, db, schema, name, limit, instance_id if instance_id else None
298
+ )
299
+ if len(logs) == 0:
300
+ raise RuntimeError(
301
+ "No logs were found. Please verify that the database, schema, and job ID are correct."
302
+ )
303
+ return os.linesep.join(row[0] for row in logs)
304
+
305
+ full_log = str(row[0])
306
+
307
+ # If verbose is True, return the complete log
308
+ if verbose:
309
+ return full_log
310
+
311
+ # Otherwise, extract only the portion between LOG_START_MSG and LOG_END_MSG
312
+ start_idx = full_log.find(constants.LOG_START_MSG)
313
+ if start_idx != -1:
314
+ start_idx += len(constants.LOG_START_MSG)
315
+ else:
316
+ # If start message not found, start from the beginning
317
+ start_idx = 0
318
+
319
+ end_idx = full_log.find(constants.LOG_END_MSG, start_idx)
320
+ if end_idx == -1:
321
+ # If end message not found, return everything after start
322
+ end_idx = len(full_log)
323
+
324
+ return full_log[start_idx:end_idx].strip()
245
325
 
246
326
 
247
327
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
@@ -256,13 +336,25 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
256
336
  Returns:
257
337
  Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
258
338
 
259
- Raises:
339
+ Raises:
260
340
  RuntimeError: If the instances died or if some instances disappeared.
261
341
  """
262
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
342
+
343
+ target_instances = _get_target_instances(session, job_id)
344
+
345
+ if target_instances == 1:
346
+ return 0
347
+
348
+ try:
349
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
350
+ except SnowparkSQLException:
351
+ # service may be deleted
352
+ raise RuntimeError("Couldn’t retrieve instances")
353
+
263
354
  if not rows:
264
355
  return None
265
- if _get_num_instances(session, job_id) > len(rows):
356
+
357
+ if target_instances > len(rows):
266
358
  raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
267
359
 
268
360
  # Sort by start_time first, then by instance_id
@@ -270,7 +362,6 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
270
362
  sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
271
363
  except TypeError:
272
364
  raise RuntimeError("Job instance information unavailable.")
273
-
274
365
  head_instance = sorted_instances[0]
275
366
  if not head_instance["start_time"]:
276
367
  # If head instance hasn't started yet, return None
@@ -281,12 +372,69 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
281
372
  return 0
282
373
 
283
374
 
375
+ def _get_service_log_from_event_table(
376
+ session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
377
+ ) -> list[Row]:
378
+ params: list[Any] = [
379
+ database,
380
+ schema,
381
+ name,
382
+ ]
383
+ query = [
384
+ "SELECT VALUE FROM snowflake.telemetry.events_view",
385
+ 'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
386
+ 'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
387
+ 'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
388
+ ]
389
+
390
+ if instance_id:
391
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
392
+ params.append(instance_id)
393
+
394
+ query.append("AND RECORD_TYPE = 'LOG'")
395
+ # sort by TIMESTAMP; although OBSERVED_TIMESTAMP is for log, it is NONE currently when record_type is log
396
+ query.append("ORDER BY TIMESTAMP")
397
+
398
+ if limit > 0:
399
+ query.append("LIMIT ?")
400
+ params.append(limit)
401
+
402
+ rows = session.sql(
403
+ "\n".join(line for line in query if line),
404
+ params=params,
405
+ ).collect()
406
+ return rows
407
+
408
+
284
409
  def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
285
410
  (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
286
411
  return row
287
412
 
288
413
 
414
+ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
415
+ """
416
+ Check if the compute pool has enough available instances.
417
+
418
+ Args:
419
+ session (Session): The Snowpark session to use.
420
+ compute_pool (str): The name of the compute pool.
421
+
422
+ Returns:
423
+ Row: The compute pool information.
424
+
425
+ Raises:
426
+ ValueError: If the compute pool is not found.
427
+ """
428
+ try:
429
+ (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
430
+ return pool_info
431
+ except ValueError as e:
432
+ if "not enough values to unpack" in str(e):
433
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
434
+ raise
435
+
436
+
289
437
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
290
- def _get_num_instances(session: snowpark.Session, job_id: str) -> int:
438
+ def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
291
439
  row = _get_service_info(session, job_id)
292
- return int(row["target_instances"]) if row["target_instances"] else 0
440
+ return int(row["target_instances"])
@@ -87,13 +87,15 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
87
87
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
88
88
  def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
89
89
  """Delete a job service from the backend. Status and logs will be lost."""
90
- if isinstance(job, jb.MLJob):
91
- job_id = job.id
92
- session = job._session or session
93
- else:
94
- job_id = job
95
- session = session or get_active_session()
96
- session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
90
+ job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
91
+ session = job._session
92
+ try:
93
+ stage_path = job._stage_path
94
+ session.sql(f"REMOVE {stage_path}/").collect()
95
+ logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
96
+ except Exception as e:
97
+ logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
98
+ session.sql("DROP SERVICE IDENTIFIER(?)", params=(job.id,)).collect()
97
99
 
98
100
 
99
101
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -108,7 +110,8 @@ def submit_file(
108
110
  external_access_integrations: Optional[list[str]] = None,
109
111
  query_warehouse: Optional[str] = None,
110
112
  spec_overrides: Optional[dict[str, Any]] = None,
111
- num_instances: Optional[int] = None,
113
+ target_instances: int = 1,
114
+ min_instances: Optional[int] = None,
112
115
  enable_metrics: bool = False,
113
116
  database: Optional[str] = None,
114
117
  schema: Optional[str] = None,
@@ -127,7 +130,9 @@ def submit_file(
127
130
  external_access_integrations: A list of external access integrations.
128
131
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
129
132
  spec_overrides: Custom service specification overrides to apply.
130
- num_instances: The number of instances to use for the job. If none specified, single node job is created.
133
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
134
+ min_instances: The minimum number of nodes required to start the job. If none specified,
135
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
131
136
  enable_metrics: Whether to enable metrics publishing for the job.
132
137
  database: The database to use.
133
138
  schema: The schema to use.
@@ -146,7 +151,8 @@ def submit_file(
146
151
  external_access_integrations=external_access_integrations,
147
152
  query_warehouse=query_warehouse,
148
153
  spec_overrides=spec_overrides,
149
- num_instances=num_instances,
154
+ target_instances=target_instances,
155
+ min_instances=min_instances,
150
156
  enable_metrics=enable_metrics,
151
157
  database=database,
152
158
  schema=schema,
@@ -167,7 +173,8 @@ def submit_directory(
167
173
  external_access_integrations: Optional[list[str]] = None,
168
174
  query_warehouse: Optional[str] = None,
169
175
  spec_overrides: Optional[dict[str, Any]] = None,
170
- num_instances: Optional[int] = None,
176
+ target_instances: int = 1,
177
+ min_instances: Optional[int] = None,
171
178
  enable_metrics: bool = False,
172
179
  database: Optional[str] = None,
173
180
  schema: Optional[str] = None,
@@ -187,7 +194,9 @@ def submit_directory(
187
194
  external_access_integrations: A list of external access integrations.
188
195
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
189
196
  spec_overrides: Custom service specification overrides to apply.
190
- num_instances: The number of instances to use for the job. If none specified, single node job is created.
197
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
198
+ min_instances: The minimum number of nodes required to start the job. If none specified,
199
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
191
200
  enable_metrics: Whether to enable metrics publishing for the job.
192
201
  database: The database to use.
193
202
  schema: The schema to use.
@@ -207,7 +216,74 @@ def submit_directory(
207
216
  external_access_integrations=external_access_integrations,
208
217
  query_warehouse=query_warehouse,
209
218
  spec_overrides=spec_overrides,
210
- num_instances=num_instances,
219
+ target_instances=target_instances,
220
+ min_instances=min_instances,
221
+ enable_metrics=enable_metrics,
222
+ database=database,
223
+ schema=schema,
224
+ session=session,
225
+ )
226
+
227
+
228
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
229
+ def submit_from_stage(
230
+ source: str,
231
+ compute_pool: str,
232
+ *,
233
+ entrypoint: str,
234
+ stage_name: str,
235
+ args: Optional[list[str]] = None,
236
+ env_vars: Optional[dict[str, str]] = None,
237
+ pip_requirements: Optional[list[str]] = None,
238
+ external_access_integrations: Optional[list[str]] = None,
239
+ query_warehouse: Optional[str] = None,
240
+ spec_overrides: Optional[dict[str, Any]] = None,
241
+ target_instances: int = 1,
242
+ min_instances: Optional[int] = None,
243
+ enable_metrics: bool = False,
244
+ database: Optional[str] = None,
245
+ schema: Optional[str] = None,
246
+ session: Optional[snowpark.Session] = None,
247
+ ) -> jb.MLJob[None]:
248
+ """
249
+ Submit a directory containing Python script(s) as a job to the compute pool.
250
+
251
+ Args:
252
+ source: a stage path or a stage containing the job payload.
253
+ compute_pool: The compute pool to use for the job.
254
+ entrypoint: a stage path containing the entry point script inside the source directory.
255
+ stage_name: The name of the stage where the job payload will be uploaded.
256
+ args: A list of arguments to pass to the job.
257
+ env_vars: Environment variables to set in container
258
+ pip_requirements: A list of pip requirements for the job.
259
+ external_access_integrations: A list of external access integrations.
260
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
261
+ spec_overrides: Custom service specification overrides to apply.
262
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
263
+ min_instances: The minimum number of nodes required to start the job. If none specified,
264
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
265
+ enable_metrics: Whether to enable metrics publishing for the job.
266
+ database: The database to use.
267
+ schema: The schema to use.
268
+ session: The Snowpark session to use. If none specified, uses active session.
269
+
270
+
271
+ Returns:
272
+ An object representing the submitted job.
273
+ """
274
+ return _submit_job(
275
+ source=source,
276
+ entrypoint=entrypoint,
277
+ args=args,
278
+ compute_pool=compute_pool,
279
+ stage_name=stage_name,
280
+ env_vars=env_vars,
281
+ pip_requirements=pip_requirements,
282
+ external_access_integrations=external_access_integrations,
283
+ query_warehouse=query_warehouse,
284
+ spec_overrides=spec_overrides,
285
+ target_instances=target_instances,
286
+ min_instances=min_instances,
211
287
  enable_metrics=enable_metrics,
212
288
  database=database,
213
289
  schema=schema,
@@ -228,7 +304,8 @@ def _submit_job(
228
304
  external_access_integrations: Optional[list[str]] = None,
229
305
  query_warehouse: Optional[str] = None,
230
306
  spec_overrides: Optional[dict[str, Any]] = None,
231
- num_instances: Optional[int] = None,
307
+ target_instances: int = 1,
308
+ min_instances: Optional[int] = None,
232
309
  enable_metrics: bool = False,
233
310
  database: Optional[str] = None,
234
311
  schema: Optional[str] = None,
@@ -250,7 +327,8 @@ def _submit_job(
250
327
  external_access_integrations: Optional[list[str]] = None,
251
328
  query_warehouse: Optional[str] = None,
252
329
  spec_overrides: Optional[dict[str, Any]] = None,
253
- num_instances: Optional[int] = None,
330
+ target_instances: int = 1,
331
+ min_instances: Optional[int] = None,
254
332
  enable_metrics: bool = False,
255
333
  database: Optional[str] = None,
256
334
  schema: Optional[str] = None,
@@ -267,7 +345,7 @@ def _submit_job(
267
345
  # TODO: Log lengths of args, env_vars, and spec_overrides values
268
346
  "pip_requirements",
269
347
  "external_access_integrations",
270
- "num_instances",
348
+ "target_instances",
271
349
  "enable_metrics",
272
350
  ],
273
351
  )
@@ -283,7 +361,8 @@ def _submit_job(
283
361
  external_access_integrations: Optional[list[str]] = None,
284
362
  query_warehouse: Optional[str] = None,
285
363
  spec_overrides: Optional[dict[str, Any]] = None,
286
- num_instances: Optional[int] = None,
364
+ target_instances: int = 1,
365
+ min_instances: Optional[int] = None,
287
366
  enable_metrics: bool = False,
288
367
  database: Optional[str] = None,
289
368
  schema: Optional[str] = None,
@@ -303,7 +382,9 @@ def _submit_job(
303
382
  external_access_integrations: A list of external access integrations.
304
383
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
305
384
  spec_overrides: Custom service specification overrides to apply.
306
- num_instances: The number of instances to use for the job. If none specified, single node job is created.
385
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
386
+ min_instances: The minimum number of nodes required to start the job. If none specified,
387
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
307
388
  enable_metrics: Whether to enable metrics publishing for the job.
308
389
  database: The database to use.
309
390
  schema: The schema to use.
@@ -316,16 +397,27 @@ def _submit_job(
316
397
  RuntimeError: If required Snowflake features are not enabled.
317
398
  ValueError: If database or schema value(s) are invalid
318
399
  """
319
- # Display warning about PrPr parameters
320
- if num_instances is not None:
321
- logger.warning(
322
- "_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
323
- )
324
400
  if database and not schema:
325
401
  raise ValueError("Schema must be specified if database is specified.")
402
+ if target_instances < 1:
403
+ raise ValueError("target_instances must be greater than 0.")
404
+
405
+ min_instances = target_instances if min_instances is None else min_instances
406
+ if not (0 < min_instances <= target_instances):
407
+ raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
326
408
 
327
409
  session = session or get_active_session()
328
410
 
411
+ if min_instances > 1:
412
+ # Validate min_instances against compute pool max_nodes
413
+ pool_info = jb._get_compute_pool_info(session, compute_pool)
414
+ max_nodes = int(pool_info["max_nodes"])
415
+ if min_instances > max_nodes:
416
+ raise ValueError(
417
+ f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
418
+ f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
419
+ )
420
+
329
421
  # Validate database and schema identifiers on client side since
330
422
  # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
331
423
  database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
@@ -350,7 +442,8 @@ def _submit_job(
350
442
  compute_pool=compute_pool,
351
443
  payload=uploaded_payload,
352
444
  args=args,
353
- num_instances=num_instances,
445
+ target_instances=target_instances,
446
+ min_instances=min_instances,
354
447
  enable_metrics=enable_metrics,
355
448
  )
356
449
  spec_overrides = spec_utils.generate_spec_overrides(
@@ -381,9 +474,9 @@ def _submit_job(
381
474
  if query_warehouse:
382
475
  query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
383
476
  params.append(query_warehouse)
384
- if num_instances:
477
+ if target_instances > 1:
385
478
  query.append("REPLICAS = ?")
386
- params.append(num_instances)
479
+ params.append(target_instances)
387
480
 
388
481
  # Submit job
389
482
  query_text = "\n".join(line for line in query if line)
@@ -426,3 +426,61 @@ class Model:
426
426
  schema_name=new_schema or self._model_ops._model_client._schema_name,
427
427
  )
428
428
  self._model_name = new_model
429
+
430
+ def _repr_html_(self) -> str:
431
+ """Generate an HTML representation of the model.
432
+
433
+ Returns:
434
+ str: HTML string containing formatted model details.
435
+ """
436
+ from snowflake.ml.utils import html_utils
437
+
438
+ # Get default version
439
+ default_version = self.default.version_name
440
+
441
+ # Get versions info
442
+ try:
443
+ versions_df = self.show_versions()
444
+ versions_html = ""
445
+
446
+ for _, row in versions_df.iterrows():
447
+ versions_html += html_utils.create_version_item(
448
+ version_name=row["name"],
449
+ created_on=str(row["created_on"]),
450
+ comment=str(row.get("comment", "")),
451
+ is_default=bool(row["is_default_version"]),
452
+ )
453
+ except Exception:
454
+ versions_html = html_utils.create_error_message("Error retrieving versions")
455
+
456
+ # Get tags
457
+ try:
458
+ tags = self.show_tags()
459
+ if not tags:
460
+ tags_html = html_utils.create_error_message("No tags available")
461
+ else:
462
+ tags_html = ""
463
+ for tag_name, tag_value in tags.items():
464
+ tags_html += html_utils.create_tag_item(tag_name, tag_value)
465
+ except Exception:
466
+ tags_html = html_utils.create_error_message("Error retrieving tags")
467
+
468
+ # Create main content sections
469
+ main_info = html_utils.create_grid_section(
470
+ [
471
+ ("Model Name", self.name),
472
+ ("Full Name", self.fully_qualified_name),
473
+ ("Description", self.description),
474
+ ("Default Version", default_version),
475
+ ]
476
+ )
477
+
478
+ versions_section = html_utils.create_section_header("Versions") + html_utils.create_content_section(
479
+ versions_html
480
+ )
481
+
482
+ tags_section = html_utils.create_section_header("Tags") + html_utils.create_content_section(tags_html)
483
+
484
+ content = main_info + versions_section + tags_section
485
+
486
+ return html_utils.create_base_container("Model Details", content)