snowflake-ml-python 1.9.1__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +6 -4
- snowflake/ml/_internal/utils/service_logger.py +101 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
- snowflake/ml/data/data_connector.py +4 -34
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/dataset/dataset_reader.py +2 -8
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +150 -49
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +125 -22
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +22 -6
- snowflake/ml/jobs/manager.py +5 -3
- snowflake/ml/model/_client/ops/service_ops.py +17 -2
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +42 -4
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +30 -28
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -21,10 +21,12 @@ class SerializableSessionMixin:
|
|
|
21
21
|
|
|
22
22
|
def __getstate__(self) -> dict[str, Any]:
|
|
23
23
|
"""Customize pickling to exclude non-serializable session and related components."""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
24
|
+
parent_state = (
|
|
25
|
+
super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
|
|
26
|
+
if hasattr(super(), "__getstate__")
|
|
27
|
+
else self.__dict__
|
|
28
|
+
)
|
|
29
|
+
state = dict(parent_state) # Create a copy so we can safely modify the state
|
|
28
30
|
|
|
29
31
|
# Save session metadata for validation during unpickling
|
|
30
32
|
session = state.pop(_SESSION_KEY, None)
|
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import logging
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
5
|
+
import tempfile
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import platformdirs
|
|
4
11
|
|
|
5
12
|
|
|
6
13
|
class LogColor(enum.Enum):
|
|
@@ -57,9 +64,102 @@ class CustomFormatter(logging.Formatter):
|
|
|
57
64
|
return "\n".join(formatted_lines)
|
|
58
65
|
|
|
59
66
|
|
|
60
|
-
def
|
|
67
|
+
def _test_writability(directory: str) -> bool:
|
|
68
|
+
"""Test if a directory is writable by creating and removing a test file."""
|
|
69
|
+
try:
|
|
70
|
+
os.makedirs(directory, exist_ok=True)
|
|
71
|
+
test_file = os.path.join(directory, f".write_test_{uuid.uuid4().hex[:8]}")
|
|
72
|
+
with open(test_file, "w") as f:
|
|
73
|
+
f.write("test")
|
|
74
|
+
os.remove(test_file)
|
|
75
|
+
return True
|
|
76
|
+
except OSError:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _try_log_location(log_dir: str, operation_id: str) -> Optional[str]:
|
|
81
|
+
"""Try to create a log file in the given directory if it's writable."""
|
|
82
|
+
if _test_writability(log_dir):
|
|
83
|
+
return os.path.join(log_dir, f"{operation_id}.log")
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _get_log_file_path(operation_id: str) -> Optional[str]:
|
|
88
|
+
"""Get platform-independent log file path. Returns None if no writable location found."""
|
|
89
|
+
# Try locations in order of preference
|
|
90
|
+
locations = [
|
|
91
|
+
# Primary: User log directory
|
|
92
|
+
platformdirs.user_log_dir("snowflake-ml", "Snowflake"),
|
|
93
|
+
# Fallback 1: System temp directory
|
|
94
|
+
os.path.join(tempfile.gettempdir(), "snowflake-ml-logs"),
|
|
95
|
+
# Fallback 2: Current working directory
|
|
96
|
+
".",
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
for location in locations:
|
|
100
|
+
log_file_path = _try_log_location(location, operation_id)
|
|
101
|
+
if log_file_path:
|
|
102
|
+
return log_file_path
|
|
103
|
+
|
|
104
|
+
# No writable location found
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
|
|
109
|
+
"""Get or create a parent logger with FileHandler for the operation."""
|
|
110
|
+
parent_logger_name = f"snowflake_ml_operation_{operation_id}"
|
|
111
|
+
parent_logger = logging.getLogger(parent_logger_name)
|
|
112
|
+
|
|
113
|
+
# Only add handler if it doesn't exist yet
|
|
114
|
+
if not parent_logger.handlers:
|
|
115
|
+
log_file_path = _get_log_file_path(operation_id)
|
|
116
|
+
|
|
117
|
+
if log_file_path:
|
|
118
|
+
# Successfully found a writable location
|
|
119
|
+
try:
|
|
120
|
+
file_handler = logging.FileHandler(log_file_path)
|
|
121
|
+
file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
|
|
122
|
+
parent_logger.addHandler(file_handler)
|
|
123
|
+
parent_logger.setLevel(logging.DEBUG)
|
|
124
|
+
parent_logger.propagate = False # Don't propagate to root logger
|
|
125
|
+
|
|
126
|
+
# Log the file location
|
|
127
|
+
parent_logger.warning(f"Operation logs saved to: {log_file_path}")
|
|
128
|
+
except OSError as e:
|
|
129
|
+
# Even though we found a path, file creation failed
|
|
130
|
+
# Fall back to console-only logging
|
|
131
|
+
parent_logger.setLevel(logging.DEBUG)
|
|
132
|
+
parent_logger.propagate = False
|
|
133
|
+
parent_logger.warning(f"Could not create log file at {log_file_path}: {e}. Using console-only logging.")
|
|
134
|
+
else:
|
|
135
|
+
# No writable location found, use console-only logging
|
|
136
|
+
parent_logger.setLevel(logging.DEBUG)
|
|
137
|
+
parent_logger.propagate = False
|
|
138
|
+
parent_logger.warning("Filesystem appears to be readonly. Using console-only logging.")
|
|
139
|
+
|
|
140
|
+
return parent_logger
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
|
|
61
144
|
logger = logging.getLogger(logger_name)
|
|
62
145
|
handler = logging.StreamHandler(sys.stdout)
|
|
63
146
|
handler.setFormatter(CustomFormatter(info_color))
|
|
64
147
|
logger.addHandler(handler)
|
|
148
|
+
|
|
149
|
+
# If operation_id provided, set up parent logger with file handler
|
|
150
|
+
if operation_id:
|
|
151
|
+
parent_logger = _get_or_create_parent_logger(operation_id)
|
|
152
|
+
logger.parent = parent_logger
|
|
153
|
+
logger.propagate = True
|
|
154
|
+
|
|
65
155
|
return logger
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_operation_id() -> str:
|
|
159
|
+
"""Generate a unique operation ID."""
|
|
160
|
+
return f"model_deploy_{uuid.uuid4().hex[:8]}_{int(time.time())}"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def get_log_file_location(operation_id: str) -> Optional[str]:
|
|
164
|
+
"""Get the log file path for an operation ID. Returns None if no writable location available."""
|
|
165
|
+
return _get_log_file_path(operation_id)
|
|
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
|
|
14
14
|
import ray
|
|
15
15
|
|
|
16
16
|
from snowflake import snowpark
|
|
17
|
+
from snowflake.ml._internal.utils import mixins
|
|
17
18
|
from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
|
|
18
19
|
|
|
19
20
|
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
|
@@ -44,7 +45,7 @@ class _RecordBatchesBuffer:
|
|
|
44
45
|
return popped
|
|
45
46
|
|
|
46
47
|
|
|
47
|
-
class ArrowIngestor(data_ingestor.DataIngestor):
|
|
48
|
+
class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin):
|
|
48
49
|
"""Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
|
|
49
50
|
|
|
50
51
|
def __init__(
|
|
@@ -71,6 +72,8 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
|
71
72
|
|
|
72
73
|
@classmethod
|
|
73
74
|
def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
|
|
75
|
+
if session is None:
|
|
76
|
+
raise ValueError("Session is required")
|
|
74
77
|
return cls(session, sources)
|
|
75
78
|
|
|
76
79
|
@classmethod
|
|
@@ -6,10 +6,9 @@ from typing_extensions import deprecated
|
|
|
6
6
|
|
|
7
7
|
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml._internal import env, telemetry
|
|
9
|
-
from snowflake.ml._internal.utils import mixins
|
|
10
9
|
from snowflake.ml.data import data_ingestor, data_source
|
|
11
10
|
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
|
12
|
-
from snowflake.snowpark import context as
|
|
11
|
+
from snowflake.snowpark import context as sp_context
|
|
13
12
|
|
|
14
13
|
if TYPE_CHECKING:
|
|
15
14
|
import pandas as pd
|
|
@@ -22,13 +21,11 @@ if TYPE_CHECKING:
|
|
|
22
21
|
from snowflake.ml import dataset
|
|
23
22
|
|
|
24
23
|
_PROJECT = "DataConnector"
|
|
25
|
-
_INGESTOR_KEY = "_ingestor"
|
|
26
|
-
_INGESTOR_SOURCES_KEY = "ingestor$sources"
|
|
27
24
|
|
|
28
25
|
DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
|
|
29
26
|
|
|
30
27
|
|
|
31
|
-
class DataConnector
|
|
28
|
+
class DataConnector:
|
|
32
29
|
"""Snowflake data reader which provides application integration connectors"""
|
|
33
30
|
|
|
34
31
|
DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
|
|
@@ -36,11 +33,8 @@ class DataConnector(mixins.SerializableSessionMixin):
|
|
|
36
33
|
def __init__(
|
|
37
34
|
self,
|
|
38
35
|
ingestor: data_ingestor.DataIngestor,
|
|
39
|
-
*,
|
|
40
|
-
session: Optional[snowpark.Session] = None,
|
|
41
36
|
**kwargs: Any,
|
|
42
37
|
) -> None:
|
|
43
|
-
self._session = session
|
|
44
38
|
self._ingestor = ingestor
|
|
45
39
|
self._kwargs = kwargs
|
|
46
40
|
|
|
@@ -63,7 +57,7 @@ class DataConnector(mixins.SerializableSessionMixin):
|
|
|
63
57
|
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
|
64
58
|
**kwargs: Any,
|
|
65
59
|
) -> DataConnectorType:
|
|
66
|
-
session = session or
|
|
60
|
+
session = session or sp_context.get_active_session()
|
|
67
61
|
source = data_source.DataFrameInfo(query)
|
|
68
62
|
return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
|
|
69
63
|
|
|
@@ -107,31 +101,7 @@ class DataConnector(mixins.SerializableSessionMixin):
|
|
|
107
101
|
) -> DataConnectorType:
|
|
108
102
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
|
109
103
|
ingestor = ingestor_class.from_sources(session, sources)
|
|
110
|
-
return cls(ingestor, **kwargs
|
|
111
|
-
|
|
112
|
-
def __getstate__(self) -> dict[str, Any]:
|
|
113
|
-
"""Customize pickling to exclude non-serializable session and related components."""
|
|
114
|
-
if hasattr(super(), "__getstate__"):
|
|
115
|
-
state = super().__getstate__()
|
|
116
|
-
else:
|
|
117
|
-
state = self.__dict__.copy()
|
|
118
|
-
|
|
119
|
-
ingestor = state.pop(_INGESTOR_KEY)
|
|
120
|
-
state[_INGESTOR_SOURCES_KEY] = ingestor.data_sources
|
|
121
|
-
|
|
122
|
-
return state
|
|
123
|
-
|
|
124
|
-
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
125
|
-
"""Restore session from context during unpickling."""
|
|
126
|
-
data_sources = state.pop(_INGESTOR_SOURCES_KEY)
|
|
127
|
-
|
|
128
|
-
if hasattr(super(), "__setstate__"):
|
|
129
|
-
super().__setstate__(state)
|
|
130
|
-
else:
|
|
131
|
-
self.__dict__.update(state)
|
|
132
|
-
|
|
133
|
-
assert self._session is not None
|
|
134
|
-
self._ingestor = self.DEFAULT_INGESTOR_CLASS.from_sources(self._session, data_sources)
|
|
104
|
+
return cls(ingestor, **kwargs)
|
|
135
105
|
|
|
136
106
|
@property
|
|
137
107
|
def data_sources(self) -> list[data_source.DataSource]:
|
snowflake/ml/dataset/dataset.py
CHANGED
|
@@ -177,7 +177,7 @@ class Dataset(lineage_node.LineageNode):
|
|
|
177
177
|
original_exception=RuntimeError("No Dataset version selected."),
|
|
178
178
|
)
|
|
179
179
|
if self._reader is None:
|
|
180
|
-
self._reader = dataset_reader.DatasetReader.from_dataset(self)
|
|
180
|
+
self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
|
|
181
181
|
return self._reader
|
|
182
182
|
|
|
183
183
|
@staticmethod
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
|
-
from warnings import warn
|
|
3
2
|
|
|
4
3
|
from snowflake import snowpark
|
|
5
4
|
from snowflake.ml._internal import telemetry
|
|
@@ -21,16 +20,11 @@ class DatasetReader(data_connector.DataConnector, mixins.SerializableSessionMixi
|
|
|
21
20
|
self,
|
|
22
21
|
ingestor: data_ingestor.DataIngestor,
|
|
23
22
|
*,
|
|
24
|
-
session: snowpark.Session,
|
|
25
23
|
snowpark_session: Optional[snowpark.Session] = None,
|
|
26
24
|
) -> None:
|
|
27
|
-
|
|
28
|
-
warn(
|
|
29
|
-
"Argument snowpark_session is deprecated and will be removed in a future release. Use session instead."
|
|
30
|
-
)
|
|
31
|
-
session = snowpark_session
|
|
32
|
-
super().__init__(ingestor, session=session)
|
|
25
|
+
super().__init__(ingestor)
|
|
33
26
|
|
|
27
|
+
self._session = snowpark_session
|
|
34
28
|
self._fs_cached: Optional[snowfs.SnowFileSystem] = None
|
|
35
29
|
self._files: Optional[list[str]] = None
|
|
36
30
|
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
3
|
+
from warnings import warn
|
|
4
|
+
|
|
5
|
+
import lightgbm as lgb
|
|
6
|
+
import xgboost as xgb
|
|
7
|
+
|
|
8
|
+
from snowflake.ml.model.model_signature import ModelSignature
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
experiment_tracking: "ExperimentTracking",
|
|
18
|
+
log_model: bool = True,
|
|
19
|
+
log_metrics: bool = True,
|
|
20
|
+
log_params: bool = True,
|
|
21
|
+
model_name: Optional[str] = None,
|
|
22
|
+
model_signature: Optional[ModelSignature] = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self._experiment_tracking = experiment_tracking
|
|
25
|
+
self.log_model = log_model
|
|
26
|
+
self.log_metrics = log_metrics
|
|
27
|
+
self.log_params = log_params
|
|
28
|
+
self.model_name = model_name
|
|
29
|
+
self.model_signature = model_signature
|
|
30
|
+
|
|
31
|
+
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
32
|
+
def _flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
|
|
33
|
+
flat_params = {}
|
|
34
|
+
items = params.items() if isinstance(params, dict) else enumerate(params)
|
|
35
|
+
for key, value in items:
|
|
36
|
+
new_prefix = f"{prefix}.{key}" if prefix else str(key)
|
|
37
|
+
if isinstance(value, (dict, list)):
|
|
38
|
+
flat_params.update(_flatten_nested_params(value, new_prefix))
|
|
39
|
+
else:
|
|
40
|
+
flat_params[new_prefix] = value
|
|
41
|
+
return flat_params
|
|
42
|
+
|
|
43
|
+
if self.log_params:
|
|
44
|
+
params = json.loads(model.save_config())
|
|
45
|
+
self._experiment_tracking.log_params(_flatten_nested_params(params))
|
|
46
|
+
|
|
47
|
+
return model
|
|
48
|
+
|
|
49
|
+
def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
|
|
50
|
+
if self.log_metrics:
|
|
51
|
+
for dataset_name, metrics in evals_log.items():
|
|
52
|
+
for metric_name, log in metrics.items():
|
|
53
|
+
metric_key = dataset_name + ":" + metric_name
|
|
54
|
+
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
|
|
55
|
+
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
59
|
+
if self.log_model:
|
|
60
|
+
if not self.model_signature:
|
|
61
|
+
warn(
|
|
62
|
+
"Model will not be logged because model signature is missing. "
|
|
63
|
+
"To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
|
|
64
|
+
)
|
|
65
|
+
return model
|
|
66
|
+
|
|
67
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
68
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
69
|
+
model=model,
|
|
70
|
+
model_name=model_name,
|
|
71
|
+
signatures={"predict": self.model_signature},
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return model
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
experiment_tracking: "ExperimentTracking",
|
|
81
|
+
log_model: bool = True,
|
|
82
|
+
log_metrics: bool = True,
|
|
83
|
+
log_params: bool = True,
|
|
84
|
+
model_name: Optional[str] = None,
|
|
85
|
+
model_signature: Optional[ModelSignature] = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
self._experiment_tracking = experiment_tracking
|
|
88
|
+
self.log_model = log_model
|
|
89
|
+
self.log_metrics = log_metrics
|
|
90
|
+
self.log_params = log_params
|
|
91
|
+
self.model_name = model_name
|
|
92
|
+
self.model_signature = model_signature
|
|
93
|
+
|
|
94
|
+
super().__init__(eval_result={})
|
|
95
|
+
|
|
96
|
+
def __call__(self, env: lgb.callback.CallbackEnv) -> None:
|
|
97
|
+
if self.log_params:
|
|
98
|
+
if env.iteration == env.begin_iteration: # Log params only at the first iteration
|
|
99
|
+
self._experiment_tracking.log_params(env.params)
|
|
100
|
+
|
|
101
|
+
if self.log_metrics:
|
|
102
|
+
super().__call__(env)
|
|
103
|
+
for dataset_name, metrics in self.eval_result.items():
|
|
104
|
+
for metric_name, log in metrics.items():
|
|
105
|
+
metric_key = dataset_name + ":" + metric_name
|
|
106
|
+
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
|
|
107
|
+
|
|
108
|
+
if self.log_model:
|
|
109
|
+
if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
|
|
110
|
+
if self.model_signature:
|
|
111
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
112
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
113
|
+
model=env.model,
|
|
114
|
+
model_name=model_name,
|
|
115
|
+
signatures={"predict": self.model_signature},
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
warn(
|
|
119
|
+
"Model will not be logged because model signature is missing. To autolog the model, "
|
|
120
|
+
"please specify `model_signature` when constructing SnowflakeLightgbmCallback."
|
|
121
|
+
)
|
|
@@ -6,10 +6,23 @@ DEFAULT_CONTAINER_NAME = "main"
|
|
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
7
7
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
8
8
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
9
|
+
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
9
10
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
10
11
|
MEMORY_VOLUME_NAME = "dshm"
|
|
11
12
|
STAGE_VOLUME_NAME = "stage-volume"
|
|
12
|
-
|
|
13
|
+
# Base mount path
|
|
14
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
15
|
+
|
|
16
|
+
# Stage subdirectory paths
|
|
17
|
+
APP_STAGE_SUBPATH = "app"
|
|
18
|
+
SYSTEM_STAGE_SUBPATH = "system"
|
|
19
|
+
OUTPUT_STAGE_SUBPATH = "output"
|
|
20
|
+
|
|
21
|
+
# Complete mount paths (automatically generated from base + subpath)
|
|
22
|
+
APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
|
|
23
|
+
SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
|
|
24
|
+
OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
|
|
25
|
+
|
|
13
26
|
|
|
14
27
|
# Default container image information
|
|
15
28
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
@@ -46,9 +59,7 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
|
46
59
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
|
47
60
|
JOB_POLL_MAX_DELAY_SECONDS = 30
|
|
48
61
|
|
|
49
|
-
|
|
50
|
-
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
|
51
|
-
RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
|
|
62
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
|
|
52
63
|
|
|
53
64
|
# Log start and end messages
|
|
54
65
|
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|