snowflake-ml-python 1.8.5__py3-none-any.whl → 1.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. snowflake/ml/_internal/telemetry.py +6 -9
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/jobs/__init__.py +2 -0
  4. snowflake/ml/jobs/_utils/constants.py +3 -2
  5. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  6. snowflake/ml/jobs/_utils/payload_utils.py +83 -35
  7. snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
  8. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
  9. snowflake/ml/jobs/_utils/spec_utils.py +23 -1
  10. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  11. snowflake/ml/jobs/_utils/types.py +5 -1
  12. snowflake/ml/jobs/decorators.py +6 -7
  13. snowflake/ml/jobs/job.py +24 -9
  14. snowflake/ml/jobs/manager.py +102 -19
  15. snowflake/ml/model/_client/model/model_impl.py +58 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  18. snowflake/ml/model/_client/ops/service_ops.py +19 -4
  19. snowflake/ml/model/_client/sql/service.py +68 -20
  20. snowflake/ml/model/_client/sql/stage.py +5 -2
  21. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  22. snowflake/ml/model/_signatures/core.py +24 -0
  23. snowflake/ml/monitoring/explain_visualize.py +2 -2
  24. snowflake/ml/monitoring/model_monitor.py +0 -4
  25. snowflake/ml/registry/registry.py +34 -14
  26. snowflake/ml/utils/connection_params.py +1 -1
  27. snowflake/ml/utils/html_utils.py +263 -0
  28. snowflake/ml/version.py +1 -1
  29. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +14 -5
  30. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +33 -30
  31. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  32. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +0 -0
  33. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
  34. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
1
+ import os
2
+ import re
3
+ from os import PathLike
4
+ from pathlib import Path, PurePath
5
+ from typing import Union
6
+
7
+ from snowflake.ml._internal.utils import identifier
8
+
9
+ PROTOCOL_NAME = "snow"
10
+ _SNOWURL_PATH_RE = re.compile(
11
+ rf"^(?:(?:{PROTOCOL_NAME}://)?"
12
+ r"(?<!@)(?P<domain>\w+)/"
13
+ rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/)?"
14
+ r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
15
+ )
16
+
17
+ _STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
18
+
19
+
20
+ class StagePath:
21
+ def __init__(self, path: str) -> None:
22
+ stage_match = _SNOWURL_PATH_RE.fullmatch(path) or _STAGEF_PATH_RE.fullmatch(path)
23
+ if not stage_match:
24
+ raise ValueError(f"{path} is not a valid stage path")
25
+ path = path.strip()
26
+ self._raw_path = path
27
+ relpath = stage_match.group("relpath")
28
+ start, _ = stage_match.span("relpath")
29
+ self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
30
+ self._path = Path(relpath or "")
31
+
32
+ @property
33
+ def parent(self) -> "StagePath":
34
+ if self._path.parent == Path(""):
35
+ return StagePath(self._root)
36
+ else:
37
+ return StagePath(f"{self._root}/{self._path.parent}")
38
+
39
+ @property
40
+ def root(self) -> str:
41
+ return self._root
42
+
43
+ @property
44
+ def suffix(self) -> str:
45
+ return self._path.suffix
46
+
47
+ def _compose_path(self, path: Path) -> str:
48
+ # in pathlib, Path("") = "."
49
+ if path == Path(""):
50
+ return self.root
51
+ else:
52
+ return f"{self.root}/{path}"
53
+
54
+ def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
55
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
56
+ if stage_path.root == self.root:
57
+ return self._path.is_relative_to(stage_path._path)
58
+ else:
59
+ return False
60
+
61
+ def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
62
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
63
+ if self.root == stage_path.root:
64
+ return self._path.relative_to(stage_path._path)
65
+ raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
66
+
67
+ def absolute(self) -> "StagePath":
68
+ return self
69
+
70
+ def as_posix(self) -> str:
71
+ return self._compose_path(self._path)
72
+
73
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
74
+ def exists(self) -> bool:
75
+ return True
76
+
77
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
78
+ def is_file(self) -> bool:
79
+ return True
80
+
81
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
82
+ def is_dir(self) -> bool:
83
+ return True
84
+
85
+ def is_absolute(self) -> bool:
86
+ return True
87
+
88
+ def __str__(self) -> str:
89
+ return self.as_posix()
90
+
91
+ def __eq__(self, other: object) -> bool:
92
+ if not isinstance(other, StagePath):
93
+ raise NotImplementedError
94
+ return bool(self.root == other.root and self._path == other._path)
95
+
96
+ def __fspath__(self) -> str:
97
+ return self._compose_path(self._path)
98
+
99
+ def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
100
+ path = self
101
+ for arg in args:
102
+ path = path._make_child(arg)
103
+ return path
104
+
105
+ def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
106
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
107
+ if self.root == stage_path.root:
108
+ child_path = self._path.joinpath(stage_path._path)
109
+ return StagePath(self._compose_path(child_path))
110
+ else:
111
+ return stage_path
112
+
113
+
114
+ def identify_stage_path(path: str) -> Union[StagePath, Path]:
115
+ try:
116
+ stage_path = StagePath(path)
117
+ except ValueError:
118
+ return Path(path)
119
+ return stage_path
@@ -2,18 +2,22 @@ from dataclasses import dataclass
2
2
  from pathlib import PurePath
3
3
  from typing import Literal, Optional, Union
4
4
 
5
+ from snowflake.ml.jobs._utils import stage_utils
6
+
5
7
  JOB_STATUS = Literal[
6
8
  "PENDING",
7
9
  "RUNNING",
8
10
  "FAILED",
9
11
  "DONE",
12
+ "CANCELLING",
13
+ "CANCELLED",
10
14
  "INTERNAL_ERROR",
11
15
  ]
12
16
 
13
17
 
14
18
  @dataclass(frozen=True)
15
19
  class PayloadEntrypoint:
16
- file_path: PurePath
20
+ file_path: Union[PurePath, stage_utils.StagePath]
17
21
  main_func: Optional[str]
18
22
 
19
23
 
@@ -7,7 +7,7 @@ from typing_extensions import ParamSpec
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import telemetry
9
9
  from snowflake.ml.jobs import job as jb, manager as jm
10
- from snowflake.ml.jobs._utils import constants
10
+ from snowflake.ml.jobs._utils import payload_utils
11
11
 
12
12
  _PROJECT = "MLJob"
13
13
 
@@ -25,7 +25,7 @@ def remote(
25
25
  query_warehouse: Optional[str] = None,
26
26
  env_vars: Optional[dict[str, str]] = None,
27
27
  target_instances: int = 1,
28
- min_instances: int = 1,
28
+ min_instances: Optional[int] = None,
29
29
  enable_metrics: bool = False,
30
30
  database: Optional[str] = None,
31
31
  schema: Optional[str] = None,
@@ -42,8 +42,8 @@ def remote(
42
42
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
43
43
  env_vars: Environment variables to set in container
44
44
  target_instances: The number of nodes in the job. If none specified, create a single node job.
45
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
46
- If set, the job will not start until the minimum number of nodes is available.
45
+ min_instances: The minimum number of nodes required to start the job. If none specified,
46
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
47
47
  enable_metrics: Whether to enable metrics publishing for the job.
48
48
  database: The database to use for the job.
49
49
  schema: The schema to use for the job.
@@ -62,8 +62,7 @@ def remote(
62
62
 
63
63
  @functools.wraps(func)
64
64
  def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
65
- payload = functools.partial(func, *args, **kwargs)
66
- setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
65
+ payload = payload_utils.create_function_payload(func, *args, **kwargs)
67
66
  job = jm._submit_job(
68
67
  source=payload,
69
68
  stage_name=stage_name,
@@ -77,7 +76,7 @@ def remote(
77
76
  enable_metrics=enable_metrics,
78
77
  database=database,
79
78
  schema=schema,
80
- session=session,
79
+ session=payload.session or session,
81
80
  )
82
81
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
83
82
  return job
snowflake/ml/jobs/job.py CHANGED
@@ -14,7 +14,7 @@ from snowflake.snowpark import Row, context as sp_context
14
14
  from snowflake.snowpark.exceptions import SnowparkSQLException
15
15
 
16
16
  _PROJECT = "MLJob"
17
- TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
17
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
18
18
 
19
19
  T = TypeVar("T")
20
20
 
@@ -183,14 +183,14 @@ class MLJob(Generic[T]):
183
183
  pool_info = _get_compute_pool_info(self._session, self._compute_pool)
184
184
  if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
185
185
  logger.warning(
186
- f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use)."
187
- " Job execution may be delayed."
186
+ f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
187
+ f"{self.min_instances} nodes required). Job execution may be delayed."
188
188
  )
189
189
  warning_shown = True
190
190
  if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
191
191
  raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
192
192
  time.sleep(delay)
193
- delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
193
+ delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
194
194
  return self.status
195
195
 
196
196
  @snowpark._internal.utils.private_preview(version="1.8.2")
@@ -338,16 +338,23 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
338
338
 
339
339
  Raises:
340
340
  RuntimeError: If the instances died or if some instances disappeared.
341
-
342
341
  """
342
+
343
+ target_instances = _get_target_instances(session, job_id)
344
+
345
+ if target_instances == 1:
346
+ return 0
347
+
343
348
  try:
344
349
  rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
345
350
  except SnowparkSQLException:
346
351
  # service may be deleted
347
352
  raise RuntimeError("Couldn’t retrieve instances")
353
+
348
354
  if not rows:
349
355
  return None
350
- if _get_target_instances(session, job_id) > len(rows):
356
+
357
+ if target_instances > len(rows):
351
358
  raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
352
359
 
353
360
  # Sort by start_time first, then by instance_id
@@ -414,12 +421,20 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
414
421
 
415
422
  Returns:
416
423
  Row: The compute pool information.
424
+
425
+ Raises:
426
+ ValueError: If the compute pool is not found.
417
427
  """
418
- (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
419
- return pool_info
428
+ try:
429
+ (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
430
+ return pool_info
431
+ except ValueError as e:
432
+ if "not enough values to unpack" in str(e):
433
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
434
+ raise
420
435
 
421
436
 
422
437
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
423
438
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
424
439
  row = _get_service_info(session, job_id)
425
- return int(row["target_instances"]) if row["target_instances"] else 0
440
+ return int(row["target_instances"])
@@ -87,13 +87,15 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
87
87
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
88
88
  def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
89
89
  """Delete a job service from the backend. Status and logs will be lost."""
90
- if isinstance(job, jb.MLJob):
91
- job_id = job.id
92
- session = job._session or session
93
- else:
94
- job_id = job
95
- session = session or get_active_session()
96
- session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
90
+ job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
91
+ session = job._session
92
+ try:
93
+ stage_path = job._stage_path
94
+ session.sql(f"REMOVE {stage_path}/").collect()
95
+ logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
96
+ except Exception as e:
97
+ logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
98
+ session.sql("DROP SERVICE IDENTIFIER(?)", params=(job.id,)).collect()
97
99
 
98
100
 
99
101
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -109,7 +111,7 @@ def submit_file(
109
111
  query_warehouse: Optional[str] = None,
110
112
  spec_overrides: Optional[dict[str, Any]] = None,
111
113
  target_instances: int = 1,
112
- min_instances: int = 1,
114
+ min_instances: Optional[int] = None,
113
115
  enable_metrics: bool = False,
114
116
  database: Optional[str] = None,
115
117
  schema: Optional[str] = None,
@@ -129,7 +131,8 @@ def submit_file(
129
131
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
130
132
  spec_overrides: Custom service specification overrides to apply.
131
133
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
132
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
134
+ min_instances: The minimum number of nodes required to start the job. If none specified,
135
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
133
136
  enable_metrics: Whether to enable metrics publishing for the job.
134
137
  database: The database to use.
135
138
  schema: The schema to use.
@@ -171,7 +174,7 @@ def submit_directory(
171
174
  query_warehouse: Optional[str] = None,
172
175
  spec_overrides: Optional[dict[str, Any]] = None,
173
176
  target_instances: int = 1,
174
- min_instances: int = 1,
177
+ min_instances: Optional[int] = None,
175
178
  enable_metrics: bool = False,
176
179
  database: Optional[str] = None,
177
180
  schema: Optional[str] = None,
@@ -192,7 +195,8 @@ def submit_directory(
192
195
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
193
196
  spec_overrides: Custom service specification overrides to apply.
194
197
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
195
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
198
+ min_instances: The minimum number of nodes required to start the job. If none specified,
199
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
196
200
  enable_metrics: Whether to enable metrics publishing for the job.
197
201
  database: The database to use.
198
202
  schema: The schema to use.
@@ -221,6 +225,72 @@ def submit_directory(
221
225
  )
222
226
 
223
227
 
228
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
229
+ def submit_from_stage(
230
+ source: str,
231
+ compute_pool: str,
232
+ *,
233
+ entrypoint: str,
234
+ stage_name: str,
235
+ args: Optional[list[str]] = None,
236
+ env_vars: Optional[dict[str, str]] = None,
237
+ pip_requirements: Optional[list[str]] = None,
238
+ external_access_integrations: Optional[list[str]] = None,
239
+ query_warehouse: Optional[str] = None,
240
+ spec_overrides: Optional[dict[str, Any]] = None,
241
+ target_instances: int = 1,
242
+ min_instances: Optional[int] = None,
243
+ enable_metrics: bool = False,
244
+ database: Optional[str] = None,
245
+ schema: Optional[str] = None,
246
+ session: Optional[snowpark.Session] = None,
247
+ ) -> jb.MLJob[None]:
248
+ """
249
+ Submit a directory containing Python script(s) as a job to the compute pool.
250
+
251
+ Args:
252
+ source: a stage path or a stage containing the job payload.
253
+ compute_pool: The compute pool to use for the job.
254
+ entrypoint: a stage path containing the entry point script inside the source directory.
255
+ stage_name: The name of the stage where the job payload will be uploaded.
256
+ args: A list of arguments to pass to the job.
257
+ env_vars: Environment variables to set in container
258
+ pip_requirements: A list of pip requirements for the job.
259
+ external_access_integrations: A list of external access integrations.
260
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
261
+ spec_overrides: Custom service specification overrides to apply.
262
+ target_instances: The number of instances to use for the job. If none specified, single node job is created.
263
+ min_instances: The minimum number of nodes required to start the job. If none specified,
264
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
265
+ enable_metrics: Whether to enable metrics publishing for the job.
266
+ database: The database to use.
267
+ schema: The schema to use.
268
+ session: The Snowpark session to use. If none specified, uses active session.
269
+
270
+
271
+ Returns:
272
+ An object representing the submitted job.
273
+ """
274
+ return _submit_job(
275
+ source=source,
276
+ entrypoint=entrypoint,
277
+ args=args,
278
+ compute_pool=compute_pool,
279
+ stage_name=stage_name,
280
+ env_vars=env_vars,
281
+ pip_requirements=pip_requirements,
282
+ external_access_integrations=external_access_integrations,
283
+ query_warehouse=query_warehouse,
284
+ spec_overrides=spec_overrides,
285
+ target_instances=target_instances,
286
+ min_instances=min_instances,
287
+ enable_metrics=enable_metrics,
288
+ database=database,
289
+ schema=schema,
290
+ session=session,
291
+ )
292
+
293
+
224
294
  @overload
225
295
  def _submit_job(
226
296
  source: str,
@@ -235,7 +305,7 @@ def _submit_job(
235
305
  query_warehouse: Optional[str] = None,
236
306
  spec_overrides: Optional[dict[str, Any]] = None,
237
307
  target_instances: int = 1,
238
- min_instances: int = 1,
308
+ min_instances: Optional[int] = None,
239
309
  enable_metrics: bool = False,
240
310
  database: Optional[str] = None,
241
311
  schema: Optional[str] = None,
@@ -258,7 +328,7 @@ def _submit_job(
258
328
  query_warehouse: Optional[str] = None,
259
329
  spec_overrides: Optional[dict[str, Any]] = None,
260
330
  target_instances: int = 1,
261
- min_instances: int = 1,
331
+ min_instances: Optional[int] = None,
262
332
  enable_metrics: bool = False,
263
333
  database: Optional[str] = None,
264
334
  schema: Optional[str] = None,
@@ -292,7 +362,7 @@ def _submit_job(
292
362
  query_warehouse: Optional[str] = None,
293
363
  spec_overrides: Optional[dict[str, Any]] = None,
294
364
  target_instances: int = 1,
295
- min_instances: int = 1,
365
+ min_instances: Optional[int] = None,
296
366
  enable_metrics: bool = False,
297
367
  database: Optional[str] = None,
298
368
  schema: Optional[str] = None,
@@ -313,7 +383,8 @@ def _submit_job(
313
383
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
314
384
  spec_overrides: Custom service specification overrides to apply.
315
385
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
316
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
386
+ min_instances: The minimum number of nodes required to start the job. If none specified,
387
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
317
388
  enable_metrics: Whether to enable metrics publishing for the job.
318
389
  database: The database to use.
319
390
  schema: The schema to use.
@@ -328,13 +399,25 @@ def _submit_job(
328
399
  """
329
400
  if database and not schema:
330
401
  raise ValueError("Schema must be specified if database is specified.")
331
- if target_instances < 1 or min_instances < 1:
332
- raise ValueError("target_instances and min_instances must be greater than 0.")
333
- if min_instances > target_instances:
334
- raise ValueError("min_instances must be less than or equal to target_instances.")
402
+ if target_instances < 1:
403
+ raise ValueError("target_instances must be greater than 0.")
404
+
405
+ min_instances = target_instances if min_instances is None else min_instances
406
+ if not (0 < min_instances <= target_instances):
407
+ raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
335
408
 
336
409
  session = session or get_active_session()
337
410
 
411
+ if min_instances > 1:
412
+ # Validate min_instances against compute pool max_nodes
413
+ pool_info = jb._get_compute_pool_info(session, compute_pool)
414
+ max_nodes = int(pool_info["max_nodes"])
415
+ if min_instances > max_nodes:
416
+ raise ValueError(
417
+ f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
418
+ f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
419
+ )
420
+
338
421
  # Validate database and schema identifiers on client side since
339
422
  # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
340
423
  database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
@@ -426,3 +426,61 @@ class Model:
426
426
  schema_name=new_schema or self._model_ops._model_client._schema_name,
427
427
  )
428
428
  self._model_name = new_model
429
+
430
+ def _repr_html_(self) -> str:
431
+ """Generate an HTML representation of the model.
432
+
433
+ Returns:
434
+ str: HTML string containing formatted model details.
435
+ """
436
+ from snowflake.ml.utils import html_utils
437
+
438
+ # Get default version
439
+ default_version = self.default.version_name
440
+
441
+ # Get versions info
442
+ try:
443
+ versions_df = self.show_versions()
444
+ versions_html = ""
445
+
446
+ for _, row in versions_df.iterrows():
447
+ versions_html += html_utils.create_version_item(
448
+ version_name=row["name"],
449
+ created_on=str(row["created_on"]),
450
+ comment=str(row.get("comment", "")),
451
+ is_default=bool(row["is_default_version"]),
452
+ )
453
+ except Exception:
454
+ versions_html = html_utils.create_error_message("Error retrieving versions")
455
+
456
+ # Get tags
457
+ try:
458
+ tags = self.show_tags()
459
+ if not tags:
460
+ tags_html = html_utils.create_error_message("No tags available")
461
+ else:
462
+ tags_html = ""
463
+ for tag_name, tag_value in tags.items():
464
+ tags_html += html_utils.create_tag_item(tag_name, tag_value)
465
+ except Exception:
466
+ tags_html = html_utils.create_error_message("Error retrieving tags")
467
+
468
+ # Create main content sections
469
+ main_info = html_utils.create_grid_section(
470
+ [
471
+ ("Model Name", self.name),
472
+ ("Full Name", self.fully_qualified_name),
473
+ ("Description", self.description),
474
+ ("Default Version", default_version),
475
+ ]
476
+ )
477
+
478
+ versions_section = html_utils.create_section_header("Versions") + html_utils.create_content_section(
479
+ versions_html
480
+ )
481
+
482
+ tags_section = html_utils.create_section_header("Tags") + html_utils.create_content_section(tags_html)
483
+
484
+ content = main_info + versions_section + tags_section
485
+
486
+ return html_utils.create_base_container("Model Details", content)
@@ -38,6 +38,96 @@ class ModelVersion(lineage_node.LineageNode):
38
38
  def __init__(self) -> None:
39
39
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
40
40
 
41
+ def _repr_html_(self) -> str:
42
+ """Generate an HTML representation of the model version.
43
+
44
+ Returns:
45
+ str: HTML string containing formatted model version details.
46
+ """
47
+ from snowflake.ml.utils import html_utils
48
+
49
+ # Get task
50
+ try:
51
+ task = self.get_model_task().value
52
+ except Exception:
53
+ task = (
54
+ html_utils.create_error_message("Not available")
55
+ .replace('<em style="color: #888; font-style: italic;">', "")
56
+ .replace("</em>", "")
57
+ )
58
+
59
+ # Get functions info for display
60
+ try:
61
+ functions = self.show_functions()
62
+ if not functions:
63
+ functions_html = html_utils.create_error_message("No functions available")
64
+ else:
65
+ functions_list = []
66
+ for func in functions:
67
+ try:
68
+ sig_html = func["signature"]._repr_html_()
69
+ except Exception:
70
+ # Fallback to simple display if can't display signature
71
+ sig_html = f"<pre style='margin: 5px 0;'>{func['signature']}</pre>"
72
+
73
+ function_content = f"""
74
+ <div style="margin: 5px 0;">
75
+ <strong>Target Method:</strong> {func['target_method']}
76
+ </div>
77
+ <div style="margin: 5px 0;">
78
+ <strong>Function Type:</strong> {func.get('target_method_function_type', 'N/A')}
79
+ </div>
80
+ <div style="margin: 5px 0;">
81
+ <strong>Partitioned:</strong> {func.get('is_partitioned', False)}
82
+ </div>
83
+ <div style="margin: 10px 0;">
84
+ <strong>Signature:</strong>
85
+ {sig_html}
86
+ </div>
87
+ """
88
+
89
+ functions_list.append(
90
+ html_utils.create_collapsible_section(
91
+ title=func["name"], content=function_content, open_by_default=False
92
+ )
93
+ )
94
+ functions_html = "".join(functions_list)
95
+ except Exception:
96
+ functions_html = html_utils.create_error_message("Error retrieving functions")
97
+
98
+ # Get metrics for display
99
+ try:
100
+ metrics = self.show_metrics()
101
+ if not metrics:
102
+ metrics_html = html_utils.create_error_message("No metrics available")
103
+ else:
104
+ metrics_html = ""
105
+ for metric_name, value in metrics.items():
106
+ metrics_html += html_utils.create_metric_item(metric_name, value)
107
+ except Exception:
108
+ metrics_html = html_utils.create_error_message("Error retrieving metrics")
109
+
110
+ # Create main content sections
111
+ main_info = html_utils.create_grid_section(
112
+ [
113
+ ("Model Name", self.model_name),
114
+ ("Version", f'<strong style="color: #28a745;">{self.version_name}</strong>'),
115
+ ("Full Name", self.fully_qualified_model_name),
116
+ ("Description", self.description),
117
+ ("Task", task),
118
+ ]
119
+ )
120
+
121
+ functions_section = html_utils.create_section_header("Functions") + html_utils.create_content_section(
122
+ functions_html
123
+ )
124
+
125
+ metrics_section = html_utils.create_section_header("Metrics") + html_utils.create_content_section(metrics_html)
126
+
127
+ content = main_info + functions_section + metrics_section
128
+
129
+ return html_utils.create_base_container("Model Version Details", content)
130
+
41
131
  @classmethod
42
132
  def _ref(
43
133
  cls,
@@ -643,14 +643,17 @@ class ModelOperator:
643
643
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
644
644
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
645
645
 
646
- result = []
647
-
646
+ result: list[ServiceInfo] = []
648
647
  for fully_qualified_service_name in fully_qualified_service_names:
649
648
  ingress_url: Optional[str] = None
650
649
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
651
- service_status, _ = self._service_client.get_service_status(
650
+ statuses = self._service_client.get_service_container_statuses(
652
651
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
653
652
  )
653
+ if len(statuses) == 0:
654
+ return result
655
+
656
+ service_status = statuses[0].service_status
654
657
  for res_row in self._service_client.show_endpoints(
655
658
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
656
659
  ):