snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +118 -4
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback/lightgbm.py +55 -0
  9. snowflake/ml/experiment/callback/xgboost.py +63 -0
  10. snowflake/ml/experiment/utils.py +14 -0
  11. snowflake/ml/jobs/_utils/constants.py +15 -4
  12. snowflake/ml/jobs/_utils/payload_utils.py +159 -52
  13. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  14. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
  15. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  16. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  17. snowflake/ml/jobs/_utils/types.py +64 -4
  18. snowflake/ml/jobs/job.py +22 -6
  19. snowflake/ml/jobs/manager.py +5 -3
  20. snowflake/ml/model/_client/model/model_version_impl.py +56 -48
  21. snowflake/ml/model/_client/ops/service_ops.py +194 -14
  22. snowflake/ml/model/_client/sql/service.py +1 -38
  23. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  24. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  25. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  26. snowflake/ml/model/_signatures/utils.py +4 -0
  27. snowflake/ml/model/event_handler.py +87 -18
  28. snowflake/ml/model/model_signature.py +2 -0
  29. snowflake/ml/model/models/huggingface_pipeline.py +71 -49
  30. snowflake/ml/model/type_hints.py +26 -1
  31. snowflake/ml/registry/_manager/model_manager.py +30 -35
  32. snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
  33. snowflake/ml/registry/registry.py +0 -19
  34. snowflake/ml/version.py +1 -1
  35. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
  36. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
  37. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
  39. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ class SerializableSessionMixin:
21
21
 
22
22
  def __getstate__(self) -> dict[str, Any]:
23
23
  """Customize pickling to exclude non-serializable session and related components."""
24
- if hasattr(super(), "__getstate__"):
25
- state: dict[str, Any] = super().__getstate__() # type: ignore[misc]
26
- else:
27
- state = self.__dict__.copy()
24
+ parent_state = (
25
+ super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
26
+ if hasattr(super(), "__getstate__")
27
+ else self.__dict__
28
+ )
29
+ state = dict(parent_state) # Create a copy so we can safely modify the state
28
30
 
29
31
  # Save session metadata for validation during unpickling
30
32
  session = state.pop(_SESSION_KEY, None)
@@ -1,6 +1,22 @@
1
1
  import enum
2
2
  import logging
3
+ import os
3
4
  import sys
5
+ import tempfile
6
+ import time
7
+ import uuid
8
+ from typing import Optional
9
+
10
+ import platformdirs
11
+
12
+ # Module-level logger for operational messages that should appear on console
13
+ stdout_handler = logging.StreamHandler(sys.stdout)
14
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
15
+
16
+ console_logger = logging.getLogger(__name__)
17
+ console_logger.addHandler(stdout_handler)
18
+ console_logger.setLevel(logging.INFO)
19
+ console_logger.propagate = False
4
20
 
5
21
 
6
22
  class LogColor(enum.Enum):
@@ -57,9 +73,107 @@ class CustomFormatter(logging.Formatter):
57
73
  return "\n".join(formatted_lines)
58
74
 
59
75
 
60
- def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger:
76
+ def _test_writability(directory: str) -> bool:
77
+ """Test if a directory is writable by creating and removing a test file."""
78
+ try:
79
+ os.makedirs(directory, exist_ok=True)
80
+ test_file = os.path.join(directory, f".write_test_{uuid.uuid4().hex[:8]}")
81
+ with open(test_file, "w") as f:
82
+ f.write("test")
83
+ os.remove(test_file)
84
+ return True
85
+ except OSError:
86
+ return False
87
+
88
+
89
+ def _try_log_location(log_dir: str, operation_id: str) -> Optional[str]:
90
+ """Try to create a log file in the given directory if it's writable."""
91
+ if _test_writability(log_dir):
92
+ return os.path.join(log_dir, f"{operation_id}.log")
93
+ return None
94
+
95
+
96
+ def _get_log_file_path(operation_id: str) -> Optional[str]:
97
+ """Get platform-independent log file path. Returns None if no writable location found."""
98
+ # Try locations in order of preference
99
+ locations = [
100
+ # Primary: User log directory
101
+ platformdirs.user_log_dir("snowflake-ml", "Snowflake"),
102
+ # Fallback 1: System temp directory
103
+ os.path.join(tempfile.gettempdir(), "snowflake-ml-logs"),
104
+ # Fallback 2: Current working directory
105
+ ".",
106
+ ]
107
+
108
+ for location in locations:
109
+ log_file_path = _try_log_location(location, operation_id)
110
+ if log_file_path:
111
+ return log_file_path
112
+
113
+ # No writable location found
114
+ return None
115
+
116
+
117
+ def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
118
+ """Get or create a parent logger with FileHandler for the operation."""
119
+ parent_logger_name = f"snowflake_ml_operation_{operation_id}"
120
+ parent_logger = logging.getLogger(parent_logger_name)
121
+ parent_logger.setLevel(logging.DEBUG)
122
+ parent_logger.propagate = False
123
+
124
+ if not parent_logger.handlers:
125
+ log_file_path = _get_log_file_path(operation_id)
126
+
127
+ if log_file_path:
128
+ try:
129
+ file_handler = logging.FileHandler(log_file_path)
130
+ file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
131
+ parent_logger.addHandler(file_handler)
132
+
133
+ console_logger.info(f"create_service logs saved to: {log_file_path}")
134
+ except OSError as e:
135
+ console_logger.warning(f"Could not create log file at {log_file_path}: {e}.")
136
+ else:
137
+ # No writable location found, use console-only logging
138
+ console_logger.warning("No writable location found for create_service log file.")
139
+
140
+ if logging.getLogger().level > logging.INFO:
141
+ console_logger.info(
142
+ "To see logs in console, set log level to INFO: logging.getLogger().setLevel(logging.INFO)"
143
+ )
144
+
145
+ return parent_logger
146
+
147
+
148
+ def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
61
149
  logger = logging.getLogger(logger_name)
62
- handler = logging.StreamHandler(sys.stdout)
63
- handler.setFormatter(CustomFormatter(info_color))
64
- logger.addHandler(handler)
150
+ root_logger = logging.getLogger()
151
+
152
+ # If operation_id provided, set up parent logger with file handler
153
+ if operation_id:
154
+ parent_logger = _get_or_create_parent_logger(operation_id)
155
+ logger.parent = parent_logger
156
+ logger.propagate = True
157
+
158
+ if root_logger.level <= logging.INFO:
159
+ handler = logging.StreamHandler(sys.stdout)
160
+ handler.setFormatter(CustomFormatter(info_color))
161
+ logger.addHandler(handler)
162
+ else:
163
+ # No operation_id - add console handler only if user wants verbose logging
164
+ if root_logger.level <= logging.INFO and not logger.handlers:
165
+ handler = logging.StreamHandler(sys.stdout)
166
+ handler.setFormatter(CustomFormatter(info_color))
167
+ logger.addHandler(handler)
168
+
65
169
  return logger
170
+
171
+
172
+ def get_operation_id() -> str:
173
+ """Generate a unique operation ID."""
174
+ return f"model_deploy_{uuid.uuid4().hex[:8]}_{int(time.time())}"
175
+
176
+
177
+ def get_log_file_location(operation_id: str) -> Optional[str]:
178
+ """Get the log file path for an operation ID. Returns None if no writable location available."""
179
+ return _get_log_file_path(operation_id)
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
14
14
  import ray
15
15
 
16
16
  from snowflake import snowpark
17
+ from snowflake.ml._internal.utils import mixins
17
18
  from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
18
19
 
19
20
  _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
@@ -44,7 +45,7 @@ class _RecordBatchesBuffer:
44
45
  return popped
45
46
 
46
47
 
47
- class ArrowIngestor(data_ingestor.DataIngestor):
48
+ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin):
48
49
  """Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
49
50
 
50
51
  def __init__(
@@ -71,6 +72,8 @@ class ArrowIngestor(data_ingestor.DataIngestor):
71
72
 
72
73
  @classmethod
73
74
  def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
75
+ if session is None:
76
+ raise ValueError("Session is required")
74
77
  return cls(session, sources)
75
78
 
76
79
  @classmethod
@@ -6,10 +6,9 @@ from typing_extensions import deprecated
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import env, telemetry
9
- from snowflake.ml._internal.utils import mixins
10
9
  from snowflake.ml.data import data_ingestor, data_source
11
10
  from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
12
- from snowflake.snowpark import context as sf_context
11
+ from snowflake.snowpark import context as sp_context
13
12
 
14
13
  if TYPE_CHECKING:
15
14
  import pandas as pd
@@ -22,13 +21,11 @@ if TYPE_CHECKING:
22
21
  from snowflake.ml import dataset
23
22
 
24
23
  _PROJECT = "DataConnector"
25
- _INGESTOR_KEY = "_ingestor"
26
- _INGESTOR_SOURCES_KEY = "ingestor$sources"
27
24
 
28
25
  DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
29
26
 
30
27
 
31
- class DataConnector(mixins.SerializableSessionMixin):
28
+ class DataConnector:
32
29
  """Snowflake data reader which provides application integration connectors"""
33
30
 
34
31
  DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
@@ -36,11 +33,8 @@ class DataConnector(mixins.SerializableSessionMixin):
36
33
  def __init__(
37
34
  self,
38
35
  ingestor: data_ingestor.DataIngestor,
39
- *,
40
- session: Optional[snowpark.Session] = None,
41
36
  **kwargs: Any,
42
37
  ) -> None:
43
- self._session = session
44
38
  self._ingestor = ingestor
45
39
  self._kwargs = kwargs
46
40
 
@@ -63,7 +57,7 @@ class DataConnector(mixins.SerializableSessionMixin):
63
57
  ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
64
58
  **kwargs: Any,
65
59
  ) -> DataConnectorType:
66
- session = session or sf_context.get_active_session()
60
+ session = session or sp_context.get_active_session()
67
61
  source = data_source.DataFrameInfo(query)
68
62
  return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
69
63
 
@@ -107,31 +101,7 @@ class DataConnector(mixins.SerializableSessionMixin):
107
101
  ) -> DataConnectorType:
108
102
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
109
103
  ingestor = ingestor_class.from_sources(session, sources)
110
- return cls(ingestor, **kwargs, session=session)
111
-
112
- def __getstate__(self) -> dict[str, Any]:
113
- """Customize pickling to exclude non-serializable session and related components."""
114
- if hasattr(super(), "__getstate__"):
115
- state = super().__getstate__()
116
- else:
117
- state = self.__dict__.copy()
118
-
119
- ingestor = state.pop(_INGESTOR_KEY)
120
- state[_INGESTOR_SOURCES_KEY] = ingestor.data_sources
121
-
122
- return state
123
-
124
- def __setstate__(self, state: dict[str, Any]) -> None:
125
- """Restore session from context during unpickling."""
126
- data_sources = state.pop(_INGESTOR_SOURCES_KEY)
127
-
128
- if hasattr(super(), "__setstate__"):
129
- super().__setstate__(state)
130
- else:
131
- self.__dict__.update(state)
132
-
133
- assert self._session is not None
134
- self._ingestor = self.DEFAULT_INGESTOR_CLASS.from_sources(self._session, data_sources)
104
+ return cls(ingestor, **kwargs)
135
105
 
136
106
  @property
137
107
  def data_sources(self) -> list[data_source.DataSource]:
@@ -177,7 +177,7 @@ class Dataset(lineage_node.LineageNode):
177
177
  original_exception=RuntimeError("No Dataset version selected."),
178
178
  )
179
179
  if self._reader is None:
180
- self._reader = dataset_reader.DatasetReader.from_dataset(self)
180
+ self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
181
181
  return self._reader
182
182
 
183
183
  @staticmethod
@@ -1,5 +1,4 @@
1
1
  from typing import Any, Optional
2
- from warnings import warn
3
2
 
4
3
  from snowflake import snowpark
5
4
  from snowflake.ml._internal import telemetry
@@ -21,16 +20,11 @@ class DatasetReader(data_connector.DataConnector, mixins.SerializableSessionMixi
21
20
  self,
22
21
  ingestor: data_ingestor.DataIngestor,
23
22
  *,
24
- session: snowpark.Session,
25
23
  snowpark_session: Optional[snowpark.Session] = None,
26
24
  ) -> None:
27
- if snowpark_session is not None:
28
- warn(
29
- "Argument snowpark_session is deprecated and will be removed in a future release. Use session instead."
30
- )
31
- session = snowpark_session
32
- super().__init__(ingestor, session=session)
25
+ super().__init__(ingestor)
33
26
 
27
+ self._session = snowpark_session
34
28
  self._fs_cached: Optional[snowfs.SnowFileSystem] = None
35
29
  self._files: Optional[list[str]] = None
36
30
 
@@ -0,0 +1,3 @@
1
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
2
+
3
+ __all__ = ["ExperimentTracking"]
@@ -0,0 +1,55 @@
1
+ from typing import TYPE_CHECKING, Optional
2
+ from warnings import warn
3
+
4
+ import lightgbm as lgb
5
+
6
+ if TYPE_CHECKING:
7
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
8
+ from snowflake.ml.model.model_signature import ModelSignature
9
+
10
+
11
+ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
12
+ def __init__(
13
+ self,
14
+ experiment_tracking: "ExperimentTracking",
15
+ log_model: bool = True,
16
+ log_metrics: bool = True,
17
+ log_params: bool = True,
18
+ model_name: Optional[str] = None,
19
+ model_signature: Optional["ModelSignature"] = None,
20
+ ) -> None:
21
+ self._experiment_tracking = experiment_tracking
22
+ self.log_model = log_model
23
+ self.log_metrics = log_metrics
24
+ self.log_params = log_params
25
+ self.model_name = model_name
26
+ self.model_signature = model_signature
27
+
28
+ super().__init__(eval_result={})
29
+
30
+ def __call__(self, env: lgb.callback.CallbackEnv) -> None:
31
+ if self.log_params:
32
+ if env.iteration == env.begin_iteration: # Log params only at the first iteration
33
+ self._experiment_tracking.log_params(env.params)
34
+
35
+ if self.log_metrics:
36
+ super().__call__(env)
37
+ for dataset_name, metrics in self.eval_result.items():
38
+ for metric_name, log in metrics.items():
39
+ metric_key = dataset_name + ":" + metric_name
40
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
41
+
42
+ if self.log_model:
43
+ if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
44
+ if self.model_signature:
45
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
46
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
47
+ model=env.model,
48
+ model_name=model_name,
49
+ signatures={"predict": self.model_signature},
50
+ )
51
+ else:
52
+ warn(
53
+ "Model will not be logged because model signature is missing. To autolog the model, "
54
+ "please specify `model_signature` when constructing SnowflakeLightgbmCallback."
55
+ )
@@ -0,0 +1,63 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+ from warnings import warn
4
+
5
+ import xgboost as xgb
6
+
7
+ from snowflake.ml.experiment import utils
8
+
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
11
+ from snowflake.ml.model.model_signature import ModelSignature
12
+
13
+
14
+ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "ExperimentTracking",
18
+ log_model: bool = True,
19
+ log_metrics: bool = True,
20
+ log_params: bool = True,
21
+ model_name: Optional[str] = None,
22
+ model_signature: Optional["ModelSignature"] = None,
23
+ ) -> None:
24
+ self._experiment_tracking = experiment_tracking
25
+ self.log_model = log_model
26
+ self.log_metrics = log_metrics
27
+ self.log_params = log_params
28
+ self.model_name = model_name
29
+ self.model_signature = model_signature
30
+
31
+ def before_training(self, model: xgb.Booster) -> xgb.Booster:
32
+ if self.log_params:
33
+ params = json.loads(model.save_config())
34
+ self._experiment_tracking.log_params(utils.flatten_nested_params(params))
35
+
36
+ return model
37
+
38
+ def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
39
+ if self.log_metrics:
40
+ for dataset_name, metrics in evals_log.items():
41
+ for metric_name, log in metrics.items():
42
+ metric_key = dataset_name + ":" + metric_name
43
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
44
+
45
+ return False
46
+
47
+ def after_training(self, model: xgb.Booster) -> xgb.Booster:
48
+ if self.log_model:
49
+ if not self.model_signature:
50
+ warn(
51
+ "Model will not be logged because model signature is missing. "
52
+ "To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
53
+ )
54
+ return model
55
+
56
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
57
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
58
+ model=model,
59
+ model_name=model_name,
60
+ signatures={"predict": self.model_signature},
61
+ )
62
+
63
+ return model
@@ -0,0 +1,14 @@
1
+ from typing import Any, Union
2
+
3
+
4
+ def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
5
+ flat_params = {}
6
+ items = params.items() if isinstance(params, dict) else enumerate(params)
7
+ for key, value in items:
8
+ key = str(key).replace(".", "_") # Replace dots in keys to avoid collisions involving nested keys
9
+ new_prefix = f"{prefix}.{key}" if prefix else key
10
+ if isinstance(value, (dict, list)):
11
+ flat_params.update(flatten_nested_params(value, new_prefix))
12
+ else:
13
+ flat_params[new_prefix] = value
14
+ return flat_params
@@ -6,10 +6,23 @@ DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
7
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
8
8
  MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
9
+ TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
9
10
  RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
10
11
  MEMORY_VOLUME_NAME = "dshm"
11
12
  STAGE_VOLUME_NAME = "stage-volume"
12
- STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
13
+ # Base mount path
14
+ STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
15
+
16
+ # Stage subdirectory paths
17
+ APP_STAGE_SUBPATH = "app"
18
+ SYSTEM_STAGE_SUBPATH = "system"
19
+ OUTPUT_STAGE_SUBPATH = "output"
20
+
21
+ # Complete mount paths (automatically generated from base + subpath)
22
+ APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
23
+ SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
24
+ OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
25
+
13
26
 
14
27
  # Default container image information
15
28
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
@@ -46,9 +59,7 @@ ENABLE_HEALTH_CHECKS = "false"
46
59
  JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
47
60
  JOB_POLL_MAX_DELAY_SECONDS = 30
48
61
 
49
- # Magic attributes
50
- IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
51
- RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
62
+ RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
52
63
 
53
64
  # Log start and end messages
54
65
  LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"