snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SVR(BaseTransformer):
|
58
70
|
r"""Epsilon-Support Vector Regression
|
59
71
|
For more details on this class, see [sklearn.svm.SVR]
|
@@ -185,7 +197,9 @@ class SVR(BaseTransformer):
|
|
185
197
|
self.set_label_cols(label_cols)
|
186
198
|
self.set_passthrough_cols(passthrough_cols)
|
187
199
|
self.set_drop_input_cols(drop_input_cols)
|
188
|
-
self.set_sample_weight_col(sample_weight_col)
|
200
|
+
self.set_sample_weight_col(sample_weight_col)
|
201
|
+
self._use_external_memory_version = False
|
202
|
+
self._batch_size = -1
|
189
203
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
190
204
|
|
191
205
|
self._deps = list(deps)
|
@@ -271,11 +285,6 @@ class SVR(BaseTransformer):
|
|
271
285
|
if isinstance(dataset, DataFrame):
|
272
286
|
session = dataset._session
|
273
287
|
assert session is not None # keep mypy happy
|
274
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
275
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
276
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
277
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
278
|
-
|
279
288
|
# Specify input columns so column pruning will be enforced
|
280
289
|
selected_cols = self._get_active_columns()
|
281
290
|
if len(selected_cols) > 0:
|
@@ -303,7 +312,9 @@ class SVR(BaseTransformer):
|
|
303
312
|
label_cols=self.label_cols,
|
304
313
|
sample_weight_col=self.sample_weight_col,
|
305
314
|
autogenerated=self._autogenerated,
|
306
|
-
subproject=_SUBPROJECT
|
315
|
+
subproject=_SUBPROJECT,
|
316
|
+
use_external_memory_version=self._use_external_memory_version,
|
317
|
+
batch_size=self._batch_size,
|
307
318
|
)
|
308
319
|
self._sklearn_object = model_trainer.train()
|
309
320
|
self._is_fitted = True
|
@@ -574,6 +585,22 @@ class SVR(BaseTransformer):
|
|
574
585
|
# each row containing a list of values.
|
575
586
|
expected_dtype = "ARRAY"
|
576
587
|
|
588
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
589
|
+
if expected_dtype == "":
|
590
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
591
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
592
|
+
expected_dtype = "ARRAY"
|
593
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
594
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
595
|
+
expected_dtype = "ARRAY"
|
596
|
+
else:
|
597
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
598
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
599
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
600
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
601
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
602
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
603
|
+
|
577
604
|
output_df = self._batch_inference(
|
578
605
|
dataset=dataset,
|
579
606
|
inference_method="transform",
|
@@ -589,8 +616,8 @@ class SVR(BaseTransformer):
|
|
589
616
|
|
590
617
|
return output_df
|
591
618
|
|
592
|
-
@available_if(
|
593
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
619
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
594
621
|
""" Method not supported for this class.
|
595
622
|
|
596
623
|
|
@@ -603,13 +630,21 @@ class SVR(BaseTransformer):
|
|
603
630
|
Returns:
|
604
631
|
Predicted dataset.
|
605
632
|
"""
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
633
|
+
self.fit(dataset)
|
634
|
+
assert self._sklearn_object is not None
|
635
|
+
return self._sklearn_object.labels_
|
636
|
+
|
637
|
+
|
638
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
639
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
640
|
+
"""
|
641
|
+
Returns:
|
642
|
+
Transformed dataset.
|
643
|
+
"""
|
644
|
+
self.fit(dataset)
|
645
|
+
assert self._sklearn_object is not None
|
646
|
+
return self._sklearn_object.embedding_
|
647
|
+
|
613
648
|
|
614
649
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
615
650
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class DecisionTreeClassifier(BaseTransformer):
|
58
70
|
r"""A decision tree classifier
|
59
71
|
For more details on this class, see [sklearn.tree.DecisionTreeClassifier]
|
@@ -251,7 +263,9 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
251
263
|
self.set_label_cols(label_cols)
|
252
264
|
self.set_passthrough_cols(passthrough_cols)
|
253
265
|
self.set_drop_input_cols(drop_input_cols)
|
254
|
-
self.set_sample_weight_col(sample_weight_col)
|
266
|
+
self.set_sample_weight_col(sample_weight_col)
|
267
|
+
self._use_external_memory_version = False
|
268
|
+
self._batch_size = -1
|
255
269
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
256
270
|
|
257
271
|
self._deps = list(deps)
|
@@ -338,11 +352,6 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
338
352
|
if isinstance(dataset, DataFrame):
|
339
353
|
session = dataset._session
|
340
354
|
assert session is not None # keep mypy happy
|
341
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
342
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
343
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
344
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
345
|
-
|
346
355
|
# Specify input columns so column pruning will be enforced
|
347
356
|
selected_cols = self._get_active_columns()
|
348
357
|
if len(selected_cols) > 0:
|
@@ -370,7 +379,9 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
370
379
|
label_cols=self.label_cols,
|
371
380
|
sample_weight_col=self.sample_weight_col,
|
372
381
|
autogenerated=self._autogenerated,
|
373
|
-
subproject=_SUBPROJECT
|
382
|
+
subproject=_SUBPROJECT,
|
383
|
+
use_external_memory_version=self._use_external_memory_version,
|
384
|
+
batch_size=self._batch_size,
|
374
385
|
)
|
375
386
|
self._sklearn_object = model_trainer.train()
|
376
387
|
self._is_fitted = True
|
@@ -641,6 +652,22 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
641
652
|
# each row containing a list of values.
|
642
653
|
expected_dtype = "ARRAY"
|
643
654
|
|
655
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
656
|
+
if expected_dtype == "":
|
657
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
658
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
659
|
+
expected_dtype = "ARRAY"
|
660
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
661
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
662
|
+
expected_dtype = "ARRAY"
|
663
|
+
else:
|
664
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
665
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
666
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
667
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
668
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
669
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
670
|
+
|
644
671
|
output_df = self._batch_inference(
|
645
672
|
dataset=dataset,
|
646
673
|
inference_method="transform",
|
@@ -656,8 +683,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
656
683
|
|
657
684
|
return output_df
|
658
685
|
|
659
|
-
@available_if(
|
660
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
686
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
687
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
661
688
|
""" Method not supported for this class.
|
662
689
|
|
663
690
|
|
@@ -670,13 +697,21 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
670
697
|
Returns:
|
671
698
|
Predicted dataset.
|
672
699
|
"""
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
700
|
+
self.fit(dataset)
|
701
|
+
assert self._sklearn_object is not None
|
702
|
+
return self._sklearn_object.labels_
|
703
|
+
|
704
|
+
|
705
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
706
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
707
|
+
"""
|
708
|
+
Returns:
|
709
|
+
Transformed dataset.
|
710
|
+
"""
|
711
|
+
self.fit(dataset)
|
712
|
+
assert self._sklearn_object is not None
|
713
|
+
return self._sklearn_object.embedding_
|
714
|
+
|
680
715
|
|
681
716
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
682
717
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class DecisionTreeRegressor(BaseTransformer):
|
58
70
|
r"""A decision tree regressor
|
59
71
|
For more details on this class, see [sklearn.tree.DecisionTreeRegressor]
|
@@ -234,7 +246,9 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
234
246
|
self.set_label_cols(label_cols)
|
235
247
|
self.set_passthrough_cols(passthrough_cols)
|
236
248
|
self.set_drop_input_cols(drop_input_cols)
|
237
|
-
self.set_sample_weight_col(sample_weight_col)
|
249
|
+
self.set_sample_weight_col(sample_weight_col)
|
250
|
+
self._use_external_memory_version = False
|
251
|
+
self._batch_size = -1
|
238
252
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
239
253
|
|
240
254
|
self._deps = list(deps)
|
@@ -320,11 +334,6 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
320
334
|
if isinstance(dataset, DataFrame):
|
321
335
|
session = dataset._session
|
322
336
|
assert session is not None # keep mypy happy
|
323
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
324
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
325
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
326
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
327
|
-
|
328
337
|
# Specify input columns so column pruning will be enforced
|
329
338
|
selected_cols = self._get_active_columns()
|
330
339
|
if len(selected_cols) > 0:
|
@@ -352,7 +361,9 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
352
361
|
label_cols=self.label_cols,
|
353
362
|
sample_weight_col=self.sample_weight_col,
|
354
363
|
autogenerated=self._autogenerated,
|
355
|
-
subproject=_SUBPROJECT
|
364
|
+
subproject=_SUBPROJECT,
|
365
|
+
use_external_memory_version=self._use_external_memory_version,
|
366
|
+
batch_size=self._batch_size,
|
356
367
|
)
|
357
368
|
self._sklearn_object = model_trainer.train()
|
358
369
|
self._is_fitted = True
|
@@ -623,6 +634,22 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
623
634
|
# each row containing a list of values.
|
624
635
|
expected_dtype = "ARRAY"
|
625
636
|
|
637
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
638
|
+
if expected_dtype == "":
|
639
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
640
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
641
|
+
expected_dtype = "ARRAY"
|
642
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
643
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
644
|
+
expected_dtype = "ARRAY"
|
645
|
+
else:
|
646
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
647
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
648
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
649
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
650
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
651
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
652
|
+
|
626
653
|
output_df = self._batch_inference(
|
627
654
|
dataset=dataset,
|
628
655
|
inference_method="transform",
|
@@ -638,8 +665,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
638
665
|
|
639
666
|
return output_df
|
640
667
|
|
641
|
-
@available_if(
|
642
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
668
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
669
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
643
670
|
""" Method not supported for this class.
|
644
671
|
|
645
672
|
|
@@ -652,13 +679,21 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
652
679
|
Returns:
|
653
680
|
Predicted dataset.
|
654
681
|
"""
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
682
|
+
self.fit(dataset)
|
683
|
+
assert self._sklearn_object is not None
|
684
|
+
return self._sklearn_object.labels_
|
685
|
+
|
686
|
+
|
687
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
688
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
689
|
+
"""
|
690
|
+
Returns:
|
691
|
+
Transformed dataset.
|
692
|
+
"""
|
693
|
+
self.fit(dataset)
|
694
|
+
assert self._sklearn_object is not None
|
695
|
+
return self._sklearn_object.embedding_
|
696
|
+
|
662
697
|
|
663
698
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
664
699
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ExtraTreeClassifier(BaseTransformer):
|
58
70
|
r"""An extremely randomized tree classifier
|
59
71
|
For more details on this class, see [sklearn.tree.ExtraTreeClassifier]
|
@@ -243,7 +255,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
243
255
|
self.set_label_cols(label_cols)
|
244
256
|
self.set_passthrough_cols(passthrough_cols)
|
245
257
|
self.set_drop_input_cols(drop_input_cols)
|
246
|
-
self.set_sample_weight_col(sample_weight_col)
|
258
|
+
self.set_sample_weight_col(sample_weight_col)
|
259
|
+
self._use_external_memory_version = False
|
260
|
+
self._batch_size = -1
|
247
261
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
248
262
|
|
249
263
|
self._deps = list(deps)
|
@@ -330,11 +344,6 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
330
344
|
if isinstance(dataset, DataFrame):
|
331
345
|
session = dataset._session
|
332
346
|
assert session is not None # keep mypy happy
|
333
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
334
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
335
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
336
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
337
|
-
|
338
347
|
# Specify input columns so column pruning will be enforced
|
339
348
|
selected_cols = self._get_active_columns()
|
340
349
|
if len(selected_cols) > 0:
|
@@ -362,7 +371,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
362
371
|
label_cols=self.label_cols,
|
363
372
|
sample_weight_col=self.sample_weight_col,
|
364
373
|
autogenerated=self._autogenerated,
|
365
|
-
subproject=_SUBPROJECT
|
374
|
+
subproject=_SUBPROJECT,
|
375
|
+
use_external_memory_version=self._use_external_memory_version,
|
376
|
+
batch_size=self._batch_size,
|
366
377
|
)
|
367
378
|
self._sklearn_object = model_trainer.train()
|
368
379
|
self._is_fitted = True
|
@@ -633,6 +644,22 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
633
644
|
# each row containing a list of values.
|
634
645
|
expected_dtype = "ARRAY"
|
635
646
|
|
647
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
648
|
+
if expected_dtype == "":
|
649
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
650
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
651
|
+
expected_dtype = "ARRAY"
|
652
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
653
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
654
|
+
expected_dtype = "ARRAY"
|
655
|
+
else:
|
656
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
657
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
658
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
659
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
660
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
661
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
662
|
+
|
636
663
|
output_df = self._batch_inference(
|
637
664
|
dataset=dataset,
|
638
665
|
inference_method="transform",
|
@@ -648,8 +675,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
648
675
|
|
649
676
|
return output_df
|
650
677
|
|
651
|
-
@available_if(
|
652
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
678
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
679
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
653
680
|
""" Method not supported for this class.
|
654
681
|
|
655
682
|
|
@@ -662,13 +689,21 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
662
689
|
Returns:
|
663
690
|
Predicted dataset.
|
664
691
|
"""
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
692
|
+
self.fit(dataset)
|
693
|
+
assert self._sklearn_object is not None
|
694
|
+
return self._sklearn_object.labels_
|
695
|
+
|
696
|
+
|
697
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
698
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
699
|
+
"""
|
700
|
+
Returns:
|
701
|
+
Transformed dataset.
|
702
|
+
"""
|
703
|
+
self.fit(dataset)
|
704
|
+
assert self._sklearn_object is not None
|
705
|
+
return self._sklearn_object.embedding_
|
706
|
+
|
672
707
|
|
673
708
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
674
709
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ExtraTreeRegressor(BaseTransformer):
|
58
70
|
r"""An extremely randomized tree regressor
|
59
71
|
For more details on this class, see [sklearn.tree.ExtraTreeRegressor]
|
@@ -226,7 +238,9 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
226
238
|
self.set_label_cols(label_cols)
|
227
239
|
self.set_passthrough_cols(passthrough_cols)
|
228
240
|
self.set_drop_input_cols(drop_input_cols)
|
229
|
-
self.set_sample_weight_col(sample_weight_col)
|
241
|
+
self.set_sample_weight_col(sample_weight_col)
|
242
|
+
self._use_external_memory_version = False
|
243
|
+
self._batch_size = -1
|
230
244
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
231
245
|
|
232
246
|
self._deps = list(deps)
|
@@ -312,11 +326,6 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
312
326
|
if isinstance(dataset, DataFrame):
|
313
327
|
session = dataset._session
|
314
328
|
assert session is not None # keep mypy happy
|
315
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
316
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
317
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
318
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
319
|
-
|
320
329
|
# Specify input columns so column pruning will be enforced
|
321
330
|
selected_cols = self._get_active_columns()
|
322
331
|
if len(selected_cols) > 0:
|
@@ -344,7 +353,9 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
344
353
|
label_cols=self.label_cols,
|
345
354
|
sample_weight_col=self.sample_weight_col,
|
346
355
|
autogenerated=self._autogenerated,
|
347
|
-
subproject=_SUBPROJECT
|
356
|
+
subproject=_SUBPROJECT,
|
357
|
+
use_external_memory_version=self._use_external_memory_version,
|
358
|
+
batch_size=self._batch_size,
|
348
359
|
)
|
349
360
|
self._sklearn_object = model_trainer.train()
|
350
361
|
self._is_fitted = True
|
@@ -615,6 +626,22 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
615
626
|
# each row containing a list of values.
|
616
627
|
expected_dtype = "ARRAY"
|
617
628
|
|
629
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
630
|
+
if expected_dtype == "":
|
631
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
632
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
633
|
+
expected_dtype = "ARRAY"
|
634
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
635
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
636
|
+
expected_dtype = "ARRAY"
|
637
|
+
else:
|
638
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
639
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
640
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
641
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
642
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
643
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
644
|
+
|
618
645
|
output_df = self._batch_inference(
|
619
646
|
dataset=dataset,
|
620
647
|
inference_method="transform",
|
@@ -630,8 +657,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
630
657
|
|
631
658
|
return output_df
|
632
659
|
|
633
|
-
@available_if(
|
634
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
660
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
661
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
635
662
|
""" Method not supported for this class.
|
636
663
|
|
637
664
|
|
@@ -644,13 +671,21 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
644
671
|
Returns:
|
645
672
|
Predicted dataset.
|
646
673
|
"""
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
674
|
+
self.fit(dataset)
|
675
|
+
assert self._sklearn_object is not None
|
676
|
+
return self._sklearn_object.labels_
|
677
|
+
|
678
|
+
|
679
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
680
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
681
|
+
"""
|
682
|
+
Returns:
|
683
|
+
Transformed dataset.
|
684
|
+
"""
|
685
|
+
self.fit(dataset)
|
686
|
+
assert self._sklearn_object is not None
|
687
|
+
return self._sklearn_object.embedding_
|
688
|
+
|
654
689
|
|
655
690
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
656
691
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|