snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  19. snowflake/ml/fileset/fileset.py +18 -18
  20. snowflake/ml/model/_client/model/model_version_impl.py +24 -8
  21. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +12 -7
  23. snowflake/ml/model/_client/sql/model_version.py +11 -0
  24. snowflake/ml/model/_client/sql/stage.py +1 -1
  25. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  27. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  28. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  29. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  31. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  32. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  33. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  34. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  39. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  40. snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
  41. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  42. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  45. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  46. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  48. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  49. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  50. snowflake/ml/model/_signatures/utils.py +0 -1
  51. snowflake/ml/model/type_hints.py +1 -0
  52. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  53. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  54. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  55. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  56. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  57. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  58. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  59. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
  60. snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
  61. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
  62. snowflake/ml/monitoring/model_monitor.py +26 -11
  63. snowflake/ml/registry/_manager/model_manager.py +70 -33
  64. snowflake/ml/registry/registry.py +53 -34
  65. snowflake/ml/utils/authentication.py +75 -0
  66. snowflake/ml/version.py +1 -1
  67. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
  68. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
  69. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  70. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  71. snowflake/ml/fileset/parquet_parser.py +0 -170
  72. snowflake/ml/fileset/tf_dataset.py +0 -88
  73. snowflake/ml/fileset/torch_datapipe.py +0 -57
  74. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  75. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  76. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  77. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  78. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -125,111 +125,171 @@ class XGBRFRegressor(BaseTransformer):
125
125
  can seriously hurt performance in gradient boosting. Set the batch_size as large as possible
126
126
  based on the available memory.
127
127
 
128
- n_estimators: int
128
+ n_estimators: Optional[int]
129
129
  Number of trees in random forest to fit.
130
130
 
131
- max_depth: Optional[int]
131
+ max_depth: typing.Optional[int]
132
+
132
133
  Maximum tree depth for base learners.
133
- max_leaves :
134
+
135
+ max_leaves: typing.Optional[int]
136
+
134
137
  Maximum number of leaves; 0 indicates no limit.
135
- max_bin :
138
+
139
+ max_bin: typing.Optional[int]
140
+
136
141
  If using histogram-based algorithm, maximum number of bins per feature
137
- grow_policy :
138
- Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow
139
- depth-wise. 1: favor splitting at nodes with highest loss change.
140
- learning_rate: Optional[float]
142
+
143
+ grow_policy: typing.Optional[str]
144
+
145
+ Tree growing policy.
146
+
147
+ - depthwise: Favors splitting at nodes closest to the node,
148
+ - lossguide: Favors splitting at nodes with highest loss change.
149
+
150
+ learning_rate: typing.Optional[float]
151
+
141
152
  Boosting learning rate (xgb's "eta")
142
- verbosity: Optional[int]
153
+
154
+ verbosity: typing.Optional[int]
155
+
143
156
  The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
144
- objective: typing.Union[str, typing.Callable[[numpy.ndarray, numpy.ndarray], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]
145
- Specify the learning task and the corresponding learning objective or
146
- a custom objective function to be used (see note below).
147
- booster: Optional[str]
148
- Specify which booster to use: gbtree, gblinear or dart.
149
- tree_method: Optional[str]
157
+
158
+ objective: typing.Union[str, xgboost.sklearn._SklObjWProto, typing.Callable[[typing.Any, typing.Any], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]
159
+
160
+ Specify the learning task and the corresponding learning objective or a custom
161
+ objective function to be used.
162
+
163
+ For custom objective, see :doc:`/tutorials/custom_metric_obj` and
164
+ :ref:`custom-obj-metric` for more information, along with the end note for
165
+ function signatures.
166
+
167
+ booster: typing.Optional[str]
168
+
169
+ Specify which booster to use: ``gbtree``, ``gblinear`` or ``dart``.
170
+
171
+ tree_method: typing.Optional[str]
172
+
150
173
  Specify which tree method to use. Default to auto. If this parameter is set to
151
174
  default, XGBoost will choose the most conservative option available. It's
152
175
  recommended to study this option from the parameters document :doc:`tree method
153
176
  </treemethod>`
154
- n_jobs: Optional[int]
177
+
178
+ n_jobs: typing.Optional[int]
179
+
155
180
  Number of parallel threads used to run xgboost. When used with other
156
181
  Scikit-Learn algorithms like grid search, you may choose which algorithm to
157
182
  parallelize and balance the threads. Creating thread contention will
158
183
  significantly slow down both algorithms.
159
- gamma: Optional[float]
160
- (min_split_loss) Minimum loss reduction required to make a further partition on a
161
- leaf node of the tree.
162
- min_child_weight: Optional[float]
184
+
185
+ gamma: typing.Optional[float]
186
+
187
+ (min_split_loss) Minimum loss reduction required to make a further partition on
188
+ a leaf node of the tree.
189
+
190
+ min_child_weight: typing.Optional[float]
191
+
163
192
  Minimum sum of instance weight(hessian) needed in a child.
164
- max_delta_step: Optional[float]
193
+
194
+ max_delta_step: typing.Optional[float]
195
+
165
196
  Maximum delta step we allow each tree's weight estimation to be.
166
- subsample: Optional[float]
197
+
198
+ subsample: typing.Optional[float]
199
+
167
200
  Subsample ratio of the training instance.
168
- sampling_method :
169
- Sampling method. Used only by `gpu_hist` tree method.
170
- - `uniform`: select random training instances uniformly.
171
- - `gradient_based` select random training instances with higher probability when
172
- the gradient and hessian are larger. (cf. CatBoost)
173
- colsample_bytree: Optional[float]
201
+
202
+ sampling_method: typing.Optional[str]
203
+
204
+ Sampling method. Used only by the GPU version of ``hist`` tree method.
205
+
206
+ - ``uniform``: Select random training instances uniformly.
207
+ - ``gradient_based``: Select random training instances with higher probability
208
+ when the gradient and hessian are larger. (cf. CatBoost)
209
+
210
+ colsample_bytree: typing.Optional[float]
211
+
174
212
  Subsample ratio of columns when constructing each tree.
175
- colsample_bylevel: Optional[float]
213
+
214
+ colsample_bylevel: typing.Optional[float]
215
+
176
216
  Subsample ratio of columns for each level.
177
- colsample_bynode: Optional[float]
217
+
218
+ colsample_bynode: typing.Optional[float]
219
+
178
220
  Subsample ratio of columns for each split.
179
- reg_alpha: Optional[float]
221
+
222
+ reg_alpha: typing.Optional[float]
223
+
180
224
  L1 regularization term on weights (xgb's alpha).
181
- reg_lambda: Optional[float]
225
+
226
+ reg_lambda: typing.Optional[float]
227
+
182
228
  L2 regularization term on weights (xgb's lambda).
183
- scale_pos_weight: Optional[float]
229
+
230
+ scale_pos_weight: typing.Optional[float]
184
231
  Balancing of positive and negative weights.
185
- base_score: Optional[float]
232
+
233
+ base_score: typing.Optional[float]
234
+
186
235
  The initial prediction score of all instances, global bias.
187
- random_state: Optional[Union[numpy.random.RandomState, int]]
236
+
237
+ random_state: typing.Union[numpy.random.mtrand.RandomState, numpy.random._generator.Generator, int, NoneType]
238
+
188
239
  Random number seed.
189
240
 
190
241
  Using gblinear booster with shotgun updater is nondeterministic as
191
242
  it uses Hogwild algorithm.
192
243
 
193
- missing: float, default np.nan
194
- Value in the data which needs to be present as a missing value.
195
- num_parallel_tree: Optional[int]
244
+ missing: float
245
+
246
+ Value in the data which needs to be present as a missing value. Default to
247
+ :py:data:`numpy.nan`.
248
+
249
+ num_parallel_tree: typing.Optional[int]
250
+
196
251
  Used for boosting random forest.
197
- monotone_constraints: Optional[Union[Dict[str, int], str]]
252
+
253
+ monotone_constraints: typing.Union[typing.Dict[str, int], str, NoneType]
254
+
198
255
  Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>`
199
256
  for more information.
200
- interaction_constraints: Optional[Union[str, List[Tuple[str]]]]
257
+
258
+ interaction_constraints: typing.Union[str, typing.List[typing.Tuple[str]], NoneType]
259
+
201
260
  Constraints for interaction representing permitted interactions. The
202
261
  constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2,
203
262
  3, 4]]``, where each inner list is a group of indices of features that are
204
263
  allowed to interact with each other. See :doc:`tutorial
205
264
  </tutorials/feature_interaction_constraint>` for more information
206
- importance_type: Optional[str]
265
+
266
+ importance_type: typing.Optional[str]
267
+
207
268
  The feature importance type for the feature_importances\_ property:
208
269
 
209
270
  * For tree model, it's either "gain", "weight", "cover", "total_gain" or
210
271
  "total_cover".
211
- * For linear model, only "weight" is defined and it's the normalized coefficients
212
- without bias.
272
+ * For linear model, only "weight" is defined and it's the normalized
273
+ coefficients without bias.
274
+
275
+ device: typing.Optional[str]
276
+
277
+ Device ordinal, available options are `cpu`, `cuda`, and `gpu`.
278
+
279
+ validate_parameters: typing.Optional[bool]
213
280
 
214
- gpu_id: Optional[int]
215
- Device ordinal.
216
- validate_parameters: Optional[bool]
217
281
  Give warnings for unknown parameter.
218
- predictor: Optional[str]
219
- Force XGBoost to use specific predictor, available choices are [cpu_predictor,
220
- gpu_predictor].
282
+
221
283
  enable_categorical: bool
222
284
 
223
- Experimental support for categorical data. When enabled, cudf/pandas.DataFrame
224
- should be used to specify categorical data type. Also, JSON/UBJSON
225
- serialization format is required.
285
+ See the same parameter of :py:class:`DMatrix` for details.
226
286
 
227
- feature_types: FeatureTypes
287
+ feature_types: typing.Optional[typing.Sequence[str]]
228
288
 
229
289
  Used for specifying feature types without constructing a dataframe. See
230
290
  :py:class:`DMatrix` for details.
231
291
 
232
- max_cat_to_onehot: Optional[int]
292
+ max_cat_to_onehot: typing.Optional[int]
233
293
 
234
294
  A threshold for deciding whether XGBoost should use one-hot encoding based split
235
295
  for categorical data. When number of categories is lesser than the threshold
@@ -238,36 +298,41 @@ class XGBRFRegressor(BaseTransformer):
238
298
  categorical feature support. See :doc:`Categorical Data
239
299
  </tutorials/categorical>` and :ref:`cat-param` for details.
240
300
 
241
- max_cat_threshold: Optional[int]
301
+ max_cat_threshold: typing.Optional[int]
242
302
 
243
303
  Maximum number of categories considered for each split. Used only by
244
304
  partition-based splits for preventing over-fitting. Also, `enable_categorical`
245
305
  needs to be set to have categorical feature support. See :doc:`Categorical Data
246
306
  </tutorials/categorical>` and :ref:`cat-param` for details.
247
307
 
248
- eval_metric: Optional[Union[str, List[str], Callable]]
308
+ multi_strategy: typing.Optional[str]
309
+
310
+ The strategy used for training multi-target models, including multi-target
311
+ regression and multi-class classification. See :doc:`/tutorials/multioutput` for
312
+ more information.
313
+
314
+ - ``one_output_per_tree``: One model for each target.
315
+ - ``multi_output_tree``: Use multi-target trees.
316
+
317
+ eval_metric: typing.Union[str, typing.List[str], typing.Callable, NoneType]
249
318
 
250
319
  Metric used for monitoring the training result and early stopping. It can be a
251
320
  string or list of strings as names of predefined metric in XGBoost (See
252
- doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other
253
- user defined metric that looks like `sklearn.metrics`.
321
+ doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any
322
+ other user defined metric that looks like `sklearn.metrics`.
254
323
 
255
324
  If custom objective is also provided, then custom metric should implement the
256
325
  corresponding reverse link function.
257
326
 
258
327
  Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
259
- object is provided, it's assumed to be a cost function and by default XGBoost will
260
- minimize the result during early stopping.
261
-
262
- For advanced usage on Early stopping like directly choosing to maximize instead of
263
- minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
328
+ object is provided, it's assumed to be a cost function and by default XGBoost
329
+ will minimize the result during early stopping.
264
330
 
265
- See :doc:`Custom Objective and Evaluation Metric </tutorials/custom_metric_obj>`
266
- for more.
331
+ For advanced usage on Early stopping like directly choosing to maximize instead
332
+ of minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
267
333
 
268
- This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one
269
- receives un-transformed prediction regardless of whether custom objective is
270
- being used.
334
+ See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
335
+ information.
271
336
 
272
337
  from sklearn.datasets import load_diabetes
273
338
  from sklearn.metrics import mean_absolute_error
@@ -278,24 +343,29 @@ class XGBRFRegressor(BaseTransformer):
278
343
  )
279
344
  reg.fit(X, y, eval_set=[(X, y)])
280
345
 
281
- early_stopping_rounds: Optional[int]
346
+ early_stopping_rounds: typing.Optional[int]
282
347
 
283
- Activates early stopping. Validation metric needs to improve at least once in
284
- every **early_stopping_rounds** round(s) to continue training. Requires at least
285
- one item in **eval_set** in :py:meth:`fit`.
348
+ - Activates early stopping. Validation metric needs to improve at least once in
349
+ every **early_stopping_rounds** round(s) to continue training. Requires at
350
+ least one item in **eval_set** in :py:meth:`fit`.
286
351
 
287
- The method returns the model from the last iteration (not the best one). If
288
- there's more than one item in **eval_set**, the last entry will be used for early
289
- stopping. If there's more than one metric in **eval_metric**, the last metric
290
- will be used for early stopping.
352
+ - If early stopping occurs, the model will have two additional attributes:
353
+ :py:attr:`best_score` and :py:attr:`best_iteration`. These are used by the
354
+ :py:meth:`predict` and :py:meth:`apply` methods to determine the optimal
355
+ number of trees during inference. If users want to access the full model
356
+ (including trees built after early stopping), they can specify the
357
+ `iteration_range` in these inference methods. In addition, other utilities
358
+ like model plotting can also use the entire model.
291
359
 
292
- If early stopping occurs, the model will have three additional fields:
293
- :py:attr:`best_score`, :py:attr:`best_iteration` and
294
- :py:attr:`best_ntree_limit`.
360
+ - If you prefer to discard the trees after `best_iteration`, consider using the
361
+ callback function :py:class:`xgboost.callback.EarlyStopping`.
295
362
 
296
- This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method.
363
+ - If there's more than one item in **eval_set**, the last entry will be used for
364
+ early stopping. If there's more than one metric in **eval_metric**, the last
365
+ metric will be used for early stopping.
366
+
367
+ callbacks: typing.Optional[typing.List[xgboost.callback.TrainingCallback]]
297
368
 
298
- callbacks: Optional[List[TrainingCallback]]
299
369
  List of callback functions that are applied at end of each iteration.
300
370
  It is possible to use predefined callbacks by using
301
371
  :ref:`Callback API <callback_api>`.
@@ -307,9 +377,11 @@ class XGBRFRegressor(BaseTransformer):
307
377
  for params in parameters_grid:
308
378
  # be sure to (re)initialize the callbacks before each run
309
379
  callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
310
- xgboost.train(params, Xy, callbacks=callbacks)
380
+ reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
381
+ reg.fit(X, y)
382
+
383
+ kwargs: typing.Optional[typing.Any]
311
384
 
312
- kwargs: dict, optional
313
385
  Keyword arguments for XGBoost Booster object. Full documentation of parameters
314
386
  can be found :doc:`here </parameter>`.
315
387
  Attempting to set a parameter via the constructor args and \*\*kwargs
@@ -320,13 +392,16 @@ class XGBRFRegressor(BaseTransformer):
320
392
  with scikit-learn.
321
393
 
322
394
  A custom objective function can be provided for the ``objective``
323
- parameter. In this case, it should have the signature
324
- ``objective(y_true, y_pred) -> grad, hess``:
395
+ parameter. In this case, it should have the signature ``objective(y_true,
396
+ y_pred) -> [grad, hess]`` or ``objective(y_true, y_pred, *, sample_weight)
397
+ -> [grad, hess]``:
325
398
 
326
399
  y_true: array_like of shape [n_samples]
327
400
  The target values
328
401
  y_pred: array_like of shape [n_samples]
329
402
  The predicted values
403
+ sample_weight :
404
+ Optional sample weights.
330
405
 
331
406
  grad: array_like of shape [n_samples]
332
407
  The value of the gradient for each sample point.
@@ -1,6 +1,4 @@
1
- import typing
2
- from collections import Counter
3
- from typing import Any, Dict, List, Mapping, Optional, Set
1
+ from typing import Any, Dict, List, Mapping, Optional
4
2
 
5
3
  from snowflake import snowpark
6
4
  from snowflake.ml._internal.utils import (
@@ -10,27 +8,12 @@ from snowflake.ml._internal.utils import (
10
8
  table_manager,
11
9
  )
12
10
  from snowflake.ml.model._client.sql import _base
13
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
14
11
  from snowflake.snowpark import session, types
15
12
 
16
- SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
17
-
18
13
  MODEL_JSON_COL_NAME = "model"
19
14
  MODEL_JSON_MODEL_NAME_FIELD = "model_name"
20
15
  MODEL_JSON_VERSION_NAME_FIELD = "version_name"
21
16
 
22
- MONITOR_NAME_COL_NAME = "MONITOR_NAME"
23
- SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
24
- FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
25
- VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
26
- FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
27
- TASK_COL_NAME = "TASK"
28
- MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
29
- TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
30
- PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
31
- LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
32
- ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
33
-
34
17
 
35
18
  def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
36
19
  sql_list = ", ".join([f"'{column}'" for column in columns])
@@ -59,7 +42,7 @@ class ModelMonitorSQLClient:
59
42
  def _infer_qualified_schema(
60
43
  self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier]
61
44
  ) -> str:
62
- return f"{database_name or self._database_name}.{schema_name or self._schema_name}"
45
+ return f"""{database_name or self._database_name}.{schema_name or self._schema_name}"""
63
46
 
64
47
  def create_model_monitor(
65
48
  self,
@@ -91,17 +74,17 @@ class ModelMonitorSQLClient:
91
74
  ) -> None:
92
75
  baseline_sql = ""
93
76
  if baseline:
94
- baseline_sql = f"BASELINE='{self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}'"
77
+ baseline_sql = f"""BASELINE={self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}"""
95
78
  query_result_checker.SqlResultValidator(
96
79
  self._sql_client._session,
97
80
  f"""
98
81
  CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name}
99
82
  WITH
100
- MODEL='{self._infer_qualified_schema(model_database, model_schema)}.{model_name}'
83
+ MODEL={self._infer_qualified_schema(model_database, model_schema)}.{model_name}
101
84
  VERSION='{version_name}'
102
85
  FUNCTION='{function_name}'
103
86
  WAREHOUSE='{warehouse_name}'
104
- SOURCE='{self._infer_qualified_schema(source_database, source_schema)}.{source}'
87
+ SOURCE={self._infer_qualified_schema(source_database, source_schema)}.{source}
105
88
  ID_COLUMNS={_build_sql_list_from_columns(id_columns)}
106
89
  PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)}
107
90
  PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)}
@@ -146,19 +129,6 @@ class ModelMonitorSQLClient:
146
129
  .validate()
147
130
  )
148
131
 
149
- def _validate_unique_columns(
150
- self,
151
- timestamp_column: sql_identifier.SqlIdentifier,
152
- id_columns: List[sql_identifier.SqlIdentifier],
153
- prediction_columns: List[sql_identifier.SqlIdentifier],
154
- label_columns: List[sql_identifier.SqlIdentifier],
155
- ) -> None:
156
- all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column]
157
- num_all_columns = len(all_columns)
158
- num_unique_columns = len(set(all_columns))
159
- if num_all_columns != num_unique_columns:
160
- raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.")
161
-
162
132
  def validate_existence_by_name(
163
133
  self,
164
134
  *,
@@ -244,125 +214,6 @@ class ModelMonitorSQLClient:
244
214
  if not all([column_name in source_column_schema for column_name in id_columns]):
245
215
  raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
246
216
 
247
- def _validate_timestamp_column_type(
248
- self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
249
- ) -> None:
250
- """Ensures columns have the same type.
251
-
252
- Args:
253
- table_schema: Dictionary of column names and types in the source table.
254
- timestamp_column: Name of the timestamp column.
255
-
256
- Raises:
257
- ValueError: If the timestamp column is not of type TimestampType.
258
- """
259
- if not isinstance(table_schema[timestamp_column], types.TimestampType):
260
- raise ValueError(
261
- f"Timestamp column: {timestamp_column} must be TimestampType. "
262
- f"Found: {table_schema[timestamp_column]}"
263
- )
264
-
265
- def _validate_id_columns_types(
266
- self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
267
- ) -> None:
268
- """Ensures id columns have the correct type.
269
-
270
- Args:
271
- table_schema: Dictionary of column names and types in the source table.
272
- id_columns: List of id column names.
273
-
274
- Raises:
275
- ValueError: If the id column is not of type StringType.
276
- """
277
- id_column_types = list({table_schema[column_name] for column_name in id_columns})
278
- all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
279
- if not all_id_columns_string:
280
- raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
281
-
282
- def _validate_prediction_columns_types(
283
- self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
284
- ) -> None:
285
- """Ensures prediction columns have the same type.
286
-
287
- Args:
288
- table_schema: Dictionary of column names and types in the source table.
289
- prediction_columns: List of prediction column names.
290
-
291
- Raises:
292
- ValueError: If the prediction columns do not share the same type.
293
- """
294
-
295
- prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
296
- if len(prediction_column_types) > 1:
297
- raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
298
-
299
- def _validate_label_columns_types(
300
- self,
301
- table_schema: Mapping[str, types.DataType],
302
- label_columns: List[sql_identifier.SqlIdentifier],
303
- ) -> None:
304
- """Ensures label columns have the same type, and the correct type for the score type.
305
-
306
- Args:
307
- table_schema: Dictionary of column names and types in the source table.
308
- label_columns: List of label column names.
309
-
310
- Raises:
311
- ValueError: If the label columns do not share the same type.
312
- """
313
- label_column_types = {table_schema[column_name] for column_name in label_columns}
314
- if len(label_column_types) > 1:
315
- raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
316
-
317
- def _validate_column_types(
318
- self,
319
- *,
320
- table_schema: Mapping[str, types.DataType],
321
- timestamp_column: sql_identifier.SqlIdentifier,
322
- id_columns: List[sql_identifier.SqlIdentifier],
323
- prediction_columns: List[sql_identifier.SqlIdentifier],
324
- label_columns: List[sql_identifier.SqlIdentifier],
325
- ) -> None:
326
- """Ensures columns have the expected type.
327
-
328
- Args:
329
- table_schema: Dictionary of column names and types in the source table.
330
- timestamp_column: Name of the timestamp column.
331
- id_columns: List of id column names.
332
- prediction_columns: List of prediction column names.
333
- label_columns: List of label column names.
334
- """
335
- self._validate_timestamp_column_type(table_schema, timestamp_column)
336
- self._validate_id_columns_types(table_schema, id_columns)
337
- self._validate_prediction_columns_types(table_schema, prediction_columns)
338
- self._validate_label_columns_types(table_schema, label_columns)
339
- # TODO(SNOW-1646693): Validate label makes sense with model task
340
-
341
- def _validate_source_table_features_shape(
342
- self,
343
- table_schema: Mapping[str, types.DataType],
344
- special_columns: Set[sql_identifier.SqlIdentifier],
345
- model_function: model_manifest_schema.ModelFunctionInfo,
346
- ) -> None:
347
- table_schema_without_special_columns = {
348
- k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
349
- }
350
- schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
351
- for column_type in table_schema_without_special_columns.values():
352
- schema_column_types_to_count[column_type] += 1
353
-
354
- inputs = model_function["signature"].inputs
355
- function_input_types = [input.as_snowpark_type() for input in inputs]
356
- function_input_types_to_count: typing.Counter[types.DataType] = Counter()
357
- for function_input_type in function_input_types:
358
- function_input_types_to_count[function_input_type] += 1
359
-
360
- if function_input_types_to_count != schema_column_types_to_count:
361
- raise ValueError(
362
- "Model function input types do not match the source table input columns types. "
363
- f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
364
- )
365
-
366
217
  def validate_source(
367
218
  self,
368
219
  *,
@@ -395,22 +246,6 @@ class ModelMonitorSQLClient:
395
246
  id_columns=id_columns,
396
247
  )
397
248
 
398
- def delete_monitor_metadata(
399
- self,
400
- name: str,
401
- statement_params: Optional[Dict[str, Any]] = None,
402
- ) -> None:
403
- """Delete the row in the metadata table corresponding to the given monitor name.
404
-
405
- Args:
406
- name: Name of the model monitor whose metadata should be deleted.
407
- statement_params: Optional set of statement_params to include with query.
408
- """
409
- self._sql_client._session.sql(
410
- f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
411
- WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
412
- ).collect(statement_params=statement_params)
413
-
414
249
  def _alter_monitor(
415
250
  self,
416
251
  operation: str,
@@ -14,15 +14,6 @@ from snowflake.snowpark import session
14
14
  class ModelMonitorManager:
15
15
  """Class to manage internal operations for Model Monitor workflows."""
16
16
 
17
- def _validate_task_from_model_version(
18
- self,
19
- model_version: model_version_impl.ModelVersion,
20
- ) -> type_hints.Task:
21
- task = model_version.get_model_task()
22
- if task == type_hints.Task.UNKNOWN:
23
- raise ValueError("Registry model must be logged with task in order to be monitored.")
24
- return task
25
-
26
17
  def __init__(
27
18
  self,
28
19
  session: session.Session,
@@ -51,6 +42,15 @@ class ModelMonitorManager:
51
42
  schema_name=self._schema_name,
52
43
  )
53
44
 
45
+ def _validate_task_from_model_version(
46
+ self,
47
+ model_version: model_version_impl.ModelVersion,
48
+ ) -> type_hints.Task:
49
+ task = model_version.get_model_task()
50
+ if task == type_hints.Task.UNKNOWN:
51
+ raise ValueError("Registry model must be logged with task in order to be monitored.")
52
+ return task
53
+
54
54
  def _validate_model_function_from_model_version(
55
55
  self, function: str, model_version: model_version_impl.ModelVersion
56
56
  ) -> None: