snowflake-ml-python 1.17.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. snowflake/ml/_internal/telemetry.py +3 -2
  2. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  3. snowflake/ml/experiment/callback/keras.py +3 -0
  4. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  5. snowflake/ml/experiment/callback/xgboost.py +3 -0
  6. snowflake/ml/experiment/experiment_tracking.py +19 -7
  7. snowflake/ml/feature_store/feature_store.py +236 -61
  8. snowflake/ml/jobs/_utils/constants.py +12 -1
  9. snowflake/ml/jobs/_utils/payload_utils.py +7 -1
  10. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  11. snowflake/ml/jobs/_utils/types.py +5 -0
  12. snowflake/ml/jobs/job.py +16 -2
  13. snowflake/ml/jobs/manager.py +12 -1
  14. snowflake/ml/model/__init__.py +19 -0
  15. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  16. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  17. snowflake/ml/model/_client/model/model_version_impl.py +129 -11
  18. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  19. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  20. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  22. snowflake/ml/model/_model_composer/model_method/model_method.py +4 -1
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  24. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  25. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -0
  26. snowflake/ml/model/type_hints.py +16 -0
  27. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  28. snowflake/ml/version.py +1 -1
  29. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +25 -1
  30. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +33 -32
  31. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  32. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  33. {snowflake_ml_python-1.17.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,566 @@
1
+ import logging
2
+ import os
3
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
4
+
5
+ import cloudpickle
6
+ import pandas as pd
7
+ from typing_extensions import TypeGuard, Unpack
8
+
9
+ from snowflake.ml._internal import type_utils
10
+ from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
+ from snowflake.ml.model._packager.model_env import model_env
12
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
+ from snowflake.ml.model._packager.model_meta import (
15
+ model_blob_meta,
16
+ model_meta as model_meta_api,
17
+ )
18
+ from snowflake.ml.model._signatures import utils as model_signature_utils
19
+
20
+ if TYPE_CHECKING:
21
+ import prophet
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def _normalize_column_names(
27
+ data: pd.DataFrame,
28
+ date_column: Optional[str] = None,
29
+ target_column: Optional[str] = None,
30
+ ) -> pd.DataFrame:
31
+ """Normalize user column names to Prophet's required 'ds' and 'y' format.
32
+
33
+ Args:
34
+ data: Input DataFrame with user's column names
35
+ date_column: Name of the date column to map to 'ds'
36
+ target_column: Name of the target column to map to 'y'
37
+
38
+ Returns:
39
+ DataFrame with columns renamed to Prophet format
40
+
41
+ Raises:
42
+ ValueError: If specified columns don't exist or if there are naming conflicts
43
+ """
44
+ if date_column is None and target_column is None:
45
+ return data
46
+
47
+ data = data.copy()
48
+
49
+ if date_column is not None:
50
+ if date_column not in data.columns:
51
+ raise ValueError(
52
+ f"Specified date_column '{date_column}' not found in DataFrame. "
53
+ f"Available columns: {list(data.columns)}"
54
+ )
55
+ if date_column != "ds":
56
+ # Check if 'ds' already exists as a different column
57
+ if "ds" in data.columns:
58
+ raise ValueError(
59
+ f"Cannot rename '{date_column}' to 'ds' because 'ds' already exists in the DataFrame. "
60
+ f"Please either: (1) rename or remove the existing 'ds' column, or "
61
+ f"(2) if 'ds' is your date column, set date_column='ds' instead."
62
+ )
63
+ data = data.rename(columns={date_column: "ds"})
64
+
65
+ if target_column is not None:
66
+ if target_column not in data.columns:
67
+ raise ValueError(
68
+ f"Specified target_column '{target_column}' not found in DataFrame. "
69
+ f"Available columns: {list(data.columns)}"
70
+ )
71
+ if target_column != "y":
72
+ # Check if 'y' already exists as a different column
73
+ if "y" in data.columns:
74
+ raise ValueError(
75
+ f"Cannot rename '{target_column}' to 'y' because 'y' already exists in the DataFrame. "
76
+ f"Please either: (1) rename or remove the existing 'y' column, or "
77
+ f"(2) if 'y' is your target column, set target_column='y' instead."
78
+ )
79
+ data = data.rename(columns={target_column: "y"})
80
+
81
+ return data
82
+
83
+
84
+ def _sanitize_prophet_output(predictions: pd.DataFrame) -> pd.DataFrame:
85
+ """Sanitize Prophet prediction output to have SQL-safe column names.
86
+
87
+ Prophet may include holiday columns with names containing spaces (e.g., "Christmas Day")
88
+ which cannot be used as unquoted SQL identifiers in Snowflake. This function normalizes all
89
+ column names to be valid unquoted SQL identifiers by replacing spaces with underscores and
90
+ removing special characters.
91
+
92
+ Args:
93
+ predictions: Raw prediction DataFrame from Prophet
94
+
95
+ Returns:
96
+ DataFrame with normalized SQL-safe column names
97
+
98
+ Raises:
99
+ ValueError: If predictions DataFrame is empty, has no columns, or is missing required
100
+ columns 'ds' and 'yhat'
101
+ """
102
+ # Check if predictions is empty or has no columns
103
+ if predictions is None or len(predictions.columns) == 0:
104
+ raise ValueError(
105
+ f"Prophet predictions DataFrame is empty or has no columns. "
106
+ f"DataFrame shape: {predictions.shape if predictions is not None else 'None'}, "
107
+ f"Type: {type(predictions)}"
108
+ )
109
+
110
+ if "ds" not in predictions.columns or "yhat" not in predictions.columns:
111
+ raise ValueError(
112
+ f"Prophet predictions missing required columns 'ds' and 'yhat'. "
113
+ f"Available columns: {list(predictions.columns)}"
114
+ )
115
+
116
+ # Normalize all column names to be SQL-safe
117
+ normalized_columns = {col: handlers_utils.normalize_column_name(col) for col in predictions.columns}
118
+
119
+ # Check for conflicts after normalization
120
+ normalized_values = list(normalized_columns.values())
121
+ if len(normalized_values) != len(set(normalized_values)):
122
+ # Find duplicates
123
+ seen = set()
124
+ duplicates = []
125
+ for val in normalized_values:
126
+ if val in seen:
127
+ duplicates.append(val)
128
+ seen.add(val)
129
+
130
+ logger.warning(
131
+ f"Column name normalization resulted in duplicates: {duplicates}. "
132
+ f"Original columns: {[k for k, v in normalized_columns.items() if v in duplicates]}"
133
+ )
134
+
135
+ # Rename columns
136
+ sanitized_predictions = predictions.rename(columns=normalized_columns)
137
+
138
+ return sanitized_predictions
139
+
140
+
141
+ def _validate_prophet_data_format(data: model_types.SupportedDataType) -> pd.DataFrame:
142
+ """Validate that input data follows Prophet's required format with 'ds' and 'y' columns.
143
+
144
+ Args:
145
+ data: Input data to validate
146
+
147
+ Returns:
148
+ DataFrame with validated Prophet format and proper data types
149
+
150
+ Raises:
151
+ ValueError: If data doesn't meet Prophet requirements
152
+ """
153
+ if not isinstance(data, pd.DataFrame):
154
+ raise ValueError("Prophet models require pandas DataFrame input with 'ds' and 'y' columns")
155
+
156
+ if "ds" not in data.columns:
157
+ raise ValueError(
158
+ "Prophet models require a 'ds' column containing dates. "
159
+ "If your date column has a different name, use the 'date_column' parameter "
160
+ "when saving the model to map it to 'ds'."
161
+ )
162
+
163
+ if "y" not in data.columns:
164
+ # Allow 'y' column with NaN values for future predictions
165
+ raise ValueError(
166
+ "Prophet models require a 'y' column containing values (can be NaN for future periods). "
167
+ "If your target column has a different name, use the 'target_column' parameter "
168
+ "when saving the model to map it to 'y'."
169
+ )
170
+
171
+ validated_data = data.copy()
172
+
173
+ # Convert datetime column - this handles string timestamps from Snowflake
174
+ try:
175
+ validated_data["ds"] = pd.to_datetime(validated_data["ds"])
176
+ except Exception as e:
177
+ raise ValueError(f"'ds' column must contain valid datetime values: {e}")
178
+
179
+ # Convert numeric columns to proper float types
180
+ for col in validated_data.columns:
181
+ if col != "ds":
182
+ try:
183
+ # Convert to numeric, coercing errors to NaN
184
+ original_col = validated_data[col]
185
+ validated_data[col] = pd.to_numeric(validated_data[col], errors="coerce")
186
+
187
+ # Force explicit dtype conversion to ensure numpy operations work
188
+ validated_data[col] = validated_data[col].astype("float64")
189
+
190
+ logger.debug(f"Converted column '{col}' from {original_col.dtype} to {validated_data[col].dtype}")
191
+
192
+ except Exception as e:
193
+ # If conversion fails completely, provide detailed error
194
+ raise ValueError(
195
+ f"Column '{col}' contains data that cannot be converted to numeric: {e}. "
196
+ f"Original dtype: {validated_data[col].dtype}, sample values: {validated_data[col].head().tolist()}"
197
+ )
198
+
199
+ return validated_data
200
+
201
+
202
+ @final
203
+ class ProphetHandler(_base.BaseModelHandler["prophet.Prophet"]):
204
+ """Handler for prophet time series forecasting models."""
205
+
206
+ HANDLER_TYPE = "prophet"
207
+ HANDLER_VERSION = "2025-01-01"
208
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
209
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
210
+
211
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
212
+ DEFAULT_TARGET_METHODS = ["predict"]
213
+
214
+ # Prophet models require sample data to infer signatures because the data may contain regressors.
215
+ IS_AUTO_SIGNATURE = False
216
+
217
+ @classmethod
218
+ def can_handle(
219
+ cls,
220
+ model: model_types.SupportedModelType,
221
+ ) -> TypeGuard["prophet.Prophet"]:
222
+ """Check if this handler can process the given model.
223
+
224
+ Args:
225
+ model: The model object to check
226
+
227
+ Returns:
228
+ True if this is a Prophet model, False otherwise
229
+ """
230
+ return type_utils.LazyType("prophet.Prophet").isinstance(model) and any(
231
+ (hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
232
+ )
233
+
234
+ @classmethod
235
+ def cast_model(
236
+ cls,
237
+ model: model_types.SupportedModelType,
238
+ ) -> "prophet.Prophet":
239
+ """Cast the model to Prophet type.
240
+
241
+ Args:
242
+ model: The model object
243
+
244
+ Returns:
245
+ The model cast as Prophet
246
+ """
247
+ import prophet
248
+
249
+ assert isinstance(model, prophet.Prophet)
250
+ return cast("prophet.Prophet", model)
251
+
252
+ @classmethod
253
+ def save_model(
254
+ cls,
255
+ name: str,
256
+ model: "prophet.Prophet",
257
+ model_meta: model_meta_api.ModelMetadata,
258
+ model_blobs_dir_path: str,
259
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
260
+ is_sub_model: Optional[bool] = False,
261
+ **kwargs: Unpack[model_types.ProphetSaveOptions],
262
+ ) -> None:
263
+ """Save Prophet model and metadata.
264
+
265
+ Args:
266
+ name: Name of the model
267
+ model: The Prophet model object
268
+ model_meta: The model metadata
269
+ model_blobs_dir_path: Directory to save model files
270
+ sample_input_data: Sample input data for signature inference
271
+ is_sub_model: Whether this is a sub-model
272
+ **kwargs: Additional save options including date_column and target_column for column mapping
273
+
274
+ Raises:
275
+ ValueError: If sample_input_data is not a pandas DataFrame or if column mapping fails
276
+ """
277
+ import prophet
278
+
279
+ assert isinstance(model, prophet.Prophet)
280
+
281
+ date_column = kwargs.pop("date_column", None)
282
+ target_column = kwargs.pop("target_column", None)
283
+
284
+ if not is_sub_model:
285
+ # Validate sample input data if provided
286
+ if sample_input_data is not None:
287
+ if isinstance(sample_input_data, pd.DataFrame):
288
+ # Normalize for validation purposes
289
+ normalized_sample = _normalize_column_names(
290
+ sample_input_data.copy(),
291
+ date_column=date_column,
292
+ target_column=target_column,
293
+ )
294
+ _validate_prophet_data_format(normalized_sample)
295
+ else:
296
+ raise ValueError("Prophet models require pandas DataFrame sample input data")
297
+
298
+ target_methods = handlers_utils.get_target_methods(
299
+ model=model,
300
+ target_methods=kwargs.pop("target_methods", None),
301
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
302
+ )
303
+
304
+ def get_prediction(
305
+ target_method_name: str,
306
+ sample_input_data: model_types.SupportedLocalDataType,
307
+ ) -> model_types.SupportedLocalDataType:
308
+ """Generate predictions for signature inference."""
309
+ if not isinstance(sample_input_data, pd.DataFrame):
310
+ raise ValueError("Prophet requires pandas DataFrame input")
311
+
312
+ normalized_data = _normalize_column_names(
313
+ sample_input_data,
314
+ date_column=date_column,
315
+ target_column=target_column,
316
+ )
317
+ validated_data = _validate_prophet_data_format(normalized_data)
318
+
319
+ target_method = getattr(model, target_method_name, None)
320
+ if not callable(target_method):
321
+ raise ValueError(f"Method {target_method_name} not found on Prophet model")
322
+
323
+ if target_method_name == "predict":
324
+ # Use the input data as the future dataframe for prediction
325
+ try:
326
+ predictions = target_method(validated_data)
327
+ predictions = _sanitize_prophet_output(predictions)
328
+ return predictions
329
+ except Exception as e:
330
+ if "numpy._core.numeric" in str(e):
331
+ raise RuntimeError(
332
+ f"Prophet model logging failed due to NumPy compatibility issue: {e}. "
333
+ f"Try using compatible NumPy versions (e.g., 1.24.x or 1.26.x) in pip_requirements "
334
+ f"with relax_version=False."
335
+ ) from e
336
+ else:
337
+ raise
338
+ elif target_method_name == "predict_components":
339
+ predictions = target_method(validated_data)
340
+ predictions = _sanitize_prophet_output(predictions)
341
+ else:
342
+ raise ValueError(f"Unsupported target method: {target_method_name}")
343
+
344
+ return predictions
345
+
346
+ model_meta = handlers_utils.validate_signature(
347
+ model=model,
348
+ model_meta=model_meta,
349
+ target_methods=target_methods,
350
+ sample_input_data=sample_input_data,
351
+ get_prediction_fn=get_prediction,
352
+ )
353
+
354
+ model_meta.task = model_types.Task.UNKNOWN # Prophet is forecasting, which isn't in standard tasks
355
+
356
+ # Save the Prophet model using cloudpickle
357
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
358
+ os.makedirs(model_blob_path, exist_ok=True)
359
+
360
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
361
+ cloudpickle.dump(model, f)
362
+
363
+ # Create model blob metadata with column mapping options
364
+ from snowflake.ml.model._packager.model_meta import model_meta_schema
365
+
366
+ options: model_meta_schema.ProphetModelBlobOptions = {}
367
+ if date_column is not None:
368
+ options["date_column"] = date_column
369
+ if target_column is not None:
370
+ options["target_column"] = target_column
371
+
372
+ base_meta = model_blob_meta.ModelBlobMeta(
373
+ name=name,
374
+ model_type=cls.HANDLER_TYPE,
375
+ handler_version=cls.HANDLER_VERSION,
376
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
377
+ options=options if options else {},
378
+ )
379
+ model_meta.models[name] = base_meta
380
+ model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
381
+
382
+ # Add Prophet dependencies
383
+ model_meta.env.include_if_absent(
384
+ [
385
+ model_env.ModelDependency(requirement="prophet", pip_name="prophet"),
386
+ model_env.ModelDependency(requirement="pandas", pip_name="pandas"),
387
+ model_env.ModelDependency(requirement="numpy", pip_name="numpy"),
388
+ ],
389
+ check_local_version=True,
390
+ )
391
+
392
+ @classmethod
393
+ def load_model(
394
+ cls,
395
+ name: str,
396
+ model_meta: model_meta_api.ModelMetadata,
397
+ model_blobs_dir_path: str,
398
+ **kwargs: Unpack[model_types.ProphetLoadOptions],
399
+ ) -> "prophet.Prophet":
400
+ """Load Prophet model from storage.
401
+
402
+ Args:
403
+ name: Name of the model
404
+ model_meta: The model metadata
405
+ model_blobs_dir_path: Directory containing model files
406
+ **kwargs: Additional load options
407
+
408
+ Returns:
409
+ The loaded Prophet model
410
+ """
411
+ import prophet
412
+
413
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
414
+ model_blobs_metadata = model_meta.models
415
+ model_blob_metadata = model_blobs_metadata[name]
416
+ model_blob_filename = model_blob_metadata.path
417
+
418
+ with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
419
+ model = cloudpickle.load(f)
420
+
421
+ assert isinstance(model, prophet.Prophet)
422
+ return model
423
+
424
+ @classmethod
425
+ def convert_as_custom_model(
426
+ cls,
427
+ raw_model: "prophet.Prophet",
428
+ model_meta: model_meta_api.ModelMetadata,
429
+ background_data: Optional[pd.DataFrame] = None,
430
+ **kwargs: Unpack[model_types.ProphetLoadOptions],
431
+ ) -> custom_model.CustomModel:
432
+ """Convert Prophet model to CustomModel for unified inference interface.
433
+
434
+ Args:
435
+ raw_model: The original Prophet model
436
+ model_meta: The model metadata
437
+ background_data: Background data for explanations (not used for Prophet)
438
+ **kwargs: Additional options
439
+
440
+ Returns:
441
+ CustomModel wrapper for the Prophet model
442
+ """
443
+ from snowflake.ml.model import custom_model
444
+
445
+ model_blob_meta = next(iter(model_meta.models.values()))
446
+ date_column: Optional[str] = cast(Optional[str], model_blob_meta.options.get("date_column", None))
447
+ target_column: Optional[str] = cast(Optional[str], model_blob_meta.options.get("target_column", None))
448
+
449
+ def _create_custom_model(
450
+ raw_model: "prophet.Prophet",
451
+ model_meta: model_meta_api.ModelMetadata,
452
+ date_column: Optional[str],
453
+ target_column: Optional[str],
454
+ ) -> type[custom_model.CustomModel]:
455
+ """Create custom model class for Prophet."""
456
+
457
+ def fn_factory(
458
+ raw_model: "prophet.Prophet",
459
+ signature: model_signature.ModelSignature,
460
+ target_method: str,
461
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
462
+ """Factory function to create method implementations."""
463
+
464
+ @custom_model.inference_api
465
+ def predict_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
466
+ """Predict method for Prophet forecasting.
467
+
468
+ For forecasting, users should provide a DataFrame with:
469
+ - 'ds' column (or custom date column name): dates for which to generate forecasts
470
+ - 'y' column (or custom target column name): can be NaN for future periods to forecast
471
+ - Additional regressor columns if the model was trained with them
472
+
473
+ Args:
474
+ self: The custom model instance
475
+ X: Input DataFrame with dates and optional target values for forecasting
476
+
477
+ Returns:
478
+ DataFrame containing Prophet predictions with columns like ds, yhat, yhat_lower, etc.
479
+
480
+ Raises:
481
+ ValueError: If column normalization fails or method is unsupported
482
+ RuntimeError: If NumPy compatibility issues are detected during validation or prediction
483
+ """
484
+ try:
485
+ normalized_data = _normalize_column_names(
486
+ X, date_column=date_column, target_column=target_column
487
+ )
488
+ except Exception as e:
489
+ raise ValueError(f"Failed to normalize column names: {e}") from e
490
+
491
+ # Validate input format with runtime error handling
492
+ try:
493
+ validated_data = _validate_prophet_data_format(normalized_data)
494
+ except Exception as e:
495
+ if "numpy._core.numeric" in str(e):
496
+ raise RuntimeError(
497
+ f"Prophet input validation failed in Snowflake runtime due to "
498
+ f"NumPy compatibility: {e}. Redeploy model with compatible dependency versions "
499
+ f"(e.g., NumPy 1.24.x or 1.26.x, Prophet 1.1.x) in pip_requirements "
500
+ f"with relax_version=False."
501
+ ) from e
502
+ else:
503
+ raise
504
+
505
+ # Generate predictions using Prophet with runtime error handling
506
+ if target_method == "predict":
507
+ try:
508
+ predictions = raw_model.predict(validated_data)
509
+ # Sanitize output to remove columns with problematic names
510
+ predictions = _sanitize_prophet_output(predictions)
511
+ except Exception as e:
512
+ if "numpy._core.numeric" in str(e) or "np.float_" in str(e):
513
+ raise RuntimeError(
514
+ f"Prophet prediction failed in Snowflake runtime due to NumPy compatibility: {e}. "
515
+ f"This indicates Prophet's internal NumPy operations are incompatible. "
516
+ f"Redeploy with compatible dependency versions in pip_requirements."
517
+ ) from e
518
+ else:
519
+ raise
520
+ else:
521
+ raise ValueError(f"Unsupported method: {target_method}")
522
+
523
+ # Prophet returns many columns, but we only want the ones in our signature
524
+ # Filter to only the columns we expect based on our signature
525
+ expected_columns = [spec.name for spec in signature.outputs]
526
+ available_columns = [col for col in expected_columns if col in predictions.columns]
527
+
528
+ # Fill missing columns with zeros to match the expected signature
529
+ filtered_predictions = predictions[available_columns].copy()
530
+
531
+ # Add missing columns with zeros if they're not present
532
+ for col_name in expected_columns:
533
+ if col_name not in filtered_predictions.columns:
534
+ # Add missing seasonal component columns with zeros
535
+ if col_name in ["weekly", "yearly", "daily"]:
536
+ filtered_predictions[col_name] = 0.0
537
+ else:
538
+ # For required columns like ds, yhat, etc., this would be an error
539
+ raise ValueError(f"Required column '{col_name}' missing from Prophet output")
540
+
541
+ # Reorder columns to match signature order
542
+ filtered_predictions = filtered_predictions[expected_columns]
543
+
544
+ # Ensure the output matches the expected signature
545
+ return model_signature_utils.rename_pandas_df(filtered_predictions, signature.outputs)
546
+
547
+ return predict_fn
548
+
549
+ # Create method dictionary for the custom model class
550
+ type_method_dict = {}
551
+ for target_method_name, sig in model_meta.signatures.items():
552
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
553
+
554
+ # Create the custom model class
555
+ _ProphetModel = type(
556
+ "_ProphetModel",
557
+ (custom_model.CustomModel,),
558
+ type_method_dict,
559
+ )
560
+
561
+ return _ProphetModel
562
+
563
+ _ProphetModel = _create_custom_model(raw_model, model_meta, date_column, target_column)
564
+ prophet_model = _ProphetModel(custom_model.ModelContext())
565
+
566
+ return prophet_model
@@ -84,6 +84,11 @@ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
84
84
  batch_size: Required[int]
85
85
 
86
86
 
87
+ class ProphetModelBlobOptions(BaseModelBlobOptions):
88
+ date_column: NotRequired[Optional[str]]
89
+ target_column: NotRequired[Optional[str]]
90
+
91
+
87
92
  ModelBlobOptions = Union[
88
93
  BaseModelBlobOptions,
89
94
  CatBoostModelBlobOptions,
@@ -94,6 +99,7 @@ ModelBlobOptions = Union[
94
99
  TorchScriptModelBlobOptions,
95
100
  TensorflowModelBlobOptions,
96
101
  SentenceTransformersModelBlobOptions,
102
+ ProphetModelBlobOptions,
97
103
  ]
98
104
 
99
105
 
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
24
24
  import mlflow
25
25
  import numpy as np
26
26
  import pandas as pd
27
+ import prophet
27
28
  import sentence_transformers
28
29
  import sklearn.base
29
30
  import sklearn.pipeline
@@ -81,6 +82,7 @@ SupportedRequireSignatureModelType = Union[
81
82
  "catboost.CatBoost",
82
83
  "lightgbm.LGBMModel",
83
84
  "lightgbm.Booster",
85
+ "prophet.Prophet",
84
86
  "snowflake.ml.model.custom_model.CustomModel",
85
87
  "sklearn.base.BaseEstimator",
86
88
  "sklearn.pipeline.Pipeline",
@@ -113,6 +115,7 @@ Here is all acceptable types of Snowflake native model packaging and its handler
113
115
  | snowflake.ml.model.custom_model.CustomModel | custom.py | _CustomModelHandler |
114
116
  | sklearn.base.BaseEstimator | sklearn.py | _SKLModelHandler |
115
117
  | sklearn.pipeline.Pipeline | sklearn.py | _SKLModelHandler |
118
+ | prophet.Prophet | prophet.py | ProphetHandler |
116
119
  | xgboost.XGBModel | xgboost.py | _XGBModelHandler |
117
120
  | xgboost.Booster | xgboost.py | _XGBModelHandler |
118
121
  | lightgbm.LGBMModel | lightgbm.py | _LGBMModelHandler |
@@ -134,6 +137,7 @@ SupportedModelHandlerType = Literal[
134
137
  "huggingface_pipeline",
135
138
  "lightgbm",
136
139
  "mlflow",
140
+ "prophet",
137
141
  "pytorch",
138
142
  "sentence_transformers",
139
143
  "sklearn",
@@ -248,11 +252,18 @@ class KerasSaveOptions(BaseModelSaveOption):
248
252
  cuda_version: NotRequired[str]
249
253
 
250
254
 
255
+ class ProphetSaveOptions(BaseModelSaveOption):
256
+ target_methods: NotRequired[Sequence[str]]
257
+ date_column: NotRequired[str]
258
+ target_column: NotRequired[str]
259
+
260
+
251
261
  ModelSaveOption = Union[
252
262
  BaseModelSaveOption,
253
263
  CatBoostModelSaveOptions,
254
264
  CustomModelSaveOption,
255
265
  LGBMModelSaveOptions,
266
+ ProphetSaveOptions,
256
267
  SKLModelSaveOptions,
257
268
  XGBModelSaveOptions,
258
269
  SNOWModelSaveOptions,
@@ -327,11 +338,16 @@ class KerasLoadOptions(BaseModelLoadOption):
327
338
  use_gpu: NotRequired[bool]
328
339
 
329
340
 
341
+ class ProphetLoadOptions(BaseModelLoadOption):
342
+ ...
343
+
344
+
330
345
  ModelLoadOption = Union[
331
346
  BaseModelLoadOption,
332
347
  CatBoostModelLoadOptions,
333
348
  CustomModelLoadOption,
334
349
  LGBMModelLoadOptions,
350
+ ProphetLoadOptions,
335
351
  SKLModelLoadOptions,
336
352
  XGBModelLoadOptions,
337
353
  SNOWModelLoadOptions,
@@ -4,6 +4,7 @@ from typing import Any, Collection, Iterable, Optional, Union
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
+ from packaging import version
7
8
 
8
9
  import snowflake.snowpark._internal.utils as snowpark_utils
9
10
  from snowflake import snowpark
@@ -59,7 +60,10 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
59
60
  ]
60
61
  ),
61
62
  input_types=[T.BinaryType()],
62
- packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
63
+ packages=[
64
+ f"numpy=={version.parse(np.__version__).major}.*",
65
+ f"cloudpickle=={version.parse(cloudpickle.__version__).major}.*",
66
+ ],
63
67
  imports=[], # Prevents unnecessary import resolution.
64
68
  name=accumulator,
65
69
  is_permanent=False,
@@ -175,7 +179,10 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
175
179
  ]
176
180
  ),
177
181
  input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
178
- packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
182
+ packages=[
183
+ f"numpy=={version.parse(np.__version__).major}.*",
184
+ f"cloudpickle=={version.parse(cloudpickle.__version__).major}.*",
185
+ ],
179
186
  imports=[], # Prevents unnecessary import resolution.
180
187
  name=sharded_dot_and_sum_computer,
181
188
  is_permanent=False,
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.17.0"
2
+ VERSION = "1.18.0"