snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +82 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +18 -18
- snowflake/ml/model/_client/model/model_version_impl.py +24 -8
- snowflake/ml/model/_client/ops/model_ops.py +2 -6
- snowflake/ml/model/_client/ops/service_ops.py +12 -7
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/pandas_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
- snowflake/ml/model/_signatures/utils.py +0 -1
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/pipeline/pipeline.py +6 -176
- snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
- snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
- snowflake/ml/monitoring/model_monitor.py +26 -11
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +53 -34
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -125,111 +125,171 @@ class XGBRFRegressor(BaseTransformer):
|
|
125
125
|
can seriously hurt performance in gradient boosting. Set the batch_size as large as possible
|
126
126
|
based on the available memory.
|
127
127
|
|
128
|
-
n_estimators: int
|
128
|
+
n_estimators: Optional[int]
|
129
129
|
Number of trees in random forest to fit.
|
130
130
|
|
131
|
-
max_depth: Optional[int]
|
131
|
+
max_depth: typing.Optional[int]
|
132
|
+
|
132
133
|
Maximum tree depth for base learners.
|
133
|
-
|
134
|
+
|
135
|
+
max_leaves: typing.Optional[int]
|
136
|
+
|
134
137
|
Maximum number of leaves; 0 indicates no limit.
|
135
|
-
|
138
|
+
|
139
|
+
max_bin: typing.Optional[int]
|
140
|
+
|
136
141
|
If using histogram-based algorithm, maximum number of bins per feature
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
142
|
+
|
143
|
+
grow_policy: typing.Optional[str]
|
144
|
+
|
145
|
+
Tree growing policy.
|
146
|
+
|
147
|
+
- depthwise: Favors splitting at nodes closest to the node,
|
148
|
+
- lossguide: Favors splitting at nodes with highest loss change.
|
149
|
+
|
150
|
+
learning_rate: typing.Optional[float]
|
151
|
+
|
141
152
|
Boosting learning rate (xgb's "eta")
|
142
|
-
|
153
|
+
|
154
|
+
verbosity: typing.Optional[int]
|
155
|
+
|
143
156
|
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
157
|
+
|
158
|
+
objective: typing.Union[str, xgboost.sklearn._SklObjWProto, typing.Callable[[typing.Any, typing.Any], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]
|
159
|
+
|
160
|
+
Specify the learning task and the corresponding learning objective or a custom
|
161
|
+
objective function to be used.
|
162
|
+
|
163
|
+
For custom objective, see :doc:`/tutorials/custom_metric_obj` and
|
164
|
+
:ref:`custom-obj-metric` for more information, along with the end note for
|
165
|
+
function signatures.
|
166
|
+
|
167
|
+
booster: typing.Optional[str]
|
168
|
+
|
169
|
+
Specify which booster to use: ``gbtree``, ``gblinear`` or ``dart``.
|
170
|
+
|
171
|
+
tree_method: typing.Optional[str]
|
172
|
+
|
150
173
|
Specify which tree method to use. Default to auto. If this parameter is set to
|
151
174
|
default, XGBoost will choose the most conservative option available. It's
|
152
175
|
recommended to study this option from the parameters document :doc:`tree method
|
153
176
|
</treemethod>`
|
154
|
-
|
177
|
+
|
178
|
+
n_jobs: typing.Optional[int]
|
179
|
+
|
155
180
|
Number of parallel threads used to run xgboost. When used with other
|
156
181
|
Scikit-Learn algorithms like grid search, you may choose which algorithm to
|
157
182
|
parallelize and balance the threads. Creating thread contention will
|
158
183
|
significantly slow down both algorithms.
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
184
|
+
|
185
|
+
gamma: typing.Optional[float]
|
186
|
+
|
187
|
+
(min_split_loss) Minimum loss reduction required to make a further partition on
|
188
|
+
a leaf node of the tree.
|
189
|
+
|
190
|
+
min_child_weight: typing.Optional[float]
|
191
|
+
|
163
192
|
Minimum sum of instance weight(hessian) needed in a child.
|
164
|
-
|
193
|
+
|
194
|
+
max_delta_step: typing.Optional[float]
|
195
|
+
|
165
196
|
Maximum delta step we allow each tree's weight estimation to be.
|
166
|
-
|
197
|
+
|
198
|
+
subsample: typing.Optional[float]
|
199
|
+
|
167
200
|
Subsample ratio of the training instance.
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
201
|
+
|
202
|
+
sampling_method: typing.Optional[str]
|
203
|
+
|
204
|
+
Sampling method. Used only by the GPU version of ``hist`` tree method.
|
205
|
+
|
206
|
+
- ``uniform``: Select random training instances uniformly.
|
207
|
+
- ``gradient_based``: Select random training instances with higher probability
|
208
|
+
when the gradient and hessian are larger. (cf. CatBoost)
|
209
|
+
|
210
|
+
colsample_bytree: typing.Optional[float]
|
211
|
+
|
174
212
|
Subsample ratio of columns when constructing each tree.
|
175
|
-
|
213
|
+
|
214
|
+
colsample_bylevel: typing.Optional[float]
|
215
|
+
|
176
216
|
Subsample ratio of columns for each level.
|
177
|
-
|
217
|
+
|
218
|
+
colsample_bynode: typing.Optional[float]
|
219
|
+
|
178
220
|
Subsample ratio of columns for each split.
|
179
|
-
|
221
|
+
|
222
|
+
reg_alpha: typing.Optional[float]
|
223
|
+
|
180
224
|
L1 regularization term on weights (xgb's alpha).
|
181
|
-
|
225
|
+
|
226
|
+
reg_lambda: typing.Optional[float]
|
227
|
+
|
182
228
|
L2 regularization term on weights (xgb's lambda).
|
183
|
-
|
229
|
+
|
230
|
+
scale_pos_weight: typing.Optional[float]
|
184
231
|
Balancing of positive and negative weights.
|
185
|
-
|
232
|
+
|
233
|
+
base_score: typing.Optional[float]
|
234
|
+
|
186
235
|
The initial prediction score of all instances, global bias.
|
187
|
-
|
236
|
+
|
237
|
+
random_state: typing.Union[numpy.random.mtrand.RandomState, numpy.random._generator.Generator, int, NoneType]
|
238
|
+
|
188
239
|
Random number seed.
|
189
240
|
|
190
241
|
Using gblinear booster with shotgun updater is nondeterministic as
|
191
242
|
it uses Hogwild algorithm.
|
192
243
|
|
193
|
-
missing: float
|
194
|
-
|
195
|
-
|
244
|
+
missing: float
|
245
|
+
|
246
|
+
Value in the data which needs to be present as a missing value. Default to
|
247
|
+
:py:data:`numpy.nan`.
|
248
|
+
|
249
|
+
num_parallel_tree: typing.Optional[int]
|
250
|
+
|
196
251
|
Used for boosting random forest.
|
197
|
-
|
252
|
+
|
253
|
+
monotone_constraints: typing.Union[typing.Dict[str, int], str, NoneType]
|
254
|
+
|
198
255
|
Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>`
|
199
256
|
for more information.
|
200
|
-
|
257
|
+
|
258
|
+
interaction_constraints: typing.Union[str, typing.List[typing.Tuple[str]], NoneType]
|
259
|
+
|
201
260
|
Constraints for interaction representing permitted interactions. The
|
202
261
|
constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2,
|
203
262
|
3, 4]]``, where each inner list is a group of indices of features that are
|
204
263
|
allowed to interact with each other. See :doc:`tutorial
|
205
264
|
</tutorials/feature_interaction_constraint>` for more information
|
206
|
-
|
265
|
+
|
266
|
+
importance_type: typing.Optional[str]
|
267
|
+
|
207
268
|
The feature importance type for the feature_importances\_ property:
|
208
269
|
|
209
270
|
* For tree model, it's either "gain", "weight", "cover", "total_gain" or
|
210
271
|
"total_cover".
|
211
|
-
* For linear model, only "weight" is defined and it's the normalized
|
212
|
-
without bias.
|
272
|
+
* For linear model, only "weight" is defined and it's the normalized
|
273
|
+
coefficients without bias.
|
274
|
+
|
275
|
+
device: typing.Optional[str]
|
276
|
+
|
277
|
+
Device ordinal, available options are `cpu`, `cuda`, and `gpu`.
|
278
|
+
|
279
|
+
validate_parameters: typing.Optional[bool]
|
213
280
|
|
214
|
-
gpu_id: Optional[int]
|
215
|
-
Device ordinal.
|
216
|
-
validate_parameters: Optional[bool]
|
217
281
|
Give warnings for unknown parameter.
|
218
|
-
|
219
|
-
Force XGBoost to use specific predictor, available choices are [cpu_predictor,
|
220
|
-
gpu_predictor].
|
282
|
+
|
221
283
|
enable_categorical: bool
|
222
284
|
|
223
|
-
|
224
|
-
should be used to specify categorical data type. Also, JSON/UBJSON
|
225
|
-
serialization format is required.
|
285
|
+
See the same parameter of :py:class:`DMatrix` for details.
|
226
286
|
|
227
|
-
feature_types:
|
287
|
+
feature_types: typing.Optional[typing.Sequence[str]]
|
228
288
|
|
229
289
|
Used for specifying feature types without constructing a dataframe. See
|
230
290
|
:py:class:`DMatrix` for details.
|
231
291
|
|
232
|
-
max_cat_to_onehot: Optional[int]
|
292
|
+
max_cat_to_onehot: typing.Optional[int]
|
233
293
|
|
234
294
|
A threshold for deciding whether XGBoost should use one-hot encoding based split
|
235
295
|
for categorical data. When number of categories is lesser than the threshold
|
@@ -238,36 +298,41 @@ class XGBRFRegressor(BaseTransformer):
|
|
238
298
|
categorical feature support. See :doc:`Categorical Data
|
239
299
|
</tutorials/categorical>` and :ref:`cat-param` for details.
|
240
300
|
|
241
|
-
max_cat_threshold: Optional[int]
|
301
|
+
max_cat_threshold: typing.Optional[int]
|
242
302
|
|
243
303
|
Maximum number of categories considered for each split. Used only by
|
244
304
|
partition-based splits for preventing over-fitting. Also, `enable_categorical`
|
245
305
|
needs to be set to have categorical feature support. See :doc:`Categorical Data
|
246
306
|
</tutorials/categorical>` and :ref:`cat-param` for details.
|
247
307
|
|
248
|
-
|
308
|
+
multi_strategy: typing.Optional[str]
|
309
|
+
|
310
|
+
The strategy used for training multi-target models, including multi-target
|
311
|
+
regression and multi-class classification. See :doc:`/tutorials/multioutput` for
|
312
|
+
more information.
|
313
|
+
|
314
|
+
- ``one_output_per_tree``: One model for each target.
|
315
|
+
- ``multi_output_tree``: Use multi-target trees.
|
316
|
+
|
317
|
+
eval_metric: typing.Union[str, typing.List[str], typing.Callable, NoneType]
|
249
318
|
|
250
319
|
Metric used for monitoring the training result and early stopping. It can be a
|
251
320
|
string or list of strings as names of predefined metric in XGBoost (See
|
252
|
-
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any
|
253
|
-
user defined metric that looks like `sklearn.metrics`.
|
321
|
+
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any
|
322
|
+
other user defined metric that looks like `sklearn.metrics`.
|
254
323
|
|
255
324
|
If custom objective is also provided, then custom metric should implement the
|
256
325
|
corresponding reverse link function.
|
257
326
|
|
258
327
|
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
|
259
|
-
object is provided, it's assumed to be a cost function and by default XGBoost
|
260
|
-
minimize the result during early stopping.
|
261
|
-
|
262
|
-
For advanced usage on Early stopping like directly choosing to maximize instead of
|
263
|
-
minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
|
328
|
+
object is provided, it's assumed to be a cost function and by default XGBoost
|
329
|
+
will minimize the result during early stopping.
|
264
330
|
|
265
|
-
|
266
|
-
|
331
|
+
For advanced usage on Early stopping like directly choosing to maximize instead
|
332
|
+
of minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
|
267
333
|
|
268
|
-
|
269
|
-
|
270
|
-
being used.
|
334
|
+
See :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
|
335
|
+
information.
|
271
336
|
|
272
337
|
from sklearn.datasets import load_diabetes
|
273
338
|
from sklearn.metrics import mean_absolute_error
|
@@ -278,24 +343,29 @@ class XGBRFRegressor(BaseTransformer):
|
|
278
343
|
)
|
279
344
|
reg.fit(X, y, eval_set=[(X, y)])
|
280
345
|
|
281
|
-
early_stopping_rounds: Optional[int]
|
346
|
+
early_stopping_rounds: typing.Optional[int]
|
282
347
|
|
283
|
-
Activates early stopping. Validation metric needs to improve at least once in
|
284
|
-
|
285
|
-
|
348
|
+
- Activates early stopping. Validation metric needs to improve at least once in
|
349
|
+
every **early_stopping_rounds** round(s) to continue training. Requires at
|
350
|
+
least one item in **eval_set** in :py:meth:`fit`.
|
286
351
|
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
352
|
+
- If early stopping occurs, the model will have two additional attributes:
|
353
|
+
:py:attr:`best_score` and :py:attr:`best_iteration`. These are used by the
|
354
|
+
:py:meth:`predict` and :py:meth:`apply` methods to determine the optimal
|
355
|
+
number of trees during inference. If users want to access the full model
|
356
|
+
(including trees built after early stopping), they can specify the
|
357
|
+
`iteration_range` in these inference methods. In addition, other utilities
|
358
|
+
like model plotting can also use the entire model.
|
291
359
|
|
292
|
-
If
|
293
|
-
|
294
|
-
:py:attr:`best_ntree_limit`.
|
360
|
+
- If you prefer to discard the trees after `best_iteration`, consider using the
|
361
|
+
callback function :py:class:`xgboost.callback.EarlyStopping`.
|
295
362
|
|
296
|
-
|
363
|
+
- If there's more than one item in **eval_set**, the last entry will be used for
|
364
|
+
early stopping. If there's more than one metric in **eval_metric**, the last
|
365
|
+
metric will be used for early stopping.
|
366
|
+
|
367
|
+
callbacks: typing.Optional[typing.List[xgboost.callback.TrainingCallback]]
|
297
368
|
|
298
|
-
callbacks: Optional[List[TrainingCallback]]
|
299
369
|
List of callback functions that are applied at end of each iteration.
|
300
370
|
It is possible to use predefined callbacks by using
|
301
371
|
:ref:`Callback API <callback_api>`.
|
@@ -307,9 +377,11 @@ class XGBRFRegressor(BaseTransformer):
|
|
307
377
|
for params in parameters_grid:
|
308
378
|
# be sure to (re)initialize the callbacks before each run
|
309
379
|
callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
|
310
|
-
xgboost.
|
380
|
+
reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
|
381
|
+
reg.fit(X, y)
|
382
|
+
|
383
|
+
kwargs: typing.Optional[typing.Any]
|
311
384
|
|
312
|
-
kwargs: dict, optional
|
313
385
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
314
386
|
can be found :doc:`here </parameter>`.
|
315
387
|
Attempting to set a parameter via the constructor args and \*\*kwargs
|
@@ -320,13 +392,16 @@ class XGBRFRegressor(BaseTransformer):
|
|
320
392
|
with scikit-learn.
|
321
393
|
|
322
394
|
A custom objective function can be provided for the ``objective``
|
323
|
-
parameter. In this case, it should have the signature
|
324
|
-
``objective(y_true, y_pred
|
395
|
+
parameter. In this case, it should have the signature ``objective(y_true,
|
396
|
+
y_pred) -> [grad, hess]`` or ``objective(y_true, y_pred, *, sample_weight)
|
397
|
+
-> [grad, hess]``:
|
325
398
|
|
326
399
|
y_true: array_like of shape [n_samples]
|
327
400
|
The target values
|
328
401
|
y_pred: array_like of shape [n_samples]
|
329
402
|
The predicted values
|
403
|
+
sample_weight :
|
404
|
+
Optional sample weights.
|
330
405
|
|
331
406
|
grad: array_like of shape [n_samples]
|
332
407
|
The value of the gradient for each sample point.
|
@@ -1,6 +1,4 @@
|
|
1
|
-
import
|
2
|
-
from collections import Counter
|
3
|
-
from typing import Any, Dict, List, Mapping, Optional, Set
|
1
|
+
from typing import Any, Dict, List, Mapping, Optional
|
4
2
|
|
5
3
|
from snowflake import snowpark
|
6
4
|
from snowflake.ml._internal.utils import (
|
@@ -10,27 +8,12 @@ from snowflake.ml._internal.utils import (
|
|
10
8
|
table_manager,
|
11
9
|
)
|
12
10
|
from snowflake.ml.model._client.sql import _base
|
13
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
14
11
|
from snowflake.snowpark import session, types
|
15
12
|
|
16
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
|
17
|
-
|
18
13
|
MODEL_JSON_COL_NAME = "model"
|
19
14
|
MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
20
15
|
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
21
16
|
|
22
|
-
MONITOR_NAME_COL_NAME = "MONITOR_NAME"
|
23
|
-
SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
|
24
|
-
FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
|
25
|
-
VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
|
26
|
-
FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
|
27
|
-
TASK_COL_NAME = "TASK"
|
28
|
-
MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
|
29
|
-
TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
|
30
|
-
PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
|
31
|
-
LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
|
32
|
-
ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
|
33
|
-
|
34
17
|
|
35
18
|
def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
|
36
19
|
sql_list = ", ".join([f"'{column}'" for column in columns])
|
@@ -59,7 +42,7 @@ class ModelMonitorSQLClient:
|
|
59
42
|
def _infer_qualified_schema(
|
60
43
|
self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier]
|
61
44
|
) -> str:
|
62
|
-
return f"{database_name or self._database_name}.{schema_name or self._schema_name}"
|
45
|
+
return f"""{database_name or self._database_name}.{schema_name or self._schema_name}"""
|
63
46
|
|
64
47
|
def create_model_monitor(
|
65
48
|
self,
|
@@ -91,17 +74,17 @@ class ModelMonitorSQLClient:
|
|
91
74
|
) -> None:
|
92
75
|
baseline_sql = ""
|
93
76
|
if baseline:
|
94
|
-
baseline_sql = f"BASELINE=
|
77
|
+
baseline_sql = f"""BASELINE={self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}"""
|
95
78
|
query_result_checker.SqlResultValidator(
|
96
79
|
self._sql_client._session,
|
97
80
|
f"""
|
98
81
|
CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name}
|
99
82
|
WITH
|
100
|
-
MODEL=
|
83
|
+
MODEL={self._infer_qualified_schema(model_database, model_schema)}.{model_name}
|
101
84
|
VERSION='{version_name}'
|
102
85
|
FUNCTION='{function_name}'
|
103
86
|
WAREHOUSE='{warehouse_name}'
|
104
|
-
SOURCE=
|
87
|
+
SOURCE={self._infer_qualified_schema(source_database, source_schema)}.{source}
|
105
88
|
ID_COLUMNS={_build_sql_list_from_columns(id_columns)}
|
106
89
|
PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)}
|
107
90
|
PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)}
|
@@ -146,19 +129,6 @@ class ModelMonitorSQLClient:
|
|
146
129
|
.validate()
|
147
130
|
)
|
148
131
|
|
149
|
-
def _validate_unique_columns(
|
150
|
-
self,
|
151
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
152
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
153
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
154
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
155
|
-
) -> None:
|
156
|
-
all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column]
|
157
|
-
num_all_columns = len(all_columns)
|
158
|
-
num_unique_columns = len(set(all_columns))
|
159
|
-
if num_all_columns != num_unique_columns:
|
160
|
-
raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.")
|
161
|
-
|
162
132
|
def validate_existence_by_name(
|
163
133
|
self,
|
164
134
|
*,
|
@@ -244,125 +214,6 @@ class ModelMonitorSQLClient:
|
|
244
214
|
if not all([column_name in source_column_schema for column_name in id_columns]):
|
245
215
|
raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
|
246
216
|
|
247
|
-
def _validate_timestamp_column_type(
|
248
|
-
self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
|
249
|
-
) -> None:
|
250
|
-
"""Ensures columns have the same type.
|
251
|
-
|
252
|
-
Args:
|
253
|
-
table_schema: Dictionary of column names and types in the source table.
|
254
|
-
timestamp_column: Name of the timestamp column.
|
255
|
-
|
256
|
-
Raises:
|
257
|
-
ValueError: If the timestamp column is not of type TimestampType.
|
258
|
-
"""
|
259
|
-
if not isinstance(table_schema[timestamp_column], types.TimestampType):
|
260
|
-
raise ValueError(
|
261
|
-
f"Timestamp column: {timestamp_column} must be TimestampType. "
|
262
|
-
f"Found: {table_schema[timestamp_column]}"
|
263
|
-
)
|
264
|
-
|
265
|
-
def _validate_id_columns_types(
|
266
|
-
self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
|
267
|
-
) -> None:
|
268
|
-
"""Ensures id columns have the correct type.
|
269
|
-
|
270
|
-
Args:
|
271
|
-
table_schema: Dictionary of column names and types in the source table.
|
272
|
-
id_columns: List of id column names.
|
273
|
-
|
274
|
-
Raises:
|
275
|
-
ValueError: If the id column is not of type StringType.
|
276
|
-
"""
|
277
|
-
id_column_types = list({table_schema[column_name] for column_name in id_columns})
|
278
|
-
all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
|
279
|
-
if not all_id_columns_string:
|
280
|
-
raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
|
281
|
-
|
282
|
-
def _validate_prediction_columns_types(
|
283
|
-
self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
|
284
|
-
) -> None:
|
285
|
-
"""Ensures prediction columns have the same type.
|
286
|
-
|
287
|
-
Args:
|
288
|
-
table_schema: Dictionary of column names and types in the source table.
|
289
|
-
prediction_columns: List of prediction column names.
|
290
|
-
|
291
|
-
Raises:
|
292
|
-
ValueError: If the prediction columns do not share the same type.
|
293
|
-
"""
|
294
|
-
|
295
|
-
prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
|
296
|
-
if len(prediction_column_types) > 1:
|
297
|
-
raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
|
298
|
-
|
299
|
-
def _validate_label_columns_types(
|
300
|
-
self,
|
301
|
-
table_schema: Mapping[str, types.DataType],
|
302
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
303
|
-
) -> None:
|
304
|
-
"""Ensures label columns have the same type, and the correct type for the score type.
|
305
|
-
|
306
|
-
Args:
|
307
|
-
table_schema: Dictionary of column names and types in the source table.
|
308
|
-
label_columns: List of label column names.
|
309
|
-
|
310
|
-
Raises:
|
311
|
-
ValueError: If the label columns do not share the same type.
|
312
|
-
"""
|
313
|
-
label_column_types = {table_schema[column_name] for column_name in label_columns}
|
314
|
-
if len(label_column_types) > 1:
|
315
|
-
raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
|
316
|
-
|
317
|
-
def _validate_column_types(
|
318
|
-
self,
|
319
|
-
*,
|
320
|
-
table_schema: Mapping[str, types.DataType],
|
321
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
322
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
323
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
324
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
325
|
-
) -> None:
|
326
|
-
"""Ensures columns have the expected type.
|
327
|
-
|
328
|
-
Args:
|
329
|
-
table_schema: Dictionary of column names and types in the source table.
|
330
|
-
timestamp_column: Name of the timestamp column.
|
331
|
-
id_columns: List of id column names.
|
332
|
-
prediction_columns: List of prediction column names.
|
333
|
-
label_columns: List of label column names.
|
334
|
-
"""
|
335
|
-
self._validate_timestamp_column_type(table_schema, timestamp_column)
|
336
|
-
self._validate_id_columns_types(table_schema, id_columns)
|
337
|
-
self._validate_prediction_columns_types(table_schema, prediction_columns)
|
338
|
-
self._validate_label_columns_types(table_schema, label_columns)
|
339
|
-
# TODO(SNOW-1646693): Validate label makes sense with model task
|
340
|
-
|
341
|
-
def _validate_source_table_features_shape(
|
342
|
-
self,
|
343
|
-
table_schema: Mapping[str, types.DataType],
|
344
|
-
special_columns: Set[sql_identifier.SqlIdentifier],
|
345
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
346
|
-
) -> None:
|
347
|
-
table_schema_without_special_columns = {
|
348
|
-
k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
|
349
|
-
}
|
350
|
-
schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
|
351
|
-
for column_type in table_schema_without_special_columns.values():
|
352
|
-
schema_column_types_to_count[column_type] += 1
|
353
|
-
|
354
|
-
inputs = model_function["signature"].inputs
|
355
|
-
function_input_types = [input.as_snowpark_type() for input in inputs]
|
356
|
-
function_input_types_to_count: typing.Counter[types.DataType] = Counter()
|
357
|
-
for function_input_type in function_input_types:
|
358
|
-
function_input_types_to_count[function_input_type] += 1
|
359
|
-
|
360
|
-
if function_input_types_to_count != schema_column_types_to_count:
|
361
|
-
raise ValueError(
|
362
|
-
"Model function input types do not match the source table input columns types. "
|
363
|
-
f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
|
364
|
-
)
|
365
|
-
|
366
217
|
def validate_source(
|
367
218
|
self,
|
368
219
|
*,
|
@@ -395,22 +246,6 @@ class ModelMonitorSQLClient:
|
|
395
246
|
id_columns=id_columns,
|
396
247
|
)
|
397
248
|
|
398
|
-
def delete_monitor_metadata(
|
399
|
-
self,
|
400
|
-
name: str,
|
401
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
402
|
-
) -> None:
|
403
|
-
"""Delete the row in the metadata table corresponding to the given monitor name.
|
404
|
-
|
405
|
-
Args:
|
406
|
-
name: Name of the model monitor whose metadata should be deleted.
|
407
|
-
statement_params: Optional set of statement_params to include with query.
|
408
|
-
"""
|
409
|
-
self._sql_client._session.sql(
|
410
|
-
f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
411
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
|
412
|
-
).collect(statement_params=statement_params)
|
413
|
-
|
414
249
|
def _alter_monitor(
|
415
250
|
self,
|
416
251
|
operation: str,
|
@@ -14,15 +14,6 @@ from snowflake.snowpark import session
|
|
14
14
|
class ModelMonitorManager:
|
15
15
|
"""Class to manage internal operations for Model Monitor workflows."""
|
16
16
|
|
17
|
-
def _validate_task_from_model_version(
|
18
|
-
self,
|
19
|
-
model_version: model_version_impl.ModelVersion,
|
20
|
-
) -> type_hints.Task:
|
21
|
-
task = model_version.get_model_task()
|
22
|
-
if task == type_hints.Task.UNKNOWN:
|
23
|
-
raise ValueError("Registry model must be logged with task in order to be monitored.")
|
24
|
-
return task
|
25
|
-
|
26
17
|
def __init__(
|
27
18
|
self,
|
28
19
|
session: session.Session,
|
@@ -51,6 +42,15 @@ class ModelMonitorManager:
|
|
51
42
|
schema_name=self._schema_name,
|
52
43
|
)
|
53
44
|
|
45
|
+
def _validate_task_from_model_version(
|
46
|
+
self,
|
47
|
+
model_version: model_version_impl.ModelVersion,
|
48
|
+
) -> type_hints.Task:
|
49
|
+
task = model_version.get_model_task()
|
50
|
+
if task == type_hints.Task.UNKNOWN:
|
51
|
+
raise ValueError("Registry model must be logged with task in order to be monitored.")
|
52
|
+
return task
|
53
|
+
|
54
54
|
def _validate_model_function_from_model_version(
|
55
55
|
self, function: str, model_version: model_version_impl.ModelVersion
|
56
56
|
) -> None:
|