snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. snowflake/ml/_internal/utils/service_logger.py +31 -17
  2. snowflake/ml/experiment/callback/keras.py +63 -0
  3. snowflake/ml/experiment/callback/lightgbm.py +59 -0
  4. snowflake/ml/experiment/callback/xgboost.py +67 -0
  5. snowflake/ml/experiment/utils.py +14 -0
  6. snowflake/ml/jobs/_utils/__init__.py +0 -0
  7. snowflake/ml/jobs/_utils/constants.py +4 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +55 -21
  9. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  10. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  11. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
  13. snowflake/ml/jobs/_utils/spec_utils.py +41 -8
  14. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  15. snowflake/ml/jobs/_utils/types.py +5 -7
  16. snowflake/ml/jobs/job.py +1 -1
  17. snowflake/ml/jobs/manager.py +1 -13
  18. snowflake/ml/model/_client/model/model_version_impl.py +219 -55
  19. snowflake/ml/model/_client/ops/service_ops.py +230 -30
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
  21. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
  22. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  24. snowflake/ml/model/event_handler.py +87 -18
  25. snowflake/ml/model/inference_engine.py +5 -0
  26. snowflake/ml/model/models/huggingface_pipeline.py +74 -51
  27. snowflake/ml/model/type_hints.py +26 -1
  28. snowflake/ml/registry/_manager/model_manager.py +37 -70
  29. snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
  30. snowflake/ml/registry/registry.py +0 -19
  31. snowflake/ml/version.py +1 -1
  32. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
  33. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
  34. snowflake/ml/experiment/callback.py +0 -121
  35. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
  36. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,15 @@ from typing import Optional
9
9
 
10
10
  import platformdirs
11
11
 
12
+ # Module-level logger for operational messages that should appear on console
13
+ stdout_handler = logging.StreamHandler(sys.stdout)
14
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
15
+
16
+ console_logger = logging.getLogger(__name__)
17
+ console_logger.addHandler(stdout_handler)
18
+ console_logger.setLevel(logging.INFO)
19
+ console_logger.propagate = False
20
+
12
21
 
13
22
  class LogColor(enum.Enum):
14
23
  GREY = "\x1b[38;20m"
@@ -109,42 +118,36 @@ def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
109
118
  """Get or create a parent logger with FileHandler for the operation."""
110
119
  parent_logger_name = f"snowflake_ml_operation_{operation_id}"
111
120
  parent_logger = logging.getLogger(parent_logger_name)
121
+ parent_logger.setLevel(logging.DEBUG)
122
+ parent_logger.propagate = False
112
123
 
113
- # Only add handler if it doesn't exist yet
114
124
  if not parent_logger.handlers:
115
125
  log_file_path = _get_log_file_path(operation_id)
116
126
 
117
127
  if log_file_path:
118
- # Successfully found a writable location
119
128
  try:
120
129
  file_handler = logging.FileHandler(log_file_path)
121
130
  file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
122
131
  parent_logger.addHandler(file_handler)
123
- parent_logger.setLevel(logging.DEBUG)
124
- parent_logger.propagate = False # Don't propagate to root logger
125
132
 
126
- # Log the file location
127
- parent_logger.warning(f"Operation logs saved to: {log_file_path}")
133
+ console_logger.info(f"create_service logs saved to: {log_file_path}")
128
134
  except OSError as e:
129
- # Even though we found a path, file creation failed
130
- # Fall back to console-only logging
131
- parent_logger.setLevel(logging.DEBUG)
132
- parent_logger.propagate = False
133
- parent_logger.warning(f"Could not create log file at {log_file_path}: {e}. Using console-only logging.")
135
+ console_logger.warning(f"Could not create log file at {log_file_path}: {e}.")
134
136
  else:
135
137
  # No writable location found, use console-only logging
136
- parent_logger.setLevel(logging.DEBUG)
137
- parent_logger.propagate = False
138
- parent_logger.warning("Filesystem appears to be readonly. Using console-only logging.")
138
+ console_logger.warning("No writable location found for create_service log file.")
139
+
140
+ if logging.getLogger().level > logging.INFO:
141
+ console_logger.info(
142
+ "To see logs in console, set log level to INFO: logging.getLogger().setLevel(logging.INFO)"
143
+ )
139
144
 
140
145
  return parent_logger
141
146
 
142
147
 
143
148
  def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
144
149
  logger = logging.getLogger(logger_name)
145
- handler = logging.StreamHandler(sys.stdout)
146
- handler.setFormatter(CustomFormatter(info_color))
147
- logger.addHandler(handler)
150
+ root_logger = logging.getLogger()
148
151
 
149
152
  # If operation_id provided, set up parent logger with file handler
150
153
  if operation_id:
@@ -152,6 +155,17 @@ def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[st
152
155
  logger.parent = parent_logger
153
156
  logger.propagate = True
154
157
 
158
+ if root_logger.level <= logging.INFO:
159
+ handler = logging.StreamHandler(sys.stdout)
160
+ handler.setFormatter(CustomFormatter(info_color))
161
+ logger.addHandler(handler)
162
+ else:
163
+ # No operation_id - add console handler only if user wants verbose logging
164
+ if root_logger.level <= logging.INFO and not logger.handlers:
165
+ handler = logging.StreamHandler(sys.stdout)
166
+ handler.setFormatter(CustomFormatter(info_color))
167
+ logger.addHandler(handler)
168
+
155
169
  return logger
156
170
 
157
171
 
@@ -0,0 +1,63 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+ from warnings import warn
4
+
5
+ import keras
6
+
7
+ from snowflake.ml.experiment import utils
8
+
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
11
+ from snowflake.ml.model.model_signature import ModelSignature
12
+
13
+
14
+ class SnowflakeKerasCallback(keras.callbacks.Callback):
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "ExperimentTracking",
18
+ log_model: bool = True,
19
+ log_metrics: bool = True,
20
+ log_params: bool = True,
21
+ log_every_n_epochs: int = 1,
22
+ model_name: Optional[str] = None,
23
+ model_signature: Optional["ModelSignature"] = None,
24
+ ) -> None:
25
+ self._experiment_tracking = experiment_tracking
26
+ self.log_model = log_model
27
+ self.log_metrics = log_metrics
28
+ self.log_params = log_params
29
+ if log_every_n_epochs < 1:
30
+ raise ValueError("`log_every_n_epochs` must be positive.")
31
+ self.log_every_n_epochs = log_every_n_epochs
32
+ self.model_name = model_name
33
+ self.model_signature = model_signature
34
+
35
+ def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
36
+ if self.log_params:
37
+ params = json.loads(self.model.to_json())
38
+ self._experiment_tracking.log_params(utils.flatten_nested_params(params))
39
+
40
+ def on_epoch_end(self, epoch: int, logs: Optional[dict[str, Any]] = None) -> None:
41
+ if self.log_metrics and logs and epoch % self.log_every_n_epochs == 0:
42
+ for key, value in logs.items():
43
+ try:
44
+ value = float(value)
45
+ except Exception:
46
+ pass
47
+ else:
48
+ self._experiment_tracking.log_metric(key=key, value=value, step=epoch)
49
+
50
+ def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None:
51
+ if self.log_model:
52
+ if not self.model_signature:
53
+ warn(
54
+ "Model will not be logged because model signature is missing. "
55
+ "To autolog the model, please specify `model_signature` when constructing SnowflakeKerasCallback."
56
+ )
57
+ return
58
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
59
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
60
+ model=self.model,
61
+ model_name=model_name,
62
+ signatures={"predict": self.model_signature},
63
+ )
@@ -0,0 +1,59 @@
1
+ from typing import TYPE_CHECKING, Optional
2
+ from warnings import warn
3
+
4
+ import lightgbm as lgb
5
+
6
+ if TYPE_CHECKING:
7
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
8
+ from snowflake.ml.model.model_signature import ModelSignature
9
+
10
+
11
+ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
12
+ def __init__(
13
+ self,
14
+ experiment_tracking: "ExperimentTracking",
15
+ log_model: bool = True,
16
+ log_metrics: bool = True,
17
+ log_params: bool = True,
18
+ log_every_n_epochs: int = 1,
19
+ model_name: Optional[str] = None,
20
+ model_signature: Optional["ModelSignature"] = None,
21
+ ) -> None:
22
+ self._experiment_tracking = experiment_tracking
23
+ self.log_model = log_model
24
+ self.log_metrics = log_metrics
25
+ self.log_params = log_params
26
+ if log_every_n_epochs < 1:
27
+ raise ValueError("`log_every_n_epochs` must be positive.")
28
+ self.log_every_n_epochs = log_every_n_epochs
29
+ self.model_name = model_name
30
+ self.model_signature = model_signature
31
+
32
+ super().__init__(eval_result={})
33
+
34
+ def __call__(self, env: lgb.callback.CallbackEnv) -> None:
35
+ if self.log_params:
36
+ if env.iteration == env.begin_iteration: # Log params only at the first iteration
37
+ self._experiment_tracking.log_params(env.params)
38
+
39
+ if self.log_metrics and env.iteration % self.log_every_n_epochs == 0:
40
+ super().__call__(env)
41
+ for dataset_name, metrics in self.eval_result.items():
42
+ for metric_name, log in metrics.items():
43
+ metric_key = dataset_name + ":" + metric_name
44
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
45
+
46
+ if self.log_model:
47
+ if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
48
+ if self.model_signature:
49
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
50
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
51
+ model=env.model,
52
+ model_name=model_name,
53
+ signatures={"predict": self.model_signature},
54
+ )
55
+ else:
56
+ warn(
57
+ "Model will not be logged because model signature is missing. To autolog the model, "
58
+ "please specify `model_signature` when constructing SnowflakeLightgbmCallback."
59
+ )
@@ -0,0 +1,67 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+ from warnings import warn
4
+
5
+ import xgboost as xgb
6
+
7
+ from snowflake.ml.experiment import utils
8
+
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
11
+ from snowflake.ml.model.model_signature import ModelSignature
12
+
13
+
14
+ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "ExperimentTracking",
18
+ log_model: bool = True,
19
+ log_metrics: bool = True,
20
+ log_params: bool = True,
21
+ log_every_n_epochs: int = 1,
22
+ model_name: Optional[str] = None,
23
+ model_signature: Optional["ModelSignature"] = None,
24
+ ) -> None:
25
+ self._experiment_tracking = experiment_tracking
26
+ self.log_model = log_model
27
+ self.log_metrics = log_metrics
28
+ self.log_params = log_params
29
+ if log_every_n_epochs < 1:
30
+ raise ValueError("`log_every_n_epochs` must be positive.")
31
+ self.log_every_n_epochs = log_every_n_epochs
32
+ self.model_name = model_name
33
+ self.model_signature = model_signature
34
+
35
+ def before_training(self, model: xgb.Booster) -> xgb.Booster:
36
+ if self.log_params:
37
+ params = json.loads(model.save_config())
38
+ self._experiment_tracking.log_params(utils.flatten_nested_params(params))
39
+
40
+ return model
41
+
42
+ def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
43
+ if self.log_metrics and epoch % self.log_every_n_epochs == 0:
44
+ for dataset_name, metrics in evals_log.items():
45
+ for metric_name, log in metrics.items():
46
+ metric_key = dataset_name + ":" + metric_name
47
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
48
+
49
+ return False
50
+
51
+ def after_training(self, model: xgb.Booster) -> xgb.Booster:
52
+ if self.log_model:
53
+ if not self.model_signature:
54
+ warn(
55
+ "Model will not be logged because model signature is missing. "
56
+ "To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
57
+ )
58
+ return model
59
+
60
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
61
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
62
+ model=model,
63
+ model_name=model_name,
64
+ signatures={"predict": self.model_signature},
65
+ )
66
+
67
+ return model
@@ -0,0 +1,14 @@
1
+ from typing import Any, Union
2
+
3
+
4
+ def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
5
+ flat_params = {}
6
+ items = params.items() if isinstance(params, dict) else enumerate(params)
7
+ for key, value in items:
8
+ key = str(key).replace(".", "_") # Replace dots in keys to avoid collisions involving nested keys
9
+ new_prefix = f"{prefix}.{key}" if prefix else key
10
+ if isinstance(value, (dict, list)):
11
+ flat_params.update(flatten_nested_params(value, new_prefix))
12
+ else:
13
+ flat_params[new_prefix] = value
14
+ return flat_params
File without changes
@@ -28,7 +28,7 @@ OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
28
28
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
29
29
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
30
30
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
31
- DEFAULT_IMAGE_TAG = "1.5.0"
31
+ DEFAULT_IMAGE_TAG = "1.6.2"
32
32
  DEFAULT_ENTRYPOINT_PATH = "func.py"
33
33
 
34
34
  # Percent of container memory to allocate for /dev/shm volume
@@ -98,3 +98,6 @@ CLOUD_INSTANCE_FAMILIES = {
98
98
  SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
99
99
  SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
100
100
  }
101
+
102
+ # runtime version environment variable
103
+ ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import importlib
2
3
  import inspect
3
4
  import io
4
5
  import itertools
@@ -7,6 +8,7 @@ import logging
7
8
  import pickle
8
9
  import sys
9
10
  import textwrap
11
+ from importlib.abc import Traversable
10
12
  from pathlib import Path, PurePath
11
13
  from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
12
14
 
@@ -63,6 +65,13 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
63
65
 
64
66
  ##### Set up Python environment #####
65
67
  export PYTHONPATH=/opt/env/site-packages/
68
+ MLRS_SYSTEM_REQUIREMENTS_FILE=${{MLRS_SYSTEM_REQUIREMENTS_FILE:-"${{SYSTEM_DIR}}/requirements.txt"}}
69
+
70
+ if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
71
+ echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
72
+ pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
73
+ fi
74
+
66
75
  MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
67
76
  if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
68
77
  # TODO: Prevent collisions with MLRS packages using virtualenvs
@@ -255,11 +264,24 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
255
264
  # Manually traverse the directory and upload each file, since Snowflake PUT
256
265
  # can't handle directories. Reduce the number of PUT operations by using
257
266
  # wildcard patterns to batch upload files with the same extension.
258
- for path in {
259
- p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
260
- for p in source_path.resolve().rglob("*")
261
- if p.is_file()
262
- }:
267
+ upload_path_patterns = set()
268
+ for p in source_path.resolve().rglob("*"):
269
+ if p.is_dir():
270
+ continue
271
+ if p.name.startswith("."):
272
+ # Hidden files: use .* pattern for batch upload
273
+ if p.suffix:
274
+ upload_path_patterns.add(p.parent.joinpath(f".*{p.suffix}"))
275
+ else:
276
+ upload_path_patterns.add(p.parent.joinpath(".*"))
277
+ else:
278
+ # Regular files: use * pattern for batch upload
279
+ if p.suffix:
280
+ upload_path_patterns.add(p.parent.joinpath(f"*{p.suffix}"))
281
+ else:
282
+ upload_path_patterns.add(p)
283
+
284
+ for path in upload_path_patterns:
263
285
  session.file.put(
264
286
  str(path),
265
287
  payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
@@ -275,6 +297,27 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
275
297
  )
276
298
 
277
299
 
300
+ def upload_system_resources(session: snowpark.Session, stage_path: PurePath) -> None:
301
+ resource_ref = importlib.resources.files(__package__).joinpath("scripts")
302
+
303
+ def upload_dir(ref: Traversable, relative_path: str = "") -> None:
304
+ for item in ref.iterdir():
305
+ current_path = Path(relative_path) / item.name if relative_path else Path(item.name)
306
+ if item.is_dir():
307
+ # Recursively process subdirectories
308
+ upload_dir(item, str(current_path))
309
+ elif item.is_file():
310
+ content = item.read_bytes()
311
+ session.file.put_stream(
312
+ io.BytesIO(content),
313
+ stage_path.joinpath(current_path).as_posix(),
314
+ auto_compress=False,
315
+ overwrite=True,
316
+ )
317
+
318
+ upload_dir(resource_ref)
319
+
320
+
278
321
  def resolve_source(
279
322
  source: Union[types.PayloadPath, Callable[..., Any]]
280
323
  ) -> Union[types.PayloadPath, Callable[..., Any]]:
@@ -454,8 +497,6 @@ class JobPayload:
454
497
  overwrite=True,
455
498
  )
456
499
  source = Path(entrypoint.file_path.parent)
457
- if not any(r.startswith("cloudpickle") for r in pip_requirements):
458
- pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
459
500
 
460
501
  elif isinstance(source, stage_utils.StagePath):
461
502
  # copy payload to stage
@@ -470,19 +511,20 @@ class JobPayload:
470
511
 
471
512
  upload_payloads(session, app_stage_path, *additional_payload_specs)
472
513
 
473
- # Upload requirements to app/ directory
474
- # TODO: Check if payload includes both a requirements.txt file and pip_requirements
514
+ if not any(r.startswith("cloudpickle") for r in pip_requirements):
515
+ pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
516
+
517
+ # Upload system scripts and requirements.txt generated by pip_requirements to system/ directory
518
+ system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
475
519
  if pip_requirements:
476
520
  # Upload requirements.txt to stage
477
521
  session.file.put_stream(
478
522
  io.BytesIO("\n".join(pip_requirements).encode()),
479
- stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
523
+ stage_location=system_stage_path.joinpath("requirements.txt").as_posix(),
480
524
  auto_compress=False,
481
525
  overwrite=True,
482
526
  )
483
527
 
484
- # Upload startup script to system/ directory within payload
485
- system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
486
528
  # TODO: Make sure payload does not include file with same name
487
529
  session.file.put_stream(
488
530
  io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
@@ -491,15 +533,7 @@ class JobPayload:
491
533
  overwrite=False, # FIXME
492
534
  )
493
535
 
494
- scripts_dir = Path(__file__).parent.joinpath("scripts")
495
- for script_file in scripts_dir.glob("*"):
496
- if script_file.is_file():
497
- session.file.put(
498
- script_file.as_posix(),
499
- system_stage_path.as_posix(),
500
- overwrite=True,
501
- auto_compress=False,
502
- )
536
+ upload_system_resources(session, system_stage_path)
503
537
  python_entrypoint: list[Union[str, PurePath]] = [
504
538
  PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
505
539
  PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
@@ -4,6 +4,7 @@ from snowflake import snowpark
4
4
  from snowflake.snowpark import Row
5
5
  from snowflake.snowpark._internal import utils
6
6
  from snowflake.snowpark._internal.analyzer import snowflake_plan
7
+ from snowflake.snowpark._internal.utils import is_in_stored_procedure
7
8
 
8
9
 
9
10
  def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
@@ -14,7 +15,10 @@ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> lis
14
15
 
15
16
  @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
16
17
  def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
17
- result = session._conn.run_query(query=query_text, params=params, _force_qmark_paramstyle=True)
18
+ kwargs: dict[str, Any] = {"query": query_text, "params": params}
19
+ if not is_in_stored_procedure(): # type: ignore[no-untyped-call]
20
+ kwargs["_force_qmark_paramstyle"] = True
21
+ result = session._conn.run_query(**kwargs)
18
22
  if not isinstance(result, dict) or "data" not in result:
19
23
  raise ValueError(f"Unprocessable result: {result}")
20
24
  return result_set_to_rows(session, result)
@@ -0,0 +1,63 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from packaging.version import Version
4
+ from pydantic import BaseModel, Field, RootModel, field_validator
5
+
6
+
7
+ class SpcsContainerRuntime(BaseModel):
8
+ python_version: Version = Field(alias="pythonVersion")
9
+ hardware_type: str = Field(alias="hardwareType")
10
+ runtime_container_image: str = Field(alias="runtimeContainerImage")
11
+
12
+ @field_validator("python_version", mode="before")
13
+ @classmethod
14
+ def validate_python_version(cls, v: Union[str, Version]) -> Version:
15
+ if isinstance(v, Version):
16
+ return v
17
+ try:
18
+ return Version(v)
19
+ except Exception:
20
+ raise ValueError(f"Invalid Python version format: {v}")
21
+
22
+ class Config:
23
+ frozen = True
24
+ extra = "allow"
25
+ arbitrary_types_allowed = True
26
+
27
+
28
+ class RuntimeEnvironmentEntry(BaseModel):
29
+ spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
30
+
31
+ class Config:
32
+ extra = "allow"
33
+ frozen = True
34
+
35
+
36
+ class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
37
+ @field_validator("root", mode="before")
38
+ @classmethod
39
+ def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
40
+ """
41
+ Pre-validation hook: keep only those items at the root level
42
+ whose values are dicts. Non-dict values will be dropped.
43
+
44
+ Args:
45
+ data: The input data to filter, expected to be a dictionary.
46
+
47
+ Returns:
48
+ A dictionary containing only the key-value pairs where values are dictionaries.
49
+
50
+ Raises:
51
+ ValueError: If input data is not a dictionary.
52
+ """
53
+ # If the entire root is not a dict, raise error immediately
54
+ if not isinstance(data, dict):
55
+ raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
56
+
57
+ # Filter out any key whose value is not a dict
58
+ return {key: value for key, value in data.items() if isinstance(value, dict)}
59
+
60
+ def get_spcs_container_runtimes(self) -> list[SpcsContainerRuntime]:
61
+ return [
62
+ entry.spcs_container_runtime for entry in self.root.values() if entry.spcs_container_runtime is not None
63
+ ]
@@ -47,8 +47,8 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
47
47
  if not result:
48
48
  return None
49
49
 
50
- # Sort by start_time first, then by instance_id
51
- sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"])))
50
+ # Sort by start_time first, then by instance_id. If start_time is null/empty, it will be sorted to the end.
51
+ sorted_instances = sorted(result, key=lambda x: (not bool(x["start_time"]), x["start_time"], int(x["instance_id"])))
52
52
  head_instance = sorted_instances[0]
53
53
  if not head_instance["instance_id"] or not head_instance["ip_address"]:
54
54
  return None
@@ -173,10 +173,10 @@ def wait_for_instances(
173
173
 
174
174
  start_time = time.time()
175
175
  current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
176
- logger.debug(
176
+ logger.info(
177
177
  "Waiting for instances to be ready "
178
- "(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
179
- min_instances, target_instances, timeout, check_interval
178
+ "(min_instances={}, target_instances={}, min_wait_time={}s, timeout={}s, max_check_interval={}s)".format(
179
+ min_instances, target_instances, min_wait_time, timeout, check_interval
180
180
  )
181
181
  )
182
182
 
@@ -191,7 +191,7 @@ def wait_for_instances(
191
191
  logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
192
192
  return
193
193
 
194
- logger.debug(
194
+ logger.info(
195
195
  f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
196
196
  f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
197
197
  )
@@ -199,7 +199,7 @@ def wait_for_instances(
199
199
  current_interval = min(current_interval * 2, check_interval) # Exponential backoff
200
200
 
201
201
  raise TimeoutError(
202
- f"Timed out after {timeout}s waiting for {min_instances} instances, only " f"{total_nodes} available"
202
+ f"Timed out after {elapsed}s waiting for {min_instances} instances, only " f"{total_nodes} available"
203
203
  )
204
204
 
205
205
 
@@ -1,12 +1,14 @@
1
1
  import logging
2
2
  import os
3
+ import sys
3
4
  from math import ceil
4
5
  from pathlib import PurePath
5
- from typing import Any, Optional, Union
6
+ from typing import Any, Literal, Optional, Union
6
7
 
7
8
  from snowflake import snowpark
8
9
  from snowflake.ml._internal.utils import snowflake_env
9
10
  from snowflake.ml.jobs._utils import constants, query_helper, types
11
+ from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
10
12
 
11
13
 
12
14
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
@@ -28,22 +30,53 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
28
30
  )
29
31
 
30
32
 
33
+ def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
34
+ rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
35
+ if not rows:
36
+ return None
37
+ try:
38
+ runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
39
+ spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
40
+ except Exception as e:
41
+ logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
42
+ return None
43
+
44
+ selected_runtime = next(
45
+ (
46
+ runtime
47
+ for runtime in spcs_container_runtimes
48
+ if (
49
+ runtime.hardware_type.lower() == target_hardware.lower()
50
+ and runtime.python_version.major == sys.version_info.major
51
+ and runtime.python_version.minor == sys.version_info.minor
52
+ )
53
+ ),
54
+ None,
55
+ )
56
+ return selected_runtime.runtime_container_image if selected_runtime else None
57
+
58
+
31
59
  def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
32
60
  # Retrieve compute pool node resources
33
61
  resources = _get_node_resources(session, compute_pool=compute_pool)
34
62
 
35
63
  # Use MLRuntime image
36
- image_repo = constants.DEFAULT_IMAGE_REPO
37
- image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
38
- image_tag = _get_runtime_image_tag()
64
+ hardware = "GPU" if resources.gpu > 0 else "CPU"
65
+ container_image = None
66
+ if os.environ.get(constants.ENABLE_IMAGE_VERSION_ENV_VAR, "").lower() == "true":
67
+ container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
68
+
69
+ if not container_image:
70
+ image_repo = constants.DEFAULT_IMAGE_REPO
71
+ image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
72
+ image_tag = _get_runtime_image_tag()
73
+ container_image = f"{image_repo}/{image_name}:{image_tag}"
39
74
 
40
75
  # TODO: Should each instance consume the entire pod?
41
76
  return types.ImageSpec(
42
- repo=image_repo,
43
- image_name=image_name,
44
- image_tag=image_tag,
45
77
  resource_requests=resources,
46
78
  resource_limits=resources,
79
+ container_image=container_image,
47
80
  )
48
81
 
49
82
 
@@ -220,7 +253,7 @@ def generate_service_spec(
220
253
  "containers": [
221
254
  {
222
255
  "name": constants.DEFAULT_CONTAINER_NAME,
223
- "image": image_spec.full_name,
256
+ "image": image_spec.container_image,
224
257
  "command": ["/usr/local/bin/_entrypoint.sh"],
225
258
  "args": [
226
259
  (stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint