snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class TruncatedSVD(BaseTransformer):
|
71
64
|
r"""Dimensionality reduction using truncated SVD (aka LSA)
|
72
65
|
For more details on this class, see [sklearn.decomposition.TruncatedSVD]
|
@@ -240,12 +233,7 @@ class TruncatedSVD(BaseTransformer):
|
|
240
233
|
)
|
241
234
|
return selected_cols
|
242
235
|
|
243
|
-
|
244
|
-
project=_PROJECT,
|
245
|
-
subproject=_SUBPROJECT,
|
246
|
-
custom_tags=dict([("autogen", True)]),
|
247
|
-
)
|
248
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "TruncatedSVD":
|
236
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "TruncatedSVD":
|
249
237
|
"""Fit model on training data X
|
250
238
|
For more details on this function, see [sklearn.decomposition.TruncatedSVD.fit]
|
251
239
|
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html#sklearn.decomposition.TruncatedSVD.fit)
|
@@ -272,12 +260,14 @@ class TruncatedSVD(BaseTransformer):
|
|
272
260
|
|
273
261
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
274
262
|
|
275
|
-
|
263
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
276
264
|
if SNOWML_SPROC_ENV in os.environ:
|
277
265
|
statement_params = telemetry.get_function_usage_statement_params(
|
278
266
|
project=_PROJECT,
|
279
267
|
subproject=_SUBPROJECT,
|
280
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
268
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
269
|
+
inspect.currentframe(), TruncatedSVD.__class__.__name__
|
270
|
+
),
|
281
271
|
api_calls=[Session.call],
|
282
272
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
283
273
|
)
|
@@ -298,27 +288,24 @@ class TruncatedSVD(BaseTransformer):
|
|
298
288
|
)
|
299
289
|
self._sklearn_object = model_trainer.train()
|
300
290
|
self._is_fitted = True
|
301
|
-
self.
|
291
|
+
self._generate_model_signatures(dataset)
|
302
292
|
return self
|
303
293
|
|
304
294
|
def _batch_inference_validate_snowpark(
|
305
295
|
self,
|
306
296
|
dataset: DataFrame,
|
307
297
|
inference_method: str,
|
308
|
-
) ->
|
309
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
310
|
-
return the available package that exists in the snowflake anaconda channel
|
298
|
+
) -> None:
|
299
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
311
300
|
|
312
301
|
Args:
|
313
302
|
dataset: snowpark dataframe
|
314
303
|
inference_method: the inference method such as predict, score...
|
315
|
-
|
304
|
+
|
316
305
|
Raises:
|
317
306
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
318
307
|
SnowflakeMLException: If the session is None, raise error
|
319
308
|
|
320
|
-
Returns:
|
321
|
-
A list of available package that exists in the snowflake anaconda channel
|
322
309
|
"""
|
323
310
|
if not self._is_fitted:
|
324
311
|
raise exceptions.SnowflakeMLException(
|
@@ -336,9 +323,7 @@ class TruncatedSVD(BaseTransformer):
|
|
336
323
|
"Session must not specified for snowpark dataset."
|
337
324
|
),
|
338
325
|
)
|
339
|
-
|
340
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
341
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
326
|
+
|
342
327
|
|
343
328
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
344
329
|
@telemetry.send_api_usage_telemetry(
|
@@ -372,7 +357,9 @@ class TruncatedSVD(BaseTransformer):
|
|
372
357
|
# when it is classifier, infer the datatype from label columns
|
373
358
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
374
359
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
375
|
-
label_cols_signatures = [
|
360
|
+
label_cols_signatures = [
|
361
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
362
|
+
]
|
376
363
|
if len(label_cols_signatures) == 0:
|
377
364
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
378
365
|
raise exceptions.SnowflakeMLException(
|
@@ -380,25 +367,23 @@ class TruncatedSVD(BaseTransformer):
|
|
380
367
|
original_exception=ValueError(error_str),
|
381
368
|
)
|
382
369
|
|
383
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
384
|
-
label_cols_signatures[0].as_snowpark_type()
|
385
|
-
)
|
370
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
386
371
|
|
387
|
-
self.
|
388
|
-
|
372
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
373
|
+
self._deps = self._get_dependencies()
|
374
|
+
assert isinstance(
|
375
|
+
dataset._session, Session
|
376
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
389
377
|
|
390
378
|
transform_kwargs = dict(
|
391
|
-
session
|
392
|
-
dependencies
|
393
|
-
drop_input_cols
|
394
|
-
expected_output_cols_type
|
379
|
+
session=dataset._session,
|
380
|
+
dependencies=self._deps,
|
381
|
+
drop_input_cols=self._drop_input_cols,
|
382
|
+
expected_output_cols_type=expected_type_inferred,
|
395
383
|
)
|
396
384
|
|
397
385
|
elif isinstance(dataset, pd.DataFrame):
|
398
|
-
transform_kwargs = dict(
|
399
|
-
snowpark_input_cols = self._snowpark_cols,
|
400
|
-
drop_input_cols = self._drop_input_cols
|
401
|
-
)
|
386
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
402
387
|
|
403
388
|
transform_handlers = ModelTransformerBuilder.build(
|
404
389
|
dataset=dataset,
|
@@ -440,7 +425,7 @@ class TruncatedSVD(BaseTransformer):
|
|
440
425
|
Transformed dataset.
|
441
426
|
"""
|
442
427
|
super()._check_dataset_type(dataset)
|
443
|
-
inference_method="transform"
|
428
|
+
inference_method = "transform"
|
444
429
|
|
445
430
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
446
431
|
# are specific to the type of dataset used.
|
@@ -470,24 +455,19 @@ class TruncatedSVD(BaseTransformer):
|
|
470
455
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
471
456
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
472
457
|
|
473
|
-
self.
|
474
|
-
|
475
|
-
inference_method=inference_method,
|
476
|
-
)
|
458
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
459
|
+
self._deps = self._get_dependencies()
|
477
460
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
478
461
|
|
479
462
|
transform_kwargs = dict(
|
480
|
-
session
|
481
|
-
dependencies
|
482
|
-
drop_input_cols
|
483
|
-
expected_output_cols_type
|
463
|
+
session=dataset._session,
|
464
|
+
dependencies=self._deps,
|
465
|
+
drop_input_cols=self._drop_input_cols,
|
466
|
+
expected_output_cols_type=expected_dtype,
|
484
467
|
)
|
485
468
|
|
486
469
|
elif isinstance(dataset, pd.DataFrame):
|
487
|
-
transform_kwargs = dict(
|
488
|
-
snowpark_input_cols = self._snowpark_cols,
|
489
|
-
drop_input_cols = self._drop_input_cols
|
490
|
-
)
|
470
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
491
471
|
|
492
472
|
transform_handlers = ModelTransformerBuilder.build(
|
493
473
|
dataset=dataset,
|
@@ -506,7 +486,11 @@ class TruncatedSVD(BaseTransformer):
|
|
506
486
|
return output_df
|
507
487
|
|
508
488
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
509
|
-
def fit_predict(
|
489
|
+
def fit_predict(
|
490
|
+
self,
|
491
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
492
|
+
output_cols_prefix: str = "fit_predict_",
|
493
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
510
494
|
""" Method not supported for this class.
|
511
495
|
|
512
496
|
|
@@ -531,22 +515,106 @@ class TruncatedSVD(BaseTransformer):
|
|
531
515
|
)
|
532
516
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
533
517
|
drop_input_cols=self._drop_input_cols,
|
534
|
-
expected_output_cols_list=
|
518
|
+
expected_output_cols_list=(
|
519
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
520
|
+
),
|
535
521
|
)
|
536
522
|
self._sklearn_object = fitted_estimator
|
537
523
|
self._is_fitted = True
|
538
524
|
return output_result
|
539
525
|
|
526
|
+
|
527
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
528
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
529
|
+
""" Fit model to X and perform dimensionality reduction on X
|
530
|
+
For more details on this function, see [sklearn.decomposition.TruncatedSVD.fit_transform]
|
531
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html#sklearn.decomposition.TruncatedSVD.fit_transform)
|
532
|
+
|
533
|
+
|
534
|
+
Raises:
|
535
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
540
536
|
|
541
|
-
|
542
|
-
|
543
|
-
|
537
|
+
Args:
|
538
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
539
|
+
Snowpark or Pandas DataFrame.
|
540
|
+
output_cols_prefix: Prefix for the response columns
|
544
541
|
Returns:
|
545
542
|
Transformed dataset.
|
546
543
|
"""
|
547
|
-
self.
|
548
|
-
|
549
|
-
|
544
|
+
self._infer_input_output_cols(dataset)
|
545
|
+
super()._check_dataset_type(dataset)
|
546
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
547
|
+
estimator=self._sklearn_object,
|
548
|
+
dataset=dataset,
|
549
|
+
input_cols=self.input_cols,
|
550
|
+
label_cols=self.label_cols,
|
551
|
+
sample_weight_col=self.sample_weight_col,
|
552
|
+
autogenerated=self._autogenerated,
|
553
|
+
subproject=_SUBPROJECT,
|
554
|
+
)
|
555
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
556
|
+
drop_input_cols=self._drop_input_cols,
|
557
|
+
expected_output_cols_list=self.output_cols,
|
558
|
+
)
|
559
|
+
self._sklearn_object = fitted_estimator
|
560
|
+
self._is_fitted = True
|
561
|
+
return output_result
|
562
|
+
|
563
|
+
|
564
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
565
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
566
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
567
|
+
"""
|
568
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
569
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
570
|
+
if output_cols:
|
571
|
+
output_cols = [
|
572
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
573
|
+
for c in output_cols
|
574
|
+
]
|
575
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
576
|
+
output_cols = [output_cols_prefix]
|
577
|
+
elif self._sklearn_object is not None:
|
578
|
+
classes = self._sklearn_object.classes_
|
579
|
+
if isinstance(classes, numpy.ndarray):
|
580
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
581
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
582
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
583
|
+
output_cols = []
|
584
|
+
for i, cl in enumerate(classes):
|
585
|
+
# For binary classification, there is only one output column for each class
|
586
|
+
# ndarray as the two classes are complementary.
|
587
|
+
if len(cl) == 2:
|
588
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
589
|
+
else:
|
590
|
+
output_cols.extend([
|
591
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
592
|
+
])
|
593
|
+
else:
|
594
|
+
output_cols = []
|
595
|
+
|
596
|
+
# Make sure column names are valid snowflake identifiers.
|
597
|
+
assert output_cols is not None # Make MyPy happy
|
598
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
599
|
+
|
600
|
+
return rv
|
601
|
+
|
602
|
+
def _align_expected_output_names(
|
603
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
604
|
+
) -> List[str]:
|
605
|
+
# in case the inferred output column names dimension is different
|
606
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
607
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
608
|
+
output_df_columns = list(output_df_pd.columns)
|
609
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
610
|
+
if self.sample_weight_col:
|
611
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
612
|
+
# if the dimension of inferred output column names is correct; use it
|
613
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
614
|
+
return expected_output_cols_list
|
615
|
+
# otherwise, use the sklearn estimator's output
|
616
|
+
else:
|
617
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
550
618
|
|
551
619
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
552
620
|
@telemetry.send_api_usage_telemetry(
|
@@ -578,24 +646,26 @@ class TruncatedSVD(BaseTransformer):
|
|
578
646
|
# are specific to the type of dataset used.
|
579
647
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
580
648
|
|
649
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
650
|
+
|
581
651
|
if isinstance(dataset, DataFrame):
|
582
|
-
self.
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
652
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
653
|
+
self._deps = self._get_dependencies()
|
654
|
+
assert isinstance(
|
655
|
+
dataset._session, Session
|
656
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
587
657
|
transform_kwargs = dict(
|
588
658
|
session=dataset._session,
|
589
659
|
dependencies=self._deps,
|
590
|
-
drop_input_cols
|
660
|
+
drop_input_cols=self._drop_input_cols,
|
591
661
|
expected_output_cols_type="float",
|
592
662
|
)
|
663
|
+
expected_output_cols = self._align_expected_output_names(
|
664
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
665
|
+
)
|
593
666
|
|
594
667
|
elif isinstance(dataset, pd.DataFrame):
|
595
|
-
transform_kwargs = dict(
|
596
|
-
snowpark_input_cols = self._snowpark_cols,
|
597
|
-
drop_input_cols = self._drop_input_cols
|
598
|
-
)
|
668
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
599
669
|
|
600
670
|
transform_handlers = ModelTransformerBuilder.build(
|
601
671
|
dataset=dataset,
|
@@ -607,7 +677,7 @@ class TruncatedSVD(BaseTransformer):
|
|
607
677
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
608
678
|
inference_method=inference_method,
|
609
679
|
input_cols=self.input_cols,
|
610
|
-
expected_output_cols=
|
680
|
+
expected_output_cols=expected_output_cols,
|
611
681
|
**transform_kwargs
|
612
682
|
)
|
613
683
|
return output_df
|
@@ -637,29 +707,30 @@ class TruncatedSVD(BaseTransformer):
|
|
637
707
|
Output dataset with log probability of the sample for each class in the model.
|
638
708
|
"""
|
639
709
|
super()._check_dataset_type(dataset)
|
640
|
-
inference_method="predict_log_proba"
|
710
|
+
inference_method = "predict_log_proba"
|
711
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
641
712
|
|
642
713
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
643
714
|
# are specific to the type of dataset used.
|
644
715
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
645
716
|
|
646
717
|
if isinstance(dataset, DataFrame):
|
647
|
-
self.
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
718
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
719
|
+
self._deps = self._get_dependencies()
|
720
|
+
assert isinstance(
|
721
|
+
dataset._session, Session
|
722
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
652
723
|
transform_kwargs = dict(
|
653
724
|
session=dataset._session,
|
654
725
|
dependencies=self._deps,
|
655
|
-
drop_input_cols
|
726
|
+
drop_input_cols=self._drop_input_cols,
|
656
727
|
expected_output_cols_type="float",
|
657
728
|
)
|
729
|
+
expected_output_cols = self._align_expected_output_names(
|
730
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
731
|
+
)
|
658
732
|
elif isinstance(dataset, pd.DataFrame):
|
659
|
-
transform_kwargs = dict(
|
660
|
-
snowpark_input_cols = self._snowpark_cols,
|
661
|
-
drop_input_cols = self._drop_input_cols
|
662
|
-
)
|
733
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
663
734
|
|
664
735
|
transform_handlers = ModelTransformerBuilder.build(
|
665
736
|
dataset=dataset,
|
@@ -672,7 +743,7 @@ class TruncatedSVD(BaseTransformer):
|
|
672
743
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
673
744
|
inference_method=inference_method,
|
674
745
|
input_cols=self.input_cols,
|
675
|
-
expected_output_cols=
|
746
|
+
expected_output_cols=expected_output_cols,
|
676
747
|
**transform_kwargs
|
677
748
|
)
|
678
749
|
return output_df
|
@@ -698,30 +769,32 @@ class TruncatedSVD(BaseTransformer):
|
|
698
769
|
Output dataset with results of the decision function for the samples in input dataset.
|
699
770
|
"""
|
700
771
|
super()._check_dataset_type(dataset)
|
701
|
-
inference_method="decision_function"
|
772
|
+
inference_method = "decision_function"
|
702
773
|
|
703
774
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
704
775
|
# are specific to the type of dataset used.
|
705
776
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
706
777
|
|
778
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
|
+
|
707
780
|
if isinstance(dataset, DataFrame):
|
708
|
-
self.
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
781
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
782
|
+
self._deps = self._get_dependencies()
|
783
|
+
assert isinstance(
|
784
|
+
dataset._session, Session
|
785
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
713
786
|
transform_kwargs = dict(
|
714
787
|
session=dataset._session,
|
715
788
|
dependencies=self._deps,
|
716
|
-
drop_input_cols
|
789
|
+
drop_input_cols=self._drop_input_cols,
|
717
790
|
expected_output_cols_type="float",
|
718
791
|
)
|
792
|
+
expected_output_cols = self._align_expected_output_names(
|
793
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
794
|
+
)
|
719
795
|
|
720
796
|
elif isinstance(dataset, pd.DataFrame):
|
721
|
-
transform_kwargs = dict(
|
722
|
-
snowpark_input_cols = self._snowpark_cols,
|
723
|
-
drop_input_cols = self._drop_input_cols
|
724
|
-
)
|
797
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
725
798
|
|
726
799
|
transform_handlers = ModelTransformerBuilder.build(
|
727
800
|
dataset=dataset,
|
@@ -734,7 +807,7 @@ class TruncatedSVD(BaseTransformer):
|
|
734
807
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
735
808
|
inference_method=inference_method,
|
736
809
|
input_cols=self.input_cols,
|
737
|
-
expected_output_cols=
|
810
|
+
expected_output_cols=expected_output_cols,
|
738
811
|
**transform_kwargs
|
739
812
|
)
|
740
813
|
return output_df
|
@@ -763,17 +836,17 @@ class TruncatedSVD(BaseTransformer):
|
|
763
836
|
Output dataset with probability of the sample for each class in the model.
|
764
837
|
"""
|
765
838
|
super()._check_dataset_type(dataset)
|
766
|
-
inference_method="score_samples"
|
839
|
+
inference_method = "score_samples"
|
767
840
|
|
768
841
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
769
842
|
# are specific to the type of dataset used.
|
770
843
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
771
844
|
|
845
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
846
|
+
|
772
847
|
if isinstance(dataset, DataFrame):
|
773
|
-
self.
|
774
|
-
|
775
|
-
inference_method=inference_method,
|
776
|
-
)
|
848
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
849
|
+
self._deps = self._get_dependencies()
|
777
850
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
778
851
|
transform_kwargs = dict(
|
779
852
|
session=dataset._session,
|
@@ -781,6 +854,9 @@ class TruncatedSVD(BaseTransformer):
|
|
781
854
|
drop_input_cols = self._drop_input_cols,
|
782
855
|
expected_output_cols_type="float",
|
783
856
|
)
|
857
|
+
expected_output_cols = self._align_expected_output_names(
|
858
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
859
|
+
)
|
784
860
|
|
785
861
|
elif isinstance(dataset, pd.DataFrame):
|
786
862
|
transform_kwargs = dict(
|
@@ -799,7 +875,7 @@ class TruncatedSVD(BaseTransformer):
|
|
799
875
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
800
876
|
inference_method=inference_method,
|
801
877
|
input_cols=self.input_cols,
|
802
|
-
expected_output_cols=
|
878
|
+
expected_output_cols=expected_output_cols,
|
803
879
|
**transform_kwargs
|
804
880
|
)
|
805
881
|
return output_df
|
@@ -832,17 +908,15 @@ class TruncatedSVD(BaseTransformer):
|
|
832
908
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
833
909
|
|
834
910
|
if isinstance(dataset, DataFrame):
|
835
|
-
self.
|
836
|
-
|
837
|
-
inference_method="score",
|
838
|
-
)
|
911
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
912
|
+
self._deps = self._get_dependencies()
|
839
913
|
selected_cols = self._get_active_columns()
|
840
914
|
if len(selected_cols) > 0:
|
841
915
|
dataset = dataset.select(selected_cols)
|
842
916
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
843
917
|
transform_kwargs = dict(
|
844
918
|
session=dataset._session,
|
845
|
-
dependencies=
|
919
|
+
dependencies=self._deps,
|
846
920
|
score_sproc_imports=['sklearn'],
|
847
921
|
)
|
848
922
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -907,11 +981,8 @@ class TruncatedSVD(BaseTransformer):
|
|
907
981
|
|
908
982
|
if isinstance(dataset, DataFrame):
|
909
983
|
|
910
|
-
self.
|
911
|
-
|
912
|
-
inference_method=inference_method,
|
913
|
-
|
914
|
-
)
|
984
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
985
|
+
self._deps = self._get_dependencies()
|
915
986
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
916
987
|
transform_kwargs = dict(
|
917
988
|
session = dataset._session,
|
@@ -944,50 +1015,84 @@ class TruncatedSVD(BaseTransformer):
|
|
944
1015
|
)
|
945
1016
|
return output_df
|
946
1017
|
|
1018
|
+
|
1019
|
+
|
1020
|
+
def to_sklearn(self) -> Any:
|
1021
|
+
"""Get sklearn.decomposition.TruncatedSVD object.
|
1022
|
+
"""
|
1023
|
+
if self._sklearn_object is None:
|
1024
|
+
self._sklearn_object = self._create_sklearn_object()
|
1025
|
+
return self._sklearn_object
|
1026
|
+
|
1027
|
+
def to_xgboost(self) -> Any:
|
1028
|
+
raise exceptions.SnowflakeMLException(
|
1029
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1030
|
+
original_exception=AttributeError(
|
1031
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1032
|
+
"to_xgboost()",
|
1033
|
+
"to_sklearn()"
|
1034
|
+
)
|
1035
|
+
),
|
1036
|
+
)
|
947
1037
|
|
948
|
-
def
|
1038
|
+
def to_lightgbm(self) -> Any:
|
1039
|
+
raise exceptions.SnowflakeMLException(
|
1040
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1041
|
+
original_exception=AttributeError(
|
1042
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1043
|
+
"to_lightgbm()",
|
1044
|
+
"to_sklearn()"
|
1045
|
+
)
|
1046
|
+
),
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
def _get_dependencies(self) -> List[str]:
|
1050
|
+
return self._deps
|
1051
|
+
|
1052
|
+
|
1053
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
949
1054
|
self._model_signature_dict = dict()
|
950
1055
|
|
951
1056
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
952
1057
|
|
953
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1058
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
954
1059
|
outputs: List[BaseFeatureSpec] = []
|
955
1060
|
if hasattr(self, "predict"):
|
956
1061
|
# keep mypy happy
|
957
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1062
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
958
1063
|
# For classifier, the type of predict is the same as the type of label
|
959
|
-
if self._sklearn_object._estimator_type ==
|
960
|
-
|
1064
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1065
|
+
# label columns is the desired type for output
|
961
1066
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
962
1067
|
# rename the output columns
|
963
1068
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
964
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
965
|
-
|
966
|
-
|
1069
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1070
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1071
|
+
)
|
967
1072
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
968
1073
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
969
|
-
# Clusterer returns int64 cluster labels.
|
1074
|
+
# Clusterer returns int64 cluster labels.
|
970
1075
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
971
1076
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
972
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
973
|
-
|
974
|
-
|
975
|
-
|
1077
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1078
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1079
|
+
)
|
1080
|
+
|
976
1081
|
# For regressor, the type of predict is float64
|
977
|
-
elif self._sklearn_object._estimator_type ==
|
1082
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
978
1083
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
979
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
980
|
-
|
981
|
-
|
982
|
-
|
1084
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1085
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1086
|
+
)
|
1087
|
+
|
983
1088
|
for prob_func in PROB_FUNCTIONS:
|
984
1089
|
if hasattr(self, prob_func):
|
985
1090
|
output_cols_prefix: str = f"{prob_func}_"
|
986
1091
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
987
1092
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
988
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
989
|
-
|
990
|
-
|
1093
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1094
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1095
|
+
)
|
991
1096
|
|
992
1097
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
993
1098
|
items = list(self._model_signature_dict.items())
|
@@ -1000,10 +1105,10 @@ class TruncatedSVD(BaseTransformer):
|
|
1000
1105
|
"""Returns model signature of current class.
|
1001
1106
|
|
1002
1107
|
Raises:
|
1003
|
-
|
1108
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1004
1109
|
|
1005
1110
|
Returns:
|
1006
|
-
Dict
|
1111
|
+
Dict with each method and its input output signature
|
1007
1112
|
"""
|
1008
1113
|
if self._model_signature_dict is None:
|
1009
1114
|
raise exceptions.SnowflakeMLException(
|
@@ -1011,35 +1116,3 @@ class TruncatedSVD(BaseTransformer):
|
|
1011
1116
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1012
1117
|
)
|
1013
1118
|
return self._model_signature_dict
|
1014
|
-
|
1015
|
-
def to_sklearn(self) -> Any:
|
1016
|
-
"""Get sklearn.decomposition.TruncatedSVD object.
|
1017
|
-
"""
|
1018
|
-
if self._sklearn_object is None:
|
1019
|
-
self._sklearn_object = self._create_sklearn_object()
|
1020
|
-
return self._sklearn_object
|
1021
|
-
|
1022
|
-
def to_xgboost(self) -> Any:
|
1023
|
-
raise exceptions.SnowflakeMLException(
|
1024
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1025
|
-
original_exception=AttributeError(
|
1026
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1027
|
-
"to_xgboost()",
|
1028
|
-
"to_sklearn()"
|
1029
|
-
)
|
1030
|
-
),
|
1031
|
-
)
|
1032
|
-
|
1033
|
-
def to_lightgbm(self) -> Any:
|
1034
|
-
raise exceptions.SnowflakeMLException(
|
1035
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1036
|
-
original_exception=AttributeError(
|
1037
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1038
|
-
"to_lightgbm()",
|
1039
|
-
"to_sklearn()"
|
1040
|
-
)
|
1041
|
-
),
|
1042
|
-
)
|
1043
|
-
|
1044
|
-
def _get_dependencies(self) -> List[str]:
|
1045
|
-
return self._deps
|