snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +19 -7
- snowflake/ml/feature_store/feature_store.py +236 -61
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +16 -2
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +8 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +15 -0
- snowflake/ml/jobs/job.py +186 -40
- snowflake/ml/jobs/manager.py +48 -39
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +168 -18
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
|
4
|
+
|
|
5
|
+
import cloudpickle
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from typing_extensions import TypeGuard, Unpack
|
|
8
|
+
|
|
9
|
+
from snowflake.ml._internal import type_utils
|
|
10
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
|
11
|
+
from snowflake.ml.model._packager.model_env import model_env
|
|
12
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
14
|
+
from snowflake.ml.model._packager.model_meta import (
|
|
15
|
+
model_blob_meta,
|
|
16
|
+
model_meta as model_meta_api,
|
|
17
|
+
)
|
|
18
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import prophet
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _normalize_column_names(
|
|
27
|
+
data: pd.DataFrame,
|
|
28
|
+
date_column: Optional[str] = None,
|
|
29
|
+
target_column: Optional[str] = None,
|
|
30
|
+
) -> pd.DataFrame:
|
|
31
|
+
"""Normalize user column names to Prophet's required 'ds' and 'y' format.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
data: Input DataFrame with user's column names
|
|
35
|
+
date_column: Name of the date column to map to 'ds'
|
|
36
|
+
target_column: Name of the target column to map to 'y'
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
DataFrame with columns renamed to Prophet format
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If specified columns don't exist or if there are naming conflicts
|
|
43
|
+
"""
|
|
44
|
+
if date_column is None and target_column is None:
|
|
45
|
+
return data
|
|
46
|
+
|
|
47
|
+
data = data.copy()
|
|
48
|
+
|
|
49
|
+
if date_column is not None:
|
|
50
|
+
if date_column not in data.columns:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Specified date_column '{date_column}' not found in DataFrame. "
|
|
53
|
+
f"Available columns: {list(data.columns)}"
|
|
54
|
+
)
|
|
55
|
+
if date_column != "ds":
|
|
56
|
+
# Check if 'ds' already exists as a different column
|
|
57
|
+
if "ds" in data.columns:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Cannot rename '{date_column}' to 'ds' because 'ds' already exists in the DataFrame. "
|
|
60
|
+
f"Please either: (1) rename or remove the existing 'ds' column, or "
|
|
61
|
+
f"(2) if 'ds' is your date column, set date_column='ds' instead."
|
|
62
|
+
)
|
|
63
|
+
data = data.rename(columns={date_column: "ds"})
|
|
64
|
+
|
|
65
|
+
if target_column is not None:
|
|
66
|
+
if target_column not in data.columns:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Specified target_column '{target_column}' not found in DataFrame. "
|
|
69
|
+
f"Available columns: {list(data.columns)}"
|
|
70
|
+
)
|
|
71
|
+
if target_column != "y":
|
|
72
|
+
# Check if 'y' already exists as a different column
|
|
73
|
+
if "y" in data.columns:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Cannot rename '{target_column}' to 'y' because 'y' already exists in the DataFrame. "
|
|
76
|
+
f"Please either: (1) rename or remove the existing 'y' column, or "
|
|
77
|
+
f"(2) if 'y' is your target column, set target_column='y' instead."
|
|
78
|
+
)
|
|
79
|
+
data = data.rename(columns={target_column: "y"})
|
|
80
|
+
|
|
81
|
+
return data
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _sanitize_prophet_output(predictions: pd.DataFrame) -> pd.DataFrame:
|
|
85
|
+
"""Sanitize Prophet prediction output to have SQL-safe column names.
|
|
86
|
+
|
|
87
|
+
Prophet may include holiday columns with names containing spaces (e.g., "Christmas Day")
|
|
88
|
+
which cannot be used as unquoted SQL identifiers in Snowflake. This function normalizes all
|
|
89
|
+
column names to be valid unquoted SQL identifiers by replacing spaces with underscores and
|
|
90
|
+
removing special characters.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
predictions: Raw prediction DataFrame from Prophet
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
DataFrame with normalized SQL-safe column names
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If predictions DataFrame is empty, has no columns, or is missing required
|
|
100
|
+
columns 'ds' and 'yhat'
|
|
101
|
+
"""
|
|
102
|
+
# Check if predictions is empty or has no columns
|
|
103
|
+
if predictions is None or len(predictions.columns) == 0:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Prophet predictions DataFrame is empty or has no columns. "
|
|
106
|
+
f"DataFrame shape: {predictions.shape if predictions is not None else 'None'}, "
|
|
107
|
+
f"Type: {type(predictions)}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if "ds" not in predictions.columns or "yhat" not in predictions.columns:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Prophet predictions missing required columns 'ds' and 'yhat'. "
|
|
113
|
+
f"Available columns: {list(predictions.columns)}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Normalize all column names to be SQL-safe
|
|
117
|
+
normalized_columns = {col: handlers_utils.normalize_column_name(col) for col in predictions.columns}
|
|
118
|
+
|
|
119
|
+
# Check for conflicts after normalization
|
|
120
|
+
normalized_values = list(normalized_columns.values())
|
|
121
|
+
if len(normalized_values) != len(set(normalized_values)):
|
|
122
|
+
# Find duplicates
|
|
123
|
+
seen = set()
|
|
124
|
+
duplicates = []
|
|
125
|
+
for val in normalized_values:
|
|
126
|
+
if val in seen:
|
|
127
|
+
duplicates.append(val)
|
|
128
|
+
seen.add(val)
|
|
129
|
+
|
|
130
|
+
logger.warning(
|
|
131
|
+
f"Column name normalization resulted in duplicates: {duplicates}. "
|
|
132
|
+
f"Original columns: {[k for k, v in normalized_columns.items() if v in duplicates]}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Rename columns
|
|
136
|
+
sanitized_predictions = predictions.rename(columns=normalized_columns)
|
|
137
|
+
|
|
138
|
+
return sanitized_predictions
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _validate_prophet_data_format(data: model_types.SupportedDataType) -> pd.DataFrame:
|
|
142
|
+
"""Validate that input data follows Prophet's required format with 'ds' and 'y' columns.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
data: Input data to validate
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
DataFrame with validated Prophet format and proper data types
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If data doesn't meet Prophet requirements
|
|
152
|
+
"""
|
|
153
|
+
if not isinstance(data, pd.DataFrame):
|
|
154
|
+
raise ValueError("Prophet models require pandas DataFrame input with 'ds' and 'y' columns")
|
|
155
|
+
|
|
156
|
+
if "ds" not in data.columns:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Prophet models require a 'ds' column containing dates. "
|
|
159
|
+
"If your date column has a different name, use the 'date_column' parameter "
|
|
160
|
+
"when saving the model to map it to 'ds'."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if "y" not in data.columns:
|
|
164
|
+
# Allow 'y' column with NaN values for future predictions
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Prophet models require a 'y' column containing values (can be NaN for future periods). "
|
|
167
|
+
"If your target column has a different name, use the 'target_column' parameter "
|
|
168
|
+
"when saving the model to map it to 'y'."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
validated_data = data.copy()
|
|
172
|
+
|
|
173
|
+
# Convert datetime column - this handles string timestamps from Snowflake
|
|
174
|
+
try:
|
|
175
|
+
validated_data["ds"] = pd.to_datetime(validated_data["ds"])
|
|
176
|
+
except Exception as e:
|
|
177
|
+
raise ValueError(f"'ds' column must contain valid datetime values: {e}")
|
|
178
|
+
|
|
179
|
+
# Convert numeric columns to proper float types
|
|
180
|
+
for col in validated_data.columns:
|
|
181
|
+
if col != "ds":
|
|
182
|
+
try:
|
|
183
|
+
# Convert to numeric, coercing errors to NaN
|
|
184
|
+
original_col = validated_data[col]
|
|
185
|
+
validated_data[col] = pd.to_numeric(validated_data[col], errors="coerce")
|
|
186
|
+
|
|
187
|
+
# Force explicit dtype conversion to ensure numpy operations work
|
|
188
|
+
validated_data[col] = validated_data[col].astype("float64")
|
|
189
|
+
|
|
190
|
+
logger.debug(f"Converted column '{col}' from {original_col.dtype} to {validated_data[col].dtype}")
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
# If conversion fails completely, provide detailed error
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Column '{col}' contains data that cannot be converted to numeric: {e}. "
|
|
196
|
+
f"Original dtype: {validated_data[col].dtype}, sample values: {validated_data[col].head().tolist()}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return validated_data
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@final
|
|
203
|
+
class ProphetHandler(_base.BaseModelHandler["prophet.Prophet"]):
|
|
204
|
+
"""Handler for prophet time series forecasting models."""
|
|
205
|
+
|
|
206
|
+
HANDLER_TYPE = "prophet"
|
|
207
|
+
HANDLER_VERSION = "2025-01-01"
|
|
208
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
|
209
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
|
210
|
+
|
|
211
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
|
212
|
+
DEFAULT_TARGET_METHODS = ["predict"]
|
|
213
|
+
|
|
214
|
+
# Prophet models require sample data to infer signatures because the data may contain regressors.
|
|
215
|
+
IS_AUTO_SIGNATURE = False
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def can_handle(
|
|
219
|
+
cls,
|
|
220
|
+
model: model_types.SupportedModelType,
|
|
221
|
+
) -> TypeGuard["prophet.Prophet"]:
|
|
222
|
+
"""Check if this handler can process the given model.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
model: The model object to check
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
True if this is a Prophet model, False otherwise
|
|
229
|
+
"""
|
|
230
|
+
return type_utils.LazyType("prophet.Prophet").isinstance(model) and any(
|
|
231
|
+
(hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def cast_model(
|
|
236
|
+
cls,
|
|
237
|
+
model: model_types.SupportedModelType,
|
|
238
|
+
) -> "prophet.Prophet":
|
|
239
|
+
"""Cast the model to Prophet type.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
model: The model object
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
The model cast as Prophet
|
|
246
|
+
"""
|
|
247
|
+
import prophet
|
|
248
|
+
|
|
249
|
+
assert isinstance(model, prophet.Prophet)
|
|
250
|
+
return cast("prophet.Prophet", model)
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def save_model(
|
|
254
|
+
cls,
|
|
255
|
+
name: str,
|
|
256
|
+
model: "prophet.Prophet",
|
|
257
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
258
|
+
model_blobs_dir_path: str,
|
|
259
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
|
260
|
+
is_sub_model: Optional[bool] = False,
|
|
261
|
+
**kwargs: Unpack[model_types.ProphetSaveOptions],
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Save Prophet model and metadata.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
name: Name of the model
|
|
267
|
+
model: The Prophet model object
|
|
268
|
+
model_meta: The model metadata
|
|
269
|
+
model_blobs_dir_path: Directory to save model files
|
|
270
|
+
sample_input_data: Sample input data for signature inference
|
|
271
|
+
is_sub_model: Whether this is a sub-model
|
|
272
|
+
**kwargs: Additional save options including date_column and target_column for column mapping
|
|
273
|
+
|
|
274
|
+
Raises:
|
|
275
|
+
ValueError: If sample_input_data is not a pandas DataFrame or if column mapping fails
|
|
276
|
+
"""
|
|
277
|
+
import prophet
|
|
278
|
+
|
|
279
|
+
assert isinstance(model, prophet.Prophet)
|
|
280
|
+
|
|
281
|
+
date_column = kwargs.pop("date_column", None)
|
|
282
|
+
target_column = kwargs.pop("target_column", None)
|
|
283
|
+
|
|
284
|
+
if not is_sub_model:
|
|
285
|
+
# Validate sample input data if provided
|
|
286
|
+
if sample_input_data is not None:
|
|
287
|
+
if isinstance(sample_input_data, pd.DataFrame):
|
|
288
|
+
# Normalize for validation purposes
|
|
289
|
+
normalized_sample = _normalize_column_names(
|
|
290
|
+
sample_input_data.copy(),
|
|
291
|
+
date_column=date_column,
|
|
292
|
+
target_column=target_column,
|
|
293
|
+
)
|
|
294
|
+
_validate_prophet_data_format(normalized_sample)
|
|
295
|
+
else:
|
|
296
|
+
raise ValueError("Prophet models require pandas DataFrame sample input data")
|
|
297
|
+
|
|
298
|
+
target_methods = handlers_utils.get_target_methods(
|
|
299
|
+
model=model,
|
|
300
|
+
target_methods=kwargs.pop("target_methods", None),
|
|
301
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def get_prediction(
|
|
305
|
+
target_method_name: str,
|
|
306
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
|
307
|
+
) -> model_types.SupportedLocalDataType:
|
|
308
|
+
"""Generate predictions for signature inference."""
|
|
309
|
+
if not isinstance(sample_input_data, pd.DataFrame):
|
|
310
|
+
raise ValueError("Prophet requires pandas DataFrame input")
|
|
311
|
+
|
|
312
|
+
normalized_data = _normalize_column_names(
|
|
313
|
+
sample_input_data,
|
|
314
|
+
date_column=date_column,
|
|
315
|
+
target_column=target_column,
|
|
316
|
+
)
|
|
317
|
+
validated_data = _validate_prophet_data_format(normalized_data)
|
|
318
|
+
|
|
319
|
+
target_method = getattr(model, target_method_name, None)
|
|
320
|
+
if not callable(target_method):
|
|
321
|
+
raise ValueError(f"Method {target_method_name} not found on Prophet model")
|
|
322
|
+
|
|
323
|
+
if target_method_name == "predict":
|
|
324
|
+
# Use the input data as the future dataframe for prediction
|
|
325
|
+
try:
|
|
326
|
+
predictions = target_method(validated_data)
|
|
327
|
+
predictions = _sanitize_prophet_output(predictions)
|
|
328
|
+
return predictions
|
|
329
|
+
except Exception as e:
|
|
330
|
+
if "numpy._core.numeric" in str(e):
|
|
331
|
+
raise RuntimeError(
|
|
332
|
+
f"Prophet model logging failed due to NumPy compatibility issue: {e}. "
|
|
333
|
+
f"Try using compatible NumPy versions (e.g., 1.24.x or 1.26.x) in pip_requirements "
|
|
334
|
+
f"with relax_version=False."
|
|
335
|
+
) from e
|
|
336
|
+
else:
|
|
337
|
+
raise
|
|
338
|
+
elif target_method_name == "predict_components":
|
|
339
|
+
predictions = target_method(validated_data)
|
|
340
|
+
predictions = _sanitize_prophet_output(predictions)
|
|
341
|
+
else:
|
|
342
|
+
raise ValueError(f"Unsupported target method: {target_method_name}")
|
|
343
|
+
|
|
344
|
+
return predictions
|
|
345
|
+
|
|
346
|
+
model_meta = handlers_utils.validate_signature(
|
|
347
|
+
model=model,
|
|
348
|
+
model_meta=model_meta,
|
|
349
|
+
target_methods=target_methods,
|
|
350
|
+
sample_input_data=sample_input_data,
|
|
351
|
+
get_prediction_fn=get_prediction,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
model_meta.task = model_types.Task.UNKNOWN # Prophet is forecasting, which isn't in standard tasks
|
|
355
|
+
|
|
356
|
+
# Save the Prophet model using cloudpickle
|
|
357
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
358
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
|
359
|
+
|
|
360
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
|
361
|
+
cloudpickle.dump(model, f)
|
|
362
|
+
|
|
363
|
+
# Create model blob metadata with column mapping options
|
|
364
|
+
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
|
365
|
+
|
|
366
|
+
options: model_meta_schema.ProphetModelBlobOptions = {}
|
|
367
|
+
if date_column is not None:
|
|
368
|
+
options["date_column"] = date_column
|
|
369
|
+
if target_column is not None:
|
|
370
|
+
options["target_column"] = target_column
|
|
371
|
+
|
|
372
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
|
373
|
+
name=name,
|
|
374
|
+
model_type=cls.HANDLER_TYPE,
|
|
375
|
+
handler_version=cls.HANDLER_VERSION,
|
|
376
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
|
377
|
+
options=options if options else {},
|
|
378
|
+
)
|
|
379
|
+
model_meta.models[name] = base_meta
|
|
380
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
|
381
|
+
|
|
382
|
+
# Add Prophet dependencies
|
|
383
|
+
model_meta.env.include_if_absent(
|
|
384
|
+
[
|
|
385
|
+
model_env.ModelDependency(requirement="prophet", pip_name="prophet"),
|
|
386
|
+
model_env.ModelDependency(requirement="pandas", pip_name="pandas"),
|
|
387
|
+
model_env.ModelDependency(requirement="numpy", pip_name="numpy"),
|
|
388
|
+
],
|
|
389
|
+
check_local_version=True,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
@classmethod
|
|
393
|
+
def load_model(
|
|
394
|
+
cls,
|
|
395
|
+
name: str,
|
|
396
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
397
|
+
model_blobs_dir_path: str,
|
|
398
|
+
**kwargs: Unpack[model_types.ProphetLoadOptions],
|
|
399
|
+
) -> "prophet.Prophet":
|
|
400
|
+
"""Load Prophet model from storage.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
name: Name of the model
|
|
404
|
+
model_meta: The model metadata
|
|
405
|
+
model_blobs_dir_path: Directory containing model files
|
|
406
|
+
**kwargs: Additional load options
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
The loaded Prophet model
|
|
410
|
+
"""
|
|
411
|
+
import prophet
|
|
412
|
+
|
|
413
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
414
|
+
model_blobs_metadata = model_meta.models
|
|
415
|
+
model_blob_metadata = model_blobs_metadata[name]
|
|
416
|
+
model_blob_filename = model_blob_metadata.path
|
|
417
|
+
|
|
418
|
+
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
|
419
|
+
model = cloudpickle.load(f)
|
|
420
|
+
|
|
421
|
+
assert isinstance(model, prophet.Prophet)
|
|
422
|
+
return model
|
|
423
|
+
|
|
424
|
+
@classmethod
|
|
425
|
+
def convert_as_custom_model(
|
|
426
|
+
cls,
|
|
427
|
+
raw_model: "prophet.Prophet",
|
|
428
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
429
|
+
background_data: Optional[pd.DataFrame] = None,
|
|
430
|
+
**kwargs: Unpack[model_types.ProphetLoadOptions],
|
|
431
|
+
) -> custom_model.CustomModel:
|
|
432
|
+
"""Convert Prophet model to CustomModel for unified inference interface.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
raw_model: The original Prophet model
|
|
436
|
+
model_meta: The model metadata
|
|
437
|
+
background_data: Background data for explanations (not used for Prophet)
|
|
438
|
+
**kwargs: Additional options
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
CustomModel wrapper for the Prophet model
|
|
442
|
+
"""
|
|
443
|
+
from snowflake.ml.model import custom_model
|
|
444
|
+
|
|
445
|
+
model_blob_meta = next(iter(model_meta.models.values()))
|
|
446
|
+
date_column: Optional[str] = cast(Optional[str], model_blob_meta.options.get("date_column", None))
|
|
447
|
+
target_column: Optional[str] = cast(Optional[str], model_blob_meta.options.get("target_column", None))
|
|
448
|
+
|
|
449
|
+
def _create_custom_model(
|
|
450
|
+
raw_model: "prophet.Prophet",
|
|
451
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
452
|
+
date_column: Optional[str],
|
|
453
|
+
target_column: Optional[str],
|
|
454
|
+
) -> type[custom_model.CustomModel]:
|
|
455
|
+
"""Create custom model class for Prophet."""
|
|
456
|
+
|
|
457
|
+
def fn_factory(
|
|
458
|
+
raw_model: "prophet.Prophet",
|
|
459
|
+
signature: model_signature.ModelSignature,
|
|
460
|
+
target_method: str,
|
|
461
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
|
462
|
+
"""Factory function to create method implementations."""
|
|
463
|
+
|
|
464
|
+
@custom_model.inference_api
|
|
465
|
+
def predict_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
|
466
|
+
"""Predict method for Prophet forecasting.
|
|
467
|
+
|
|
468
|
+
For forecasting, users should provide a DataFrame with:
|
|
469
|
+
- 'ds' column (or custom date column name): dates for which to generate forecasts
|
|
470
|
+
- 'y' column (or custom target column name): can be NaN for future periods to forecast
|
|
471
|
+
- Additional regressor columns if the model was trained with them
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
self: The custom model instance
|
|
475
|
+
X: Input DataFrame with dates and optional target values for forecasting
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
DataFrame containing Prophet predictions with columns like ds, yhat, yhat_lower, etc.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
ValueError: If column normalization fails or method is unsupported
|
|
482
|
+
RuntimeError: If NumPy compatibility issues are detected during validation or prediction
|
|
483
|
+
"""
|
|
484
|
+
try:
|
|
485
|
+
normalized_data = _normalize_column_names(
|
|
486
|
+
X, date_column=date_column, target_column=target_column
|
|
487
|
+
)
|
|
488
|
+
except Exception as e:
|
|
489
|
+
raise ValueError(f"Failed to normalize column names: {e}") from e
|
|
490
|
+
|
|
491
|
+
# Validate input format with runtime error handling
|
|
492
|
+
try:
|
|
493
|
+
validated_data = _validate_prophet_data_format(normalized_data)
|
|
494
|
+
except Exception as e:
|
|
495
|
+
if "numpy._core.numeric" in str(e):
|
|
496
|
+
raise RuntimeError(
|
|
497
|
+
f"Prophet input validation failed in Snowflake runtime due to "
|
|
498
|
+
f"NumPy compatibility: {e}. Redeploy model with compatible dependency versions "
|
|
499
|
+
f"(e.g., NumPy 1.24.x or 1.26.x, Prophet 1.1.x) in pip_requirements "
|
|
500
|
+
f"with relax_version=False."
|
|
501
|
+
) from e
|
|
502
|
+
else:
|
|
503
|
+
raise
|
|
504
|
+
|
|
505
|
+
# Generate predictions using Prophet with runtime error handling
|
|
506
|
+
if target_method == "predict":
|
|
507
|
+
try:
|
|
508
|
+
predictions = raw_model.predict(validated_data)
|
|
509
|
+
# Sanitize output to remove columns with problematic names
|
|
510
|
+
predictions = _sanitize_prophet_output(predictions)
|
|
511
|
+
except Exception as e:
|
|
512
|
+
if "numpy._core.numeric" in str(e) or "np.float_" in str(e):
|
|
513
|
+
raise RuntimeError(
|
|
514
|
+
f"Prophet prediction failed in Snowflake runtime due to NumPy compatibility: {e}. "
|
|
515
|
+
f"This indicates Prophet's internal NumPy operations are incompatible. "
|
|
516
|
+
f"Redeploy with compatible dependency versions in pip_requirements."
|
|
517
|
+
) from e
|
|
518
|
+
else:
|
|
519
|
+
raise
|
|
520
|
+
else:
|
|
521
|
+
raise ValueError(f"Unsupported method: {target_method}")
|
|
522
|
+
|
|
523
|
+
# Prophet returns many columns, but we only want the ones in our signature
|
|
524
|
+
# Filter to only the columns we expect based on our signature
|
|
525
|
+
expected_columns = [spec.name for spec in signature.outputs]
|
|
526
|
+
available_columns = [col for col in expected_columns if col in predictions.columns]
|
|
527
|
+
|
|
528
|
+
# Fill missing columns with zeros to match the expected signature
|
|
529
|
+
filtered_predictions = predictions[available_columns].copy()
|
|
530
|
+
|
|
531
|
+
# Add missing columns with zeros if they're not present
|
|
532
|
+
for col_name in expected_columns:
|
|
533
|
+
if col_name not in filtered_predictions.columns:
|
|
534
|
+
# Add missing seasonal component columns with zeros
|
|
535
|
+
if col_name in ["weekly", "yearly", "daily"]:
|
|
536
|
+
filtered_predictions[col_name] = 0.0
|
|
537
|
+
else:
|
|
538
|
+
# For required columns like ds, yhat, etc., this would be an error
|
|
539
|
+
raise ValueError(f"Required column '{col_name}' missing from Prophet output")
|
|
540
|
+
|
|
541
|
+
# Reorder columns to match signature order
|
|
542
|
+
filtered_predictions = filtered_predictions[expected_columns]
|
|
543
|
+
|
|
544
|
+
# Ensure the output matches the expected signature
|
|
545
|
+
return model_signature_utils.rename_pandas_df(filtered_predictions, signature.outputs)
|
|
546
|
+
|
|
547
|
+
return predict_fn
|
|
548
|
+
|
|
549
|
+
# Create method dictionary for the custom model class
|
|
550
|
+
type_method_dict = {}
|
|
551
|
+
for target_method_name, sig in model_meta.signatures.items():
|
|
552
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
|
553
|
+
|
|
554
|
+
# Create the custom model class
|
|
555
|
+
_ProphetModel = type(
|
|
556
|
+
"_ProphetModel",
|
|
557
|
+
(custom_model.CustomModel,),
|
|
558
|
+
type_method_dict,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
return _ProphetModel
|
|
562
|
+
|
|
563
|
+
_ProphetModel = _create_custom_model(raw_model, model_meta, date_column, target_column)
|
|
564
|
+
prophet_model = _ProphetModel(custom_model.ModelContext())
|
|
565
|
+
|
|
566
|
+
return prophet_model
|
|
@@ -116,6 +116,8 @@ def create_model_metadata(
|
|
|
116
116
|
if embed_local_ml_library:
|
|
117
117
|
env.snowpark_ml_version = f"{snowml_version.VERSION}+{file_utils.hash_directory(path_to_copy)}"
|
|
118
118
|
|
|
119
|
+
# Persist full method_options
|
|
120
|
+
method_options: dict[str, dict[str, Any]] = kwargs.pop("method_options", {})
|
|
119
121
|
model_meta = ModelMetadata(
|
|
120
122
|
name=name,
|
|
121
123
|
env=env,
|
|
@@ -124,6 +126,7 @@ def create_model_metadata(
|
|
|
124
126
|
signatures=signatures,
|
|
125
127
|
function_properties=function_properties,
|
|
126
128
|
task=task,
|
|
129
|
+
method_options=method_options,
|
|
127
130
|
)
|
|
128
131
|
|
|
129
132
|
code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
|
|
@@ -256,6 +259,7 @@ class ModelMetadata:
|
|
|
256
259
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
|
257
260
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
|
258
261
|
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
|
262
|
+
method_options: Optional[dict[str, dict[str, Any]]] = None,
|
|
259
263
|
) -> None:
|
|
260
264
|
self.name = name
|
|
261
265
|
self.signatures: dict[str, model_signature.ModelSignature] = dict()
|
|
@@ -283,6 +287,7 @@ class ModelMetadata:
|
|
|
283
287
|
|
|
284
288
|
self.task: model_types.Task = task
|
|
285
289
|
self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
|
|
290
|
+
self.method_options: dict[str, dict[str, Any]] = method_options or {}
|
|
286
291
|
|
|
287
292
|
@property
|
|
288
293
|
def min_snowpark_ml_version(self) -> str:
|
|
@@ -342,6 +347,7 @@ class ModelMetadata:
|
|
|
342
347
|
else None
|
|
343
348
|
),
|
|
344
349
|
"function_properties": self.function_properties,
|
|
350
|
+
"method_options": self.method_options,
|
|
345
351
|
}
|
|
346
352
|
)
|
|
347
353
|
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
|
@@ -381,6 +387,7 @@ class ModelMetadata:
|
|
|
381
387
|
task=loaded_meta.get("task", model_types.Task.UNKNOWN.value),
|
|
382
388
|
explainability=loaded_meta.get("explainability", None),
|
|
383
389
|
function_properties=loaded_meta.get("function_properties", {}),
|
|
390
|
+
method_options=loaded_meta.get("method_options", {}),
|
|
384
391
|
)
|
|
385
392
|
|
|
386
393
|
@classmethod
|
|
@@ -436,4 +443,5 @@ class ModelMetadata:
|
|
|
436
443
|
task=model_types.Task(model_dict.get("task", model_types.Task.UNKNOWN.value)),
|
|
437
444
|
explain_algorithm=explanation_algorithm,
|
|
438
445
|
function_properties=model_dict.get("function_properties", {}),
|
|
446
|
+
method_options=model_dict.get("method_options", {}),
|
|
439
447
|
)
|
|
@@ -84,6 +84,11 @@ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
|
|
|
84
84
|
batch_size: Required[int]
|
|
85
85
|
|
|
86
86
|
|
|
87
|
+
class ProphetModelBlobOptions(BaseModelBlobOptions):
|
|
88
|
+
date_column: NotRequired[Optional[str]]
|
|
89
|
+
target_column: NotRequired[Optional[str]]
|
|
90
|
+
|
|
91
|
+
|
|
87
92
|
ModelBlobOptions = Union[
|
|
88
93
|
BaseModelBlobOptions,
|
|
89
94
|
CatBoostModelBlobOptions,
|
|
@@ -94,6 +99,7 @@ ModelBlobOptions = Union[
|
|
|
94
99
|
TorchScriptModelBlobOptions,
|
|
95
100
|
TensorflowModelBlobOptions,
|
|
96
101
|
SentenceTransformersModelBlobOptions,
|
|
102
|
+
ProphetModelBlobOptions,
|
|
97
103
|
]
|
|
98
104
|
|
|
99
105
|
|
|
@@ -125,6 +131,7 @@ class ModelMetadataDict(TypedDict):
|
|
|
125
131
|
task: Required[str]
|
|
126
132
|
explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
|
|
127
133
|
function_properties: NotRequired[dict[str, dict[str, Any]]]
|
|
134
|
+
method_options: NotRequired[dict[str, dict[str, Any]]]
|
|
128
135
|
|
|
129
136
|
|
|
130
137
|
class ModelExplainAlgorithm(Enum):
|
|
@@ -24,11 +24,11 @@ REQUIREMENTS = [
|
|
|
24
24
|
"scikit-learn<1.8",
|
|
25
25
|
"scipy>=1.9,<2",
|
|
26
26
|
"shap>=0.46.0,<1",
|
|
27
|
-
"snowflake-connector-python>=3.
|
|
27
|
+
"snowflake-connector-python>=3.17.0,<4",
|
|
28
28
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
|
29
29
|
"snowflake.core>=1.0.2,<2",
|
|
30
30
|
"sqlparse>=0.4,<1",
|
|
31
31
|
"tqdm<5",
|
|
32
32
|
"typing-extensions>=4.1.0,<5",
|
|
33
|
-
"xgboost
|
|
33
|
+
"xgboost<4",
|
|
34
34
|
]
|