snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class ARDRegression(BaseTransformer):
|
71
64
|
r"""Bayesian ARD regression
|
72
65
|
For more details on this class, see [sklearn.linear_model.ARDRegression]
|
@@ -258,12 +251,7 @@ class ARDRegression(BaseTransformer):
|
|
258
251
|
)
|
259
252
|
return selected_cols
|
260
253
|
|
261
|
-
|
262
|
-
project=_PROJECT,
|
263
|
-
subproject=_SUBPROJECT,
|
264
|
-
custom_tags=dict([("autogen", True)]),
|
265
|
-
)
|
266
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ARDRegression":
|
254
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ARDRegression":
|
267
255
|
"""Fit the model according to the given training data and parameters
|
268
256
|
For more details on this function, see [sklearn.linear_model.ARDRegression.fit]
|
269
257
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ARDRegression.html#sklearn.linear_model.ARDRegression.fit)
|
@@ -290,12 +278,14 @@ class ARDRegression(BaseTransformer):
|
|
290
278
|
|
291
279
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
292
280
|
|
293
|
-
|
281
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
294
282
|
if SNOWML_SPROC_ENV in os.environ:
|
295
283
|
statement_params = telemetry.get_function_usage_statement_params(
|
296
284
|
project=_PROJECT,
|
297
285
|
subproject=_SUBPROJECT,
|
298
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
286
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
287
|
+
inspect.currentframe(), ARDRegression.__class__.__name__
|
288
|
+
),
|
299
289
|
api_calls=[Session.call],
|
300
290
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
301
291
|
)
|
@@ -316,27 +306,24 @@ class ARDRegression(BaseTransformer):
|
|
316
306
|
)
|
317
307
|
self._sklearn_object = model_trainer.train()
|
318
308
|
self._is_fitted = True
|
319
|
-
self.
|
309
|
+
self._generate_model_signatures(dataset)
|
320
310
|
return self
|
321
311
|
|
322
312
|
def _batch_inference_validate_snowpark(
|
323
313
|
self,
|
324
314
|
dataset: DataFrame,
|
325
315
|
inference_method: str,
|
326
|
-
) ->
|
327
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
328
|
-
return the available package that exists in the snowflake anaconda channel
|
316
|
+
) -> None:
|
317
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
329
318
|
|
330
319
|
Args:
|
331
320
|
dataset: snowpark dataframe
|
332
321
|
inference_method: the inference method such as predict, score...
|
333
|
-
|
322
|
+
|
334
323
|
Raises:
|
335
324
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
336
325
|
SnowflakeMLException: If the session is None, raise error
|
337
326
|
|
338
|
-
Returns:
|
339
|
-
A list of available package that exists in the snowflake anaconda channel
|
340
327
|
"""
|
341
328
|
if not self._is_fitted:
|
342
329
|
raise exceptions.SnowflakeMLException(
|
@@ -354,9 +341,7 @@ class ARDRegression(BaseTransformer):
|
|
354
341
|
"Session must not specified for snowpark dataset."
|
355
342
|
),
|
356
343
|
)
|
357
|
-
|
358
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
359
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
344
|
+
|
360
345
|
|
361
346
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
362
347
|
@telemetry.send_api_usage_telemetry(
|
@@ -392,7 +377,9 @@ class ARDRegression(BaseTransformer):
|
|
392
377
|
# when it is classifier, infer the datatype from label columns
|
393
378
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
394
379
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
395
|
-
label_cols_signatures = [
|
380
|
+
label_cols_signatures = [
|
381
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
382
|
+
]
|
396
383
|
if len(label_cols_signatures) == 0:
|
397
384
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
398
385
|
raise exceptions.SnowflakeMLException(
|
@@ -400,25 +387,23 @@ class ARDRegression(BaseTransformer):
|
|
400
387
|
original_exception=ValueError(error_str),
|
401
388
|
)
|
402
389
|
|
403
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
404
|
-
label_cols_signatures[0].as_snowpark_type()
|
405
|
-
)
|
390
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
406
391
|
|
407
|
-
self.
|
408
|
-
|
392
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
393
|
+
self._deps = self._get_dependencies()
|
394
|
+
assert isinstance(
|
395
|
+
dataset._session, Session
|
396
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
409
397
|
|
410
398
|
transform_kwargs = dict(
|
411
|
-
session
|
412
|
-
dependencies
|
413
|
-
drop_input_cols
|
414
|
-
expected_output_cols_type
|
399
|
+
session=dataset._session,
|
400
|
+
dependencies=self._deps,
|
401
|
+
drop_input_cols=self._drop_input_cols,
|
402
|
+
expected_output_cols_type=expected_type_inferred,
|
415
403
|
)
|
416
404
|
|
417
405
|
elif isinstance(dataset, pd.DataFrame):
|
418
|
-
transform_kwargs = dict(
|
419
|
-
snowpark_input_cols = self._snowpark_cols,
|
420
|
-
drop_input_cols = self._drop_input_cols
|
421
|
-
)
|
406
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
422
407
|
|
423
408
|
transform_handlers = ModelTransformerBuilder.build(
|
424
409
|
dataset=dataset,
|
@@ -458,7 +443,7 @@ class ARDRegression(BaseTransformer):
|
|
458
443
|
Transformed dataset.
|
459
444
|
"""
|
460
445
|
super()._check_dataset_type(dataset)
|
461
|
-
inference_method="transform"
|
446
|
+
inference_method = "transform"
|
462
447
|
|
463
448
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
464
449
|
# are specific to the type of dataset used.
|
@@ -488,24 +473,19 @@ class ARDRegression(BaseTransformer):
|
|
488
473
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
489
474
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
490
475
|
|
491
|
-
self.
|
492
|
-
|
493
|
-
inference_method=inference_method,
|
494
|
-
)
|
476
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
477
|
+
self._deps = self._get_dependencies()
|
495
478
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
496
479
|
|
497
480
|
transform_kwargs = dict(
|
498
|
-
session
|
499
|
-
dependencies
|
500
|
-
drop_input_cols
|
501
|
-
expected_output_cols_type
|
481
|
+
session=dataset._session,
|
482
|
+
dependencies=self._deps,
|
483
|
+
drop_input_cols=self._drop_input_cols,
|
484
|
+
expected_output_cols_type=expected_dtype,
|
502
485
|
)
|
503
486
|
|
504
487
|
elif isinstance(dataset, pd.DataFrame):
|
505
|
-
transform_kwargs = dict(
|
506
|
-
snowpark_input_cols = self._snowpark_cols,
|
507
|
-
drop_input_cols = self._drop_input_cols
|
508
|
-
)
|
488
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
509
489
|
|
510
490
|
transform_handlers = ModelTransformerBuilder.build(
|
511
491
|
dataset=dataset,
|
@@ -524,7 +504,11 @@ class ARDRegression(BaseTransformer):
|
|
524
504
|
return output_df
|
525
505
|
|
526
506
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
527
|
-
def fit_predict(
|
507
|
+
def fit_predict(
|
508
|
+
self,
|
509
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
510
|
+
output_cols_prefix: str = "fit_predict_",
|
511
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
528
512
|
""" Method not supported for this class.
|
529
513
|
|
530
514
|
|
@@ -549,22 +533,104 @@ class ARDRegression(BaseTransformer):
|
|
549
533
|
)
|
550
534
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
551
535
|
drop_input_cols=self._drop_input_cols,
|
552
|
-
expected_output_cols_list=
|
536
|
+
expected_output_cols_list=(
|
537
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
538
|
+
),
|
553
539
|
)
|
554
540
|
self._sklearn_object = fitted_estimator
|
555
541
|
self._is_fitted = True
|
556
542
|
return output_result
|
557
543
|
|
544
|
+
|
545
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
546
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
547
|
+
""" Method not supported for this class.
|
548
|
+
|
558
549
|
|
559
|
-
|
560
|
-
|
561
|
-
|
550
|
+
Raises:
|
551
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
555
|
+
Snowpark or Pandas DataFrame.
|
556
|
+
output_cols_prefix: Prefix for the response columns
|
562
557
|
Returns:
|
563
558
|
Transformed dataset.
|
564
559
|
"""
|
565
|
-
self.
|
566
|
-
|
567
|
-
|
560
|
+
self._infer_input_output_cols(dataset)
|
561
|
+
super()._check_dataset_type(dataset)
|
562
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
563
|
+
estimator=self._sklearn_object,
|
564
|
+
dataset=dataset,
|
565
|
+
input_cols=self.input_cols,
|
566
|
+
label_cols=self.label_cols,
|
567
|
+
sample_weight_col=self.sample_weight_col,
|
568
|
+
autogenerated=self._autogenerated,
|
569
|
+
subproject=_SUBPROJECT,
|
570
|
+
)
|
571
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
572
|
+
drop_input_cols=self._drop_input_cols,
|
573
|
+
expected_output_cols_list=self.output_cols,
|
574
|
+
)
|
575
|
+
self._sklearn_object = fitted_estimator
|
576
|
+
self._is_fitted = True
|
577
|
+
return output_result
|
578
|
+
|
579
|
+
|
580
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
581
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
582
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
583
|
+
"""
|
584
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
585
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
586
|
+
if output_cols:
|
587
|
+
output_cols = [
|
588
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
589
|
+
for c in output_cols
|
590
|
+
]
|
591
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
592
|
+
output_cols = [output_cols_prefix]
|
593
|
+
elif self._sklearn_object is not None:
|
594
|
+
classes = self._sklearn_object.classes_
|
595
|
+
if isinstance(classes, numpy.ndarray):
|
596
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
597
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
598
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
599
|
+
output_cols = []
|
600
|
+
for i, cl in enumerate(classes):
|
601
|
+
# For binary classification, there is only one output column for each class
|
602
|
+
# ndarray as the two classes are complementary.
|
603
|
+
if len(cl) == 2:
|
604
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
605
|
+
else:
|
606
|
+
output_cols.extend([
|
607
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
608
|
+
])
|
609
|
+
else:
|
610
|
+
output_cols = []
|
611
|
+
|
612
|
+
# Make sure column names are valid snowflake identifiers.
|
613
|
+
assert output_cols is not None # Make MyPy happy
|
614
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
615
|
+
|
616
|
+
return rv
|
617
|
+
|
618
|
+
def _align_expected_output_names(
|
619
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
620
|
+
) -> List[str]:
|
621
|
+
# in case the inferred output column names dimension is different
|
622
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
623
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
624
|
+
output_df_columns = list(output_df_pd.columns)
|
625
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
626
|
+
if self.sample_weight_col:
|
627
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
628
|
+
# if the dimension of inferred output column names is correct; use it
|
629
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
630
|
+
return expected_output_cols_list
|
631
|
+
# otherwise, use the sklearn estimator's output
|
632
|
+
else:
|
633
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
568
634
|
|
569
635
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
570
636
|
@telemetry.send_api_usage_telemetry(
|
@@ -596,24 +662,26 @@ class ARDRegression(BaseTransformer):
|
|
596
662
|
# are specific to the type of dataset used.
|
597
663
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
598
664
|
|
665
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
666
|
+
|
599
667
|
if isinstance(dataset, DataFrame):
|
600
|
-
self.
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
668
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
669
|
+
self._deps = self._get_dependencies()
|
670
|
+
assert isinstance(
|
671
|
+
dataset._session, Session
|
672
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
605
673
|
transform_kwargs = dict(
|
606
674
|
session=dataset._session,
|
607
675
|
dependencies=self._deps,
|
608
|
-
drop_input_cols
|
676
|
+
drop_input_cols=self._drop_input_cols,
|
609
677
|
expected_output_cols_type="float",
|
610
678
|
)
|
679
|
+
expected_output_cols = self._align_expected_output_names(
|
680
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
681
|
+
)
|
611
682
|
|
612
683
|
elif isinstance(dataset, pd.DataFrame):
|
613
|
-
transform_kwargs = dict(
|
614
|
-
snowpark_input_cols = self._snowpark_cols,
|
615
|
-
drop_input_cols = self._drop_input_cols
|
616
|
-
)
|
684
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
617
685
|
|
618
686
|
transform_handlers = ModelTransformerBuilder.build(
|
619
687
|
dataset=dataset,
|
@@ -625,7 +693,7 @@ class ARDRegression(BaseTransformer):
|
|
625
693
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
626
694
|
inference_method=inference_method,
|
627
695
|
input_cols=self.input_cols,
|
628
|
-
expected_output_cols=
|
696
|
+
expected_output_cols=expected_output_cols,
|
629
697
|
**transform_kwargs
|
630
698
|
)
|
631
699
|
return output_df
|
@@ -655,29 +723,30 @@ class ARDRegression(BaseTransformer):
|
|
655
723
|
Output dataset with log probability of the sample for each class in the model.
|
656
724
|
"""
|
657
725
|
super()._check_dataset_type(dataset)
|
658
|
-
inference_method="predict_log_proba"
|
726
|
+
inference_method = "predict_log_proba"
|
727
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
659
728
|
|
660
729
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
661
730
|
# are specific to the type of dataset used.
|
662
731
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
663
732
|
|
664
733
|
if isinstance(dataset, DataFrame):
|
665
|
-
self.
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
734
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
735
|
+
self._deps = self._get_dependencies()
|
736
|
+
assert isinstance(
|
737
|
+
dataset._session, Session
|
738
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
670
739
|
transform_kwargs = dict(
|
671
740
|
session=dataset._session,
|
672
741
|
dependencies=self._deps,
|
673
|
-
drop_input_cols
|
742
|
+
drop_input_cols=self._drop_input_cols,
|
674
743
|
expected_output_cols_type="float",
|
675
744
|
)
|
745
|
+
expected_output_cols = self._align_expected_output_names(
|
746
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
747
|
+
)
|
676
748
|
elif isinstance(dataset, pd.DataFrame):
|
677
|
-
transform_kwargs = dict(
|
678
|
-
snowpark_input_cols = self._snowpark_cols,
|
679
|
-
drop_input_cols = self._drop_input_cols
|
680
|
-
)
|
749
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
681
750
|
|
682
751
|
transform_handlers = ModelTransformerBuilder.build(
|
683
752
|
dataset=dataset,
|
@@ -690,7 +759,7 @@ class ARDRegression(BaseTransformer):
|
|
690
759
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
691
760
|
inference_method=inference_method,
|
692
761
|
input_cols=self.input_cols,
|
693
|
-
expected_output_cols=
|
762
|
+
expected_output_cols=expected_output_cols,
|
694
763
|
**transform_kwargs
|
695
764
|
)
|
696
765
|
return output_df
|
@@ -716,30 +785,32 @@ class ARDRegression(BaseTransformer):
|
|
716
785
|
Output dataset with results of the decision function for the samples in input dataset.
|
717
786
|
"""
|
718
787
|
super()._check_dataset_type(dataset)
|
719
|
-
inference_method="decision_function"
|
788
|
+
inference_method = "decision_function"
|
720
789
|
|
721
790
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
722
791
|
# are specific to the type of dataset used.
|
723
792
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
724
793
|
|
794
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
795
|
+
|
725
796
|
if isinstance(dataset, DataFrame):
|
726
|
-
self.
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
797
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
798
|
+
self._deps = self._get_dependencies()
|
799
|
+
assert isinstance(
|
800
|
+
dataset._session, Session
|
801
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
731
802
|
transform_kwargs = dict(
|
732
803
|
session=dataset._session,
|
733
804
|
dependencies=self._deps,
|
734
|
-
drop_input_cols
|
805
|
+
drop_input_cols=self._drop_input_cols,
|
735
806
|
expected_output_cols_type="float",
|
736
807
|
)
|
808
|
+
expected_output_cols = self._align_expected_output_names(
|
809
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
810
|
+
)
|
737
811
|
|
738
812
|
elif isinstance(dataset, pd.DataFrame):
|
739
|
-
transform_kwargs = dict(
|
740
|
-
snowpark_input_cols = self._snowpark_cols,
|
741
|
-
drop_input_cols = self._drop_input_cols
|
742
|
-
)
|
813
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
743
814
|
|
744
815
|
transform_handlers = ModelTransformerBuilder.build(
|
745
816
|
dataset=dataset,
|
@@ -752,7 +823,7 @@ class ARDRegression(BaseTransformer):
|
|
752
823
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
753
824
|
inference_method=inference_method,
|
754
825
|
input_cols=self.input_cols,
|
755
|
-
expected_output_cols=
|
826
|
+
expected_output_cols=expected_output_cols,
|
756
827
|
**transform_kwargs
|
757
828
|
)
|
758
829
|
return output_df
|
@@ -781,17 +852,17 @@ class ARDRegression(BaseTransformer):
|
|
781
852
|
Output dataset with probability of the sample for each class in the model.
|
782
853
|
"""
|
783
854
|
super()._check_dataset_type(dataset)
|
784
|
-
inference_method="score_samples"
|
855
|
+
inference_method = "score_samples"
|
785
856
|
|
786
857
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
787
858
|
# are specific to the type of dataset used.
|
788
859
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
789
860
|
|
861
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
862
|
+
|
790
863
|
if isinstance(dataset, DataFrame):
|
791
|
-
self.
|
792
|
-
|
793
|
-
inference_method=inference_method,
|
794
|
-
)
|
864
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
865
|
+
self._deps = self._get_dependencies()
|
795
866
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
796
867
|
transform_kwargs = dict(
|
797
868
|
session=dataset._session,
|
@@ -799,6 +870,9 @@ class ARDRegression(BaseTransformer):
|
|
799
870
|
drop_input_cols = self._drop_input_cols,
|
800
871
|
expected_output_cols_type="float",
|
801
872
|
)
|
873
|
+
expected_output_cols = self._align_expected_output_names(
|
874
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
875
|
+
)
|
802
876
|
|
803
877
|
elif isinstance(dataset, pd.DataFrame):
|
804
878
|
transform_kwargs = dict(
|
@@ -817,7 +891,7 @@ class ARDRegression(BaseTransformer):
|
|
817
891
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
818
892
|
inference_method=inference_method,
|
819
893
|
input_cols=self.input_cols,
|
820
|
-
expected_output_cols=
|
894
|
+
expected_output_cols=expected_output_cols,
|
821
895
|
**transform_kwargs
|
822
896
|
)
|
823
897
|
return output_df
|
@@ -852,17 +926,15 @@ class ARDRegression(BaseTransformer):
|
|
852
926
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
853
927
|
|
854
928
|
if isinstance(dataset, DataFrame):
|
855
|
-
self.
|
856
|
-
|
857
|
-
inference_method="score",
|
858
|
-
)
|
929
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
930
|
+
self._deps = self._get_dependencies()
|
859
931
|
selected_cols = self._get_active_columns()
|
860
932
|
if len(selected_cols) > 0:
|
861
933
|
dataset = dataset.select(selected_cols)
|
862
934
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
863
935
|
transform_kwargs = dict(
|
864
936
|
session=dataset._session,
|
865
|
-
dependencies=
|
937
|
+
dependencies=self._deps,
|
866
938
|
score_sproc_imports=['sklearn'],
|
867
939
|
)
|
868
940
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -927,11 +999,8 @@ class ARDRegression(BaseTransformer):
|
|
927
999
|
|
928
1000
|
if isinstance(dataset, DataFrame):
|
929
1001
|
|
930
|
-
self.
|
931
|
-
|
932
|
-
inference_method=inference_method,
|
933
|
-
|
934
|
-
)
|
1002
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1003
|
+
self._deps = self._get_dependencies()
|
935
1004
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
936
1005
|
transform_kwargs = dict(
|
937
1006
|
session = dataset._session,
|
@@ -964,50 +1033,84 @@ class ARDRegression(BaseTransformer):
|
|
964
1033
|
)
|
965
1034
|
return output_df
|
966
1035
|
|
1036
|
+
|
1037
|
+
|
1038
|
+
def to_sklearn(self) -> Any:
|
1039
|
+
"""Get sklearn.linear_model.ARDRegression object.
|
1040
|
+
"""
|
1041
|
+
if self._sklearn_object is None:
|
1042
|
+
self._sklearn_object = self._create_sklearn_object()
|
1043
|
+
return self._sklearn_object
|
1044
|
+
|
1045
|
+
def to_xgboost(self) -> Any:
|
1046
|
+
raise exceptions.SnowflakeMLException(
|
1047
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1048
|
+
original_exception=AttributeError(
|
1049
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1050
|
+
"to_xgboost()",
|
1051
|
+
"to_sklearn()"
|
1052
|
+
)
|
1053
|
+
),
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
def to_lightgbm(self) -> Any:
|
1057
|
+
raise exceptions.SnowflakeMLException(
|
1058
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1059
|
+
original_exception=AttributeError(
|
1060
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1061
|
+
"to_lightgbm()",
|
1062
|
+
"to_sklearn()"
|
1063
|
+
)
|
1064
|
+
),
|
1065
|
+
)
|
1066
|
+
|
1067
|
+
def _get_dependencies(self) -> List[str]:
|
1068
|
+
return self._deps
|
1069
|
+
|
967
1070
|
|
968
|
-
def
|
1071
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
969
1072
|
self._model_signature_dict = dict()
|
970
1073
|
|
971
1074
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
972
1075
|
|
973
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1076
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
974
1077
|
outputs: List[BaseFeatureSpec] = []
|
975
1078
|
if hasattr(self, "predict"):
|
976
1079
|
# keep mypy happy
|
977
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1080
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
978
1081
|
# For classifier, the type of predict is the same as the type of label
|
979
|
-
if self._sklearn_object._estimator_type ==
|
980
|
-
|
1082
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1083
|
+
# label columns is the desired type for output
|
981
1084
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
982
1085
|
# rename the output columns
|
983
1086
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
984
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
985
|
-
|
986
|
-
|
1087
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1088
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1089
|
+
)
|
987
1090
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
988
1091
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
989
|
-
# Clusterer returns int64 cluster labels.
|
1092
|
+
# Clusterer returns int64 cluster labels.
|
990
1093
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
991
1094
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
992
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
993
|
-
|
994
|
-
|
995
|
-
|
1095
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1096
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1097
|
+
)
|
1098
|
+
|
996
1099
|
# For regressor, the type of predict is float64
|
997
|
-
elif self._sklearn_object._estimator_type ==
|
1100
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
998
1101
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
999
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1102
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1103
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1104
|
+
)
|
1105
|
+
|
1003
1106
|
for prob_func in PROB_FUNCTIONS:
|
1004
1107
|
if hasattr(self, prob_func):
|
1005
1108
|
output_cols_prefix: str = f"{prob_func}_"
|
1006
1109
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1007
1110
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1008
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1009
|
-
|
1010
|
-
|
1111
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1112
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1113
|
+
)
|
1011
1114
|
|
1012
1115
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1013
1116
|
items = list(self._model_signature_dict.items())
|
@@ -1020,10 +1123,10 @@ class ARDRegression(BaseTransformer):
|
|
1020
1123
|
"""Returns model signature of current class.
|
1021
1124
|
|
1022
1125
|
Raises:
|
1023
|
-
|
1126
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1024
1127
|
|
1025
1128
|
Returns:
|
1026
|
-
Dict
|
1129
|
+
Dict with each method and its input output signature
|
1027
1130
|
"""
|
1028
1131
|
if self._model_signature_dict is None:
|
1029
1132
|
raise exceptions.SnowflakeMLException(
|
@@ -1031,35 +1134,3 @@ class ARDRegression(BaseTransformer):
|
|
1031
1134
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1032
1135
|
)
|
1033
1136
|
return self._model_signature_dict
|
1034
|
-
|
1035
|
-
def to_sklearn(self) -> Any:
|
1036
|
-
"""Get sklearn.linear_model.ARDRegression object.
|
1037
|
-
"""
|
1038
|
-
if self._sklearn_object is None:
|
1039
|
-
self._sklearn_object = self._create_sklearn_object()
|
1040
|
-
return self._sklearn_object
|
1041
|
-
|
1042
|
-
def to_xgboost(self) -> Any:
|
1043
|
-
raise exceptions.SnowflakeMLException(
|
1044
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1045
|
-
original_exception=AttributeError(
|
1046
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1047
|
-
"to_xgboost()",
|
1048
|
-
"to_sklearn()"
|
1049
|
-
)
|
1050
|
-
),
|
1051
|
-
)
|
1052
|
-
|
1053
|
-
def to_lightgbm(self) -> Any:
|
1054
|
-
raise exceptions.SnowflakeMLException(
|
1055
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1056
|
-
original_exception=AttributeError(
|
1057
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1058
|
-
"to_lightgbm()",
|
1059
|
-
"to_sklearn()"
|
1060
|
-
)
|
1061
|
-
),
|
1062
|
-
)
|
1063
|
-
|
1064
|
-
def _get_dependencies(self) -> List[str]:
|
1065
|
-
return self._deps
|