snowflake-ml-python 1.17.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +19 -7
- snowflake/ml/feature_store/feature_store.py +236 -61
- snowflake/ml/jobs/_utils/constants.py +12 -1
- snowflake/ml/jobs/_utils/payload_utils.py +7 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +5 -0
- snowflake/ml/jobs/job.py +16 -2
- snowflake/ml/jobs/manager.py +12 -1
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +129 -11
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +25 -1
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +33 -32
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
|
4
|
+
|
|
5
|
+
import cloudpickle
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from typing_extensions import TypeGuard, Unpack
|
|
8
|
+
|
|
9
|
+
from snowflake.ml._internal import type_utils
|
|
10
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
|
11
|
+
from snowflake.ml.model._packager.model_env import model_env
|
|
12
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
14
|
+
from snowflake.ml.model._packager.model_meta import (
|
|
15
|
+
model_blob_meta,
|
|
16
|
+
model_meta as model_meta_api,
|
|
17
|
+
)
|
|
18
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import prophet
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _normalize_column_names(
|
|
27
|
+
data: pd.DataFrame,
|
|
28
|
+
date_column: Optional[str] = None,
|
|
29
|
+
target_column: Optional[str] = None,
|
|
30
|
+
) -> pd.DataFrame:
|
|
31
|
+
"""Normalize user column names to Prophet's required 'ds' and 'y' format.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
data: Input DataFrame with user's column names
|
|
35
|
+
date_column: Name of the date column to map to 'ds'
|
|
36
|
+
target_column: Name of the target column to map to 'y'
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
DataFrame with columns renamed to Prophet format
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If specified columns don't exist or if there are naming conflicts
|
|
43
|
+
"""
|
|
44
|
+
if date_column is None and target_column is None:
|
|
45
|
+
return data
|
|
46
|
+
|
|
47
|
+
data = data.copy()
|
|
48
|
+
|
|
49
|
+
if date_column is not None:
|
|
50
|
+
if date_column not in data.columns:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Specified date_column '{date_column}' not found in DataFrame. "
|
|
53
|
+
f"Available columns: {list(data.columns)}"
|
|
54
|
+
)
|
|
55
|
+
if date_column != "ds":
|
|
56
|
+
# Check if 'ds' already exists as a different column
|
|
57
|
+
if "ds" in data.columns:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Cannot rename '{date_column}' to 'ds' because 'ds' already exists in the DataFrame. "
|
|
60
|
+
f"Please either: (1) rename or remove the existing 'ds' column, or "
|
|
61
|
+
f"(2) if 'ds' is your date column, set date_column='ds' instead."
|
|
62
|
+
)
|
|
63
|
+
data = data.rename(columns={date_column: "ds"})
|
|
64
|
+
|
|
65
|
+
if target_column is not None:
|
|
66
|
+
if target_column not in data.columns:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Specified target_column '{target_column}' not found in DataFrame. "
|
|
69
|
+
f"Available columns: {list(data.columns)}"
|
|
70
|
+
)
|
|
71
|
+
if target_column != "y":
|
|
72
|
+
# Check if 'y' already exists as a different column
|
|
73
|
+
if "y" in data.columns:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Cannot rename '{target_column}' to 'y' because 'y' already exists in the DataFrame. "
|
|
76
|
+
f"Please either: (1) rename or remove the existing 'y' column, or "
|
|
77
|
+
f"(2) if 'y' is your target column, set target_column='y' instead."
|
|
78
|
+
)
|
|
79
|
+
data = data.rename(columns={target_column: "y"})
|
|
80
|
+
|
|
81
|
+
return data
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _sanitize_prophet_output(predictions: pd.DataFrame) -> pd.DataFrame:
|
|
85
|
+
"""Sanitize Prophet prediction output to have SQL-safe column names.
|
|
86
|
+
|
|
87
|
+
Prophet may include holiday columns with names containing spaces (e.g., "Christmas Day")
|
|
88
|
+
which cannot be used as unquoted SQL identifiers in Snowflake. This function normalizes all
|
|
89
|
+
column names to be valid unquoted SQL identifiers by replacing spaces with underscores and
|
|
90
|
+
removing special characters.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
predictions: Raw prediction DataFrame from Prophet
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
DataFrame with normalized SQL-safe column names
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If predictions DataFrame is empty, has no columns, or is missing required
|
|
100
|
+
columns 'ds' and 'yhat'
|
|
101
|
+
"""
|
|
102
|
+
# Check if predictions is empty or has no columns
|
|
103
|
+
if predictions is None or len(predictions.columns) == 0:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Prophet predictions DataFrame is empty or has no columns. "
|
|
106
|
+
f"DataFrame shape: {predictions.shape if predictions is not None else 'None'}, "
|
|
107
|
+
f"Type: {type(predictions)}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if "ds" not in predictions.columns or "yhat" not in predictions.columns:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Prophet predictions missing required columns 'ds' and 'yhat'. "
|
|
113
|
+
f"Available columns: {list(predictions.columns)}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Normalize all column names to be SQL-safe
|
|
117
|
+
normalized_columns = {col: handlers_utils.normalize_column_name(col) for col in predictions.columns}
|
|
118
|
+
|
|
119
|
+
# Check for conflicts after normalization
|
|
120
|
+
normalized_values = list(normalized_columns.values())
|
|
121
|
+
if len(normalized_values) != len(set(normalized_values)):
|
|
122
|
+
# Find duplicates
|
|
123
|
+
seen = set()
|
|
124
|
+
duplicates = []
|
|
125
|
+
for val in normalized_values:
|
|
126
|
+
if val in seen:
|
|
127
|
+
duplicates.append(val)
|
|
128
|
+
seen.add(val)
|
|
129
|
+
|
|
130
|
+
logger.warning(
|
|
131
|
+
f"Column name normalization resulted in duplicates: {duplicates}. "
|
|
132
|
+
f"Original columns: {[k for k, v in normalized_columns.items() if v in duplicates]}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Rename columns
|
|
136
|
+
sanitized_predictions = predictions.rename(columns=normalized_columns)
|
|
137
|
+
|
|
138
|
+
return sanitized_predictions
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _validate_prophet_data_format(data: model_types.SupportedDataType) -> pd.DataFrame:
|
|
142
|
+
"""Validate that input data follows Prophet's required format with 'ds' and 'y' columns.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
data: Input data to validate
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
DataFrame with validated Prophet format and proper data types
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If data doesn't meet Prophet requirements
|
|
152
|
+
"""
|
|
153
|
+
if not isinstance(data, pd.DataFrame):
|
|
154
|
+
raise ValueError("Prophet models require pandas DataFrame input with 'ds' and 'y' columns")
|
|
155
|
+
|
|
156
|
+
if "ds" not in data.columns:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Prophet models require a 'ds' column containing dates. "
|
|
159
|
+
"If your date column has a different name, use the 'date_column' parameter "
|
|
160
|
+
"when saving the model to map it to 'ds'."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if "y" not in data.columns:
|
|
164
|
+
# Allow 'y' column with NaN values for future predictions
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Prophet models require a 'y' column containing values (can be NaN for future periods). "
|
|
167
|
+
"If your target column has a different name, use the 'target_column' parameter "
|
|
168
|
+
"when saving the model to map it to 'y'."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
validated_data = data.copy()
|
|
172
|
+
|
|
173
|
+
# Convert datetime column - this handles string timestamps from Snowflake
|
|
174
|
+
try:
|
|
175
|
+
validated_data["ds"] = pd.to_datetime(validated_data["ds"])
|
|
176
|
+
except Exception as e:
|
|
177
|
+
raise ValueError(f"'ds' column must contain valid datetime values: {e}")
|
|
178
|
+
|
|
179
|
+
# Convert numeric columns to proper float types
|
|
180
|
+
for col in validated_data.columns:
|
|
181
|
+
if col != "ds":
|
|
182
|
+
try:
|
|
183
|
+
# Convert to numeric, coercing errors to NaN
|
|
184
|
+
original_col = validated_data[col]
|
|
185
|
+
validated_data[col] = pd.to_numeric(validated_data[col], errors="coerce")
|
|
186
|
+
|
|
187
|
+
# Force explicit dtype conversion to ensure numpy operations work
|
|
188
|
+
validated_data[col] = validated_data[col].astype("float64")
|
|
189
|
+
|
|
190
|
+
logger.debug(f"Converted column '{col}' from {original_col.dtype} to {validated_data[col].dtype}")
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
# If conversion fails completely, provide detailed error
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Column '{col}' contains data that cannot be converted to numeric: {e}. "
|
|
196
|
+
f"Original dtype: {validated_data[col].dtype}, sample values: {validated_data[col].head().tolist()}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return validated_data
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@final
|
|
203
|
+
class ProphetHandler(_base.BaseModelHandler["prophet.Prophet"]):
|
|
204
|
+
"""Handler for prophet time series forecasting models."""
|
|
205
|
+
|
|
206
|
+
HANDLER_TYPE = "prophet"
|
|
207
|
+
HANDLER_VERSION = "2025-01-01"
|
|
208
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
|
209
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
|
210
|
+
|
|
211
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
|
212
|
+
DEFAULT_TARGET_METHODS = ["predict"]
|
|
213
|
+
|
|
214
|
+
# Prophet models require sample data to infer signatures because the data may contain regressors.
|
|
215
|
+
IS_AUTO_SIGNATURE = False
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def can_handle(
|
|
219
|
+
cls,
|
|
220
|
+
model: model_types.SupportedModelType,
|
|
221
|
+
) -> TypeGuard["prophet.Prophet"]:
|
|
222
|
+
"""Check if this handler can process the given model.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
model: The model object to check
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
True if this is a Prophet model, False otherwise
|
|
229
|
+
"""
|
|
230
|
+
return type_utils.LazyType("prophet.Prophet").isinstance(model) and any(
|
|
231
|
+
(hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def cast_model(
|
|
236
|
+
cls,
|
|
237
|
+
model: model_types.SupportedModelType,
|
|
238
|
+
) -> "prophet.Prophet":
|
|
239
|
+
"""Cast the model to Prophet type.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
model: The model object
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
The model cast as Prophet
|
|
246
|
+
"""
|
|
247
|
+
import prophet
|
|
248
|
+
|
|
249
|
+
assert isinstance(model, prophet.Prophet)
|
|
250
|
+
return cast("prophet.Prophet", model)
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def save_model(
|
|
254
|
+
cls,
|
|
255
|
+
name: str,
|
|
256
|
+
model: "prophet.Prophet",
|
|
257
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
258
|
+
model_blobs_dir_path: str,
|
|
259
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
|
260
|
+
is_sub_model: Optional[bool] = False,
|
|
261
|
+
**kwargs: Unpack[model_types.ProphetSaveOptions],
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Save Prophet model and metadata.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
name: Name of the model
|
|
267
|
+
model: The Prophet model object
|
|
268
|
+
model_meta: The model metadata
|
|
269
|
+
model_blobs_dir_path: Directory to save model files
|
|
270
|
+
sample_input_data: Sample input data for signature inference
|
|
271
|
+
is_sub_model: Whether this is a sub-model
|
|
272
|
+
**kwargs: Additional save options including date_column and target_column for column mapping
|
|
273
|
+
|
|
274
|
+
Raises:
|
|
275
|
+
ValueError: If sample_input_data is not a pandas DataFrame or if column mapping fails
|
|
276
|
+
"""
|
|
277
|
+
import prophet
|
|
278
|
+
|
|
279
|
+
assert isinstance(model, prophet.Prophet)
|
|
280
|
+
|
|
281
|
+
date_column = kwargs.pop("date_column", None)
|
|
282
|
+
target_column = kwargs.pop("target_column", None)
|
|
283
|
+
|
|
284
|
+
if not is_sub_model:
|
|
285
|
+
# Validate sample input data if provided
|
|
286
|
+
if sample_input_data is not None:
|
|
287
|
+
if isinstance(sample_input_data, pd.DataFrame):
|
|
288
|
+
# Normalize for validation purposes
|
|
289
|
+
normalized_sample = _normalize_column_names(
|
|
290
|
+
sample_input_data.copy(),
|
|
291
|
+
date_column=date_column,
|
|
292
|
+
target_column=target_column,
|
|
293
|
+
)
|
|
294
|
+
_validate_prophet_data_format(normalized_sample)
|
|
295
|
+
else:
|
|
296
|
+
raise ValueError("Prophet models require pandas DataFrame sample input data")
|
|
297
|
+
|
|
298
|
+
target_methods = handlers_utils.get_target_methods(
|
|
299
|
+
model=model,
|
|
300
|
+
target_methods=kwargs.pop("target_methods", None),
|
|
301
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def get_prediction(
|
|
305
|
+
target_method_name: str,
|
|
306
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
|
307
|
+
) -> model_types.SupportedLocalDataType:
|
|
308
|
+
"""Generate predictions for signature inference."""
|
|
309
|
+
if not isinstance(sample_input_data, pd.DataFrame):
|
|
310
|
+
raise ValueError("Prophet requires pandas DataFrame input")
|
|
311
|
+
|
|
312
|
+
normalized_data = _normalize_column_names(
|
|
313
|
+
sample_input_data,
|
|
314
|
+
date_column=date_column,
|
|
315
|
+
target_column=target_column,
|
|
316
|
+
)
|
|
317
|
+
validated_data = _validate_prophet_data_format(normalized_data)
|
|
318
|
+
|
|
319
|
+
target_method = getattr(model, target_method_name, None)
|
|
320
|
+
if not callable(target_method):
|
|
321
|
+
raise ValueError(f"Method {target_method_name} not found on Prophet model")
|
|
322
|
+
|
|
323
|
+
if target_method_name == "predict":
|
|
324
|
+
# Use the input data as the future dataframe for prediction
|
|
325
|
+
try:
|
|
326
|
+
predictions = target_method(validated_data)
|
|
327
|
+
predictions = _sanitize_prophet_output(predictions)
|
|
328
|
+
return predictions
|
|
329
|
+
except Exception as e:
|
|
330
|
+
if "numpy._core.numeric" in str(e):
|
|
331
|
+
raise RuntimeError(
|
|
332
|
+
f"Prophet model logging failed due to NumPy compatibility issue: {e}. "
|
|
333
|
+
f"Try using compatible NumPy versions (e.g., 1.24.x or 1.26.x) in pip_requirements "
|
|
334
|
+
f"with relax_version=False."
|
|
335
|
+
) from e
|
|
336
|
+
else:
|
|
337
|
+
raise
|
|
338
|
+
elif target_method_name == "predict_components":
|
|
339
|
+
predictions = target_method(validated_data)
|
|
340
|
+
predictions = _sanitize_prophet_output(predictions)
|
|
341
|
+
else:
|
|
342
|
+
raise ValueError(f"Unsupported target method: {target_method_name}")
|
|
343
|
+
|
|
344
|
+
return predictions
|
|
345
|
+
|
|
346
|
+
model_meta = handlers_utils.validate_signature(
|
|
347
|
+
model=model,
|
|
348
|
+
model_meta=model_meta,
|
|
349
|
+
target_methods=target_methods,
|
|
350
|
+
sample_input_data=sample_input_data,
|
|
351
|
+
get_prediction_fn=get_prediction,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
model_meta.task = model_types.Task.UNKNOWN # Prophet is forecasting, which isn't in standard tasks
|
|
355
|
+
|
|
356
|
+
# Save the Prophet model using cloudpickle
|
|
357
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
358
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
|
359
|
+
|
|
360
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
|
361
|
+
cloudpickle.dump(model, f)
|
|
362
|
+
|
|
363
|
+
# Create model blob metadata with column mapping options
|
|
364
|
+
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
|
365
|
+
|
|
366
|
+
options: model_meta_schema.ProphetModelBlobOptions = {}
|
|
367
|
+
if date_column is not None:
|
|
368
|
+
options["date_column"] = date_column
|
|
369
|
+
if target_column is not None:
|
|
370
|
+
options["target_column"] = target_column
|
|
371
|
+
|
|
372
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
|
373
|
+
name=name,
|
|
374
|
+
model_type=cls.HANDLER_TYPE,
|
|
375
|
+
handler_version=cls.HANDLER_VERSION,
|
|
376
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
|
377
|
+
options=options if options else {},
|
|
378
|
+
)
|
|
379
|
+
model_meta.models[name] = base_meta
|
|
380
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
|
381
|
+
|
|
382
|
+
# Add Prophet dependencies
|
|
383
|
+
model_meta.env.include_if_absent(
|
|
384
|
+
[
|
|
385
|
+
model_env.ModelDependency(requirement="prophet", pip_name="prophet"),
|
|
386
|
+
model_env.ModelDependency(requirement="pandas", pip_name="pandas"),
|
|
387
|
+
model_env.ModelDependency(requirement="numpy", pip_name="numpy"),
|
|
388
|
+
],
|
|
389
|
+
check_local_version=True,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
@classmethod
|
|
393
|
+
def load_model(
|
|
394
|
+
cls,
|
|
395
|
+
name: str,
|
|
396
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
397
|
+
model_blobs_dir_path: str,
|
|
398
|
+
**kwargs: Unpack[model_types.ProphetLoadOptions],
|
|
399
|
+
) -> "prophet.Prophet":
|
|
400
|
+
"""Load Prophet model from storage.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
name: Name of the model
|
|
404
|
+
model_meta: The model metadata
|
|
405
|
+
model_blobs_dir_path: Directory containing model files
|
|
406
|
+
**kwargs: Additional load options
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
The loaded Prophet model
|
|
410
|
+
"""
|
|
411
|
+
import prophet
|
|
412
|
+
|
|
413
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
|
414
|
+
model_blobs_metadata = model_meta.models
|
|
415
|
+
model_blob_metadata = model_blobs_metadata[name]
|
|
416
|
+
model_blob_filename = model_blob_metadata.path
|
|
417
|
+
|
|
418
|
+
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
|
419
|
+
model = cloudpickle.load(f)
|
|
420
|
+
|
|
421
|
+
assert isinstance(model, prophet.Prophet)
|
|
422
|
+
return model
|
|
423
|
+
|
|
424
|
+
@classmethod
|
|
425
|
+
def convert_as_custom_model(
|
|
426
|
+
cls,
|
|
427
|
+
raw_model: "prophet.Prophet",
|
|
428
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
429
|
+
background_data: Optional[pd.DataFrame] = None,
|
|
430
|
+
**kwargs: Unpack[model_types.ProphetLoadOptions],
|
|
431
|
+
) -> custom_model.CustomModel:
|
|
432
|
+
"""Convert Prophet model to CustomModel for unified inference interface.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
raw_model: The original Prophet model
|
|
436
|
+
model_meta: The model metadata
|
|
437
|
+
background_data: Background data for explanations (not used for Prophet)
|
|
438
|
+
**kwargs: Additional options
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
CustomModel wrapper for the Prophet model
|
|
442
|
+
"""
|
|
443
|
+
from snowflake.ml.model import custom_model
|
|
444
|
+
|
|
445
|
+
model_blob_meta = next(iter(model_meta.models.values()))
|
|
446
|
+
date_column: Optional[str] = cast(Optional[str], model_blob_meta.options.get("date_column", None))
|
|
447
|
+
target_column: Optional[str] = cast(Optional[str], model_blob_meta.options.get("target_column", None))
|
|
448
|
+
|
|
449
|
+
def _create_custom_model(
|
|
450
|
+
raw_model: "prophet.Prophet",
|
|
451
|
+
model_meta: model_meta_api.ModelMetadata,
|
|
452
|
+
date_column: Optional[str],
|
|
453
|
+
target_column: Optional[str],
|
|
454
|
+
) -> type[custom_model.CustomModel]:
|
|
455
|
+
"""Create custom model class for Prophet."""
|
|
456
|
+
|
|
457
|
+
def fn_factory(
|
|
458
|
+
raw_model: "prophet.Prophet",
|
|
459
|
+
signature: model_signature.ModelSignature,
|
|
460
|
+
target_method: str,
|
|
461
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
|
462
|
+
"""Factory function to create method implementations."""
|
|
463
|
+
|
|
464
|
+
@custom_model.inference_api
|
|
465
|
+
def predict_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
|
466
|
+
"""Predict method for Prophet forecasting.
|
|
467
|
+
|
|
468
|
+
For forecasting, users should provide a DataFrame with:
|
|
469
|
+
- 'ds' column (or custom date column name): dates for which to generate forecasts
|
|
470
|
+
- 'y' column (or custom target column name): can be NaN for future periods to forecast
|
|
471
|
+
- Additional regressor columns if the model was trained with them
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
self: The custom model instance
|
|
475
|
+
X: Input DataFrame with dates and optional target values for forecasting
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
DataFrame containing Prophet predictions with columns like ds, yhat, yhat_lower, etc.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
ValueError: If column normalization fails or method is unsupported
|
|
482
|
+
RuntimeError: If NumPy compatibility issues are detected during validation or prediction
|
|
483
|
+
"""
|
|
484
|
+
try:
|
|
485
|
+
normalized_data = _normalize_column_names(
|
|
486
|
+
X, date_column=date_column, target_column=target_column
|
|
487
|
+
)
|
|
488
|
+
except Exception as e:
|
|
489
|
+
raise ValueError(f"Failed to normalize column names: {e}") from e
|
|
490
|
+
|
|
491
|
+
# Validate input format with runtime error handling
|
|
492
|
+
try:
|
|
493
|
+
validated_data = _validate_prophet_data_format(normalized_data)
|
|
494
|
+
except Exception as e:
|
|
495
|
+
if "numpy._core.numeric" in str(e):
|
|
496
|
+
raise RuntimeError(
|
|
497
|
+
f"Prophet input validation failed in Snowflake runtime due to "
|
|
498
|
+
f"NumPy compatibility: {e}. Redeploy model with compatible dependency versions "
|
|
499
|
+
f"(e.g., NumPy 1.24.x or 1.26.x, Prophet 1.1.x) in pip_requirements "
|
|
500
|
+
f"with relax_version=False."
|
|
501
|
+
) from e
|
|
502
|
+
else:
|
|
503
|
+
raise
|
|
504
|
+
|
|
505
|
+
# Generate predictions using Prophet with runtime error handling
|
|
506
|
+
if target_method == "predict":
|
|
507
|
+
try:
|
|
508
|
+
predictions = raw_model.predict(validated_data)
|
|
509
|
+
# Sanitize output to remove columns with problematic names
|
|
510
|
+
predictions = _sanitize_prophet_output(predictions)
|
|
511
|
+
except Exception as e:
|
|
512
|
+
if "numpy._core.numeric" in str(e) or "np.float_" in str(e):
|
|
513
|
+
raise RuntimeError(
|
|
514
|
+
f"Prophet prediction failed in Snowflake runtime due to NumPy compatibility: {e}. "
|
|
515
|
+
f"This indicates Prophet's internal NumPy operations are incompatible. "
|
|
516
|
+
f"Redeploy with compatible dependency versions in pip_requirements."
|
|
517
|
+
) from e
|
|
518
|
+
else:
|
|
519
|
+
raise
|
|
520
|
+
else:
|
|
521
|
+
raise ValueError(f"Unsupported method: {target_method}")
|
|
522
|
+
|
|
523
|
+
# Prophet returns many columns, but we only want the ones in our signature
|
|
524
|
+
# Filter to only the columns we expect based on our signature
|
|
525
|
+
expected_columns = [spec.name for spec in signature.outputs]
|
|
526
|
+
available_columns = [col for col in expected_columns if col in predictions.columns]
|
|
527
|
+
|
|
528
|
+
# Fill missing columns with zeros to match the expected signature
|
|
529
|
+
filtered_predictions = predictions[available_columns].copy()
|
|
530
|
+
|
|
531
|
+
# Add missing columns with zeros if they're not present
|
|
532
|
+
for col_name in expected_columns:
|
|
533
|
+
if col_name not in filtered_predictions.columns:
|
|
534
|
+
# Add missing seasonal component columns with zeros
|
|
535
|
+
if col_name in ["weekly", "yearly", "daily"]:
|
|
536
|
+
filtered_predictions[col_name] = 0.0
|
|
537
|
+
else:
|
|
538
|
+
# For required columns like ds, yhat, etc., this would be an error
|
|
539
|
+
raise ValueError(f"Required column '{col_name}' missing from Prophet output")
|
|
540
|
+
|
|
541
|
+
# Reorder columns to match signature order
|
|
542
|
+
filtered_predictions = filtered_predictions[expected_columns]
|
|
543
|
+
|
|
544
|
+
# Ensure the output matches the expected signature
|
|
545
|
+
return model_signature_utils.rename_pandas_df(filtered_predictions, signature.outputs)
|
|
546
|
+
|
|
547
|
+
return predict_fn
|
|
548
|
+
|
|
549
|
+
# Create method dictionary for the custom model class
|
|
550
|
+
type_method_dict = {}
|
|
551
|
+
for target_method_name, sig in model_meta.signatures.items():
|
|
552
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
|
553
|
+
|
|
554
|
+
# Create the custom model class
|
|
555
|
+
_ProphetModel = type(
|
|
556
|
+
"_ProphetModel",
|
|
557
|
+
(custom_model.CustomModel,),
|
|
558
|
+
type_method_dict,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
return _ProphetModel
|
|
562
|
+
|
|
563
|
+
_ProphetModel = _create_custom_model(raw_model, model_meta, date_column, target_column)
|
|
564
|
+
prophet_model = _ProphetModel(custom_model.ModelContext())
|
|
565
|
+
|
|
566
|
+
return prophet_model
|
|
@@ -84,6 +84,11 @@ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
|
|
|
84
84
|
batch_size: Required[int]
|
|
85
85
|
|
|
86
86
|
|
|
87
|
+
class ProphetModelBlobOptions(BaseModelBlobOptions):
|
|
88
|
+
date_column: NotRequired[Optional[str]]
|
|
89
|
+
target_column: NotRequired[Optional[str]]
|
|
90
|
+
|
|
91
|
+
|
|
87
92
|
ModelBlobOptions = Union[
|
|
88
93
|
BaseModelBlobOptions,
|
|
89
94
|
CatBoostModelBlobOptions,
|
|
@@ -94,6 +99,7 @@ ModelBlobOptions = Union[
|
|
|
94
99
|
TorchScriptModelBlobOptions,
|
|
95
100
|
TensorflowModelBlobOptions,
|
|
96
101
|
SentenceTransformersModelBlobOptions,
|
|
102
|
+
ProphetModelBlobOptions,
|
|
97
103
|
]
|
|
98
104
|
|
|
99
105
|
|
snowflake/ml/model/type_hints.py
CHANGED
|
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
|
|
24
24
|
import mlflow
|
|
25
25
|
import numpy as np
|
|
26
26
|
import pandas as pd
|
|
27
|
+
import prophet
|
|
27
28
|
import sentence_transformers
|
|
28
29
|
import sklearn.base
|
|
29
30
|
import sklearn.pipeline
|
|
@@ -81,6 +82,7 @@ SupportedRequireSignatureModelType = Union[
|
|
|
81
82
|
"catboost.CatBoost",
|
|
82
83
|
"lightgbm.LGBMModel",
|
|
83
84
|
"lightgbm.Booster",
|
|
85
|
+
"prophet.Prophet",
|
|
84
86
|
"snowflake.ml.model.custom_model.CustomModel",
|
|
85
87
|
"sklearn.base.BaseEstimator",
|
|
86
88
|
"sklearn.pipeline.Pipeline",
|
|
@@ -113,6 +115,7 @@ Here is all acceptable types of Snowflake native model packaging and its handler
|
|
|
113
115
|
| snowflake.ml.model.custom_model.CustomModel | custom.py | _CustomModelHandler |
|
|
114
116
|
| sklearn.base.BaseEstimator | sklearn.py | _SKLModelHandler |
|
|
115
117
|
| sklearn.pipeline.Pipeline | sklearn.py | _SKLModelHandler |
|
|
118
|
+
| prophet.Prophet | prophet.py | ProphetHandler |
|
|
116
119
|
| xgboost.XGBModel | xgboost.py | _XGBModelHandler |
|
|
117
120
|
| xgboost.Booster | xgboost.py | _XGBModelHandler |
|
|
118
121
|
| lightgbm.LGBMModel | lightgbm.py | _LGBMModelHandler |
|
|
@@ -134,6 +137,7 @@ SupportedModelHandlerType = Literal[
|
|
|
134
137
|
"huggingface_pipeline",
|
|
135
138
|
"lightgbm",
|
|
136
139
|
"mlflow",
|
|
140
|
+
"prophet",
|
|
137
141
|
"pytorch",
|
|
138
142
|
"sentence_transformers",
|
|
139
143
|
"sklearn",
|
|
@@ -248,11 +252,18 @@ class KerasSaveOptions(BaseModelSaveOption):
|
|
|
248
252
|
cuda_version: NotRequired[str]
|
|
249
253
|
|
|
250
254
|
|
|
255
|
+
class ProphetSaveOptions(BaseModelSaveOption):
|
|
256
|
+
target_methods: NotRequired[Sequence[str]]
|
|
257
|
+
date_column: NotRequired[str]
|
|
258
|
+
target_column: NotRequired[str]
|
|
259
|
+
|
|
260
|
+
|
|
251
261
|
ModelSaveOption = Union[
|
|
252
262
|
BaseModelSaveOption,
|
|
253
263
|
CatBoostModelSaveOptions,
|
|
254
264
|
CustomModelSaveOption,
|
|
255
265
|
LGBMModelSaveOptions,
|
|
266
|
+
ProphetSaveOptions,
|
|
256
267
|
SKLModelSaveOptions,
|
|
257
268
|
XGBModelSaveOptions,
|
|
258
269
|
SNOWModelSaveOptions,
|
|
@@ -327,11 +338,16 @@ class KerasLoadOptions(BaseModelLoadOption):
|
|
|
327
338
|
use_gpu: NotRequired[bool]
|
|
328
339
|
|
|
329
340
|
|
|
341
|
+
class ProphetLoadOptions(BaseModelLoadOption):
|
|
342
|
+
...
|
|
343
|
+
|
|
344
|
+
|
|
330
345
|
ModelLoadOption = Union[
|
|
331
346
|
BaseModelLoadOption,
|
|
332
347
|
CatBoostModelLoadOptions,
|
|
333
348
|
CustomModelLoadOption,
|
|
334
349
|
LGBMModelLoadOptions,
|
|
350
|
+
ProphetLoadOptions,
|
|
335
351
|
SKLModelLoadOptions,
|
|
336
352
|
XGBModelLoadOptions,
|
|
337
353
|
SNOWModelLoadOptions,
|
|
@@ -4,6 +4,7 @@ from typing import Any, Collection, Iterable, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
import cloudpickle
|
|
6
6
|
import numpy as np
|
|
7
|
+
from packaging import version
|
|
7
8
|
|
|
8
9
|
import snowflake.snowpark._internal.utils as snowpark_utils
|
|
9
10
|
from snowflake import snowpark
|
|
@@ -59,7 +60,10 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
|
|
|
59
60
|
]
|
|
60
61
|
),
|
|
61
62
|
input_types=[T.BinaryType()],
|
|
62
|
-
packages=[
|
|
63
|
+
packages=[
|
|
64
|
+
f"numpy=={version.parse(np.__version__).major}.*",
|
|
65
|
+
f"cloudpickle=={version.parse(cloudpickle.__version__).major}.*",
|
|
66
|
+
],
|
|
63
67
|
imports=[], # Prevents unnecessary import resolution.
|
|
64
68
|
name=accumulator,
|
|
65
69
|
is_permanent=False,
|
|
@@ -175,7 +179,10 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
|
|
|
175
179
|
]
|
|
176
180
|
),
|
|
177
181
|
input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
|
|
178
|
-
packages=[
|
|
182
|
+
packages=[
|
|
183
|
+
f"numpy=={version.parse(np.__version__).major}.*",
|
|
184
|
+
f"cloudpickle=={version.parse(cloudpickle.__version__).major}.*",
|
|
185
|
+
],
|
|
179
186
|
imports=[], # Prevents unnecessary import resolution.
|
|
180
187
|
name=sharded_dot_and_sum_computer,
|
|
181
188
|
is_permanent=False,
|
snowflake/ml/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
|
2
|
-
VERSION = "1.
|
|
2
|
+
VERSION = "1.18.0"
|