snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/callback/keras.py +63 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  7. snowflake/ml/experiment/callback/xgboost.py +5 -1
  8. snowflake/ml/experiment/experiment_tracking.py +89 -4
  9. snowflake/ml/feature_store/feature_store.py +1150 -131
  10. snowflake/ml/feature_store/feature_view.py +122 -0
  11. snowflake/ml/jobs/_utils/__init__.py +0 -0
  12. snowflake/ml/jobs/_utils/constants.py +9 -14
  13. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  14. snowflake/ml/jobs/_utils/payload_utils.py +61 -19
  15. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  16. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  17. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
  19. snowflake/ml/jobs/_utils/spec_utils.py +44 -13
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  21. snowflake/ml/jobs/_utils/types.py +7 -8
  22. snowflake/ml/jobs/job.py +34 -18
  23. snowflake/ml/jobs/manager.py +107 -24
  24. snowflake/ml/model/__init__.py +6 -1
  25. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  26. snowflake/ml/model/_client/model/model_version_impl.py +225 -73
  27. snowflake/ml/model/_client/ops/service_ops.py +128 -174
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
  30. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  33. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  35. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  36. snowflake/ml/model/_signatures/utils.py +4 -2
  37. snowflake/ml/model/inference_engine.py +5 -0
  38. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  39. snowflake/ml/model/openai_signatures.py +57 -0
  40. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  41. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  46. snowflake/ml/modeling/cluster/birch.py +1 -1
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  50. snowflake/ml/modeling/cluster/k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  53. snowflake/ml/modeling/cluster/optics.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  57. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  64. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  65. snowflake/ml/modeling/covariance/oas.py +1 -1
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  69. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/pca.py +1 -1
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  105. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  106. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  107. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  118. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  122. snowflake/ml/modeling/linear_model/lars.py +1 -1
  123. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  129. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  151. snowflake/ml/modeling/manifold/isomap.py +1 -1
  152. snowflake/ml/modeling/manifold/mds.py +1 -1
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  154. snowflake/ml/modeling/manifold/tsne.py +1 -1
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  157. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  158. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  159. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  163. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  164. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  165. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  166. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  167. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  168. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  169. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  170. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  171. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  172. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  173. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  174. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  175. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  176. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  178. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  179. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  180. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  181. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  182. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  183. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  184. snowflake/ml/modeling/svm/svc.py +1 -1
  185. snowflake/ml/modeling/svm/svr.py +1 -1
  186. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  189. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  192. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  193. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  194. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  195. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  196. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  197. snowflake/ml/monitoring/model_monitor.py +26 -0
  198. snowflake/ml/registry/_manager/model_manager.py +7 -35
  199. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  200. snowflake/ml/version.py +1 -1
  201. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
  202. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
  203. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  204. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from enum import Enum, auto
1
2
  from typing import Any, Mapping, Optional
2
3
 
3
4
  from snowflake import snowpark
@@ -15,6 +16,25 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
15
16
  MODEL_JSON_VERSION_NAME_FIELD = "version_name"
16
17
 
17
18
 
19
+ class MonitorOperation(Enum):
20
+ SUSPEND = auto()
21
+ RESUME = auto()
22
+ ADD = auto()
23
+ DROP = auto()
24
+
25
+ @property
26
+ def supported_target_properties(self) -> frozenset[str]:
27
+ return _OPERATION_SUPPORTED_PROPS[self]
28
+
29
+
30
+ _OPERATION_SUPPORTED_PROPS: dict[MonitorOperation, frozenset[str]] = {
31
+ MonitorOperation.SUSPEND: frozenset(),
32
+ MonitorOperation.RESUME: frozenset(),
33
+ MonitorOperation.ADD: frozenset({"SEGMENT_COLUMN"}),
34
+ MonitorOperation.DROP: frozenset({"SEGMENT_COLUMN"}),
35
+ }
36
+
37
+
18
38
  def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
19
39
  sql_list = ", ".join([f"'{column}'" for column in columns])
20
40
  return f"({sql_list})"
@@ -70,11 +90,17 @@ class ModelMonitorSQLClient:
70
90
  baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
71
91
  baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
72
92
  baseline: Optional[sql_identifier.SqlIdentifier] = None,
93
+ segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
73
94
  statement_params: Optional[dict[str, Any]] = None,
74
95
  ) -> None:
75
96
  baseline_sql = ""
76
97
  if baseline:
77
98
  baseline_sql = f"""BASELINE={self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}"""
99
+
100
+ segment_columns_sql = ""
101
+ if segment_columns:
102
+ segment_columns_sql = f"SEGMENT_COLUMNS={_build_sql_list_from_columns(segment_columns)}"
103
+
78
104
  query_result_checker.SqlResultValidator(
79
105
  self._sql_client._session,
80
106
  f"""
@@ -93,6 +119,7 @@ class ModelMonitorSQLClient:
93
119
  TIMESTAMP_COLUMN='{timestamp_column}'
94
120
  REFRESH_INTERVAL='{refresh_interval}'
95
121
  AGGREGATION_WINDOW='{aggregation_window}'
122
+ {segment_columns_sql}
96
123
  {baseline_sql}""",
97
124
  statement_params=statement_params,
98
125
  ).has_column("status").has_dimensions(1, 1).validate()
@@ -182,6 +209,7 @@ class ModelMonitorSQLClient:
182
209
  actual_score_columns: list[sql_identifier.SqlIdentifier],
183
210
  actual_class_columns: list[sql_identifier.SqlIdentifier],
184
211
  id_columns: list[sql_identifier.SqlIdentifier],
212
+ segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
185
213
  ) -> None:
186
214
  """Ensures all columns exist in the source table.
187
215
 
@@ -193,11 +221,14 @@ class ModelMonitorSQLClient:
193
221
  actual_score_columns: List of actual score column names.
194
222
  actual_class_columns: List of actual class column names.
195
223
  id_columns: List of id column names.
224
+ segment_columns: List of segment column names.
196
225
 
197
226
  Raises:
198
227
  ValueError: If any of the columns do not exist in the source.
199
228
  """
200
229
 
230
+ segment_columns = [] if segment_columns is None else segment_columns
231
+
201
232
  if timestamp_column not in source_column_schema:
202
233
  raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.")
203
234
 
@@ -214,6 +245,9 @@ class ModelMonitorSQLClient:
214
245
  if not all([column_name in source_column_schema for column_name in id_columns]):
215
246
  raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
216
247
 
248
+ if not all([column_name in source_column_schema for column_name in segment_columns]):
249
+ raise ValueError(f"Segment column(s): {segment_columns} do not exist in source.")
250
+
217
251
  def validate_source(
218
252
  self,
219
253
  *,
@@ -226,7 +260,9 @@ class ModelMonitorSQLClient:
226
260
  actual_score_columns: list[sql_identifier.SqlIdentifier],
227
261
  actual_class_columns: list[sql_identifier.SqlIdentifier],
228
262
  id_columns: list[sql_identifier.SqlIdentifier],
263
+ segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None,
229
264
  ) -> None:
265
+
230
266
  source_database = source_database or self._database_name
231
267
  source_schema = source_schema or self._schema_name
232
268
  # Get Schema of the source. Implicitly validates that the source exists.
@@ -244,19 +280,38 @@ class ModelMonitorSQLClient:
244
280
  actual_score_columns=actual_score_columns,
245
281
  actual_class_columns=actual_class_columns,
246
282
  id_columns=id_columns,
283
+ segment_columns=segment_columns,
247
284
  )
248
285
 
249
286
  def _alter_monitor(
250
287
  self,
251
- operation: str,
288
+ operation: MonitorOperation,
252
289
  monitor_name: sql_identifier.SqlIdentifier,
290
+ target_property: Optional[str] = None,
291
+ target_value: Optional[sql_identifier.SqlIdentifier] = None,
253
292
  statement_params: Optional[dict[str, Any]] = None,
254
293
  ) -> None:
255
- if operation not in {"SUSPEND", "RESUME"}:
256
- raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
294
+ supported_target_properties = operation.supported_target_properties
295
+
296
+ if supported_target_properties:
297
+ if target_property is None or target_value is None:
298
+ raise ValueError(f"Target property and value must be provided for {operation.name} operation")
299
+
300
+ if target_property not in supported_target_properties:
301
+ raise ValueError(
302
+ f"Only {', '.join(supported_target_properties)} supported as target property "
303
+ f"for {operation.name} operation"
304
+ )
305
+
306
+ property_clause = f"{target_property}={target_value}" if target_property and target_value else ""
307
+ alter_momo_sql = (
308
+ f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} """
309
+ f"""{operation.name} {property_clause}"""
310
+ )
311
+
257
312
  query_result_checker.SqlResultValidator(
258
313
  self._sql_client._session,
259
- f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} {operation}""",
314
+ alter_momo_sql,
260
315
  statement_params=statement_params,
261
316
  ).has_column("status").has_dimensions(1, 1).validate()
262
317
 
@@ -266,7 +321,7 @@ class ModelMonitorSQLClient:
266
321
  statement_params: Optional[dict[str, Any]] = None,
267
322
  ) -> None:
268
323
  self._alter_monitor(
269
- operation="SUSPEND",
324
+ operation=MonitorOperation.SUSPEND,
270
325
  monitor_name=monitor_name,
271
326
  statement_params=statement_params,
272
327
  )
@@ -277,7 +332,37 @@ class ModelMonitorSQLClient:
277
332
  statement_params: Optional[dict[str, Any]] = None,
278
333
  ) -> None:
279
334
  self._alter_monitor(
280
- operation="RESUME",
335
+ operation=MonitorOperation.RESUME,
336
+ monitor_name=monitor_name,
337
+ statement_params=statement_params,
338
+ )
339
+
340
+ def add_segment_column(
341
+ self,
342
+ monitor_name: sql_identifier.SqlIdentifier,
343
+ segment_column: sql_identifier.SqlIdentifier,
344
+ statement_params: Optional[dict[str, Any]] = None,
345
+ ) -> None:
346
+ """Add a segment column to the Model Monitor"""
347
+ self._alter_monitor(
348
+ operation=MonitorOperation.ADD,
349
+ monitor_name=monitor_name,
350
+ target_property="SEGMENT_COLUMN",
351
+ target_value=segment_column,
352
+ statement_params=statement_params,
353
+ )
354
+
355
+ def drop_segment_column(
356
+ self,
357
+ monitor_name: sql_identifier.SqlIdentifier,
358
+ segment_column: sql_identifier.SqlIdentifier,
359
+ statement_params: Optional[dict[str, Any]] = None,
360
+ ) -> None:
361
+ """Drop a segment column from the Model Monitor"""
362
+ self._alter_monitor(
363
+ operation=MonitorOperation.DROP,
281
364
  monitor_name=monitor_name,
365
+ target_property="SEGMENT_COLUMN",
366
+ target_value=segment_column,
282
367
  statement_params=statement_params,
283
368
  )
@@ -108,6 +108,7 @@ class ModelMonitorManager:
108
108
  prediction_class_columns = self._build_column_list_from_input(source_config.prediction_class_columns)
109
109
  actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns)
110
110
  actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns)
111
+ segment_columns = self._build_column_list_from_input(source_config.segment_columns)
111
112
 
112
113
  id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns]
113
114
  ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column)
@@ -123,6 +124,7 @@ class ModelMonitorManager:
123
124
  actual_score_columns=actual_score_columns,
124
125
  actual_class_columns=actual_class_columns,
125
126
  id_columns=id_columns,
127
+ segment_columns=segment_columns,
126
128
  )
127
129
 
128
130
  self._model_monitor_client.create_model_monitor(
@@ -144,6 +146,7 @@ class ModelMonitorManager:
144
146
  prediction_class_columns=prediction_class_columns,
145
147
  actual_score_columns=actual_score_columns,
146
148
  actual_class_columns=actual_class_columns,
149
+ segment_columns=segment_columns,
147
150
  refresh_interval=model_monitor_config.refresh_interval,
148
151
  aggregation_window=model_monitor_config.aggregation_window,
149
152
  baseline_database=baseline_database_name_id,
@@ -33,6 +33,9 @@ class ModelMonitorSourceConfig:
33
33
  baseline: Optional[str] = None
34
34
  """Name of table containing the baseline data."""
35
35
 
36
+ segment_columns: Optional[list[str]] = None
37
+ """List of columns in the source containing segment information for grouped monitoring."""
38
+
36
39
 
37
40
  @dataclass
38
41
  class ModelMonitorConfig:
@@ -46,3 +46,29 @@ class ModelMonitor:
46
46
  telemetry.TelemetrySubProject.MONITORING.value,
47
47
  )
48
48
  self._model_monitor_client.resume_monitor(self.name, statement_params=statement_params)
49
+
50
+ @telemetry.send_api_usage_telemetry(
51
+ project=telemetry.TelemetryProject.MLOPS.value,
52
+ subproject=telemetry.TelemetrySubProject.MONITORING.value,
53
+ )
54
+ def add_segment_column(self, segment_column: str) -> None:
55
+ """Add a segment column to the Model Monitor"""
56
+ statement_params = telemetry.get_statement_params(
57
+ telemetry.TelemetryProject.MLOPS.value,
58
+ telemetry.TelemetrySubProject.MONITORING.value,
59
+ )
60
+ segment_column_id = sql_identifier.SqlIdentifier(segment_column)
61
+ self._model_monitor_client.add_segment_column(self.name, segment_column_id, statement_params=statement_params)
62
+
63
+ @telemetry.send_api_usage_telemetry(
64
+ project=telemetry.TelemetryProject.MLOPS.value,
65
+ subproject=telemetry.TelemetrySubProject.MONITORING.value,
66
+ )
67
+ def drop_segment_column(self, segment_column: str) -> None:
68
+ """Drop a segment column from the Model Monitor"""
69
+ statement_params = telemetry.get_statement_params(
70
+ telemetry.TelemetryProject.MLOPS.value,
71
+ telemetry.TelemetrySubProject.MONITORING.value,
72
+ )
73
+ segment_column_id = sql_identifier.SqlIdentifier(segment_column)
74
+ self._model_monitor_client.drop_segment_column(self.name, segment_column_id, statement_params=statement_params)
@@ -4,15 +4,14 @@ from typing import TYPE_CHECKING, Any, Optional, Union
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
6
 
7
- from snowflake.ml._internal import env, platform_capabilities, telemetry
7
+ from snowflake.ml._internal import platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
- from snowflake.ml.model import model_signature, target_platform, task, type_hints
11
+ from snowflake.ml.model import model_signature, task, type_hints
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
15
  from snowflake.ml.model._packager.model_meta import model_meta
17
16
  from snowflake.ml.registry._manager import model_parameter_reconciler
18
17
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
@@ -221,37 +220,8 @@ class ModelManager:
221
220
  statement_params=statement_params,
222
221
  )
223
222
 
224
- platforms = None
225
- # User specified target platforms are defaulted to None and will not show up in the generated manifest.
226
- if target_platforms:
227
- # Convert any string target platforms to TargetPlatform objects
228
- platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
229
- else:
230
- # Default the target platform to warehouse if not specified and any table function exists
231
- if options and (
232
- options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
233
- or (
234
- any(
235
- opt.get("function_type") == "TABLE_FUNCTION"
236
- for opt in options.get("method_options", {}).values()
237
- )
238
- )
239
- ):
240
- logger.info(
241
- "Logging a partitioned model with a table function without specifying `target_platforms`. "
242
- 'Default to `target_platforms=["WAREHOUSE"]`.'
243
- )
244
- platforms = [target_platform.TargetPlatform.WAREHOUSE]
245
-
246
- # Default the target platform to SPCS if not specified when running in ML runtime
247
- if not platforms and env.IN_ML_RUNTIME:
248
- logger.info(
249
- "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
250
- 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
251
- )
252
- platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
253
-
254
223
  reconciler = model_parameter_reconciler.ModelParameterReconciler(
224
+ session=self._model_ops._session,
255
225
  database_name=self._database_name,
256
226
  schema_name=self._schema_name,
257
227
  conda_dependencies=conda_dependencies,
@@ -259,6 +229,8 @@ class ModelManager:
259
229
  target_platforms=target_platforms,
260
230
  artifact_repository_map=artifact_repository_map,
261
231
  options=options,
232
+ python_version=python_version,
233
+ statement_params=statement_params,
262
234
  )
263
235
 
264
236
  model_params = reconciler.reconcile()
@@ -293,12 +265,12 @@ class ModelManager:
293
265
  pip_requirements=pip_requirements,
294
266
  artifact_repository_map=artifact_repository_map,
295
267
  resource_constraint=resource_constraint,
296
- target_platforms=platforms,
268
+ target_platforms=model_params.target_platforms,
297
269
  python_version=python_version,
298
270
  user_files=user_files,
299
271
  code_paths=code_paths,
300
272
  ext_modules=ext_modules,
301
- options=options,
273
+ options=model_params.options,
302
274
  task=task,
303
275
  experiment_info=experiment_info,
304
276
  )
@@ -1,9 +1,20 @@
1
1
  import warnings
2
2
  from dataclasses import dataclass
3
- from typing import Optional
3
+ from typing import Any, Optional
4
4
 
5
+ from absl.logging import logging
6
+ from packaging import requirements
7
+
8
+ from snowflake.ml import version as snowml_version
9
+ from snowflake.ml._internal import env, env as snowml_env, env_utils
10
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
5
11
  from snowflake.ml._internal.utils import sql_identifier
6
- from snowflake.ml.model import type_hints as model_types
12
+ from snowflake.ml.model import target_platform, type_hints as model_types
13
+ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
14
+ from snowflake.snowpark import Session
15
+ from snowflake.snowpark._internal import utils as snowpark_utils
16
+
17
+ logger = logging.getLogger(__name__)
7
18
 
8
19
 
9
20
  @dataclass
@@ -12,7 +23,7 @@ class ReconciledParameters:
12
23
 
13
24
  conda_dependencies: Optional[list[str]] = None
14
25
  pip_requirements: Optional[list[str]] = None
15
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None
26
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None
16
27
  artifact_repository_map: Optional[dict[str, str]] = None
17
28
  options: Optional[model_types.ModelSaveOption] = None
18
29
  save_location: Optional[str] = None
@@ -23,6 +34,7 @@ class ModelParameterReconciler:
23
34
 
24
35
  def __init__(
25
36
  self,
37
+ session: Session,
26
38
  database_name: sql_identifier.SqlIdentifier,
27
39
  schema_name: sql_identifier.SqlIdentifier,
28
40
  conda_dependencies: Optional[list[str]] = None,
@@ -30,7 +42,10 @@ class ModelParameterReconciler:
30
42
  target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
31
43
  artifact_repository_map: Optional[dict[str, str]] = None,
32
44
  options: Optional[model_types.ModelSaveOption] = None,
45
+ python_version: Optional[str] = None,
46
+ statement_params: Optional[dict[str, str]] = None,
33
47
  ) -> None:
48
+ self._session = session
34
49
  self._database_name = database_name
35
50
  self._schema_name = schema_name
36
51
  self._conda_dependencies = conda_dependencies
@@ -38,20 +53,27 @@ class ModelParameterReconciler:
38
53
  self._target_platforms = target_platforms
39
54
  self._artifact_repository_map = artifact_repository_map
40
55
  self._options = options
56
+ self._python_version = python_version
57
+ self._statement_params = statement_params
41
58
 
42
59
  def reconcile(self) -> ReconciledParameters:
43
60
  """Perform all parameter reconciliation and return clean parameters."""
61
+
44
62
  reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
45
63
  reconciled_save_location = self._extract_save_location()
46
64
 
47
65
  self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
48
66
 
67
+ reconciled_target_platforms = self._reconcile_target_platforms()
68
+ reconciled_options = self._reconcile_explainability_options(reconciled_target_platforms)
69
+ reconciled_options = self._reconcile_relax_version(reconciled_options, reconciled_target_platforms)
70
+
49
71
  return ReconciledParameters(
50
72
  conda_dependencies=self._conda_dependencies,
51
73
  pip_requirements=self._pip_requirements,
52
- target_platforms=self._target_platforms,
74
+ target_platforms=reconciled_target_platforms,
53
75
  artifact_repository_map=reconciled_artifact_repository_map,
54
- options=self._options,
76
+ options=reconciled_options,
55
77
  save_location=reconciled_save_location,
56
78
  )
57
79
 
@@ -82,6 +104,45 @@ class ModelParameterReconciler:
82
104
 
83
105
  return None
84
106
 
107
+ def _reconcile_target_platforms(self) -> Optional[list[model_types.TargetPlatform]]:
108
+ """Reconcile target platforms with proper defaulting logic."""
109
+ # User specified target platforms are defaulted to None and will not show up in the generated manifest.
110
+ if self._target_platforms:
111
+ # Convert any string target platforms to TargetPlatform objects
112
+ return [model_types.TargetPlatform(platform) for platform in self._target_platforms]
113
+
114
+ # Default the target platform to warehouse if not specified and any table function exists
115
+ if self._has_table_function():
116
+ logger.info(
117
+ "Logging a partitioned model with a table function without specifying `target_platforms`. "
118
+ 'Default to `target_platforms=["WAREHOUSE"]`.'
119
+ )
120
+ return [target_platform.TargetPlatform.WAREHOUSE]
121
+
122
+ # Default the target platform to SPCS if not specified when running in ML runtime
123
+ if env.IN_ML_RUNTIME:
124
+ logger.info(
125
+ "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
126
+ 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
127
+ )
128
+ return [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
129
+
130
+ return None
131
+
132
+ def _has_table_function(self) -> bool:
133
+ """Check if any table function exists in options."""
134
+ if self._options is None:
135
+ return False
136
+
137
+ if self._options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
138
+ return True
139
+
140
+ for opt in self._options.get("method_options", {}).values():
141
+ if opt.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
142
+ return True
143
+
144
+ return False
145
+
85
146
  def _validate_pip_requirements_warehouse_compatibility(
86
147
  self, artifact_repository_map: Optional[dict[str, str]]
87
148
  ) -> None:
@@ -103,3 +164,131 @@ class ModelParameterReconciler:
103
164
  or model_types.TargetPlatform.WAREHOUSE in target_platforms
104
165
  or "WAREHOUSE" in target_platforms
105
166
  )
167
+
168
+ def _reconcile_explainability_options(
169
+ self, target_platforms: Optional[list[model_types.TargetPlatform]]
170
+ ) -> model_types.ModelSaveOption:
171
+ """Reconcile explainability settings and embed_local_ml_library based on warehouse runnability."""
172
+ options = self._options.copy() if self._options else model_types.BaseModelSaveOption()
173
+
174
+ conda_dep_dict = env_utils.validate_conda_dependency_string_list(self._conda_dependencies or [])
175
+
176
+ enable_explainability = options.get("enable_explainability", None)
177
+
178
+ # Handle case where user explicitly disabled explainability
179
+ if enable_explainability is False:
180
+ return self._handle_embed_local_ml_library(options, target_platforms)
181
+
182
+ target_platform_set = set(target_platforms) if target_platforms else set()
183
+
184
+ is_warehouse_runnable = self._is_warehouse_runnable(conda_dep_dict)
185
+ only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
186
+ has_both_platforms = target_platform_set == set(target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES)
187
+
188
+ # Handle case where user explicitly requested explainability
189
+ if enable_explainability:
190
+ if only_spcs or not is_warehouse_runnable:
191
+ raise ValueError(
192
+ "`enable_explainability` cannot be set to True when the model is not runnable in WH "
193
+ "or the target platforms include SPCS."
194
+ )
195
+ elif has_both_platforms:
196
+ warnings.warn(
197
+ ("Explain function will only be available for model deployed to warehouse."),
198
+ category=UserWarning,
199
+ stacklevel=2,
200
+ )
201
+
202
+ # Handle case where explainability is not specified (None) - set default behavior
203
+ if enable_explainability is None:
204
+ if only_spcs or not is_warehouse_runnable:
205
+ options["enable_explainability"] = False
206
+
207
+ return self._handle_embed_local_ml_library(options, target_platforms)
208
+
209
+ def _handle_embed_local_ml_library(
210
+ self, options: model_types.ModelSaveOption, target_platforms: Optional[list[model_types.TargetPlatform]]
211
+ ) -> model_types.ModelSaveOption:
212
+ """Handle embed_local_ml_library logic."""
213
+ if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
214
+ model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
215
+ ]:
216
+ snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
217
+ self._session,
218
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
219
+ python_version=self._python_version or snowml_env.PYTHON_VERSION,
220
+ statement_params=self._statement_params,
221
+ ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
222
+
223
+ if len(snowml_matched_versions) < 1 and not options.get("embed_local_ml_library", False):
224
+ logging.info(
225
+ f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
226
+ " which is not available in the Snowflake server, embedding local ML library automatically."
227
+ )
228
+ options["embed_local_ml_library"] = True
229
+
230
+ return options
231
+
232
+ def _is_warehouse_runnable(self, conda_dep_dict: dict[str, list[Any]]) -> bool:
233
+ """Check if model can run in warehouse based on conda channels and pip requirements."""
234
+ # If pip requirements are present but no artifact repository map, model cannot run in warehouse
235
+ if self._pip_requirements and not self._artifact_repository_map:
236
+ return False
237
+
238
+ # If no conda dependencies, model can run in warehouse
239
+ if not conda_dep_dict:
240
+ return True
241
+
242
+ # Check if all conda channels are warehouse-compatible
243
+ warehouse_compatible_channels = {env_utils.DEFAULT_CHANNEL_NAME, env_utils.SNOWFLAKE_CONDA_CHANNEL_URL}
244
+ for channel in conda_dep_dict:
245
+ if channel not in warehouse_compatible_channels:
246
+ return False
247
+
248
+ return True
249
+
250
+ def _reconcile_relax_version(
251
+ self,
252
+ options: model_types.ModelSaveOption,
253
+ target_platforms: Optional[list[model_types.TargetPlatform]],
254
+ ) -> model_types.ModelSaveOption:
255
+ """Reconcile relax_version setting based on pip requirements and target platforms."""
256
+ target_platform_set = set(target_platforms) if target_platforms else set()
257
+ has_pip_requirements = bool(self._pip_requirements)
258
+ only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
259
+
260
+ if "relax_version" not in options:
261
+ if has_pip_requirements or only_spcs:
262
+ logger.info(
263
+ "Setting `relax_version=False` as this model will run in Snowpark Container Services "
264
+ "or in Warehouse with a specified artifact_repository_map where exact version "
265
+ " specifications will be honored."
266
+ )
267
+ relax_version = False
268
+ else:
269
+ warnings.warn(
270
+ (
271
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
272
+ " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
273
+ " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
274
+ ),
275
+ category=UserWarning,
276
+ stacklevel=2,
277
+ )
278
+ relax_version = True
279
+ options["relax_version"] = relax_version
280
+ return options
281
+
282
+ # Handle case where relax_version is already set
283
+ relax_version = options["relax_version"]
284
+ if relax_version and (has_pip_requirements or only_spcs):
285
+ raise exceptions.SnowflakeMLException(
286
+ error_code=error_codes.INVALID_ARGUMENT,
287
+ original_exception=ValueError(
288
+ "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
289
+ "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
290
+ "targeting only Snowpark Container Services."
291
+ ),
292
+ )
293
+
294
+ return options
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.10.0"
2
+ VERSION = "1.12.0"