snowflake-ml-python 1.9.1__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +101 -1
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback.py +121 -0
  9. snowflake/ml/jobs/_utils/constants.py +15 -4
  10. snowflake/ml/jobs/_utils/payload_utils.py +150 -49
  11. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +125 -22
  13. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  14. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  15. snowflake/ml/jobs/_utils/types.py +64 -4
  16. snowflake/ml/jobs/job.py +22 -6
  17. snowflake/ml/jobs/manager.py +5 -3
  18. snowflake/ml/model/_client/ops/service_ops.py +17 -2
  19. snowflake/ml/model/_client/sql/service.py +1 -38
  20. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  21. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  22. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  23. snowflake/ml/model/_signatures/utils.py +4 -0
  24. snowflake/ml/model/model_signature.py +2 -0
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +42 -4
  27. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +30 -28
  28. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ class SerializableSessionMixin:
21
21
 
22
22
  def __getstate__(self) -> dict[str, Any]:
23
23
  """Customize pickling to exclude non-serializable session and related components."""
24
- if hasattr(super(), "__getstate__"):
25
- state: dict[str, Any] = super().__getstate__() # type: ignore[misc]
26
- else:
27
- state = self.__dict__.copy()
24
+ parent_state = (
25
+ super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
26
+ if hasattr(super(), "__getstate__")
27
+ else self.__dict__
28
+ )
29
+ state = dict(parent_state) # Create a copy so we can safely modify the state
28
30
 
29
31
  # Save session metadata for validation during unpickling
30
32
  session = state.pop(_SESSION_KEY, None)
@@ -1,6 +1,13 @@
1
1
  import enum
2
2
  import logging
3
+ import os
3
4
  import sys
5
+ import tempfile
6
+ import time
7
+ import uuid
8
+ from typing import Optional
9
+
10
+ import platformdirs
4
11
 
5
12
 
6
13
  class LogColor(enum.Enum):
@@ -57,9 +64,102 @@ class CustomFormatter(logging.Formatter):
57
64
  return "\n".join(formatted_lines)
58
65
 
59
66
 
60
- def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger:
67
+ def _test_writability(directory: str) -> bool:
68
+ """Test if a directory is writable by creating and removing a test file."""
69
+ try:
70
+ os.makedirs(directory, exist_ok=True)
71
+ test_file = os.path.join(directory, f".write_test_{uuid.uuid4().hex[:8]}")
72
+ with open(test_file, "w") as f:
73
+ f.write("test")
74
+ os.remove(test_file)
75
+ return True
76
+ except OSError:
77
+ return False
78
+
79
+
80
+ def _try_log_location(log_dir: str, operation_id: str) -> Optional[str]:
81
+ """Try to create a log file in the given directory if it's writable."""
82
+ if _test_writability(log_dir):
83
+ return os.path.join(log_dir, f"{operation_id}.log")
84
+ return None
85
+
86
+
87
+ def _get_log_file_path(operation_id: str) -> Optional[str]:
88
+ """Get platform-independent log file path. Returns None if no writable location found."""
89
+ # Try locations in order of preference
90
+ locations = [
91
+ # Primary: User log directory
92
+ platformdirs.user_log_dir("snowflake-ml", "Snowflake"),
93
+ # Fallback 1: System temp directory
94
+ os.path.join(tempfile.gettempdir(), "snowflake-ml-logs"),
95
+ # Fallback 2: Current working directory
96
+ ".",
97
+ ]
98
+
99
+ for location in locations:
100
+ log_file_path = _try_log_location(location, operation_id)
101
+ if log_file_path:
102
+ return log_file_path
103
+
104
+ # No writable location found
105
+ return None
106
+
107
+
108
+ def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
109
+ """Get or create a parent logger with FileHandler for the operation."""
110
+ parent_logger_name = f"snowflake_ml_operation_{operation_id}"
111
+ parent_logger = logging.getLogger(parent_logger_name)
112
+
113
+ # Only add handler if it doesn't exist yet
114
+ if not parent_logger.handlers:
115
+ log_file_path = _get_log_file_path(operation_id)
116
+
117
+ if log_file_path:
118
+ # Successfully found a writable location
119
+ try:
120
+ file_handler = logging.FileHandler(log_file_path)
121
+ file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
122
+ parent_logger.addHandler(file_handler)
123
+ parent_logger.setLevel(logging.DEBUG)
124
+ parent_logger.propagate = False # Don't propagate to root logger
125
+
126
+ # Log the file location
127
+ parent_logger.warning(f"Operation logs saved to: {log_file_path}")
128
+ except OSError as e:
129
+ # Even though we found a path, file creation failed
130
+ # Fall back to console-only logging
131
+ parent_logger.setLevel(logging.DEBUG)
132
+ parent_logger.propagate = False
133
+ parent_logger.warning(f"Could not create log file at {log_file_path}: {e}. Using console-only logging.")
134
+ else:
135
+ # No writable location found, use console-only logging
136
+ parent_logger.setLevel(logging.DEBUG)
137
+ parent_logger.propagate = False
138
+ parent_logger.warning("Filesystem appears to be readonly. Using console-only logging.")
139
+
140
+ return parent_logger
141
+
142
+
143
+ def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
61
144
  logger = logging.getLogger(logger_name)
62
145
  handler = logging.StreamHandler(sys.stdout)
63
146
  handler.setFormatter(CustomFormatter(info_color))
64
147
  logger.addHandler(handler)
148
+
149
+ # If operation_id provided, set up parent logger with file handler
150
+ if operation_id:
151
+ parent_logger = _get_or_create_parent_logger(operation_id)
152
+ logger.parent = parent_logger
153
+ logger.propagate = True
154
+
65
155
  return logger
156
+
157
+
158
+ def get_operation_id() -> str:
159
+ """Generate a unique operation ID."""
160
+ return f"model_deploy_{uuid.uuid4().hex[:8]}_{int(time.time())}"
161
+
162
+
163
+ def get_log_file_location(operation_id: str) -> Optional[str]:
164
+ """Get the log file path for an operation ID. Returns None if no writable location available."""
165
+ return _get_log_file_path(operation_id)
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
14
14
  import ray
15
15
 
16
16
  from snowflake import snowpark
17
+ from snowflake.ml._internal.utils import mixins
17
18
  from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
18
19
 
19
20
  _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
@@ -44,7 +45,7 @@ class _RecordBatchesBuffer:
44
45
  return popped
45
46
 
46
47
 
47
- class ArrowIngestor(data_ingestor.DataIngestor):
48
+ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin):
48
49
  """Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
49
50
 
50
51
  def __init__(
@@ -71,6 +72,8 @@ class ArrowIngestor(data_ingestor.DataIngestor):
71
72
 
72
73
  @classmethod
73
74
  def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
75
+ if session is None:
76
+ raise ValueError("Session is required")
74
77
  return cls(session, sources)
75
78
 
76
79
  @classmethod
@@ -6,10 +6,9 @@ from typing_extensions import deprecated
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import env, telemetry
9
- from snowflake.ml._internal.utils import mixins
10
9
  from snowflake.ml.data import data_ingestor, data_source
11
10
  from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
12
- from snowflake.snowpark import context as sf_context
11
+ from snowflake.snowpark import context as sp_context
13
12
 
14
13
  if TYPE_CHECKING:
15
14
  import pandas as pd
@@ -22,13 +21,11 @@ if TYPE_CHECKING:
22
21
  from snowflake.ml import dataset
23
22
 
24
23
  _PROJECT = "DataConnector"
25
- _INGESTOR_KEY = "_ingestor"
26
- _INGESTOR_SOURCES_KEY = "ingestor$sources"
27
24
 
28
25
  DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
29
26
 
30
27
 
31
- class DataConnector(mixins.SerializableSessionMixin):
28
+ class DataConnector:
32
29
  """Snowflake data reader which provides application integration connectors"""
33
30
 
34
31
  DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
@@ -36,11 +33,8 @@ class DataConnector(mixins.SerializableSessionMixin):
36
33
  def __init__(
37
34
  self,
38
35
  ingestor: data_ingestor.DataIngestor,
39
- *,
40
- session: Optional[snowpark.Session] = None,
41
36
  **kwargs: Any,
42
37
  ) -> None:
43
- self._session = session
44
38
  self._ingestor = ingestor
45
39
  self._kwargs = kwargs
46
40
 
@@ -63,7 +57,7 @@ class DataConnector(mixins.SerializableSessionMixin):
63
57
  ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
64
58
  **kwargs: Any,
65
59
  ) -> DataConnectorType:
66
- session = session or sf_context.get_active_session()
60
+ session = session or sp_context.get_active_session()
67
61
  source = data_source.DataFrameInfo(query)
68
62
  return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
69
63
 
@@ -107,31 +101,7 @@ class DataConnector(mixins.SerializableSessionMixin):
107
101
  ) -> DataConnectorType:
108
102
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
109
103
  ingestor = ingestor_class.from_sources(session, sources)
110
- return cls(ingestor, **kwargs, session=session)
111
-
112
- def __getstate__(self) -> dict[str, Any]:
113
- """Customize pickling to exclude non-serializable session and related components."""
114
- if hasattr(super(), "__getstate__"):
115
- state = super().__getstate__()
116
- else:
117
- state = self.__dict__.copy()
118
-
119
- ingestor = state.pop(_INGESTOR_KEY)
120
- state[_INGESTOR_SOURCES_KEY] = ingestor.data_sources
121
-
122
- return state
123
-
124
- def __setstate__(self, state: dict[str, Any]) -> None:
125
- """Restore session from context during unpickling."""
126
- data_sources = state.pop(_INGESTOR_SOURCES_KEY)
127
-
128
- if hasattr(super(), "__setstate__"):
129
- super().__setstate__(state)
130
- else:
131
- self.__dict__.update(state)
132
-
133
- assert self._session is not None
134
- self._ingestor = self.DEFAULT_INGESTOR_CLASS.from_sources(self._session, data_sources)
104
+ return cls(ingestor, **kwargs)
135
105
 
136
106
  @property
137
107
  def data_sources(self) -> list[data_source.DataSource]:
@@ -177,7 +177,7 @@ class Dataset(lineage_node.LineageNode):
177
177
  original_exception=RuntimeError("No Dataset version selected."),
178
178
  )
179
179
  if self._reader is None:
180
- self._reader = dataset_reader.DatasetReader.from_dataset(self)
180
+ self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
181
181
  return self._reader
182
182
 
183
183
  @staticmethod
@@ -1,5 +1,4 @@
1
1
  from typing import Any, Optional
2
- from warnings import warn
3
2
 
4
3
  from snowflake import snowpark
5
4
  from snowflake.ml._internal import telemetry
@@ -21,16 +20,11 @@ class DatasetReader(data_connector.DataConnector, mixins.SerializableSessionMixi
21
20
  self,
22
21
  ingestor: data_ingestor.DataIngestor,
23
22
  *,
24
- session: snowpark.Session,
25
23
  snowpark_session: Optional[snowpark.Session] = None,
26
24
  ) -> None:
27
- if snowpark_session is not None:
28
- warn(
29
- "Argument snowpark_session is deprecated and will be removed in a future release. Use session instead."
30
- )
31
- session = snowpark_session
32
- super().__init__(ingestor, session=session)
25
+ super().__init__(ingestor)
33
26
 
27
+ self._session = snowpark_session
34
28
  self._fs_cached: Optional[snowfs.SnowFileSystem] = None
35
29
  self._files: Optional[list[str]] = None
36
30
 
@@ -0,0 +1,3 @@
1
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
2
+
3
+ __all__ = ["ExperimentTracking"]
@@ -0,0 +1,121 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional, Union
3
+ from warnings import warn
4
+
5
+ import lightgbm as lgb
6
+ import xgboost as xgb
7
+
8
+ from snowflake.ml.model.model_signature import ModelSignature
9
+
10
+ if TYPE_CHECKING:
11
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
12
+
13
+
14
+ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "ExperimentTracking",
18
+ log_model: bool = True,
19
+ log_metrics: bool = True,
20
+ log_params: bool = True,
21
+ model_name: Optional[str] = None,
22
+ model_signature: Optional[ModelSignature] = None,
23
+ ) -> None:
24
+ self._experiment_tracking = experiment_tracking
25
+ self.log_model = log_model
26
+ self.log_metrics = log_metrics
27
+ self.log_params = log_params
28
+ self.model_name = model_name
29
+ self.model_signature = model_signature
30
+
31
+ def before_training(self, model: xgb.Booster) -> xgb.Booster:
32
+ def _flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
33
+ flat_params = {}
34
+ items = params.items() if isinstance(params, dict) else enumerate(params)
35
+ for key, value in items:
36
+ new_prefix = f"{prefix}.{key}" if prefix else str(key)
37
+ if isinstance(value, (dict, list)):
38
+ flat_params.update(_flatten_nested_params(value, new_prefix))
39
+ else:
40
+ flat_params[new_prefix] = value
41
+ return flat_params
42
+
43
+ if self.log_params:
44
+ params = json.loads(model.save_config())
45
+ self._experiment_tracking.log_params(_flatten_nested_params(params))
46
+
47
+ return model
48
+
49
+ def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
50
+ if self.log_metrics:
51
+ for dataset_name, metrics in evals_log.items():
52
+ for metric_name, log in metrics.items():
53
+ metric_key = dataset_name + ":" + metric_name
54
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
55
+
56
+ return False
57
+
58
+ def after_training(self, model: xgb.Booster) -> xgb.Booster:
59
+ if self.log_model:
60
+ if not self.model_signature:
61
+ warn(
62
+ "Model will not be logged because model signature is missing. "
63
+ "To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
64
+ )
65
+ return model
66
+
67
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
68
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
69
+ model=model,
70
+ model_name=model_name,
71
+ signatures={"predict": self.model_signature},
72
+ )
73
+
74
+ return model
75
+
76
+
77
+ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
78
+ def __init__(
79
+ self,
80
+ experiment_tracking: "ExperimentTracking",
81
+ log_model: bool = True,
82
+ log_metrics: bool = True,
83
+ log_params: bool = True,
84
+ model_name: Optional[str] = None,
85
+ model_signature: Optional[ModelSignature] = None,
86
+ ) -> None:
87
+ self._experiment_tracking = experiment_tracking
88
+ self.log_model = log_model
89
+ self.log_metrics = log_metrics
90
+ self.log_params = log_params
91
+ self.model_name = model_name
92
+ self.model_signature = model_signature
93
+
94
+ super().__init__(eval_result={})
95
+
96
+ def __call__(self, env: lgb.callback.CallbackEnv) -> None:
97
+ if self.log_params:
98
+ if env.iteration == env.begin_iteration: # Log params only at the first iteration
99
+ self._experiment_tracking.log_params(env.params)
100
+
101
+ if self.log_metrics:
102
+ super().__call__(env)
103
+ for dataset_name, metrics in self.eval_result.items():
104
+ for metric_name, log in metrics.items():
105
+ metric_key = dataset_name + ":" + metric_name
106
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
107
+
108
+ if self.log_model:
109
+ if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
110
+ if self.model_signature:
111
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
112
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
113
+ model=env.model,
114
+ model_name=model_name,
115
+ signatures={"predict": self.model_signature},
116
+ )
117
+ else:
118
+ warn(
119
+ "Model will not be logged because model signature is missing. To autolog the model, "
120
+ "please specify `model_signature` when constructing SnowflakeLightgbmCallback."
121
+ )
@@ -6,10 +6,23 @@ DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
7
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
8
8
  MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
9
+ TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
9
10
  RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
10
11
  MEMORY_VOLUME_NAME = "dshm"
11
12
  STAGE_VOLUME_NAME = "stage-volume"
12
- STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
13
+ # Base mount path
14
+ STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
15
+
16
+ # Stage subdirectory paths
17
+ APP_STAGE_SUBPATH = "app"
18
+ SYSTEM_STAGE_SUBPATH = "system"
19
+ OUTPUT_STAGE_SUBPATH = "output"
20
+
21
+ # Complete mount paths (automatically generated from base + subpath)
22
+ APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
23
+ SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
24
+ OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
25
+
13
26
 
14
27
  # Default container image information
15
28
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
@@ -46,9 +59,7 @@ ENABLE_HEALTH_CHECKS = "false"
46
59
  JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
47
60
  JOB_POLL_MAX_DELAY_SECONDS = 30
48
61
 
49
- # Magic attributes
50
- IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
51
- RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
62
+ RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
52
63
 
53
64
  # Log start and end messages
54
65
  LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"