snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class SGDOneClassSVM(BaseTransformer):
|
71
64
|
r"""Solves linear One-Class SVM using Stochastic Gradient Descent
|
72
65
|
For more details on this class, see [sklearn.linear_model.SGDOneClassSVM]
|
@@ -286,12 +279,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
286
279
|
)
|
287
280
|
return selected_cols
|
288
281
|
|
289
|
-
|
290
|
-
project=_PROJECT,
|
291
|
-
subproject=_SUBPROJECT,
|
292
|
-
custom_tags=dict([("autogen", True)]),
|
293
|
-
)
|
294
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SGDOneClassSVM":
|
282
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SGDOneClassSVM":
|
295
283
|
"""Fit linear One-Class SVM with Stochastic Gradient Descent
|
296
284
|
For more details on this function, see [sklearn.linear_model.SGDOneClassSVM.fit]
|
297
285
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDOneClassSVM.html#sklearn.linear_model.SGDOneClassSVM.fit)
|
@@ -318,12 +306,14 @@ class SGDOneClassSVM(BaseTransformer):
|
|
318
306
|
|
319
307
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
320
308
|
|
321
|
-
|
309
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
322
310
|
if SNOWML_SPROC_ENV in os.environ:
|
323
311
|
statement_params = telemetry.get_function_usage_statement_params(
|
324
312
|
project=_PROJECT,
|
325
313
|
subproject=_SUBPROJECT,
|
326
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
314
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
315
|
+
inspect.currentframe(), SGDOneClassSVM.__class__.__name__
|
316
|
+
),
|
327
317
|
api_calls=[Session.call],
|
328
318
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
329
319
|
)
|
@@ -344,27 +334,24 @@ class SGDOneClassSVM(BaseTransformer):
|
|
344
334
|
)
|
345
335
|
self._sklearn_object = model_trainer.train()
|
346
336
|
self._is_fitted = True
|
347
|
-
self.
|
337
|
+
self._generate_model_signatures(dataset)
|
348
338
|
return self
|
349
339
|
|
350
340
|
def _batch_inference_validate_snowpark(
|
351
341
|
self,
|
352
342
|
dataset: DataFrame,
|
353
343
|
inference_method: str,
|
354
|
-
) ->
|
355
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
356
|
-
return the available package that exists in the snowflake anaconda channel
|
344
|
+
) -> None:
|
345
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
357
346
|
|
358
347
|
Args:
|
359
348
|
dataset: snowpark dataframe
|
360
349
|
inference_method: the inference method such as predict, score...
|
361
|
-
|
350
|
+
|
362
351
|
Raises:
|
363
352
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
364
353
|
SnowflakeMLException: If the session is None, raise error
|
365
354
|
|
366
|
-
Returns:
|
367
|
-
A list of available package that exists in the snowflake anaconda channel
|
368
355
|
"""
|
369
356
|
if not self._is_fitted:
|
370
357
|
raise exceptions.SnowflakeMLException(
|
@@ -382,9 +369,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
382
369
|
"Session must not specified for snowpark dataset."
|
383
370
|
),
|
384
371
|
)
|
385
|
-
|
386
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
387
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
372
|
+
|
388
373
|
|
389
374
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
390
375
|
@telemetry.send_api_usage_telemetry(
|
@@ -420,7 +405,9 @@ class SGDOneClassSVM(BaseTransformer):
|
|
420
405
|
# when it is classifier, infer the datatype from label columns
|
421
406
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
422
407
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
423
|
-
label_cols_signatures = [
|
408
|
+
label_cols_signatures = [
|
409
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
410
|
+
]
|
424
411
|
if len(label_cols_signatures) == 0:
|
425
412
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
426
413
|
raise exceptions.SnowflakeMLException(
|
@@ -428,25 +415,23 @@ class SGDOneClassSVM(BaseTransformer):
|
|
428
415
|
original_exception=ValueError(error_str),
|
429
416
|
)
|
430
417
|
|
431
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
432
|
-
label_cols_signatures[0].as_snowpark_type()
|
433
|
-
)
|
418
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
434
419
|
|
435
|
-
self.
|
436
|
-
|
420
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
421
|
+
self._deps = self._get_dependencies()
|
422
|
+
assert isinstance(
|
423
|
+
dataset._session, Session
|
424
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
437
425
|
|
438
426
|
transform_kwargs = dict(
|
439
|
-
session
|
440
|
-
dependencies
|
441
|
-
drop_input_cols
|
442
|
-
expected_output_cols_type
|
427
|
+
session=dataset._session,
|
428
|
+
dependencies=self._deps,
|
429
|
+
drop_input_cols=self._drop_input_cols,
|
430
|
+
expected_output_cols_type=expected_type_inferred,
|
443
431
|
)
|
444
432
|
|
445
433
|
elif isinstance(dataset, pd.DataFrame):
|
446
|
-
transform_kwargs = dict(
|
447
|
-
snowpark_input_cols = self._snowpark_cols,
|
448
|
-
drop_input_cols = self._drop_input_cols
|
449
|
-
)
|
434
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
450
435
|
|
451
436
|
transform_handlers = ModelTransformerBuilder.build(
|
452
437
|
dataset=dataset,
|
@@ -486,7 +471,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
486
471
|
Transformed dataset.
|
487
472
|
"""
|
488
473
|
super()._check_dataset_type(dataset)
|
489
|
-
inference_method="transform"
|
474
|
+
inference_method = "transform"
|
490
475
|
|
491
476
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
492
477
|
# are specific to the type of dataset used.
|
@@ -516,24 +501,19 @@ class SGDOneClassSVM(BaseTransformer):
|
|
516
501
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
517
502
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
518
503
|
|
519
|
-
self.
|
520
|
-
|
521
|
-
inference_method=inference_method,
|
522
|
-
)
|
504
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
505
|
+
self._deps = self._get_dependencies()
|
523
506
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
524
507
|
|
525
508
|
transform_kwargs = dict(
|
526
|
-
session
|
527
|
-
dependencies
|
528
|
-
drop_input_cols
|
529
|
-
expected_output_cols_type
|
509
|
+
session=dataset._session,
|
510
|
+
dependencies=self._deps,
|
511
|
+
drop_input_cols=self._drop_input_cols,
|
512
|
+
expected_output_cols_type=expected_dtype,
|
530
513
|
)
|
531
514
|
|
532
515
|
elif isinstance(dataset, pd.DataFrame):
|
533
|
-
transform_kwargs = dict(
|
534
|
-
snowpark_input_cols = self._snowpark_cols,
|
535
|
-
drop_input_cols = self._drop_input_cols
|
536
|
-
)
|
516
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
537
517
|
|
538
518
|
transform_handlers = ModelTransformerBuilder.build(
|
539
519
|
dataset=dataset,
|
@@ -552,7 +532,11 @@ class SGDOneClassSVM(BaseTransformer):
|
|
552
532
|
return output_df
|
553
533
|
|
554
534
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
555
|
-
def fit_predict(
|
535
|
+
def fit_predict(
|
536
|
+
self,
|
537
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
538
|
+
output_cols_prefix: str = "fit_predict_",
|
539
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
556
540
|
""" Perform fit on X and returns labels for X
|
557
541
|
For more details on this function, see [sklearn.linear_model.SGDOneClassSVM.fit_predict]
|
558
542
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDOneClassSVM.html#sklearn.linear_model.SGDOneClassSVM.fit_predict)
|
@@ -579,22 +563,104 @@ class SGDOneClassSVM(BaseTransformer):
|
|
579
563
|
)
|
580
564
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
581
565
|
drop_input_cols=self._drop_input_cols,
|
582
|
-
expected_output_cols_list=
|
566
|
+
expected_output_cols_list=(
|
567
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
568
|
+
),
|
583
569
|
)
|
584
570
|
self._sklearn_object = fitted_estimator
|
585
571
|
self._is_fitted = True
|
586
572
|
return output_result
|
587
573
|
|
574
|
+
|
575
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
576
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
577
|
+
""" Method not supported for this class.
|
578
|
+
|
579
|
+
|
580
|
+
Raises:
|
581
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
588
582
|
|
589
|
-
|
590
|
-
|
591
|
-
|
583
|
+
Args:
|
584
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
585
|
+
Snowpark or Pandas DataFrame.
|
586
|
+
output_cols_prefix: Prefix for the response columns
|
592
587
|
Returns:
|
593
588
|
Transformed dataset.
|
594
589
|
"""
|
595
|
-
self.
|
596
|
-
|
597
|
-
|
590
|
+
self._infer_input_output_cols(dataset)
|
591
|
+
super()._check_dataset_type(dataset)
|
592
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
593
|
+
estimator=self._sklearn_object,
|
594
|
+
dataset=dataset,
|
595
|
+
input_cols=self.input_cols,
|
596
|
+
label_cols=self.label_cols,
|
597
|
+
sample_weight_col=self.sample_weight_col,
|
598
|
+
autogenerated=self._autogenerated,
|
599
|
+
subproject=_SUBPROJECT,
|
600
|
+
)
|
601
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
602
|
+
drop_input_cols=self._drop_input_cols,
|
603
|
+
expected_output_cols_list=self.output_cols,
|
604
|
+
)
|
605
|
+
self._sklearn_object = fitted_estimator
|
606
|
+
self._is_fitted = True
|
607
|
+
return output_result
|
608
|
+
|
609
|
+
|
610
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
612
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
613
|
+
"""
|
614
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
615
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
616
|
+
if output_cols:
|
617
|
+
output_cols = [
|
618
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
619
|
+
for c in output_cols
|
620
|
+
]
|
621
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
622
|
+
output_cols = [output_cols_prefix]
|
623
|
+
elif self._sklearn_object is not None:
|
624
|
+
classes = self._sklearn_object.classes_
|
625
|
+
if isinstance(classes, numpy.ndarray):
|
626
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
627
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
628
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
629
|
+
output_cols = []
|
630
|
+
for i, cl in enumerate(classes):
|
631
|
+
# For binary classification, there is only one output column for each class
|
632
|
+
# ndarray as the two classes are complementary.
|
633
|
+
if len(cl) == 2:
|
634
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
635
|
+
else:
|
636
|
+
output_cols.extend([
|
637
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
638
|
+
])
|
639
|
+
else:
|
640
|
+
output_cols = []
|
641
|
+
|
642
|
+
# Make sure column names are valid snowflake identifiers.
|
643
|
+
assert output_cols is not None # Make MyPy happy
|
644
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
645
|
+
|
646
|
+
return rv
|
647
|
+
|
648
|
+
def _align_expected_output_names(
|
649
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
650
|
+
) -> List[str]:
|
651
|
+
# in case the inferred output column names dimension is different
|
652
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
653
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
654
|
+
output_df_columns = list(output_df_pd.columns)
|
655
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
656
|
+
if self.sample_weight_col:
|
657
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
658
|
+
# if the dimension of inferred output column names is correct; use it
|
659
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
660
|
+
return expected_output_cols_list
|
661
|
+
# otherwise, use the sklearn estimator's output
|
662
|
+
else:
|
663
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
598
664
|
|
599
665
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
600
666
|
@telemetry.send_api_usage_telemetry(
|
@@ -626,24 +692,26 @@ class SGDOneClassSVM(BaseTransformer):
|
|
626
692
|
# are specific to the type of dataset used.
|
627
693
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
628
694
|
|
695
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
696
|
+
|
629
697
|
if isinstance(dataset, DataFrame):
|
630
|
-
self.
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
698
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
699
|
+
self._deps = self._get_dependencies()
|
700
|
+
assert isinstance(
|
701
|
+
dataset._session, Session
|
702
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
635
703
|
transform_kwargs = dict(
|
636
704
|
session=dataset._session,
|
637
705
|
dependencies=self._deps,
|
638
|
-
drop_input_cols
|
706
|
+
drop_input_cols=self._drop_input_cols,
|
639
707
|
expected_output_cols_type="float",
|
640
708
|
)
|
709
|
+
expected_output_cols = self._align_expected_output_names(
|
710
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
711
|
+
)
|
641
712
|
|
642
713
|
elif isinstance(dataset, pd.DataFrame):
|
643
|
-
transform_kwargs = dict(
|
644
|
-
snowpark_input_cols = self._snowpark_cols,
|
645
|
-
drop_input_cols = self._drop_input_cols
|
646
|
-
)
|
714
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
647
715
|
|
648
716
|
transform_handlers = ModelTransformerBuilder.build(
|
649
717
|
dataset=dataset,
|
@@ -655,7 +723,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
655
723
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
656
724
|
inference_method=inference_method,
|
657
725
|
input_cols=self.input_cols,
|
658
|
-
expected_output_cols=
|
726
|
+
expected_output_cols=expected_output_cols,
|
659
727
|
**transform_kwargs
|
660
728
|
)
|
661
729
|
return output_df
|
@@ -685,29 +753,30 @@ class SGDOneClassSVM(BaseTransformer):
|
|
685
753
|
Output dataset with log probability of the sample for each class in the model.
|
686
754
|
"""
|
687
755
|
super()._check_dataset_type(dataset)
|
688
|
-
inference_method="predict_log_proba"
|
756
|
+
inference_method = "predict_log_proba"
|
757
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
689
758
|
|
690
759
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
691
760
|
# are specific to the type of dataset used.
|
692
761
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
693
762
|
|
694
763
|
if isinstance(dataset, DataFrame):
|
695
|
-
self.
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
766
|
+
assert isinstance(
|
767
|
+
dataset._session, Session
|
768
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
700
769
|
transform_kwargs = dict(
|
701
770
|
session=dataset._session,
|
702
771
|
dependencies=self._deps,
|
703
|
-
drop_input_cols
|
772
|
+
drop_input_cols=self._drop_input_cols,
|
704
773
|
expected_output_cols_type="float",
|
705
774
|
)
|
775
|
+
expected_output_cols = self._align_expected_output_names(
|
776
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
777
|
+
)
|
706
778
|
elif isinstance(dataset, pd.DataFrame):
|
707
|
-
transform_kwargs = dict(
|
708
|
-
snowpark_input_cols = self._snowpark_cols,
|
709
|
-
drop_input_cols = self._drop_input_cols
|
710
|
-
)
|
779
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
711
780
|
|
712
781
|
transform_handlers = ModelTransformerBuilder.build(
|
713
782
|
dataset=dataset,
|
@@ -720,7 +789,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
720
789
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
721
790
|
inference_method=inference_method,
|
722
791
|
input_cols=self.input_cols,
|
723
|
-
expected_output_cols=
|
792
|
+
expected_output_cols=expected_output_cols,
|
724
793
|
**transform_kwargs
|
725
794
|
)
|
726
795
|
return output_df
|
@@ -748,30 +817,32 @@ class SGDOneClassSVM(BaseTransformer):
|
|
748
817
|
Output dataset with results of the decision function for the samples in input dataset.
|
749
818
|
"""
|
750
819
|
super()._check_dataset_type(dataset)
|
751
|
-
inference_method="decision_function"
|
820
|
+
inference_method = "decision_function"
|
752
821
|
|
753
822
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
754
823
|
# are specific to the type of dataset used.
|
755
824
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
756
825
|
|
826
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
827
|
+
|
757
828
|
if isinstance(dataset, DataFrame):
|
758
|
-
self.
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
829
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
830
|
+
self._deps = self._get_dependencies()
|
831
|
+
assert isinstance(
|
832
|
+
dataset._session, Session
|
833
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
763
834
|
transform_kwargs = dict(
|
764
835
|
session=dataset._session,
|
765
836
|
dependencies=self._deps,
|
766
|
-
drop_input_cols
|
837
|
+
drop_input_cols=self._drop_input_cols,
|
767
838
|
expected_output_cols_type="float",
|
768
839
|
)
|
840
|
+
expected_output_cols = self._align_expected_output_names(
|
841
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
842
|
+
)
|
769
843
|
|
770
844
|
elif isinstance(dataset, pd.DataFrame):
|
771
|
-
transform_kwargs = dict(
|
772
|
-
snowpark_input_cols = self._snowpark_cols,
|
773
|
-
drop_input_cols = self._drop_input_cols
|
774
|
-
)
|
845
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
775
846
|
|
776
847
|
transform_handlers = ModelTransformerBuilder.build(
|
777
848
|
dataset=dataset,
|
@@ -784,7 +855,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
784
855
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
785
856
|
inference_method=inference_method,
|
786
857
|
input_cols=self.input_cols,
|
787
|
-
expected_output_cols=
|
858
|
+
expected_output_cols=expected_output_cols,
|
788
859
|
**transform_kwargs
|
789
860
|
)
|
790
861
|
return output_df
|
@@ -815,17 +886,17 @@ class SGDOneClassSVM(BaseTransformer):
|
|
815
886
|
Output dataset with probability of the sample for each class in the model.
|
816
887
|
"""
|
817
888
|
super()._check_dataset_type(dataset)
|
818
|
-
inference_method="score_samples"
|
889
|
+
inference_method = "score_samples"
|
819
890
|
|
820
891
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
821
892
|
# are specific to the type of dataset used.
|
822
893
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
823
894
|
|
895
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
896
|
+
|
824
897
|
if isinstance(dataset, DataFrame):
|
825
|
-
self.
|
826
|
-
|
827
|
-
inference_method=inference_method,
|
828
|
-
)
|
898
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
899
|
+
self._deps = self._get_dependencies()
|
829
900
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
830
901
|
transform_kwargs = dict(
|
831
902
|
session=dataset._session,
|
@@ -833,6 +904,9 @@ class SGDOneClassSVM(BaseTransformer):
|
|
833
904
|
drop_input_cols = self._drop_input_cols,
|
834
905
|
expected_output_cols_type="float",
|
835
906
|
)
|
907
|
+
expected_output_cols = self._align_expected_output_names(
|
908
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
909
|
+
)
|
836
910
|
|
837
911
|
elif isinstance(dataset, pd.DataFrame):
|
838
912
|
transform_kwargs = dict(
|
@@ -851,7 +925,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
851
925
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
852
926
|
inference_method=inference_method,
|
853
927
|
input_cols=self.input_cols,
|
854
|
-
expected_output_cols=
|
928
|
+
expected_output_cols=expected_output_cols,
|
855
929
|
**transform_kwargs
|
856
930
|
)
|
857
931
|
return output_df
|
@@ -884,17 +958,15 @@ class SGDOneClassSVM(BaseTransformer):
|
|
884
958
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
885
959
|
|
886
960
|
if isinstance(dataset, DataFrame):
|
887
|
-
self.
|
888
|
-
|
889
|
-
inference_method="score",
|
890
|
-
)
|
961
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
962
|
+
self._deps = self._get_dependencies()
|
891
963
|
selected_cols = self._get_active_columns()
|
892
964
|
if len(selected_cols) > 0:
|
893
965
|
dataset = dataset.select(selected_cols)
|
894
966
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
895
967
|
transform_kwargs = dict(
|
896
968
|
session=dataset._session,
|
897
|
-
dependencies=
|
969
|
+
dependencies=self._deps,
|
898
970
|
score_sproc_imports=['sklearn'],
|
899
971
|
)
|
900
972
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -959,11 +1031,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
959
1031
|
|
960
1032
|
if isinstance(dataset, DataFrame):
|
961
1033
|
|
962
|
-
self.
|
963
|
-
|
964
|
-
inference_method=inference_method,
|
965
|
-
|
966
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1035
|
+
self._deps = self._get_dependencies()
|
967
1036
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
968
1037
|
transform_kwargs = dict(
|
969
1038
|
session = dataset._session,
|
@@ -996,50 +1065,84 @@ class SGDOneClassSVM(BaseTransformer):
|
|
996
1065
|
)
|
997
1066
|
return output_df
|
998
1067
|
|
1068
|
+
|
1069
|
+
|
1070
|
+
def to_sklearn(self) -> Any:
|
1071
|
+
"""Get sklearn.linear_model.SGDOneClassSVM object.
|
1072
|
+
"""
|
1073
|
+
if self._sklearn_object is None:
|
1074
|
+
self._sklearn_object = self._create_sklearn_object()
|
1075
|
+
return self._sklearn_object
|
1076
|
+
|
1077
|
+
def to_xgboost(self) -> Any:
|
1078
|
+
raise exceptions.SnowflakeMLException(
|
1079
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1080
|
+
original_exception=AttributeError(
|
1081
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1082
|
+
"to_xgboost()",
|
1083
|
+
"to_sklearn()"
|
1084
|
+
)
|
1085
|
+
),
|
1086
|
+
)
|
999
1087
|
|
1000
|
-
def
|
1088
|
+
def to_lightgbm(self) -> Any:
|
1089
|
+
raise exceptions.SnowflakeMLException(
|
1090
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1091
|
+
original_exception=AttributeError(
|
1092
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1093
|
+
"to_lightgbm()",
|
1094
|
+
"to_sklearn()"
|
1095
|
+
)
|
1096
|
+
),
|
1097
|
+
)
|
1098
|
+
|
1099
|
+
def _get_dependencies(self) -> List[str]:
|
1100
|
+
return self._deps
|
1101
|
+
|
1102
|
+
|
1103
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1001
1104
|
self._model_signature_dict = dict()
|
1002
1105
|
|
1003
1106
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1004
1107
|
|
1005
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1108
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1006
1109
|
outputs: List[BaseFeatureSpec] = []
|
1007
1110
|
if hasattr(self, "predict"):
|
1008
1111
|
# keep mypy happy
|
1009
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1112
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1010
1113
|
# For classifier, the type of predict is the same as the type of label
|
1011
|
-
if self._sklearn_object._estimator_type ==
|
1012
|
-
|
1114
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1115
|
+
# label columns is the desired type for output
|
1013
1116
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1014
1117
|
# rename the output columns
|
1015
1118
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1016
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1017
|
-
|
1018
|
-
|
1119
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1120
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1121
|
+
)
|
1019
1122
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1020
1123
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1021
|
-
# Clusterer returns int64 cluster labels.
|
1124
|
+
# Clusterer returns int64 cluster labels.
|
1022
1125
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1023
1126
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1024
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1127
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1128
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1129
|
+
)
|
1130
|
+
|
1028
1131
|
# For regressor, the type of predict is float64
|
1029
|
-
elif self._sklearn_object._estimator_type ==
|
1132
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1030
1133
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1031
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1134
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1135
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1136
|
+
)
|
1137
|
+
|
1035
1138
|
for prob_func in PROB_FUNCTIONS:
|
1036
1139
|
if hasattr(self, prob_func):
|
1037
1140
|
output_cols_prefix: str = f"{prob_func}_"
|
1038
1141
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1039
1142
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1040
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1041
|
-
|
1042
|
-
|
1143
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1144
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1145
|
+
)
|
1043
1146
|
|
1044
1147
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1045
1148
|
items = list(self._model_signature_dict.items())
|
@@ -1052,10 +1155,10 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1052
1155
|
"""Returns model signature of current class.
|
1053
1156
|
|
1054
1157
|
Raises:
|
1055
|
-
|
1158
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1056
1159
|
|
1057
1160
|
Returns:
|
1058
|
-
Dict
|
1161
|
+
Dict with each method and its input output signature
|
1059
1162
|
"""
|
1060
1163
|
if self._model_signature_dict is None:
|
1061
1164
|
raise exceptions.SnowflakeMLException(
|
@@ -1063,35 +1166,3 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1063
1166
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1064
1167
|
)
|
1065
1168
|
return self._model_signature_dict
|
1066
|
-
|
1067
|
-
def to_sklearn(self) -> Any:
|
1068
|
-
"""Get sklearn.linear_model.SGDOneClassSVM object.
|
1069
|
-
"""
|
1070
|
-
if self._sklearn_object is None:
|
1071
|
-
self._sklearn_object = self._create_sklearn_object()
|
1072
|
-
return self._sklearn_object
|
1073
|
-
|
1074
|
-
def to_xgboost(self) -> Any:
|
1075
|
-
raise exceptions.SnowflakeMLException(
|
1076
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1077
|
-
original_exception=AttributeError(
|
1078
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1079
|
-
"to_xgboost()",
|
1080
|
-
"to_sklearn()"
|
1081
|
-
)
|
1082
|
-
),
|
1083
|
-
)
|
1084
|
-
|
1085
|
-
def to_lightgbm(self) -> Any:
|
1086
|
-
raise exceptions.SnowflakeMLException(
|
1087
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1088
|
-
original_exception=AttributeError(
|
1089
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1090
|
-
"to_lightgbm()",
|
1091
|
-
"to_sklearn()"
|
1092
|
-
)
|
1093
|
-
),
|
1094
|
-
)
|
1095
|
-
|
1096
|
-
def _get_dependencies(self) -> List[str]:
|
1097
|
-
return self._deps
|