snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class TSNE(BaseTransformer):
|
71
64
|
r"""T-distributed Stochastic Neighbor Embedding
|
72
65
|
For more details on this class, see [sklearn.manifold.TSNE]
|
@@ -322,12 +315,7 @@ class TSNE(BaseTransformer):
|
|
322
315
|
)
|
323
316
|
return selected_cols
|
324
317
|
|
325
|
-
|
326
|
-
project=_PROJECT,
|
327
|
-
subproject=_SUBPROJECT,
|
328
|
-
custom_tags=dict([("autogen", True)]),
|
329
|
-
)
|
330
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "TSNE":
|
318
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "TSNE":
|
331
319
|
"""Fit X into an embedded space
|
332
320
|
For more details on this function, see [sklearn.manifold.TSNE.fit]
|
333
321
|
(https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html#sklearn.manifold.TSNE.fit)
|
@@ -354,12 +342,14 @@ class TSNE(BaseTransformer):
|
|
354
342
|
|
355
343
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
356
344
|
|
357
|
-
|
345
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
358
346
|
if SNOWML_SPROC_ENV in os.environ:
|
359
347
|
statement_params = telemetry.get_function_usage_statement_params(
|
360
348
|
project=_PROJECT,
|
361
349
|
subproject=_SUBPROJECT,
|
362
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
350
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
351
|
+
inspect.currentframe(), TSNE.__class__.__name__
|
352
|
+
),
|
363
353
|
api_calls=[Session.call],
|
364
354
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
365
355
|
)
|
@@ -380,27 +370,24 @@ class TSNE(BaseTransformer):
|
|
380
370
|
)
|
381
371
|
self._sklearn_object = model_trainer.train()
|
382
372
|
self._is_fitted = True
|
383
|
-
self.
|
373
|
+
self._generate_model_signatures(dataset)
|
384
374
|
return self
|
385
375
|
|
386
376
|
def _batch_inference_validate_snowpark(
|
387
377
|
self,
|
388
378
|
dataset: DataFrame,
|
389
379
|
inference_method: str,
|
390
|
-
) ->
|
391
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
392
|
-
return the available package that exists in the snowflake anaconda channel
|
380
|
+
) -> None:
|
381
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
393
382
|
|
394
383
|
Args:
|
395
384
|
dataset: snowpark dataframe
|
396
385
|
inference_method: the inference method such as predict, score...
|
397
|
-
|
386
|
+
|
398
387
|
Raises:
|
399
388
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
400
389
|
SnowflakeMLException: If the session is None, raise error
|
401
390
|
|
402
|
-
Returns:
|
403
|
-
A list of available package that exists in the snowflake anaconda channel
|
404
391
|
"""
|
405
392
|
if not self._is_fitted:
|
406
393
|
raise exceptions.SnowflakeMLException(
|
@@ -418,9 +405,7 @@ class TSNE(BaseTransformer):
|
|
418
405
|
"Session must not specified for snowpark dataset."
|
419
406
|
),
|
420
407
|
)
|
421
|
-
|
422
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
423
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
408
|
+
|
424
409
|
|
425
410
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
426
411
|
@telemetry.send_api_usage_telemetry(
|
@@ -454,7 +439,9 @@ class TSNE(BaseTransformer):
|
|
454
439
|
# when it is classifier, infer the datatype from label columns
|
455
440
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
456
441
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
457
|
-
label_cols_signatures = [
|
442
|
+
label_cols_signatures = [
|
443
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
444
|
+
]
|
458
445
|
if len(label_cols_signatures) == 0:
|
459
446
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
460
447
|
raise exceptions.SnowflakeMLException(
|
@@ -462,25 +449,23 @@ class TSNE(BaseTransformer):
|
|
462
449
|
original_exception=ValueError(error_str),
|
463
450
|
)
|
464
451
|
|
465
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
466
|
-
label_cols_signatures[0].as_snowpark_type()
|
467
|
-
)
|
452
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
468
453
|
|
469
|
-
self.
|
470
|
-
|
454
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
455
|
+
self._deps = self._get_dependencies()
|
456
|
+
assert isinstance(
|
457
|
+
dataset._session, Session
|
458
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
471
459
|
|
472
460
|
transform_kwargs = dict(
|
473
|
-
session
|
474
|
-
dependencies
|
475
|
-
drop_input_cols
|
476
|
-
expected_output_cols_type
|
461
|
+
session=dataset._session,
|
462
|
+
dependencies=self._deps,
|
463
|
+
drop_input_cols=self._drop_input_cols,
|
464
|
+
expected_output_cols_type=expected_type_inferred,
|
477
465
|
)
|
478
466
|
|
479
467
|
elif isinstance(dataset, pd.DataFrame):
|
480
|
-
transform_kwargs = dict(
|
481
|
-
snowpark_input_cols = self._snowpark_cols,
|
482
|
-
drop_input_cols = self._drop_input_cols
|
483
|
-
)
|
468
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
484
469
|
|
485
470
|
transform_handlers = ModelTransformerBuilder.build(
|
486
471
|
dataset=dataset,
|
@@ -520,7 +505,7 @@ class TSNE(BaseTransformer):
|
|
520
505
|
Transformed dataset.
|
521
506
|
"""
|
522
507
|
super()._check_dataset_type(dataset)
|
523
|
-
inference_method="transform"
|
508
|
+
inference_method = "transform"
|
524
509
|
|
525
510
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
526
511
|
# are specific to the type of dataset used.
|
@@ -550,24 +535,19 @@ class TSNE(BaseTransformer):
|
|
550
535
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
551
536
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
552
537
|
|
553
|
-
self.
|
554
|
-
|
555
|
-
inference_method=inference_method,
|
556
|
-
)
|
538
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
539
|
+
self._deps = self._get_dependencies()
|
557
540
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
558
541
|
|
559
542
|
transform_kwargs = dict(
|
560
|
-
session
|
561
|
-
dependencies
|
562
|
-
drop_input_cols
|
563
|
-
expected_output_cols_type
|
543
|
+
session=dataset._session,
|
544
|
+
dependencies=self._deps,
|
545
|
+
drop_input_cols=self._drop_input_cols,
|
546
|
+
expected_output_cols_type=expected_dtype,
|
564
547
|
)
|
565
548
|
|
566
549
|
elif isinstance(dataset, pd.DataFrame):
|
567
|
-
transform_kwargs = dict(
|
568
|
-
snowpark_input_cols = self._snowpark_cols,
|
569
|
-
drop_input_cols = self._drop_input_cols
|
570
|
-
)
|
550
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
571
551
|
|
572
552
|
transform_handlers = ModelTransformerBuilder.build(
|
573
553
|
dataset=dataset,
|
@@ -586,7 +566,11 @@ class TSNE(BaseTransformer):
|
|
586
566
|
return output_df
|
587
567
|
|
588
568
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
589
|
-
def fit_predict(
|
569
|
+
def fit_predict(
|
570
|
+
self,
|
571
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
572
|
+
output_cols_prefix: str = "fit_predict_",
|
573
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
590
574
|
""" Method not supported for this class.
|
591
575
|
|
592
576
|
|
@@ -611,22 +595,106 @@ class TSNE(BaseTransformer):
|
|
611
595
|
)
|
612
596
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
613
597
|
drop_input_cols=self._drop_input_cols,
|
614
|
-
expected_output_cols_list=
|
598
|
+
expected_output_cols_list=(
|
599
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
600
|
+
),
|
615
601
|
)
|
616
602
|
self._sklearn_object = fitted_estimator
|
617
603
|
self._is_fitted = True
|
618
604
|
return output_result
|
619
605
|
|
606
|
+
|
607
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
608
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
609
|
+
""" Fit X into an embedded space and return that transformed output
|
610
|
+
For more details on this function, see [sklearn.manifold.TSNE.fit_transform]
|
611
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html#sklearn.manifold.TSNE.fit_transform)
|
612
|
+
|
613
|
+
|
614
|
+
Raises:
|
615
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
620
616
|
|
621
|
-
|
622
|
-
|
623
|
-
|
617
|
+
Args:
|
618
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
619
|
+
Snowpark or Pandas DataFrame.
|
620
|
+
output_cols_prefix: Prefix for the response columns
|
624
621
|
Returns:
|
625
622
|
Transformed dataset.
|
626
623
|
"""
|
627
|
-
self.
|
628
|
-
|
629
|
-
|
624
|
+
self._infer_input_output_cols(dataset)
|
625
|
+
super()._check_dataset_type(dataset)
|
626
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
627
|
+
estimator=self._sklearn_object,
|
628
|
+
dataset=dataset,
|
629
|
+
input_cols=self.input_cols,
|
630
|
+
label_cols=self.label_cols,
|
631
|
+
sample_weight_col=self.sample_weight_col,
|
632
|
+
autogenerated=self._autogenerated,
|
633
|
+
subproject=_SUBPROJECT,
|
634
|
+
)
|
635
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
636
|
+
drop_input_cols=self._drop_input_cols,
|
637
|
+
expected_output_cols_list=self.output_cols,
|
638
|
+
)
|
639
|
+
self._sklearn_object = fitted_estimator
|
640
|
+
self._is_fitted = True
|
641
|
+
return output_result
|
642
|
+
|
643
|
+
|
644
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
645
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
646
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
647
|
+
"""
|
648
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
649
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
650
|
+
if output_cols:
|
651
|
+
output_cols = [
|
652
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
653
|
+
for c in output_cols
|
654
|
+
]
|
655
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
656
|
+
output_cols = [output_cols_prefix]
|
657
|
+
elif self._sklearn_object is not None:
|
658
|
+
classes = self._sklearn_object.classes_
|
659
|
+
if isinstance(classes, numpy.ndarray):
|
660
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
661
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
662
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
663
|
+
output_cols = []
|
664
|
+
for i, cl in enumerate(classes):
|
665
|
+
# For binary classification, there is only one output column for each class
|
666
|
+
# ndarray as the two classes are complementary.
|
667
|
+
if len(cl) == 2:
|
668
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
669
|
+
else:
|
670
|
+
output_cols.extend([
|
671
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
672
|
+
])
|
673
|
+
else:
|
674
|
+
output_cols = []
|
675
|
+
|
676
|
+
# Make sure column names are valid snowflake identifiers.
|
677
|
+
assert output_cols is not None # Make MyPy happy
|
678
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
679
|
+
|
680
|
+
return rv
|
681
|
+
|
682
|
+
def _align_expected_output_names(
|
683
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
684
|
+
) -> List[str]:
|
685
|
+
# in case the inferred output column names dimension is different
|
686
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
687
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
688
|
+
output_df_columns = list(output_df_pd.columns)
|
689
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
690
|
+
if self.sample_weight_col:
|
691
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
692
|
+
# if the dimension of inferred output column names is correct; use it
|
693
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
694
|
+
return expected_output_cols_list
|
695
|
+
# otherwise, use the sklearn estimator's output
|
696
|
+
else:
|
697
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
630
698
|
|
631
699
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
632
700
|
@telemetry.send_api_usage_telemetry(
|
@@ -658,24 +726,26 @@ class TSNE(BaseTransformer):
|
|
658
726
|
# are specific to the type of dataset used.
|
659
727
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
660
728
|
|
729
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
730
|
+
|
661
731
|
if isinstance(dataset, DataFrame):
|
662
|
-
self.
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
732
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
733
|
+
self._deps = self._get_dependencies()
|
734
|
+
assert isinstance(
|
735
|
+
dataset._session, Session
|
736
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
667
737
|
transform_kwargs = dict(
|
668
738
|
session=dataset._session,
|
669
739
|
dependencies=self._deps,
|
670
|
-
drop_input_cols
|
740
|
+
drop_input_cols=self._drop_input_cols,
|
671
741
|
expected_output_cols_type="float",
|
672
742
|
)
|
743
|
+
expected_output_cols = self._align_expected_output_names(
|
744
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
745
|
+
)
|
673
746
|
|
674
747
|
elif isinstance(dataset, pd.DataFrame):
|
675
|
-
transform_kwargs = dict(
|
676
|
-
snowpark_input_cols = self._snowpark_cols,
|
677
|
-
drop_input_cols = self._drop_input_cols
|
678
|
-
)
|
748
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
679
749
|
|
680
750
|
transform_handlers = ModelTransformerBuilder.build(
|
681
751
|
dataset=dataset,
|
@@ -687,7 +757,7 @@ class TSNE(BaseTransformer):
|
|
687
757
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
688
758
|
inference_method=inference_method,
|
689
759
|
input_cols=self.input_cols,
|
690
|
-
expected_output_cols=
|
760
|
+
expected_output_cols=expected_output_cols,
|
691
761
|
**transform_kwargs
|
692
762
|
)
|
693
763
|
return output_df
|
@@ -717,29 +787,30 @@ class TSNE(BaseTransformer):
|
|
717
787
|
Output dataset with log probability of the sample for each class in the model.
|
718
788
|
"""
|
719
789
|
super()._check_dataset_type(dataset)
|
720
|
-
inference_method="predict_log_proba"
|
790
|
+
inference_method = "predict_log_proba"
|
791
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
721
792
|
|
722
793
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
723
794
|
# are specific to the type of dataset used.
|
724
795
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
725
796
|
|
726
797
|
if isinstance(dataset, DataFrame):
|
727
|
-
self.
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
798
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
799
|
+
self._deps = self._get_dependencies()
|
800
|
+
assert isinstance(
|
801
|
+
dataset._session, Session
|
802
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
732
803
|
transform_kwargs = dict(
|
733
804
|
session=dataset._session,
|
734
805
|
dependencies=self._deps,
|
735
|
-
drop_input_cols
|
806
|
+
drop_input_cols=self._drop_input_cols,
|
736
807
|
expected_output_cols_type="float",
|
737
808
|
)
|
809
|
+
expected_output_cols = self._align_expected_output_names(
|
810
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
811
|
+
)
|
738
812
|
elif isinstance(dataset, pd.DataFrame):
|
739
|
-
transform_kwargs = dict(
|
740
|
-
snowpark_input_cols = self._snowpark_cols,
|
741
|
-
drop_input_cols = self._drop_input_cols
|
742
|
-
)
|
813
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
743
814
|
|
744
815
|
transform_handlers = ModelTransformerBuilder.build(
|
745
816
|
dataset=dataset,
|
@@ -752,7 +823,7 @@ class TSNE(BaseTransformer):
|
|
752
823
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
753
824
|
inference_method=inference_method,
|
754
825
|
input_cols=self.input_cols,
|
755
|
-
expected_output_cols=
|
826
|
+
expected_output_cols=expected_output_cols,
|
756
827
|
**transform_kwargs
|
757
828
|
)
|
758
829
|
return output_df
|
@@ -778,30 +849,32 @@ class TSNE(BaseTransformer):
|
|
778
849
|
Output dataset with results of the decision function for the samples in input dataset.
|
779
850
|
"""
|
780
851
|
super()._check_dataset_type(dataset)
|
781
|
-
inference_method="decision_function"
|
852
|
+
inference_method = "decision_function"
|
782
853
|
|
783
854
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
784
855
|
# are specific to the type of dataset used.
|
785
856
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
786
857
|
|
858
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
859
|
+
|
787
860
|
if isinstance(dataset, DataFrame):
|
788
|
-
self.
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
861
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
862
|
+
self._deps = self._get_dependencies()
|
863
|
+
assert isinstance(
|
864
|
+
dataset._session, Session
|
865
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
793
866
|
transform_kwargs = dict(
|
794
867
|
session=dataset._session,
|
795
868
|
dependencies=self._deps,
|
796
|
-
drop_input_cols
|
869
|
+
drop_input_cols=self._drop_input_cols,
|
797
870
|
expected_output_cols_type="float",
|
798
871
|
)
|
872
|
+
expected_output_cols = self._align_expected_output_names(
|
873
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
874
|
+
)
|
799
875
|
|
800
876
|
elif isinstance(dataset, pd.DataFrame):
|
801
|
-
transform_kwargs = dict(
|
802
|
-
snowpark_input_cols = self._snowpark_cols,
|
803
|
-
drop_input_cols = self._drop_input_cols
|
804
|
-
)
|
877
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
805
878
|
|
806
879
|
transform_handlers = ModelTransformerBuilder.build(
|
807
880
|
dataset=dataset,
|
@@ -814,7 +887,7 @@ class TSNE(BaseTransformer):
|
|
814
887
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
815
888
|
inference_method=inference_method,
|
816
889
|
input_cols=self.input_cols,
|
817
|
-
expected_output_cols=
|
890
|
+
expected_output_cols=expected_output_cols,
|
818
891
|
**transform_kwargs
|
819
892
|
)
|
820
893
|
return output_df
|
@@ -843,17 +916,17 @@ class TSNE(BaseTransformer):
|
|
843
916
|
Output dataset with probability of the sample for each class in the model.
|
844
917
|
"""
|
845
918
|
super()._check_dataset_type(dataset)
|
846
|
-
inference_method="score_samples"
|
919
|
+
inference_method = "score_samples"
|
847
920
|
|
848
921
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
849
922
|
# are specific to the type of dataset used.
|
850
923
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
851
924
|
|
925
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
926
|
+
|
852
927
|
if isinstance(dataset, DataFrame):
|
853
|
-
self.
|
854
|
-
|
855
|
-
inference_method=inference_method,
|
856
|
-
)
|
928
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
929
|
+
self._deps = self._get_dependencies()
|
857
930
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
858
931
|
transform_kwargs = dict(
|
859
932
|
session=dataset._session,
|
@@ -861,6 +934,9 @@ class TSNE(BaseTransformer):
|
|
861
934
|
drop_input_cols = self._drop_input_cols,
|
862
935
|
expected_output_cols_type="float",
|
863
936
|
)
|
937
|
+
expected_output_cols = self._align_expected_output_names(
|
938
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
939
|
+
)
|
864
940
|
|
865
941
|
elif isinstance(dataset, pd.DataFrame):
|
866
942
|
transform_kwargs = dict(
|
@@ -879,7 +955,7 @@ class TSNE(BaseTransformer):
|
|
879
955
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
880
956
|
inference_method=inference_method,
|
881
957
|
input_cols=self.input_cols,
|
882
|
-
expected_output_cols=
|
958
|
+
expected_output_cols=expected_output_cols,
|
883
959
|
**transform_kwargs
|
884
960
|
)
|
885
961
|
return output_df
|
@@ -912,17 +988,15 @@ class TSNE(BaseTransformer):
|
|
912
988
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
913
989
|
|
914
990
|
if isinstance(dataset, DataFrame):
|
915
|
-
self.
|
916
|
-
|
917
|
-
inference_method="score",
|
918
|
-
)
|
991
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
992
|
+
self._deps = self._get_dependencies()
|
919
993
|
selected_cols = self._get_active_columns()
|
920
994
|
if len(selected_cols) > 0:
|
921
995
|
dataset = dataset.select(selected_cols)
|
922
996
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
923
997
|
transform_kwargs = dict(
|
924
998
|
session=dataset._session,
|
925
|
-
dependencies=
|
999
|
+
dependencies=self._deps,
|
926
1000
|
score_sproc_imports=['sklearn'],
|
927
1001
|
)
|
928
1002
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -987,11 +1061,8 @@ class TSNE(BaseTransformer):
|
|
987
1061
|
|
988
1062
|
if isinstance(dataset, DataFrame):
|
989
1063
|
|
990
|
-
self.
|
991
|
-
|
992
|
-
inference_method=inference_method,
|
993
|
-
|
994
|
-
)
|
1064
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1065
|
+
self._deps = self._get_dependencies()
|
995
1066
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
996
1067
|
transform_kwargs = dict(
|
997
1068
|
session = dataset._session,
|
@@ -1024,50 +1095,84 @@ class TSNE(BaseTransformer):
|
|
1024
1095
|
)
|
1025
1096
|
return output_df
|
1026
1097
|
|
1098
|
+
|
1099
|
+
|
1100
|
+
def to_sklearn(self) -> Any:
|
1101
|
+
"""Get sklearn.manifold.TSNE object.
|
1102
|
+
"""
|
1103
|
+
if self._sklearn_object is None:
|
1104
|
+
self._sklearn_object = self._create_sklearn_object()
|
1105
|
+
return self._sklearn_object
|
1106
|
+
|
1107
|
+
def to_xgboost(self) -> Any:
|
1108
|
+
raise exceptions.SnowflakeMLException(
|
1109
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1110
|
+
original_exception=AttributeError(
|
1111
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1112
|
+
"to_xgboost()",
|
1113
|
+
"to_sklearn()"
|
1114
|
+
)
|
1115
|
+
),
|
1116
|
+
)
|
1027
1117
|
|
1028
|
-
def
|
1118
|
+
def to_lightgbm(self) -> Any:
|
1119
|
+
raise exceptions.SnowflakeMLException(
|
1120
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1121
|
+
original_exception=AttributeError(
|
1122
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1123
|
+
"to_lightgbm()",
|
1124
|
+
"to_sklearn()"
|
1125
|
+
)
|
1126
|
+
),
|
1127
|
+
)
|
1128
|
+
|
1129
|
+
def _get_dependencies(self) -> List[str]:
|
1130
|
+
return self._deps
|
1131
|
+
|
1132
|
+
|
1133
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1029
1134
|
self._model_signature_dict = dict()
|
1030
1135
|
|
1031
1136
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1032
1137
|
|
1033
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1138
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1034
1139
|
outputs: List[BaseFeatureSpec] = []
|
1035
1140
|
if hasattr(self, "predict"):
|
1036
1141
|
# keep mypy happy
|
1037
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1142
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1038
1143
|
# For classifier, the type of predict is the same as the type of label
|
1039
|
-
if self._sklearn_object._estimator_type ==
|
1040
|
-
|
1144
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1145
|
+
# label columns is the desired type for output
|
1041
1146
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1042
1147
|
# rename the output columns
|
1043
1148
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1044
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1045
|
-
|
1046
|
-
|
1149
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1150
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1151
|
+
)
|
1047
1152
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1048
1153
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1049
|
-
# Clusterer returns int64 cluster labels.
|
1154
|
+
# Clusterer returns int64 cluster labels.
|
1050
1155
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1051
1156
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1052
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1157
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1158
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1159
|
+
)
|
1160
|
+
|
1056
1161
|
# For regressor, the type of predict is float64
|
1057
|
-
elif self._sklearn_object._estimator_type ==
|
1162
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1058
1163
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1059
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1164
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1165
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1166
|
+
)
|
1167
|
+
|
1063
1168
|
for prob_func in PROB_FUNCTIONS:
|
1064
1169
|
if hasattr(self, prob_func):
|
1065
1170
|
output_cols_prefix: str = f"{prob_func}_"
|
1066
1171
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1067
1172
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1068
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1069
|
-
|
1070
|
-
|
1173
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1174
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1175
|
+
)
|
1071
1176
|
|
1072
1177
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1073
1178
|
items = list(self._model_signature_dict.items())
|
@@ -1080,10 +1185,10 @@ class TSNE(BaseTransformer):
|
|
1080
1185
|
"""Returns model signature of current class.
|
1081
1186
|
|
1082
1187
|
Raises:
|
1083
|
-
|
1188
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1084
1189
|
|
1085
1190
|
Returns:
|
1086
|
-
Dict
|
1191
|
+
Dict with each method and its input output signature
|
1087
1192
|
"""
|
1088
1193
|
if self._model_signature_dict is None:
|
1089
1194
|
raise exceptions.SnowflakeMLException(
|
@@ -1091,35 +1196,3 @@ class TSNE(BaseTransformer):
|
|
1091
1196
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1092
1197
|
)
|
1093
1198
|
return self._model_signature_dict
|
1094
|
-
|
1095
|
-
def to_sklearn(self) -> Any:
|
1096
|
-
"""Get sklearn.manifold.TSNE object.
|
1097
|
-
"""
|
1098
|
-
if self._sklearn_object is None:
|
1099
|
-
self._sklearn_object = self._create_sklearn_object()
|
1100
|
-
return self._sklearn_object
|
1101
|
-
|
1102
|
-
def to_xgboost(self) -> Any:
|
1103
|
-
raise exceptions.SnowflakeMLException(
|
1104
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1105
|
-
original_exception=AttributeError(
|
1106
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1107
|
-
"to_xgboost()",
|
1108
|
-
"to_sklearn()"
|
1109
|
-
)
|
1110
|
-
),
|
1111
|
-
)
|
1112
|
-
|
1113
|
-
def to_lightgbm(self) -> Any:
|
1114
|
-
raise exceptions.SnowflakeMLException(
|
1115
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1116
|
-
original_exception=AttributeError(
|
1117
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1118
|
-
"to_lightgbm()",
|
1119
|
-
"to_sklearn()"
|
1120
|
-
)
|
1121
|
-
),
|
1122
|
-
)
|
1123
|
-
|
1124
|
-
def _get_dependencies(self) -> List[str]:
|
1125
|
-
return self._deps
|