snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class OrthogonalMatchingPursuit(BaseTransformer):
|
71
64
|
r"""Orthogonal Matching Pursuit model (OMP)
|
72
65
|
For more details on this class, see [sklearn.linear_model.OrthogonalMatchingPursuit]
|
@@ -227,12 +220,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
227
220
|
)
|
228
221
|
return selected_cols
|
229
222
|
|
230
|
-
|
231
|
-
project=_PROJECT,
|
232
|
-
subproject=_SUBPROJECT,
|
233
|
-
custom_tags=dict([("autogen", True)]),
|
234
|
-
)
|
235
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OrthogonalMatchingPursuit":
|
223
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OrthogonalMatchingPursuit":
|
236
224
|
"""Fit the model using X, y as training data
|
237
225
|
For more details on this function, see [sklearn.linear_model.OrthogonalMatchingPursuit.fit]
|
238
226
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.OrthogonalMatchingPursuit.html#sklearn.linear_model.OrthogonalMatchingPursuit.fit)
|
@@ -259,12 +247,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
259
247
|
|
260
248
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
261
249
|
|
262
|
-
|
250
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
263
251
|
if SNOWML_SPROC_ENV in os.environ:
|
264
252
|
statement_params = telemetry.get_function_usage_statement_params(
|
265
253
|
project=_PROJECT,
|
266
254
|
subproject=_SUBPROJECT,
|
267
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
255
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
256
|
+
inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__
|
257
|
+
),
|
268
258
|
api_calls=[Session.call],
|
269
259
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
270
260
|
)
|
@@ -285,27 +275,24 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
285
275
|
)
|
286
276
|
self._sklearn_object = model_trainer.train()
|
287
277
|
self._is_fitted = True
|
288
|
-
self.
|
278
|
+
self._generate_model_signatures(dataset)
|
289
279
|
return self
|
290
280
|
|
291
281
|
def _batch_inference_validate_snowpark(
|
292
282
|
self,
|
293
283
|
dataset: DataFrame,
|
294
284
|
inference_method: str,
|
295
|
-
) ->
|
296
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
297
|
-
return the available package that exists in the snowflake anaconda channel
|
285
|
+
) -> None:
|
286
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
298
287
|
|
299
288
|
Args:
|
300
289
|
dataset: snowpark dataframe
|
301
290
|
inference_method: the inference method such as predict, score...
|
302
|
-
|
291
|
+
|
303
292
|
Raises:
|
304
293
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
305
294
|
SnowflakeMLException: If the session is None, raise error
|
306
295
|
|
307
|
-
Returns:
|
308
|
-
A list of available package that exists in the snowflake anaconda channel
|
309
296
|
"""
|
310
297
|
if not self._is_fitted:
|
311
298
|
raise exceptions.SnowflakeMLException(
|
@@ -323,9 +310,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
323
310
|
"Session must not specified for snowpark dataset."
|
324
311
|
),
|
325
312
|
)
|
326
|
-
|
327
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
328
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
313
|
+
|
329
314
|
|
330
315
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
331
316
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +346,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
361
346
|
# when it is classifier, infer the datatype from label columns
|
362
347
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
363
348
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
364
|
-
label_cols_signatures = [
|
349
|
+
label_cols_signatures = [
|
350
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
351
|
+
]
|
365
352
|
if len(label_cols_signatures) == 0:
|
366
353
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
367
354
|
raise exceptions.SnowflakeMLException(
|
@@ -369,25 +356,23 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
369
356
|
original_exception=ValueError(error_str),
|
370
357
|
)
|
371
358
|
|
372
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
373
|
-
label_cols_signatures[0].as_snowpark_type()
|
374
|
-
)
|
359
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
375
360
|
|
376
|
-
self.
|
377
|
-
|
361
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
362
|
+
self._deps = self._get_dependencies()
|
363
|
+
assert isinstance(
|
364
|
+
dataset._session, Session
|
365
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
378
366
|
|
379
367
|
transform_kwargs = dict(
|
380
|
-
session
|
381
|
-
dependencies
|
382
|
-
drop_input_cols
|
383
|
-
expected_output_cols_type
|
368
|
+
session=dataset._session,
|
369
|
+
dependencies=self._deps,
|
370
|
+
drop_input_cols=self._drop_input_cols,
|
371
|
+
expected_output_cols_type=expected_type_inferred,
|
384
372
|
)
|
385
373
|
|
386
374
|
elif isinstance(dataset, pd.DataFrame):
|
387
|
-
transform_kwargs = dict(
|
388
|
-
snowpark_input_cols = self._snowpark_cols,
|
389
|
-
drop_input_cols = self._drop_input_cols
|
390
|
-
)
|
375
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
391
376
|
|
392
377
|
transform_handlers = ModelTransformerBuilder.build(
|
393
378
|
dataset=dataset,
|
@@ -427,7 +412,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
427
412
|
Transformed dataset.
|
428
413
|
"""
|
429
414
|
super()._check_dataset_type(dataset)
|
430
|
-
inference_method="transform"
|
415
|
+
inference_method = "transform"
|
431
416
|
|
432
417
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
433
418
|
# are specific to the type of dataset used.
|
@@ -457,24 +442,19 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
457
442
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
458
443
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
459
444
|
|
460
|
-
self.
|
461
|
-
|
462
|
-
inference_method=inference_method,
|
463
|
-
)
|
445
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
446
|
+
self._deps = self._get_dependencies()
|
464
447
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
465
448
|
|
466
449
|
transform_kwargs = dict(
|
467
|
-
session
|
468
|
-
dependencies
|
469
|
-
drop_input_cols
|
470
|
-
expected_output_cols_type
|
450
|
+
session=dataset._session,
|
451
|
+
dependencies=self._deps,
|
452
|
+
drop_input_cols=self._drop_input_cols,
|
453
|
+
expected_output_cols_type=expected_dtype,
|
471
454
|
)
|
472
455
|
|
473
456
|
elif isinstance(dataset, pd.DataFrame):
|
474
|
-
transform_kwargs = dict(
|
475
|
-
snowpark_input_cols = self._snowpark_cols,
|
476
|
-
drop_input_cols = self._drop_input_cols
|
477
|
-
)
|
457
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
478
458
|
|
479
459
|
transform_handlers = ModelTransformerBuilder.build(
|
480
460
|
dataset=dataset,
|
@@ -493,7 +473,11 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
493
473
|
return output_df
|
494
474
|
|
495
475
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
496
|
-
def fit_predict(
|
476
|
+
def fit_predict(
|
477
|
+
self,
|
478
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
479
|
+
output_cols_prefix: str = "fit_predict_",
|
480
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
497
481
|
""" Method not supported for this class.
|
498
482
|
|
499
483
|
|
@@ -518,22 +502,104 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
518
502
|
)
|
519
503
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
520
504
|
drop_input_cols=self._drop_input_cols,
|
521
|
-
expected_output_cols_list=
|
505
|
+
expected_output_cols_list=(
|
506
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
507
|
+
),
|
522
508
|
)
|
523
509
|
self._sklearn_object = fitted_estimator
|
524
510
|
self._is_fitted = True
|
525
511
|
return output_result
|
526
512
|
|
513
|
+
|
514
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
515
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
516
|
+
""" Method not supported for this class.
|
517
|
+
|
527
518
|
|
528
|
-
|
529
|
-
|
530
|
-
|
519
|
+
Raises:
|
520
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
521
|
+
|
522
|
+
Args:
|
523
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
524
|
+
Snowpark or Pandas DataFrame.
|
525
|
+
output_cols_prefix: Prefix for the response columns
|
531
526
|
Returns:
|
532
527
|
Transformed dataset.
|
533
528
|
"""
|
534
|
-
self.
|
535
|
-
|
536
|
-
|
529
|
+
self._infer_input_output_cols(dataset)
|
530
|
+
super()._check_dataset_type(dataset)
|
531
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
532
|
+
estimator=self._sklearn_object,
|
533
|
+
dataset=dataset,
|
534
|
+
input_cols=self.input_cols,
|
535
|
+
label_cols=self.label_cols,
|
536
|
+
sample_weight_col=self.sample_weight_col,
|
537
|
+
autogenerated=self._autogenerated,
|
538
|
+
subproject=_SUBPROJECT,
|
539
|
+
)
|
540
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
541
|
+
drop_input_cols=self._drop_input_cols,
|
542
|
+
expected_output_cols_list=self.output_cols,
|
543
|
+
)
|
544
|
+
self._sklearn_object = fitted_estimator
|
545
|
+
self._is_fitted = True
|
546
|
+
return output_result
|
547
|
+
|
548
|
+
|
549
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
550
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
551
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
552
|
+
"""
|
553
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
554
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
555
|
+
if output_cols:
|
556
|
+
output_cols = [
|
557
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
558
|
+
for c in output_cols
|
559
|
+
]
|
560
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
561
|
+
output_cols = [output_cols_prefix]
|
562
|
+
elif self._sklearn_object is not None:
|
563
|
+
classes = self._sklearn_object.classes_
|
564
|
+
if isinstance(classes, numpy.ndarray):
|
565
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
566
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
567
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
568
|
+
output_cols = []
|
569
|
+
for i, cl in enumerate(classes):
|
570
|
+
# For binary classification, there is only one output column for each class
|
571
|
+
# ndarray as the two classes are complementary.
|
572
|
+
if len(cl) == 2:
|
573
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
574
|
+
else:
|
575
|
+
output_cols.extend([
|
576
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
577
|
+
])
|
578
|
+
else:
|
579
|
+
output_cols = []
|
580
|
+
|
581
|
+
# Make sure column names are valid snowflake identifiers.
|
582
|
+
assert output_cols is not None # Make MyPy happy
|
583
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
584
|
+
|
585
|
+
return rv
|
586
|
+
|
587
|
+
def _align_expected_output_names(
|
588
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
589
|
+
) -> List[str]:
|
590
|
+
# in case the inferred output column names dimension is different
|
591
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
592
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
593
|
+
output_df_columns = list(output_df_pd.columns)
|
594
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
595
|
+
if self.sample_weight_col:
|
596
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
597
|
+
# if the dimension of inferred output column names is correct; use it
|
598
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
599
|
+
return expected_output_cols_list
|
600
|
+
# otherwise, use the sklearn estimator's output
|
601
|
+
else:
|
602
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
537
603
|
|
538
604
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
539
605
|
@telemetry.send_api_usage_telemetry(
|
@@ -565,24 +631,26 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
565
631
|
# are specific to the type of dataset used.
|
566
632
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
567
633
|
|
634
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
635
|
+
|
568
636
|
if isinstance(dataset, DataFrame):
|
569
|
-
self.
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
637
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
638
|
+
self._deps = self._get_dependencies()
|
639
|
+
assert isinstance(
|
640
|
+
dataset._session, Session
|
641
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
574
642
|
transform_kwargs = dict(
|
575
643
|
session=dataset._session,
|
576
644
|
dependencies=self._deps,
|
577
|
-
drop_input_cols
|
645
|
+
drop_input_cols=self._drop_input_cols,
|
578
646
|
expected_output_cols_type="float",
|
579
647
|
)
|
648
|
+
expected_output_cols = self._align_expected_output_names(
|
649
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
650
|
+
)
|
580
651
|
|
581
652
|
elif isinstance(dataset, pd.DataFrame):
|
582
|
-
transform_kwargs = dict(
|
583
|
-
snowpark_input_cols = self._snowpark_cols,
|
584
|
-
drop_input_cols = self._drop_input_cols
|
585
|
-
)
|
653
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
586
654
|
|
587
655
|
transform_handlers = ModelTransformerBuilder.build(
|
588
656
|
dataset=dataset,
|
@@ -594,7 +662,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
594
662
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
595
663
|
inference_method=inference_method,
|
596
664
|
input_cols=self.input_cols,
|
597
|
-
expected_output_cols=
|
665
|
+
expected_output_cols=expected_output_cols,
|
598
666
|
**transform_kwargs
|
599
667
|
)
|
600
668
|
return output_df
|
@@ -624,29 +692,30 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
624
692
|
Output dataset with log probability of the sample for each class in the model.
|
625
693
|
"""
|
626
694
|
super()._check_dataset_type(dataset)
|
627
|
-
inference_method="predict_log_proba"
|
695
|
+
inference_method = "predict_log_proba"
|
696
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
628
697
|
|
629
698
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
630
699
|
# are specific to the type of dataset used.
|
631
700
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
632
701
|
|
633
702
|
if isinstance(dataset, DataFrame):
|
634
|
-
self.
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
703
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
704
|
+
self._deps = self._get_dependencies()
|
705
|
+
assert isinstance(
|
706
|
+
dataset._session, Session
|
707
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
639
708
|
transform_kwargs = dict(
|
640
709
|
session=dataset._session,
|
641
710
|
dependencies=self._deps,
|
642
|
-
drop_input_cols
|
711
|
+
drop_input_cols=self._drop_input_cols,
|
643
712
|
expected_output_cols_type="float",
|
644
713
|
)
|
714
|
+
expected_output_cols = self._align_expected_output_names(
|
715
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
716
|
+
)
|
645
717
|
elif isinstance(dataset, pd.DataFrame):
|
646
|
-
transform_kwargs = dict(
|
647
|
-
snowpark_input_cols = self._snowpark_cols,
|
648
|
-
drop_input_cols = self._drop_input_cols
|
649
|
-
)
|
718
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
650
719
|
|
651
720
|
transform_handlers = ModelTransformerBuilder.build(
|
652
721
|
dataset=dataset,
|
@@ -659,7 +728,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
659
728
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
660
729
|
inference_method=inference_method,
|
661
730
|
input_cols=self.input_cols,
|
662
|
-
expected_output_cols=
|
731
|
+
expected_output_cols=expected_output_cols,
|
663
732
|
**transform_kwargs
|
664
733
|
)
|
665
734
|
return output_df
|
@@ -685,30 +754,32 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
685
754
|
Output dataset with results of the decision function for the samples in input dataset.
|
686
755
|
"""
|
687
756
|
super()._check_dataset_type(dataset)
|
688
|
-
inference_method="decision_function"
|
757
|
+
inference_method = "decision_function"
|
689
758
|
|
690
759
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
691
760
|
# are specific to the type of dataset used.
|
692
761
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
693
762
|
|
763
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
764
|
+
|
694
765
|
if isinstance(dataset, DataFrame):
|
695
|
-
self.
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
766
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
767
|
+
self._deps = self._get_dependencies()
|
768
|
+
assert isinstance(
|
769
|
+
dataset._session, Session
|
770
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
700
771
|
transform_kwargs = dict(
|
701
772
|
session=dataset._session,
|
702
773
|
dependencies=self._deps,
|
703
|
-
drop_input_cols
|
774
|
+
drop_input_cols=self._drop_input_cols,
|
704
775
|
expected_output_cols_type="float",
|
705
776
|
)
|
777
|
+
expected_output_cols = self._align_expected_output_names(
|
778
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
779
|
+
)
|
706
780
|
|
707
781
|
elif isinstance(dataset, pd.DataFrame):
|
708
|
-
transform_kwargs = dict(
|
709
|
-
snowpark_input_cols = self._snowpark_cols,
|
710
|
-
drop_input_cols = self._drop_input_cols
|
711
|
-
)
|
782
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
712
783
|
|
713
784
|
transform_handlers = ModelTransformerBuilder.build(
|
714
785
|
dataset=dataset,
|
@@ -721,7 +792,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
721
792
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
722
793
|
inference_method=inference_method,
|
723
794
|
input_cols=self.input_cols,
|
724
|
-
expected_output_cols=
|
795
|
+
expected_output_cols=expected_output_cols,
|
725
796
|
**transform_kwargs
|
726
797
|
)
|
727
798
|
return output_df
|
@@ -750,17 +821,17 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
750
821
|
Output dataset with probability of the sample for each class in the model.
|
751
822
|
"""
|
752
823
|
super()._check_dataset_type(dataset)
|
753
|
-
inference_method="score_samples"
|
824
|
+
inference_method = "score_samples"
|
754
825
|
|
755
826
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
756
827
|
# are specific to the type of dataset used.
|
757
828
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
758
829
|
|
830
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
831
|
+
|
759
832
|
if isinstance(dataset, DataFrame):
|
760
|
-
self.
|
761
|
-
|
762
|
-
inference_method=inference_method,
|
763
|
-
)
|
833
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
834
|
+
self._deps = self._get_dependencies()
|
764
835
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
765
836
|
transform_kwargs = dict(
|
766
837
|
session=dataset._session,
|
@@ -768,6 +839,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
768
839
|
drop_input_cols = self._drop_input_cols,
|
769
840
|
expected_output_cols_type="float",
|
770
841
|
)
|
842
|
+
expected_output_cols = self._align_expected_output_names(
|
843
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
844
|
+
)
|
771
845
|
|
772
846
|
elif isinstance(dataset, pd.DataFrame):
|
773
847
|
transform_kwargs = dict(
|
@@ -786,7 +860,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
786
860
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
787
861
|
inference_method=inference_method,
|
788
862
|
input_cols=self.input_cols,
|
789
|
-
expected_output_cols=
|
863
|
+
expected_output_cols=expected_output_cols,
|
790
864
|
**transform_kwargs
|
791
865
|
)
|
792
866
|
return output_df
|
@@ -821,17 +895,15 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
821
895
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
822
896
|
|
823
897
|
if isinstance(dataset, DataFrame):
|
824
|
-
self.
|
825
|
-
|
826
|
-
inference_method="score",
|
827
|
-
)
|
898
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
899
|
+
self._deps = self._get_dependencies()
|
828
900
|
selected_cols = self._get_active_columns()
|
829
901
|
if len(selected_cols) > 0:
|
830
902
|
dataset = dataset.select(selected_cols)
|
831
903
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
832
904
|
transform_kwargs = dict(
|
833
905
|
session=dataset._session,
|
834
|
-
dependencies=
|
906
|
+
dependencies=self._deps,
|
835
907
|
score_sproc_imports=['sklearn'],
|
836
908
|
)
|
837
909
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -896,11 +968,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
896
968
|
|
897
969
|
if isinstance(dataset, DataFrame):
|
898
970
|
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method=inference_method,
|
902
|
-
|
903
|
-
)
|
971
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
972
|
+
self._deps = self._get_dependencies()
|
904
973
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
905
974
|
transform_kwargs = dict(
|
906
975
|
session = dataset._session,
|
@@ -933,50 +1002,84 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
933
1002
|
)
|
934
1003
|
return output_df
|
935
1004
|
|
1005
|
+
|
1006
|
+
|
1007
|
+
def to_sklearn(self) -> Any:
|
1008
|
+
"""Get sklearn.linear_model.OrthogonalMatchingPursuit object.
|
1009
|
+
"""
|
1010
|
+
if self._sklearn_object is None:
|
1011
|
+
self._sklearn_object = self._create_sklearn_object()
|
1012
|
+
return self._sklearn_object
|
1013
|
+
|
1014
|
+
def to_xgboost(self) -> Any:
|
1015
|
+
raise exceptions.SnowflakeMLException(
|
1016
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1017
|
+
original_exception=AttributeError(
|
1018
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1019
|
+
"to_xgboost()",
|
1020
|
+
"to_sklearn()"
|
1021
|
+
)
|
1022
|
+
),
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
def to_lightgbm(self) -> Any:
|
1026
|
+
raise exceptions.SnowflakeMLException(
|
1027
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1028
|
+
original_exception=AttributeError(
|
1029
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1030
|
+
"to_lightgbm()",
|
1031
|
+
"to_sklearn()"
|
1032
|
+
)
|
1033
|
+
),
|
1034
|
+
)
|
1035
|
+
|
1036
|
+
def _get_dependencies(self) -> List[str]:
|
1037
|
+
return self._deps
|
1038
|
+
|
936
1039
|
|
937
|
-
def
|
1040
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
938
1041
|
self._model_signature_dict = dict()
|
939
1042
|
|
940
1043
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
941
1044
|
|
942
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1045
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
943
1046
|
outputs: List[BaseFeatureSpec] = []
|
944
1047
|
if hasattr(self, "predict"):
|
945
1048
|
# keep mypy happy
|
946
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1049
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
947
1050
|
# For classifier, the type of predict is the same as the type of label
|
948
|
-
if self._sklearn_object._estimator_type ==
|
949
|
-
|
1051
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1052
|
+
# label columns is the desired type for output
|
950
1053
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
951
1054
|
# rename the output columns
|
952
1055
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
953
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
954
|
-
|
955
|
-
|
1056
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1057
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1058
|
+
)
|
956
1059
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
957
1060
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
958
|
-
# Clusterer returns int64 cluster labels.
|
1061
|
+
# Clusterer returns int64 cluster labels.
|
959
1062
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
960
1063
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
961
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
962
|
-
|
963
|
-
|
964
|
-
|
1064
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1065
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1066
|
+
)
|
1067
|
+
|
965
1068
|
# For regressor, the type of predict is float64
|
966
|
-
elif self._sklearn_object._estimator_type ==
|
1069
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
967
1070
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
968
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
969
|
-
|
970
|
-
|
971
|
-
|
1071
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1072
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1073
|
+
)
|
1074
|
+
|
972
1075
|
for prob_func in PROB_FUNCTIONS:
|
973
1076
|
if hasattr(self, prob_func):
|
974
1077
|
output_cols_prefix: str = f"{prob_func}_"
|
975
1078
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
976
1079
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
977
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
978
|
-
|
979
|
-
|
1080
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1081
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1082
|
+
)
|
980
1083
|
|
981
1084
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
982
1085
|
items = list(self._model_signature_dict.items())
|
@@ -989,10 +1092,10 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
989
1092
|
"""Returns model signature of current class.
|
990
1093
|
|
991
1094
|
Raises:
|
992
|
-
|
1095
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
993
1096
|
|
994
1097
|
Returns:
|
995
|
-
Dict
|
1098
|
+
Dict with each method and its input output signature
|
996
1099
|
"""
|
997
1100
|
if self._model_signature_dict is None:
|
998
1101
|
raise exceptions.SnowflakeMLException(
|
@@ -1000,35 +1103,3 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
1000
1103
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1001
1104
|
)
|
1002
1105
|
return self._model_signature_dict
|
1003
|
-
|
1004
|
-
def to_sklearn(self) -> Any:
|
1005
|
-
"""Get sklearn.linear_model.OrthogonalMatchingPursuit object.
|
1006
|
-
"""
|
1007
|
-
if self._sklearn_object is None:
|
1008
|
-
self._sklearn_object = self._create_sklearn_object()
|
1009
|
-
return self._sklearn_object
|
1010
|
-
|
1011
|
-
def to_xgboost(self) -> Any:
|
1012
|
-
raise exceptions.SnowflakeMLException(
|
1013
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1014
|
-
original_exception=AttributeError(
|
1015
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1016
|
-
"to_xgboost()",
|
1017
|
-
"to_sklearn()"
|
1018
|
-
)
|
1019
|
-
),
|
1020
|
-
)
|
1021
|
-
|
1022
|
-
def to_lightgbm(self) -> Any:
|
1023
|
-
raise exceptions.SnowflakeMLException(
|
1024
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1025
|
-
original_exception=AttributeError(
|
1026
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1027
|
-
"to_lightgbm()",
|
1028
|
-
"to_sklearn()"
|
1029
|
-
)
|
1030
|
-
),
|
1031
|
-
)
|
1032
|
-
|
1033
|
-
def _get_dependencies(self) -> List[str]:
|
1034
|
-
return self._deps
|