snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LassoLarsCV(BaseTransformer):
|
70
64
|
r"""Cross-validated Lasso, using the LARS algorithm
|
71
65
|
For more details on this class, see [sklearn.linear_model.LassoLarsCV]
|
@@ -344,20 +338,17 @@ class LassoLarsCV(BaseTransformer):
|
|
344
338
|
self,
|
345
339
|
dataset: DataFrame,
|
346
340
|
inference_method: str,
|
347
|
-
) ->
|
348
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
349
|
-
return the available package that exists in the snowflake anaconda channel
|
341
|
+
) -> None:
|
342
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
350
343
|
|
351
344
|
Args:
|
352
345
|
dataset: snowpark dataframe
|
353
346
|
inference_method: the inference method such as predict, score...
|
354
|
-
|
347
|
+
|
355
348
|
Raises:
|
356
349
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
357
350
|
SnowflakeMLException: If the session is None, raise error
|
358
351
|
|
359
|
-
Returns:
|
360
|
-
A list of available package that exists in the snowflake anaconda channel
|
361
352
|
"""
|
362
353
|
if not self._is_fitted:
|
363
354
|
raise exceptions.SnowflakeMLException(
|
@@ -375,9 +366,7 @@ class LassoLarsCV(BaseTransformer):
|
|
375
366
|
"Session must not specified for snowpark dataset."
|
376
367
|
),
|
377
368
|
)
|
378
|
-
|
379
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
380
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
369
|
+
|
381
370
|
|
382
371
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
383
372
|
@telemetry.send_api_usage_telemetry(
|
@@ -425,7 +414,8 @@ class LassoLarsCV(BaseTransformer):
|
|
425
414
|
|
426
415
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
427
416
|
|
428
|
-
self.
|
417
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
418
|
+
self._deps = self._get_dependencies()
|
429
419
|
assert isinstance(
|
430
420
|
dataset._session, Session
|
431
421
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -508,10 +498,8 @@ class LassoLarsCV(BaseTransformer):
|
|
508
498
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
509
499
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
510
500
|
|
511
|
-
self.
|
512
|
-
|
513
|
-
inference_method=inference_method,
|
514
|
-
)
|
501
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
502
|
+
self._deps = self._get_dependencies()
|
515
503
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
516
504
|
|
517
505
|
transform_kwargs = dict(
|
@@ -578,16 +566,40 @@ class LassoLarsCV(BaseTransformer):
|
|
578
566
|
self._is_fitted = True
|
579
567
|
return output_result
|
580
568
|
|
569
|
+
|
570
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
571
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
572
|
+
""" Method not supported for this class.
|
581
573
|
|
582
|
-
|
583
|
-
|
584
|
-
|
574
|
+
|
575
|
+
Raises:
|
576
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
580
|
+
Snowpark or Pandas DataFrame.
|
581
|
+
output_cols_prefix: Prefix for the response columns
|
585
582
|
Returns:
|
586
583
|
Transformed dataset.
|
587
584
|
"""
|
588
|
-
self.
|
589
|
-
|
590
|
-
|
585
|
+
self._infer_input_output_cols(dataset)
|
586
|
+
super()._check_dataset_type(dataset)
|
587
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
588
|
+
estimator=self._sklearn_object,
|
589
|
+
dataset=dataset,
|
590
|
+
input_cols=self.input_cols,
|
591
|
+
label_cols=self.label_cols,
|
592
|
+
sample_weight_col=self.sample_weight_col,
|
593
|
+
autogenerated=self._autogenerated,
|
594
|
+
subproject=_SUBPROJECT,
|
595
|
+
)
|
596
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
597
|
+
drop_input_cols=self._drop_input_cols,
|
598
|
+
expected_output_cols_list=self.output_cols,
|
599
|
+
)
|
600
|
+
self._sklearn_object = fitted_estimator
|
601
|
+
self._is_fitted = True
|
602
|
+
return output_result
|
591
603
|
|
592
604
|
|
593
605
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -678,10 +690,8 @@ class LassoLarsCV(BaseTransformer):
|
|
678
690
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
679
691
|
|
680
692
|
if isinstance(dataset, DataFrame):
|
681
|
-
self.
|
682
|
-
|
683
|
-
inference_method=inference_method,
|
684
|
-
)
|
693
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
694
|
+
self._deps = self._get_dependencies()
|
685
695
|
assert isinstance(
|
686
696
|
dataset._session, Session
|
687
697
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -746,10 +756,8 @@ class LassoLarsCV(BaseTransformer):
|
|
746
756
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
747
757
|
|
748
758
|
if isinstance(dataset, DataFrame):
|
749
|
-
self.
|
750
|
-
|
751
|
-
inference_method=inference_method,
|
752
|
-
)
|
759
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
760
|
+
self._deps = self._get_dependencies()
|
753
761
|
assert isinstance(
|
754
762
|
dataset._session, Session
|
755
763
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -811,10 +819,8 @@ class LassoLarsCV(BaseTransformer):
|
|
811
819
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
812
820
|
|
813
821
|
if isinstance(dataset, DataFrame):
|
814
|
-
self.
|
815
|
-
|
816
|
-
inference_method=inference_method,
|
817
|
-
)
|
822
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
823
|
+
self._deps = self._get_dependencies()
|
818
824
|
assert isinstance(
|
819
825
|
dataset._session, Session
|
820
826
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -880,10 +886,8 @@ class LassoLarsCV(BaseTransformer):
|
|
880
886
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
881
887
|
|
882
888
|
if isinstance(dataset, DataFrame):
|
883
|
-
self.
|
884
|
-
|
885
|
-
inference_method=inference_method,
|
886
|
-
)
|
889
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
890
|
+
self._deps = self._get_dependencies()
|
887
891
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
888
892
|
transform_kwargs = dict(
|
889
893
|
session=dataset._session,
|
@@ -947,17 +951,15 @@ class LassoLarsCV(BaseTransformer):
|
|
947
951
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
948
952
|
|
949
953
|
if isinstance(dataset, DataFrame):
|
950
|
-
self.
|
951
|
-
|
952
|
-
inference_method="score",
|
953
|
-
)
|
954
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
955
|
+
self._deps = self._get_dependencies()
|
954
956
|
selected_cols = self._get_active_columns()
|
955
957
|
if len(selected_cols) > 0:
|
956
958
|
dataset = dataset.select(selected_cols)
|
957
959
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
958
960
|
transform_kwargs = dict(
|
959
961
|
session=dataset._session,
|
960
|
-
dependencies=
|
962
|
+
dependencies=self._deps,
|
961
963
|
score_sproc_imports=['sklearn'],
|
962
964
|
)
|
963
965
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1022,11 +1024,8 @@ class LassoLarsCV(BaseTransformer):
|
|
1022
1024
|
|
1023
1025
|
if isinstance(dataset, DataFrame):
|
1024
1026
|
|
1025
|
-
self.
|
1026
|
-
|
1027
|
-
inference_method=inference_method,
|
1028
|
-
|
1029
|
-
)
|
1027
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1028
|
+
self._deps = self._get_dependencies()
|
1030
1029
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1031
1030
|
transform_kwargs = dict(
|
1032
1031
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LassoLarsIC(BaseTransformer):
|
70
64
|
r"""Lasso model fit with Lars using BIC or AIC for model selection
|
71
65
|
For more details on this class, see [sklearn.linear_model.LassoLarsIC]
|
@@ -327,20 +321,17 @@ class LassoLarsIC(BaseTransformer):
|
|
327
321
|
self,
|
328
322
|
dataset: DataFrame,
|
329
323
|
inference_method: str,
|
330
|
-
) ->
|
331
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
332
|
-
return the available package that exists in the snowflake anaconda channel
|
324
|
+
) -> None:
|
325
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
333
326
|
|
334
327
|
Args:
|
335
328
|
dataset: snowpark dataframe
|
336
329
|
inference_method: the inference method such as predict, score...
|
337
|
-
|
330
|
+
|
338
331
|
Raises:
|
339
332
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
340
333
|
SnowflakeMLException: If the session is None, raise error
|
341
334
|
|
342
|
-
Returns:
|
343
|
-
A list of available package that exists in the snowflake anaconda channel
|
344
335
|
"""
|
345
336
|
if not self._is_fitted:
|
346
337
|
raise exceptions.SnowflakeMLException(
|
@@ -358,9 +349,7 @@ class LassoLarsIC(BaseTransformer):
|
|
358
349
|
"Session must not specified for snowpark dataset."
|
359
350
|
),
|
360
351
|
)
|
361
|
-
|
362
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
363
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
352
|
+
|
364
353
|
|
365
354
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
366
355
|
@telemetry.send_api_usage_telemetry(
|
@@ -408,7 +397,8 @@ class LassoLarsIC(BaseTransformer):
|
|
408
397
|
|
409
398
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
410
399
|
|
411
|
-
self.
|
400
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
401
|
+
self._deps = self._get_dependencies()
|
412
402
|
assert isinstance(
|
413
403
|
dataset._session, Session
|
414
404
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -491,10 +481,8 @@ class LassoLarsIC(BaseTransformer):
|
|
491
481
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
492
482
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
493
483
|
|
494
|
-
self.
|
495
|
-
|
496
|
-
inference_method=inference_method,
|
497
|
-
)
|
484
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
485
|
+
self._deps = self._get_dependencies()
|
498
486
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
499
487
|
|
500
488
|
transform_kwargs = dict(
|
@@ -561,16 +549,40 @@ class LassoLarsIC(BaseTransformer):
|
|
561
549
|
self._is_fitted = True
|
562
550
|
return output_result
|
563
551
|
|
552
|
+
|
553
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
554
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
555
|
+
""" Method not supported for this class.
|
564
556
|
|
565
|
-
|
566
|
-
|
567
|
-
|
557
|
+
|
558
|
+
Raises:
|
559
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
560
|
+
|
561
|
+
Args:
|
562
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
563
|
+
Snowpark or Pandas DataFrame.
|
564
|
+
output_cols_prefix: Prefix for the response columns
|
568
565
|
Returns:
|
569
566
|
Transformed dataset.
|
570
567
|
"""
|
571
|
-
self.
|
572
|
-
|
573
|
-
|
568
|
+
self._infer_input_output_cols(dataset)
|
569
|
+
super()._check_dataset_type(dataset)
|
570
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
571
|
+
estimator=self._sklearn_object,
|
572
|
+
dataset=dataset,
|
573
|
+
input_cols=self.input_cols,
|
574
|
+
label_cols=self.label_cols,
|
575
|
+
sample_weight_col=self.sample_weight_col,
|
576
|
+
autogenerated=self._autogenerated,
|
577
|
+
subproject=_SUBPROJECT,
|
578
|
+
)
|
579
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
580
|
+
drop_input_cols=self._drop_input_cols,
|
581
|
+
expected_output_cols_list=self.output_cols,
|
582
|
+
)
|
583
|
+
self._sklearn_object = fitted_estimator
|
584
|
+
self._is_fitted = True
|
585
|
+
return output_result
|
574
586
|
|
575
587
|
|
576
588
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -661,10 +673,8 @@ class LassoLarsIC(BaseTransformer):
|
|
661
673
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
662
674
|
|
663
675
|
if isinstance(dataset, DataFrame):
|
664
|
-
self.
|
665
|
-
|
666
|
-
inference_method=inference_method,
|
667
|
-
)
|
676
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
677
|
+
self._deps = self._get_dependencies()
|
668
678
|
assert isinstance(
|
669
679
|
dataset._session, Session
|
670
680
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -729,10 +739,8 @@ class LassoLarsIC(BaseTransformer):
|
|
729
739
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
730
740
|
|
731
741
|
if isinstance(dataset, DataFrame):
|
732
|
-
self.
|
733
|
-
|
734
|
-
inference_method=inference_method,
|
735
|
-
)
|
742
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
743
|
+
self._deps = self._get_dependencies()
|
736
744
|
assert isinstance(
|
737
745
|
dataset._session, Session
|
738
746
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -794,10 +802,8 @@ class LassoLarsIC(BaseTransformer):
|
|
794
802
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
795
803
|
|
796
804
|
if isinstance(dataset, DataFrame):
|
797
|
-
self.
|
798
|
-
|
799
|
-
inference_method=inference_method,
|
800
|
-
)
|
805
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
806
|
+
self._deps = self._get_dependencies()
|
801
807
|
assert isinstance(
|
802
808
|
dataset._session, Session
|
803
809
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -863,10 +869,8 @@ class LassoLarsIC(BaseTransformer):
|
|
863
869
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
864
870
|
|
865
871
|
if isinstance(dataset, DataFrame):
|
866
|
-
self.
|
867
|
-
|
868
|
-
inference_method=inference_method,
|
869
|
-
)
|
872
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
873
|
+
self._deps = self._get_dependencies()
|
870
874
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
871
875
|
transform_kwargs = dict(
|
872
876
|
session=dataset._session,
|
@@ -930,17 +934,15 @@ class LassoLarsIC(BaseTransformer):
|
|
930
934
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
931
935
|
|
932
936
|
if isinstance(dataset, DataFrame):
|
933
|
-
self.
|
934
|
-
|
935
|
-
inference_method="score",
|
936
|
-
)
|
937
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
938
|
+
self._deps = self._get_dependencies()
|
937
939
|
selected_cols = self._get_active_columns()
|
938
940
|
if len(selected_cols) > 0:
|
939
941
|
dataset = dataset.select(selected_cols)
|
940
942
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
941
943
|
transform_kwargs = dict(
|
942
944
|
session=dataset._session,
|
943
|
-
dependencies=
|
945
|
+
dependencies=self._deps,
|
944
946
|
score_sproc_imports=['sklearn'],
|
945
947
|
)
|
946
948
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1005,11 +1007,8 @@ class LassoLarsIC(BaseTransformer):
|
|
1005
1007
|
|
1006
1008
|
if isinstance(dataset, DataFrame):
|
1007
1009
|
|
1008
|
-
self.
|
1009
|
-
|
1010
|
-
inference_method=inference_method,
|
1011
|
-
|
1012
|
-
)
|
1010
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1011
|
+
self._deps = self._get_dependencies()
|
1013
1012
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1014
1013
|
transform_kwargs = dict(
|
1015
1014
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LinearRegression(BaseTransformer):
|
70
64
|
r"""Ordinary least squares Linear Regression
|
71
65
|
For more details on this class, see [sklearn.linear_model.LinearRegression]
|
@@ -280,20 +274,17 @@ class LinearRegression(BaseTransformer):
|
|
280
274
|
self,
|
281
275
|
dataset: DataFrame,
|
282
276
|
inference_method: str,
|
283
|
-
) ->
|
284
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
285
|
-
return the available package that exists in the snowflake anaconda channel
|
277
|
+
) -> None:
|
278
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
286
279
|
|
287
280
|
Args:
|
288
281
|
dataset: snowpark dataframe
|
289
282
|
inference_method: the inference method such as predict, score...
|
290
|
-
|
283
|
+
|
291
284
|
Raises:
|
292
285
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
293
286
|
SnowflakeMLException: If the session is None, raise error
|
294
287
|
|
295
|
-
Returns:
|
296
|
-
A list of available package that exists in the snowflake anaconda channel
|
297
288
|
"""
|
298
289
|
if not self._is_fitted:
|
299
290
|
raise exceptions.SnowflakeMLException(
|
@@ -311,9 +302,7 @@ class LinearRegression(BaseTransformer):
|
|
311
302
|
"Session must not specified for snowpark dataset."
|
312
303
|
),
|
313
304
|
)
|
314
|
-
|
315
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
316
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
305
|
+
|
317
306
|
|
318
307
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
319
308
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +350,8 @@ class LinearRegression(BaseTransformer):
|
|
361
350
|
|
362
351
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
363
352
|
|
364
|
-
self.
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
365
355
|
assert isinstance(
|
366
356
|
dataset._session, Session
|
367
357
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -444,10 +434,8 @@ class LinearRegression(BaseTransformer):
|
|
444
434
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
445
435
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
446
436
|
|
447
|
-
self.
|
448
|
-
|
449
|
-
inference_method=inference_method,
|
450
|
-
)
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
451
439
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
452
440
|
|
453
441
|
transform_kwargs = dict(
|
@@ -514,16 +502,40 @@ class LinearRegression(BaseTransformer):
|
|
514
502
|
self._is_fitted = True
|
515
503
|
return output_result
|
516
504
|
|
505
|
+
|
506
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
507
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
508
|
+
""" Method not supported for this class.
|
517
509
|
|
518
|
-
|
519
|
-
|
520
|
-
|
510
|
+
|
511
|
+
Raises:
|
512
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
513
|
+
|
514
|
+
Args:
|
515
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
516
|
+
Snowpark or Pandas DataFrame.
|
517
|
+
output_cols_prefix: Prefix for the response columns
|
521
518
|
Returns:
|
522
519
|
Transformed dataset.
|
523
520
|
"""
|
524
|
-
self.
|
525
|
-
|
526
|
-
|
521
|
+
self._infer_input_output_cols(dataset)
|
522
|
+
super()._check_dataset_type(dataset)
|
523
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
|
+
estimator=self._sklearn_object,
|
525
|
+
dataset=dataset,
|
526
|
+
input_cols=self.input_cols,
|
527
|
+
label_cols=self.label_cols,
|
528
|
+
sample_weight_col=self.sample_weight_col,
|
529
|
+
autogenerated=self._autogenerated,
|
530
|
+
subproject=_SUBPROJECT,
|
531
|
+
)
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=self.output_cols,
|
535
|
+
)
|
536
|
+
self._sklearn_object = fitted_estimator
|
537
|
+
self._is_fitted = True
|
538
|
+
return output_result
|
527
539
|
|
528
540
|
|
529
541
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -614,10 +626,8 @@ class LinearRegression(BaseTransformer):
|
|
614
626
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
615
627
|
|
616
628
|
if isinstance(dataset, DataFrame):
|
617
|
-
self.
|
618
|
-
|
619
|
-
inference_method=inference_method,
|
620
|
-
)
|
629
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
630
|
+
self._deps = self._get_dependencies()
|
621
631
|
assert isinstance(
|
622
632
|
dataset._session, Session
|
623
633
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -682,10 +692,8 @@ class LinearRegression(BaseTransformer):
|
|
682
692
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
683
693
|
|
684
694
|
if isinstance(dataset, DataFrame):
|
685
|
-
self.
|
686
|
-
|
687
|
-
inference_method=inference_method,
|
688
|
-
)
|
695
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
696
|
+
self._deps = self._get_dependencies()
|
689
697
|
assert isinstance(
|
690
698
|
dataset._session, Session
|
691
699
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -747,10 +755,8 @@ class LinearRegression(BaseTransformer):
|
|
747
755
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
748
756
|
|
749
757
|
if isinstance(dataset, DataFrame):
|
750
|
-
self.
|
751
|
-
|
752
|
-
inference_method=inference_method,
|
753
|
-
)
|
758
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
759
|
+
self._deps = self._get_dependencies()
|
754
760
|
assert isinstance(
|
755
761
|
dataset._session, Session
|
756
762
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -816,10 +822,8 @@ class LinearRegression(BaseTransformer):
|
|
816
822
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
817
823
|
|
818
824
|
if isinstance(dataset, DataFrame):
|
819
|
-
self.
|
820
|
-
|
821
|
-
inference_method=inference_method,
|
822
|
-
)
|
825
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
826
|
+
self._deps = self._get_dependencies()
|
823
827
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
824
828
|
transform_kwargs = dict(
|
825
829
|
session=dataset._session,
|
@@ -883,17 +887,15 @@ class LinearRegression(BaseTransformer):
|
|
883
887
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
884
888
|
|
885
889
|
if isinstance(dataset, DataFrame):
|
886
|
-
self.
|
887
|
-
|
888
|
-
inference_method="score",
|
889
|
-
)
|
890
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
891
|
+
self._deps = self._get_dependencies()
|
890
892
|
selected_cols = self._get_active_columns()
|
891
893
|
if len(selected_cols) > 0:
|
892
894
|
dataset = dataset.select(selected_cols)
|
893
895
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
894
896
|
transform_kwargs = dict(
|
895
897
|
session=dataset._session,
|
896
|
-
dependencies=
|
898
|
+
dependencies=self._deps,
|
897
899
|
score_sproc_imports=['sklearn'],
|
898
900
|
)
|
899
901
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -958,11 +960,8 @@ class LinearRegression(BaseTransformer):
|
|
958
960
|
|
959
961
|
if isinstance(dataset, DataFrame):
|
960
962
|
|
961
|
-
self.
|
962
|
-
|
963
|
-
inference_method=inference_method,
|
964
|
-
|
965
|
-
)
|
963
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
964
|
+
self._deps = self._get_dependencies()
|
966
965
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
967
966
|
transform_kwargs = dict(
|
968
967
|
session = dataset._session,
|