snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class PolynomialCountSketch(BaseTransformer):
|
58
70
|
r"""Polynomial kernel approximation via Tensor Sketch
|
59
71
|
For more details on this class, see [sklearn.kernel_approximation.PolynomialCountSketch]
|
@@ -151,7 +163,9 @@ class PolynomialCountSketch(BaseTransformer):
|
|
151
163
|
self.set_label_cols(label_cols)
|
152
164
|
self.set_passthrough_cols(passthrough_cols)
|
153
165
|
self.set_drop_input_cols(drop_input_cols)
|
154
|
-
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self.set_sample_weight_col(sample_weight_col)
|
167
|
+
self._use_external_memory_version = False
|
168
|
+
self._batch_size = -1
|
155
169
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
156
170
|
|
157
171
|
self._deps = list(deps)
|
@@ -231,11 +245,6 @@ class PolynomialCountSketch(BaseTransformer):
|
|
231
245
|
if isinstance(dataset, DataFrame):
|
232
246
|
session = dataset._session
|
233
247
|
assert session is not None # keep mypy happy
|
234
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
235
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
236
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
237
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
238
|
-
|
239
248
|
# Specify input columns so column pruning will be enforced
|
240
249
|
selected_cols = self._get_active_columns()
|
241
250
|
if len(selected_cols) > 0:
|
@@ -263,7 +272,9 @@ class PolynomialCountSketch(BaseTransformer):
|
|
263
272
|
label_cols=self.label_cols,
|
264
273
|
sample_weight_col=self.sample_weight_col,
|
265
274
|
autogenerated=self._autogenerated,
|
266
|
-
subproject=_SUBPROJECT
|
275
|
+
subproject=_SUBPROJECT,
|
276
|
+
use_external_memory_version=self._use_external_memory_version,
|
277
|
+
batch_size=self._batch_size,
|
267
278
|
)
|
268
279
|
self._sklearn_object = model_trainer.train()
|
269
280
|
self._is_fitted = True
|
@@ -534,6 +545,22 @@ class PolynomialCountSketch(BaseTransformer):
|
|
534
545
|
# each row containing a list of values.
|
535
546
|
expected_dtype = "ARRAY"
|
536
547
|
|
548
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
549
|
+
if expected_dtype == "":
|
550
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
551
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
552
|
+
expected_dtype = "ARRAY"
|
553
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
554
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
555
|
+
expected_dtype = "ARRAY"
|
556
|
+
else:
|
557
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
558
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
559
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
560
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
561
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
562
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
563
|
+
|
537
564
|
output_df = self._batch_inference(
|
538
565
|
dataset=dataset,
|
539
566
|
inference_method="transform",
|
@@ -549,8 +576,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
549
576
|
|
550
577
|
return output_df
|
551
578
|
|
552
|
-
@available_if(
|
553
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
579
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
580
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
554
581
|
""" Method not supported for this class.
|
555
582
|
|
556
583
|
|
@@ -563,13 +590,21 @@ class PolynomialCountSketch(BaseTransformer):
|
|
563
590
|
Returns:
|
564
591
|
Predicted dataset.
|
565
592
|
"""
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
593
|
+
self.fit(dataset)
|
594
|
+
assert self._sklearn_object is not None
|
595
|
+
return self._sklearn_object.labels_
|
596
|
+
|
597
|
+
|
598
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
599
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
600
|
+
"""
|
601
|
+
Returns:
|
602
|
+
Transformed dataset.
|
603
|
+
"""
|
604
|
+
self.fit(dataset)
|
605
|
+
assert self._sklearn_object is not None
|
606
|
+
return self._sklearn_object.embedding_
|
607
|
+
|
573
608
|
|
574
609
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
575
610
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RBFSampler(BaseTransformer):
|
58
70
|
r"""Approximate a RBF kernel feature map using random Fourier features
|
59
71
|
For more details on this class, see [sklearn.kernel_approximation.RBFSampler]
|
@@ -140,7 +152,9 @@ class RBFSampler(BaseTransformer):
|
|
140
152
|
self.set_label_cols(label_cols)
|
141
153
|
self.set_passthrough_cols(passthrough_cols)
|
142
154
|
self.set_drop_input_cols(drop_input_cols)
|
143
|
-
self.set_sample_weight_col(sample_weight_col)
|
155
|
+
self.set_sample_weight_col(sample_weight_col)
|
156
|
+
self._use_external_memory_version = False
|
157
|
+
self._batch_size = -1
|
144
158
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
145
159
|
|
146
160
|
self._deps = list(deps)
|
@@ -218,11 +232,6 @@ class RBFSampler(BaseTransformer):
|
|
218
232
|
if isinstance(dataset, DataFrame):
|
219
233
|
session = dataset._session
|
220
234
|
assert session is not None # keep mypy happy
|
221
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
222
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
223
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
224
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
225
|
-
|
226
235
|
# Specify input columns so column pruning will be enforced
|
227
236
|
selected_cols = self._get_active_columns()
|
228
237
|
if len(selected_cols) > 0:
|
@@ -250,7 +259,9 @@ class RBFSampler(BaseTransformer):
|
|
250
259
|
label_cols=self.label_cols,
|
251
260
|
sample_weight_col=self.sample_weight_col,
|
252
261
|
autogenerated=self._autogenerated,
|
253
|
-
subproject=_SUBPROJECT
|
262
|
+
subproject=_SUBPROJECT,
|
263
|
+
use_external_memory_version=self._use_external_memory_version,
|
264
|
+
batch_size=self._batch_size,
|
254
265
|
)
|
255
266
|
self._sklearn_object = model_trainer.train()
|
256
267
|
self._is_fitted = True
|
@@ -521,6 +532,22 @@ class RBFSampler(BaseTransformer):
|
|
521
532
|
# each row containing a list of values.
|
522
533
|
expected_dtype = "ARRAY"
|
523
534
|
|
535
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
536
|
+
if expected_dtype == "":
|
537
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
538
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
539
|
+
expected_dtype = "ARRAY"
|
540
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
541
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
542
|
+
expected_dtype = "ARRAY"
|
543
|
+
else:
|
544
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
545
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
546
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
547
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
548
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
549
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
550
|
+
|
524
551
|
output_df = self._batch_inference(
|
525
552
|
dataset=dataset,
|
526
553
|
inference_method="transform",
|
@@ -536,8 +563,8 @@ class RBFSampler(BaseTransformer):
|
|
536
563
|
|
537
564
|
return output_df
|
538
565
|
|
539
|
-
@available_if(
|
540
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
566
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
567
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
541
568
|
""" Method not supported for this class.
|
542
569
|
|
543
570
|
|
@@ -550,13 +577,21 @@ class RBFSampler(BaseTransformer):
|
|
550
577
|
Returns:
|
551
578
|
Predicted dataset.
|
552
579
|
"""
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
580
|
+
self.fit(dataset)
|
581
|
+
assert self._sklearn_object is not None
|
582
|
+
return self._sklearn_object.labels_
|
583
|
+
|
584
|
+
|
585
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
586
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
587
|
+
"""
|
588
|
+
Returns:
|
589
|
+
Transformed dataset.
|
590
|
+
"""
|
591
|
+
self.fit(dataset)
|
592
|
+
assert self._sklearn_object is not None
|
593
|
+
return self._sklearn_object.embedding_
|
594
|
+
|
560
595
|
|
561
596
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
562
597
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SkewedChi2Sampler(BaseTransformer):
|
58
70
|
r"""Approximate feature map for "skewed chi-squared" kernel
|
59
71
|
For more details on this class, see [sklearn.kernel_approximation.SkewedChi2Sampler]
|
@@ -138,7 +150,9 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
138
150
|
self.set_label_cols(label_cols)
|
139
151
|
self.set_passthrough_cols(passthrough_cols)
|
140
152
|
self.set_drop_input_cols(drop_input_cols)
|
141
|
-
self.set_sample_weight_col(sample_weight_col)
|
153
|
+
self.set_sample_weight_col(sample_weight_col)
|
154
|
+
self._use_external_memory_version = False
|
155
|
+
self._batch_size = -1
|
142
156
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
143
157
|
|
144
158
|
self._deps = list(deps)
|
@@ -216,11 +230,6 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
216
230
|
if isinstance(dataset, DataFrame):
|
217
231
|
session = dataset._session
|
218
232
|
assert session is not None # keep mypy happy
|
219
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
220
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
221
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
222
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
223
|
-
|
224
233
|
# Specify input columns so column pruning will be enforced
|
225
234
|
selected_cols = self._get_active_columns()
|
226
235
|
if len(selected_cols) > 0:
|
@@ -248,7 +257,9 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
248
257
|
label_cols=self.label_cols,
|
249
258
|
sample_weight_col=self.sample_weight_col,
|
250
259
|
autogenerated=self._autogenerated,
|
251
|
-
subproject=_SUBPROJECT
|
260
|
+
subproject=_SUBPROJECT,
|
261
|
+
use_external_memory_version=self._use_external_memory_version,
|
262
|
+
batch_size=self._batch_size,
|
252
263
|
)
|
253
264
|
self._sklearn_object = model_trainer.train()
|
254
265
|
self._is_fitted = True
|
@@ -519,6 +530,22 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
519
530
|
# each row containing a list of values.
|
520
531
|
expected_dtype = "ARRAY"
|
521
532
|
|
533
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
534
|
+
if expected_dtype == "":
|
535
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
536
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
537
|
+
expected_dtype = "ARRAY"
|
538
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
539
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
540
|
+
expected_dtype = "ARRAY"
|
541
|
+
else:
|
542
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
543
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
544
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
545
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
546
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
547
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
548
|
+
|
522
549
|
output_df = self._batch_inference(
|
523
550
|
dataset=dataset,
|
524
551
|
inference_method="transform",
|
@@ -534,8 +561,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
534
561
|
|
535
562
|
return output_df
|
536
563
|
|
537
|
-
@available_if(
|
538
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
564
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
565
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
539
566
|
""" Method not supported for this class.
|
540
567
|
|
541
568
|
|
@@ -548,13 +575,21 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
548
575
|
Returns:
|
549
576
|
Predicted dataset.
|
550
577
|
"""
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
578
|
+
self.fit(dataset)
|
579
|
+
assert self._sklearn_object is not None
|
580
|
+
return self._sklearn_object.labels_
|
581
|
+
|
582
|
+
|
583
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
584
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
585
|
+
"""
|
586
|
+
Returns:
|
587
|
+
Transformed dataset.
|
588
|
+
"""
|
589
|
+
self.fit(dataset)
|
590
|
+
assert self._sklearn_object is not None
|
591
|
+
return self._sklearn_object.embedding_
|
592
|
+
|
558
593
|
|
559
594
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
560
595
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_ridge".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KernelRidge(BaseTransformer):
|
58
70
|
r"""Kernel ridge regression
|
59
71
|
For more details on this class, see [sklearn.kernel_ridge.KernelRidge]
|
@@ -171,7 +183,9 @@ class KernelRidge(BaseTransformer):
|
|
171
183
|
self.set_label_cols(label_cols)
|
172
184
|
self.set_passthrough_cols(passthrough_cols)
|
173
185
|
self.set_drop_input_cols(drop_input_cols)
|
174
|
-
self.set_sample_weight_col(sample_weight_col)
|
186
|
+
self.set_sample_weight_col(sample_weight_col)
|
187
|
+
self._use_external_memory_version = False
|
188
|
+
self._batch_size = -1
|
175
189
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
176
190
|
|
177
191
|
self._deps = list(deps)
|
@@ -252,11 +266,6 @@ class KernelRidge(BaseTransformer):
|
|
252
266
|
if isinstance(dataset, DataFrame):
|
253
267
|
session = dataset._session
|
254
268
|
assert session is not None # keep mypy happy
|
255
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
256
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
257
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
258
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
259
|
-
|
260
269
|
# Specify input columns so column pruning will be enforced
|
261
270
|
selected_cols = self._get_active_columns()
|
262
271
|
if len(selected_cols) > 0:
|
@@ -284,7 +293,9 @@ class KernelRidge(BaseTransformer):
|
|
284
293
|
label_cols=self.label_cols,
|
285
294
|
sample_weight_col=self.sample_weight_col,
|
286
295
|
autogenerated=self._autogenerated,
|
287
|
-
subproject=_SUBPROJECT
|
296
|
+
subproject=_SUBPROJECT,
|
297
|
+
use_external_memory_version=self._use_external_memory_version,
|
298
|
+
batch_size=self._batch_size,
|
288
299
|
)
|
289
300
|
self._sklearn_object = model_trainer.train()
|
290
301
|
self._is_fitted = True
|
@@ -555,6 +566,22 @@ class KernelRidge(BaseTransformer):
|
|
555
566
|
# each row containing a list of values.
|
556
567
|
expected_dtype = "ARRAY"
|
557
568
|
|
569
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
570
|
+
if expected_dtype == "":
|
571
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
572
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
573
|
+
expected_dtype = "ARRAY"
|
574
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
575
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
576
|
+
expected_dtype = "ARRAY"
|
577
|
+
else:
|
578
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
579
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
580
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
581
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
582
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
583
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
584
|
+
|
558
585
|
output_df = self._batch_inference(
|
559
586
|
dataset=dataset,
|
560
587
|
inference_method="transform",
|
@@ -570,8 +597,8 @@ class KernelRidge(BaseTransformer):
|
|
570
597
|
|
571
598
|
return output_df
|
572
599
|
|
573
|
-
@available_if(
|
574
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
600
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
601
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
575
602
|
""" Method not supported for this class.
|
576
603
|
|
577
604
|
|
@@ -584,13 +611,21 @@ class KernelRidge(BaseTransformer):
|
|
584
611
|
Returns:
|
585
612
|
Predicted dataset.
|
586
613
|
"""
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
614
|
+
self.fit(dataset)
|
615
|
+
assert self._sklearn_object is not None
|
616
|
+
return self._sklearn_object.labels_
|
617
|
+
|
618
|
+
|
619
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
621
|
+
"""
|
622
|
+
Returns:
|
623
|
+
Transformed dataset.
|
624
|
+
"""
|
625
|
+
self.fit(dataset)
|
626
|
+
assert self._sklearn_object is not None
|
627
|
+
return self._sklearn_object.embedding_
|
628
|
+
|
594
629
|
|
595
630
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
596
631
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
|
|
53
53
|
_SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", "").split("_")])
|
54
54
|
|
55
55
|
|
56
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
57
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
58
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
59
|
+
return check
|
60
|
+
|
61
|
+
|
62
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
63
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
64
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
65
|
+
return check
|
66
|
+
|
67
|
+
|
56
68
|
class LGBMClassifier(BaseTransformer):
|
57
69
|
r"""LightGBM classifier
|
58
70
|
For more details on this class, see [lightgbm.LGBMClassifier]
|
@@ -144,7 +156,9 @@ class LGBMClassifier(BaseTransformer):
|
|
144
156
|
self.set_label_cols(label_cols)
|
145
157
|
self.set_passthrough_cols(passthrough_cols)
|
146
158
|
self.set_drop_input_cols(drop_input_cols)
|
147
|
-
self.set_sample_weight_col(sample_weight_col)
|
159
|
+
self.set_sample_weight_col(sample_weight_col)
|
160
|
+
self._use_external_memory_version = False
|
161
|
+
self._batch_size = -1
|
148
162
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'])
|
149
163
|
|
150
164
|
self._deps = list(deps)
|
@@ -240,11 +254,6 @@ class LGBMClassifier(BaseTransformer):
|
|
240
254
|
if isinstance(dataset, DataFrame):
|
241
255
|
session = dataset._session
|
242
256
|
assert session is not None # keep mypy happy
|
243
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
244
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
245
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
246
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
247
|
-
|
248
257
|
# Specify input columns so column pruning will be enforced
|
249
258
|
selected_cols = self._get_active_columns()
|
250
259
|
if len(selected_cols) > 0:
|
@@ -272,7 +281,9 @@ class LGBMClassifier(BaseTransformer):
|
|
272
281
|
label_cols=self.label_cols,
|
273
282
|
sample_weight_col=self.sample_weight_col,
|
274
283
|
autogenerated=self._autogenerated,
|
275
|
-
subproject=_SUBPROJECT
|
284
|
+
subproject=_SUBPROJECT,
|
285
|
+
use_external_memory_version=self._use_external_memory_version,
|
286
|
+
batch_size=self._batch_size,
|
276
287
|
)
|
277
288
|
self._sklearn_object = model_trainer.train()
|
278
289
|
self._is_fitted = True
|
@@ -543,6 +554,22 @@ class LGBMClassifier(BaseTransformer):
|
|
543
554
|
# each row containing a list of values.
|
544
555
|
expected_dtype = "ARRAY"
|
545
556
|
|
557
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
558
|
+
if expected_dtype == "":
|
559
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
560
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
561
|
+
expected_dtype = "ARRAY"
|
562
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
563
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
564
|
+
expected_dtype = "ARRAY"
|
565
|
+
else:
|
566
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
567
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
568
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
569
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
570
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
571
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
572
|
+
|
546
573
|
output_df = self._batch_inference(
|
547
574
|
dataset=dataset,
|
548
575
|
inference_method="transform",
|
@@ -558,8 +585,8 @@ class LGBMClassifier(BaseTransformer):
|
|
558
585
|
|
559
586
|
return output_df
|
560
587
|
|
561
|
-
@available_if(
|
562
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
588
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
589
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
563
590
|
""" Method not supported for this class.
|
564
591
|
|
565
592
|
|
@@ -572,13 +599,21 @@ class LGBMClassifier(BaseTransformer):
|
|
572
599
|
Returns:
|
573
600
|
Predicted dataset.
|
574
601
|
"""
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
602
|
+
self.fit(dataset)
|
603
|
+
assert self._sklearn_object is not None
|
604
|
+
return self._sklearn_object.labels_
|
605
|
+
|
606
|
+
|
607
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
608
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
609
|
+
"""
|
610
|
+
Returns:
|
611
|
+
Transformed dataset.
|
612
|
+
"""
|
613
|
+
self.fit(dataset)
|
614
|
+
assert self._sklearn_object is not None
|
615
|
+
return self._sklearn_object.embedding_
|
616
|
+
|
582
617
|
|
583
618
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
584
619
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|