snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class LinearSVR(BaseTransformer):
|
71
64
|
r"""Linear Support Vector Regression
|
72
65
|
For more details on this class, see [sklearn.svm.LinearSVR]
|
@@ -265,12 +258,7 @@ class LinearSVR(BaseTransformer):
|
|
265
258
|
)
|
266
259
|
return selected_cols
|
267
260
|
|
268
|
-
|
269
|
-
project=_PROJECT,
|
270
|
-
subproject=_SUBPROJECT,
|
271
|
-
custom_tags=dict([("autogen", True)]),
|
272
|
-
)
|
273
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LinearSVR":
|
261
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LinearSVR":
|
274
262
|
"""Fit the model according to the given training data
|
275
263
|
For more details on this function, see [sklearn.svm.LinearSVR.fit]
|
276
264
|
(https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVR.html#sklearn.svm.LinearSVR.fit)
|
@@ -297,12 +285,14 @@ class LinearSVR(BaseTransformer):
|
|
297
285
|
|
298
286
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
299
287
|
|
300
|
-
|
288
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
301
289
|
if SNOWML_SPROC_ENV in os.environ:
|
302
290
|
statement_params = telemetry.get_function_usage_statement_params(
|
303
291
|
project=_PROJECT,
|
304
292
|
subproject=_SUBPROJECT,
|
305
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
293
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
294
|
+
inspect.currentframe(), LinearSVR.__class__.__name__
|
295
|
+
),
|
306
296
|
api_calls=[Session.call],
|
307
297
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
308
298
|
)
|
@@ -323,27 +313,24 @@ class LinearSVR(BaseTransformer):
|
|
323
313
|
)
|
324
314
|
self._sklearn_object = model_trainer.train()
|
325
315
|
self._is_fitted = True
|
326
|
-
self.
|
316
|
+
self._generate_model_signatures(dataset)
|
327
317
|
return self
|
328
318
|
|
329
319
|
def _batch_inference_validate_snowpark(
|
330
320
|
self,
|
331
321
|
dataset: DataFrame,
|
332
322
|
inference_method: str,
|
333
|
-
) ->
|
334
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
335
|
-
return the available package that exists in the snowflake anaconda channel
|
323
|
+
) -> None:
|
324
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
336
325
|
|
337
326
|
Args:
|
338
327
|
dataset: snowpark dataframe
|
339
328
|
inference_method: the inference method such as predict, score...
|
340
|
-
|
329
|
+
|
341
330
|
Raises:
|
342
331
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
343
332
|
SnowflakeMLException: If the session is None, raise error
|
344
333
|
|
345
|
-
Returns:
|
346
|
-
A list of available package that exists in the snowflake anaconda channel
|
347
334
|
"""
|
348
335
|
if not self._is_fitted:
|
349
336
|
raise exceptions.SnowflakeMLException(
|
@@ -361,9 +348,7 @@ class LinearSVR(BaseTransformer):
|
|
361
348
|
"Session must not specified for snowpark dataset."
|
362
349
|
),
|
363
350
|
)
|
364
|
-
|
365
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
351
|
+
|
367
352
|
|
368
353
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
369
354
|
@telemetry.send_api_usage_telemetry(
|
@@ -399,7 +384,9 @@ class LinearSVR(BaseTransformer):
|
|
399
384
|
# when it is classifier, infer the datatype from label columns
|
400
385
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
401
386
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
402
|
-
label_cols_signatures = [
|
387
|
+
label_cols_signatures = [
|
388
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
389
|
+
]
|
403
390
|
if len(label_cols_signatures) == 0:
|
404
391
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
405
392
|
raise exceptions.SnowflakeMLException(
|
@@ -407,25 +394,23 @@ class LinearSVR(BaseTransformer):
|
|
407
394
|
original_exception=ValueError(error_str),
|
408
395
|
)
|
409
396
|
|
410
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
411
|
-
label_cols_signatures[0].as_snowpark_type()
|
412
|
-
)
|
397
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
413
398
|
|
414
|
-
self.
|
415
|
-
|
399
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
400
|
+
self._deps = self._get_dependencies()
|
401
|
+
assert isinstance(
|
402
|
+
dataset._session, Session
|
403
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
416
404
|
|
417
405
|
transform_kwargs = dict(
|
418
|
-
session
|
419
|
-
dependencies
|
420
|
-
drop_input_cols
|
421
|
-
expected_output_cols_type
|
406
|
+
session=dataset._session,
|
407
|
+
dependencies=self._deps,
|
408
|
+
drop_input_cols=self._drop_input_cols,
|
409
|
+
expected_output_cols_type=expected_type_inferred,
|
422
410
|
)
|
423
411
|
|
424
412
|
elif isinstance(dataset, pd.DataFrame):
|
425
|
-
transform_kwargs = dict(
|
426
|
-
snowpark_input_cols = self._snowpark_cols,
|
427
|
-
drop_input_cols = self._drop_input_cols
|
428
|
-
)
|
413
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
429
414
|
|
430
415
|
transform_handlers = ModelTransformerBuilder.build(
|
431
416
|
dataset=dataset,
|
@@ -465,7 +450,7 @@ class LinearSVR(BaseTransformer):
|
|
465
450
|
Transformed dataset.
|
466
451
|
"""
|
467
452
|
super()._check_dataset_type(dataset)
|
468
|
-
inference_method="transform"
|
453
|
+
inference_method = "transform"
|
469
454
|
|
470
455
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
471
456
|
# are specific to the type of dataset used.
|
@@ -495,24 +480,19 @@ class LinearSVR(BaseTransformer):
|
|
495
480
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
496
481
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
497
482
|
|
498
|
-
self.
|
499
|
-
|
500
|
-
inference_method=inference_method,
|
501
|
-
)
|
483
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
484
|
+
self._deps = self._get_dependencies()
|
502
485
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
503
486
|
|
504
487
|
transform_kwargs = dict(
|
505
|
-
session
|
506
|
-
dependencies
|
507
|
-
drop_input_cols
|
508
|
-
expected_output_cols_type
|
488
|
+
session=dataset._session,
|
489
|
+
dependencies=self._deps,
|
490
|
+
drop_input_cols=self._drop_input_cols,
|
491
|
+
expected_output_cols_type=expected_dtype,
|
509
492
|
)
|
510
493
|
|
511
494
|
elif isinstance(dataset, pd.DataFrame):
|
512
|
-
transform_kwargs = dict(
|
513
|
-
snowpark_input_cols = self._snowpark_cols,
|
514
|
-
drop_input_cols = self._drop_input_cols
|
515
|
-
)
|
495
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
516
496
|
|
517
497
|
transform_handlers = ModelTransformerBuilder.build(
|
518
498
|
dataset=dataset,
|
@@ -531,7 +511,11 @@ class LinearSVR(BaseTransformer):
|
|
531
511
|
return output_df
|
532
512
|
|
533
513
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
534
|
-
def fit_predict(
|
514
|
+
def fit_predict(
|
515
|
+
self,
|
516
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
517
|
+
output_cols_prefix: str = "fit_predict_",
|
518
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
535
519
|
""" Method not supported for this class.
|
536
520
|
|
537
521
|
|
@@ -556,22 +540,104 @@ class LinearSVR(BaseTransformer):
|
|
556
540
|
)
|
557
541
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
558
542
|
drop_input_cols=self._drop_input_cols,
|
559
|
-
expected_output_cols_list=
|
543
|
+
expected_output_cols_list=(
|
544
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
545
|
+
),
|
560
546
|
)
|
561
547
|
self._sklearn_object = fitted_estimator
|
562
548
|
self._is_fitted = True
|
563
549
|
return output_result
|
564
550
|
|
551
|
+
|
552
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
553
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
554
|
+
""" Method not supported for this class.
|
555
|
+
|
565
556
|
|
566
|
-
|
567
|
-
|
568
|
-
|
557
|
+
Raises:
|
558
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
559
|
+
|
560
|
+
Args:
|
561
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
562
|
+
Snowpark or Pandas DataFrame.
|
563
|
+
output_cols_prefix: Prefix for the response columns
|
569
564
|
Returns:
|
570
565
|
Transformed dataset.
|
571
566
|
"""
|
572
|
-
self.
|
573
|
-
|
574
|
-
|
567
|
+
self._infer_input_output_cols(dataset)
|
568
|
+
super()._check_dataset_type(dataset)
|
569
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
570
|
+
estimator=self._sklearn_object,
|
571
|
+
dataset=dataset,
|
572
|
+
input_cols=self.input_cols,
|
573
|
+
label_cols=self.label_cols,
|
574
|
+
sample_weight_col=self.sample_weight_col,
|
575
|
+
autogenerated=self._autogenerated,
|
576
|
+
subproject=_SUBPROJECT,
|
577
|
+
)
|
578
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
579
|
+
drop_input_cols=self._drop_input_cols,
|
580
|
+
expected_output_cols_list=self.output_cols,
|
581
|
+
)
|
582
|
+
self._sklearn_object = fitted_estimator
|
583
|
+
self._is_fitted = True
|
584
|
+
return output_result
|
585
|
+
|
586
|
+
|
587
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
588
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
589
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
590
|
+
"""
|
591
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
592
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
593
|
+
if output_cols:
|
594
|
+
output_cols = [
|
595
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
596
|
+
for c in output_cols
|
597
|
+
]
|
598
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
599
|
+
output_cols = [output_cols_prefix]
|
600
|
+
elif self._sklearn_object is not None:
|
601
|
+
classes = self._sklearn_object.classes_
|
602
|
+
if isinstance(classes, numpy.ndarray):
|
603
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
604
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
605
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
606
|
+
output_cols = []
|
607
|
+
for i, cl in enumerate(classes):
|
608
|
+
# For binary classification, there is only one output column for each class
|
609
|
+
# ndarray as the two classes are complementary.
|
610
|
+
if len(cl) == 2:
|
611
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
612
|
+
else:
|
613
|
+
output_cols.extend([
|
614
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
615
|
+
])
|
616
|
+
else:
|
617
|
+
output_cols = []
|
618
|
+
|
619
|
+
# Make sure column names are valid snowflake identifiers.
|
620
|
+
assert output_cols is not None # Make MyPy happy
|
621
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
622
|
+
|
623
|
+
return rv
|
624
|
+
|
625
|
+
def _align_expected_output_names(
|
626
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
627
|
+
) -> List[str]:
|
628
|
+
# in case the inferred output column names dimension is different
|
629
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
630
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
631
|
+
output_df_columns = list(output_df_pd.columns)
|
632
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
633
|
+
if self.sample_weight_col:
|
634
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
635
|
+
# if the dimension of inferred output column names is correct; use it
|
636
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
637
|
+
return expected_output_cols_list
|
638
|
+
# otherwise, use the sklearn estimator's output
|
639
|
+
else:
|
640
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
575
641
|
|
576
642
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
577
643
|
@telemetry.send_api_usage_telemetry(
|
@@ -603,24 +669,26 @@ class LinearSVR(BaseTransformer):
|
|
603
669
|
# are specific to the type of dataset used.
|
604
670
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
605
671
|
|
672
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
673
|
+
|
606
674
|
if isinstance(dataset, DataFrame):
|
607
|
-
self.
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
675
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
676
|
+
self._deps = self._get_dependencies()
|
677
|
+
assert isinstance(
|
678
|
+
dataset._session, Session
|
679
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
612
680
|
transform_kwargs = dict(
|
613
681
|
session=dataset._session,
|
614
682
|
dependencies=self._deps,
|
615
|
-
drop_input_cols
|
683
|
+
drop_input_cols=self._drop_input_cols,
|
616
684
|
expected_output_cols_type="float",
|
617
685
|
)
|
686
|
+
expected_output_cols = self._align_expected_output_names(
|
687
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
688
|
+
)
|
618
689
|
|
619
690
|
elif isinstance(dataset, pd.DataFrame):
|
620
|
-
transform_kwargs = dict(
|
621
|
-
snowpark_input_cols = self._snowpark_cols,
|
622
|
-
drop_input_cols = self._drop_input_cols
|
623
|
-
)
|
691
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
624
692
|
|
625
693
|
transform_handlers = ModelTransformerBuilder.build(
|
626
694
|
dataset=dataset,
|
@@ -632,7 +700,7 @@ class LinearSVR(BaseTransformer):
|
|
632
700
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
633
701
|
inference_method=inference_method,
|
634
702
|
input_cols=self.input_cols,
|
635
|
-
expected_output_cols=
|
703
|
+
expected_output_cols=expected_output_cols,
|
636
704
|
**transform_kwargs
|
637
705
|
)
|
638
706
|
return output_df
|
@@ -662,29 +730,30 @@ class LinearSVR(BaseTransformer):
|
|
662
730
|
Output dataset with log probability of the sample for each class in the model.
|
663
731
|
"""
|
664
732
|
super()._check_dataset_type(dataset)
|
665
|
-
inference_method="predict_log_proba"
|
733
|
+
inference_method = "predict_log_proba"
|
734
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
666
735
|
|
667
736
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
668
737
|
# are specific to the type of dataset used.
|
669
738
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
670
739
|
|
671
740
|
if isinstance(dataset, DataFrame):
|
672
|
-
self.
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
741
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
742
|
+
self._deps = self._get_dependencies()
|
743
|
+
assert isinstance(
|
744
|
+
dataset._session, Session
|
745
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
677
746
|
transform_kwargs = dict(
|
678
747
|
session=dataset._session,
|
679
748
|
dependencies=self._deps,
|
680
|
-
drop_input_cols
|
749
|
+
drop_input_cols=self._drop_input_cols,
|
681
750
|
expected_output_cols_type="float",
|
682
751
|
)
|
752
|
+
expected_output_cols = self._align_expected_output_names(
|
753
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
754
|
+
)
|
683
755
|
elif isinstance(dataset, pd.DataFrame):
|
684
|
-
transform_kwargs = dict(
|
685
|
-
snowpark_input_cols = self._snowpark_cols,
|
686
|
-
drop_input_cols = self._drop_input_cols
|
687
|
-
)
|
756
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
688
757
|
|
689
758
|
transform_handlers = ModelTransformerBuilder.build(
|
690
759
|
dataset=dataset,
|
@@ -697,7 +766,7 @@ class LinearSVR(BaseTransformer):
|
|
697
766
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
698
767
|
inference_method=inference_method,
|
699
768
|
input_cols=self.input_cols,
|
700
|
-
expected_output_cols=
|
769
|
+
expected_output_cols=expected_output_cols,
|
701
770
|
**transform_kwargs
|
702
771
|
)
|
703
772
|
return output_df
|
@@ -723,30 +792,32 @@ class LinearSVR(BaseTransformer):
|
|
723
792
|
Output dataset with results of the decision function for the samples in input dataset.
|
724
793
|
"""
|
725
794
|
super()._check_dataset_type(dataset)
|
726
|
-
inference_method="decision_function"
|
795
|
+
inference_method = "decision_function"
|
727
796
|
|
728
797
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
729
798
|
# are specific to the type of dataset used.
|
730
799
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
731
800
|
|
801
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
802
|
+
|
732
803
|
if isinstance(dataset, DataFrame):
|
733
|
-
self.
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
804
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
805
|
+
self._deps = self._get_dependencies()
|
806
|
+
assert isinstance(
|
807
|
+
dataset._session, Session
|
808
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
738
809
|
transform_kwargs = dict(
|
739
810
|
session=dataset._session,
|
740
811
|
dependencies=self._deps,
|
741
|
-
drop_input_cols
|
812
|
+
drop_input_cols=self._drop_input_cols,
|
742
813
|
expected_output_cols_type="float",
|
743
814
|
)
|
815
|
+
expected_output_cols = self._align_expected_output_names(
|
816
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
817
|
+
)
|
744
818
|
|
745
819
|
elif isinstance(dataset, pd.DataFrame):
|
746
|
-
transform_kwargs = dict(
|
747
|
-
snowpark_input_cols = self._snowpark_cols,
|
748
|
-
drop_input_cols = self._drop_input_cols
|
749
|
-
)
|
820
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
750
821
|
|
751
822
|
transform_handlers = ModelTransformerBuilder.build(
|
752
823
|
dataset=dataset,
|
@@ -759,7 +830,7 @@ class LinearSVR(BaseTransformer):
|
|
759
830
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
760
831
|
inference_method=inference_method,
|
761
832
|
input_cols=self.input_cols,
|
762
|
-
expected_output_cols=
|
833
|
+
expected_output_cols=expected_output_cols,
|
763
834
|
**transform_kwargs
|
764
835
|
)
|
765
836
|
return output_df
|
@@ -788,17 +859,17 @@ class LinearSVR(BaseTransformer):
|
|
788
859
|
Output dataset with probability of the sample for each class in the model.
|
789
860
|
"""
|
790
861
|
super()._check_dataset_type(dataset)
|
791
|
-
inference_method="score_samples"
|
862
|
+
inference_method = "score_samples"
|
792
863
|
|
793
864
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
794
865
|
# are specific to the type of dataset used.
|
795
866
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
796
867
|
|
868
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
869
|
+
|
797
870
|
if isinstance(dataset, DataFrame):
|
798
|
-
self.
|
799
|
-
|
800
|
-
inference_method=inference_method,
|
801
|
-
)
|
871
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
872
|
+
self._deps = self._get_dependencies()
|
802
873
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
803
874
|
transform_kwargs = dict(
|
804
875
|
session=dataset._session,
|
@@ -806,6 +877,9 @@ class LinearSVR(BaseTransformer):
|
|
806
877
|
drop_input_cols = self._drop_input_cols,
|
807
878
|
expected_output_cols_type="float",
|
808
879
|
)
|
880
|
+
expected_output_cols = self._align_expected_output_names(
|
881
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
882
|
+
)
|
809
883
|
|
810
884
|
elif isinstance(dataset, pd.DataFrame):
|
811
885
|
transform_kwargs = dict(
|
@@ -824,7 +898,7 @@ class LinearSVR(BaseTransformer):
|
|
824
898
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
825
899
|
inference_method=inference_method,
|
826
900
|
input_cols=self.input_cols,
|
827
|
-
expected_output_cols=
|
901
|
+
expected_output_cols=expected_output_cols,
|
828
902
|
**transform_kwargs
|
829
903
|
)
|
830
904
|
return output_df
|
@@ -859,17 +933,15 @@ class LinearSVR(BaseTransformer):
|
|
859
933
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
860
934
|
|
861
935
|
if isinstance(dataset, DataFrame):
|
862
|
-
self.
|
863
|
-
|
864
|
-
inference_method="score",
|
865
|
-
)
|
936
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
937
|
+
self._deps = self._get_dependencies()
|
866
938
|
selected_cols = self._get_active_columns()
|
867
939
|
if len(selected_cols) > 0:
|
868
940
|
dataset = dataset.select(selected_cols)
|
869
941
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
870
942
|
transform_kwargs = dict(
|
871
943
|
session=dataset._session,
|
872
|
-
dependencies=
|
944
|
+
dependencies=self._deps,
|
873
945
|
score_sproc_imports=['sklearn'],
|
874
946
|
)
|
875
947
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -934,11 +1006,8 @@ class LinearSVR(BaseTransformer):
|
|
934
1006
|
|
935
1007
|
if isinstance(dataset, DataFrame):
|
936
1008
|
|
937
|
-
self.
|
938
|
-
|
939
|
-
inference_method=inference_method,
|
940
|
-
|
941
|
-
)
|
1009
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1010
|
+
self._deps = self._get_dependencies()
|
942
1011
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
943
1012
|
transform_kwargs = dict(
|
944
1013
|
session = dataset._session,
|
@@ -971,50 +1040,84 @@ class LinearSVR(BaseTransformer):
|
|
971
1040
|
)
|
972
1041
|
return output_df
|
973
1042
|
|
1043
|
+
|
1044
|
+
|
1045
|
+
def to_sklearn(self) -> Any:
|
1046
|
+
"""Get sklearn.svm.LinearSVR object.
|
1047
|
+
"""
|
1048
|
+
if self._sklearn_object is None:
|
1049
|
+
self._sklearn_object = self._create_sklearn_object()
|
1050
|
+
return self._sklearn_object
|
1051
|
+
|
1052
|
+
def to_xgboost(self) -> Any:
|
1053
|
+
raise exceptions.SnowflakeMLException(
|
1054
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1055
|
+
original_exception=AttributeError(
|
1056
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1057
|
+
"to_xgboost()",
|
1058
|
+
"to_sklearn()"
|
1059
|
+
)
|
1060
|
+
),
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
def to_lightgbm(self) -> Any:
|
1064
|
+
raise exceptions.SnowflakeMLException(
|
1065
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1066
|
+
original_exception=AttributeError(
|
1067
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1068
|
+
"to_lightgbm()",
|
1069
|
+
"to_sklearn()"
|
1070
|
+
)
|
1071
|
+
),
|
1072
|
+
)
|
1073
|
+
|
1074
|
+
def _get_dependencies(self) -> List[str]:
|
1075
|
+
return self._deps
|
1076
|
+
|
974
1077
|
|
975
|
-
def
|
1078
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
976
1079
|
self._model_signature_dict = dict()
|
977
1080
|
|
978
1081
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
979
1082
|
|
980
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1083
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
981
1084
|
outputs: List[BaseFeatureSpec] = []
|
982
1085
|
if hasattr(self, "predict"):
|
983
1086
|
# keep mypy happy
|
984
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1087
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
985
1088
|
# For classifier, the type of predict is the same as the type of label
|
986
|
-
if self._sklearn_object._estimator_type ==
|
987
|
-
|
1089
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1090
|
+
# label columns is the desired type for output
|
988
1091
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
989
1092
|
# rename the output columns
|
990
1093
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
991
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
992
|
-
|
993
|
-
|
1094
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1095
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1096
|
+
)
|
994
1097
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
995
1098
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
996
|
-
# Clusterer returns int64 cluster labels.
|
1099
|
+
# Clusterer returns int64 cluster labels.
|
997
1100
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
998
1101
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
999
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1102
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1103
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1104
|
+
)
|
1105
|
+
|
1003
1106
|
# For regressor, the type of predict is float64
|
1004
|
-
elif self._sklearn_object._estimator_type ==
|
1107
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1005
1108
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1006
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1109
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1110
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1111
|
+
)
|
1112
|
+
|
1010
1113
|
for prob_func in PROB_FUNCTIONS:
|
1011
1114
|
if hasattr(self, prob_func):
|
1012
1115
|
output_cols_prefix: str = f"{prob_func}_"
|
1013
1116
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1014
1117
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1015
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1016
|
-
|
1017
|
-
|
1118
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1119
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1120
|
+
)
|
1018
1121
|
|
1019
1122
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1020
1123
|
items = list(self._model_signature_dict.items())
|
@@ -1027,10 +1130,10 @@ class LinearSVR(BaseTransformer):
|
|
1027
1130
|
"""Returns model signature of current class.
|
1028
1131
|
|
1029
1132
|
Raises:
|
1030
|
-
|
1133
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1031
1134
|
|
1032
1135
|
Returns:
|
1033
|
-
Dict
|
1136
|
+
Dict with each method and its input output signature
|
1034
1137
|
"""
|
1035
1138
|
if self._model_signature_dict is None:
|
1036
1139
|
raise exceptions.SnowflakeMLException(
|
@@ -1038,35 +1141,3 @@ class LinearSVR(BaseTransformer):
|
|
1038
1141
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1039
1142
|
)
|
1040
1143
|
return self._model_signature_dict
|
1041
|
-
|
1042
|
-
def to_sklearn(self) -> Any:
|
1043
|
-
"""Get sklearn.svm.LinearSVR object.
|
1044
|
-
"""
|
1045
|
-
if self._sklearn_object is None:
|
1046
|
-
self._sklearn_object = self._create_sklearn_object()
|
1047
|
-
return self._sklearn_object
|
1048
|
-
|
1049
|
-
def to_xgboost(self) -> Any:
|
1050
|
-
raise exceptions.SnowflakeMLException(
|
1051
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1052
|
-
original_exception=AttributeError(
|
1053
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1054
|
-
"to_xgboost()",
|
1055
|
-
"to_sklearn()"
|
1056
|
-
)
|
1057
|
-
),
|
1058
|
-
)
|
1059
|
-
|
1060
|
-
def to_lightgbm(self) -> Any:
|
1061
|
-
raise exceptions.SnowflakeMLException(
|
1062
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1063
|
-
original_exception=AttributeError(
|
1064
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1065
|
-
"to_lightgbm()",
|
1066
|
-
"to_sklearn()"
|
1067
|
-
)
|
1068
|
-
),
|
1069
|
-
)
|
1070
|
-
|
1071
|
-
def _get_dependencies(self) -> List[str]:
|
1072
|
-
return self._deps
|