snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class FactorAnalysis(BaseTransformer):
|
71
64
|
r"""Factor Analysis (FA)
|
72
65
|
For more details on this class, see [sklearn.decomposition.FactorAnalysis]
|
@@ -251,12 +244,7 @@ class FactorAnalysis(BaseTransformer):
|
|
251
244
|
)
|
252
245
|
return selected_cols
|
253
246
|
|
254
|
-
|
255
|
-
project=_PROJECT,
|
256
|
-
subproject=_SUBPROJECT,
|
257
|
-
custom_tags=dict([("autogen", True)]),
|
258
|
-
)
|
259
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "FactorAnalysis":
|
247
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "FactorAnalysis":
|
260
248
|
"""Fit the FactorAnalysis model to X using SVD based approach
|
261
249
|
For more details on this function, see [sklearn.decomposition.FactorAnalysis.fit]
|
262
250
|
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FactorAnalysis.html#sklearn.decomposition.FactorAnalysis.fit)
|
@@ -283,12 +271,14 @@ class FactorAnalysis(BaseTransformer):
|
|
283
271
|
|
284
272
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
285
273
|
|
286
|
-
|
274
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
287
275
|
if SNOWML_SPROC_ENV in os.environ:
|
288
276
|
statement_params = telemetry.get_function_usage_statement_params(
|
289
277
|
project=_PROJECT,
|
290
278
|
subproject=_SUBPROJECT,
|
291
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
279
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
280
|
+
inspect.currentframe(), FactorAnalysis.__class__.__name__
|
281
|
+
),
|
292
282
|
api_calls=[Session.call],
|
293
283
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
294
284
|
)
|
@@ -309,27 +299,24 @@ class FactorAnalysis(BaseTransformer):
|
|
309
299
|
)
|
310
300
|
self._sklearn_object = model_trainer.train()
|
311
301
|
self._is_fitted = True
|
312
|
-
self.
|
302
|
+
self._generate_model_signatures(dataset)
|
313
303
|
return self
|
314
304
|
|
315
305
|
def _batch_inference_validate_snowpark(
|
316
306
|
self,
|
317
307
|
dataset: DataFrame,
|
318
308
|
inference_method: str,
|
319
|
-
) ->
|
320
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
321
|
-
return the available package that exists in the snowflake anaconda channel
|
309
|
+
) -> None:
|
310
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
322
311
|
|
323
312
|
Args:
|
324
313
|
dataset: snowpark dataframe
|
325
314
|
inference_method: the inference method such as predict, score...
|
326
|
-
|
315
|
+
|
327
316
|
Raises:
|
328
317
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
329
318
|
SnowflakeMLException: If the session is None, raise error
|
330
319
|
|
331
|
-
Returns:
|
332
|
-
A list of available package that exists in the snowflake anaconda channel
|
333
320
|
"""
|
334
321
|
if not self._is_fitted:
|
335
322
|
raise exceptions.SnowflakeMLException(
|
@@ -347,9 +334,7 @@ class FactorAnalysis(BaseTransformer):
|
|
347
334
|
"Session must not specified for snowpark dataset."
|
348
335
|
),
|
349
336
|
)
|
350
|
-
|
351
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
352
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
337
|
+
|
353
338
|
|
354
339
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
355
340
|
@telemetry.send_api_usage_telemetry(
|
@@ -383,7 +368,9 @@ class FactorAnalysis(BaseTransformer):
|
|
383
368
|
# when it is classifier, infer the datatype from label columns
|
384
369
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
385
370
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
386
|
-
label_cols_signatures = [
|
371
|
+
label_cols_signatures = [
|
372
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
373
|
+
]
|
387
374
|
if len(label_cols_signatures) == 0:
|
388
375
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
389
376
|
raise exceptions.SnowflakeMLException(
|
@@ -391,25 +378,23 @@ class FactorAnalysis(BaseTransformer):
|
|
391
378
|
original_exception=ValueError(error_str),
|
392
379
|
)
|
393
380
|
|
394
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
395
|
-
label_cols_signatures[0].as_snowpark_type()
|
396
|
-
)
|
381
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
397
382
|
|
398
|
-
self.
|
399
|
-
|
383
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
384
|
+
self._deps = self._get_dependencies()
|
385
|
+
assert isinstance(
|
386
|
+
dataset._session, Session
|
387
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
400
388
|
|
401
389
|
transform_kwargs = dict(
|
402
|
-
session
|
403
|
-
dependencies
|
404
|
-
drop_input_cols
|
405
|
-
expected_output_cols_type
|
390
|
+
session=dataset._session,
|
391
|
+
dependencies=self._deps,
|
392
|
+
drop_input_cols=self._drop_input_cols,
|
393
|
+
expected_output_cols_type=expected_type_inferred,
|
406
394
|
)
|
407
395
|
|
408
396
|
elif isinstance(dataset, pd.DataFrame):
|
409
|
-
transform_kwargs = dict(
|
410
|
-
snowpark_input_cols = self._snowpark_cols,
|
411
|
-
drop_input_cols = self._drop_input_cols
|
412
|
-
)
|
397
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
413
398
|
|
414
399
|
transform_handlers = ModelTransformerBuilder.build(
|
415
400
|
dataset=dataset,
|
@@ -451,7 +436,7 @@ class FactorAnalysis(BaseTransformer):
|
|
451
436
|
Transformed dataset.
|
452
437
|
"""
|
453
438
|
super()._check_dataset_type(dataset)
|
454
|
-
inference_method="transform"
|
439
|
+
inference_method = "transform"
|
455
440
|
|
456
441
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
457
442
|
# are specific to the type of dataset used.
|
@@ -481,24 +466,19 @@ class FactorAnalysis(BaseTransformer):
|
|
481
466
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
482
467
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
483
468
|
|
484
|
-
self.
|
485
|
-
|
486
|
-
inference_method=inference_method,
|
487
|
-
)
|
469
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
470
|
+
self._deps = self._get_dependencies()
|
488
471
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
489
472
|
|
490
473
|
transform_kwargs = dict(
|
491
|
-
session
|
492
|
-
dependencies
|
493
|
-
drop_input_cols
|
494
|
-
expected_output_cols_type
|
474
|
+
session=dataset._session,
|
475
|
+
dependencies=self._deps,
|
476
|
+
drop_input_cols=self._drop_input_cols,
|
477
|
+
expected_output_cols_type=expected_dtype,
|
495
478
|
)
|
496
479
|
|
497
480
|
elif isinstance(dataset, pd.DataFrame):
|
498
|
-
transform_kwargs = dict(
|
499
|
-
snowpark_input_cols = self._snowpark_cols,
|
500
|
-
drop_input_cols = self._drop_input_cols
|
501
|
-
)
|
481
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
502
482
|
|
503
483
|
transform_handlers = ModelTransformerBuilder.build(
|
504
484
|
dataset=dataset,
|
@@ -517,7 +497,11 @@ class FactorAnalysis(BaseTransformer):
|
|
517
497
|
return output_df
|
518
498
|
|
519
499
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
520
|
-
def fit_predict(
|
500
|
+
def fit_predict(
|
501
|
+
self,
|
502
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
503
|
+
output_cols_prefix: str = "fit_predict_",
|
504
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
521
505
|
""" Method not supported for this class.
|
522
506
|
|
523
507
|
|
@@ -542,22 +526,106 @@ class FactorAnalysis(BaseTransformer):
|
|
542
526
|
)
|
543
527
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
544
528
|
drop_input_cols=self._drop_input_cols,
|
545
|
-
expected_output_cols_list=
|
529
|
+
expected_output_cols_list=(
|
530
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
531
|
+
),
|
546
532
|
)
|
547
533
|
self._sklearn_object = fitted_estimator
|
548
534
|
self._is_fitted = True
|
549
535
|
return output_result
|
550
536
|
|
537
|
+
|
538
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
539
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
540
|
+
""" Fit to data, then transform it
|
541
|
+
For more details on this function, see [sklearn.decomposition.FactorAnalysis.fit_transform]
|
542
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FactorAnalysis.html#sklearn.decomposition.FactorAnalysis.fit_transform)
|
543
|
+
|
544
|
+
|
545
|
+
Raises:
|
546
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
551
547
|
|
552
|
-
|
553
|
-
|
554
|
-
|
548
|
+
Args:
|
549
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
550
|
+
Snowpark or Pandas DataFrame.
|
551
|
+
output_cols_prefix: Prefix for the response columns
|
555
552
|
Returns:
|
556
553
|
Transformed dataset.
|
557
554
|
"""
|
558
|
-
self.
|
559
|
-
|
560
|
-
|
555
|
+
self._infer_input_output_cols(dataset)
|
556
|
+
super()._check_dataset_type(dataset)
|
557
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
558
|
+
estimator=self._sklearn_object,
|
559
|
+
dataset=dataset,
|
560
|
+
input_cols=self.input_cols,
|
561
|
+
label_cols=self.label_cols,
|
562
|
+
sample_weight_col=self.sample_weight_col,
|
563
|
+
autogenerated=self._autogenerated,
|
564
|
+
subproject=_SUBPROJECT,
|
565
|
+
)
|
566
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
567
|
+
drop_input_cols=self._drop_input_cols,
|
568
|
+
expected_output_cols_list=self.output_cols,
|
569
|
+
)
|
570
|
+
self._sklearn_object = fitted_estimator
|
571
|
+
self._is_fitted = True
|
572
|
+
return output_result
|
573
|
+
|
574
|
+
|
575
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
576
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
577
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
578
|
+
"""
|
579
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
580
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
581
|
+
if output_cols:
|
582
|
+
output_cols = [
|
583
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
584
|
+
for c in output_cols
|
585
|
+
]
|
586
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
587
|
+
output_cols = [output_cols_prefix]
|
588
|
+
elif self._sklearn_object is not None:
|
589
|
+
classes = self._sklearn_object.classes_
|
590
|
+
if isinstance(classes, numpy.ndarray):
|
591
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
592
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
593
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
594
|
+
output_cols = []
|
595
|
+
for i, cl in enumerate(classes):
|
596
|
+
# For binary classification, there is only one output column for each class
|
597
|
+
# ndarray as the two classes are complementary.
|
598
|
+
if len(cl) == 2:
|
599
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
600
|
+
else:
|
601
|
+
output_cols.extend([
|
602
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
603
|
+
])
|
604
|
+
else:
|
605
|
+
output_cols = []
|
606
|
+
|
607
|
+
# Make sure column names are valid snowflake identifiers.
|
608
|
+
assert output_cols is not None # Make MyPy happy
|
609
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
610
|
+
|
611
|
+
return rv
|
612
|
+
|
613
|
+
def _align_expected_output_names(
|
614
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
615
|
+
) -> List[str]:
|
616
|
+
# in case the inferred output column names dimension is different
|
617
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
618
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
619
|
+
output_df_columns = list(output_df_pd.columns)
|
620
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
621
|
+
if self.sample_weight_col:
|
622
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
623
|
+
# if the dimension of inferred output column names is correct; use it
|
624
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
625
|
+
return expected_output_cols_list
|
626
|
+
# otherwise, use the sklearn estimator's output
|
627
|
+
else:
|
628
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
561
629
|
|
562
630
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
563
631
|
@telemetry.send_api_usage_telemetry(
|
@@ -589,24 +657,26 @@ class FactorAnalysis(BaseTransformer):
|
|
589
657
|
# are specific to the type of dataset used.
|
590
658
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
591
659
|
|
660
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
661
|
+
|
592
662
|
if isinstance(dataset, DataFrame):
|
593
|
-
self.
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
663
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
664
|
+
self._deps = self._get_dependencies()
|
665
|
+
assert isinstance(
|
666
|
+
dataset._session, Session
|
667
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
598
668
|
transform_kwargs = dict(
|
599
669
|
session=dataset._session,
|
600
670
|
dependencies=self._deps,
|
601
|
-
drop_input_cols
|
671
|
+
drop_input_cols=self._drop_input_cols,
|
602
672
|
expected_output_cols_type="float",
|
603
673
|
)
|
674
|
+
expected_output_cols = self._align_expected_output_names(
|
675
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
676
|
+
)
|
604
677
|
|
605
678
|
elif isinstance(dataset, pd.DataFrame):
|
606
|
-
transform_kwargs = dict(
|
607
|
-
snowpark_input_cols = self._snowpark_cols,
|
608
|
-
drop_input_cols = self._drop_input_cols
|
609
|
-
)
|
679
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
610
680
|
|
611
681
|
transform_handlers = ModelTransformerBuilder.build(
|
612
682
|
dataset=dataset,
|
@@ -618,7 +688,7 @@ class FactorAnalysis(BaseTransformer):
|
|
618
688
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
619
689
|
inference_method=inference_method,
|
620
690
|
input_cols=self.input_cols,
|
621
|
-
expected_output_cols=
|
691
|
+
expected_output_cols=expected_output_cols,
|
622
692
|
**transform_kwargs
|
623
693
|
)
|
624
694
|
return output_df
|
@@ -648,29 +718,30 @@ class FactorAnalysis(BaseTransformer):
|
|
648
718
|
Output dataset with log probability of the sample for each class in the model.
|
649
719
|
"""
|
650
720
|
super()._check_dataset_type(dataset)
|
651
|
-
inference_method="predict_log_proba"
|
721
|
+
inference_method = "predict_log_proba"
|
722
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
652
723
|
|
653
724
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
654
725
|
# are specific to the type of dataset used.
|
655
726
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
656
727
|
|
657
728
|
if isinstance(dataset, DataFrame):
|
658
|
-
self.
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
729
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
730
|
+
self._deps = self._get_dependencies()
|
731
|
+
assert isinstance(
|
732
|
+
dataset._session, Session
|
733
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
663
734
|
transform_kwargs = dict(
|
664
735
|
session=dataset._session,
|
665
736
|
dependencies=self._deps,
|
666
|
-
drop_input_cols
|
737
|
+
drop_input_cols=self._drop_input_cols,
|
667
738
|
expected_output_cols_type="float",
|
668
739
|
)
|
740
|
+
expected_output_cols = self._align_expected_output_names(
|
741
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
742
|
+
)
|
669
743
|
elif isinstance(dataset, pd.DataFrame):
|
670
|
-
transform_kwargs = dict(
|
671
|
-
snowpark_input_cols = self._snowpark_cols,
|
672
|
-
drop_input_cols = self._drop_input_cols
|
673
|
-
)
|
744
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
674
745
|
|
675
746
|
transform_handlers = ModelTransformerBuilder.build(
|
676
747
|
dataset=dataset,
|
@@ -683,7 +754,7 @@ class FactorAnalysis(BaseTransformer):
|
|
683
754
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
684
755
|
inference_method=inference_method,
|
685
756
|
input_cols=self.input_cols,
|
686
|
-
expected_output_cols=
|
757
|
+
expected_output_cols=expected_output_cols,
|
687
758
|
**transform_kwargs
|
688
759
|
)
|
689
760
|
return output_df
|
@@ -709,30 +780,32 @@ class FactorAnalysis(BaseTransformer):
|
|
709
780
|
Output dataset with results of the decision function for the samples in input dataset.
|
710
781
|
"""
|
711
782
|
super()._check_dataset_type(dataset)
|
712
|
-
inference_method="decision_function"
|
783
|
+
inference_method = "decision_function"
|
713
784
|
|
714
785
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
715
786
|
# are specific to the type of dataset used.
|
716
787
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
717
788
|
|
789
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
790
|
+
|
718
791
|
if isinstance(dataset, DataFrame):
|
719
|
-
self.
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
792
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
793
|
+
self._deps = self._get_dependencies()
|
794
|
+
assert isinstance(
|
795
|
+
dataset._session, Session
|
796
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
724
797
|
transform_kwargs = dict(
|
725
798
|
session=dataset._session,
|
726
799
|
dependencies=self._deps,
|
727
|
-
drop_input_cols
|
800
|
+
drop_input_cols=self._drop_input_cols,
|
728
801
|
expected_output_cols_type="float",
|
729
802
|
)
|
803
|
+
expected_output_cols = self._align_expected_output_names(
|
804
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
805
|
+
)
|
730
806
|
|
731
807
|
elif isinstance(dataset, pd.DataFrame):
|
732
|
-
transform_kwargs = dict(
|
733
|
-
snowpark_input_cols = self._snowpark_cols,
|
734
|
-
drop_input_cols = self._drop_input_cols
|
735
|
-
)
|
808
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
736
809
|
|
737
810
|
transform_handlers = ModelTransformerBuilder.build(
|
738
811
|
dataset=dataset,
|
@@ -745,7 +818,7 @@ class FactorAnalysis(BaseTransformer):
|
|
745
818
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
746
819
|
inference_method=inference_method,
|
747
820
|
input_cols=self.input_cols,
|
748
|
-
expected_output_cols=
|
821
|
+
expected_output_cols=expected_output_cols,
|
749
822
|
**transform_kwargs
|
750
823
|
)
|
751
824
|
return output_df
|
@@ -776,17 +849,17 @@ class FactorAnalysis(BaseTransformer):
|
|
776
849
|
Output dataset with probability of the sample for each class in the model.
|
777
850
|
"""
|
778
851
|
super()._check_dataset_type(dataset)
|
779
|
-
inference_method="score_samples"
|
852
|
+
inference_method = "score_samples"
|
780
853
|
|
781
854
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
782
855
|
# are specific to the type of dataset used.
|
783
856
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
784
857
|
|
858
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
859
|
+
|
785
860
|
if isinstance(dataset, DataFrame):
|
786
|
-
self.
|
787
|
-
|
788
|
-
inference_method=inference_method,
|
789
|
-
)
|
861
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
862
|
+
self._deps = self._get_dependencies()
|
790
863
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
791
864
|
transform_kwargs = dict(
|
792
865
|
session=dataset._session,
|
@@ -794,6 +867,9 @@ class FactorAnalysis(BaseTransformer):
|
|
794
867
|
drop_input_cols = self._drop_input_cols,
|
795
868
|
expected_output_cols_type="float",
|
796
869
|
)
|
870
|
+
expected_output_cols = self._align_expected_output_names(
|
871
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
872
|
+
)
|
797
873
|
|
798
874
|
elif isinstance(dataset, pd.DataFrame):
|
799
875
|
transform_kwargs = dict(
|
@@ -812,7 +888,7 @@ class FactorAnalysis(BaseTransformer):
|
|
812
888
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
813
889
|
inference_method=inference_method,
|
814
890
|
input_cols=self.input_cols,
|
815
|
-
expected_output_cols=
|
891
|
+
expected_output_cols=expected_output_cols,
|
816
892
|
**transform_kwargs
|
817
893
|
)
|
818
894
|
return output_df
|
@@ -847,17 +923,15 @@ class FactorAnalysis(BaseTransformer):
|
|
847
923
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
848
924
|
|
849
925
|
if isinstance(dataset, DataFrame):
|
850
|
-
self.
|
851
|
-
|
852
|
-
inference_method="score",
|
853
|
-
)
|
926
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
927
|
+
self._deps = self._get_dependencies()
|
854
928
|
selected_cols = self._get_active_columns()
|
855
929
|
if len(selected_cols) > 0:
|
856
930
|
dataset = dataset.select(selected_cols)
|
857
931
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
858
932
|
transform_kwargs = dict(
|
859
933
|
session=dataset._session,
|
860
|
-
dependencies=
|
934
|
+
dependencies=self._deps,
|
861
935
|
score_sproc_imports=['sklearn'],
|
862
936
|
)
|
863
937
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -922,11 +996,8 @@ class FactorAnalysis(BaseTransformer):
|
|
922
996
|
|
923
997
|
if isinstance(dataset, DataFrame):
|
924
998
|
|
925
|
-
self.
|
926
|
-
|
927
|
-
inference_method=inference_method,
|
928
|
-
|
929
|
-
)
|
999
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1000
|
+
self._deps = self._get_dependencies()
|
930
1001
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
931
1002
|
transform_kwargs = dict(
|
932
1003
|
session = dataset._session,
|
@@ -959,50 +1030,84 @@ class FactorAnalysis(BaseTransformer):
|
|
959
1030
|
)
|
960
1031
|
return output_df
|
961
1032
|
|
1033
|
+
|
1034
|
+
|
1035
|
+
def to_sklearn(self) -> Any:
|
1036
|
+
"""Get sklearn.decomposition.FactorAnalysis object.
|
1037
|
+
"""
|
1038
|
+
if self._sklearn_object is None:
|
1039
|
+
self._sklearn_object = self._create_sklearn_object()
|
1040
|
+
return self._sklearn_object
|
1041
|
+
|
1042
|
+
def to_xgboost(self) -> Any:
|
1043
|
+
raise exceptions.SnowflakeMLException(
|
1044
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1045
|
+
original_exception=AttributeError(
|
1046
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1047
|
+
"to_xgboost()",
|
1048
|
+
"to_sklearn()"
|
1049
|
+
)
|
1050
|
+
),
|
1051
|
+
)
|
962
1052
|
|
963
|
-
def
|
1053
|
+
def to_lightgbm(self) -> Any:
|
1054
|
+
raise exceptions.SnowflakeMLException(
|
1055
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1056
|
+
original_exception=AttributeError(
|
1057
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1058
|
+
"to_lightgbm()",
|
1059
|
+
"to_sklearn()"
|
1060
|
+
)
|
1061
|
+
),
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
def _get_dependencies(self) -> List[str]:
|
1065
|
+
return self._deps
|
1066
|
+
|
1067
|
+
|
1068
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
964
1069
|
self._model_signature_dict = dict()
|
965
1070
|
|
966
1071
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
967
1072
|
|
968
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1073
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
969
1074
|
outputs: List[BaseFeatureSpec] = []
|
970
1075
|
if hasattr(self, "predict"):
|
971
1076
|
# keep mypy happy
|
972
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1077
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
973
1078
|
# For classifier, the type of predict is the same as the type of label
|
974
|
-
if self._sklearn_object._estimator_type ==
|
975
|
-
|
1079
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1080
|
+
# label columns is the desired type for output
|
976
1081
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
977
1082
|
# rename the output columns
|
978
1083
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
979
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
980
|
-
|
981
|
-
|
1084
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1085
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1086
|
+
)
|
982
1087
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
983
1088
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
984
|
-
# Clusterer returns int64 cluster labels.
|
1089
|
+
# Clusterer returns int64 cluster labels.
|
985
1090
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
986
1091
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
987
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
988
|
-
|
989
|
-
|
990
|
-
|
1092
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1093
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1094
|
+
)
|
1095
|
+
|
991
1096
|
# For regressor, the type of predict is float64
|
992
|
-
elif self._sklearn_object._estimator_type ==
|
1097
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
993
1098
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
994
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
995
|
-
|
996
|
-
|
997
|
-
|
1099
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1100
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1101
|
+
)
|
1102
|
+
|
998
1103
|
for prob_func in PROB_FUNCTIONS:
|
999
1104
|
if hasattr(self, prob_func):
|
1000
1105
|
output_cols_prefix: str = f"{prob_func}_"
|
1001
1106
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1002
1107
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1003
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1004
|
-
|
1005
|
-
|
1108
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1109
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1110
|
+
)
|
1006
1111
|
|
1007
1112
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1008
1113
|
items = list(self._model_signature_dict.items())
|
@@ -1015,10 +1120,10 @@ class FactorAnalysis(BaseTransformer):
|
|
1015
1120
|
"""Returns model signature of current class.
|
1016
1121
|
|
1017
1122
|
Raises:
|
1018
|
-
|
1123
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1019
1124
|
|
1020
1125
|
Returns:
|
1021
|
-
Dict
|
1126
|
+
Dict with each method and its input output signature
|
1022
1127
|
"""
|
1023
1128
|
if self._model_signature_dict is None:
|
1024
1129
|
raise exceptions.SnowflakeMLException(
|
@@ -1026,35 +1131,3 @@ class FactorAnalysis(BaseTransformer):
|
|
1026
1131
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1027
1132
|
)
|
1028
1133
|
return self._model_signature_dict
|
1029
|
-
|
1030
|
-
def to_sklearn(self) -> Any:
|
1031
|
-
"""Get sklearn.decomposition.FactorAnalysis object.
|
1032
|
-
"""
|
1033
|
-
if self._sklearn_object is None:
|
1034
|
-
self._sklearn_object = self._create_sklearn_object()
|
1035
|
-
return self._sklearn_object
|
1036
|
-
|
1037
|
-
def to_xgboost(self) -> Any:
|
1038
|
-
raise exceptions.SnowflakeMLException(
|
1039
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1040
|
-
original_exception=AttributeError(
|
1041
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1042
|
-
"to_xgboost()",
|
1043
|
-
"to_sklearn()"
|
1044
|
-
)
|
1045
|
-
),
|
1046
|
-
)
|
1047
|
-
|
1048
|
-
def to_lightgbm(self) -> Any:
|
1049
|
-
raise exceptions.SnowflakeMLException(
|
1050
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1051
|
-
original_exception=AttributeError(
|
1052
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1053
|
-
"to_lightgbm()",
|
1054
|
-
"to_sklearn()"
|
1055
|
-
)
|
1056
|
-
),
|
1057
|
-
)
|
1058
|
-
|
1059
|
-
def _get_dependencies(self) -> List[str]:
|
1060
|
-
return self._deps
|