snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class TweedieRegressor(BaseTransformer):
|
58
70
|
r"""Generalized Linear Model with a Tweedie distribution
|
59
71
|
For more details on this class, see [sklearn.linear_model.TweedieRegressor]
|
@@ -206,7 +218,9 @@ class TweedieRegressor(BaseTransformer):
|
|
206
218
|
self.set_label_cols(label_cols)
|
207
219
|
self.set_passthrough_cols(passthrough_cols)
|
208
220
|
self.set_drop_input_cols(drop_input_cols)
|
209
|
-
self.set_sample_weight_col(sample_weight_col)
|
221
|
+
self.set_sample_weight_col(sample_weight_col)
|
222
|
+
self._use_external_memory_version = False
|
223
|
+
self._batch_size = -1
|
210
224
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
211
225
|
|
212
226
|
self._deps = list(deps)
|
@@ -290,11 +304,6 @@ class TweedieRegressor(BaseTransformer):
|
|
290
304
|
if isinstance(dataset, DataFrame):
|
291
305
|
session = dataset._session
|
292
306
|
assert session is not None # keep mypy happy
|
293
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
294
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
295
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
296
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
297
|
-
|
298
307
|
# Specify input columns so column pruning will be enforced
|
299
308
|
selected_cols = self._get_active_columns()
|
300
309
|
if len(selected_cols) > 0:
|
@@ -322,7 +331,9 @@ class TweedieRegressor(BaseTransformer):
|
|
322
331
|
label_cols=self.label_cols,
|
323
332
|
sample_weight_col=self.sample_weight_col,
|
324
333
|
autogenerated=self._autogenerated,
|
325
|
-
subproject=_SUBPROJECT
|
334
|
+
subproject=_SUBPROJECT,
|
335
|
+
use_external_memory_version=self._use_external_memory_version,
|
336
|
+
batch_size=self._batch_size,
|
326
337
|
)
|
327
338
|
self._sklearn_object = model_trainer.train()
|
328
339
|
self._is_fitted = True
|
@@ -593,6 +604,22 @@ class TweedieRegressor(BaseTransformer):
|
|
593
604
|
# each row containing a list of values.
|
594
605
|
expected_dtype = "ARRAY"
|
595
606
|
|
607
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
608
|
+
if expected_dtype == "":
|
609
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
610
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
611
|
+
expected_dtype = "ARRAY"
|
612
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
613
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
614
|
+
expected_dtype = "ARRAY"
|
615
|
+
else:
|
616
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
617
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
618
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
619
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
620
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
621
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
622
|
+
|
596
623
|
output_df = self._batch_inference(
|
597
624
|
dataset=dataset,
|
598
625
|
inference_method="transform",
|
@@ -608,8 +635,8 @@ class TweedieRegressor(BaseTransformer):
|
|
608
635
|
|
609
636
|
return output_df
|
610
637
|
|
611
|
-
@available_if(
|
612
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
638
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
639
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
613
640
|
""" Method not supported for this class.
|
614
641
|
|
615
642
|
|
@@ -622,13 +649,21 @@ class TweedieRegressor(BaseTransformer):
|
|
622
649
|
Returns:
|
623
650
|
Predicted dataset.
|
624
651
|
"""
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
652
|
+
self.fit(dataset)
|
653
|
+
assert self._sklearn_object is not None
|
654
|
+
return self._sklearn_object.labels_
|
655
|
+
|
656
|
+
|
657
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
658
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
659
|
+
"""
|
660
|
+
Returns:
|
661
|
+
Transformed dataset.
|
662
|
+
"""
|
663
|
+
self.fit(dataset)
|
664
|
+
assert self._sklearn_object is not None
|
665
|
+
return self._sklearn_object.embedding_
|
666
|
+
|
632
667
|
|
633
668
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
634
669
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Isomap(BaseTransformer):
|
58
70
|
r"""Isomap Embedding
|
59
71
|
For more details on this class, see [sklearn.manifold.Isomap]
|
@@ -199,7 +211,9 @@ class Isomap(BaseTransformer):
|
|
199
211
|
self.set_label_cols(label_cols)
|
200
212
|
self.set_passthrough_cols(passthrough_cols)
|
201
213
|
self.set_drop_input_cols(drop_input_cols)
|
202
|
-
self.set_sample_weight_col(sample_weight_col)
|
214
|
+
self.set_sample_weight_col(sample_weight_col)
|
215
|
+
self._use_external_memory_version = False
|
216
|
+
self._batch_size = -1
|
203
217
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
204
218
|
|
205
219
|
self._deps = list(deps)
|
@@ -286,11 +300,6 @@ class Isomap(BaseTransformer):
|
|
286
300
|
if isinstance(dataset, DataFrame):
|
287
301
|
session = dataset._session
|
288
302
|
assert session is not None # keep mypy happy
|
289
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
290
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
291
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
292
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
293
|
-
|
294
303
|
# Specify input columns so column pruning will be enforced
|
295
304
|
selected_cols = self._get_active_columns()
|
296
305
|
if len(selected_cols) > 0:
|
@@ -318,7 +327,9 @@ class Isomap(BaseTransformer):
|
|
318
327
|
label_cols=self.label_cols,
|
319
328
|
sample_weight_col=self.sample_weight_col,
|
320
329
|
autogenerated=self._autogenerated,
|
321
|
-
subproject=_SUBPROJECT
|
330
|
+
subproject=_SUBPROJECT,
|
331
|
+
use_external_memory_version=self._use_external_memory_version,
|
332
|
+
batch_size=self._batch_size,
|
322
333
|
)
|
323
334
|
self._sklearn_object = model_trainer.train()
|
324
335
|
self._is_fitted = True
|
@@ -589,6 +600,22 @@ class Isomap(BaseTransformer):
|
|
589
600
|
# each row containing a list of values.
|
590
601
|
expected_dtype = "ARRAY"
|
591
602
|
|
603
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
604
|
+
if expected_dtype == "":
|
605
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
606
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
607
|
+
expected_dtype = "ARRAY"
|
608
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
609
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
610
|
+
expected_dtype = "ARRAY"
|
611
|
+
else:
|
612
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
613
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
614
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
615
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
616
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
617
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
618
|
+
|
592
619
|
output_df = self._batch_inference(
|
593
620
|
dataset=dataset,
|
594
621
|
inference_method="transform",
|
@@ -604,8 +631,8 @@ class Isomap(BaseTransformer):
|
|
604
631
|
|
605
632
|
return output_df
|
606
633
|
|
607
|
-
@available_if(
|
608
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
634
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
635
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
609
636
|
""" Method not supported for this class.
|
610
637
|
|
611
638
|
|
@@ -618,13 +645,21 @@ class Isomap(BaseTransformer):
|
|
618
645
|
Returns:
|
619
646
|
Predicted dataset.
|
620
647
|
"""
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
648
|
+
self.fit(dataset)
|
649
|
+
assert self._sklearn_object is not None
|
650
|
+
return self._sklearn_object.labels_
|
651
|
+
|
652
|
+
|
653
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
654
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
655
|
+
"""
|
656
|
+
Returns:
|
657
|
+
Transformed dataset.
|
658
|
+
"""
|
659
|
+
self.fit(dataset)
|
660
|
+
assert self._sklearn_object is not None
|
661
|
+
return self._sklearn_object.embedding_
|
662
|
+
|
628
663
|
|
629
664
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
630
665
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MDS(BaseTransformer):
|
58
70
|
r"""Multidimensional scaling
|
59
71
|
For more details on this class, see [sklearn.manifold.MDS]
|
@@ -184,7 +196,9 @@ class MDS(BaseTransformer):
|
|
184
196
|
self.set_label_cols(label_cols)
|
185
197
|
self.set_passthrough_cols(passthrough_cols)
|
186
198
|
self.set_drop_input_cols(drop_input_cols)
|
187
|
-
self.set_sample_weight_col(sample_weight_col)
|
199
|
+
self.set_sample_weight_col(sample_weight_col)
|
200
|
+
self._use_external_memory_version = False
|
201
|
+
self._batch_size = -1
|
188
202
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
189
203
|
|
190
204
|
self._deps = list(deps)
|
@@ -269,11 +283,6 @@ class MDS(BaseTransformer):
|
|
269
283
|
if isinstance(dataset, DataFrame):
|
270
284
|
session = dataset._session
|
271
285
|
assert session is not None # keep mypy happy
|
272
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
273
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
274
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
275
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
276
|
-
|
277
286
|
# Specify input columns so column pruning will be enforced
|
278
287
|
selected_cols = self._get_active_columns()
|
279
288
|
if len(selected_cols) > 0:
|
@@ -301,7 +310,9 @@ class MDS(BaseTransformer):
|
|
301
310
|
label_cols=self.label_cols,
|
302
311
|
sample_weight_col=self.sample_weight_col,
|
303
312
|
autogenerated=self._autogenerated,
|
304
|
-
subproject=_SUBPROJECT
|
313
|
+
subproject=_SUBPROJECT,
|
314
|
+
use_external_memory_version=self._use_external_memory_version,
|
315
|
+
batch_size=self._batch_size,
|
305
316
|
)
|
306
317
|
self._sklearn_object = model_trainer.train()
|
307
318
|
self._is_fitted = True
|
@@ -570,6 +581,22 @@ class MDS(BaseTransformer):
|
|
570
581
|
# each row containing a list of values.
|
571
582
|
expected_dtype = "ARRAY"
|
572
583
|
|
584
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
585
|
+
if expected_dtype == "":
|
586
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
587
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
588
|
+
expected_dtype = "ARRAY"
|
589
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
590
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
591
|
+
expected_dtype = "ARRAY"
|
592
|
+
else:
|
593
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
594
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
595
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
596
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
597
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
598
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
599
|
+
|
573
600
|
output_df = self._batch_inference(
|
574
601
|
dataset=dataset,
|
575
602
|
inference_method="transform",
|
@@ -585,8 +612,8 @@ class MDS(BaseTransformer):
|
|
585
612
|
|
586
613
|
return output_df
|
587
614
|
|
588
|
-
@available_if(
|
589
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
615
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
616
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
590
617
|
""" Method not supported for this class.
|
591
618
|
|
592
619
|
|
@@ -599,13 +626,21 @@ class MDS(BaseTransformer):
|
|
599
626
|
Returns:
|
600
627
|
Predicted dataset.
|
601
628
|
"""
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
629
|
+
self.fit(dataset)
|
630
|
+
assert self._sklearn_object is not None
|
631
|
+
return self._sklearn_object.labels_
|
632
|
+
|
633
|
+
|
634
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
635
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
636
|
+
"""
|
637
|
+
Returns:
|
638
|
+
Transformed dataset.
|
639
|
+
"""
|
640
|
+
self.fit(dataset)
|
641
|
+
assert self._sklearn_object is not None
|
642
|
+
return self._sklearn_object.embedding_
|
643
|
+
|
609
644
|
|
610
645
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
646
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SpectralEmbedding(BaseTransformer):
|
58
70
|
r"""Spectral embedding for non-linear dimensionality reduction
|
59
71
|
For more details on this class, see [sklearn.manifold.SpectralEmbedding]
|
@@ -188,7 +200,9 @@ class SpectralEmbedding(BaseTransformer):
|
|
188
200
|
self.set_label_cols(label_cols)
|
189
201
|
self.set_passthrough_cols(passthrough_cols)
|
190
202
|
self.set_drop_input_cols(drop_input_cols)
|
191
|
-
self.set_sample_weight_col(sample_weight_col)
|
203
|
+
self.set_sample_weight_col(sample_weight_col)
|
204
|
+
self._use_external_memory_version = False
|
205
|
+
self._batch_size = -1
|
192
206
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
193
207
|
|
194
208
|
self._deps = list(deps)
|
@@ -271,11 +285,6 @@ class SpectralEmbedding(BaseTransformer):
|
|
271
285
|
if isinstance(dataset, DataFrame):
|
272
286
|
session = dataset._session
|
273
287
|
assert session is not None # keep mypy happy
|
274
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
275
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
276
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
277
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
278
|
-
|
279
288
|
# Specify input columns so column pruning will be enforced
|
280
289
|
selected_cols = self._get_active_columns()
|
281
290
|
if len(selected_cols) > 0:
|
@@ -303,7 +312,9 @@ class SpectralEmbedding(BaseTransformer):
|
|
303
312
|
label_cols=self.label_cols,
|
304
313
|
sample_weight_col=self.sample_weight_col,
|
305
314
|
autogenerated=self._autogenerated,
|
306
|
-
subproject=_SUBPROJECT
|
315
|
+
subproject=_SUBPROJECT,
|
316
|
+
use_external_memory_version=self._use_external_memory_version,
|
317
|
+
batch_size=self._batch_size,
|
307
318
|
)
|
308
319
|
self._sklearn_object = model_trainer.train()
|
309
320
|
self._is_fitted = True
|
@@ -572,6 +583,22 @@ class SpectralEmbedding(BaseTransformer):
|
|
572
583
|
# each row containing a list of values.
|
573
584
|
expected_dtype = "ARRAY"
|
574
585
|
|
586
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
587
|
+
if expected_dtype == "":
|
588
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
589
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
590
|
+
expected_dtype = "ARRAY"
|
591
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
592
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
593
|
+
expected_dtype = "ARRAY"
|
594
|
+
else:
|
595
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
596
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
597
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
598
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
599
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
600
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
601
|
+
|
575
602
|
output_df = self._batch_inference(
|
576
603
|
dataset=dataset,
|
577
604
|
inference_method="transform",
|
@@ -587,8 +614,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
587
614
|
|
588
615
|
return output_df
|
589
616
|
|
590
|
-
@available_if(
|
591
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
617
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
618
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
592
619
|
""" Method not supported for this class.
|
593
620
|
|
594
621
|
|
@@ -601,13 +628,21 @@ class SpectralEmbedding(BaseTransformer):
|
|
601
628
|
Returns:
|
602
629
|
Predicted dataset.
|
603
630
|
"""
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
631
|
+
self.fit(dataset)
|
632
|
+
assert self._sklearn_object is not None
|
633
|
+
return self._sklearn_object.labels_
|
634
|
+
|
635
|
+
|
636
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
637
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
638
|
+
"""
|
639
|
+
Returns:
|
640
|
+
Transformed dataset.
|
641
|
+
"""
|
642
|
+
self.fit(dataset)
|
643
|
+
assert self._sklearn_object is not None
|
644
|
+
return self._sklearn_object.embedding_
|
645
|
+
|
611
646
|
|
612
647
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
613
648
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class TSNE(BaseTransformer):
|
58
70
|
r"""T-distributed Stochastic Neighbor Embedding
|
59
71
|
For more details on this class, see [sklearn.manifold.TSNE]
|
@@ -240,7 +252,9 @@ class TSNE(BaseTransformer):
|
|
240
252
|
self.set_label_cols(label_cols)
|
241
253
|
self.set_passthrough_cols(passthrough_cols)
|
242
254
|
self.set_drop_input_cols(drop_input_cols)
|
243
|
-
self.set_sample_weight_col(sample_weight_col)
|
255
|
+
self.set_sample_weight_col(sample_weight_col)
|
256
|
+
self._use_external_memory_version = False
|
257
|
+
self._batch_size = -1
|
244
258
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
245
259
|
|
246
260
|
self._deps = list(deps)
|
@@ -330,11 +344,6 @@ class TSNE(BaseTransformer):
|
|
330
344
|
if isinstance(dataset, DataFrame):
|
331
345
|
session = dataset._session
|
332
346
|
assert session is not None # keep mypy happy
|
333
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
334
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
335
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
336
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
337
|
-
|
338
347
|
# Specify input columns so column pruning will be enforced
|
339
348
|
selected_cols = self._get_active_columns()
|
340
349
|
if len(selected_cols) > 0:
|
@@ -362,7 +371,9 @@ class TSNE(BaseTransformer):
|
|
362
371
|
label_cols=self.label_cols,
|
363
372
|
sample_weight_col=self.sample_weight_col,
|
364
373
|
autogenerated=self._autogenerated,
|
365
|
-
subproject=_SUBPROJECT
|
374
|
+
subproject=_SUBPROJECT,
|
375
|
+
use_external_memory_version=self._use_external_memory_version,
|
376
|
+
batch_size=self._batch_size,
|
366
377
|
)
|
367
378
|
self._sklearn_object = model_trainer.train()
|
368
379
|
self._is_fitted = True
|
@@ -631,6 +642,22 @@ class TSNE(BaseTransformer):
|
|
631
642
|
# each row containing a list of values.
|
632
643
|
expected_dtype = "ARRAY"
|
633
644
|
|
645
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
646
|
+
if expected_dtype == "":
|
647
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
648
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
649
|
+
expected_dtype = "ARRAY"
|
650
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
651
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
652
|
+
expected_dtype = "ARRAY"
|
653
|
+
else:
|
654
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
655
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
656
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
657
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
658
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
659
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
660
|
+
|
634
661
|
output_df = self._batch_inference(
|
635
662
|
dataset=dataset,
|
636
663
|
inference_method="transform",
|
@@ -646,8 +673,8 @@ class TSNE(BaseTransformer):
|
|
646
673
|
|
647
674
|
return output_df
|
648
675
|
|
649
|
-
@available_if(
|
650
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
676
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
677
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
651
678
|
""" Method not supported for this class.
|
652
679
|
|
653
680
|
|
@@ -660,13 +687,21 @@ class TSNE(BaseTransformer):
|
|
660
687
|
Returns:
|
661
688
|
Predicted dataset.
|
662
689
|
"""
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
690
|
+
self.fit(dataset)
|
691
|
+
assert self._sklearn_object is not None
|
692
|
+
return self._sklearn_object.labels_
|
693
|
+
|
694
|
+
|
695
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
696
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
697
|
+
"""
|
698
|
+
Returns:
|
699
|
+
Transformed dataset.
|
700
|
+
"""
|
701
|
+
self.fit(dataset)
|
702
|
+
assert self._sklearn_object is not None
|
703
|
+
return self._sklearn_object.embedding_
|
704
|
+
|
670
705
|
|
671
706
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
672
707
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -228,16 +228,15 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
228
228
|
Returns:
|
229
229
|
Name of the UDTF.
|
230
230
|
"""
|
231
|
+
batch_size = metrics_utils.BATCH_SIZE
|
231
232
|
|
232
233
|
class ConfusionMatrixComputer:
|
233
|
-
BATCH_SIZE = 1000
|
234
|
-
|
235
234
|
def __init__(self) -> None:
|
236
235
|
self._initialized = False
|
237
236
|
self._confusion_matrix = np.zeros((1, 1))
|
238
|
-
# 2d array containing a batch of input rows. A batch contains
|
237
|
+
# 2d array containing a batch of input rows. A batch contains metrics_utils.BATCH_SIZE rows.
|
239
238
|
# [sample_weight, y_true, y_pred]
|
240
|
-
self._batched_rows = np.zeros((
|
239
|
+
self._batched_rows = np.zeros((batch_size, 1))
|
241
240
|
# Number of columns in the dataset.
|
242
241
|
self._n_cols = -1
|
243
242
|
# Running count of number of rows added to self._batched_rows.
|
@@ -255,7 +254,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
255
254
|
# 1. Initialize variables.
|
256
255
|
if not self._initialized:
|
257
256
|
self._n_cols = len(input_row)
|
258
|
-
self._batched_rows = np.zeros((
|
257
|
+
self._batched_rows = np.zeros((batch_size, self._n_cols))
|
259
258
|
self._n_label = n_label
|
260
259
|
self._confusion_matrix = np.zeros((self._n_label, self._n_label))
|
261
260
|
self._initialized = True
|
@@ -264,7 +263,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
264
263
|
self._cur_count += 1
|
265
264
|
|
266
265
|
# 2. Compute incremental confusion matrix for the batch.
|
267
|
-
if self._cur_count >=
|
266
|
+
if self._cur_count >= batch_size:
|
268
267
|
self.update_confusion_matrix()
|
269
268
|
self._cur_count = 0
|
270
269
|
|
@@ -15,6 +15,7 @@ from snowflake.snowpark import Session, functions as F, types as T
|
|
15
15
|
|
16
16
|
LABEL = "LABEL"
|
17
17
|
INDEX = "INDEX"
|
18
|
+
BATCH_SIZE = 1000
|
18
19
|
|
19
20
|
|
20
21
|
def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, Any]) -> str:
|
@@ -82,7 +83,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
82
83
|
"""This class is registered as a UDTF and computes the sum and dot product
|
83
84
|
of columns for each partition of rows. The computations across all the partitions happens
|
84
85
|
in parallel using the nodes in the warehouse. In order to avoid keeping the entire partition
|
85
|
-
in memory, we batch the rows
|
86
|
+
in memory, we batch the rows and maintain a running sum and dot prod in self._sum_by_count,
|
86
87
|
self._sum_by_countd and self._dot_prod respectively. We return these at the end of the partition.
|
87
88
|
"""
|
88
89
|
|
@@ -95,7 +96,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
95
96
|
# delta degree of freedom
|
96
97
|
self._ddof = 0
|
97
98
|
# Setting the batch size to 1000 based on experimentation. Can be fine tuned later.
|
98
|
-
self._batch_size =
|
99
|
+
self._batch_size = BATCH_SIZE
|
99
100
|
# 2d array containing a batch of input rows. A batch contains self._batch_size rows.
|
100
101
|
self._batched_rows = np.zeros((self._batch_size, 1))
|
101
102
|
# 1d array of length = # of cols. Contains sum(col/count) for each column.
|
@@ -224,7 +225,7 @@ def check_label_columns(
|
|
224
225
|
TypeError: `y_true_col_names` and `y_pred_col_names` are of different types.
|
225
226
|
ValueError: Multilabel `y_true_col_names` and `y_pred_col_names` are of different lengths.
|
226
227
|
"""
|
227
|
-
if type(y_true_col_names)
|
228
|
+
if type(y_true_col_names) is not type(y_pred_col_names):
|
228
229
|
raise TypeError(
|
229
230
|
"Label columns should be of the same type."
|
230
231
|
f"Got y_true_col_names={type(y_true_col_names)} vs y_pred_col_names={type(y_pred_col_names)}."
|
@@ -300,6 +301,7 @@ def validate_average_pos_label(average: Optional[str] = None, pos_label: Union[s
|
|
300
301
|
"average != 'binary' (got %r). You may use "
|
301
302
|
"labels=[pos_label] to specify a single positive class." % (pos_label, average),
|
302
303
|
UserWarning,
|
304
|
+
stacklevel=2,
|
303
305
|
)
|
304
306
|
|
305
307
|
|